import os import time import traceback import uuid from datetime import datetime from multiprocessing import Process import pandas as pd from sqlalchemy import text from sqlalchemy.orm import sessionmaker from my_model import MyModel from orm import connection, engine # 🎯 适应度函数(核心目标函数) def fitness(individual, controller_db_obj): """ 遗传算法适应度函数:用于评估个体(模型参数)的优劣。 参数: individual : list 个体参数列表: [n_max_trial, prf_size, prf_conn, cap_limit_prob_type, cap_limit_level, diff_new_conn, netw_prf_n, s_r, S_r, x, k, production_increase_ratio] 目标: 使 ABM 模型生成的“脆弱产业集合”与目标产业集合尽可能相似。 - fitness = -error - error = 模拟结果集合与目标集合的差异度(越小越好) """ # 生成唯一 GA ID ga_id = str(uuid.uuid4())[:8] # 简短随机ID individual.ga_id = ga_id # 将 ga_id 绑定到个体上 # ========== 1️⃣ 生成参数字典 ========== dct_exp = { 'n_max_trial': individual[0], 'prf_size': individual[1], 'prf_conn': individual[2], 'cap_limit_prob_type': individual[3], 'cap_limit_level': individual[4], 'diff_new_conn': individual[5], 'netw_prf_n': individual[6], 's_r': individual[7], 'S_r': individual[8], 'x': individual[9], 'k': individual[10], 'production_increase_ratio': individual[11] } # 将 GA 染色体的值映射为实际 ABM 参数 if dct_exp['cap_limit_prob_type'] == 0: dct_exp['cap_limit_prob_type'] = "uniform" # 类型A(例如 uniform) else: dct_exp['cap_limit_prob_type'] = "normal" # 类型B(例如 normal) # 打印 GA ID 和参数 print(f"\n 正在执行 GA 个体 {ga_id},参数如下:") # for key, value in dct_exp.items(): # print(f" {key}: {value}") # ========== 2️⃣ 调用 ABM 模型 ========== # 并行进程数目 job=4 do_process(controller_db_obj,ga_id,dct_exp,job) # ========== 3️⃣ 获取数据库连接并提取结果 ========== simulated_vulnerable_industries = get_vulnerable35_code(connection,ga_id) print(simulated_vulnerable_industries) # ========== 4️⃣ 获取目标产业集合 ========== target_vulnerable_industries = get_target_vulnerable_industries() # ========== 5️⃣ 计算误差(集合差异度) ========== set_sim = set(simulated_vulnerable_industries) set_target = set(target_vulnerable_industries) error = len(set_sim.symmetric_difference(set_target)) simulated_set = set(simulated_vulnerable_industries) target_set = set(target_vulnerable_industries) matching = simulated_set & target_set # 交集 extra = simulated_set - target_set # 模拟多出的产业 missing = target_set - simulated_set # 未覆盖产业 print(f"符合目标的产业数量: {len(matching)}") print(matching) print(f"模拟多出的产业数量: {len(extra)}") print(f"未覆盖目标产业数量: {len(missing)}") # ========== 6️⃣ 返回适应度(越大越好) ========== return (float(-error),) # 目标产业集合 def get_target_vulnerable_industries(): """ 获取行业列表中所有产业链编号的集合(整数形式)。 说明: - 输入的 industry_list 是一个字典列表,每个字典包含: {"product": 产品名称, "category": 产品类别, "chain_id": 产业链编号} - 某些 chain_id 可能是复合编号,例如 "11 / 513742",需要拆分成单独整数。 - 输出是一个 set,包含所有 chain_id(去重、整数形式)。 参数: industry_list : list of dict 行业字典列表,每个字典必须包含 "chain_id" 键。 返回: set 所有产业链编号的整数集合。 """ industry_list = [ # ① 半导体设备类 {"product": "离子注入机", "category": "离子注入设备", "chain_id": 34538}, {"product": "刻蚀设备 / 湿法刻蚀设备", "category": "刻蚀机", "chain_id": 34529}, {"product": "沉积设备", "category": "薄膜生长设备(CVD/PVD)", "chain_id": 34539}, {"product": "CVD", "category": "薄膜生长设备", "chain_id": 34539}, {"product": "PVD", "category": "薄膜生长设备", "chain_id": 34539}, {"product": "CMP", "category": "化学机械抛光设备", "chain_id": 34530}, {"product": "光刻机", "category": "光刻机", "chain_id": 34533}, {"product": "涂胶显影机", "category": "涂胶显影设备", "chain_id": 34535}, {"product": "晶圆清洗设备", "category": "晶圆清洗机", "chain_id": 34531}, {"product": "测试设备", "category": "测试机", "chain_id": 34554}, {"product": "外延生长设备", "category": "薄膜生长设备", "chain_id": 34539}, # ② 半导体材料与化学品类 {"product": "三氯乙烯", "category": "清洗溶剂 → 通用湿电子化学品", "chain_id": 32438}, {"product": "丙酮", "category": "清洗溶剂 → 通用湿电子化学品", "chain_id": 32438}, {"product": "异丙醇", "category": "清洗溶剂 → 通用湿电子化学品", "chain_id": 32438}, {"product": "其他醇类", "category": "清洗溶剂 → 通用湿电子化学品", "chain_id": 32438}, {"product": "光刻胶", "category": "光刻胶及配套试剂", "chain_id": 32445}, {"product": "显影液", "category": "显影液", "chain_id": 46504}, {"product": "蚀刻液", "category": "蚀刻液", "chain_id": 56341}, {"product": "光阻去除剂", "category": "光阻去除剂", "chain_id": 32442}, # ③ 晶圆制造类 {"product": "晶圆", "category": "单晶硅片 / 多晶硅片", "chain_id": 32338}, {"product": "硅衬底", "category": "硅衬底", "chain_id": 36914}, {"product": "外延片", "category": "硅外延片 / GaN外延片 / SiC外延片等", "chain_id": 32338}, # ④ 封装与测试类 {"product": "封装", "category": "IC封装", "chain_id": 10}, {"product": "测试", "category": "芯片测试 / 晶圆测试", "chain_id": 513742}, {"product": "测试", "category": "芯片测试 / 晶圆测试", "chain_id": 11}, # ⑤ 芯片与设计EDA类 {"product": "芯片(通用)", "category": "集成电路制造", "chain_id": 317589}, {"product": "DRAM", "category": "存储芯片 → 集成电路制造", "chain_id": 317589}, {"product": "GPU", "category": "图形芯片 → 集成电路制造", "chain_id": 317589}, {"product": "处理器(CPU/SoC)", "category": "芯片设计", "chain_id": 9}, {"product": "高频芯片", "category": "芯片设计", "chain_id": 9}, {"product": "光子芯片(含激光)", "category": "芯片设计 / 功率半导体器件", "chain_id": 9}, {"product": "光子芯片(含激光)", "category": "芯片设计 / 功率半导体器件", "chain_id": 2717}, {"product": "先进节点制造设备", "category": "集成电路制造", "chain_id": 317589}, {"product": "EDA及IP服务", "category": "设计辅助", "chain_id": 2515}, {"product": "MPW服务", "category": "多项目晶圆流片", "chain_id": 2514}, {"product": "芯片设计验证", "category": "设计验证", "chain_id": 513738}, {"product": "过程工艺检测", "category": "制程检测", "chain_id": 513740} ] # 手工转换 merged_list = [ '61', '68', '65', '66', '59', '71', '77', '7', '38', '95', '58', '90', '56', '57', '97', '98', '27', '8', '11', '9', '21', '96', '99', '100', '44', '45', '46', '47', '48', '49', '50', '51', '52', '53', '54', '55' ] # # 提取所有 chain_id,并去重 # chain_ids = set() # for item in industry_list: # # 如果 chain_id 是字符串包含多个编号,用逗号或斜杠拆分 # if isinstance(item["chain_id"], str): # for cid in item["chain_id"].replace("/", ",").split(","): # chain_ids.add(cid.strip()) # else: # chain_ids.add(str(item["chain_id"])) return merged_list # 从数据库计算脆弱产业集合 def get_vulnerable35_code(engine, ga_id): """ 计算最脆弱前100产品的 Code 列表(去重),只针对指定 ga_id。 """ # 生成新的 session Session = sessionmaker(bind=engine) session = Session() # 1️⃣ 读取 SQL 文件 base_dir = os.path.dirname(os.path.abspath(__file__)) sql_file = os.path.join(base_dir, "..", "GA_Agent_0925", "SQL_analysis_risk_ga.sql") sql_file = os.path.abspath(sql_file) with open(sql_file, "r", encoding="utf-8") as f: str_sql = f.read() # 注意这里是 str_sql,不是 tr_sql print(f"[信息] 正在查询 ga_id={ga_id} 的脆弱产品数据...") # 2️⃣ 执行 SQL 查询 # 2️⃣ 新建 connection # 2️⃣ 新建 session,每次查询都用新的 session/connection Session = sessionmaker(bind=engine) with Session() as session: # 使用 session.connection() 来保证 SQLAlchemy 执行原生 SQL result = pd.read_sql( sql=text(str_sql), con=session.connection(), params={"ga_id": ga_id} # 绑定参数 ) # =============================== # 2️⃣ 统计每个企业-产品组合出现次数 # =============================== count_firm_prod = result.value_counts(subset=['id_firm', 'id_product']) count_firm_prod.name = 'count' count_firm_prod = count_firm_prod.to_frame().reset_index() count_firm_prod.to_csv('count_firm_prod.csv', index=False, encoding='utf-8-sig') # =============================== # 3️⃣ 统计每个企业出现的总次数 # =============================== count_firm = count_firm_prod.groupby('id_firm')['count'].sum().reset_index() count_firm.sort_values('count', ascending=False, inplace=True) count_firm.to_csv('count_firm.csv', index=False, encoding='utf-8-sig') # =============================== # 4️⃣ 统计每个产品出现的总次数 # =============================== count_prod = count_firm_prod.groupby('id_product')['count'].sum().reset_index() count_prod.sort_values('count', ascending=False, inplace=True) count_prod.to_csv('count_prod.csv', index=False, encoding='utf-8-sig') # =============================== # 5️⃣ 选出最脆弱的前100个产品(出现次数最多) # =============================== vulnerable100_product = count_prod.nlargest(35, "count")["id_product"].tolist() print(f"[信息] ga_id={ga_id} 查询完成,共找到 {len(vulnerable100_product)} 个脆弱产品") # =============================== # # 6️⃣ 过滤 result,只保留前100脆弱产品 # # =============================== # result_vulnerable100 = result[result['id_product'].isin(vulnerable100_product)].copy() # print(f"[信息] 筛选后剩余记录数: {len(result_vulnerable100)}") # # =============================== # # 7️⃣ 构造 DCP(Disruption Causing Probability) # # =============================== # result_dcp_list = [] # for sid, group in result_vulnerable100.groupby('s_id'): # ts_start = max(group['ts']) # while ts_start >= 1: # ts_end = ts_start - 1 # while ts_end >= 0: # up = group.loc[group['ts'] == ts_end, ['id_firm', 'id_product']] # down = group.loc[group['ts'] == ts_start, ['id_firm', 'id_product']] # for _, up_row in up.iterrows(): # for _, down_row in down.iterrows(): # result_dcp_list.append([sid] + up_row.tolist() + down_row.tolist()) # ts_end -= 1 # ts_start -= 1 # # # 转换为 DataFrame # result_dcp = pd.DataFrame(result_dcp_list, columns=[ # 's_id', 'up_id_firm', 'up_id_product', 'down_id_firm', 'down_id_product' # ]) # # # =============================== # # 8️⃣ 统计 DCP 出现次数 # # =============================== # count_dcp = result_dcp.value_counts( # subset=['up_id_firm', 'up_id_product', 'down_id_firm', 'down_id_product'] # ).reset_index(name='count') # # # 保存文件 # count_dcp.to_csv('count_dcp.csv', index=False, encoding='utf-8-sig') # 输出结果 return vulnerable100_product def run_ABM_samples(controller_db_obj,ga_id,dct_exp, str_code="GA",): """ 从数据库获取一个随机样本,锁定它,然后运行模型仿真。 参数: controller_db: ControllerDB 对象 str_code: 可选标识,用于打印 返回: True 如果没有可用样本 False 如果成功运行 """ # 1. 从数据库获取一个随机样本 sample_random = controller_db_obj.fetch_a_sample(s_id=None) # s_id 可根据需要传入 if sample_random is None: #print(ga_id+"无样本") return True # 没有样本,返回 True 表示结束 # 2. 锁定该样本 controller_db_obj.lock_the_sample(sample_random) # print(f"Pid {pid} ({str_code}) is running sample {sample_random.id} at {datetime.now()}") # print(f"Pid {pid} ({str_code})") # print(f"[信息] 当前正在运行的 GA 个体 ID: {ga_id}, 时间: {datetime.now()}") # 3. 获取 experiment 的所有列及其值 dct_exp_new = {column: getattr(sample_random.experiment, column) for column in sample_random.experiment.__table__.c.keys()} # 删除不需要的主键 id dct_exp_new.pop('id', None) dct_exp_new = {'sample': sample_random, 'seed': sample_random.seed, **dct_exp_new} try: dct_sample_para = { **dct_exp_new, **dct_exp} # 切换工作目录到项目根目录或 ABM 需要的目录 project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) os.chdir(project_root) abm_model = MyModel(dct_sample_para) abm_model.step() abm_model.end(ga_id) except Exception as e: print(f"[❌ ABM运行错误] 错误:{e} ") traceback.print_exc() def do_computation(controller_db_obj,ga_id,dct_exp,): """每个进程执行 ABM 样本运行""" pid = os.getpid() print(f"[启动] 进程 {pid} 已启动 (PID={pid})") while 1: # time.sleep(random.uniform(0, 1)) is_all_done = run_ABM_samples(controller_db_obj,ga_id,dct_exp,) if is_all_done: break def do_process(controller_db,ga_id,dct_exp,job): process_list = [] for i in range(job): p = Process(target=do_computation, args=(controller_db,ga_id,dct_exp,)) p.start() process_list.append(p) for i in process_list: i.join()