mesa-GA/GA_Agent_0925/main.py

179 lines
6.3 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import os
import random
from datetime import datetime
from deap import tools
from sqlalchemy.orm import close_all_sessions
from tqdm import tqdm
import matplotlib.pyplot as plt
from GA_Agent_0925.creating import creating
from GA_Agent_0925.orm import connection
from controller_db import ControllerDB
from evaluate_func import fitness, get_vulnerable35_code, get_target_vulnerable_industries
# ==============================
# 遗传算法主函数(单进程)
# ==============================
def main():
# 1⃣ 加载配置
with open("config.json", "r", encoding="utf-8") as f:
cfg = json.load(f)
random.seed(cfg["seed"])
print("\n📘 参数配置:")
for k, v in cfg.items():
print(f" {k}: {v}")
print("-" * 40)
# 2⃣ 初始化 ControllerDB数据库连接
controller_db_obj = ControllerDB("without_exp", reset_flag=0)
controller_db_obj.reset_db(force_drop=True)
# 准备样本表
controller_db_obj.prepare_list_sample()
# 2⃣ 初始化工具箱
toolbox = creating()
pop = toolbox.population(n=cfg["pop_size"])
hof = tools.HallOfFame(1)
stats = tools.Statistics(lambda ind: ind.fitness.values)
stats.register("avg", lambda fits: sum(f[0] for f in fits) / len(fits))
stats.register("max", lambda fits: max(f[0] for f in fits))
best_list = []
avg_list = []
# ============================================================
# 🔧 新增内容 1准备保存每代最优个体的文件
# ============================================================
results_dir = "results"
os.makedirs(results_dir, exist_ok=True)
# 文件名
txt_result_file = os.path.join(results_dir, "best_individual_each_gen.txt")
json_result_file = os.path.join(results_dir, "best_result_with_industry.json")
# 写入第一行:实验时间(年月日+小时)
with open(txt_result_file, "w", encoding="utf-8") as f:
exp_time = datetime.now().strftime("%Y-%m-%d %H")
f.write(f"实验开始时间(年月日-小时):{exp_time}\n\n")
f.write("以下为每一代的最优个体基因参数:\n")
# ==============================
# 主进化循环
# ==============================
for gen in tqdm(range(cfg["n_gen"]), desc="进化中", ncols=90):
# 计算未评估个体适应度
invalid_ind = [ind for ind in pop if not ind.fitness.valid]
for ind in invalid_ind:
controller_db_obj.reset_sample_db()
controller_db_obj.prepare_list_sample()
ind.fitness.values = fitness(ind, controller_db_obj=controller_db_obj)
# 选择、交叉、变异
offspring = toolbox.select(pop, len(pop))
offspring = list(map(toolbox.clone, offspring))
for child1, child2 in zip(offspring[::2], offspring[1::2]):
if random.random() < cfg["cx_prob"]:
toolbox.mate(child1, child2)
del child1.fitness.values, child2.fitness.values
for mutant in offspring:
if random.random() < cfg["mut_prob"]:
toolbox.mutate(mutant)
del mutant.fitness.values
# 更新适应度
invalid_ind = [ind for ind in offspring if not ind.fitness.valid]
for ind in invalid_ind:
controller_db_obj.reset_sample_db()
controller_db_obj.prepare_list_sample()
ind.fitness.values = fitness(ind, controller_db_obj=controller_db_obj)
pop[:] = offspring
hof.update(pop)
record = stats.compile(pop)
best_list.append(record["max"])
avg_list.append(record["avg"])
# ============================================================
# 🔧 新增内容 2每代实时记录最优基因到文件
# ============================================================
best_ind = tools.selBest(pop, 1)[0]
best_gene = list(map(float, best_ind))
best_ga_id = getattr(best_ind, "ga_id", None) # 获取 ga_id如果没有就返回 None
# 写入 TXT 文件
with open(txt_result_file, "a", encoding="utf-8") as f:
f.write(
(f"{gen + 1} 代最优基因:{best_gene} 最优适应度: {best_ind.fitness.values[0]:.4f}"
if best_gene else "N/A")
+ "\n"
)
# ============================================================
# 新增:删除上一轮产生的临时表
# ============================================================
# 保留当前代最优 ga_id
controller_db_obj.drop_table("without_exp_result", keep_ga_id=best_ga_id)
# 希望彻底删除整张表:
# controller_db_obj.drop_table("without_exp_result")
# ==============================
# 输出最优结果
# ==============================
print("\n✅ 进化完成!")
print(f"🏆 最优个体: {hof[0]}")
print(f"🌟 最优适应度: {hof[0].fitness.values[0]:.4f}")
# 绘制收敛曲线
plt.figure(figsize=(12, 12))
plt.plot(best_list, label="Best Fitness", linewidth=2)
plt.plot(avg_list, label="Average Fitness", linestyle="--")
plt.title("Genetic Algorithm Convergence")
plt.xlabel("Generation")
plt.ylabel("Fitness")
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig("convergence1.png", dpi=300)
plt.show()
# ==============================
# 最优个体产业匹配
# ==============================
print("\n📊 计算最优个体产业匹配情况...")
# ==============================
# 保存结果到文件
# ==============================
results_dir = "results"
os.makedirs(results_dir, exist_ok=True)
# 固定保存文件名
result_file = os.path.join(results_dir, "best_result_with_industry.json")
result_data = {
"config": cfg,
"best_individual": list(map(float, hof[0])),
"best_fitness": float(hof[0].fitness.values[0]),
"fitness_curve": {
"best_list": best_list,
"avg_list": avg_list
},
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
}
with open(result_file, "w", encoding="utf-8") as f:
json.dump(result_data, f, indent=4, ensure_ascii=False)
print(f"\n💾 最优结果已保存至: {result_file}")
if __name__ == "__main__":
main()