遗传算法

This commit is contained in:
Cricial
2026-03-15 17:15:30 +08:00
parent 03ac80715f
commit a69e272e43
27 changed files with 934 additions and 191 deletions

View File

@@ -13,7 +13,6 @@ from my_model import MyModel
from orm import connection, engine
# 🎯 适应度函数(核心目标函数)
def fitness(individual, controller_db_obj):
"""
遗传算法适应度函数:用于评估个体(模型参数)的优劣。
@@ -57,7 +56,7 @@ def fitness(individual, controller_db_obj):
# print(f" {key}: {value}")
# ========== 2⃣ 调用 ABM 模型 ==========
# 并行进程数目
job=4
job=6
do_process(controller_db_obj,ga_id,dct_exp,job)
# ========== 3⃣ 获取数据库连接并提取结果 ==========
simulated_vulnerable_industries = get_vulnerable35_code(connection,ga_id)
@@ -65,25 +64,36 @@ def fitness(individual, controller_db_obj):
# ========== 4⃣ 获取目标产业集合 ==========
target_vulnerable_industries = get_target_vulnerable_industries()
# ========== 5⃣ 计算误差(集合差异度) ==========
set_sim = set(simulated_vulnerable_industries)
set_target = set(target_vulnerable_industries)
error = len(set_sim.symmetric_difference(set_target))
"""
Top-K 加权命中误差(越小越好)
simulated_set = set(simulated_vulnerable_industries)
simulated_vulnerable_industries : list[str]
模型输出的产业排序(风险从高到低)
target_vulnerable_industries : list[str] or set[str]
真实脆弱产业集合(无序)
"""
ranked_list = simulated_vulnerable_industries
target_set = set(target_vulnerable_industries)
matching = simulated_set & target_set # 交集
extra = simulated_set - target_set # 模拟多出的产业
missing = target_set - simulated_set # 未覆盖产业
total_weight = 0.0
hit_weight = 0.0
print(f"符合目标的产业数量: {len(matching)}")
print(matching)
print(f"模拟多出的产业数量: {len(extra)}")
print(f"未覆盖目标产业数量: {len(missing)}")
for rank, industry in enumerate(ranked_list, start=1):
# 权重排名越靠前越大1 / rank
w = 1.0 /rank
total_weight += w
# ========== 6⃣ 返回适应度(越大越好) ==========
return (float(-error),)
if industry in target_set:
hit_weight += w
hit_ratio = hit_weight / total_weight if total_weight > 0 else 0.0
error = 1.0 - hit_ratio # GA 里:越小越好
# ---- 调试信息(强烈建议保留) ----
print(f"加权命中率: {hit_ratio:.4f}")
print(f"加权误差: {error:.4f}")
return (error,)
# 目标产业集合
def get_target_vulnerable_industries():
"""
@@ -156,6 +166,20 @@ def get_target_vulnerable_industries():
'27', '8', '11', '9', '21', '96', '99', '100', '44', '45', '46', '47', '48', '49', '50', '51',
'52', '53', '54', '55'
]
# 修改为25年的产业进行比较
merged_list_24 = [
'32', '33', '34', '35', '36', '37',
'38', '39', '41', '42', '43',
'46', '47', '48', '49',
'51', '52', '53', '54', '55',
'56', '57', '58',
'59', '60', '61', '62', '63', '64', '65', '66',
'68', '70', '71', '73', '74', '78', '79',
'90', '91', '92', '93', '94',
'95', '96', '97', '99', '100', '101',
'102', '103', '104', '105', '106', '107', '108', '109'
]
# # 提取所有 chain_id并去重
# chain_ids = set()
# for item in industry_list:
@@ -166,7 +190,7 @@ def get_target_vulnerable_industries():
# else:
# chain_ids.add(str(item["chain_id"]))
return merged_list
return merged_list_24
# 从数据库计算脆弱产业集合
def get_vulnerable35_code(engine, ga_id):
@@ -221,7 +245,7 @@ def get_vulnerable35_code(engine, ga_id):
# ===============================
# 5⃣ 选出最脆弱的前100个产品出现次数最多
# ===============================
vulnerable100_product = count_prod.nlargest(35, "count")["id_product"].tolist()
vulnerable100_product = count_prod.nlargest(64, "count")["id_product"].tolist()
print(f"[信息] ga_id={ga_id} 查询完成,共找到 {len(vulnerable100_product)} 个脆弱产品")
# ===============================