Files
mesa-GA/GA_Agent_0925/main.py
2026-03-15 17:15:30 +08:00

217 lines
7.3 KiB
Python
Raw Blame History

This file contains invisible Unicode characters
This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import csv
import json
import os
import random
from datetime import datetime
from deap import tools
from tqdm import tqdm
import matplotlib.pyplot as plt
from GA_Agent_0925.creating import creating
from controller_db import ControllerDB
from evaluate_func import fitness
# 遗传算法主函数(单进程)
def main():
# 1⃣ 加载配置
with open("config.json", "r", encoding="utf-8") as f:
cfg = json.load(f)
random.seed(cfg["seed"])
print("\n📘 参数配置:")
for k, v in cfg.items():
print(f" {k}: {v}")
print("-" * 40)
# ==============================
# 2⃣ 初始化数据库控制器
# ==============================
controller_db_obj = ControllerDB("without_exp", reset_flag=0)
controller_db_obj.reset_db(force_drop=True)
controller_db_obj.prepare_list_sample()
# ==============================
# 3⃣ 初始化 GA
# ==============================
toolbox = creating()
pop = toolbox.population(n=cfg["pop_size"])
hof = tools.HallOfFame(1)
stats = tools.Statistics(lambda ind: ind.fitness.values)
stats.register("max", lambda fits: max(f[0] for f in fits))
best_list = []
# ==============================
# 4⃣ 结果保存准备
# ==============================
results_dir = "results"
os.makedirs(results_dir, exist_ok=True)
txt_result_file = os.path.join(results_dir, "best_individual_each_gen.txt")
json_result_file = os.path.join(results_dir, "best_result_with_industry.json")
csv_file = "convergence0119_data.csv"
exp_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
# TXT 初始化
with open(txt_result_file, "a", encoding="utf-8") as f:
f.write(f"\n实验开始时间:{exp_time}\n")
f.write("每一代最优个体基因参数如下:\n")
# CSV 表头
write_header = not os.path.exists(csv_file)
if write_header:
with open(csv_file, "w", newline="", encoding="utf-8") as f:
writer = csv.writer(f)
writer.writerow([
"Timestamp",
"Generation",
"Best_Fitness_Percentage"
])
# ==============================
# 5⃣ Fitness 缓存(关键加速)
# ==============================
fitness_cache = {}
# ==============================
# 6⃣ 主进化循环
# ==============================
for gen in tqdm(range(cfg["n_gen"]), desc="进化中", ncols=90):
# ---------- 评估当前种群 ----------
invalid_ind = [ind for ind in pop if not ind.fitness.valid]
for ind in invalid_ind:
gene_key = tuple(round(x, 4) for x in ind)
if gene_key in fitness_cache:
ind.fitness.values = fitness_cache[gene_key]
else:
controller_db_obj.reset_sample_db()
controller_db_obj.prepare_list_sample()
fit_val = fitness(ind, controller_db_obj=controller_db_obj)
ind.fitness.values = fit_val
fitness_cache[gene_key] = fit_val
# ---------- 选择 ----------
offspring = toolbox.select(pop, len(pop))
offspring = list(map(toolbox.clone, offspring))
# ---------- 交叉 ----------
for child1, child2 in zip(offspring[::2], offspring[1::2]):
if random.random() < cfg["cx_prob"]:
toolbox.mate(child1, child2)
del child1.fitness.values, child2.fitness.values
# ---------- 变异 ----------
for mutant in offspring:
if random.random() < cfg["mut_prob"]:
toolbox.mutate(mutant)
del mutant.fitness.values
# ---------- 评估新个体 ----------
invalid_ind = [ind for ind in offspring if not ind.fitness.valid]
for ind in invalid_ind:
gene_key = tuple(round(x, 4) for x in ind)
if gene_key in fitness_cache:
ind.fitness.values = fitness_cache[gene_key]
else:
controller_db_obj.reset_sample_db()
controller_db_obj.prepare_list_sample()
fit_val = fitness(ind, controller_db_obj=controller_db_obj)
ind.fitness.values = fit_val
fitness_cache[gene_key] = fit_val
pop[:] = offspring
# ---------- 精英保留Elitism ----------
elite = hof[0] if len(hof) > 0 else None
if elite is not None:
# 替换当前种群中最差的个体(假设适应度越大越差)
worst_idx = max(range(len(pop)), key=lambda i: pop[i].fitness.values[0])
pop[worst_idx] = toolbox.clone(elite)
# 更新 hof
hof.update(pop)
# ---------- 记录最优 ----------
best_ind = tools.selBest(pop, 1)[0]
best_fitness = best_ind.fitness.values[0] # error越小越好
best_gene = list(map(float, best_ind))
best_ga_id = getattr(best_ind, "ga_id", None)
# ---------- 维护“历史最优” ----------
if 'prev_best_fitness' not in globals():
prev_best_fitness = best_fitness
prev_best_gene = best_gene
prev_best_ga_id = best_ga_id
else:
if best_fitness < prev_best_fitness: # 更优(误差更小)
prev_best_fitness = best_fitness
prev_best_gene = best_gene
prev_best_ga_id = best_ga_id
# 否则:保持历史最优不变
with open(txt_result_file, "a", encoding="utf-8") as f:
f.write(
f"{gen + 1} 代 | 最优适应度: {prev_best_fitness:.4f} | 基因: {prev_best_gene}\n"
)
# CSV
with open(csv_file, "a", newline="", encoding="utf-8") as f:
writer = csv.writer(f)
writer.writerow([
exp_time,
gen,
best_fitness
])
# ---------- 定期清理 DB ----------
if gen % 3 == 0:
controller_db_obj.drop_table(
"without_exp_result",
keep_ga_id=prev_best_ga_id
)
# ==============================
# 7⃣ 输出最终结果
# ==============================
print("\n✅ 进化完成!")
print(f"🏆 最优个体: {list(map(float, hof[0]))}")
print(f"🌟 最优适应度: {hof[0].fitness.values[0]:.4f}")
best_list.append(prev_best_fitness)
# ==============================
# 8⃣ 绘制收敛曲线(仅 Best
# ==============================
plt.figure(figsize=(12, 8))
plt.plot(best_list, linewidth=2, label="Best Fitness")
plt.xlabel("Generation")
plt.ylabel("Error")
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig("convergence0119.png", dpi=300)
plt.close()
# ==============================
# 9⃣ 保存 JSON 结果
# ==============================
result_data = {
"config": cfg,
"best_individual": list(map(float, hof[0])),
"best_fitness": float(hof[0].fitness.values[0]),
"fitness_curve": {
"best_list": best_list
},
"timestamp": exp_time
}
with open(json_result_file, "w", encoding="utf-8") as f:
json.dump(result_data, f, indent=4, ensure_ascii=False)
print(f"\n💾 最优结果已保存至: {json_result_file}")
if __name__ == "__main__":
main()