mesa-GA/GA_Agent_0925/main.py

110 lines
3.6 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import random
from deap import tools
from sqlalchemy.orm import close_all_sessions
from tqdm import tqdm
import matplotlib.pyplot as plt
from GA_Agent_0925.creating import creating
from GA_Agent_0925.orm import connection
from controller_db import ControllerDB
from evaluate_func import fitness, get_vulnerable35_code, get_target_vulnerable_industries
# ==============================
# 遗传算法主函数(单进程)
# ==============================
def main():
# 1⃣ 加载配置
with open("config.json", "r", encoding="utf-8") as f:
cfg = json.load(f)
random.seed(cfg["seed"])
print("\n📘 参数配置:")
for k, v in cfg.items():
print(f" {k}: {v}")
print("-" * 40)
# 2⃣ 初始化 ControllerDB数据库连接
controller_db_obj = ControllerDB("without_exp", reset_flag=0)
controller_db_obj.reset_db(force_drop=False)
# 准备样本表
controller_db_obj.prepare_list_sample()
# 2⃣ 初始化工具箱
toolbox = creating()
pop = toolbox.population(n=cfg["pop_size"])
hof = tools.HallOfFame(1)
stats = tools.Statistics(lambda ind: ind.fitness.values)
stats.register("avg", lambda fits: sum(f[0] for f in fits) / len(fits))
stats.register("max", lambda fits: max(f[0] for f in fits))
best_list = []
avg_list = []
# ==============================
# 主进化循环
# ==============================
for gen in tqdm(range(cfg["n_gen"]), desc="进化中", ncols=90):
# 计算未评估个体适应度
invalid_ind = [ind for ind in pop if not ind.fitness.valid]
for ind in invalid_ind:
controller_db_obj.reset_sample_db()
controller_db_obj.prepare_list_sample()
ind.fitness.values = fitness(ind, controller_db_obj=controller_db_obj)
# 选择、交叉、变异
offspring = toolbox.select(pop, len(pop))
offspring = list(map(toolbox.clone, offspring))
for child1, child2 in zip(offspring[::2], offspring[1::2]):
if random.random() < cfg["cx_prob"]:
toolbox.mate(child1, child2)
del child1.fitness.values, child2.fitness.values
for mutant in offspring:
if random.random() < cfg["mut_prob"]:
toolbox.mutate(mutant)
del mutant.fitness.values
# 更新适应度
invalid_ind = [ind for ind in offspring if not ind.fitness.valid]
for ind in invalid_ind:
controller_db_obj.reset_sample_db()
controller_db_obj.prepare_list_sample()
ind.fitness.values = fitness(ind, controller_db_obj=controller_db_obj)
pop[:] = offspring
hof.update(pop)
record = stats.compile(pop)
best_list.append(record["max"])
avg_list.append(record["avg"])
# ==============================
# 输出最优结果
# ==============================
print("\n✅ 进化完成!")
print(f"🏆 最优个体: {hof[0]}")
print(f"🌟 最优适应度: {hof[0].fitness.values[0]:.4f}")
# 绘制收敛曲线
plt.figure(figsize=(8, 5))
plt.plot(best_list, label="Best Fitness", linewidth=2)
plt.plot(avg_list, label="Average Fitness", linestyle="--")
plt.title("Genetic Algorithm Convergence")
plt.xlabel("Generation")
plt.ylabel("Fitness")
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
# ==============================
# 最优个体产业匹配
# ==============================
print("\n📊 计算最优个体产业匹配情况...")
if __name__ == "__main__":
main()