24 KiB
遗传算法优化ABM模型参数项目(GA_Agent_0925)
项目概述
本项目使用遗传算法(Genetic Algorithm)来优化基于代理的模型(Agent-Based Model)的关键参数,目标是使模型生成的脆弱产业集合与目标产业集合尽可能匹配。项目使用DEAP框架实现遗传算法,集成MySQL数据库进行大规模模拟管理,支持多进程并行计算,具有完整的结果分析和可视化功能。
项目结构
GA_Agent_0925/
├── 核心算法模块
│ ├── main.py # GA演化主循环和结果管理
│ ├── creating.py # DEAP工具箱配置(参数编码、遗传算子)
│ ├── evaluate_func.py # 适应度函数、产业匹配计算
│ ├── orm.py # SQLAlchemy ORM配置、数据库表定义
│ ├── controller_db.py # 数据库生命周期管理、样本操作
│ └── my_model.py # ABM模型调用接口
│
├── 配置文件
│ ├── config.json # GA超参数(种群大小、进化代数等)
│ ├── conf_db.yaml # 数据库连接凭证(主机、端口、用户)
│ ├── conf_db_prefix.yaml # 数据库名前缀配置
│ ├── conf_experiment.yaml # 实验参数(样本数、迭代次数)
│ └── conf.yaml # 其他配置参数
│
├── 数据分析和可视化
│ ├── SQL_analysis_risk_ga.sql # 参数化SQL分析脚本(按ga_id筛选)
│ ├── 多功能.py # 数据分析工具函数
│ └── 绘图.py # 收敛曲线绘制(带移动平均)
│
├── 输出目录
│ ├── results/ # GA演化结果输出
│ ├── risk_ay/ # 风险分析结果
│ └── vulnerable35_match_results/ # 脆弱产业匹配结果
│
└── 输出文件示例(在results/目录下)
├── best_individual_each_gen.txt # 每代最优个体历史记录
├── best_result_with_industry.json # 最终最优解完整结果
└── convergence.png # 收敛曲线图表
核心文件详细说明
1. main.py - GA演化主循环
主要功能:协调整个遗传算法流程,包括初始化、演化循环、结果保存 关键流程(9个阶段):
- 加载配置和数据库初始化
- 创建DEAP工具箱和种群
- 初始化适应度评估和缓存机制
- 开始主演化循环(每代):
- 评估无效个体
- 精英保留
- 选择操作(锦标赛选择)
- 交叉和变异
- 缓存最优个体
- 每3代执行一次数据库清理(删除旧结果记录,保留当前最优)
- 实时写入进度信息
- 最终结果保存到多种格式文件
- 生成收敛曲线图表
- 记录统计信息和时间戳
输入参数:从配置文件读取 输出文件:
best_individual_each_gen.txt- 每代最优个体的基因参数和适应度best_result_with_industry.json- 最终最优解(包含目标产业与匹配产业详情)convergence.png- 收敛过程图表best_result.csv- 最优解的CSV格式
2. evaluate_func.py - 适应度函数(核心优化目标)
主要功能:计算GA个体的适应度值,连接ABM模型和优化目标 关键函数:
fitness(individual)
- 输入:GA个体(12个参数的浮点数组)
- 处理流程:
- 检查缓存(避免重复评估相同参数组合)
- 将基因参数映射到ABM模型参数(参考第4.2节参数映射表)
- 调用ABM模型
do_process()进行模拟 - 提取模拟得到的脆弱产业集合
- 与35个目标产业比对,计算匹配情况
- 使用加权命中率计算适应度:
hit_ratio = Σ(1/rank_i) 其中rank_i为第i个脆弱产业在目标集中的排名 error = 1.0 - hit_ratio fitness = -error (负号因为DEAP最小化) - 缓存结果并返回
- 输出:返回元组(error_value,)
- 缓存机制:使用四舍五入后的参数作为键,防止浮点精度问题
get_target_vulnerable_industries()
- 功能:定义优化目标——35个关键半导体产业
- 产业分类:
- 半导体设备类:光刻机、刻蚀设备、离子注入机等
- 半导体材料类:光刻胶、特种气体、清洗溶剂等
- 晶圆制造类:12英寸晶圆厂、光刻工艺等
- 封装测试类:测试设备、引脚框架等
- 芯片设计类:设计工具、IP库等
- 返回值:35个产业代码的集合
get_vulnerable35_code()
- 功能:从ABM模型输出结果中提取脆弱产业代码
- 处理:解析数据库结果表,识别脆弱产业标识符
- 返回:脆弱产业代码集合
参数映射(ABM模型参数): 详见第4.2节
3. creating.py - DEAP工具箱配置
主要功能:定义GA的参数编码方案、遗传算子配置 关键配置:
基因定义(12个连续或离散参数)
# 参数范围
toolbox.register("n_max_trial", random.randint, 1, 60) # 1-60
toolbox.register("prf_size", random.uniform, 0.0, 1.0) # 0.0-1.0
toolbox.register("prf_conn", random.uniform, 0.0, 1.0) # 0.0-1.0
toolbox.register("cap_limit_prob_type", random.randint, 0, 1) # 0-1
toolbox.register("cap_limit_level", random.randint, 5, 80) # 5-80
toolbox.register("diff_new_conn", random.uniform, 0.0, 1.0) # 0.0-1.0
toolbox.register("netw_prf_n", random.randint, 1, 20) # 1-20
toolbox.register("s_r", random.uniform, 0.01, 0.5) # 0.01-0.5
toolbox.register("S_r", random.uniform, 0.5, 1.0) # 0.5-1.0
toolbox.register("x", random.uniform, 0.0, 1.0) # 0.0-1.0
toolbox.register("k", random.uniform, 0.1, 5.0) # 0.1-5.0
toolbox.register("production_increase_ratio", random.uniform, 1.0, 5.0) # 1.0-5.0
遗传算子配置
- 个体创建:
create_individual()- 随机初始化12个基因 - 种群创建:
create_population(n)- 创建n个个体的种群 - 选择:
tools.selTournament(tournament_size=3)- 锦标赛选择 - 交叉:
tools.cxTwoPoint()- 两点交叉,默认概率0.8 - 变异:
tools.mutShuffleIndexes(indpb=0.2)- 索引洗牌变异,独立变异概率0.2 - 菁英保留:每代自动保留最优个体不参与选择/交叉/变异
4. orm.py - 数据库ORM配置
主要功能:SQLAlchemy ORM配置、表结构定义、数据库连接管理 关键配置:
数据库连接
# 使用NullPool避免连接丢失错误
engine = create_engine(connection_string, poolclass=NullPool)
- 支持MySQL 8.0+
- 数据库名通过
conf_db_prefix.yaml中的前缀确定 - 从
conf_db.yaml读取凭证(主机、端口、用户、密码)
表定义
Experiment 表(实验信息)
id- 主键,实验唯一标识name- 实验名称description- 实验描述created_at- 创建时间
Sample 表(样本信息)
id- 主键experiment_id- 外键,关联Experimentparam_*- ABM参数列(n_max_trial, prf_size等)locked- 样本锁定标志(0/1)locked_at- 锁定时间
Result 表(模拟结果)
id- 主键experiment_id- 外键sample_id- 外键ga_id- GA个体唯一标识(8字符UUID)vulnerable_industry_*- 脆弱产业代码列metric_*- 评估指标列created_at- 结果时间戳
5. controller_db.py - 数据库生命周期管理
主要功能:数据库操作的中间层,样本管理、清理、事务控制 关键方法:
prepare_samples()
- 从数据库中随机获取N个样本
- 对样本进行锁定,防止重复使用
- 返回样本参数列表
get_sample_for_ga(ga_id)
- 为GA个体获取对应的样本
- 从Result表查询已执行结果
cleanup_old_results(keep_best_n)
- 定期清理旧的结果记录
- 保留每个实验的最优结果
- 防止数据库表过大
execute_with_db_session()
- 事务管理
- 支持多进程并发访问
reset_sample_locks()
- 重置所有样本锁定标志
- 允许样本被重新使用
6. my_model.py - ABM模型接口
主要功能:调用外部ABM模型,执行模拟 关键函数:
do_process(sample_params, ga_id, job=6)
- 输入:
sample_params- 样本的ABM参数字典ga_id- GA个体的唯一标识job- 并行进程数(默认6)
- 处理:
- 验证参数有效性
- 启动ABM模型多进程执行
- 实时监控模拟进度
- 收集和整理脆弱产业结果
- 将结果写入数据库Result表
- 输出:
- 返回脆弱产业代码列表
- 同步写入数据库
- 错误处理:
- 参数范围检查
- 进程超时处理
- 数据库事务回滚
功能特性
核心功能
- 遗传算法优化:使用DEAP框架实现完整的遗传算法流程
- 多进程并行计算:支持多进程运行ABM模型,提高计算效率
- 实时结果记录:每代最优个体和适应度实时保存到文件
- 收敛曲线可视化:自动生成算法收敛过程图表
- 产业匹配分析:详细分析模拟结果与目标产业的匹配情况
- 适应度缓存:避免相同参数的重复评估
- 数据库集成:MySQL数据库管理大规模模拟结果
参数优化范围
算法优化以下12个关键ABM模型参数:
n_max_trial- 最大尝试次数 [1, 60]prf_size- 偏好大小 [0.0, 1.0]prf_conn- 偏好连接 [0.0, 1.0]cap_limit_prob_type- 容量限制概率类型 [0, 1]cap_limit_level- 容量限制水平 [5, 80]diff_new_conn- 新连接差异 [0.0, 1.0]netw_prf_n- 网络偏好N [1, 20]s_r- 小r参数 [0.01, 0.5]S_r- 大R参数 [0.5, 1.0]x- X参数 [0.0, 1.0]k- K参数 [0.1, 5.0]production_increase_ratio- 生产增加比率 [1.0, 5.0]
配置文件详细说明
1. config.json - GA超参数配置
文件位置:GA_Agent_0925/config.json
功能:定义遗传算法的执行参数
示例内容:
{
"seed": 42,
"pop_size": 10,
"n_gen": 60,
"cx_prob": 0.8,
"mut_prob": 0.2
}
参数说明:
seed- 随机数种子,确保实验结果可重现性pop_size- 种群大小(每代个体数)n_gen- 进化代数cx_prob- 交叉概率(0-1)mut_prob- 变异概率(0-1)
2. conf_db.yaml - 数据库连接配置
文件位置:GA_Agent_0925/conf_db.yaml
功能:配置MySQL数据库连接信息
示例内容:
host: localhost
port: 3306
user: iiabm_user
password: your_password
说明:
- 确保数据库服务运行在指定主机和端口
- 用户需要有创建表、读写数据的权限
- 密码应妥善保管,不要提交到版本控制
3. conf_db_prefix.yaml - 数据库名前缀配置
文件位置:GA_Agent_0925/conf_db_prefix.yaml
功能:配置数据库名选择策略
配置选项:
without_exp- 不包含实验标识的数据库名with_exp- 包含实验标识的数据库名(带时间戳)test- 测试环境数据库名
4. conf_experiment.yaml - 实验参数配置
文件位置:GA_Agent_0925/conf_experiment.yaml
功能:配置ABM模型实验参数
示例内容:
n_sample: 30
n_iter: 100
meta_seed: 123
n_trial_each_sample: 5
参数说明:
n_sample- 每个GA代数使用的样本数n_iter- 每个样本的迭代次数meta_seed- 元随机数种子n_trial_each_sample- 每个样本的试验次数
5. conf.yaml - 其他全局配置
文件位置:GA_Agent_0925/conf.yaml
功能:系统级配置(日志、输出路径等)
工具脚本详细说明
1. SQL_analysis_risk_ga.sql - 数据库分析SQL脚本
文件位置:GA_Agent_0925/SQL_analysis_risk_ga.sql
主要功能:查询和分析GA优化过程中的结果数据
关键查询:
-- 按GA个体ID筛选结果
SELECT * FROM Result
WHERE ga_id = ?
ORDER BY created_at DESC;
-- 聚合统计脆弱产业
SELECT vulnerable_industry, COUNT(*) as frequency
FROM Result
WHERE experiment_id = ?
GROUP BY vulnerable_industry
ORDER BY frequency DESC;
使用方式:
- 接收参数化输入(ga_id、experiment_id等)
- 与Python脚本配合进行批量查询
- 用于结果分析和验证
2. 多功能.py - 数据分析工具集
文件位置:GA_Agent_0925/多功能.py
主要功能:提供通用数据处理和分析函数
关键函数:
get_vulnerable35_code()
- 从结果中提取脆弱产业代码
- 用于适应度评估
- 处理多种产业代码格式
industry_code_extraction()
- 解析产业编码
- 支持多种编码规范
- 返回标准化的产业代码
analyze_industry_distribution()
- 统计脆弱产业的分布
- 计算各类产业的出现频率
- 生成分析报告
data_cleaning()
- 数据预处理和清理
- 去除重复和异常值
- 格式转换
使用示例:
from 多功能 import get_vulnerable35_code, analyze_industry_distribution
codes = get_vulnerable35_code(result_data)
distribution = analyze_industry_distribution(codes)
3. 绘图.py - 收敛曲线绘制工具
文件位置:GA_Agent_0925/绘图.py
主要功能:可视化GA演化过程
绘制内容:
- 最优适应度曲线 - 每代的最优个体适应度变化
- 平均适应度曲线 - 每代种群的平均适应度
- 移动平均线 - 平滑后的趋势(可选)
- 收敛速度指示 - 显示早期/晚期收敛特征
主要函数:
plot_convergence(history, window_size=5)
- 输入:
history- 每代适应度值列表window_size- 移动平均窗口大小
- 输出:
- 返回图表对象
- 保存为
convergence.png
- 功能:
- 绘制原始数据曲线
- 计算移动平均
- 绘制平滑后的趋势线
- 添加网格和标签
- 保存为高质量图片
plot_multiple_runs(multiple_histories, labels)
- 在同一图表上对比多次运行
- 用于分析算法稳定性
- 显示最好/最差/平均运行结果
配置参数:
- 图表分辨率、线条颜色、字体大小等在脚本中定义
- 支持自定义输出路径
输出目录说明
results/ 目录
自动创建:程序运行时自动创建此目录 内容:
best_individual_each_gen.txt- 每代最优个体的完整历史记录best_result_with_industry.json- 最终最优解和产业匹配详情convergence.png- 收敛曲线可视化图表best_result.csv- 最优解的表格格式
best_individual_each_gen.txt 格式:
实验开始时间:2025-01-15 10:30:45
以下为每一代的最优个体基因参数:
第1代最优基因:[45, 0.632, 0.785, 1, 42, 0.234, 12, 0.15, 0.87, 0.56, 2.1, 3.2]
最优适应度: 0.8500
平均适应度: 0.7234
脆弱产业数: 18
第2代最优基因:[45, 0.632, 0.785, 1, 42, 0.234, 12, 0.15, 0.87, 0.56, 2.1, 3.3]
最优适应度: 0.8650
...
best_result_with_industry.json 格式:
{
"best_individual": [45, 0.632, 0.785, 1, 42, 0.234, 12, 0.15, 0.87, 0.56, 2.1, 3.2],
"best_fitness": 0.8850,
"generation": 60,
"parameter_names": ["n_max_trial", "prf_size", "prf_conn", ...],
"target_vulnerable_industries": [35个产业代码],
"simulated_vulnerable_industries": [模拟得到的脆弱产业],
"matched_industries": [匹配的产业],
"missing_industries": [缺失的产业],
"excess_industries": [多余的产业],
"timestamp": "2025-01-15 11:45:30",
"algorithm_config": {
"seed": 42,
"pop_size": 10,
"n_gen": 60,
"cx_prob": 0.8,
"mut_prob": 0.2
}
}
risk_ay/ 和 vulnerable35_match_results/ 目录
功能:存储中间分析结果和产业匹配详情
risk_ay/- 风险分析相关的数据和图表vulnerable35_match_results/- 脆弱产业匹配的详细分析报告
完整使用指南
第一步:环境配置
1.1 检查依赖包
pip install deap numpy pandas sqlalchemy pymysql pyyaml matplotlib
1.2 配置数据库连接
编辑 conf_db.yaml:
host: your_db_host
port: 3306
user: your_db_user
password: your_db_password
1.3 验证数据库连接
python -c "from orm import engine; engine.connect()"
第二步:参数调优
2.1 GA超参数配置(config.json)
{
"seed": 42,
"pop_size": 10,
"n_gen": 60,
"cx_prob": 0.8,
"mut_prob": 0.2
}
调优建议:
pop_size: 越大收敛越慢但结果越好(建议10-100)n_gen: 代数越多收敛越好但时间越长(建议30-200)cx_prob: 交叉概率通常0.7-0.9mut_prob: 变异概率通常0.1-0.3
2.2 实验参数配置(conf_experiment.yaml)
n_sample: 30
n_iter: 100
meta_seed: 123
n_trial_each_sample: 5
第三步:运行GA优化
3.1 标准运行
cd GA_Agent_0925
python main.py
3.2 监控运行进度
程序输出示例:
=== GA进化进度 ===
代数: 1/60 | 最优适应度: -0.85 | 平均适应度: -0.62 | 脆弱产业: 18/35
代数: 2/60 | 最优适应度: -0.87 | 平均适应度: -0.65 | 脆弱产业: 19/35
...
3.3 暂停和恢复
- 使用 Ctrl+C 暂停程序(中间状态已保存)
- 重新运行
python main.py可以恢复优化
第四步:结果分析
4.1 查看最优解
cat results/best_result_with_industry.json
4.2 分析产业匹配
import json
with open('results/best_result_with_industry.json') as f:
result = json.load(f)
print(f"匹配产业数: {len(result['matched_industries'])}")
print(f"缺失产业数: {len(result['missing_industries'])}")
print(f"多余产业数: {len(result['excess_industries'])}")
4.3 绘制收敛曲线
python 绘图.py
生成 convergence.png 图表
4.4 数据库查询
mysql -h host -u user -p iiabmdb_20250925 < SQL_analysis_risk_ga.sql
适应度函数详解
优化目标
最小化以下误差函数:
误差 = 1.0 - hit_ratio
Hit Ratio 计算
hit_ratio = Σ(1/rank_i) for i in matched_industries
其中 rank_i 是第i个匹配产业在目标35产业中的排名
产业分类定义
35个目标产业分布:
- 光刻机、刻蚀设备、离子注入机等(设备类)
- 光刻胶、特种气体、清洗溶剂等(材料类)
- 12英寸晶圆厂、工艺等(制造类)
- 测试设备、引脚框架等(封装类)
- 设计工具、IP库等(设计类)
评估指标
| 指标 | 说明 | 范围 |
|---|---|---|
| 匹配度 | 模拟脆弱产业与目标重合的数量 | 0-35 |
| Hit Ratio | 加权命中率 | 0.0-1.0 |
| 误差 | 1.0 - Hit Ratio | 0.0-1.0 |
| 适应度 | -误差 | -1.0-0.0 |
数据库操作详解
表关系图
Experiment (1) ----<< (N) Sample
↑ ↑
| |
└──────────────────────┴──→ Result (多条)
每条Result记录:
- experiment_id: 关联实验
- sample_id: 关联样本
- ga_id: 关联GA个体
- 脆弱产业列: 模拟得到的产业
- 指标列: 评估结果
关键SQL查询
查询某GA个体的所有结果
SELECT * FROM Result
WHERE ga_id = '12345678'
ORDER BY created_at;
统计每代产业频率
SELECT vulnerable_industry, COUNT(*) as frequency
FROM Result
WHERE ga_id IN (
SELECT ga_id FROM Result
WHERE generation = 30
)
GROUP BY vulnerable_industry
ORDER BY frequency DESC;
导出最佳结果
SELECT
ga_id,
vulnerable_industry,
error_value,
created_at
FROM Result
WHERE error_value < 0.2
ORDER BY error_value ASC;
多进程并发安全
- 所有数据库操作使用SQLAlchemy会话管理
- 使用NullPool避免连接超时
- Sample表的
locked字段防止并发冲突
故障排除
常见问题与解决方案
| 问题 | 原因 | 解决方案 |
|---|---|---|
| 数据库连接失败 | conf_db.yaml配置错误或数据库离线 | 检查配置、确保数据库运行 |
| 缺少DEAP模块 | 依赖包未安装 | pip install deap |
| 内存不足 | pop_size过大或job过多 | 减少pop_size或job参数 |
| ABM模型错误 | my_model.py实现问题 | 检查my_model.py中do_process()函数 |
| 适应度无法改进 | 参数范围设置不当 | 调整creating.py中的基因范围 |
| 数据库表损坏 | 程序异常中断 | 手动删除表重新初始化 |
| 样本锁定冲突 | 并发访问冲突 | 运行controller_db.reset_sample_locks() |
调试模式
启用详细日志
编辑 main.py 第10行:
import logging
logging.basicConfig(level=logging.DEBUG)
查看详细错误信息
python main.py 2>&1 | tee run.log
单步测试适应度函数
from evaluate_func import fitness
test_individual = [45, 0.632, 0.785, 1, 42, 0.234, 12, 0.15, 0.87, 0.56, 2.1, 3.2]
result = fitness(test_individual)
print(f"适应度值: {result}")
性能优化
加快收敛速度
- 增加交叉概率:
cx_prob: 0.9 - 减少种群大小:
pop_size: 5 - 增加变异概率:
mut_prob: 0.3
减少计算时间
- 降低ABM迭代次数:
n_iter: 50 - 减少样本数:
n_sample: 10 - 增加缓存命中率(使用舍入)
降低内存占用
- 减少
job参数(少于系统核心数) - 定期清理结果表:
cleanup_old_results(keep_best_n=10)
扩展开发指南
添加新的GA参数
步骤1:在creating.py中注册新基因
# 添加新参数 new_param,范围[0, 100]
toolbox.register("new_param", random.randint, 0, 100)
# 更新个体创建函数
def create_individual():
individual = [
# ... 现有12个基因 ...
toolbox.new_param(),
]
return individual
步骤2:在evaluate_func.py中映射参数
def fitness(individual):
# ... 现有代码 ...
params = {
# ... 现有参数 ...
'new_param': int(individual[12]),
}
# 将new_param传给ABM模型
步骤3:更新ABM模型
在 my_model.py 的 do_process() 中使用新参数
修改目标产业集合
编辑evaluate_func.py中的get_target_vulnerable_industries()
def get_target_vulnerable_industries():
return {
# 更新目标产业代码
'IC_001', 'IC_002', # 新增/修改
# ...
}
更换遗传算子
在creating.py中替换算子
# 使用轮盘赌选择替代锦标赛选择
toolbox.register("select", tools.selRoulette)
# 使用均匀交叉替代两点交叉
toolbox.register("mate", tools.cxUniform, indpb=0.5)
# 使用高斯变异替代洗牌变异
toolbox.register("mutate", tools.mutGaussian, mu=0, sigma=0.1, indpb=0.2)
自定义适应度函数
修改fitness计算逻辑
def fitness(individual):
# 自定义适应度计算
# 例如:多目标优化、惩罚项等
simulated = get_simulated_industries(individual)
target = get_target_vulnerable_industries()
# 多目标适应度
match_ratio = len(simulated & target) / len(target)
size_penalty = abs(len(simulated) - len(target)) / len(target)
error = 1.0 - match_ratio + 0.2 * size_penalty
return (-error,)
技术支持与参考
关键文件路径速查表
核心算法: GA_Agent_0925/main.py
参数定义: GA_Agent_0925/creating.py
适应度: GA_Agent_0925/evaluate_func.py
数据库: GA_Agent_0925/orm.py
配置参数: GA_Agent_0925/config.json
输出结果: GA_Agent_0925/results/
许可证
本项目仅供学术研究和半导体供应链韧性研究使用。
技术支持
如有问题或建议,请联系项目维护团队。