import json import os import random from datetime import datetime from deap import tools from sqlalchemy.orm import close_all_sessions from tqdm import tqdm import matplotlib.pyplot as plt from GA_Agent_0925.creating import creating from GA_Agent_0925.orm import connection from controller_db import ControllerDB from evaluate_func import fitness, get_vulnerable35_code, get_target_vulnerable_industries # ============================== # 遗传算法主函数(单进程) # ============================== def main(): # 1️⃣ 加载配置 with open("config.json", "r", encoding="utf-8") as f: cfg = json.load(f) random.seed(cfg["seed"]) print("\n📘 参数配置:") for k, v in cfg.items(): print(f" {k}: {v}") print("-" * 40) # 2️⃣ 初始化 ControllerDB(数据库连接) controller_db_obj = ControllerDB("without_exp", reset_flag=0) controller_db_obj.reset_db(force_drop=True) # 准备样本表 controller_db_obj.prepare_list_sample() # 2️⃣ 初始化工具箱 toolbox = creating() pop = toolbox.population(n=cfg["pop_size"]) hof = tools.HallOfFame(1) stats = tools.Statistics(lambda ind: ind.fitness.values) stats.register("avg", lambda fits: sum(f[0] for f in fits) / len(fits)) stats.register("max", lambda fits: max(f[0] for f in fits)) best_list = [] avg_list = [] # ============================================================ # 🔧 新增内容 1:准备保存每代最优个体的文件 # ============================================================ results_dir = "results" os.makedirs(results_dir, exist_ok=True) # 文件名 txt_result_file = os.path.join(results_dir, "best_individual_each_gen.txt") json_result_file = os.path.join(results_dir, "best_result_with_industry.json") # 写入第一行:实验时间(年月日+小时) with open(txt_result_file, "w", encoding="utf-8") as f: exp_time = datetime.now().strftime("%Y-%m-%d %H") f.write(f"实验开始时间(年月日-小时):{exp_time}\n\n") f.write("以下为每一代的最优个体基因参数:\n") # ============================== # 主进化循环 # ============================== for gen in tqdm(range(cfg["n_gen"]), desc="进化中", ncols=90): # 计算未评估个体适应度 invalid_ind = [ind for ind in pop if not ind.fitness.valid] for ind in invalid_ind: controller_db_obj.reset_sample_db() controller_db_obj.prepare_list_sample() ind.fitness.values = fitness(ind, controller_db_obj=controller_db_obj) # 选择、交叉、变异 offspring = toolbox.select(pop, len(pop)) offspring = list(map(toolbox.clone, offspring)) for child1, child2 in zip(offspring[::2], offspring[1::2]): if random.random() < cfg["cx_prob"]: toolbox.mate(child1, child2) del child1.fitness.values, child2.fitness.values for mutant in offspring: if random.random() < cfg["mut_prob"]: toolbox.mutate(mutant) del mutant.fitness.values # 更新适应度 invalid_ind = [ind for ind in offspring if not ind.fitness.valid] for ind in invalid_ind: controller_db_obj.reset_sample_db() controller_db_obj.prepare_list_sample() ind.fitness.values = fitness(ind, controller_db_obj=controller_db_obj) pop[:] = offspring hof.update(pop) record = stats.compile(pop) best_list.append(record["max"]) avg_list.append(record["avg"]) # ============================================================ # 🔧 新增内容 2:每代实时记录最优基因到文件 # ============================================================ best_ind = tools.selBest(pop, 1)[0] best_gene = list(map(float, best_ind)) best_ga_id = getattr(best_ind, "ga_id", None) # 获取 ga_id,如果没有就返回 None # 写入 TXT 文件 with open(txt_result_file, "a", encoding="utf-8") as f: f.write( (f"第 {gen + 1} 代最优基因:{best_gene} 最优适应度: {best_ind.fitness.values[0]:.4f}" if best_gene else "N/A") + "\n" ) # ============================================================ # 新增:删除上一轮产生的临时表 # ============================================================ # 保留当前代最优 ga_id: controller_db_obj.drop_table("without_exp_result", keep_ga_id=best_ga_id) # 希望彻底删除整张表: # controller_db_obj.drop_table("without_exp_result") # ============================== # 输出最优结果 # ============================== print("\n✅ 进化完成!") print(f"🏆 最优个体: {hof[0]}") print(f"🌟 最优适应度: {hof[0].fitness.values[0]:.4f}") # 绘制收敛曲线 plt.figure(figsize=(12, 12)) plt.plot(best_list, label="Best Fitness", linewidth=2) plt.plot(avg_list, label="Average Fitness", linestyle="--") plt.title("Genetic Algorithm Convergence") plt.xlabel("Generation") plt.ylabel("Fitness") plt.legend() plt.grid(True, alpha=0.3) plt.tight_layout() plt.savefig("convergence1.png", dpi=300) plt.show() # ============================== # 最优个体产业匹配 # ============================== print("\n📊 计算最优个体产业匹配情况...") # ============================== # 保存结果到文件 # ============================== results_dir = "results" os.makedirs(results_dir, exist_ok=True) # 固定保存文件名 result_file = os.path.join(results_dir, "best_result_with_industry.json") result_data = { "config": cfg, "best_individual": list(map(float, hof[0])), "best_fitness": float(hof[0].fitness.values[0]), "fitness_curve": { "best_list": best_list, "avg_list": avg_list }, "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S") } with open(result_file, "w", encoding="utf-8") as f: json.dump(result_data, f, indent=4, ensure_ascii=False) print(f"\n💾 最优结果已保存至: {result_file}") if __name__ == "__main__": main()