217 lines
7.3 KiB
Python
217 lines
7.3 KiB
Python
import csv
|
||
import json
|
||
import os
|
||
import random
|
||
from datetime import datetime
|
||
|
||
from deap import tools
|
||
from tqdm import tqdm
|
||
import matplotlib.pyplot as plt
|
||
|
||
from GA_Agent_0925.creating import creating
|
||
from controller_db import ControllerDB
|
||
from evaluate_func import fitness
|
||
|
||
# 遗传算法主函数(单进程)
|
||
def main():
|
||
# 1️⃣ 加载配置
|
||
with open("config.json", "r", encoding="utf-8") as f:
|
||
cfg = json.load(f)
|
||
random.seed(cfg["seed"])
|
||
print("\n📘 参数配置:")
|
||
for k, v in cfg.items():
|
||
print(f" {k}: {v}")
|
||
print("-" * 40)
|
||
# ==============================
|
||
# 2️⃣ 初始化数据库控制器
|
||
# ==============================
|
||
controller_db_obj = ControllerDB("without_exp", reset_flag=0)
|
||
controller_db_obj.reset_db(force_drop=True)
|
||
controller_db_obj.prepare_list_sample()
|
||
|
||
# ==============================
|
||
# 3️⃣ 初始化 GA
|
||
# ==============================
|
||
toolbox = creating()
|
||
pop = toolbox.population(n=cfg["pop_size"])
|
||
hof = tools.HallOfFame(1)
|
||
|
||
stats = tools.Statistics(lambda ind: ind.fitness.values)
|
||
stats.register("max", lambda fits: max(f[0] for f in fits))
|
||
|
||
best_list = []
|
||
|
||
# ==============================
|
||
# 4️⃣ 结果保存准备
|
||
# ==============================
|
||
results_dir = "results"
|
||
os.makedirs(results_dir, exist_ok=True)
|
||
|
||
txt_result_file = os.path.join(results_dir, "best_individual_each_gen.txt")
|
||
json_result_file = os.path.join(results_dir, "best_result_with_industry.json")
|
||
csv_file = "convergence0119_data.csv"
|
||
|
||
exp_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||
|
||
# TXT 初始化
|
||
with open(txt_result_file, "a", encoding="utf-8") as f:
|
||
f.write(f"\n实验开始时间:{exp_time}\n")
|
||
f.write("每一代最优个体基因参数如下:\n")
|
||
|
||
# CSV 表头
|
||
write_header = not os.path.exists(csv_file)
|
||
if write_header:
|
||
with open(csv_file, "w", newline="", encoding="utf-8") as f:
|
||
writer = csv.writer(f)
|
||
writer.writerow([
|
||
"Timestamp",
|
||
"Generation",
|
||
"Best_Fitness_Percentage"
|
||
])
|
||
|
||
# ==============================
|
||
# 5️⃣ Fitness 缓存(关键加速)
|
||
# ==============================
|
||
fitness_cache = {}
|
||
|
||
# ==============================
|
||
# 6️⃣ 主进化循环
|
||
# ==============================
|
||
for gen in tqdm(range(cfg["n_gen"]), desc="进化中", ncols=90):
|
||
|
||
# ---------- 评估当前种群 ----------
|
||
invalid_ind = [ind for ind in pop if not ind.fitness.valid]
|
||
for ind in invalid_ind:
|
||
gene_key = tuple(round(x, 4) for x in ind)
|
||
|
||
if gene_key in fitness_cache:
|
||
ind.fitness.values = fitness_cache[gene_key]
|
||
else:
|
||
controller_db_obj.reset_sample_db()
|
||
controller_db_obj.prepare_list_sample()
|
||
|
||
fit_val = fitness(ind, controller_db_obj=controller_db_obj)
|
||
ind.fitness.values = fit_val
|
||
fitness_cache[gene_key] = fit_val
|
||
# ---------- 选择 ----------
|
||
offspring = toolbox.select(pop, len(pop))
|
||
offspring = list(map(toolbox.clone, offspring))
|
||
# ---------- 交叉 ----------
|
||
for child1, child2 in zip(offspring[::2], offspring[1::2]):
|
||
if random.random() < cfg["cx_prob"]:
|
||
toolbox.mate(child1, child2)
|
||
del child1.fitness.values, child2.fitness.values
|
||
# ---------- 变异 ----------
|
||
for mutant in offspring:
|
||
if random.random() < cfg["mut_prob"]:
|
||
toolbox.mutate(mutant)
|
||
del mutant.fitness.values
|
||
# ---------- 评估新个体 ----------
|
||
invalid_ind = [ind for ind in offspring if not ind.fitness.valid]
|
||
|
||
for ind in invalid_ind:
|
||
gene_key = tuple(round(x, 4) for x in ind)
|
||
|
||
if gene_key in fitness_cache:
|
||
ind.fitness.values = fitness_cache[gene_key]
|
||
else:
|
||
controller_db_obj.reset_sample_db()
|
||
controller_db_obj.prepare_list_sample()
|
||
|
||
fit_val = fitness(ind, controller_db_obj=controller_db_obj)
|
||
ind.fitness.values = fit_val
|
||
fitness_cache[gene_key] = fit_val
|
||
|
||
pop[:] = offspring
|
||
# ---------- 精英保留(Elitism) ----------
|
||
elite = hof[0] if len(hof) > 0 else None
|
||
if elite is not None:
|
||
# 替换当前种群中最差的个体(假设适应度越大越差)
|
||
worst_idx = max(range(len(pop)), key=lambda i: pop[i].fitness.values[0])
|
||
pop[worst_idx] = toolbox.clone(elite)
|
||
|
||
# 更新 hof
|
||
hof.update(pop)
|
||
|
||
# ---------- 记录最优 ----------
|
||
best_ind = tools.selBest(pop, 1)[0]
|
||
best_fitness = best_ind.fitness.values[0] # error,越小越好
|
||
best_gene = list(map(float, best_ind))
|
||
best_ga_id = getattr(best_ind, "ga_id", None)
|
||
|
||
# ---------- 维护“历史最优” ----------
|
||
if 'prev_best_fitness' not in globals():
|
||
prev_best_fitness = best_fitness
|
||
prev_best_gene = best_gene
|
||
prev_best_ga_id = best_ga_id
|
||
else:
|
||
if best_fitness < prev_best_fitness: # 更优(误差更小)
|
||
prev_best_fitness = best_fitness
|
||
prev_best_gene = best_gene
|
||
prev_best_ga_id = best_ga_id
|
||
# 否则:保持历史最优不变
|
||
|
||
with open(txt_result_file, "a", encoding="utf-8") as f:
|
||
f.write(
|
||
f"第 {gen + 1} 代 | 最优适应度: {prev_best_fitness:.4f} | 基因: {prev_best_gene}\n"
|
||
)
|
||
|
||
# CSV
|
||
with open(csv_file, "a", newline="", encoding="utf-8") as f:
|
||
writer = csv.writer(f)
|
||
writer.writerow([
|
||
exp_time,
|
||
gen,
|
||
best_fitness
|
||
])
|
||
|
||
# ---------- 定期清理 DB ----------
|
||
if gen % 3 == 0:
|
||
controller_db_obj.drop_table(
|
||
"without_exp_result",
|
||
keep_ga_id=prev_best_ga_id
|
||
)
|
||
|
||
# ==============================
|
||
# 7️⃣ 输出最终结果
|
||
# ==============================
|
||
print("\n✅ 进化完成!")
|
||
print(f"🏆 最优个体: {list(map(float, hof[0]))}")
|
||
print(f"🌟 最优适应度: {hof[0].fitness.values[0]:.4f}")
|
||
|
||
best_list.append(prev_best_fitness)
|
||
# ==============================
|
||
# 8️⃣ 绘制收敛曲线(仅 Best)
|
||
# ==============================
|
||
plt.figure(figsize=(12, 8))
|
||
plt.plot(best_list, linewidth=2, label="Best Fitness")
|
||
plt.xlabel("Generation")
|
||
plt.ylabel("Error")
|
||
plt.legend()
|
||
plt.grid(True, alpha=0.3)
|
||
plt.tight_layout()
|
||
plt.savefig("convergence0119.png", dpi=300)
|
||
plt.close()
|
||
|
||
# ==============================
|
||
# 9️⃣ 保存 JSON 结果
|
||
# ==============================
|
||
result_data = {
|
||
"config": cfg,
|
||
"best_individual": list(map(float, hof[0])),
|
||
"best_fitness": float(hof[0].fitness.values[0]),
|
||
"fitness_curve": {
|
||
"best_list": best_list
|
||
},
|
||
"timestamp": exp_time
|
||
}
|
||
|
||
with open(json_result_file, "w", encoding="utf-8") as f:
|
||
json.dump(result_data, f, indent=4, ensure_ascii=False)
|
||
|
||
print(f"\n💾 最优结果已保存至: {json_result_file}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|