遗传算法

This commit is contained in:
Cricial
2026-03-15 17:15:30 +08:00
parent 03ac80715f
commit a69e272e43
27 changed files with 934 additions and 191 deletions

View File

@@ -1,178 +1,216 @@
import csv
import json
import os
import random
from datetime import datetime
from deap import tools
from sqlalchemy.orm import close_all_sessions
from tqdm import tqdm
import matplotlib.pyplot as plt
from GA_Agent_0925.creating import creating
from GA_Agent_0925.orm import connection
from controller_db import ControllerDB
from evaluate_func import fitness, get_vulnerable35_code, get_target_vulnerable_industries
from evaluate_func import fitness
# ==============================
# 遗传算法主函数(单进程)
# ==============================
def main():
# 1⃣ 加载配置
with open("config.json", "r", encoding="utf-8") as f:
cfg = json.load(f)
random.seed(cfg["seed"])
print("\n📘 参数配置:")
for k, v in cfg.items():
print(f" {k}: {v}")
print("-" * 40)
# 2⃣ 初始化 ControllerDB数据库连接
# ==============================
# 2⃣ 初始化数据库控制器
# ==============================
controller_db_obj = ControllerDB("without_exp", reset_flag=0)
controller_db_obj.reset_db(force_drop=True)
# 准备样本表
controller_db_obj.prepare_list_sample()
# 2⃣ 初始化工具箱
# ==============================
# 3⃣ 初始化 GA
# ==============================
toolbox = creating()
pop = toolbox.population(n=cfg["pop_size"])
hof = tools.HallOfFame(1)
stats = tools.Statistics(lambda ind: ind.fitness.values)
stats.register("avg", lambda fits: sum(f[0] for f in fits) / len(fits))
stats.register("max", lambda fits: max(f[0] for f in fits))
best_list = []
avg_list = []
# ============================================================
# 🔧 新增内容 1准备保存每代最优个体的文件
# ============================================================
# ==============================
# 4⃣ 结果保存准备
# ==============================
results_dir = "results"
os.makedirs(results_dir, exist_ok=True)
# 文件名
txt_result_file = os.path.join(results_dir, "best_individual_each_gen.txt")
json_result_file = os.path.join(results_dir, "best_result_with_industry.json")
csv_file = "convergence0119_data.csv"
# 写入第一行:实验时间(年月日+小时)
with open(txt_result_file, "w", encoding="utf-8") as f:
exp_time = datetime.now().strftime("%Y-%m-%d %H")
f.write(f"实验开始时间(年月日-小时):{exp_time}\n\n")
f.write("以下为每一代的最优个体基因参数:\n")
exp_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
# TXT 初始化
with open(txt_result_file, "a", encoding="utf-8") as f:
f.write(f"\n实验开始时间:{exp_time}\n")
f.write("每一代最优个体基因参数如下:\n")
# CSV 表头
write_header = not os.path.exists(csv_file)
if write_header:
with open(csv_file, "w", newline="", encoding="utf-8") as f:
writer = csv.writer(f)
writer.writerow([
"Timestamp",
"Generation",
"Best_Fitness_Percentage"
])
# ==============================
# 主进化循环
# 5⃣ Fitness 缓存(关键加速)
# ==============================
fitness_cache = {}
# ==============================
# 6⃣ 主进化循环
# ==============================
for gen in tqdm(range(cfg["n_gen"]), desc="进化中", ncols=90):
# 计算未评估个体适应度
# ---------- 评估当前种群 ----------
invalid_ind = [ind for ind in pop if not ind.fitness.valid]
for ind in invalid_ind:
controller_db_obj.reset_sample_db()
controller_db_obj.prepare_list_sample()
ind.fitness.values = fitness(ind, controller_db_obj=controller_db_obj)
gene_key = tuple(round(x, 4) for x in ind)
# 选择、交叉、变异
if gene_key in fitness_cache:
ind.fitness.values = fitness_cache[gene_key]
else:
controller_db_obj.reset_sample_db()
controller_db_obj.prepare_list_sample()
fit_val = fitness(ind, controller_db_obj=controller_db_obj)
ind.fitness.values = fit_val
fitness_cache[gene_key] = fit_val
# ---------- 选择 ----------
offspring = toolbox.select(pop, len(pop))
offspring = list(map(toolbox.clone, offspring))
# ---------- 交叉 ----------
for child1, child2 in zip(offspring[::2], offspring[1::2]):
if random.random() < cfg["cx_prob"]:
toolbox.mate(child1, child2)
del child1.fitness.values, child2.fitness.values
# ---------- 变异 ----------
for mutant in offspring:
if random.random() < cfg["mut_prob"]:
toolbox.mutate(mutant)
del mutant.fitness.values
# 更新适应度
# ---------- 评估新个体 ----------
invalid_ind = [ind for ind in offspring if not ind.fitness.valid]
for ind in invalid_ind:
controller_db_obj.reset_sample_db()
controller_db_obj.prepare_list_sample()
ind.fitness.values = fitness(ind, controller_db_obj=controller_db_obj)
gene_key = tuple(round(x, 4) for x in ind)
if gene_key in fitness_cache:
ind.fitness.values = fitness_cache[gene_key]
else:
controller_db_obj.reset_sample_db()
controller_db_obj.prepare_list_sample()
fit_val = fitness(ind, controller_db_obj=controller_db_obj)
ind.fitness.values = fit_val
fitness_cache[gene_key] = fit_val
pop[:] = offspring
# ---------- 精英保留Elitism ----------
elite = hof[0] if len(hof) > 0 else None
if elite is not None:
# 替换当前种群中最差的个体(假设适应度越大越差)
worst_idx = max(range(len(pop)), key=lambda i: pop[i].fitness.values[0])
pop[worst_idx] = toolbox.clone(elite)
# 更新 hof
hof.update(pop)
record = stats.compile(pop)
best_list.append(record["max"])
avg_list.append(record["avg"])
# ============================================================
# 🔧 新增内容 2每代实时记录最优基因到文件
# ============================================================
# ---------- 记录最优 ----------
best_ind = tools.selBest(pop, 1)[0]
best_fitness = best_ind.fitness.values[0] # error越小越好
best_gene = list(map(float, best_ind))
best_ga_id = getattr(best_ind, "ga_id", None) # 获取 ga_id如果没有就返回 None
best_ga_id = getattr(best_ind, "ga_id", None)
# ---------- 维护“历史最优” ----------
if 'prev_best_fitness' not in globals():
prev_best_fitness = best_fitness
prev_best_gene = best_gene
prev_best_ga_id = best_ga_id
else:
if best_fitness < prev_best_fitness: # 更优(误差更小)
prev_best_fitness = best_fitness
prev_best_gene = best_gene
prev_best_ga_id = best_ga_id
# 否则:保持历史最优不变
# 写入 TXT 文件
with open(txt_result_file, "a", encoding="utf-8") as f:
f.write(
(f"{gen + 1}最优基因:{best_gene} 最优适应度: {best_ind.fitness.values[0]:.4f}"
if best_gene else "N/A")
+ "\n"
f"{gen + 1} | 最优适应度: {prev_best_fitness:.4f} | 基因: {prev_best_gene}\n"
)
# ============================================================
# 新增:删除上一轮产生的临时表
# ============================================================
# CSV
with open(csv_file, "a", newline="", encoding="utf-8") as f:
writer = csv.writer(f)
writer.writerow([
exp_time,
gen,
best_fitness
])
# ---------- 定期清理 DB ----------
if gen % 3 == 0:
controller_db_obj.drop_table(
"without_exp_result",
keep_ga_id=prev_best_ga_id
)
# 保留当前代最优 ga_id
controller_db_obj.drop_table("without_exp_result", keep_ga_id=best_ga_id)
# 希望彻底删除整张表:
# controller_db_obj.drop_table("without_exp_result")
# ==============================
# 输出最结果
# 7 输出最结果
# ==============================
print("\n✅ 进化完成!")
print(f"🏆 最优个体: {hof[0]}")
print(f"🏆 最优个体: {list(map(float, hof[0]))}")
print(f"🌟 最优适应度: {hof[0].fitness.values[0]:.4f}")
# 绘制收敛曲线
plt.figure(figsize=(12, 12))
plt.plot(best_list, label="Best Fitness", linewidth=2)
plt.plot(avg_list, label="Average Fitness", linestyle="--")
plt.title("Genetic Algorithm Convergence")
best_list.append(prev_best_fitness)
# ==============================
# 8⃣ 绘制收敛曲线(仅 Best
# ==============================
plt.figure(figsize=(12, 8))
plt.plot(best_list, linewidth=2, label="Best Fitness")
plt.xlabel("Generation")
plt.ylabel("Fitness")
plt.ylabel("Error")
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig("convergence1.png", dpi=300)
plt.show()
plt.savefig("convergence0119.png", dpi=300)
plt.close()
# ==============================
# 最优个体产业匹配
# 9⃣ 保存 JSON 结果
# ==============================
print("\n📊 计算最优个体产业匹配情况...")
# ==============================
# 保存结果到文件
# ==============================
results_dir = "results"
os.makedirs(results_dir, exist_ok=True)
# 固定保存文件名
result_file = os.path.join(results_dir, "best_result_with_industry.json")
result_data = {
"config": cfg,
"best_individual": list(map(float, hof[0])),
"best_fitness": float(hof[0].fitness.values[0]),
"fitness_curve": {
"best_list": best_list,
"avg_list": avg_list
"best_list": best_list
},
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
"timestamp": exp_time
}
with open(result_file, "w", encoding="utf-8") as f:
with open(json_result_file, "w", encoding="utf-8") as f:
json.dump(result_data, f, indent=4, ensure_ascii=False)
print(f"\n💾 最优结果已保存至: {result_file}")
print(f"\n💾 最优结果已保存至: {json_result_file}")
if __name__ == "__main__":
main()