遗传算法001
This commit is contained in:
109
GA_Agent_0925/main.py
Normal file
109
GA_Agent_0925/main.py
Normal file
@@ -0,0 +1,109 @@
|
||||
import json
|
||||
import random
|
||||
from deap import tools
|
||||
from sqlalchemy.orm import close_all_sessions
|
||||
from tqdm import tqdm
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from GA_Agent_0925.creating import creating
|
||||
from GA_Agent_0925.orm import connection
|
||||
from controller_db import ControllerDB
|
||||
from evaluate_func import fitness, get_vulnerable100_code, get_target_vulnerable_industries
|
||||
|
||||
|
||||
# ==============================
|
||||
# 遗传算法主函数(单进程)
|
||||
# ==============================
|
||||
def main():
|
||||
# 1️⃣ 加载配置
|
||||
with open("config.json", "r", encoding="utf-8") as f:
|
||||
cfg = json.load(f)
|
||||
|
||||
random.seed(cfg["seed"])
|
||||
|
||||
print("\n📘 参数配置:")
|
||||
for k, v in cfg.items():
|
||||
print(f" {k}: {v}")
|
||||
print("-" * 40)
|
||||
|
||||
# 2️⃣ 初始化 ControllerDB(数据库连接)
|
||||
controller_db_obj = ControllerDB("without_exp", reset_flag=0)
|
||||
controller_db_obj.reset_db(force_drop=True)
|
||||
# 准备样本表
|
||||
controller_db_obj.prepare_list_sample()
|
||||
# 2️⃣ 初始化工具箱
|
||||
toolbox = creating()
|
||||
pop = toolbox.population(n=cfg["pop_size"])
|
||||
hof = tools.HallOfFame(1)
|
||||
stats = tools.Statistics(lambda ind: ind.fitness.values)
|
||||
stats.register("avg", lambda fits: sum(f[0] for f in fits) / len(fits))
|
||||
stats.register("max", lambda fits: max(f[0] for f in fits))
|
||||
|
||||
best_list = []
|
||||
avg_list = []
|
||||
|
||||
# ==============================
|
||||
# 主进化循环
|
||||
# ==============================
|
||||
for gen in tqdm(range(cfg["n_gen"]), desc="进化中", ncols=90):
|
||||
# 计算未评估个体适应度
|
||||
invalid_ind = [ind for ind in pop if not ind.fitness.valid]
|
||||
for ind in invalid_ind:
|
||||
controller_db_obj.reset_sample_db()
|
||||
controller_db_obj.prepare_list_sample()
|
||||
ind.fitness.values = fitness(ind, controller_db_obj=controller_db_obj)
|
||||
|
||||
# 选择、交叉、变异
|
||||
offspring = toolbox.select(pop, len(pop))
|
||||
offspring = list(map(toolbox.clone, offspring))
|
||||
|
||||
for child1, child2 in zip(offspring[::2], offspring[1::2]):
|
||||
if random.random() < cfg["cx_prob"]:
|
||||
toolbox.mate(child1, child2)
|
||||
del child1.fitness.values, child2.fitness.values
|
||||
|
||||
for mutant in offspring:
|
||||
if random.random() < cfg["mut_prob"]:
|
||||
toolbox.mutate(mutant)
|
||||
del mutant.fitness.values
|
||||
|
||||
# 更新适应度
|
||||
invalid_ind = [ind for ind in offspring if not ind.fitness.valid]
|
||||
for ind in invalid_ind:
|
||||
controller_db_obj.reset_sample_db()
|
||||
controller_db_obj.prepare_list_sample()
|
||||
ind.fitness.values = fitness(ind, controller_db_obj=controller_db_obj)
|
||||
|
||||
pop[:] = offspring
|
||||
hof.update(pop)
|
||||
|
||||
record = stats.compile(pop)
|
||||
best_list.append(record["max"])
|
||||
avg_list.append(record["avg"])
|
||||
|
||||
# ==============================
|
||||
# 输出最优结果
|
||||
# ==============================
|
||||
print("\n✅ 进化完成!")
|
||||
print(f"🏆 最优个体: {hof[0]}")
|
||||
print(f"🌟 最优适应度: {hof[0].fitness.values[0]:.4f}")
|
||||
|
||||
# 绘制收敛曲线
|
||||
plt.figure(figsize=(8, 5))
|
||||
plt.plot(best_list, label="Best Fitness", linewidth=2)
|
||||
plt.plot(avg_list, label="Average Fitness", linestyle="--")
|
||||
plt.title("Genetic Algorithm Convergence")
|
||||
plt.xlabel("Generation")
|
||||
plt.ylabel("Fitness")
|
||||
plt.legend()
|
||||
plt.grid(True, alpha=0.3)
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
# ==============================
|
||||
# 最优个体产业匹配
|
||||
# ==============================
|
||||
print("\n📊 计算最优个体产业匹配情况...")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user