mesa-GA/GA_Agent_0925/evaluate_func.py

331 lines
15 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import time
import traceback
import uuid
from datetime import datetime
from multiprocessing import Process
import pandas as pd
from sqlalchemy import text
from sqlalchemy.orm import sessionmaker
from my_model import MyModel
from orm import connection, engine
# 🎯 适应度函数(核心目标函数)
def fitness(individual, controller_db_obj):
"""
遗传算法适应度函数:用于评估个体(模型参数)的优劣。
参数:
individual : list
个体参数列表:
[n_max_trial, prf_size, prf_conn, cap_limit_prob_type, cap_limit_level,
diff_new_conn, netw_prf_n, s_r, S_r, x, k, production_increase_ratio]
目标:
使 ABM 模型生成的“脆弱产业集合”与目标产业集合尽可能相似。
- fitness = -error
- error = 模拟结果集合与目标集合的差异度(越小越好)
"""
# 生成唯一 GA ID
ga_id = str(uuid.uuid4())[:8] # 简短随机ID
individual.ga_id = ga_id # 将 ga_id 绑定到个体上
# ========== 1⃣ 生成参数字典 ==========
dct_exp = {
'n_max_trial': individual[0],
'prf_size': individual[1],
'prf_conn': individual[2],
'cap_limit_prob_type': individual[3],
'cap_limit_level': individual[4],
'diff_new_conn': individual[5],
'netw_prf_n': individual[6],
's_r': individual[7],
'S_r': individual[8],
'x': individual[9],
'k': individual[10],
'production_increase_ratio': individual[11]
}
# 将 GA 染色体的值映射为实际 ABM 参数
if dct_exp['cap_limit_prob_type'] == 0:
dct_exp['cap_limit_prob_type'] = "uniform" # 类型A例如 uniform
else:
dct_exp['cap_limit_prob_type'] = "normal" # 类型B例如 normal
# 打印 GA ID 和参数
print(f"\n 正在执行 GA 个体 {ga_id},参数如下:")
# for key, value in dct_exp.items():
# print(f" {key}: {value}")
# ========== 2⃣ 调用 ABM 模型 ==========
# 并行进程数目
job=6
do_process(controller_db_obj,ga_id,dct_exp,job)
# ========== 3⃣ 获取数据库连接并提取结果 ==========
simulated_vulnerable_industries = get_vulnerable35_code(connection,ga_id)
print(simulated_vulnerable_industries)
# ========== 4⃣ 获取目标产业集合 ==========
target_vulnerable_industries = get_target_vulnerable_industries()
# ========== 5⃣ 计算误差(集合差异度) ==========
set_sim = set(simulated_vulnerable_industries)
set_target = set(target_vulnerable_industries)
error = len(set_sim.symmetric_difference(set_target))
simulated_set = set(simulated_vulnerable_industries)
target_set = set(target_vulnerable_industries)
matching = simulated_set & target_set # 交集
extra = simulated_set - target_set # 模拟多出的产业
missing = target_set - simulated_set # 未覆盖产业
print(f"符合目标的产业数量: {len(matching)}")
print(matching)
print(f"模拟多出的产业数量: {len(extra)}")
print(f"未覆盖目标产业数量: {len(missing)}")
# ========== 6⃣ 返回适应度(越大越好) ==========
return (float(-error),)
# 目标产业集合
def get_target_vulnerable_industries():
"""
获取行业列表中所有产业链编号的集合(整数形式)。
说明:
- 输入的 industry_list 是一个字典列表,每个字典包含:
{"product": 产品名称, "category": 产品类别, "chain_id": 产业链编号}
- 某些 chain_id 可能是复合编号,例如 "11 / 513742",需要拆分成单独整数。
- 输出是一个 set包含所有 chain_id去重、整数形式
参数:
industry_list : list of dict
行业字典列表,每个字典必须包含 "chain_id" 键。
返回:
set
所有产业链编号的整数集合。
"""
industry_list = [
# ① 半导体设备类
{"product": "离子注入机", "category": "离子注入设备", "chain_id": 34538},
{"product": "刻蚀设备 / 湿法刻蚀设备", "category": "刻蚀机", "chain_id": 34529},
{"product": "沉积设备", "category": "薄膜生长设备CVD/PVD", "chain_id": 34539},
{"product": "CVD", "category": "薄膜生长设备", "chain_id": 34539},
{"product": "PVD", "category": "薄膜生长设备", "chain_id": 34539},
{"product": "CMP", "category": "化学机械抛光设备", "chain_id": 34530},
{"product": "光刻机", "category": "光刻机", "chain_id": 34533},
{"product": "涂胶显影机", "category": "涂胶显影设备", "chain_id": 34535},
{"product": "晶圆清洗设备", "category": "晶圆清洗机", "chain_id": 34531},
{"product": "测试设备", "category": "测试机", "chain_id": 34554},
{"product": "外延生长设备", "category": "薄膜生长设备", "chain_id": 34539},
# ② 半导体材料与化学品类
{"product": "三氯乙烯", "category": "清洗溶剂 → 通用湿电子化学品", "chain_id": 32438},
{"product": "丙酮", "category": "清洗溶剂 → 通用湿电子化学品", "chain_id": 32438},
{"product": "异丙醇", "category": "清洗溶剂 → 通用湿电子化学品", "chain_id": 32438},
{"product": "其他醇类", "category": "清洗溶剂 → 通用湿电子化学品", "chain_id": 32438},
{"product": "光刻胶", "category": "光刻胶及配套试剂", "chain_id": 32445},
{"product": "显影液", "category": "显影液", "chain_id": 46504},
{"product": "蚀刻液", "category": "蚀刻液", "chain_id": 56341},
{"product": "光阻去除剂", "category": "光阻去除剂", "chain_id": 32442},
# ③ 晶圆制造类
{"product": "晶圆", "category": "单晶硅片 / 多晶硅片", "chain_id": 32338},
{"product": "硅衬底", "category": "硅衬底", "chain_id": 36914},
{"product": "外延片", "category": "硅外延片 / GaN外延片 / SiC外延片等", "chain_id": 32338},
# ④ 封装与测试类
{"product": "封装", "category": "IC封装", "chain_id": 10},
{"product": "测试", "category": "芯片测试 / 晶圆测试", "chain_id": 513742},
{"product": "测试", "category": "芯片测试 / 晶圆测试", "chain_id": 11},
# ⑤ 芯片与设计EDA类
{"product": "芯片(通用)", "category": "集成电路制造", "chain_id": 317589},
{"product": "DRAM", "category": "存储芯片 → 集成电路制造", "chain_id": 317589},
{"product": "GPU", "category": "图形芯片 → 集成电路制造", "chain_id": 317589},
{"product": "处理器CPU/SoC", "category": "芯片设计", "chain_id": 9},
{"product": "高频芯片", "category": "芯片设计", "chain_id": 9},
{"product": "光子芯片(含激光)", "category": "芯片设计 / 功率半导体器件", "chain_id": 9},
{"product": "光子芯片(含激光)", "category": "芯片设计 / 功率半导体器件", "chain_id": 2717},
{"product": "先进节点制造设备", "category": "集成电路制造", "chain_id": 317589},
{"product": "EDA及IP服务", "category": "设计辅助", "chain_id": 2515},
{"product": "MPW服务", "category": "多项目晶圆流片", "chain_id": 2514},
{"product": "芯片设计验证", "category": "设计验证", "chain_id": 513738},
{"product": "过程工艺检测", "category": "制程检测", "chain_id": 513740}
]
# 手工转换
merged_list = [
'61', '68', '65', '66', '59', '71', '77', '7', '38', '95', '58', '90', '56', '57', '97', '98',
'27', '8', '11', '9', '21', '96', '99', '100', '44', '45', '46', '47', '48', '49', '50', '51',
'52', '53', '54', '55'
]
# # 提取所有 chain_id并去重
# chain_ids = set()
# for item in industry_list:
# # 如果 chain_id 是字符串包含多个编号,用逗号或斜杠拆分
# if isinstance(item["chain_id"], str):
# for cid in item["chain_id"].replace("/", ",").split(","):
# chain_ids.add(cid.strip())
# else:
# chain_ids.add(str(item["chain_id"]))
return merged_list
# 从数据库计算脆弱产业集合
def get_vulnerable35_code(engine, ga_id):
"""
计算最脆弱前100产品的 Code 列表(去重),只针对指定 ga_id。
"""
# 生成新的 session
Session = sessionmaker(bind=engine)
session = Session()
# 1⃣ 读取 SQL 文件
base_dir = os.path.dirname(os.path.abspath(__file__))
sql_file = os.path.join(base_dir, "..", "GA_Agent_0925", "SQL_analysis_risk_ga.sql")
sql_file = os.path.abspath(sql_file)
with open(sql_file, "r", encoding="utf-8") as f:
str_sql = f.read() # 注意这里是 str_sql不是 tr_sql
print(f"[信息] 正在查询 ga_id={ga_id} 的脆弱产品数据...")
# 2⃣ 执行 SQL 查询
# 2⃣ 新建 connection
# 2⃣ 新建 session每次查询都用新的 session/connection
Session = sessionmaker(bind=engine)
with Session() as session:
# 使用 session.connection() 来保证 SQLAlchemy 执行原生 SQL
result = pd.read_sql(
sql=text(str_sql),
con=session.connection(),
params={"ga_id": ga_id} # 绑定参数
)
# ===============================
# 2⃣ 统计每个企业-产品组合出现次数
# ===============================
count_firm_prod = result.value_counts(subset=['id_firm', 'id_product'])
count_firm_prod.name = 'count'
count_firm_prod = count_firm_prod.to_frame().reset_index()
count_firm_prod.to_csv('count_firm_prod.csv', index=False, encoding='utf-8-sig')
# ===============================
# 3⃣ 统计每个企业出现的总次数
# ===============================
count_firm = count_firm_prod.groupby('id_firm')['count'].sum().reset_index()
count_firm.sort_values('count', ascending=False, inplace=True)
count_firm.to_csv('count_firm.csv', index=False, encoding='utf-8-sig')
# ===============================
# 4⃣ 统计每个产品出现的总次数
# ===============================
count_prod = count_firm_prod.groupby('id_product')['count'].sum().reset_index()
count_prod.sort_values('count', ascending=False, inplace=True)
count_prod.to_csv('count_prod.csv', index=False, encoding='utf-8-sig')
# ===============================
# 5⃣ 选出最脆弱的前100个产品出现次数最多
# ===============================
vulnerable100_product = count_prod.nlargest(35, "count")["id_product"].tolist()
print(f"[信息] ga_id={ga_id} 查询完成,共找到 {len(vulnerable100_product)} 个脆弱产品")
# ===============================
# # 6⃣ 过滤 result只保留前100脆弱产品
# # ===============================
# result_vulnerable100 = result[result['id_product'].isin(vulnerable100_product)].copy()
# print(f"[信息] 筛选后剩余记录数: {len(result_vulnerable100)}")
# # ===============================
# # 7⃣ 构造 DCPDisruption Causing Probability
# # ===============================
# result_dcp_list = []
# for sid, group in result_vulnerable100.groupby('s_id'):
# ts_start = max(group['ts'])
# while ts_start >= 1:
# ts_end = ts_start - 1
# while ts_end >= 0:
# up = group.loc[group['ts'] == ts_end, ['id_firm', 'id_product']]
# down = group.loc[group['ts'] == ts_start, ['id_firm', 'id_product']]
# for _, up_row in up.iterrows():
# for _, down_row in down.iterrows():
# result_dcp_list.append([sid] + up_row.tolist() + down_row.tolist())
# ts_end -= 1
# ts_start -= 1
#
# # 转换为 DataFrame
# result_dcp = pd.DataFrame(result_dcp_list, columns=[
# 's_id', 'up_id_firm', 'up_id_product', 'down_id_firm', 'down_id_product'
# ])
#
# # ===============================
# # 8⃣ 统计 DCP 出现次数
# # ===============================
# count_dcp = result_dcp.value_counts(
# subset=['up_id_firm', 'up_id_product', 'down_id_firm', 'down_id_product']
# ).reset_index(name='count')
#
# # 保存文件
# count_dcp.to_csv('count_dcp.csv', index=False, encoding='utf-8-sig')
# 输出结果
return vulnerable100_product
def run_ABM_samples(controller_db_obj,ga_id,dct_exp, str_code="GA",):
"""
从数据库获取一个随机样本,锁定它,然后运行模型仿真。
参数:
controller_db: ControllerDB 对象
str_code: 可选标识,用于打印
返回:
True 如果没有可用样本
False 如果成功运行
"""
# 1. 从数据库获取一个随机样本
sample_random = controller_db_obj.fetch_a_sample(s_id=None) # s_id 可根据需要传入
if sample_random is None:
#print(ga_id+"无样本")
return True # 没有样本,返回 True 表示结束
# 2. 锁定该样本
controller_db_obj.lock_the_sample(sample_random)
# print(f"Pid {pid} ({str_code}) is running sample {sample_random.id} at {datetime.now()}")
# print(f"Pid {pid} ({str_code})")
# print(f"[信息] 当前正在运行的 GA 个体 ID: {ga_id}, 时间: {datetime.now()}")
# 3. 获取 experiment 的所有列及其值
dct_exp_new = {column: getattr(sample_random.experiment, column)
for column in sample_random.experiment.__table__.c.keys()}
# 删除不需要的主键 id
dct_exp_new.pop('id', None)
dct_exp_new = {'sample': sample_random,
'seed': sample_random.seed,
**dct_exp_new}
try:
dct_sample_para = {
**dct_exp_new,
**dct_exp}
# 切换工作目录到项目根目录或 ABM 需要的目录
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
os.chdir(project_root)
abm_model = MyModel(dct_sample_para)
abm_model.step()
abm_model.end(ga_id)
except Exception as e:
print(f"[❌ ABM运行错误] 错误:{e} ")
traceback.print_exc()
def do_computation(controller_db_obj,ga_id,dct_exp,):
"""每个进程执行 ABM 样本运行"""
pid = os.getpid()
print(f"[启动] 进程 {pid} 已启动 (PID={pid})")
while 1:
# time.sleep(random.uniform(0, 1))
is_all_done = run_ABM_samples(controller_db_obj,ga_id,dct_exp,)
if is_all_done:
break
def do_process(controller_db,ga_id,dct_exp,job):
process_list = []
for i in range(job):
p = Process(target=do_computation, args=(controller_db,ga_id,dct_exp,))
p.start()
process_list.append(p)
for i in process_list:
i.join()