遗传算法001
This commit is contained in:
21
GA_Agent_0925/SQL_analysis_risk_ga.sql
Normal file
21
GA_Agent_0925/SQL_analysis_risk_ga.sql
Normal file
@@ -0,0 +1,21 @@
|
||||
SELECT *
|
||||
FROM (
|
||||
SELECT s_id, id_firm, id_product, MIN(ts) AS ts
|
||||
FROM iiabmdb_20250925.without_exp_result
|
||||
WHERE `status` = 'D'
|
||||
AND ga_id = :ga_id
|
||||
GROUP BY s_id, id_firm, id_product
|
||||
) AS s_disrupt
|
||||
WHERE s_id IN (
|
||||
SELECT s_id
|
||||
FROM (
|
||||
SELECT s_id, id_firm, id_product, MIN(ts) AS ts
|
||||
FROM iiabmdb_20250925.without_exp_result
|
||||
WHERE `status` = 'D'
|
||||
AND ga_id = :ga_id
|
||||
GROUP BY s_id, id_firm, id_product
|
||||
) AS t
|
||||
GROUP BY s_id
|
||||
HAVING COUNT(*) > 1
|
||||
)
|
||||
ORDER BY s_id;
|
||||
BIN
GA_Agent_0925/__pycache__/controller_db.cpython-38.pyc
Normal file
BIN
GA_Agent_0925/__pycache__/controller_db.cpython-38.pyc
Normal file
Binary file not shown.
BIN
GA_Agent_0925/__pycache__/creating.cpython-38.pyc
Normal file
BIN
GA_Agent_0925/__pycache__/creating.cpython-38.pyc
Normal file
Binary file not shown.
BIN
GA_Agent_0925/__pycache__/evaluate_func.cpython-38.pyc
Normal file
BIN
GA_Agent_0925/__pycache__/evaluate_func.cpython-38.pyc
Normal file
Binary file not shown.
BIN
GA_Agent_0925/__pycache__/orm.cpython-38.pyc
Normal file
BIN
GA_Agent_0925/__pycache__/orm.cpython-38.pyc
Normal file
Binary file not shown.
136
GA_Agent_0925/best_result_with_industry.json
Normal file
136
GA_Agent_0925/best_result_with_industry.json
Normal file
@@ -0,0 +1,136 @@
|
||||
{
|
||||
"best_individual": [
|
||||
291,
|
||||
0.24373090607513836,
|
||||
0.5512650768804697,
|
||||
1,
|
||||
0.7859155564218925,
|
||||
0.5993775986748999,
|
||||
3,
|
||||
0.4456107737714353,
|
||||
0.6381237014110205,
|
||||
0.07900061820031135,
|
||||
0.4734481811962107,
|
||||
1.9013725905237802
|
||||
],
|
||||
"best_fitness": -9999.0,
|
||||
"best_per_gen": [
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0
|
||||
],
|
||||
"avg_per_gen": [
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0,
|
||||
-9999.0
|
||||
],
|
||||
"industry_matching": {
|
||||
"matching": [],
|
||||
"extra": [],
|
||||
"missing": [
|
||||
"2515",
|
||||
"34533",
|
||||
"10",
|
||||
"34539",
|
||||
"34529",
|
||||
"513740",
|
||||
"9",
|
||||
"34530",
|
||||
"513742"
|
||||
]
|
||||
}
|
||||
}
|
||||
12
GA_Agent_0925/conf_experiment.yaml
Normal file
12
GA_Agent_0925/conf_experiment.yaml
Normal file
@@ -0,0 +1,12 @@
|
||||
# read by ControllerDB
|
||||
|
||||
# run settings
|
||||
meta_seed: 2
|
||||
|
||||
test: # only for test scenarios
|
||||
n_sample: 1
|
||||
n_iter: 100
|
||||
|
||||
not_test: # normal scenarios
|
||||
n_sample: 5
|
||||
n_iter: 10
|
||||
15
GA_Agent_0925/config.json
Normal file
15
GA_Agent_0925/config.json
Normal file
@@ -0,0 +1,15 @@
|
||||
{
|
||||
"pop_size": 5,
|
||||
"n_gen": 5,
|
||||
"n_var": 12,
|
||||
"bound_min": -5,
|
||||
"bound_max": 5,
|
||||
"cx_prob": 0.5,
|
||||
"mut_prob": 0.2,
|
||||
"cx_alpha": 0.5,
|
||||
"mut_sigma": 0.1,
|
||||
"mut_indpb": 0.2,
|
||||
"tourn_size": 3,
|
||||
"n_jobs": 1,
|
||||
"seed": 42
|
||||
}
|
||||
393
GA_Agent_0925/controller_db.py
Normal file
393
GA_Agent_0925/controller_db.py
Normal file
@@ -0,0 +1,393 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from orm import db_session, engine, Base, ins, connection
|
||||
from orm import Experiment, Sample, Result
|
||||
from sqlalchemy.exc import OperationalError
|
||||
from sqlalchemy import text
|
||||
import yaml
|
||||
import random
|
||||
import pandas as pd
|
||||
import platform
|
||||
import networkx as nx
|
||||
import json
|
||||
import pickle
|
||||
|
||||
class ControllerDB:
|
||||
is_with_exp: bool
|
||||
dct_parameter = None
|
||||
is_test: bool = None
|
||||
db_name_prefix: str = None
|
||||
reset_flag: int
|
||||
|
||||
lst_saved_s_id: list
|
||||
|
||||
def __init__(self, prefix, reset_flag=0):
|
||||
with open('conf_experiment.yaml') as yaml_file:
|
||||
dct_conf_experiment = yaml.full_load(yaml_file)
|
||||
assert prefix in ['test', 'without_exp', 'with_exp'], "db name not in test, without_exp, with_exp"
|
||||
|
||||
self.is_test = prefix == 'test'
|
||||
self.is_with_exp = False if prefix == 'test' or prefix == 'without_exp' else True
|
||||
self.db_name_prefix = prefix
|
||||
dct_para_in_test = dct_conf_experiment['test'] if self.is_test else dct_conf_experiment['not_test']
|
||||
self.dct_parameter = {'meta_seed': dct_conf_experiment['meta_seed'], **dct_para_in_test}
|
||||
|
||||
# print(self.dct_parameter)
|
||||
# 0, not reset; 1, reset self; 2, reset all
|
||||
self.reset_flag = reset_flag
|
||||
self.is_exist = False
|
||||
self.lst_saved_s_id = []
|
||||
|
||||
self.experiment_data = []
|
||||
self.batch_size = 5000
|
||||
# 根据需求设置每批次的大小
|
||||
|
||||
def init_tables(self):
|
||||
self.fill_experiment_table()
|
||||
self.fill_sample_table()
|
||||
|
||||
def fill_experiment_table(self):
|
||||
firm = pd.read_csv("../input_data/input_firm_data/Firm_amended.csv")
|
||||
firm['Code'] = firm['Code'].astype('string')
|
||||
firm.fillna(0, inplace=True)
|
||||
|
||||
# fill dct_lst_init_disrupt_firm_prod
|
||||
# 存储 公司-在供应链结点的位置.. 0 :‘1.1’
|
||||
if self.is_with_exp:
|
||||
# 对于方差分析时候使用
|
||||
with open('../SQL_export_high_risk_setting.sql', 'r') as f:
|
||||
str_sql = text(f.read())
|
||||
result = pd.read_sql(sql=str_sql, con=connection)
|
||||
result['dct_lst_init_disrupt_firm_prod'] = \
|
||||
result['dct_lst_init_disrupt_firm_prod'].apply(
|
||||
lambda x: pickle.loads(x))
|
||||
list_dct = result['dct_lst_init_disrupt_firm_prod'].to_list()
|
||||
else:
|
||||
# 行索引 (index):这一行在数据帧中的索引值。
|
||||
# 行数据 (row):这一行的数据,是一个 pandas.Series 对象,包含该行的所有列和值。
|
||||
|
||||
# 读取企业与产品关系数据
|
||||
firm_industry = pd.read_csv("../input_data/firm_industry_relation.csv")
|
||||
firm_industry['Firm_Code'] = firm_industry['Firm_Code'].astype('string')
|
||||
|
||||
# 假设已从 BOM 数据构建了 code_to_indices
|
||||
bom_nodes = pd.read_csv("../input_data/input_product_data/BomNodes.csv")
|
||||
code_to_indices = bom_nodes.groupby('Code')['Index'].apply(list).to_dict()
|
||||
|
||||
# 初始化存储映射结果的列表
|
||||
list_dct = []
|
||||
|
||||
# 遍历 firm_industry 数据
|
||||
for _, row in firm_industry.iterrows():
|
||||
firm_code = row['Firm_Code'] # 企业代码
|
||||
product_code = row['Product_Code'] # 原始产品代码
|
||||
|
||||
# 使用 code_to_indices 映射 Product_Code 到 Product_Indices
|
||||
mapped_indices = code_to_indices.get(product_code, []) # 如果找不到则返回空列表
|
||||
|
||||
# 构建企业到产品索引的映射
|
||||
dct = {firm_code: mapped_indices}
|
||||
list_dct.append(dct)
|
||||
|
||||
# fill g_bom
|
||||
# 结点属性值 相当于 图上点的 原始 产品名称
|
||||
bom_nodes = pd.read_csv('../input_data/input_product_data/BomNodes.csv')
|
||||
bom_nodes['Code'] = bom_nodes['Code'].astype(str)
|
||||
bom_nodes.set_index('Index', inplace=True)
|
||||
|
||||
bom_cate_net = pd.read_csv('../input_data/input_product_data/合成结点.csv')
|
||||
g_bom = nx.from_pandas_edgelist(bom_cate_net, source='UPID', target='ID', create_using=nx.MultiDiGraph())
|
||||
# 填充每一个结点 的具体内容 通过 相同的 code 并且通过BomNodes.loc[code].to_dict()字典化 格式类似 格式 { code(0) : {level: 0 ,name: 工业互联网 }}
|
||||
bom_labels_dict = {}
|
||||
for index in g_bom.nodes:
|
||||
try:
|
||||
bom_labels_dict[index] = bom_nodes.loc[index].to_dict()
|
||||
# print(bom_labels_dict[index])
|
||||
except KeyError:
|
||||
print(f"节点 {index} 不存在于 bom_nodes 中")
|
||||
# 分配属性 给每一个结点 获得类似 格式:{1: {'label': 'A', 'value': 10},
|
||||
nx.set_node_attributes(g_bom, bom_labels_dict)
|
||||
# 改为json 格式
|
||||
g_product_js = json.dumps(nx.adjacency_data(g_bom))
|
||||
|
||||
# insert exp
|
||||
df_xv = pd.read_csv(
|
||||
"../input_data/"
|
||||
f"xv_{'with_exp' if self.is_with_exp else 'without_exp'}.csv",
|
||||
index_col=None)
|
||||
# read the OA table
|
||||
df_oa = pd.read_csv(
|
||||
"../input_data/"
|
||||
f"oa_{'with_exp' if self.is_with_exp else 'without_exp'}.csv",
|
||||
index_col=None)
|
||||
# .shape[1] 列数 .iloc 访问特定的值 而不是标签
|
||||
df_oa = df_oa.iloc[:, 0:df_xv.shape[1]]
|
||||
|
||||
# idx_scenario 是 0 指行 idx_init_removal 指 索引 0.. dct_init_removal 键 code 公司 g_product_js 图的json数据 dct_exp_para 解码 全局参数xv-
|
||||
for idx_scenario, row in df_oa.iterrows():
|
||||
dct_exp_para = {}
|
||||
for idx_col, para_level in enumerate(row):
|
||||
# 处理 NaN 值,替换为默认值(如 0 或其他合适的值)
|
||||
para_level = para_level if not pd.isna(para_level) else 0
|
||||
# 转换为整数
|
||||
para_level = int(para_level)
|
||||
dct_exp_para[df_xv.columns[idx_col]] = \
|
||||
df_xv.iloc[para_level, idx_col]
|
||||
# different initial removal 只会得到 键 和 值
|
||||
for idx_init_removal, dct_init_removal in enumerate(list_dct):
|
||||
self.add_experiment_1(idx_scenario,
|
||||
idx_init_removal,
|
||||
dct_init_removal,
|
||||
g_product_js,
|
||||
**dct_exp_para)
|
||||
print(f"Inserted experiment for scenario {idx_scenario}, "
|
||||
f"init_removal {idx_init_removal}!")
|
||||
self.finalize_insertion()
|
||||
|
||||
def add_experiment_1(self, idx_scenario, idx_init_removal,
|
||||
dct_lst_init_disrupt_firm_prod, g_bom,
|
||||
n_max_trial, prf_size, prf_conn,
|
||||
cap_limit_prob_type, cap_limit_level,
|
||||
diff_new_conn, remove_t, netw_prf_n):
|
||||
e = Experiment(
|
||||
idx_scenario=idx_scenario,
|
||||
idx_init_removal=idx_init_removal,
|
||||
n_sample=int(self.dct_parameter['n_sample']),
|
||||
n_iter=int(self.dct_parameter['n_iter']),
|
||||
dct_lst_init_disrupt_firm_prod=dct_lst_init_disrupt_firm_prod,
|
||||
g_bom=g_bom,
|
||||
n_max_trial=n_max_trial,
|
||||
prf_size=prf_size,
|
||||
prf_conn=prf_conn,
|
||||
cap_limit_prob_type=cap_limit_prob_type,
|
||||
cap_limit_level=cap_limit_level,
|
||||
diff_new_conn=diff_new_conn,
|
||||
remove_t=remove_t,
|
||||
netw_prf_n=netw_prf_n
|
||||
)
|
||||
# 这里我们不立即提交,而是先添加到批量保存的队列中
|
||||
self.experiment_data.append(e)
|
||||
|
||||
# 当批量数据达到一定数量时再提交
|
||||
if len(self.experiment_data) >= self.batch_size:
|
||||
self._commit_batch()
|
||||
|
||||
# 辅助方法:批量提交
|
||||
def _commit_batch(self):
|
||||
db_session.bulk_save_objects(self.experiment_data)
|
||||
db_session.commit()
|
||||
self.experiment_data.clear() # 清空队列
|
||||
|
||||
def finalize_insertion(self):
|
||||
if self.experiment_data:
|
||||
self._commit_batch() # 提交剩余的数据
|
||||
|
||||
def fill_sample_table(self):
|
||||
rng = random.Random(self.dct_parameter['meta_seed'])
|
||||
# 根据样本数目 设置 32 位随机整数
|
||||
lst_seed = [
|
||||
rng.getrandbits(32)
|
||||
for _ in range(int(self.dct_parameter['n_sample']))
|
||||
]
|
||||
lst_exp = db_session.query(Experiment).all()
|
||||
|
||||
lst_sample = []
|
||||
for experiment in lst_exp:
|
||||
# idx_sample: 1-50
|
||||
for idx_sample in range(int(experiment.n_sample)):
|
||||
s = Sample(e_id=experiment.id,
|
||||
idx_sample=idx_sample + 1,
|
||||
seed=lst_seed[idx_sample],
|
||||
is_done_flag=-1)
|
||||
lst_sample.append(s)
|
||||
# 每当达到批量大小时提交一次
|
||||
if len(lst_sample) >= self.batch_size:
|
||||
db_session.bulk_save_objects(lst_sample)
|
||||
db_session.commit()
|
||||
print(f'Inserted {len(lst_sample)} samples!')
|
||||
lst_sample.clear() # 清空已提交的样本列表
|
||||
|
||||
# 提交剩余的样本
|
||||
if lst_sample:
|
||||
db_session.bulk_save_objects(lst_sample)
|
||||
db_session.commit()
|
||||
print(f'Inserted {len(lst_sample)} samples!')
|
||||
|
||||
def reset_db(self, force_drop=False):
|
||||
# first, check if tables exist
|
||||
lst_table_obj = [
|
||||
Base.metadata.tables[str_table]
|
||||
for str_table in ins.get_table_names()
|
||||
if str_table.startswith(self.db_name_prefix)
|
||||
]
|
||||
self.is_exist = len(lst_table_obj) > 0
|
||||
if force_drop:
|
||||
self.force_drop_db(lst_table_obj)
|
||||
# while is_exist:
|
||||
# a_table = random.choice(lst_table_obj)
|
||||
# try:
|
||||
# Base.metadata.drop_all(bind=engine, tables=[a_table])
|
||||
# except KeyError:
|
||||
# pass
|
||||
# except OperationalError:
|
||||
# pass
|
||||
# else:
|
||||
# lst_table_obj.remove(a_table)
|
||||
# print(
|
||||
# f"Table {a_table.name} is dropped "
|
||||
# f"for exp: {self.db_name_prefix}!!!"
|
||||
# )
|
||||
# finally:
|
||||
# is_exist = len(lst_table_obj) > 0
|
||||
|
||||
if self.is_exist:
|
||||
print(
|
||||
f"All tables exist. No need to reset "
|
||||
f"for exp: {self.db_name_prefix}."
|
||||
)
|
||||
# change the is_done_flag from 0 to -1
|
||||
# rerun the in-finished tasks
|
||||
self.is_exist_reset_flag_resset_db()
|
||||
# if self.reset_flag > 0:
|
||||
# if self.reset_flag == 2:
|
||||
# sample = db_session.query(Sample).filter(
|
||||
# Sample.is_done_flag == 0)
|
||||
# elif self.reset_flag == 1:
|
||||
# sample = db_session.query(Sample).filter(
|
||||
# Sample.is_done_flag == 0,
|
||||
# Sample.computer_name == platform.node())
|
||||
# else:
|
||||
# raise ValueError('Wrong reset flag')
|
||||
# if sample.count() > 0:
|
||||
# for s in sample:
|
||||
# qry_result = db_session.query(Result).filter_by(
|
||||
# s_id=s.id)
|
||||
# if qry_result.count() > 0:
|
||||
# db_session.query(Result).filter(s_id=s.id).delete()
|
||||
# db_session.commit()
|
||||
# s.is_done_flag = -1
|
||||
# db_session.commit()
|
||||
# print(f"Reset the sample id {s.id} flag from 0 to -1")
|
||||
|
||||
else:
|
||||
# 不存在则重新生成所有的表结构
|
||||
Base.metadata.create_all(bind=engine)
|
||||
self.init_tables()
|
||||
print(
|
||||
f"All tables are just created and initialized "
|
||||
f"for exp: {self.db_name_prefix}."
|
||||
)
|
||||
|
||||
def force_drop_db(self, lst_table_obj):
|
||||
self.is_exist = len(lst_table_obj) > 0
|
||||
while self.is_exist:
|
||||
a_table = random.choice(lst_table_obj)
|
||||
try:
|
||||
Base.metadata.drop_all(bind=engine, tables=[a_table])
|
||||
except KeyError:
|
||||
pass
|
||||
except OperationalError:
|
||||
pass
|
||||
else:
|
||||
lst_table_obj.remove(a_table)
|
||||
print(
|
||||
f"Table {a_table.name} is dropped "
|
||||
f"for exp: {self.db_name_prefix}!!!"
|
||||
)
|
||||
finally:
|
||||
self.is_exist = len(lst_table_obj) > 0
|
||||
|
||||
def is_exist_reset_flag_resset_db(self):
|
||||
if self.reset_flag > 0:
|
||||
if self.reset_flag == 2:
|
||||
sample = db_session.query(Sample).filter(
|
||||
Sample.is_done_flag == 0)
|
||||
elif self.reset_flag == 1:
|
||||
sample = db_session.query(Sample).filter(
|
||||
Sample.is_done_flag == 0,
|
||||
Sample.computer_name == platform.node())
|
||||
else:
|
||||
raise ValueError('Wrong reset flag')
|
||||
if sample.count() > 0:
|
||||
for s in sample:
|
||||
qry_result = db_session.query(Result).filter_by(
|
||||
s_id=s.id)
|
||||
if qry_result.count() > 0:
|
||||
db_session.query(Result).filter(s_id=s.id).delete()
|
||||
db_session.commit()
|
||||
s.is_done_flag = -1
|
||||
db_session.commit()
|
||||
print(f"Reset the sample id {s.id} flag from 0 to -1")
|
||||
|
||||
def prepare_list_sample(self):
|
||||
# 为了符合前面 重置表里面存在 重置本机 或者重置全部 或者不重置的部分 这个部分的 关于样本运行也得重新拿出来
|
||||
# 查找一个风险事件中 50 个样本
|
||||
res = db_session.execute(
|
||||
text(f"SELECT count(*) FROM {self.db_name_prefix}_sample s, "
|
||||
f"{self.db_name_prefix}_experiment e WHERE s.e_id=e.id"
|
||||
)).scalar()
|
||||
# 控制 n_sample数量 作为后面的参数
|
||||
n_sample = 0 if res is None else res
|
||||
# print(f'There are a total of {n_sample} samples.')
|
||||
# 查找 is_done_flag = -1 也就是没有运行的 样本 运行后会改为0
|
||||
res = db_session.execute(
|
||||
text(f"SELECT id FROM {self.db_name_prefix}_sample "
|
||||
f"WHERE is_done_flag = -1"
|
||||
))
|
||||
for row in res:
|
||||
s_id = row[0]
|
||||
self.lst_saved_s_id.append(s_id)
|
||||
|
||||
@staticmethod
|
||||
def select_random_sample(lst_s_id):
|
||||
temp_lst = lst_s_id[:] # 复制列表
|
||||
while temp_lst:
|
||||
s_id = random.choice(temp_lst)
|
||||
temp_lst.remove(s_id) # 从临时列表删除
|
||||
res = db_session.query(Sample).filter(
|
||||
Sample.id == int(s_id),
|
||||
Sample.is_done_flag == -1
|
||||
)
|
||||
if res.count() == 1:
|
||||
return res[0]
|
||||
# 尝试完所有样本都没找到
|
||||
return None
|
||||
|
||||
def fetch_a_sample(self, s_id=None):
|
||||
# 由Computation 调用 返回 sample对象 同时给出 2中 指定访问模式 抓取特定的 样本 通过s_id
|
||||
# 默认访问 flag为-1的 lst_saved_s_id
|
||||
if s_id is not None:
|
||||
res = db_session.query(Sample).filter(Sample.id == int(s_id))
|
||||
if res.count() == 0:
|
||||
return None
|
||||
else:
|
||||
return res[0]
|
||||
|
||||
sample = self.select_random_sample(self.lst_saved_s_id)
|
||||
if sample is not None:
|
||||
return sample
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def lock_the_sample(sample: Sample):
|
||||
sample.is_done_flag, sample.computer_name = 0, platform.node()
|
||||
db_session.commit()
|
||||
|
||||
def reset_sample_db(self):
|
||||
"""
|
||||
将 iiabmdb_20250925.without_exp_sample 表中
|
||||
所有样本的 is_done_flag 更新为 -1
|
||||
"""
|
||||
sql = text("UPDATE iiabmdb_20250925.without_exp_sample SET is_done_flag = -1")
|
||||
db_session.execute(sql)
|
||||
db_session.commit()
|
||||
|
||||
if __name__ == '__main__':
|
||||
print("Testing the database connection...")
|
||||
try:
|
||||
controller_db = ControllerDB('test')
|
||||
Base.metadata.create_all(bind=engine)
|
||||
except Exception as e:
|
||||
print("Failed to connect to the database!")
|
||||
print(e)
|
||||
exit(1)
|
||||
4571
GA_Agent_0925/count_dcp.csv
Normal file
4571
GA_Agent_0925/count_dcp.csv
Normal file
File diff suppressed because it is too large
Load Diff
204
GA_Agent_0925/count_firm.csv
Normal file
204
GA_Agent_0925/count_firm.csv
Normal file
@@ -0,0 +1,204 @@
|
||||
id_firm,count
|
||||
214851100,1722
|
||||
3111603340,1654
|
||||
70634828,1548
|
||||
25980377,643
|
||||
340093034,616
|
||||
395736790,607
|
||||
3330358736,567
|
||||
29223617,564
|
||||
303926772,494
|
||||
532328014,488
|
||||
331545755,339
|
||||
2337727838,333
|
||||
2326722141,332
|
||||
3191869223,272
|
||||
728969035,240
|
||||
2327605629,237
|
||||
591350440,232
|
||||
517675473,226
|
||||
2336923756,223
|
||||
16116663,205
|
||||
2349705416,196
|
||||
16210433,168
|
||||
471121089,168
|
||||
2349179532,167
|
||||
1452048,22
|
||||
515770253,15
|
||||
2349345463,10
|
||||
495782506,10
|
||||
300186799,10
|
||||
3312358902,10
|
||||
33822284,10
|
||||
2728939,10
|
||||
420984285,10
|
||||
6,10
|
||||
37873062,10
|
||||
3226664625,10
|
||||
1,10
|
||||
80158773,10
|
||||
78979697,10
|
||||
8,10
|
||||
169978927,9
|
||||
5849940,9
|
||||
3392803162,8
|
||||
11807506,8
|
||||
79938367,8
|
||||
5971532,8
|
||||
2424229017,8
|
||||
314846874,7
|
||||
14913649,6
|
||||
3462551351,5
|
||||
872394725,5
|
||||
35404067,5
|
||||
805940123,5
|
||||
3362063909,5
|
||||
3358892171,5
|
||||
862404568,5
|
||||
6333996,5
|
||||
9032550,5
|
||||
367669349,5
|
||||
3268669333,5
|
||||
950849442,5
|
||||
3226232,5
|
||||
3203980088,5
|
||||
3195293647,5
|
||||
31732840,5
|
||||
961017,5
|
||||
3151377261,5
|
||||
3147958370,5
|
||||
9620005,5
|
||||
10437056,5
|
||||
907433543,5
|
||||
382080545,5
|
||||
598808584,5
|
||||
676597455,5
|
||||
640700057,5
|
||||
644252759,5
|
||||
596368303,5
|
||||
59234665,5
|
||||
648145286,5
|
||||
668539285,5
|
||||
675729777,5
|
||||
578803019,5
|
||||
3118140206,5
|
||||
551856519,5
|
||||
543470507,5
|
||||
688155470,5
|
||||
385766513,5
|
||||
71271700,5
|
||||
733657390,5
|
||||
737770776,5
|
||||
507827038,5
|
||||
756272716,5
|
||||
758879940,5
|
||||
4607820,5
|
||||
441623911,5
|
||||
771821595,5
|
||||
410030851,5
|
||||
38852110,5
|
||||
562681526,5
|
||||
9746245,5
|
||||
26487185,5
|
||||
197362120,5
|
||||
2333843479,5
|
||||
25685135,5
|
||||
247297633,5
|
||||
2448521375,5
|
||||
2353549582,5
|
||||
2353389310,5
|
||||
2352421906,5
|
||||
205960791,5
|
||||
2351592628,5
|
||||
225958786,5
|
||||
2350443114,5
|
||||
2310406050,5
|
||||
2310534839,5
|
||||
2349742676,5
|
||||
2349588257,5
|
||||
2349349655,5
|
||||
2311838590,5
|
||||
2348894245,5
|
||||
2314659369,5
|
||||
2316150629,5
|
||||
2345982379,5
|
||||
2326520912,5
|
||||
2337952436,5
|
||||
203314437,5
|
||||
2334430421,5
|
||||
193814549,5
|
||||
157087137,5
|
||||
1160497810,5
|
||||
1171244159,5
|
||||
3010580773,5
|
||||
29954548,5
|
||||
2989649772,5
|
||||
2978926070,5
|
||||
1208566436,5
|
||||
178452970,5
|
||||
1476953321,5
|
||||
1444449910,5
|
||||
280281699,5
|
||||
16715045,5
|
||||
26895145,5
|
||||
286335813,5
|
||||
28667694,5
|
||||
742704658,4
|
||||
2326903290,4
|
||||
2326655246,4
|
||||
1698501971,4
|
||||
2323069589,4
|
||||
1605495,4
|
||||
696450846,4
|
||||
1247902451,4
|
||||
1253552935,4
|
||||
664591135,4
|
||||
863973253,4
|
||||
1651310523,4
|
||||
8114841,4
|
||||
3135349256,4
|
||||
3145389278,4
|
||||
2348987001,4
|
||||
3420061649,4
|
||||
3031766093,4
|
||||
3196033145,4
|
||||
265133300,4
|
||||
2350544061,4
|
||||
3378606529,4
|
||||
3011933107,4
|
||||
493002466,4
|
||||
290636928,4
|
||||
518871190,4
|
||||
2347561020,4
|
||||
3222664794,4
|
||||
2344471631,4
|
||||
3220049148,4
|
||||
28665295,3
|
||||
78576577,3
|
||||
2341774429,3
|
||||
808524154,3
|
||||
2944593082,3
|
||||
1524794108,3
|
||||
29452962,3
|
||||
3222821993,3
|
||||
13854344,3
|
||||
778745779,3
|
||||
340603317,3
|
||||
762501019,3
|
||||
27042865,3
|
||||
5979030,3
|
||||
189427260,3
|
||||
3429928077,3
|
||||
2382390052,3
|
||||
395739442,3
|
||||
2349746655,3
|
||||
466148111,3
|
||||
643954924,3
|
||||
618469306,3
|
||||
23421122,3
|
||||
2962064709,3
|
||||
308365582,2
|
||||
1717102128,1
|
||||
2959520478,1
|
||||
3449575456,1
|
||||
2346894985,1
|
||||
|
266
GA_Agent_0925/count_firm_prod.csv
Normal file
266
GA_Agent_0925/count_firm_prod.csv
Normal file
@@ -0,0 +1,266 @@
|
||||
id_firm,id_product,count
|
||||
340093034,95,616
|
||||
395736790,95,607
|
||||
29223617,95,564
|
||||
303926772,99,494
|
||||
532328014,99,488
|
||||
331545755,90,339
|
||||
2337727838,90,333
|
||||
2326722141,90,332
|
||||
3191869223,91,272
|
||||
728969035,93,240
|
||||
2327605629,94,237
|
||||
591350440,91,232
|
||||
517675473,92,226
|
||||
2336923756,91,223
|
||||
2349705416,94,196
|
||||
3111603340,53,193
|
||||
3111603340,55,192
|
||||
214851100,54,189
|
||||
214851100,55,188
|
||||
214851100,52,188
|
||||
16116663,92,187
|
||||
214851100,53,187
|
||||
214851100,50,187
|
||||
214851100,51,184
|
||||
3111603340,54,182
|
||||
3111603340,52,181
|
||||
70634828,53,179
|
||||
3111603340,51,178
|
||||
70634828,52,178
|
||||
70634828,54,178
|
||||
3111603340,50,178
|
||||
70634828,51,177
|
||||
70634828,55,171
|
||||
16210433,92,168
|
||||
471121089,93,168
|
||||
70634828,50,168
|
||||
2349179532,93,167
|
||||
25980377,39,115
|
||||
25980377,38,112
|
||||
25980377,43,108
|
||||
214851100,49,106
|
||||
25980377,41,105
|
||||
214851100,47,104
|
||||
3330358736,43,104
|
||||
25980377,40,102
|
||||
3330358736,38,102
|
||||
214851100,46,102
|
||||
25980377,42,101
|
||||
3111603340,47,100
|
||||
214851100,44,98
|
||||
3111603340,45,96
|
||||
3330358736,39,93
|
||||
214851100,48,93
|
||||
3330358736,41,92
|
||||
3330358736,40,92
|
||||
214851100,45,91
|
||||
70634828,47,90
|
||||
3111603340,46,89
|
||||
3111603340,49,89
|
||||
70634828,45,88
|
||||
3111603340,44,86
|
||||
3111603340,48,85
|
||||
3330358736,42,84
|
||||
70634828,46,81
|
||||
70634828,44,80
|
||||
70634828,48,77
|
||||
70634828,49,76
|
||||
515770253,9,15
|
||||
1452048,9,12
|
||||
6,10,10
|
||||
495782506,19,10
|
||||
420984285,16,10
|
||||
37873062,9,10
|
||||
33822284,9,10
|
||||
8,37,10
|
||||
80158773,69,10
|
||||
1,10,10
|
||||
16116663,11,10
|
||||
3392803162,9,8
|
||||
79938367,9,8
|
||||
5971532,9,8
|
||||
14913649,9,6
|
||||
950849442,9,5
|
||||
16116663,10,5
|
||||
382080545,9,5
|
||||
385766513,24,5
|
||||
38852110,10,5
|
||||
410030851,27,5
|
||||
169978927,66,5
|
||||
441623911,79,5
|
||||
4607820,9,5
|
||||
16715045,10,5
|
||||
507827038,10,5
|
||||
9620005,9,5
|
||||
543470507,8,5
|
||||
551856519,33,5
|
||||
562681526,25,5
|
||||
578803019,10,5
|
||||
367669349,31,5
|
||||
35404067,9,5
|
||||
3462551351,13,5
|
||||
214851100,7,5
|
||||
3226664625,28,5
|
||||
3268669333,67,5
|
||||
3312358902,59,5
|
||||
3312358902,79,5
|
||||
2310406050,60,5
|
||||
225958786,11,5
|
||||
205960791,63,5
|
||||
178452970,25,5
|
||||
203314437,22,5
|
||||
197362120,15,5
|
||||
193814549,33,5
|
||||
3358892171,64,5
|
||||
3362063909,10,5
|
||||
2349588257,10,5
|
||||
5849940,26,5
|
||||
59234665,65,5
|
||||
907433543,10,5
|
||||
78979697,74,5
|
||||
733657390,10,5
|
||||
737770776,34,5
|
||||
756272716,32,5
|
||||
758879940,65,5
|
||||
771821595,31,5
|
||||
78979697,61,5
|
||||
961017,23,5
|
||||
70634828,7,5
|
||||
1171244159,32,5
|
||||
1160497810,12,5
|
||||
805940123,72,5
|
||||
862404568,11,5
|
||||
872394725,70,5
|
||||
9032550,34,5
|
||||
71271700,27,5
|
||||
1208566436,62,5
|
||||
596368303,12,5
|
||||
668539285,17,5
|
||||
3226232,10,5
|
||||
598808584,15,5
|
||||
6333996,66,5
|
||||
640700057,9,5
|
||||
644252759,61,5
|
||||
648145286,35,5
|
||||
675729777,10,5
|
||||
1444449910,20,5
|
||||
676597455,68,5
|
||||
688155470,30,5
|
||||
157087137,73,5
|
||||
1476953321,10,5
|
||||
1452048,30,5
|
||||
1452048,11,5
|
||||
3226664625,13,5
|
||||
2349349655,9,5
|
||||
9746245,97,5
|
||||
2424229017,26,5
|
||||
2334430421,73,5
|
||||
2337952436,24,5
|
||||
2352421906,17,5
|
||||
2353389310,15,5
|
||||
2353549582,23,5
|
||||
3010580773,9,5
|
||||
300186799,11,5
|
||||
300186799,10,5
|
||||
29954548,18,5
|
||||
2989649772,71,5
|
||||
2978926070,8,5
|
||||
2448521375,25,5
|
||||
10437056,9,5
|
||||
28667694,10,5
|
||||
247297633,18,5
|
||||
286335813,72,5
|
||||
280281699,60,5
|
||||
2728939,71,5
|
||||
2728939,63,5
|
||||
25685135,11,5
|
||||
26895145,9,5
|
||||
26487185,62,5
|
||||
2349345463,35,5
|
||||
2345982379,67,5
|
||||
2333843479,70,5
|
||||
2348894245,11,5
|
||||
2326520912,10,5
|
||||
2350443114,74,5
|
||||
3118140206,68,5
|
||||
3147958370,12,5
|
||||
314846874,11,5
|
||||
2349742676,33,5
|
||||
3111603340,7,5
|
||||
2349345463,59,5
|
||||
2311838590,97,5
|
||||
3151377261,29,5
|
||||
31732840,29,5
|
||||
2310534839,64,5
|
||||
3195293647,28,5
|
||||
2316150629,10,5
|
||||
3203980088,10,5
|
||||
2314659369,20,5
|
||||
2351592628,10,5
|
||||
11807506,9,4
|
||||
664591135,9,4
|
||||
11807506,36,4
|
||||
3145389278,9,4
|
||||
2347561020,9,4
|
||||
742704658,9,4
|
||||
3196033145,9,4
|
||||
265133300,9,4
|
||||
696450846,9,4
|
||||
3220049148,9,4
|
||||
1247902451,9,4
|
||||
1253552935,9,4
|
||||
2344471631,9,4
|
||||
290636928,9,4
|
||||
1605495,9,4
|
||||
3031766093,9,4
|
||||
3011933107,9,4
|
||||
169978927,9,4
|
||||
863973253,9,4
|
||||
2326903290,9,4
|
||||
2323069589,9,4
|
||||
1698501971,9,4
|
||||
493002466,9,4
|
||||
8114841,9,4
|
||||
1651310523,9,4
|
||||
2350544061,9,4
|
||||
2348987001,9,4
|
||||
518871190,9,4
|
||||
3420061649,9,4
|
||||
5849940,36,4
|
||||
3378606529,9,4
|
||||
3222664794,9,4
|
||||
2326655246,9,4
|
||||
3135349256,9,4
|
||||
778745779,9,3
|
||||
762501019,9,3
|
||||
2349746655,10,3
|
||||
808524154,9,3
|
||||
2382390052,9,3
|
||||
2424229017,9,3
|
||||
78576577,9,3
|
||||
3222821993,9,3
|
||||
16116663,9,3
|
||||
2962064709,9,3
|
||||
189427260,9,3
|
||||
340603317,9,3
|
||||
3429928077,9,3
|
||||
395739442,9,3
|
||||
466148111,9,3
|
||||
2341774429,9,3
|
||||
13854344,9,3
|
||||
23421122,9,3
|
||||
5979030,9,3
|
||||
618469306,9,3
|
||||
29452962,9,3
|
||||
2944593082,9,3
|
||||
643954924,9,3
|
||||
28665295,9,3
|
||||
1524794108,9,3
|
||||
27042865,9,3
|
||||
308365582,9,2
|
||||
314846874,9,2
|
||||
2959520478,9,1
|
||||
1717102128,9,1
|
||||
3449575456,9,1
|
||||
2346894985,9,1
|
||||
|
73
GA_Agent_0925/count_prod.csv
Normal file
73
GA_Agent_0925/count_prod.csv
Normal file
@@ -0,0 +1,73 @@
|
||||
id_product,count
|
||||
95,1787
|
||||
90,1004
|
||||
99,982
|
||||
91,727
|
||||
92,581
|
||||
93,575
|
||||
53,559
|
||||
55,551
|
||||
54,549
|
||||
52,547
|
||||
51,539
|
||||
50,533
|
||||
94,433
|
||||
9,338
|
||||
47,294
|
||||
45,275
|
||||
46,272
|
||||
49,271
|
||||
44,264
|
||||
48,255
|
||||
38,214
|
||||
43,212
|
||||
39,208
|
||||
41,197
|
||||
40,194
|
||||
42,185
|
||||
10,113
|
||||
11,45
|
||||
15,15
|
||||
12,15
|
||||
33,15
|
||||
7,15
|
||||
25,15
|
||||
69,10
|
||||
74,10
|
||||
68,10
|
||||
70,10
|
||||
71,10
|
||||
72,10
|
||||
73,10
|
||||
19,10
|
||||
79,10
|
||||
8,10
|
||||
66,10
|
||||
18,10
|
||||
17,10
|
||||
16,10
|
||||
13,10
|
||||
97,10
|
||||
67,10
|
||||
32,10
|
||||
65,10
|
||||
64,10
|
||||
34,10
|
||||
35,10
|
||||
37,10
|
||||
31,10
|
||||
30,10
|
||||
29,10
|
||||
28,10
|
||||
27,10
|
||||
26,10
|
||||
24,10
|
||||
23,10
|
||||
20,10
|
||||
59,10
|
||||
60,10
|
||||
62,10
|
||||
63,10
|
||||
61,10
|
||||
36,8
|
||||
22,5
|
||||
|
56
GA_Agent_0925/creating.py
Normal file
56
GA_Agent_0925/creating.py
Normal file
@@ -0,0 +1,56 @@
|
||||
import random
|
||||
from deap import creator, base, tools
|
||||
from evaluate_func import fitness
|
||||
|
||||
def creating():
|
||||
if "FitnessMax" not in creator.__dict__:
|
||||
creator.create("FitnessMax", base.Fitness, weights=(1.0,))
|
||||
if "Individual" not in creator.__dict__:
|
||||
creator.create("Individual", list, fitness=creator.FitnessMax)
|
||||
|
||||
toolbox = base.Toolbox()
|
||||
|
||||
# 基因注册
|
||||
toolbox.register("n_max_trial", random.randint, 1, 40)
|
||||
toolbox.register("prf_size", random.uniform, 0.0, 1.0)
|
||||
toolbox.register("prf_conn", random.uniform, 0.0, 1.0)
|
||||
toolbox.register("cap_limit_prob_type", random.randint, 0, 1)
|
||||
toolbox.register("cap_limit_level", random.randint, 5, 50)
|
||||
toolbox.register("diff_new_conn", random.uniform, 0.0, 1.0)
|
||||
toolbox.register("netw_prf_n", random.randint, 1, 20)
|
||||
toolbox.register("s_r", random.uniform, 0.05, 0.5)
|
||||
toolbox.register("S_r", random.uniform, 0.5, 1.0)
|
||||
toolbox.register("x", random.uniform, 0.0, 1)
|
||||
toolbox.register("k", random.uniform, 0.05, 2.0)
|
||||
toolbox.register("production_increase_ratio", random.uniform, 0.5, 2.0)
|
||||
|
||||
# 个体与种群注册
|
||||
toolbox.register(
|
||||
"individual",
|
||||
tools.initCycle,
|
||||
creator.Individual,
|
||||
(
|
||||
toolbox.n_max_trial,
|
||||
toolbox.prf_size,
|
||||
toolbox.prf_conn,
|
||||
toolbox.cap_limit_prob_type,
|
||||
toolbox.cap_limit_level,
|
||||
toolbox.diff_new_conn,
|
||||
toolbox.netw_prf_n,
|
||||
toolbox.s_r,
|
||||
toolbox.S_r,
|
||||
toolbox.x,
|
||||
toolbox.k,
|
||||
toolbox.production_increase_ratio
|
||||
),
|
||||
n=1
|
||||
)
|
||||
toolbox.register("population", tools.initRepeat, list, toolbox.individual)
|
||||
|
||||
# 遗传算子
|
||||
toolbox.register("evaluate", fitness)
|
||||
toolbox.register("mate", tools.cxTwoPoint)
|
||||
toolbox.register("mutate", tools.mutShuffleIndexes, indpb=0.1)
|
||||
toolbox.register("select", tools.selTournament, tournsize=3)
|
||||
|
||||
return toolbox
|
||||
327
GA_Agent_0925/evaluate_func.py
Normal file
327
GA_Agent_0925/evaluate_func.py
Normal file
@@ -0,0 +1,327 @@
|
||||
import os
|
||||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from multiprocessing import Process
|
||||
|
||||
import pandas as pd
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from my_model import MyModel
|
||||
from orm import connection, engine
|
||||
|
||||
|
||||
# 🎯 适应度函数(核心目标函数)
|
||||
def fitness(individual, controller_db_obj):
|
||||
"""
|
||||
遗传算法适应度函数:用于评估个体(模型参数)的优劣。
|
||||
|
||||
参数:
|
||||
individual : list
|
||||
个体参数列表:
|
||||
[n_max_trial, prf_size, prf_conn, cap_limit_prob_type, cap_limit_level,
|
||||
diff_new_conn, netw_prf_n, s_r, S_r, x, k, production_increase_ratio]
|
||||
|
||||
目标:
|
||||
使 ABM 模型生成的“脆弱产业集合”与目标产业集合尽可能相似。
|
||||
- fitness = -error
|
||||
- error = 模拟结果集合与目标集合的差异度(越小越好)
|
||||
"""
|
||||
# 生成唯一 GA ID
|
||||
ga_id = str(uuid.uuid4())[:8] # 简短随机ID
|
||||
individual.ga_id = ga_id # 将 ga_id 绑定到个体上
|
||||
# ========== 1️⃣ 生成参数字典 ==========
|
||||
dct_exp = {
|
||||
'n_max_trial': individual[0],
|
||||
'prf_size': individual[1],
|
||||
'prf_conn': individual[2],
|
||||
'cap_limit_prob_type': individual[3],
|
||||
'cap_limit_level': individual[4],
|
||||
'diff_new_conn': individual[5],
|
||||
'netw_prf_n': individual[6],
|
||||
's_r': individual[7],
|
||||
'S_r': individual[8],
|
||||
'x': individual[9],
|
||||
'k': individual[10],
|
||||
'production_increase_ratio': individual[11]
|
||||
}
|
||||
# 将 GA 染色体的值映射为实际 ABM 参数
|
||||
if dct_exp['cap_limit_prob_type'] == 0:
|
||||
dct_exp['cap_limit_prob_type'] = "uniform" # 类型A(例如 uniform)
|
||||
else:
|
||||
dct_exp['cap_limit_prob_type'] = "normal" # 类型B(例如 normal)
|
||||
|
||||
# 打印 GA ID 和参数
|
||||
print(f"\n 正在执行 GA 个体 {ga_id},参数如下:")
|
||||
# for key, value in dct_exp.items():
|
||||
# print(f" {key}: {value}")
|
||||
# ========== 2️⃣ 调用 ABM 模型 ==========
|
||||
# 并行进程数目
|
||||
job=6
|
||||
|
||||
do_process(controller_db_obj,ga_id,dct_exp,job)
|
||||
# ========== 3️⃣ 获取数据库连接并提取结果 ==========
|
||||
simulated_vulnerable_industries = get_vulnerable100_code(connection,ga_id)
|
||||
print(simulated_vulnerable_industries)
|
||||
# ========== 4️⃣ 获取目标产业集合 ==========
|
||||
target_vulnerable_industries = get_target_vulnerable_industries()
|
||||
|
||||
# ========== 5️⃣ 计算误差(集合差异度) ==========
|
||||
set_sim = set(simulated_vulnerable_industries)
|
||||
set_target = set(target_vulnerable_industries)
|
||||
error = len(set_sim.symmetric_difference(set_target))
|
||||
|
||||
simulated_set = set(simulated_vulnerable_industries)
|
||||
target_set = set(target_vulnerable_industries)
|
||||
|
||||
matching = simulated_set & target_set # 交集
|
||||
extra = simulated_set - target_set # 模拟多出的产业
|
||||
missing = target_set - simulated_set # 未覆盖产业
|
||||
|
||||
print(f"符合目标的产业数量: {len(matching)}")
|
||||
print(matching)
|
||||
print(f"模拟多出的产业数量: {len(extra)}")
|
||||
print(f"未覆盖目标产业数量: {len(missing)}")
|
||||
|
||||
# ========== 6️⃣ 返回适应度(越大越好) ==========
|
||||
return (float(-error),)
|
||||
# 目标产业集合
|
||||
def get_target_vulnerable_industries():
|
||||
"""
|
||||
获取行业列表中所有产业链编号的集合(整数形式)。
|
||||
说明:
|
||||
- 输入的 industry_list 是一个字典列表,每个字典包含:
|
||||
{"product": 产品名称, "category": 产品类别, "chain_id": 产业链编号}
|
||||
- 某些 chain_id 可能是复合编号,例如 "11 / 513742",需要拆分成单独整数。
|
||||
- 输出是一个 set,包含所有 chain_id(去重、整数形式)。
|
||||
|
||||
参数:
|
||||
industry_list : list of dict
|
||||
行业字典列表,每个字典必须包含 "chain_id" 键。
|
||||
|
||||
返回:
|
||||
set
|
||||
所有产业链编号的整数集合。
|
||||
"""
|
||||
industry_list = [
|
||||
# ① 半导体设备类
|
||||
{"product": "离子注入机", "category": "离子注入设备", "chain_id": 34538},
|
||||
{"product": "刻蚀设备 / 湿法刻蚀设备", "category": "刻蚀机", "chain_id": 34529},
|
||||
{"product": "沉积设备", "category": "薄膜生长设备(CVD/PVD)", "chain_id": 34539},
|
||||
{"product": "CVD", "category": "薄膜生长设备", "chain_id": 34539},
|
||||
{"product": "PVD", "category": "薄膜生长设备", "chain_id": 34539},
|
||||
{"product": "CMP", "category": "化学机械抛光设备", "chain_id": 34530},
|
||||
{"product": "光刻机", "category": "光刻机", "chain_id": 34533},
|
||||
{"product": "涂胶显影机", "category": "涂胶显影设备", "chain_id": 34535},
|
||||
{"product": "晶圆清洗设备", "category": "晶圆清洗机", "chain_id": 34531},
|
||||
{"product": "测试设备", "category": "测试机", "chain_id": 34554},
|
||||
{"product": "外延生长设备", "category": "薄膜生长设备", "chain_id": 34539},
|
||||
|
||||
# ② 半导体材料与化学品类
|
||||
{"product": "三氯乙烯", "category": "清洗溶剂 → 通用湿电子化学品", "chain_id": 32438},
|
||||
{"product": "丙酮", "category": "清洗溶剂 → 通用湿电子化学品", "chain_id": 32438},
|
||||
{"product": "异丙醇", "category": "清洗溶剂 → 通用湿电子化学品", "chain_id": 32438},
|
||||
{"product": "其他醇类", "category": "清洗溶剂 → 通用湿电子化学品", "chain_id": 32438},
|
||||
{"product": "光刻胶", "category": "光刻胶及配套试剂", "chain_id": 32445},
|
||||
{"product": "显影液", "category": "显影液", "chain_id": 46504},
|
||||
{"product": "蚀刻液", "category": "蚀刻液", "chain_id": 56341},
|
||||
{"product": "光阻去除剂", "category": "光阻去除剂", "chain_id": 32442},
|
||||
|
||||
# ③ 晶圆制造类
|
||||
{"product": "晶圆", "category": "单晶硅片 / 多晶硅片", "chain_id": 32338},
|
||||
{"product": "硅衬底", "category": "硅衬底", "chain_id": 36914},
|
||||
{"product": "外延片", "category": "硅外延片 / GaN外延片 / SiC外延片等", "chain_id": 32338},
|
||||
|
||||
# ④ 封装与测试类
|
||||
{"product": "封装", "category": "IC封装", "chain_id": 10},
|
||||
{"product": "测试", "category": "芯片测试 / 晶圆测试", "chain_id": 513742},
|
||||
{"product": "测试", "category": "芯片测试 / 晶圆测试", "chain_id": 11},
|
||||
|
||||
# ⑤ 芯片与设计EDA类
|
||||
{"product": "芯片(通用)", "category": "集成电路制造", "chain_id": 317589},
|
||||
{"product": "DRAM", "category": "存储芯片 → 集成电路制造", "chain_id": 317589},
|
||||
{"product": "GPU", "category": "图形芯片 → 集成电路制造", "chain_id": 317589},
|
||||
{"product": "处理器(CPU/SoC)", "category": "芯片设计", "chain_id": 9},
|
||||
{"product": "高频芯片", "category": "芯片设计", "chain_id": 9},
|
||||
{"product": "光子芯片(含激光)", "category": "芯片设计 / 功率半导体器件", "chain_id": 9},
|
||||
{"product": "光子芯片(含激光)", "category": "芯片设计 / 功率半导体器件", "chain_id": 2717},
|
||||
{"product": "先进节点制造设备", "category": "集成电路制造", "chain_id": 317589},
|
||||
{"product": "EDA及IP服务", "category": "设计辅助", "chain_id": 2515},
|
||||
{"product": "MPW服务", "category": "多项目晶圆流片", "chain_id": 2514},
|
||||
{"product": "芯片设计验证", "category": "设计验证", "chain_id": 513738},
|
||||
{"product": "过程工艺检测", "category": "制程检测", "chain_id": 513740}
|
||||
]
|
||||
# 手工转换
|
||||
industry_list_index = vulnerable100_index = \
|
||||
['100', '58', '61', '9', '7', '98', '57', '8', '65', '68', '66', '38',
|
||||
'90', '21', '96', '71', '27', '74', '99', '95', '11', '77', '59', '56', '97']
|
||||
# # 提取所有 chain_id,并去重
|
||||
# chain_ids = set()
|
||||
# for item in industry_list:
|
||||
# # 如果 chain_id 是字符串包含多个编号,用逗号或斜杠拆分
|
||||
# if isinstance(item["chain_id"], str):
|
||||
# for cid in item["chain_id"].replace("/", ",").split(","):
|
||||
# chain_ids.add(cid.strip())
|
||||
# else:
|
||||
# chain_ids.add(str(item["chain_id"]))
|
||||
|
||||
return industry_list_index
|
||||
# 从数据库计算脆弱产业集合
|
||||
def get_vulnerable100_code(engine, ga_id):
|
||||
"""
|
||||
计算最脆弱前100产品的 Code 列表(去重),只针对指定 ga_id。
|
||||
"""
|
||||
# 生成新的 session
|
||||
Session = sessionmaker(bind=engine)
|
||||
session = Session()
|
||||
# 1️⃣ 读取 SQL 文件
|
||||
base_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
sql_file = os.path.join(base_dir, "..", "GA_Agent_0925", "SQL_analysis_risk_ga.sql")
|
||||
sql_file = os.path.abspath(sql_file)
|
||||
with open(sql_file, "r", encoding="utf-8") as f:
|
||||
str_sql = f.read() # 注意这里是 str_sql,不是 tr_sql
|
||||
print(f"[信息] 正在查询 ga_id={ga_id} 的脆弱产品数据...")
|
||||
|
||||
# 2️⃣ 执行 SQL 查询
|
||||
# 2️⃣ 新建 connection
|
||||
# 2️⃣ 新建 session,每次查询都用新的 session/connection
|
||||
Session = sessionmaker(bind=engine)
|
||||
with Session() as session:
|
||||
# 使用 session.connection() 来保证 SQLAlchemy 执行原生 SQL
|
||||
result = pd.read_sql(
|
||||
sql=text(str_sql),
|
||||
con=session.connection(),
|
||||
params={"ga_id": ga_id} # 绑定参数
|
||||
)
|
||||
|
||||
# ===============================
|
||||
# 2️⃣ 统计每个企业-产品组合出现次数
|
||||
# ===============================
|
||||
count_firm_prod = result.value_counts(subset=['id_firm', 'id_product'])
|
||||
count_firm_prod.name = 'count'
|
||||
count_firm_prod = count_firm_prod.to_frame().reset_index()
|
||||
count_firm_prod.to_csv('count_firm_prod.csv', index=False, encoding='utf-8-sig')
|
||||
|
||||
# ===============================
|
||||
# 3️⃣ 统计每个企业出现的总次数
|
||||
# ===============================
|
||||
count_firm = count_firm_prod.groupby('id_firm')['count'].sum().reset_index()
|
||||
count_firm.sort_values('count', ascending=False, inplace=True)
|
||||
count_firm.to_csv('count_firm.csv', index=False, encoding='utf-8-sig')
|
||||
|
||||
# ===============================
|
||||
# 4️⃣ 统计每个产品出现的总次数
|
||||
# ===============================
|
||||
count_prod = count_firm_prod.groupby('id_product')['count'].sum().reset_index()
|
||||
count_prod.sort_values('count', ascending=False, inplace=True)
|
||||
count_prod.to_csv('count_prod.csv', index=False, encoding='utf-8-sig')
|
||||
|
||||
# ===============================
|
||||
# 5️⃣ 选出最脆弱的前100个产品(出现次数最多)
|
||||
# ===============================
|
||||
vulnerable100_product = count_prod.nlargest(100, "count")["id_product"].tolist()
|
||||
print(f"[信息] ga_id={ga_id} 查询完成,共找到 {len(vulnerable100_product)} 个脆弱产品")
|
||||
|
||||
# ===============================
|
||||
# # 6️⃣ 过滤 result,只保留前100脆弱产品
|
||||
# # ===============================
|
||||
# result_vulnerable100 = result[result['id_product'].isin(vulnerable100_product)].copy()
|
||||
# print(f"[信息] 筛选后剩余记录数: {len(result_vulnerable100)}")
|
||||
# # ===============================
|
||||
# # 7️⃣ 构造 DCP(Disruption Causing Probability)
|
||||
# # ===============================
|
||||
# result_dcp_list = []
|
||||
# for sid, group in result_vulnerable100.groupby('s_id'):
|
||||
# ts_start = max(group['ts'])
|
||||
# while ts_start >= 1:
|
||||
# ts_end = ts_start - 1
|
||||
# while ts_end >= 0:
|
||||
# up = group.loc[group['ts'] == ts_end, ['id_firm', 'id_product']]
|
||||
# down = group.loc[group['ts'] == ts_start, ['id_firm', 'id_product']]
|
||||
# for _, up_row in up.iterrows():
|
||||
# for _, down_row in down.iterrows():
|
||||
# result_dcp_list.append([sid] + up_row.tolist() + down_row.tolist())
|
||||
# ts_end -= 1
|
||||
# ts_start -= 1
|
||||
#
|
||||
# # 转换为 DataFrame
|
||||
# result_dcp = pd.DataFrame(result_dcp_list, columns=[
|
||||
# 's_id', 'up_id_firm', 'up_id_product', 'down_id_firm', 'down_id_product'
|
||||
# ])
|
||||
#
|
||||
# # ===============================
|
||||
# # 8️⃣ 统计 DCP 出现次数
|
||||
# # ===============================
|
||||
# count_dcp = result_dcp.value_counts(
|
||||
# subset=['up_id_firm', 'up_id_product', 'down_id_firm', 'down_id_product']
|
||||
# ).reset_index(name='count')
|
||||
#
|
||||
# # 保存文件
|
||||
# count_dcp.to_csv('count_dcp.csv', index=False, encoding='utf-8-sig')
|
||||
|
||||
# 输出结果
|
||||
return vulnerable100_product
|
||||
|
||||
def run_ABM_samples(controller_db_obj,ga_id,dct_exp, str_code="GA",):
|
||||
"""
|
||||
从数据库获取一个随机样本,锁定它,然后运行模型仿真。
|
||||
参数:
|
||||
controller_db: ControllerDB 对象
|
||||
str_code: 可选标识,用于打印
|
||||
返回:
|
||||
True 如果没有可用样本
|
||||
False 如果成功运行
|
||||
"""
|
||||
# 1. 从数据库获取一个随机样本
|
||||
sample_random = controller_db_obj.fetch_a_sample(s_id=None) # s_id 可根据需要传入
|
||||
if sample_random is None:
|
||||
#print(ga_id+"无样本")
|
||||
return True # 没有样本,返回 True 表示结束
|
||||
# 2. 锁定该样本
|
||||
controller_db_obj.lock_the_sample(sample_random)
|
||||
# print(f"Pid {pid} ({str_code}) is running sample {sample_random.id} at {datetime.now()}")
|
||||
# print(f"Pid {pid} ({str_code})")
|
||||
# print(f"[信息] 当前正在运行的 GA 个体 ID: {ga_id}, 时间: {datetime.now()}")
|
||||
# 3. 获取 experiment 的所有列及其值
|
||||
dct_exp_new = {column: getattr(sample_random.experiment, column)
|
||||
for column in sample_random.experiment.__table__.c.keys()}
|
||||
# 删除不需要的主键 id
|
||||
dct_exp_new.pop('id', None)
|
||||
dct_exp_new = {'sample': sample_random,
|
||||
'seed': sample_random.seed,
|
||||
**dct_exp_new}
|
||||
try:
|
||||
dct_sample_para = {
|
||||
**dct_exp_new,
|
||||
**dct_exp}
|
||||
# 切换工作目录到项目根目录或 ABM 需要的目录
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
os.chdir(project_root)
|
||||
|
||||
abm_model = MyModel(dct_sample_para)
|
||||
abm_model.step()
|
||||
abm_model.end(ga_id)
|
||||
except Exception as e:
|
||||
print(f"[❌ ABM运行错误] 错误:{e} ")
|
||||
traceback.print_exc()
|
||||
|
||||
def do_computation(controller_db_obj,ga_id,dct_exp,):
|
||||
"""每个进程执行 ABM 样本运行"""
|
||||
pid = os.getpid()
|
||||
print(f"[启动] 进程 {pid} 已启动 (PID={pid})")
|
||||
while 1:
|
||||
# time.sleep(random.uniform(0, 1))
|
||||
is_all_done = run_ABM_samples(controller_db_obj,ga_id,dct_exp,)
|
||||
if is_all_done:
|
||||
break
|
||||
|
||||
def do_process(controller_db,ga_id,dct_exp,job):
|
||||
process_list = []
|
||||
for i in range(job):
|
||||
p = Process(target=do_computation, args=(controller_db,ga_id,dct_exp,))
|
||||
p.start()
|
||||
process_list.append(p)
|
||||
for i in process_list:
|
||||
i.join()
|
||||
BIN
GA_Agent_0925/g_bom.pkl
Normal file
BIN
GA_Agent_0925/g_bom.pkl
Normal file
Binary file not shown.
@@ -1,116 +0,0 @@
|
||||
import numpy as np # 引入NumPy库,用于高效的数值计算
|
||||
from pySOT.utils import round_vars # 引入用于四舍五入的函数
|
||||
from typing import TYPE_CHECKING # 引入类型检查工具
|
||||
if TYPE_CHECKING:
|
||||
from policy import Policy
|
||||
|
||||
class GeneticAlgorithm: # 定义一个遗传算法类
|
||||
def __init__(self, the_policy: 'Policy'): # 初始化方法,接收一个 Policy 对象
|
||||
self.n_variables = the_policy.dim # 从 policy 中获取问题的维度
|
||||
self.lower_boundary = the_policy.lb # 获取决策变量的下界
|
||||
self.upper_boundary = the_policy.ub # 获取决策变量的上界
|
||||
self.integer_variables = the_policy.int_var # 获取整数变量的索引
|
||||
|
||||
self.sigma = 0.2 # 设置变异操作的标准差
|
||||
self.p_mutation = 1.0 / the_policy.dim # 设置变异概率
|
||||
self.tournament_size = 5 # 设置锦标赛选择的大小
|
||||
self.p_cross = 0.9 # 设置交叉概率
|
||||
|
||||
pop_size = the_policy.arr_init_doe_points.shape[0] # 获取种群大小
|
||||
self.lst_value = the_policy.lst_y_init_doe_points # 初始化每个个体的适应度值
|
||||
|
||||
# 如果种群大小是奇数,生成一个随机个体来确保种群大小是偶数
|
||||
if pop_size % 2 == 1:
|
||||
arr_random = np.random.rand(1, self.n_variables) # 生成一个随机的个体
|
||||
arr_one_random = self.lower_boundary + arr_random * (self.upper_boundary - self.lower_boundary) # 将随机个体约束在边界内
|
||||
self.lst_value.append(the_policy.eval(arr_one_random[0, :], is_init_points=True)) # 评估该个体的适应度
|
||||
self.population = np.vstack((the_policy.arr_init_doe_points, arr_one_random)) # 将该个体加入到种群中
|
||||
else:
|
||||
self.population = np.copy(the_policy.arr_init_doe_points) # 直接使用初始种群
|
||||
|
||||
self.n_individuals = self.population.shape[0] # 获取种群中个体的数量
|
||||
assert self.n_individuals == pop_size or self.n_individuals == pop_size + 1, 'Wrong pop size' # 确保种群大小正确
|
||||
|
||||
# 如果有整数变量,需要进行位置四舍五入
|
||||
if len(self.integer_variables) > 0:
|
||||
self.population[:, self.integer_variables] = np.round(self.population[:, self.integer_variables]) # 对整数变量四舍五入
|
||||
for i in self.integer_variables:
|
||||
ind = np.where(self.population[:, i] < self.lower_boundary[i]) # 如果超出了下界,修正为下界
|
||||
self.population[ind, i] += 1
|
||||
ind = np.where(self.population[:, i] > self.upper_boundary[i]) # 如果超出了上界,修正为上界
|
||||
self.population[ind, i] -= 1
|
||||
|
||||
self.ind, self.best_individual, self.best_value = None, None, None # 初始化最优个体和最优值
|
||||
self.pop_next, self.lst_pop_next_is_evaluated = None, None # 初始化下一代种群和评估标志
|
||||
self.update_info() # 更新最优解信息
|
||||
|
||||
def update_info(self):
|
||||
# 更新最优个体和适应度值
|
||||
self.ind = np.argmin(self.lst_value) # 获取适应度最小的个体(假设目标是最小化)
|
||||
self.best_individual = np.copy(self.population[self.ind, :]) # 复制最优个体
|
||||
self.best_value = self.lst_value[self.ind] # 记录最优值
|
||||
|
||||
self.pop_next, self.lst_pop_next_is_evaluated = self._generate_next_population() # 生成下一代种群
|
||||
self.lst_value = [] # 清空当前种群的适应度值
|
||||
|
||||
def _generate_next_population(self):
|
||||
# 生成下一代种群
|
||||
competitors = np.random.randint(0, self.n_individuals, (self.n_individuals, self.tournament_size)) # 随机选择竞赛个体
|
||||
ind = np.argmin(np.array(self.lst_value)[competitors], axis=1) # 选择每轮锦标赛中的最优个体
|
||||
winner_indices = np.zeros(self.n_individuals, dtype=int) # 用于存储胜利个体的索引
|
||||
for i in range(self.tournament_size): # 进行锦标赛选择
|
||||
winner_indices[np.where(ind == i)] = competitors[np.where(ind == i), i]
|
||||
|
||||
# 按照锦标赛结果将种群分为父母
|
||||
parent1 = self.population[winner_indices[0: self.n_individuals // 2], :]
|
||||
parent2 = self.population[winner_indices[self.n_individuals // 2: self.n_individuals], :]
|
||||
|
||||
# 交叉操作:对父母个体进行交叉
|
||||
cross = np.where(np.random.rand(self.n_individuals // 2) < self.p_cross)[0] # 按照概率决定哪些个体进行交叉
|
||||
nn = len(cross) # 计算交叉个体的数量
|
||||
alpha = np.random.rand(nn, 1) # 生成交叉系数
|
||||
|
||||
# 创建新的染色体
|
||||
parent1_new = np.multiply(alpha, parent1[cross, :]) + np.multiply(1 - alpha, parent2[cross, :])
|
||||
parent2_new = np.multiply(alpha, parent2[cross, :]) + np.multiply(1 - alpha, parent1[cross, :])
|
||||
parent1[cross, :] = parent1_new
|
||||
parent2[cross, :] = parent2_new
|
||||
arr_new_population = np.concatenate((parent1, parent2)) # 合并两个父代得到新的种群
|
||||
|
||||
# 变异操作
|
||||
scale_factors = self.sigma * (self.upper_boundary - self.lower_boundary) # 计算变异的尺度
|
||||
perturbation = np.random.randn(self.n_individuals, self.n_variables) # 生成扰动
|
||||
perturbation = np.multiply(perturbation, scale_factors) # 根据尺度调整扰动
|
||||
perturbation = np.multiply(
|
||||
perturbation, (np.random.rand(self.n_individuals, self.n_variables) < self.p_mutation)
|
||||
) # 根据变异概率决定哪些位置进行扰动
|
||||
|
||||
arr_new_population += perturbation # 将扰动添加到新种群
|
||||
arr_new_population = np.maximum(np.reshape(self.lower_boundary, (1, self.n_variables)), arr_new_population) # 确保不超过下界
|
||||
arr_new_population = np.minimum(np.reshape(self.upper_boundary, (1, self.n_variables)), arr_new_population) # 确保不超过上界
|
||||
|
||||
# 如果有整数变量,进行四舍五入
|
||||
if len(self.integer_variables) > 0:
|
||||
arr_new_population = round_vars(arr_new_population, self.integer_variables, self.lower_boundary,
|
||||
self.upper_boundary)
|
||||
|
||||
assert arr_new_population.shape[0] == self.n_individuals, 'Wrong arr_new_population shape' # 确保新种群的大小正确
|
||||
return arr_new_population, [False] * self.n_individuals # 返回新种群和评估标志(都设为未评估)
|
||||
|
||||
def select_next_point(self):
|
||||
# 选择下一个要评估的个体
|
||||
for idx_ind, is_evaluated in enumerate(self.lst_pop_next_is_evaluated):
|
||||
if not is_evaluated: # 如果该个体没有被评估
|
||||
return self.pop_next[idx_ind, :] # 返回该个体
|
||||
|
||||
def receive_sim_value(self, the_value):
|
||||
# 接收评估结果,并更新适应度信息
|
||||
self.lst_value.append(the_value) # 将评估值添加到适应度列表
|
||||
idx_ind = 0
|
||||
for idx_ind, is_evaluated in enumerate(self.lst_pop_next_is_evaluated):
|
||||
if not is_evaluated: # 找到未评估的个体
|
||||
self.lst_pop_next_is_evaluated[idx_ind] = True # 标记该个体为已评估
|
||||
break
|
||||
if idx_ind == len(self.lst_pop_next_is_evaluated) - 1:
|
||||
assert idx_ind == self.n_individuals - 1, 'Wrong index' # 确保所有个体都已评估
|
||||
self.update_info() # 更新最优解信息
|
||||
109
GA_Agent_0925/main.py
Normal file
109
GA_Agent_0925/main.py
Normal file
@@ -0,0 +1,109 @@
|
||||
import json
|
||||
import random
|
||||
from deap import tools
|
||||
from sqlalchemy.orm import close_all_sessions
|
||||
from tqdm import tqdm
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from GA_Agent_0925.creating import creating
|
||||
from GA_Agent_0925.orm import connection
|
||||
from controller_db import ControllerDB
|
||||
from evaluate_func import fitness, get_vulnerable100_code, get_target_vulnerable_industries
|
||||
|
||||
|
||||
# ==============================
|
||||
# 遗传算法主函数(单进程)
|
||||
# ==============================
|
||||
def main():
|
||||
# 1️⃣ 加载配置
|
||||
with open("config.json", "r", encoding="utf-8") as f:
|
||||
cfg = json.load(f)
|
||||
|
||||
random.seed(cfg["seed"])
|
||||
|
||||
print("\n📘 参数配置:")
|
||||
for k, v in cfg.items():
|
||||
print(f" {k}: {v}")
|
||||
print("-" * 40)
|
||||
|
||||
# 2️⃣ 初始化 ControllerDB(数据库连接)
|
||||
controller_db_obj = ControllerDB("without_exp", reset_flag=0)
|
||||
controller_db_obj.reset_db(force_drop=True)
|
||||
# 准备样本表
|
||||
controller_db_obj.prepare_list_sample()
|
||||
# 2️⃣ 初始化工具箱
|
||||
toolbox = creating()
|
||||
pop = toolbox.population(n=cfg["pop_size"])
|
||||
hof = tools.HallOfFame(1)
|
||||
stats = tools.Statistics(lambda ind: ind.fitness.values)
|
||||
stats.register("avg", lambda fits: sum(f[0] for f in fits) / len(fits))
|
||||
stats.register("max", lambda fits: max(f[0] for f in fits))
|
||||
|
||||
best_list = []
|
||||
avg_list = []
|
||||
|
||||
# ==============================
|
||||
# 主进化循环
|
||||
# ==============================
|
||||
for gen in tqdm(range(cfg["n_gen"]), desc="进化中", ncols=90):
|
||||
# 计算未评估个体适应度
|
||||
invalid_ind = [ind for ind in pop if not ind.fitness.valid]
|
||||
for ind in invalid_ind:
|
||||
controller_db_obj.reset_sample_db()
|
||||
controller_db_obj.prepare_list_sample()
|
||||
ind.fitness.values = fitness(ind, controller_db_obj=controller_db_obj)
|
||||
|
||||
# 选择、交叉、变异
|
||||
offspring = toolbox.select(pop, len(pop))
|
||||
offspring = list(map(toolbox.clone, offspring))
|
||||
|
||||
for child1, child2 in zip(offspring[::2], offspring[1::2]):
|
||||
if random.random() < cfg["cx_prob"]:
|
||||
toolbox.mate(child1, child2)
|
||||
del child1.fitness.values, child2.fitness.values
|
||||
|
||||
for mutant in offspring:
|
||||
if random.random() < cfg["mut_prob"]:
|
||||
toolbox.mutate(mutant)
|
||||
del mutant.fitness.values
|
||||
|
||||
# 更新适应度
|
||||
invalid_ind = [ind for ind in offspring if not ind.fitness.valid]
|
||||
for ind in invalid_ind:
|
||||
controller_db_obj.reset_sample_db()
|
||||
controller_db_obj.prepare_list_sample()
|
||||
ind.fitness.values = fitness(ind, controller_db_obj=controller_db_obj)
|
||||
|
||||
pop[:] = offspring
|
||||
hof.update(pop)
|
||||
|
||||
record = stats.compile(pop)
|
||||
best_list.append(record["max"])
|
||||
avg_list.append(record["avg"])
|
||||
|
||||
# ==============================
|
||||
# 输出最优结果
|
||||
# ==============================
|
||||
print("\n✅ 进化完成!")
|
||||
print(f"🏆 最优个体: {hof[0]}")
|
||||
print(f"🌟 最优适应度: {hof[0].fitness.values[0]:.4f}")
|
||||
|
||||
# 绘制收敛曲线
|
||||
plt.figure(figsize=(8, 5))
|
||||
plt.plot(best_list, label="Best Fitness", linewidth=2)
|
||||
plt.plot(avg_list, label="Average Fitness", linestyle="--")
|
||||
plt.title("Genetic Algorithm Convergence")
|
||||
plt.xlabel("Generation")
|
||||
plt.ylabel("Fitness")
|
||||
plt.legend()
|
||||
plt.grid(True, alpha=0.3)
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
# ==============================
|
||||
# 最优个体产业匹配
|
||||
# ==============================
|
||||
print("\n📊 计算最优个体产业匹配情况...")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
117
GA_Agent_0925/orm.py
Normal file
117
GA_Agent_0925/orm.py
Normal file
@@ -0,0 +1,117 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from datetime import datetime
|
||||
from sqlalchemy import create_engine, inspect, Inspector, Float
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy import (Column, Integer, DECIMAL, String, ForeignKey,
|
||||
BigInteger, DateTime, PickleType, Boolean, Text)
|
||||
from sqlalchemy.sql import func
|
||||
from sqlalchemy.orm import relationship, Session
|
||||
from sqlalchemy.pool import NullPool
|
||||
import yaml
|
||||
|
||||
with open('../conf_db.yaml') as file:
|
||||
dct_conf_db_all = yaml.full_load(file)
|
||||
is_local_db = dct_conf_db_all['is_local_db']
|
||||
if is_local_db:
|
||||
dct_conf_db = dct_conf_db_all['local']
|
||||
else:
|
||||
dct_conf_db = dct_conf_db_all['remote']
|
||||
|
||||
with open('../conf_db_prefix.yaml') as file:
|
||||
dct_conf_db_prefix = yaml.full_load(file)
|
||||
db_name_prefix = dct_conf_db_prefix['db_name_prefix']
|
||||
|
||||
str_login = 'mysql://{}:{}@{}:{}/{}'.format(dct_conf_db['user_name'],
|
||||
dct_conf_db['password'],
|
||||
dct_conf_db['address'],
|
||||
dct_conf_db['port'],
|
||||
dct_conf_db['db_name'])
|
||||
# print('DB is {}:{}/{}'.format(dct_conf_db['address'], dct_conf_db['port'], dct_conf_db['db_name']))
|
||||
|
||||
# must be null pool to avoid connection lost error
|
||||
engine = create_engine(str_login, poolclass=NullPool)
|
||||
connection = engine.connect()
|
||||
ins: Inspector = inspect(engine)
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
db_session = Session(bind=engine)
|
||||
|
||||
|
||||
class Experiment(Base):
|
||||
__tablename__ = f"{db_name_prefix}_experiment"
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
|
||||
idx_scenario = Column(Integer, nullable=False)
|
||||
idx_init_removal = Column(Integer, nullable=False)
|
||||
|
||||
# fixed parameters
|
||||
n_sample = Column(Integer, nullable=False)
|
||||
n_iter = Column(Integer, nullable=False)
|
||||
|
||||
# variables
|
||||
dct_lst_init_disrupt_firm_prod = Column(PickleType, nullable=False)
|
||||
g_bom = Column(Text(4294000000), nullable=False)
|
||||
|
||||
n_max_trial = Column(Integer, nullable=False)
|
||||
prf_size = Column(Boolean, nullable=False)
|
||||
prf_conn = Column(Boolean, nullable=False)
|
||||
cap_limit_prob_type = Column(String(16), nullable=False)
|
||||
cap_limit_level = Column(DECIMAL(8, 4), nullable=False)
|
||||
diff_new_conn = Column(DECIMAL(8, 4), nullable=False)
|
||||
remove_t = Column(Integer, nullable=False)
|
||||
netw_prf_n = Column(Integer, nullable=False)
|
||||
|
||||
sample = relationship(
|
||||
'Sample', back_populates='experiment', lazy='dynamic')
|
||||
|
||||
def __repr__(self):
|
||||
return f'<Experiment: {self.id}>'
|
||||
|
||||
|
||||
class Sample(Base):
|
||||
__tablename__ = f"{db_name_prefix}_sample"
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
e_id = Column(Integer, ForeignKey('{}.id'.format(
|
||||
f"{db_name_prefix}_experiment")), nullable=False)
|
||||
|
||||
idx_sample = Column(Integer, nullable=False)
|
||||
seed = Column(BigInteger, nullable=False)
|
||||
# -1, waiting; 0, running; 1, done
|
||||
is_done_flag = Column(Integer, nullable=False)
|
||||
computer_name = Column(String(64), nullable=True)
|
||||
ts_done = Column(DateTime(timezone=True), onupdate=func.now())
|
||||
stop_t = Column(Integer, nullable=True)
|
||||
|
||||
g_firm = Column(Text(4294000000), nullable=True)
|
||||
|
||||
experiment = relationship(
|
||||
'Experiment', back_populates='sample', uselist=False)
|
||||
result = relationship('Result', back_populates='sample', lazy='dynamic')
|
||||
|
||||
def __repr__(self):
|
||||
return f'<Sample id: {self.id}>'
|
||||
|
||||
|
||||
class Result(Base):
|
||||
__tablename__ = f"{db_name_prefix}_result"
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
s_id = Column(Integer, ForeignKey('{}.id'.format(
|
||||
f"{db_name_prefix}_sample")), nullable=False)
|
||||
|
||||
id_firm = Column(String(20), nullable=False)
|
||||
id_product = Column(String(20), nullable=False)
|
||||
ts = Column(Integer, nullable=False)
|
||||
status = Column(String(5), nullable=False)
|
||||
|
||||
sample = relationship('Sample', back_populates='result', uselist=False)
|
||||
|
||||
# 💥 新增 GA 调用 ID,用于标记属于哪一次遗传算法运行
|
||||
ga_id = Column(String(50), nullable=True)
|
||||
|
||||
def __repr__(self):
|
||||
return f'<Product id: {self.id}>'
|
||||
|
||||
if __name__ == '__main__':
|
||||
Base.metadata.drop_all()
|
||||
Base.metadata.create_all()
|
||||
@@ -1,23 +0,0 @@
|
||||
from sqlalchemy import text
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from orm import connection
|
||||
|
||||
# SQL query
|
||||
with open("../SQL_analysis_risk.sql", "r", encoding="utf-8") as f:
|
||||
str_sql = text(f.read())
|
||||
|
||||
result = pd.read_sql(sql=str_sql, con=connection)
|
||||
# Count firm product
|
||||
count_firm_prod = result.value_counts(subset=['id_firm', 'id_product'])
|
||||
count_firm_prod.name = 'count'
|
||||
count_firm_prod = count_firm_prod.to_frame().reset_index()
|
||||
|
||||
# Count product
|
||||
count_prod = count_firm_prod.groupby('id_product')['count'].sum()
|
||||
count_prod = count_prod.to_frame().reset_index()
|
||||
count_prod.sort_values('count', inplace=True, ascending=False)
|
||||
print(count_prod)
|
||||
top100 = count_prod.head(100)['id_product'].tolist()
|
||||
|
||||
211
GA_Agent_0925/多功能.py
Normal file
211
GA_Agent_0925/多功能.py
Normal file
@@ -0,0 +1,211 @@
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import networkx as nx
|
||||
from sqlalchemy import text
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from orm import connection
|
||||
|
||||
# """
|
||||
# 计算最脆弱前100产品的 Code 列表(去重)。
|
||||
# """
|
||||
# bom_file = r"../input_data/input_product_data/BomNodes.csv"
|
||||
# mapping_df = pd.read_csv(bom_file)
|
||||
#
|
||||
# with open("../SQL_analysis_risk.sql", "r", encoding="utf-8") as f:
|
||||
# str_sql = text(f.read())
|
||||
#
|
||||
# result = pd.read_sql(sql=str_sql, con=connection)
|
||||
#
|
||||
# count_firm_prod = result.value_counts(subset=['id_firm', 'id_product'])
|
||||
# count_firm_prod.name = 'count'
|
||||
# count_firm_prod = count_firm_prod.to_frame().reset_index()
|
||||
#
|
||||
# count_prod = (
|
||||
# count_firm_prod.groupby("id_product")["count"].sum().reset_index()
|
||||
# )
|
||||
#
|
||||
# vulnerable100_index = count_prod.nsmallest(100, "count")["id_product"].tolist()
|
||||
# # 确保 index_to_code 的 key 都是 int
|
||||
# index_to_code = {int(k): v for k, v in zip(mapping_df["Index"], mapping_df["Code"])}
|
||||
#
|
||||
# # vulnerable100_index 也转成 int
|
||||
# vulnerable100_index_int = [int(i) for i in vulnerable100_index]
|
||||
#
|
||||
# # 获取 code
|
||||
# vulnerable100_code = [index_to_code[i] for i in vulnerable100_index_int if i in index_to_code]
|
||||
#
|
||||
# print(vulnerable100_code)
|
||||
|
||||
|
||||
# 读取 SQL
|
||||
ga_id="c943f2c6"
|
||||
with open("SQL_analysis_risk_ga.sql", "r", encoding="utf-8") as f:
|
||||
str_sql = text(f.read())
|
||||
print(str_sql[:300])
|
||||
print(f"[信息] 正在查询 ga_id={ga_id} 的脆弱产品数据...")
|
||||
|
||||
# 执行 SQL 查询,并绑定参数 ga_id
|
||||
result = pd.read_sql(
|
||||
sql=str_sql,
|
||||
con=connection,
|
||||
params={"ga_id": ga_id} # 绑定参数
|
||||
)
|
||||
# ===============================
|
||||
# 2️⃣ 统计每个企业-产品组合出现次数
|
||||
# ===============================
|
||||
count_firm_prod = result.value_counts(subset=['id_firm', 'id_product'])
|
||||
count_firm_prod.name = 'count'
|
||||
count_firm_prod = count_firm_prod.to_frame().reset_index()
|
||||
count_firm_prod.to_csv('count_firm_prod.csv', index=False, encoding='utf-8-sig')
|
||||
|
||||
# ===============================
|
||||
# 3️⃣ 统计每个企业出现的总次数
|
||||
# ===============================
|
||||
count_firm = count_firm_prod.groupby('id_firm')['count'].sum().reset_index()
|
||||
count_firm.sort_values('count', ascending=False, inplace=True)
|
||||
count_firm.to_csv('count_firm.csv', index=False, encoding='utf-8-sig')
|
||||
|
||||
# ===============================
|
||||
# 4️⃣ 统计每个产品出现的总次数
|
||||
# ===============================
|
||||
count_prod = count_firm_prod.groupby('id_product')['count'].sum().reset_index()
|
||||
count_prod.sort_values('count', ascending=False, inplace=True)
|
||||
count_prod.to_csv('count_prod.csv', index=False, encoding='utf-8-sig')
|
||||
|
||||
# ===============================
|
||||
# 5️⃣ 选出最脆弱的前100个产品(出现次数最多)
|
||||
# ===============================
|
||||
vulnerable100_product = count_prod.nlargest(100, "count")["id_product"].tolist()
|
||||
print(f"[信息] ga_id={ga_id} 查询完成,共找到 {len(vulnerable100_product)} 个脆弱产品")
|
||||
|
||||
# ===============================
|
||||
# 6️⃣ 过滤 result,只保留前100脆弱产品
|
||||
# ===============================
|
||||
result_vulnerable100 = result[result['id_product'].isin(vulnerable100_product)].copy()
|
||||
print(f"[信息] 筛选后剩余记录数: {len(result_vulnerable100)}")
|
||||
|
||||
# ===============================
|
||||
# 7️⃣ 构造 DCP(Disruption Causing Probability)
|
||||
# ===============================
|
||||
result_dcp_list = []
|
||||
for sid, group in result_vulnerable100.groupby('s_id'):
|
||||
ts_start = max(group['ts'])
|
||||
while ts_start >= 1:
|
||||
ts_end = ts_start - 1
|
||||
while ts_end >= 0:
|
||||
up = group.loc[group['ts'] == ts_end, ['id_firm', 'id_product']]
|
||||
down = group.loc[group['ts'] == ts_start, ['id_firm', 'id_product']]
|
||||
for _, up_row in up.iterrows():
|
||||
for _, down_row in down.iterrows():
|
||||
result_dcp_list.append([sid] + up_row.tolist() + down_row.tolist())
|
||||
ts_end -= 1
|
||||
ts_start -= 1
|
||||
|
||||
# 转换为 DataFrame
|
||||
result_dcp = pd.DataFrame(result_dcp_list, columns=[
|
||||
's_id', 'up_id_firm', 'up_id_product', 'down_id_firm', 'down_id_product'
|
||||
])
|
||||
|
||||
# ===============================
|
||||
# 8️⃣ 统计 DCP 出现次数
|
||||
# ===============================
|
||||
count_dcp = result_dcp.value_counts(
|
||||
subset=['up_id_firm', 'up_id_product', 'down_id_firm', 'down_id_product']
|
||||
).reset_index(name='count')
|
||||
|
||||
# 保存文件
|
||||
count_dcp.to_csv('count_dcp.csv', index=False, encoding='utf-8-sig')
|
||||
|
||||
# 输出结果
|
||||
print(count_dcp)
|
||||
print(type(vulnerable100_product[0]))
|
||||
|
||||
# industry_list = [
|
||||
# # ① 半导体设备类
|
||||
# {"product": "离子注入机", "category": "离子注入设备", "chain_id": 34538},
|
||||
# {"product": "刻蚀设备 / 湿法刻蚀设备", "category": "刻蚀机", "chain_id": 34529},
|
||||
# {"product": "沉积设备", "category": "薄膜生长设备(CVD/PVD)", "chain_id": 34539},
|
||||
# {"product": "CVD", "category": "薄膜生长设备", "chain_id": 34539},
|
||||
# {"product": "PVD", "category": "薄膜生长设备", "chain_id": 34539},
|
||||
# {"product": "CMP", "category": "化学机械抛光设备", "chain_id": 34530},
|
||||
# {"product": "光刻机", "category": "光刻机", "chain_id": 34533},
|
||||
# {"product": "涂胶显影机", "category": "涂胶显影设备", "chain_id": 34535},
|
||||
# {"product": "晶圆清洗设备", "category": "晶圆清洗机", "chain_id": 34531},
|
||||
# {"product": "测试设备", "category": "测试机", "chain_id": 34554},
|
||||
# {"product": "外延生长设备", "category": "薄膜生长设备", "chain_id": 34539},
|
||||
#
|
||||
# # ② 半导体材料与化学品类
|
||||
# {"product": "三氯乙烯", "category": "清洗溶剂 → 通用湿电子化学品", "chain_id": 32438},
|
||||
# {"product": "丙酮", "category": "清洗溶剂 → 通用湿电子化学品", "chain_id": 32438},
|
||||
# {"product": "异丙醇", "category": "清洗溶剂 → 通用湿电子化学品", "chain_id": 32438},
|
||||
# {"product": "其他醇类", "category": "清洗溶剂 → 通用湿电子化学品", "chain_id": 32438},
|
||||
# {"product": "光刻胶", "category": "光刻胶及配套试剂", "chain_id": 32445},
|
||||
# {"product": "显影液", "category": "显影液", "chain_id": 46504},
|
||||
# {"product": "蚀刻液", "category": "蚀刻液", "chain_id": 56341},
|
||||
# {"product": "光阻去除剂", "category": "光阻去除剂", "chain_id": 32442},
|
||||
#
|
||||
# # ③ 晶圆制造类
|
||||
# {"product": "晶圆", "category": "单晶硅片 / 多晶硅片", "chain_id": 32338},
|
||||
# {"product": "硅衬底", "category": "硅衬底", "chain_id": 36914},
|
||||
# {"product": "外延片", "category": "硅外延片 / GaN外延片 / SiC外延片等", "chain_id": 32338},
|
||||
#
|
||||
# # ④ 封装与测试类
|
||||
# {"product": "封装", "category": "IC封装", "chain_id": 10},
|
||||
# {"product": "测试", "category": "芯片测试 / 晶圆测试", "chain_id": 513742},
|
||||
# {"product": "测试", "category": "芯片测试 / 晶圆测试", "chain_id": 11},
|
||||
#
|
||||
# # ⑤ 芯片与设计EDA类
|
||||
# {"product": "芯片(通用)", "category": "集成电路制造", "chain_id": 317589},
|
||||
# {"product": "DRAM", "category": "存储芯片 → 集成电路制造", "chain_id": 317589},
|
||||
# {"product": "GPU", "category": "图形芯片 → 集成电路制造", "chain_id": 317589},
|
||||
# {"product": "处理器(CPU/SoC)", "category": "芯片设计", "chain_id": 9},
|
||||
# {"product": "高频芯片", "category": "芯片设计", "chain_id": 9},
|
||||
# {"product": "光子芯片(含激光)", "category": "芯片设计 / 功率半导体器件", "chain_id": 9},
|
||||
# {"product": "光子芯片(含激光)", "category": "芯片设计 / 功率半导体器件", "chain_id": 2717},
|
||||
# {"product": "先进节点制造设备", "category": "集成电路制造", "chain_id": 317589},
|
||||
# {"product": "EDA及IP服务", "category": "设计辅助", "chain_id": 2515},
|
||||
# {"product": "MPW服务", "category": "多项目晶圆流片", "chain_id": 2514},
|
||||
# {"product": "芯片设计验证", "category": "设计验证", "chain_id": 513738},
|
||||
# {"product": "过程工艺检测", "category": "制程检测", "chain_id": 513740}
|
||||
# ]
|
||||
# # 提取所有 chain_id,并去重
|
||||
# chain_ids = set()
|
||||
# for item in industry_list:
|
||||
# # 如果 chain_id 是字符串包含多个编号,用逗号或斜杠拆分
|
||||
# if isinstance(item["chain_id"], str):
|
||||
# for cid in item["chain_id"].replace("/", ",").split(","):
|
||||
# chain_ids.add(cid.strip())
|
||||
# else:
|
||||
# chain_ids.add(str(item["chain_id"]))
|
||||
# print(list(chain_ids))
|
||||
# fill g_bom
|
||||
# 结点属性值 相当于 图上点的 原始 产品名称
|
||||
# bom_nodes = pd.read_csv('../input_data/input_product_data/BomNodes.csv')
|
||||
# bom_nodes['Code'] = bom_nodes['Code'].astype(str)
|
||||
# bom_nodes.set_index('Index', inplace=True)
|
||||
#
|
||||
# bom_cate_net = pd.read_csv('../input_data/input_product_data/合成结点.csv')
|
||||
# g_bom = nx.from_pandas_edgelist(bom_cate_net, source='UPID', target='ID', create_using=nx.MultiDiGraph())
|
||||
# # 填充每一个结点 的具体内容 通过 相同的 code 并且通过BomNodes.loc[code].to_dict()字典化 格式类似 格式 { code(0) : {level: 0 ,name: 工业互联网 }}
|
||||
# bom_labels_dict = {}
|
||||
# for index in g_bom.nodes:
|
||||
# try:
|
||||
# bom_labels_dict[index] = bom_nodes.loc[index].to_dict()
|
||||
# # print(bom_labels_dict[index])
|
||||
# except KeyError:
|
||||
# print(f"节点 {index} 不存在于 bom_nodes 中")
|
||||
# # 分配属性 给每一个结点 获得类似 格式:{1: {'label': 'A', 'value': 10},
|
||||
# nx.set_node_attributes(g_bom, bom_labels_dict)
|
||||
# # 改为json 格式
|
||||
# g_product_js = json.dumps(nx.adjacency_data(g_bom))
|
||||
# # 假设 g_bom 是你的 NetworkX 图
|
||||
# g_product_data = nx.adjacency_data(g_bom)
|
||||
#
|
||||
# # 保存为 pkl 文件
|
||||
# with open("g_bom.pkl", "wb") as f:
|
||||
# pickle.dump(g_product_data, f)
|
||||
#
|
||||
# print("✅ 图数据已保存为 g_bom.pkl")
|
||||
@@ -1,6 +1,7 @@
|
||||
# -*- coding: utf-8 -*- # 文件的编码格式设置为 UTF-8
|
||||
from __future__ import division # 为了兼容 Python 2 和 3,保证除法始终返回浮点数
|
||||
|
||||
import multiprocessing
|
||||
import random # 导入 random 库,用于生成随机数
|
||||
|
||||
from deap import base # 从 DEAP 库导入 base 模块,提供一些遗传算法相关的功能
|
||||
@@ -11,6 +12,77 @@ from my_model import MyModel
|
||||
from sqlalchemy import text
|
||||
import pandas as pd
|
||||
from orm import connection
|
||||
|
||||
def main():
|
||||
random.seed(42) # 可复现结果
|
||||
print("Start of evolution")
|
||||
|
||||
ga = creating()
|
||||
pop = ga.population(n=50)
|
||||
CXPB, MUTPB, NGEN = 0.5, 0.2, 200
|
||||
|
||||
# # 并行计算
|
||||
# pool = multiprocessing.Pool()
|
||||
# ga.register("map", pool.map)
|
||||
|
||||
# 改为:
|
||||
ga.register("map", map) # 单进程
|
||||
|
||||
# 评估初始种群
|
||||
fitnesses = list(ga.map(ga.evaluate, pop))
|
||||
for ind, fit in zip(pop, fitnesses):
|
||||
ind.fitness.values = fit
|
||||
print(f"Evaluated {len(pop)} individuals")
|
||||
|
||||
best_log = []
|
||||
|
||||
for g in range(NGEN):
|
||||
print(f"-- Generation {g} --")
|
||||
|
||||
# 选择并克隆
|
||||
offspring = list(map(ga.clone, ga.select(pop, len(pop))))
|
||||
|
||||
# 交叉与变异
|
||||
for child1, child2 in zip(offspring[::2], offspring[1::2]):
|
||||
if random.random() < CXPB:
|
||||
ga.mate(child1, child2)
|
||||
del child1.fitness.values
|
||||
del child2.fitness.values
|
||||
|
||||
for mutant in offspring:
|
||||
if random.random() < MUTPB:
|
||||
ga.mutate(mutant)
|
||||
del mutant.fitness.values
|
||||
|
||||
# 重新计算失效适应度
|
||||
invalid_ind = [ind for ind in offspring if not ind.fitness.valid]
|
||||
fitnesses = list(ga.map(ga.evaluate, invalid_ind))
|
||||
for ind, fit in zip(invalid_ind, fitnesses):
|
||||
ind.fitness.values = fit
|
||||
|
||||
pop[:] = offspring
|
||||
|
||||
# 最优个体
|
||||
best_ind = tools.selBest(pop, 1)[0]
|
||||
best_log.append((g, best_ind.fitness.values[0]))
|
||||
|
||||
print(f"Best individual {g}: {best_ind}, Fitness: {best_ind.fitness.values[0]:.3f}")
|
||||
|
||||
# 写入数据库
|
||||
result_sql = text(f"""
|
||||
INSERT INTO ga (generation, stu_beta, stu_nmb, gtu_mgf, gtu_discount, fitness, remark)
|
||||
VALUES ({g}, {best_ind[0]}, {best_ind[1]}, {best_ind[2]}, {best_ind[3]}, {best_ind.fitness.values[0]}, 'Random2')
|
||||
""")
|
||||
with connection.connect() as conn:
|
||||
conn.execute(result_sql)
|
||||
conn.commit()
|
||||
|
||||
# pool.close()
|
||||
# pool.join()
|
||||
|
||||
pd.DataFrame(best_log, columns=["generation", "fitness"]).to_csv("ga_log.csv", index=False)
|
||||
print("-- End of (successful) evolution --")
|
||||
|
||||
# 目标函数(适应度函数),用于评估个体的适应度
|
||||
def fitness(individual):
|
||||
"""
|
||||
@@ -77,6 +149,10 @@ def creating():
|
||||
创建遗传算法工具箱,用于优化 ABM 模型参数,使生成的脆弱产业集合
|
||||
与目标产业集合误差最小化(fitness 最大化)。
|
||||
"""
|
||||
if "FitnessMax" not in creator.__dict__:
|
||||
creator.create("FitnessMax", base.Fitness, weights=(1.0,))
|
||||
if "Individual" not in creator.__dict__:
|
||||
creator.create("Individual", list, fitness=creator.FitnessMax)
|
||||
# 定义最大化适应度
|
||||
creator.create("FitnessMax", base.Fitness, weights=(1.0,))
|
||||
# 定义个体类
|
||||
@@ -124,7 +200,7 @@ def creating():
|
||||
toolbox.register("population", tools.initRepeat, list, toolbox.individual)
|
||||
|
||||
# 注册 fitness 函数(需要在调用时传入目标产业集合)
|
||||
# toolbox.register("evaluate", fitness) # 可以在 main 中使用 lambda 包装 target_chain_set
|
||||
toolbox.register("evaluate", fitness) # 可以在 main 中使用 lambda 包装 target_chain_set
|
||||
|
||||
# 交叉、变异和选择操作
|
||||
toolbox.register("mate", tools.cxTwoPoint)
|
||||
@@ -133,76 +209,6 @@ def creating():
|
||||
|
||||
return toolbox
|
||||
|
||||
def main():
|
||||
# 创建遗传算法的工具箱
|
||||
ga = creating()
|
||||
|
||||
# 初始化种群大小为 50
|
||||
pop = ga.population(n=50)
|
||||
|
||||
# 交叉概率、变异概率和代数
|
||||
CXPB, MUTPB, NGEN = 0.5, 0.2, 500
|
||||
|
||||
print("Start of evolution")
|
||||
|
||||
# 评估整个种群的适应度
|
||||
fitnesses = list(map(ga.evaluate, pop))
|
||||
for ind, fit in zip(pop, fitnesses):
|
||||
ind.fitness.values = fit
|
||||
|
||||
print(" Evaluated %i individuals" % len(pop))
|
||||
# my_sql = Sql() # 创建 Sql 类的实例,用于与数据库交互
|
||||
|
||||
# 开始演化
|
||||
for g in range(NGEN):
|
||||
print("-- Generation %i --" % g)
|
||||
|
||||
# 选择下一代的个体
|
||||
offspring = ga.select(pop, len(pop))
|
||||
# 克隆选择的个体
|
||||
offspring = list(map(ga.clone, offspring))
|
||||
|
||||
# 对后代进行交叉和变异
|
||||
for child1, child2 in zip(offspring[::2], offspring[1::2]):
|
||||
# 以 CXPB 的概率交叉两个个体
|
||||
if random.random() < CXPB:
|
||||
ga.mate(child1, child2)
|
||||
|
||||
# 交叉后的适应度值需要重新计算
|
||||
del child1.fitness.values
|
||||
del child2.fitness.values
|
||||
|
||||
for mutant in offspring:
|
||||
# 以 MUTPB 的概率变异个体
|
||||
if random.random() < MUTPB:
|
||||
ga.mutate(mutant)
|
||||
del mutant.fitness.values
|
||||
|
||||
# 评估适应度无效的个体
|
||||
invalid_ind = [ind for ind in offspring if not ind.fitness.valid]
|
||||
fitnesses = map(ga.evaluate, invalid_ind)
|
||||
for ind, fit in zip(invalid_ind, fitnesses):
|
||||
ind.fitness.values = fit
|
||||
|
||||
print(" Evaluated %i individuals" % len(invalid_ind))
|
||||
|
||||
# 将种群完全替换为后代
|
||||
pop[:] = offspring
|
||||
|
||||
# 收集所有个体的适应度并打印统计信息
|
||||
fits = [ind.fitness.values[0] for ind in pop]
|
||||
|
||||
# 获取当前最好的个体并打印
|
||||
best_ind = tools.selBest(pop, 1)[0]
|
||||
print("Best individual is %s, %s" % (best_ind, best_ind.fitness.values))
|
||||
|
||||
# 将最优个体的信息插入数据库
|
||||
result_string = '''INSERT INTO ga (generation, stu_beta, stu_nmb, gtu_mgf, gtu_discount, fitness, remark)
|
||||
VALUES ({}, {}, {}, {}, {}, {}, 'Random2')'''.format(g, best_ind[0], best_ind[1], best_ind[2], best_ind[3], best_ind.fitness.values[0])
|
||||
# my_sql.insert_one_row_and_return_new_id(result_string)
|
||||
|
||||
print("-- End of (successful) evolution --")
|
||||
|
||||
def get_target_vulnerable_industries():
|
||||
"""
|
||||
获取行业列表中所有产业链编号的集合(整数形式)。
|
||||
@@ -280,11 +286,6 @@ def get_target_vulnerable_industries():
|
||||
|
||||
return chain_ids
|
||||
|
||||
|
||||
import pandas as pd
|
||||
from sqlalchemy import text # 用于 SQL 查询
|
||||
|
||||
|
||||
def get_vulnerable100_code(connection):
|
||||
"""
|
||||
计算最脆弱前100产品的 Code 列表(去重)。
|
||||
@@ -294,11 +295,11 @@ def get_vulnerable100_code(connection):
|
||||
List[int]: 最脆弱前100产品对应的 Code 列表
|
||||
"""
|
||||
# 读取映射表
|
||||
bom_file = r"../input_data/input_product_data/BomNodes.csv" # 直接给出路径
|
||||
bom_file = r"../../input_data/input_product_data/BomNodes.csv" # 直接给出路径
|
||||
mapping_df = pd.read_csv(bom_file)
|
||||
|
||||
# 执行 SQL 获取结果
|
||||
with open("../SQL_analysis_risk.sql", "r", encoding="utf-8") as f:
|
||||
with open("../../SQL_analysis_risk.sql", "r", encoding="utf-8") as f:
|
||||
str_sql = text(f.read())
|
||||
|
||||
result = pd.read_sql(sql=str_sql, con=connection)
|
||||
43
GA_Agent_0925/进度.py
Normal file
43
GA_Agent_0925/进度.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from matplotlib import rcParams, pyplot as plt
|
||||
from sqlalchemy import func
|
||||
from orm import db_session, Sample
|
||||
|
||||
# 🔹 全局创建绘图对象
|
||||
plt.ion() # 启用交互模式
|
||||
fig, ax = plt.subplots(figsize=(8, 5))
|
||||
rcParams['font.family'] = 'Microsoft YaHei'
|
||||
rcParams['font.size'] = 12
|
||||
|
||||
# 初始化柱状图
|
||||
labels = ['未完成 (-1)', '计算中(0)', '完成 (1)']
|
||||
initial_values = [0, 0, 0]
|
||||
bars = ax.bar(labels, initial_values, color=['red', 'orange', 'green'])
|
||||
value_texts = [ax.text(bar.get_x() + bar.get_width()/2, 0, '0',
|
||||
ha='center', va='bottom', fontsize=12)
|
||||
for bar in bars]
|
||||
|
||||
ax.set_title('任务进度分布', fontsize=16)
|
||||
ax.set_xlabel('任务状态', fontsize=14)
|
||||
ax.set_ylabel('数量', fontsize=14)
|
||||
ax.tick_params(axis='both', labelsize=12)
|
||||
|
||||
def visualize_progress():
|
||||
"""
|
||||
实时更新 Sample 表中 is_done_flag 的分布。
|
||||
"""
|
||||
# 查询数据库
|
||||
result = db_session.query(Sample.is_done_flag, func.count(Sample.id))\
|
||||
.group_by(Sample.is_done_flag).all()
|
||||
data = {flag: count for flag, count in result}
|
||||
for flag in [-1, 0, 1]:
|
||||
data.setdefault(flag, 0)
|
||||
values = [data[-1], data[0], data[1]]
|
||||
|
||||
# 更新柱子高度和文本
|
||||
for bar, new_val, txt in zip(bars, values, value_texts):
|
||||
bar.set_height(new_val)
|
||||
txt.set_y(new_val + 0.5)
|
||||
txt.set_text(str(new_val))
|
||||
|
||||
plt.draw()
|
||||
plt.pause(0.1) # 刷新图表
|
||||
Reference in New Issue
Block a user