遗传算法001

This commit is contained in:
Cricial
2025-10-18 16:16:05 +08:00
parent 91f2122b65
commit dfd5c5b32d
49 changed files with 8473 additions and 1825 deletions

View File

@@ -0,0 +1,327 @@
import os
import time
import traceback
import uuid
from datetime import datetime
from multiprocessing import Process
import pandas as pd
from sqlalchemy import text
from sqlalchemy.orm import sessionmaker
from my_model import MyModel
from orm import connection, engine
# 🎯 适应度函数(核心目标函数)
def fitness(individual, controller_db_obj):
"""
遗传算法适应度函数:用于评估个体(模型参数)的优劣。
参数:
individual : list
个体参数列表:
[n_max_trial, prf_size, prf_conn, cap_limit_prob_type, cap_limit_level,
diff_new_conn, netw_prf_n, s_r, S_r, x, k, production_increase_ratio]
目标:
使 ABM 模型生成的“脆弱产业集合”与目标产业集合尽可能相似。
- fitness = -error
- error = 模拟结果集合与目标集合的差异度(越小越好)
"""
# 生成唯一 GA ID
ga_id = str(uuid.uuid4())[:8] # 简短随机ID
individual.ga_id = ga_id # 将 ga_id 绑定到个体上
# ========== 1⃣ 生成参数字典 ==========
dct_exp = {
'n_max_trial': individual[0],
'prf_size': individual[1],
'prf_conn': individual[2],
'cap_limit_prob_type': individual[3],
'cap_limit_level': individual[4],
'diff_new_conn': individual[5],
'netw_prf_n': individual[6],
's_r': individual[7],
'S_r': individual[8],
'x': individual[9],
'k': individual[10],
'production_increase_ratio': individual[11]
}
# 将 GA 染色体的值映射为实际 ABM 参数
if dct_exp['cap_limit_prob_type'] == 0:
dct_exp['cap_limit_prob_type'] = "uniform" # 类型A例如 uniform
else:
dct_exp['cap_limit_prob_type'] = "normal" # 类型B例如 normal
# 打印 GA ID 和参数
print(f"\n 正在执行 GA 个体 {ga_id},参数如下:")
# for key, value in dct_exp.items():
# print(f" {key}: {value}")
# ========== 2⃣ 调用 ABM 模型 ==========
# 并行进程数目
job=6
do_process(controller_db_obj,ga_id,dct_exp,job)
# ========== 3⃣ 获取数据库连接并提取结果 ==========
simulated_vulnerable_industries = get_vulnerable100_code(connection,ga_id)
print(simulated_vulnerable_industries)
# ========== 4⃣ 获取目标产业集合 ==========
target_vulnerable_industries = get_target_vulnerable_industries()
# ========== 5⃣ 计算误差(集合差异度) ==========
set_sim = set(simulated_vulnerable_industries)
set_target = set(target_vulnerable_industries)
error = len(set_sim.symmetric_difference(set_target))
simulated_set = set(simulated_vulnerable_industries)
target_set = set(target_vulnerable_industries)
matching = simulated_set & target_set # 交集
extra = simulated_set - target_set # 模拟多出的产业
missing = target_set - simulated_set # 未覆盖产业
print(f"符合目标的产业数量: {len(matching)}")
print(matching)
print(f"模拟多出的产业数量: {len(extra)}")
print(f"未覆盖目标产业数量: {len(missing)}")
# ========== 6⃣ 返回适应度(越大越好) ==========
return (float(-error),)
# 目标产业集合
def get_target_vulnerable_industries():
"""
获取行业列表中所有产业链编号的集合(整数形式)。
说明:
- 输入的 industry_list 是一个字典列表,每个字典包含:
{"product": 产品名称, "category": 产品类别, "chain_id": 产业链编号}
- 某些 chain_id 可能是复合编号,例如 "11 / 513742",需要拆分成单独整数。
- 输出是一个 set包含所有 chain_id去重、整数形式
参数:
industry_list : list of dict
行业字典列表,每个字典必须包含 "chain_id" 键。
返回:
set
所有产业链编号的整数集合。
"""
industry_list = [
# ① 半导体设备类
{"product": "离子注入机", "category": "离子注入设备", "chain_id": 34538},
{"product": "刻蚀设备 / 湿法刻蚀设备", "category": "刻蚀机", "chain_id": 34529},
{"product": "沉积设备", "category": "薄膜生长设备CVD/PVD", "chain_id": 34539},
{"product": "CVD", "category": "薄膜生长设备", "chain_id": 34539},
{"product": "PVD", "category": "薄膜生长设备", "chain_id": 34539},
{"product": "CMP", "category": "化学机械抛光设备", "chain_id": 34530},
{"product": "光刻机", "category": "光刻机", "chain_id": 34533},
{"product": "涂胶显影机", "category": "涂胶显影设备", "chain_id": 34535},
{"product": "晶圆清洗设备", "category": "晶圆清洗机", "chain_id": 34531},
{"product": "测试设备", "category": "测试机", "chain_id": 34554},
{"product": "外延生长设备", "category": "薄膜生长设备", "chain_id": 34539},
# ② 半导体材料与化学品类
{"product": "三氯乙烯", "category": "清洗溶剂 → 通用湿电子化学品", "chain_id": 32438},
{"product": "丙酮", "category": "清洗溶剂 → 通用湿电子化学品", "chain_id": 32438},
{"product": "异丙醇", "category": "清洗溶剂 → 通用湿电子化学品", "chain_id": 32438},
{"product": "其他醇类", "category": "清洗溶剂 → 通用湿电子化学品", "chain_id": 32438},
{"product": "光刻胶", "category": "光刻胶及配套试剂", "chain_id": 32445},
{"product": "显影液", "category": "显影液", "chain_id": 46504},
{"product": "蚀刻液", "category": "蚀刻液", "chain_id": 56341},
{"product": "光阻去除剂", "category": "光阻去除剂", "chain_id": 32442},
# ③ 晶圆制造类
{"product": "晶圆", "category": "单晶硅片 / 多晶硅片", "chain_id": 32338},
{"product": "硅衬底", "category": "硅衬底", "chain_id": 36914},
{"product": "外延片", "category": "硅外延片 / GaN外延片 / SiC外延片等", "chain_id": 32338},
# ④ 封装与测试类
{"product": "封装", "category": "IC封装", "chain_id": 10},
{"product": "测试", "category": "芯片测试 / 晶圆测试", "chain_id": 513742},
{"product": "测试", "category": "芯片测试 / 晶圆测试", "chain_id": 11},
# ⑤ 芯片与设计EDA类
{"product": "芯片(通用)", "category": "集成电路制造", "chain_id": 317589},
{"product": "DRAM", "category": "存储芯片 → 集成电路制造", "chain_id": 317589},
{"product": "GPU", "category": "图形芯片 → 集成电路制造", "chain_id": 317589},
{"product": "处理器CPU/SoC", "category": "芯片设计", "chain_id": 9},
{"product": "高频芯片", "category": "芯片设计", "chain_id": 9},
{"product": "光子芯片(含激光)", "category": "芯片设计 / 功率半导体器件", "chain_id": 9},
{"product": "光子芯片(含激光)", "category": "芯片设计 / 功率半导体器件", "chain_id": 2717},
{"product": "先进节点制造设备", "category": "集成电路制造", "chain_id": 317589},
{"product": "EDA及IP服务", "category": "设计辅助", "chain_id": 2515},
{"product": "MPW服务", "category": "多项目晶圆流片", "chain_id": 2514},
{"product": "芯片设计验证", "category": "设计验证", "chain_id": 513738},
{"product": "过程工艺检测", "category": "制程检测", "chain_id": 513740}
]
# 手工转换
industry_list_index = vulnerable100_index = \
['100', '58', '61', '9', '7', '98', '57', '8', '65', '68', '66', '38',
'90', '21', '96', '71', '27', '74', '99', '95', '11', '77', '59', '56', '97']
# # 提取所有 chain_id并去重
# chain_ids = set()
# for item in industry_list:
# # 如果 chain_id 是字符串包含多个编号,用逗号或斜杠拆分
# if isinstance(item["chain_id"], str):
# for cid in item["chain_id"].replace("/", ",").split(","):
# chain_ids.add(cid.strip())
# else:
# chain_ids.add(str(item["chain_id"]))
return industry_list_index
# 从数据库计算脆弱产业集合
def get_vulnerable100_code(engine, ga_id):
"""
计算最脆弱前100产品的 Code 列表(去重),只针对指定 ga_id。
"""
# 生成新的 session
Session = sessionmaker(bind=engine)
session = Session()
# 1⃣ 读取 SQL 文件
base_dir = os.path.dirname(os.path.abspath(__file__))
sql_file = os.path.join(base_dir, "..", "GA_Agent_0925", "SQL_analysis_risk_ga.sql")
sql_file = os.path.abspath(sql_file)
with open(sql_file, "r", encoding="utf-8") as f:
str_sql = f.read() # 注意这里是 str_sql不是 tr_sql
print(f"[信息] 正在查询 ga_id={ga_id} 的脆弱产品数据...")
# 2⃣ 执行 SQL 查询
# 2⃣ 新建 connection
# 2⃣ 新建 session每次查询都用新的 session/connection
Session = sessionmaker(bind=engine)
with Session() as session:
# 使用 session.connection() 来保证 SQLAlchemy 执行原生 SQL
result = pd.read_sql(
sql=text(str_sql),
con=session.connection(),
params={"ga_id": ga_id} # 绑定参数
)
# ===============================
# 2⃣ 统计每个企业-产品组合出现次数
# ===============================
count_firm_prod = result.value_counts(subset=['id_firm', 'id_product'])
count_firm_prod.name = 'count'
count_firm_prod = count_firm_prod.to_frame().reset_index()
count_firm_prod.to_csv('count_firm_prod.csv', index=False, encoding='utf-8-sig')
# ===============================
# 3⃣ 统计每个企业出现的总次数
# ===============================
count_firm = count_firm_prod.groupby('id_firm')['count'].sum().reset_index()
count_firm.sort_values('count', ascending=False, inplace=True)
count_firm.to_csv('count_firm.csv', index=False, encoding='utf-8-sig')
# ===============================
# 4⃣ 统计每个产品出现的总次数
# ===============================
count_prod = count_firm_prod.groupby('id_product')['count'].sum().reset_index()
count_prod.sort_values('count', ascending=False, inplace=True)
count_prod.to_csv('count_prod.csv', index=False, encoding='utf-8-sig')
# ===============================
# 5⃣ 选出最脆弱的前100个产品出现次数最多
# ===============================
vulnerable100_product = count_prod.nlargest(100, "count")["id_product"].tolist()
print(f"[信息] ga_id={ga_id} 查询完成,共找到 {len(vulnerable100_product)} 个脆弱产品")
# ===============================
# # 6⃣ 过滤 result只保留前100脆弱产品
# # ===============================
# result_vulnerable100 = result[result['id_product'].isin(vulnerable100_product)].copy()
# print(f"[信息] 筛选后剩余记录数: {len(result_vulnerable100)}")
# # ===============================
# # 7⃣ 构造 DCPDisruption Causing Probability
# # ===============================
# result_dcp_list = []
# for sid, group in result_vulnerable100.groupby('s_id'):
# ts_start = max(group['ts'])
# while ts_start >= 1:
# ts_end = ts_start - 1
# while ts_end >= 0:
# up = group.loc[group['ts'] == ts_end, ['id_firm', 'id_product']]
# down = group.loc[group['ts'] == ts_start, ['id_firm', 'id_product']]
# for _, up_row in up.iterrows():
# for _, down_row in down.iterrows():
# result_dcp_list.append([sid] + up_row.tolist() + down_row.tolist())
# ts_end -= 1
# ts_start -= 1
#
# # 转换为 DataFrame
# result_dcp = pd.DataFrame(result_dcp_list, columns=[
# 's_id', 'up_id_firm', 'up_id_product', 'down_id_firm', 'down_id_product'
# ])
#
# # ===============================
# # 8⃣ 统计 DCP 出现次数
# # ===============================
# count_dcp = result_dcp.value_counts(
# subset=['up_id_firm', 'up_id_product', 'down_id_firm', 'down_id_product']
# ).reset_index(name='count')
#
# # 保存文件
# count_dcp.to_csv('count_dcp.csv', index=False, encoding='utf-8-sig')
# 输出结果
return vulnerable100_product
def run_ABM_samples(controller_db_obj,ga_id,dct_exp, str_code="GA",):
"""
从数据库获取一个随机样本,锁定它,然后运行模型仿真。
参数:
controller_db: ControllerDB 对象
str_code: 可选标识,用于打印
返回:
True 如果没有可用样本
False 如果成功运行
"""
# 1. 从数据库获取一个随机样本
sample_random = controller_db_obj.fetch_a_sample(s_id=None) # s_id 可根据需要传入
if sample_random is None:
#print(ga_id+"无样本")
return True # 没有样本,返回 True 表示结束
# 2. 锁定该样本
controller_db_obj.lock_the_sample(sample_random)
# print(f"Pid {pid} ({str_code}) is running sample {sample_random.id} at {datetime.now()}")
# print(f"Pid {pid} ({str_code})")
# print(f"[信息] 当前正在运行的 GA 个体 ID: {ga_id}, 时间: {datetime.now()}")
# 3. 获取 experiment 的所有列及其值
dct_exp_new = {column: getattr(sample_random.experiment, column)
for column in sample_random.experiment.__table__.c.keys()}
# 删除不需要的主键 id
dct_exp_new.pop('id', None)
dct_exp_new = {'sample': sample_random,
'seed': sample_random.seed,
**dct_exp_new}
try:
dct_sample_para = {
**dct_exp_new,
**dct_exp}
# 切换工作目录到项目根目录或 ABM 需要的目录
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
os.chdir(project_root)
abm_model = MyModel(dct_sample_para)
abm_model.step()
abm_model.end(ga_id)
except Exception as e:
print(f"[❌ ABM运行错误] 错误:{e} ")
traceback.print_exc()
def do_computation(controller_db_obj,ga_id,dct_exp,):
"""每个进程执行 ABM 样本运行"""
pid = os.getpid()
print(f"[启动] 进程 {pid} 已启动 (PID={pid})")
while 1:
# time.sleep(random.uniform(0, 1))
is_all_done = run_ABM_samples(controller_db_obj,ga_id,dct_exp,)
if is_all_done:
break
def do_process(controller_db,ga_id,dct_exp,job):
process_list = []
for i in range(job):
p = Process(target=do_computation, args=(controller_db,ga_id,dct_exp,))
p.start()
process_list.append(p)
for i in process_list:
i.join()