遗传算法001

This commit is contained in:
Cricial
2025-10-18 16:16:05 +08:00
parent 91f2122b65
commit dfd5c5b32d
49 changed files with 8473 additions and 1825 deletions

View File

@@ -0,0 +1,330 @@
# -*- coding: utf-8 -*- # 文件的编码格式设置为 UTF-8
from __future__ import division # 为了兼容 Python 2 和 3保证除法始终返回浮点数
import multiprocessing
import random # 导入 random 库,用于生成随机数
from deap import base # 从 DEAP 库导入 base 模块,提供一些遗传算法相关的功能
from deap import creator # 从 DEAP 库导入 creator 模块,用于定义个体和适应度
from deap import tools # 从 DEAP 库导入 tools 模块,提供常用的遗传算法工具(如交叉、变异等)
from my_model import MyModel
from sqlalchemy import text
import pandas as pd
from orm import connection
def main():
random.seed(42) # 可复现结果
print("Start of evolution")
ga = creating()
pop = ga.population(n=50)
CXPB, MUTPB, NGEN = 0.5, 0.2, 200
# # 并行计算
# pool = multiprocessing.Pool()
# ga.register("map", pool.map)
# 改为:
ga.register("map", map) # 单进程
# 评估初始种群
fitnesses = list(ga.map(ga.evaluate, pop))
for ind, fit in zip(pop, fitnesses):
ind.fitness.values = fit
print(f"Evaluated {len(pop)} individuals")
best_log = []
for g in range(NGEN):
print(f"-- Generation {g} --")
# 选择并克隆
offspring = list(map(ga.clone, ga.select(pop, len(pop))))
# 交叉与变异
for child1, child2 in zip(offspring[::2], offspring[1::2]):
if random.random() < CXPB:
ga.mate(child1, child2)
del child1.fitness.values
del child2.fitness.values
for mutant in offspring:
if random.random() < MUTPB:
ga.mutate(mutant)
del mutant.fitness.values
# 重新计算失效适应度
invalid_ind = [ind for ind in offspring if not ind.fitness.valid]
fitnesses = list(ga.map(ga.evaluate, invalid_ind))
for ind, fit in zip(invalid_ind, fitnesses):
ind.fitness.values = fit
pop[:] = offspring
# 最优个体
best_ind = tools.selBest(pop, 1)[0]
best_log.append((g, best_ind.fitness.values[0]))
print(f"Best individual {g}: {best_ind}, Fitness: {best_ind.fitness.values[0]:.3f}")
# 写入数据库
result_sql = text(f"""
INSERT INTO ga (generation, stu_beta, stu_nmb, gtu_mgf, gtu_discount, fitness, remark)
VALUES ({g}, {best_ind[0]}, {best_ind[1]}, {best_ind[2]}, {best_ind[3]}, {best_ind.fitness.values[0]}, 'Random2')
""")
with connection.connect() as conn:
conn.execute(result_sql)
conn.commit()
# pool.close()
# pool.join()
pd.DataFrame(best_log, columns=["generation", "fitness"]).to_csv("ga_log.csv", index=False)
print("-- End of (successful) evolution --")
# 目标函数(适应度函数),用于评估个体的适应度
def fitness(individual):
"""
GA 适应度函数:用于评估个体(模型参数)的效果。
目标:
- individual: 遗传算法中的个体参数列表
[n_max_trial, prf_size, prf_conn, cap_limit_prob_type, cap_limit_level,
diff_new_conn, netw_prf_n, s_r, S_r, x, k, production_increase_ratio]
- target_chain_set: 美国打击的产业链编号集合(整数集合)
适应度定义:
- fitness = -error
- error = 脆弱产业集合与 target_chain_set 的差集大小
"""
# 1 将 GA 生成的个体参数传入 ABM 模型
"""
n_iter
g_bom
seed
sample
dct_lst_init_disrupt_firm_prod
remove_t
"""
dct_exp = {
'n_max_trial': individual[0],
'prf_size': individual[1],
'prf_conn': individual[2],
'cap_limit_prob_type': individual[3],
'cap_limit_level': individual[4],
'diff_new_conn': individual[5],
'netw_prf_n': individual[6],
's_r': individual[7],
'S_r': individual[8],
'x': individual[9],
'k': individual[10],
'production_increase_ratio': individual[11]
}
abm_model = MyModel(**dct_exp)
# 2 运行 ABM获取模拟结果的“脆弱产业集合”
abm_model.step()
abm_model.end()
simulated_vulnerable_industries=get_vulnerable100_code(connection)
# 3 获取目标集合(美国打击我们的产业集合)
target_vulnerable_industries = get_target_vulnerable_industries() # list / set
# 4 计算误差(集合差异度)
# 这里可以用 Jaccard 距离、集合交并比、或者简单的匹配数差
set_sim = set(simulated_vulnerable_industries)
set_target = set(target_vulnerable_industries)
error = len(set_sim.symmetric_difference(set_target)) # 差异元素个数
# 5 返回 fitnessGA 目标是最大化)
# 因为我们希望误差越小越好,所以 fitness = -error
return -error,
def creating():
"""
创建遗传算法工具箱,用于优化 ABM 模型参数,使生成的脆弱产业集合
与目标产业集合误差最小化fitness 最大化)。
"""
if "FitnessMax" not in creator.__dict__:
creator.create("FitnessMax", base.Fitness, weights=(1.0,))
if "Individual" not in creator.__dict__:
creator.create("Individual", list, fitness=creator.FitnessMax)
# 定义最大化适应度
creator.create("FitnessMax", base.Fitness, weights=(1.0,))
# 定义个体类
creator.create("Individual", list, fitness=creator.FitnessMax)
toolbox = base.Toolbox()
# 定义每个基因的取值范围 / 类型及默认值
toolbox.register("n_max_trial", random.randint, 50, 500) # 最大尝试次数 [50,500]
toolbox.register("prf_size", random.uniform, 0.0, 1.0) # 是否规模偏好参数 [0,1]
toolbox.register("prf_conn", random.uniform, 0.0, 1.0) # 是否已有连接偏好 [0,1]
toolbox.register("cap_limit_prob_type", random.randint, 0, 2) # 额外产能分布类型 {0:正态,1:均匀,2:指数}
toolbox.register("cap_limit_level", random.uniform, 0.5, 2.0) # 额外产能均值放缩因子 [0.5,2.0]
toolbox.register("diff_new_conn", random.uniform, 0.0, 1.0) # 新供应关系构成概率 [0,1]
toolbox.register("netw_prf_n", random.randint, 1, 10) # 在网络中选择供应商目标数量 [1,10]
toolbox.register("s_r", random.uniform, 0.1, 0.5) # 补货下阈值 [0.1,0.5]
toolbox.register("S_r", random.uniform, 0.5, 1.0) # 补货上阈值 [0.5,1.0]
toolbox.register("x", random.uniform, 0.0, 0.1) # 每周期减少残值 [0.0,0.1]
toolbox.register("k", random.uniform, 0.1, 1.0) # 资源消耗比例 [0.1,1.0]
toolbox.register("production_increase_ratio", random.uniform, 0.5, 2.0) # 产品生产比例 [0.5,2.0]
# 个体由上述基因组成
toolbox.register(
"individual",
tools.initCycle,
creator.Individual,
(
toolbox.n_max_trial,
toolbox.prf_size,
toolbox.prf_conn,
toolbox.cap_limit_prob_type,
toolbox.cap_limit_level,
toolbox.diff_new_conn,
toolbox.netw_prf_n,
toolbox.s_r,
toolbox.S_r,
toolbox.x,
toolbox.k,
toolbox.production_increase_ratio
),
n=1
)
# 种群初始化
toolbox.register("population", tools.initRepeat, list, toolbox.individual)
# 注册 fitness 函数(需要在调用时传入目标产业集合)
toolbox.register("evaluate", fitness) # 可以在 main 中使用 lambda 包装 target_chain_set
# 交叉、变异和选择操作
toolbox.register("mate", tools.cxTwoPoint)
toolbox.register("mutate", tools.mutShuffleIndexes, indpb=0.1)
toolbox.register("select", tools.selTournament, tournsize=3)
return toolbox
def get_target_vulnerable_industries():
"""
获取行业列表中所有产业链编号的集合(整数形式)。
说明:
- 输入的 industry_list 是一个字典列表,每个字典包含:
{"product": 产品名称, "category": 产品类别, "chain_id": 产业链编号}
- 某些 chain_id 可能是复合编号,例如 "11 / 513742",需要拆分成单独整数。
- 输出是一个 set包含所有 chain_id去重、整数形式
参数:
industry_list : list of dict
行业字典列表,每个字典必须包含 "chain_id" 键。
返回:
set
所有产业链编号的整数集合。
"""
industry_list = [
# ① 半导体设备类
{"product": "离子注入机", "category": "离子注入设备", "chain_id": 34538},
{"product": "刻蚀设备 / 湿法刻蚀设备", "category": "刻蚀机", "chain_id": 34529},
{"product": "沉积设备", "category": "薄膜生长设备CVD/PVD", "chain_id": 34539},
{"product": "CVD", "category": "薄膜生长设备", "chain_id": 34539},
{"product": "PVD", "category": "薄膜生长设备", "chain_id": 34539},
{"product": "CMP", "category": "化学机械抛光设备", "chain_id": 34530},
{"product": "光刻机", "category": "光刻机", "chain_id": 34533},
{"product": "涂胶显影机", "category": "涂胶显影设备", "chain_id": 34535},
{"product": "晶圆清洗设备", "category": "晶圆清洗机", "chain_id": 34531},
{"product": "测试设备", "category": "测试机", "chain_id": 34554},
{"product": "外延生长设备", "category": "薄膜生长设备", "chain_id": 34539},
# ② 半导体材料与化学品类
{"product": "三氯乙烯", "category": "清洗溶剂 → 通用湿电子化学品", "chain_id": 32438},
{"product": "丙酮", "category": "清洗溶剂 → 通用湿电子化学品", "chain_id": 32438},
{"product": "异丙醇", "category": "清洗溶剂 → 通用湿电子化学品", "chain_id": 32438},
{"product": "其他醇类", "category": "清洗溶剂 → 通用湿电子化学品", "chain_id": 32438},
{"product": "光刻胶", "category": "光刻胶及配套试剂", "chain_id": 32445},
{"product": "显影液", "category": "显影液", "chain_id": 46504},
{"product": "蚀刻液", "category": "蚀刻液", "chain_id": 56341},
{"product": "光阻去除剂", "category": "光阻去除剂", "chain_id": 32442},
# ③ 晶圆制造类
{"product": "晶圆", "category": "单晶硅片 / 多晶硅片", "chain_id": 32338},
{"product": "硅衬底", "category": "硅衬底", "chain_id": 36914},
{"product": "外延片", "category": "硅外延片 / GaN外延片 / SiC外延片等", "chain_id": 32338},
# ④ 封装与测试类
{"product": "封装", "category": "IC封装", "chain_id": 10},
{"product": "测试", "category": "芯片测试 / 晶圆测试", "chain_id": 513742},
{"product": "测试", "category": "芯片测试 / 晶圆测试", "chain_id": 11},
# ⑤ 芯片与设计EDA类
{"product": "芯片(通用)", "category": "集成电路制造", "chain_id": 317589},
{"product": "DRAM", "category": "存储芯片 → 集成电路制造", "chain_id": 317589},
{"product": "GPU", "category": "图形芯片 → 集成电路制造", "chain_id": 317589},
{"product": "处理器CPU/SoC", "category": "芯片设计", "chain_id": 9},
{"product": "高频芯片", "category": "芯片设计", "chain_id": 9},
{"product": "光子芯片(含激光)", "category": "芯片设计 / 功率半导体器件", "chain_id": 9},
{"product": "光子芯片(含激光)", "category": "芯片设计 / 功率半导体器件", "chain_id": 2717},
{"product": "先进节点制造设备", "category": "集成电路制造", "chain_id": 317589},
{"product": "EDA及IP服务", "category": "设计辅助", "chain_id": 2515},
{"product": "MPW服务", "category": "多项目晶圆流片", "chain_id": 2514},
{"product": "芯片设计验证", "category": "设计验证", "chain_id": 513738},
{"product": "过程工艺检测", "category": "制程检测", "chain_id": 513740}
]
# 提取所有 chain_id并去重
chain_ids = set()
for item in industry_list:
# 如果 chain_id 是字符串包含多个编号,用逗号或斜杠拆分
if isinstance(item["chain_id"], str):
for cid in item["chain_id"].replace("/", ",").split(","):
chain_ids.add(cid.strip())
else:
chain_ids.add(str(item["chain_id"]))
return chain_ids
def get_vulnerable100_code(connection):
"""
计算最脆弱前100产品的 Code 列表(去重)。
参数:
connection: 数据库连接对象,用于执行 SQL
返回:
List[int]: 最脆弱前100产品对应的 Code 列表
"""
# 读取映射表
bom_file = r"../../input_data/input_product_data/BomNodes.csv" # 直接给出路径
mapping_df = pd.read_csv(bom_file)
# 执行 SQL 获取结果
with open("../../SQL_analysis_risk.sql", "r", encoding="utf-8") as f:
str_sql = text(f.read())
result = pd.read_sql(sql=str_sql, con=connection)
# 统计每个 (id_firm, id_product) 出现次数
count_firm_prod = result.value_counts(subset=['id_firm', 'id_product'])
count_firm_prod.name = 'count'
count_firm_prod = count_firm_prod.to_frame().reset_index()
# 统计每个 id_product 的总 count
count_prod = (
count_firm_prod
.groupby("id_product")["count"]
.sum()
.reset_index()
)
# 按 count 升序取最脆弱前100 id_product
vulnerable100_index = count_prod.nsmallest(100, "count")["id_product"].tolist()
# 映射 Index -> Code 并去重
index_to_code = dict(zip(mapping_df["Index"], mapping_df["Code"]))
vulnerable100_code = list({index_to_code[i] for i in vulnerable100_index if i in index_to_code})
return vulnerable100_code
if __name__ == "__main__":
main()