330 lines
15 KiB
Python
330 lines
15 KiB
Python
import os
|
||
import time
|
||
import traceback
|
||
import uuid
|
||
from datetime import datetime
|
||
from multiprocessing import Process
|
||
|
||
import pandas as pd
|
||
from sqlalchemy import text
|
||
from sqlalchemy.orm import sessionmaker
|
||
|
||
from my_model import MyModel
|
||
from orm import connection, engine
|
||
|
||
|
||
# 🎯 适应度函数(核心目标函数)
|
||
def fitness(individual, controller_db_obj):
|
||
"""
|
||
遗传算法适应度函数:用于评估个体(模型参数)的优劣。
|
||
|
||
参数:
|
||
individual : list
|
||
个体参数列表:
|
||
[n_max_trial, prf_size, prf_conn, cap_limit_prob_type, cap_limit_level,
|
||
diff_new_conn, netw_prf_n, s_r, S_r, x, k, production_increase_ratio]
|
||
|
||
目标:
|
||
使 ABM 模型生成的“脆弱产业集合”与目标产业集合尽可能相似。
|
||
- fitness = -error
|
||
- error = 模拟结果集合与目标集合的差异度(越小越好)
|
||
"""
|
||
# 生成唯一 GA ID
|
||
ga_id = str(uuid.uuid4())[:8] # 简短随机ID
|
||
individual.ga_id = ga_id # 将 ga_id 绑定到个体上
|
||
# ========== 1️⃣ 生成参数字典 ==========
|
||
dct_exp = {
|
||
'n_max_trial': individual[0],
|
||
'prf_size': individual[1],
|
||
'prf_conn': individual[2],
|
||
'cap_limit_prob_type': individual[3],
|
||
'cap_limit_level': individual[4],
|
||
'diff_new_conn': individual[5],
|
||
'netw_prf_n': individual[6],
|
||
's_r': individual[7],
|
||
'S_r': individual[8],
|
||
'x': individual[9],
|
||
'k': individual[10],
|
||
'production_increase_ratio': individual[11]
|
||
}
|
||
# 将 GA 染色体的值映射为实际 ABM 参数
|
||
if dct_exp['cap_limit_prob_type'] == 0:
|
||
dct_exp['cap_limit_prob_type'] = "uniform" # 类型A(例如 uniform)
|
||
else:
|
||
dct_exp['cap_limit_prob_type'] = "normal" # 类型B(例如 normal)
|
||
|
||
# 打印 GA ID 和参数
|
||
print(f"\n 正在执行 GA 个体 {ga_id},参数如下:")
|
||
# for key, value in dct_exp.items():
|
||
# print(f" {key}: {value}")
|
||
# ========== 2️⃣ 调用 ABM 模型 ==========
|
||
# 并行进程数目
|
||
job=4
|
||
do_process(controller_db_obj,ga_id,dct_exp,job)
|
||
# ========== 3️⃣ 获取数据库连接并提取结果 ==========
|
||
simulated_vulnerable_industries = get_vulnerable35_code(connection,ga_id)
|
||
print(simulated_vulnerable_industries)
|
||
# ========== 4️⃣ 获取目标产业集合 ==========
|
||
target_vulnerable_industries = get_target_vulnerable_industries()
|
||
|
||
# ========== 5️⃣ 计算误差(集合差异度) ==========
|
||
set_sim = set(simulated_vulnerable_industries)
|
||
set_target = set(target_vulnerable_industries)
|
||
error = len(set_sim.symmetric_difference(set_target))
|
||
|
||
simulated_set = set(simulated_vulnerable_industries)
|
||
target_set = set(target_vulnerable_industries)
|
||
|
||
matching = simulated_set & target_set # 交集
|
||
extra = simulated_set - target_set # 模拟多出的产业
|
||
missing = target_set - simulated_set # 未覆盖产业
|
||
|
||
print(f"符合目标的产业数量: {len(matching)}")
|
||
print(matching)
|
||
print(f"模拟多出的产业数量: {len(extra)}")
|
||
print(f"未覆盖目标产业数量: {len(missing)}")
|
||
|
||
# ========== 6️⃣ 返回适应度(越大越好) ==========
|
||
return (float(-error),)
|
||
# 目标产业集合
|
||
def get_target_vulnerable_industries():
|
||
"""
|
||
获取行业列表中所有产业链编号的集合(整数形式)。
|
||
说明:
|
||
- 输入的 industry_list 是一个字典列表,每个字典包含:
|
||
{"product": 产品名称, "category": 产品类别, "chain_id": 产业链编号}
|
||
- 某些 chain_id 可能是复合编号,例如 "11 / 513742",需要拆分成单独整数。
|
||
- 输出是一个 set,包含所有 chain_id(去重、整数形式)。
|
||
|
||
参数:
|
||
industry_list : list of dict
|
||
行业字典列表,每个字典必须包含 "chain_id" 键。
|
||
|
||
返回:
|
||
set
|
||
所有产业链编号的整数集合。
|
||
"""
|
||
industry_list = [
|
||
# ① 半导体设备类
|
||
{"product": "离子注入机", "category": "离子注入设备", "chain_id": 34538},
|
||
{"product": "刻蚀设备 / 湿法刻蚀设备", "category": "刻蚀机", "chain_id": 34529},
|
||
{"product": "沉积设备", "category": "薄膜生长设备(CVD/PVD)", "chain_id": 34539},
|
||
{"product": "CVD", "category": "薄膜生长设备", "chain_id": 34539},
|
||
{"product": "PVD", "category": "薄膜生长设备", "chain_id": 34539},
|
||
{"product": "CMP", "category": "化学机械抛光设备", "chain_id": 34530},
|
||
{"product": "光刻机", "category": "光刻机", "chain_id": 34533},
|
||
{"product": "涂胶显影机", "category": "涂胶显影设备", "chain_id": 34535},
|
||
{"product": "晶圆清洗设备", "category": "晶圆清洗机", "chain_id": 34531},
|
||
{"product": "测试设备", "category": "测试机", "chain_id": 34554},
|
||
{"product": "外延生长设备", "category": "薄膜生长设备", "chain_id": 34539},
|
||
|
||
# ② 半导体材料与化学品类
|
||
{"product": "三氯乙烯", "category": "清洗溶剂 → 通用湿电子化学品", "chain_id": 32438},
|
||
{"product": "丙酮", "category": "清洗溶剂 → 通用湿电子化学品", "chain_id": 32438},
|
||
{"product": "异丙醇", "category": "清洗溶剂 → 通用湿电子化学品", "chain_id": 32438},
|
||
{"product": "其他醇类", "category": "清洗溶剂 → 通用湿电子化学品", "chain_id": 32438},
|
||
{"product": "光刻胶", "category": "光刻胶及配套试剂", "chain_id": 32445},
|
||
{"product": "显影液", "category": "显影液", "chain_id": 46504},
|
||
{"product": "蚀刻液", "category": "蚀刻液", "chain_id": 56341},
|
||
{"product": "光阻去除剂", "category": "光阻去除剂", "chain_id": 32442},
|
||
|
||
# ③ 晶圆制造类
|
||
{"product": "晶圆", "category": "单晶硅片 / 多晶硅片", "chain_id": 32338},
|
||
{"product": "硅衬底", "category": "硅衬底", "chain_id": 36914},
|
||
{"product": "外延片", "category": "硅外延片 / GaN外延片 / SiC外延片等", "chain_id": 32338},
|
||
|
||
# ④ 封装与测试类
|
||
{"product": "封装", "category": "IC封装", "chain_id": 10},
|
||
{"product": "测试", "category": "芯片测试 / 晶圆测试", "chain_id": 513742},
|
||
{"product": "测试", "category": "芯片测试 / 晶圆测试", "chain_id": 11},
|
||
|
||
# ⑤ 芯片与设计EDA类
|
||
{"product": "芯片(通用)", "category": "集成电路制造", "chain_id": 317589},
|
||
{"product": "DRAM", "category": "存储芯片 → 集成电路制造", "chain_id": 317589},
|
||
{"product": "GPU", "category": "图形芯片 → 集成电路制造", "chain_id": 317589},
|
||
{"product": "处理器(CPU/SoC)", "category": "芯片设计", "chain_id": 9},
|
||
{"product": "高频芯片", "category": "芯片设计", "chain_id": 9},
|
||
{"product": "光子芯片(含激光)", "category": "芯片设计 / 功率半导体器件", "chain_id": 9},
|
||
{"product": "光子芯片(含激光)", "category": "芯片设计 / 功率半导体器件", "chain_id": 2717},
|
||
{"product": "先进节点制造设备", "category": "集成电路制造", "chain_id": 317589},
|
||
{"product": "EDA及IP服务", "category": "设计辅助", "chain_id": 2515},
|
||
{"product": "MPW服务", "category": "多项目晶圆流片", "chain_id": 2514},
|
||
{"product": "芯片设计验证", "category": "设计验证", "chain_id": 513738},
|
||
{"product": "过程工艺检测", "category": "制程检测", "chain_id": 513740}
|
||
]
|
||
# 手工转换
|
||
merged_list = [
|
||
'61', '68', '65', '66', '59', '71', '77', '7', '38', '95', '58', '90', '56', '57', '97', '98',
|
||
'27', '8', '11', '9', '21', '96', '99', '100', '44', '45', '46', '47', '48', '49', '50', '51',
|
||
'52', '53', '54', '55'
|
||
]
|
||
# # 提取所有 chain_id,并去重
|
||
# chain_ids = set()
|
||
# for item in industry_list:
|
||
# # 如果 chain_id 是字符串包含多个编号,用逗号或斜杠拆分
|
||
# if isinstance(item["chain_id"], str):
|
||
# for cid in item["chain_id"].replace("/", ",").split(","):
|
||
# chain_ids.add(cid.strip())
|
||
# else:
|
||
# chain_ids.add(str(item["chain_id"]))
|
||
|
||
return merged_list
|
||
|
||
# 从数据库计算脆弱产业集合
|
||
def get_vulnerable35_code(engine, ga_id):
|
||
"""
|
||
计算最脆弱前100产品的 Code 列表(去重),只针对指定 ga_id。
|
||
"""
|
||
# 生成新的 session
|
||
Session = sessionmaker(bind=engine)
|
||
session = Session()
|
||
# 1️⃣ 读取 SQL 文件
|
||
base_dir = os.path.dirname(os.path.abspath(__file__))
|
||
sql_file = os.path.join(base_dir, "..", "GA_Agent_0925", "SQL_analysis_risk_ga.sql")
|
||
sql_file = os.path.abspath(sql_file)
|
||
with open(sql_file, "r", encoding="utf-8") as f:
|
||
str_sql = f.read() # 注意这里是 str_sql,不是 tr_sql
|
||
print(f"[信息] 正在查询 ga_id={ga_id} 的脆弱产品数据...")
|
||
|
||
# 2️⃣ 执行 SQL 查询
|
||
# 2️⃣ 新建 connection
|
||
# 2️⃣ 新建 session,每次查询都用新的 session/connection
|
||
Session = sessionmaker(bind=engine)
|
||
with Session() as session:
|
||
# 使用 session.connection() 来保证 SQLAlchemy 执行原生 SQL
|
||
result = pd.read_sql(
|
||
sql=text(str_sql),
|
||
con=session.connection(),
|
||
params={"ga_id": ga_id} # 绑定参数
|
||
)
|
||
|
||
# ===============================
|
||
# 2️⃣ 统计每个企业-产品组合出现次数
|
||
# ===============================
|
||
count_firm_prod = result.value_counts(subset=['id_firm', 'id_product'])
|
||
count_firm_prod.name = 'count'
|
||
count_firm_prod = count_firm_prod.to_frame().reset_index()
|
||
count_firm_prod.to_csv('count_firm_prod.csv', index=False, encoding='utf-8-sig')
|
||
|
||
# ===============================
|
||
# 3️⃣ 统计每个企业出现的总次数
|
||
# ===============================
|
||
count_firm = count_firm_prod.groupby('id_firm')['count'].sum().reset_index()
|
||
count_firm.sort_values('count', ascending=False, inplace=True)
|
||
count_firm.to_csv('count_firm.csv', index=False, encoding='utf-8-sig')
|
||
|
||
# ===============================
|
||
# 4️⃣ 统计每个产品出现的总次数
|
||
# ===============================
|
||
count_prod = count_firm_prod.groupby('id_product')['count'].sum().reset_index()
|
||
count_prod.sort_values('count', ascending=False, inplace=True)
|
||
count_prod.to_csv('count_prod.csv', index=False, encoding='utf-8-sig')
|
||
|
||
# ===============================
|
||
# 5️⃣ 选出最脆弱的前100个产品(出现次数最多)
|
||
# ===============================
|
||
vulnerable100_product = count_prod.nlargest(35, "count")["id_product"].tolist()
|
||
print(f"[信息] ga_id={ga_id} 查询完成,共找到 {len(vulnerable100_product)} 个脆弱产品")
|
||
|
||
# ===============================
|
||
# # 6️⃣ 过滤 result,只保留前100脆弱产品
|
||
# # ===============================
|
||
# result_vulnerable100 = result[result['id_product'].isin(vulnerable100_product)].copy()
|
||
# print(f"[信息] 筛选后剩余记录数: {len(result_vulnerable100)}")
|
||
# # ===============================
|
||
# # 7️⃣ 构造 DCP(Disruption Causing Probability)
|
||
# # ===============================
|
||
# result_dcp_list = []
|
||
# for sid, group in result_vulnerable100.groupby('s_id'):
|
||
# ts_start = max(group['ts'])
|
||
# while ts_start >= 1:
|
||
# ts_end = ts_start - 1
|
||
# while ts_end >= 0:
|
||
# up = group.loc[group['ts'] == ts_end, ['id_firm', 'id_product']]
|
||
# down = group.loc[group['ts'] == ts_start, ['id_firm', 'id_product']]
|
||
# for _, up_row in up.iterrows():
|
||
# for _, down_row in down.iterrows():
|
||
# result_dcp_list.append([sid] + up_row.tolist() + down_row.tolist())
|
||
# ts_end -= 1
|
||
# ts_start -= 1
|
||
#
|
||
# # 转换为 DataFrame
|
||
# result_dcp = pd.DataFrame(result_dcp_list, columns=[
|
||
# 's_id', 'up_id_firm', 'up_id_product', 'down_id_firm', 'down_id_product'
|
||
# ])
|
||
#
|
||
# # ===============================
|
||
# # 8️⃣ 统计 DCP 出现次数
|
||
# # ===============================
|
||
# count_dcp = result_dcp.value_counts(
|
||
# subset=['up_id_firm', 'up_id_product', 'down_id_firm', 'down_id_product']
|
||
# ).reset_index(name='count')
|
||
#
|
||
# # 保存文件
|
||
# count_dcp.to_csv('count_dcp.csv', index=False, encoding='utf-8-sig')
|
||
|
||
# 输出结果
|
||
return vulnerable100_product
|
||
|
||
def run_ABM_samples(controller_db_obj,ga_id,dct_exp, str_code="GA",):
|
||
"""
|
||
从数据库获取一个随机样本,锁定它,然后运行模型仿真。
|
||
参数:
|
||
controller_db: ControllerDB 对象
|
||
str_code: 可选标识,用于打印
|
||
返回:
|
||
True 如果没有可用样本
|
||
False 如果成功运行
|
||
"""
|
||
# 1. 从数据库获取一个随机样本
|
||
sample_random = controller_db_obj.fetch_a_sample(s_id=None) # s_id 可根据需要传入
|
||
if sample_random is None:
|
||
#print(ga_id+"无样本")
|
||
return True # 没有样本,返回 True 表示结束
|
||
# 2. 锁定该样本
|
||
controller_db_obj.lock_the_sample(sample_random)
|
||
# print(f"Pid {pid} ({str_code}) is running sample {sample_random.id} at {datetime.now()}")
|
||
# print(f"Pid {pid} ({str_code})")
|
||
# print(f"[信息] 当前正在运行的 GA 个体 ID: {ga_id}, 时间: {datetime.now()}")
|
||
# 3. 获取 experiment 的所有列及其值
|
||
dct_exp_new = {column: getattr(sample_random.experiment, column)
|
||
for column in sample_random.experiment.__table__.c.keys()}
|
||
# 删除不需要的主键 id
|
||
dct_exp_new.pop('id', None)
|
||
dct_exp_new = {'sample': sample_random,
|
||
'seed': sample_random.seed,
|
||
**dct_exp_new}
|
||
try:
|
||
dct_sample_para = {
|
||
**dct_exp_new,
|
||
**dct_exp}
|
||
# 切换工作目录到项目根目录或 ABM 需要的目录
|
||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||
os.chdir(project_root)
|
||
|
||
abm_model = MyModel(dct_sample_para)
|
||
abm_model.step()
|
||
abm_model.end(ga_id)
|
||
except Exception as e:
|
||
print(f"[❌ ABM运行错误] 错误:{e} ")
|
||
traceback.print_exc()
|
||
|
||
def do_computation(controller_db_obj,ga_id,dct_exp,):
|
||
"""每个进程执行 ABM 样本运行"""
|
||
pid = os.getpid()
|
||
print(f"[启动] 进程 {pid} 已启动 (PID={pid})")
|
||
while 1:
|
||
# time.sleep(random.uniform(0, 1))
|
||
is_all_done = run_ABM_samples(controller_db_obj,ga_id,dct_exp,)
|
||
if is_all_done:
|
||
break
|
||
|
||
def do_process(controller_db,ga_id,dct_exp,job):
|
||
process_list = []
|
||
for i in range(job):
|
||
p = Process(target=do_computation, args=(controller_db,ga_id,dct_exp,))
|
||
p.start()
|
||
process_list.append(p)
|
||
for i in process_list:
|
||
i.join()
|