遗传算法
This commit is contained in:
@@ -13,7 +13,6 @@ from my_model import MyModel
|
||||
from orm import connection, engine
|
||||
|
||||
|
||||
# 🎯 适应度函数(核心目标函数)
|
||||
def fitness(individual, controller_db_obj):
|
||||
"""
|
||||
遗传算法适应度函数:用于评估个体(模型参数)的优劣。
|
||||
@@ -57,7 +56,7 @@ def fitness(individual, controller_db_obj):
|
||||
# print(f" {key}: {value}")
|
||||
# ========== 2️⃣ 调用 ABM 模型 ==========
|
||||
# 并行进程数目
|
||||
job=4
|
||||
job=6
|
||||
do_process(controller_db_obj,ga_id,dct_exp,job)
|
||||
# ========== 3️⃣ 获取数据库连接并提取结果 ==========
|
||||
simulated_vulnerable_industries = get_vulnerable35_code(connection,ga_id)
|
||||
@@ -65,25 +64,36 @@ def fitness(individual, controller_db_obj):
|
||||
# ========== 4️⃣ 获取目标产业集合 ==========
|
||||
target_vulnerable_industries = get_target_vulnerable_industries()
|
||||
|
||||
# ========== 5️⃣ 计算误差(集合差异度) ==========
|
||||
set_sim = set(simulated_vulnerable_industries)
|
||||
set_target = set(target_vulnerable_industries)
|
||||
error = len(set_sim.symmetric_difference(set_target))
|
||||
"""
|
||||
Top-K 加权命中误差(越小越好)
|
||||
|
||||
simulated_set = set(simulated_vulnerable_industries)
|
||||
simulated_vulnerable_industries : list[str]
|
||||
模型输出的产业排序(风险从高到低)
|
||||
target_vulnerable_industries : list[str] or set[str]
|
||||
真实脆弱产业集合(无序)
|
||||
"""
|
||||
ranked_list = simulated_vulnerable_industries
|
||||
target_set = set(target_vulnerable_industries)
|
||||
|
||||
matching = simulated_set & target_set # 交集
|
||||
extra = simulated_set - target_set # 模拟多出的产业
|
||||
missing = target_set - simulated_set # 未覆盖产业
|
||||
total_weight = 0.0
|
||||
hit_weight = 0.0
|
||||
|
||||
print(f"符合目标的产业数量: {len(matching)}")
|
||||
print(matching)
|
||||
print(f"模拟多出的产业数量: {len(extra)}")
|
||||
print(f"未覆盖目标产业数量: {len(missing)}")
|
||||
for rank, industry in enumerate(ranked_list, start=1):
|
||||
# 权重:排名越靠前越大(1 / rank)
|
||||
w = 1.0 /rank
|
||||
total_weight += w
|
||||
|
||||
# ========== 6️⃣ 返回适应度(越大越好) ==========
|
||||
return (float(-error),)
|
||||
if industry in target_set:
|
||||
hit_weight += w
|
||||
|
||||
hit_ratio = hit_weight / total_weight if total_weight > 0 else 0.0
|
||||
error = 1.0 - hit_ratio # GA 里:越小越好
|
||||
|
||||
# ---- 调试信息(强烈建议保留) ----
|
||||
print(f"加权命中率: {hit_ratio:.4f}")
|
||||
print(f"加权误差: {error:.4f}")
|
||||
|
||||
return (error,)
|
||||
# 目标产业集合
|
||||
def get_target_vulnerable_industries():
|
||||
"""
|
||||
@@ -156,6 +166,20 @@ def get_target_vulnerable_industries():
|
||||
'27', '8', '11', '9', '21', '96', '99', '100', '44', '45', '46', '47', '48', '49', '50', '51',
|
||||
'52', '53', '54', '55'
|
||||
]
|
||||
# 修改为25年的产业进行比较
|
||||
merged_list_24 = [
|
||||
'32', '33', '34', '35', '36', '37',
|
||||
'38', '39', '41', '42', '43',
|
||||
'46', '47', '48', '49',
|
||||
'51', '52', '53', '54', '55',
|
||||
'56', '57', '58',
|
||||
'59', '60', '61', '62', '63', '64', '65', '66',
|
||||
'68', '70', '71', '73', '74', '78', '79',
|
||||
'90', '91', '92', '93', '94',
|
||||
'95', '96', '97', '99', '100', '101',
|
||||
'102', '103', '104', '105', '106', '107', '108', '109'
|
||||
]
|
||||
|
||||
# # 提取所有 chain_id,并去重
|
||||
# chain_ids = set()
|
||||
# for item in industry_list:
|
||||
@@ -166,7 +190,7 @@ def get_target_vulnerable_industries():
|
||||
# else:
|
||||
# chain_ids.add(str(item["chain_id"]))
|
||||
|
||||
return merged_list
|
||||
return merged_list_24
|
||||
|
||||
# 从数据库计算脆弱产业集合
|
||||
def get_vulnerable35_code(engine, ga_id):
|
||||
@@ -221,7 +245,7 @@ def get_vulnerable35_code(engine, ga_id):
|
||||
# ===============================
|
||||
# 5️⃣ 选出最脆弱的前100个产品(出现次数最多)
|
||||
# ===============================
|
||||
vulnerable100_product = count_prod.nlargest(35, "count")["id_product"].tolist()
|
||||
vulnerable100_product = count_prod.nlargest(64, "count")["id_product"].tolist()
|
||||
print(f"[信息] ga_id={ga_id} 查询完成,共找到 {len(vulnerable100_product)} 个脆弱产品")
|
||||
|
||||
# ===============================
|
||||
|
||||
Reference in New Issue
Block a user