2024-08-24 11:20:13 +08:00
|
|
|
|
# -*- coding: utf-8 -*-
|
|
|
|
|
from orm import db_session, engine, Base, ins, connection
|
|
|
|
|
from orm import Experiment, Sample, Result
|
|
|
|
|
from sqlalchemy.exc import OperationalError
|
|
|
|
|
from sqlalchemy import text
|
|
|
|
|
import yaml
|
|
|
|
|
import random
|
|
|
|
|
import pandas as pd
|
|
|
|
|
import platform
|
|
|
|
|
import networkx as nx
|
|
|
|
|
import json
|
|
|
|
|
import pickle
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ControllerDB:
|
|
|
|
|
is_with_exp: bool
|
|
|
|
|
dct_parameter = None
|
|
|
|
|
is_test: bool = None
|
|
|
|
|
db_name_prefix: str = None
|
|
|
|
|
reset_flag: int
|
|
|
|
|
|
|
|
|
|
lst_saved_s_id: list
|
|
|
|
|
|
|
|
|
|
def __init__(self, prefix, reset_flag=0):
|
|
|
|
|
with open('conf_experiment.yaml') as yaml_file:
|
|
|
|
|
dct_conf_experiment = yaml.full_load(yaml_file)
|
|
|
|
|
assert prefix in ['test', 'without_exp', 'with_exp'], "db name not in test, without_exp, with_exp"
|
|
|
|
|
|
|
|
|
|
self.is_test = prefix == 'test'
|
|
|
|
|
self.is_with_exp = False if prefix == 'test' or prefix == 'without_exp' else True
|
|
|
|
|
self.db_name_prefix = prefix
|
|
|
|
|
dct_para_in_test = dct_conf_experiment['test'] if self.is_test else dct_conf_experiment['not_test']
|
2024-10-21 17:41:50 +08:00
|
|
|
|
self.dct_parameter = {'meta_seed': dct_conf_experiment['meta_seed'], **dct_para_in_test}
|
2024-08-24 11:20:13 +08:00
|
|
|
|
|
|
|
|
|
print(self.dct_parameter)
|
|
|
|
|
# 0, not reset; 1, reset self; 2, reset all
|
|
|
|
|
self.reset_flag = reset_flag
|
|
|
|
|
self.is_exist = False
|
|
|
|
|
self.lst_saved_s_id = []
|
|
|
|
|
|
2024-10-21 17:41:50 +08:00
|
|
|
|
self.experiment_data = []
|
|
|
|
|
self.batch_size = 999 # 根据需求设置每批次的大小
|
2024-08-24 11:20:13 +08:00
|
|
|
|
|
|
|
|
|
def init_tables(self):
|
|
|
|
|
self.fill_experiment_table()
|
|
|
|
|
self.fill_sample_table()
|
|
|
|
|
|
|
|
|
|
def fill_experiment_table(self):
|
2024-09-29 16:41:34 +08:00
|
|
|
|
firm = pd.read_csv("input_data/input_firm_data/Firm_amended.csv")
|
2024-09-18 16:59:32 +08:00
|
|
|
|
firm['Code'] = firm['Code'].astype('string')
|
|
|
|
|
firm.fillna(0, inplace=True)
|
2024-08-24 11:20:13 +08:00
|
|
|
|
|
|
|
|
|
# fill dct_lst_init_disrupt_firm_prod
|
|
|
|
|
# 存储 公司-在供应链结点的位置.. 0 :‘1.1’
|
2024-10-21 17:41:50 +08:00
|
|
|
|
list_dct = [] # 存储 公司编码code 和对应的产业链 结点
|
2024-08-24 11:20:13 +08:00
|
|
|
|
if self.is_with_exp:
|
2024-10-21 17:41:50 +08:00
|
|
|
|
# 对于方差分析时候使用
|
2024-08-24 11:20:13 +08:00
|
|
|
|
with open('SQL_export_high_risk_setting.sql', 'r') as f:
|
|
|
|
|
str_sql = text(f.read())
|
|
|
|
|
result = pd.read_sql(sql=str_sql, con=connection)
|
|
|
|
|
result['dct_lst_init_disrupt_firm_prod'] = \
|
|
|
|
|
result['dct_lst_init_disrupt_firm_prod'].apply(
|
|
|
|
|
lambda x: pickle.loads(x))
|
|
|
|
|
list_dct = result['dct_lst_init_disrupt_firm_prod'].to_list()
|
|
|
|
|
else:
|
|
|
|
|
# 行索引 (index):这一行在数据帧中的索引值。
|
|
|
|
|
# 行数据 (row):这一行的数据,是一个 pandas.Series 对象,包含该行的所有列和值。
|
2024-09-18 16:59:32 +08:00
|
|
|
|
|
2024-10-21 17:41:50 +08:00
|
|
|
|
firm_industry = pd.read_csv("input_data/firm_industry_relation.csv")
|
2024-09-20 09:26:39 +08:00
|
|
|
|
firm_industry['Firm_Code'] = firm_industry['Firm_Code'].astype('string')
|
|
|
|
|
for _, row in firm_industry.iterrows():
|
|
|
|
|
code = row['Firm_Code']
|
|
|
|
|
row = row['Product_Code']
|
|
|
|
|
dct = {code: [row]}
|
|
|
|
|
list_dct.append(dct)
|
2024-08-24 11:20:13 +08:00
|
|
|
|
|
|
|
|
|
# fill g_bom
|
|
|
|
|
# 结点属性值 相当于 图上点的 原始 产品名称
|
2024-10-21 17:41:50 +08:00
|
|
|
|
bom_nodes = pd.read_csv('input_data/input_product_data/BomNodes.csv')
|
|
|
|
|
bom_nodes['Code'] = bom_nodes['Code'].astype(str)
|
2024-09-18 16:59:32 +08:00
|
|
|
|
bom_nodes.set_index('Code', inplace=True)
|
2024-10-21 17:41:50 +08:00
|
|
|
|
# bom_cate_net = pd.read_csv('input_data/input_product_data/BomCateNet.csv', index_col=0)
|
|
|
|
|
# bom_cate_net.fillna(0, inplace=True)
|
|
|
|
|
# # 创建 可以多边的有向图 同时 转置操作 使得 上游指向下游结点 也就是 1.1.1 - 1.1 类似这种
|
|
|
|
|
# # 将第一列转换为字符串类型
|
|
|
|
|
# print("sssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss")
|
|
|
|
|
# print(bom_cate_net.columns)
|
|
|
|
|
# print(bom_cate_net.index) # 打印行标题(索引)
|
|
|
|
|
# print(bom_cate_net.iloc[:, 0]) # 打印第一列的内容
|
|
|
|
|
#
|
|
|
|
|
# g_bom = nx.from_pandas_adjacency(bom_cate_net.T,
|
|
|
|
|
# create_using=nx.MultiDiGraph())
|
|
|
|
|
|
|
|
|
|
bom_cate_net = pd.read_csv('input_data/input_product_data/合成结点.csv')
|
|
|
|
|
g_bom = nx.from_pandas_edgelist(bom_cate_net, source='UPID', target='ID', create_using=nx.MultiDiGraph())
|
2024-08-24 11:20:13 +08:00
|
|
|
|
# 填充每一个结点 的具体内容 通过 相同的 code 并且通过BomNodes.loc[code].to_dict()字典化 格式类似 格式 { code(0) : {level: 0 ,name: 工业互联网 }}
|
|
|
|
|
bom_labels_dict = {}
|
|
|
|
|
for code in g_bom.nodes:
|
2024-10-21 17:41:50 +08:00
|
|
|
|
try:
|
|
|
|
|
int_code = int(code)
|
|
|
|
|
bom_labels_dict[code] = bom_nodes.loc[int_code].to_dict()
|
|
|
|
|
except KeyError:
|
|
|
|
|
print(f"节点 {code} 不存在于 bom_nodes 中")
|
2024-08-24 11:20:13 +08:00
|
|
|
|
# 分配属性 给每一个结点 获得类似 格式:{1: {'label': 'A', 'value': 10},
|
|
|
|
|
nx.set_node_attributes(g_bom, bom_labels_dict)
|
|
|
|
|
# 改为json 格式
|
|
|
|
|
g_product_js = json.dumps(nx.adjacency_data(g_bom))
|
|
|
|
|
|
|
|
|
|
# insert exp
|
|
|
|
|
df_xv = pd.read_csv(
|
|
|
|
|
"input_data/"
|
|
|
|
|
f"xv_{'with_exp' if self.is_with_exp else 'without_exp'}.csv",
|
|
|
|
|
index_col=None)
|
|
|
|
|
# read the OA table
|
|
|
|
|
df_oa = pd.read_csv(
|
|
|
|
|
"input_data/"
|
|
|
|
|
f"oa_{'with_exp' if self.is_with_exp else 'without_exp'}.csv",
|
|
|
|
|
index_col=None)
|
|
|
|
|
# .shape[1] 列数 .iloc 访问特定的值 而不是标签
|
|
|
|
|
df_oa = df_oa.iloc[:, 0:df_xv.shape[1]]
|
2024-08-24 16:13:37 +08:00
|
|
|
|
|
2024-08-24 11:20:13 +08:00
|
|
|
|
# idx_scenario 是 0 指行 idx_init_removal 指 索引 0.. dct_init_removal 键 code 公司 g_product_js 图的json数据 dct_exp_para 解码 全局参数xv-
|
|
|
|
|
for idx_scenario, row in df_oa.iterrows():
|
|
|
|
|
dct_exp_para = {}
|
|
|
|
|
for idx_col, para_level in enumerate(row):
|
2024-08-24 16:13:37 +08:00
|
|
|
|
# 处理 NaN 值,替换为默认值(如 0 或其他合适的值)
|
|
|
|
|
para_level = para_level if not pd.isna(para_level) else 0
|
|
|
|
|
# 转换为整数
|
|
|
|
|
para_level = int(para_level)
|
2024-08-24 11:20:13 +08:00
|
|
|
|
dct_exp_para[df_xv.columns[idx_col]] = \
|
|
|
|
|
df_xv.iloc[para_level, idx_col]
|
|
|
|
|
# different initial removal 只会得到 键 和 值
|
|
|
|
|
for idx_init_removal, dct_init_removal in enumerate(list_dct):
|
|
|
|
|
self.add_experiment_1(idx_scenario,
|
|
|
|
|
idx_init_removal,
|
|
|
|
|
dct_init_removal,
|
|
|
|
|
g_product_js,
|
|
|
|
|
**dct_exp_para)
|
|
|
|
|
print(f"Inserted experiment for scenario {idx_scenario}, "
|
|
|
|
|
f"init_removal {idx_init_removal}!")
|
2024-10-21 17:41:50 +08:00
|
|
|
|
self.finalize_insertion()
|
2024-08-24 11:20:13 +08:00
|
|
|
|
|
|
|
|
|
def add_experiment_1(self, idx_scenario, idx_init_removal,
|
|
|
|
|
dct_lst_init_disrupt_firm_prod, g_bom,
|
|
|
|
|
n_max_trial, prf_size, prf_conn,
|
|
|
|
|
cap_limit_prob_type, cap_limit_level,
|
|
|
|
|
diff_new_conn, remove_t, netw_prf_n):
|
|
|
|
|
e = Experiment(
|
|
|
|
|
idx_scenario=idx_scenario,
|
|
|
|
|
idx_init_removal=idx_init_removal,
|
|
|
|
|
n_sample=int(self.dct_parameter['n_sample']),
|
|
|
|
|
n_iter=int(self.dct_parameter['n_iter']),
|
|
|
|
|
dct_lst_init_disrupt_firm_prod=dct_lst_init_disrupt_firm_prod,
|
|
|
|
|
g_bom=g_bom,
|
|
|
|
|
n_max_trial=n_max_trial,
|
|
|
|
|
prf_size=prf_size,
|
|
|
|
|
prf_conn=prf_conn,
|
|
|
|
|
cap_limit_prob_type=cap_limit_prob_type,
|
|
|
|
|
cap_limit_level=cap_limit_level,
|
|
|
|
|
diff_new_conn=diff_new_conn,
|
|
|
|
|
remove_t=remove_t,
|
|
|
|
|
netw_prf_n=netw_prf_n
|
|
|
|
|
)
|
2024-10-21 17:41:50 +08:00
|
|
|
|
# 这里我们不立即提交,而是先添加到批量保存的队列中
|
|
|
|
|
self.experiment_data.append(e)
|
|
|
|
|
|
|
|
|
|
# 当批量数据达到一定数量时再提交
|
|
|
|
|
if len(self.experiment_data) >= self.batch_size:
|
|
|
|
|
self._commit_batch()
|
|
|
|
|
|
|
|
|
|
# 辅助方法:批量提交
|
|
|
|
|
def _commit_batch(self):
|
|
|
|
|
db_session.bulk_save_objects(self.experiment_data)
|
2024-08-24 11:20:13 +08:00
|
|
|
|
db_session.commit()
|
2024-10-21 17:41:50 +08:00
|
|
|
|
self.experiment_data.clear() # 清空队列
|
|
|
|
|
|
|
|
|
|
def finalize_insertion(self):
|
|
|
|
|
if self.experiment_data:
|
|
|
|
|
self._commit_batch() # 提交剩余的数据
|
2024-08-24 11:20:13 +08:00
|
|
|
|
|
|
|
|
|
def fill_sample_table(self):
|
|
|
|
|
rng = random.Random(self.dct_parameter['meta_seed'])
|
|
|
|
|
# 根据样本数目 设置 32 位随机整数
|
|
|
|
|
lst_seed = [
|
|
|
|
|
rng.getrandbits(32)
|
|
|
|
|
for _ in range(int(self.dct_parameter['n_sample']))
|
|
|
|
|
]
|
|
|
|
|
lst_exp = db_session.query(Experiment).all()
|
|
|
|
|
|
|
|
|
|
lst_sample = []
|
|
|
|
|
for experiment in lst_exp:
|
|
|
|
|
# idx_sample: 1-50
|
|
|
|
|
for idx_sample in range(int(experiment.n_sample)):
|
|
|
|
|
s = Sample(e_id=experiment.id,
|
|
|
|
|
idx_sample=idx_sample + 1,
|
|
|
|
|
seed=lst_seed[idx_sample],
|
|
|
|
|
is_done_flag=-1)
|
|
|
|
|
lst_sample.append(s)
|
2024-10-21 17:41:50 +08:00
|
|
|
|
# 每当达到批量大小时提交一次
|
|
|
|
|
if len(lst_sample) >= self.batch_size:
|
|
|
|
|
db_session.bulk_save_objects(lst_sample)
|
|
|
|
|
db_session.commit()
|
|
|
|
|
print(f'Inserted {len(lst_sample)} samples!')
|
|
|
|
|
lst_sample.clear() # 清空已提交的样本列表
|
|
|
|
|
|
|
|
|
|
# 提交剩余的样本
|
|
|
|
|
if lst_sample:
|
|
|
|
|
db_session.bulk_save_objects(lst_sample)
|
|
|
|
|
db_session.commit()
|
2024-08-24 11:20:13 +08:00
|
|
|
|
print(f'Inserted {len(lst_sample)} samples!')
|
|
|
|
|
|
|
|
|
|
def reset_db(self, force_drop=False):
|
|
|
|
|
# first, check if tables exist
|
|
|
|
|
lst_table_obj = [
|
|
|
|
|
Base.metadata.tables[str_table]
|
|
|
|
|
for str_table in ins.get_table_names()
|
|
|
|
|
if str_table.startswith(self.db_name_prefix)
|
|
|
|
|
]
|
|
|
|
|
self.is_exist = len(lst_table_obj) > 0
|
|
|
|
|
if force_drop:
|
|
|
|
|
self.force_drop_db(lst_table_obj)
|
|
|
|
|
# while is_exist:
|
|
|
|
|
# a_table = random.choice(lst_table_obj)
|
|
|
|
|
# try:
|
|
|
|
|
# Base.metadata.drop_all(bind=engine, tables=[a_table])
|
|
|
|
|
# except KeyError:
|
|
|
|
|
# pass
|
|
|
|
|
# except OperationalError:
|
|
|
|
|
# pass
|
|
|
|
|
# else:
|
|
|
|
|
# lst_table_obj.remove(a_table)
|
|
|
|
|
# print(
|
|
|
|
|
# f"Table {a_table.name} is dropped "
|
|
|
|
|
# f"for exp: {self.db_name_prefix}!!!"
|
|
|
|
|
# )
|
|
|
|
|
# finally:
|
|
|
|
|
# is_exist = len(lst_table_obj) > 0
|
|
|
|
|
|
|
|
|
|
if self.is_exist:
|
|
|
|
|
print(
|
|
|
|
|
f"All tables exist. No need to reset "
|
|
|
|
|
f"for exp: {self.db_name_prefix}."
|
|
|
|
|
)
|
|
|
|
|
# change the is_done_flag from 0 to -1
|
|
|
|
|
# rerun the in-finished tasks
|
|
|
|
|
self.is_exist_reset_flag_resset_db()
|
|
|
|
|
# if self.reset_flag > 0:
|
|
|
|
|
# if self.reset_flag == 2:
|
|
|
|
|
# sample = db_session.query(Sample).filter(
|
|
|
|
|
# Sample.is_done_flag == 0)
|
|
|
|
|
# elif self.reset_flag == 1:
|
|
|
|
|
# sample = db_session.query(Sample).filter(
|
|
|
|
|
# Sample.is_done_flag == 0,
|
|
|
|
|
# Sample.computer_name == platform.node())
|
|
|
|
|
# else:
|
|
|
|
|
# raise ValueError('Wrong reset flag')
|
|
|
|
|
# if sample.count() > 0:
|
|
|
|
|
# for s in sample:
|
|
|
|
|
# qry_result = db_session.query(Result).filter_by(
|
|
|
|
|
# s_id=s.id)
|
|
|
|
|
# if qry_result.count() > 0:
|
|
|
|
|
# db_session.query(Result).filter(s_id=s.id).delete()
|
|
|
|
|
# db_session.commit()
|
|
|
|
|
# s.is_done_flag = -1
|
|
|
|
|
# db_session.commit()
|
|
|
|
|
# print(f"Reset the sample id {s.id} flag from 0 to -1")
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
# 不存在则重新生成所有的表结构
|
|
|
|
|
Base.metadata.create_all(bind=engine)
|
|
|
|
|
self.init_tables()
|
|
|
|
|
print(
|
|
|
|
|
f"All tables are just created and initialized "
|
|
|
|
|
f"for exp: {self.db_name_prefix}."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def force_drop_db(self, lst_table_obj):
|
|
|
|
|
self.is_exist = len(lst_table_obj) > 0
|
|
|
|
|
while self.is_exist:
|
|
|
|
|
a_table = random.choice(lst_table_obj)
|
|
|
|
|
try:
|
|
|
|
|
Base.metadata.drop_all(bind=engine, tables=[a_table])
|
|
|
|
|
except KeyError:
|
|
|
|
|
pass
|
|
|
|
|
except OperationalError:
|
|
|
|
|
pass
|
|
|
|
|
else:
|
|
|
|
|
lst_table_obj.remove(a_table)
|
|
|
|
|
print(
|
|
|
|
|
f"Table {a_table.name} is dropped "
|
|
|
|
|
f"for exp: {self.db_name_prefix}!!!"
|
|
|
|
|
)
|
|
|
|
|
finally:
|
|
|
|
|
self.is_exist = len(lst_table_obj) > 0
|
|
|
|
|
|
|
|
|
|
def is_exist_reset_flag_resset_db(self):
|
|
|
|
|
if self.reset_flag > 0:
|
|
|
|
|
if self.reset_flag == 2:
|
|
|
|
|
sample = db_session.query(Sample).filter(
|
|
|
|
|
Sample.is_done_flag == 0)
|
|
|
|
|
elif self.reset_flag == 1:
|
|
|
|
|
sample = db_session.query(Sample).filter(
|
|
|
|
|
Sample.is_done_flag == 0,
|
|
|
|
|
Sample.computer_name == platform.node())
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError('Wrong reset flag')
|
|
|
|
|
if sample.count() > 0:
|
|
|
|
|
for s in sample:
|
|
|
|
|
qry_result = db_session.query(Result).filter_by(
|
|
|
|
|
s_id=s.id)
|
|
|
|
|
if qry_result.count() > 0:
|
|
|
|
|
db_session.query(Result).filter(s_id=s.id).delete()
|
|
|
|
|
db_session.commit()
|
|
|
|
|
s.is_done_flag = -1
|
|
|
|
|
db_session.commit()
|
|
|
|
|
print(f"Reset the sample id {s.id} flag from 0 to -1")
|
|
|
|
|
|
|
|
|
|
def prepare_list_sample(self):
|
|
|
|
|
# 为了符合前面 重置表里面存在 重置本机 或者重置全部 或者不重置的部分 这个部分的 关于样本运行也得重新拿出来
|
|
|
|
|
# 查找一个风险事件中 50 个样本
|
|
|
|
|
res = db_session.execute(
|
|
|
|
|
text(f"SELECT count(*) FROM {self.db_name_prefix}_sample s, "
|
|
|
|
|
f"{self.db_name_prefix}_experiment e WHERE s.e_id=e.id"
|
|
|
|
|
)).scalar()
|
|
|
|
|
# 控制 n_sample数量 作为后面的参数
|
|
|
|
|
n_sample = 0 if res is None else res
|
|
|
|
|
print(f'There are a total of {n_sample} samples.')
|
|
|
|
|
# 查找 is_done_flag = -1 也就是没有运行的 样本 运行后会改为0
|
|
|
|
|
res = db_session.execute(
|
|
|
|
|
text(f"SELECT id FROM {self.db_name_prefix}_sample "
|
|
|
|
|
f"WHERE is_done_flag = -1"
|
|
|
|
|
))
|
|
|
|
|
for row in res:
|
|
|
|
|
s_id = row[0]
|
|
|
|
|
self.lst_saved_s_id.append(s_id)
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def select_random_sample(lst_s_id):
|
|
|
|
|
while 1:
|
|
|
|
|
if len(lst_s_id) == 0:
|
|
|
|
|
return None
|
|
|
|
|
s_id = random.choice(lst_s_id)
|
|
|
|
|
lst_s_id.remove(s_id)
|
|
|
|
|
res = db_session.query(Sample).filter(Sample.id == int(s_id),
|
|
|
|
|
Sample.is_done_flag == -1)
|
|
|
|
|
if res.count() == 1:
|
|
|
|
|
return res[0]
|
|
|
|
|
|
|
|
|
|
def fetch_a_sample(self, s_id=None):
|
|
|
|
|
# 由Computation 调用 返回 sample对象 同时给出 2中 指定访问模式 抓取特定的 样本 通过s_id
|
|
|
|
|
# 默认访问 flag为-1的 lst_saved_s_id
|
|
|
|
|
if s_id is not None:
|
|
|
|
|
res = db_session.query(Sample).filter(Sample.id == int(s_id))
|
|
|
|
|
if res.count() == 0:
|
|
|
|
|
return None
|
|
|
|
|
else:
|
|
|
|
|
return res[0]
|
|
|
|
|
|
|
|
|
|
sample = self.select_random_sample(self.lst_saved_s_id)
|
|
|
|
|
if sample is not None:
|
|
|
|
|
return sample
|
|
|
|
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def lock_the_sample(sample: Sample):
|
|
|
|
|
sample.is_done_flag, sample.computer_name = 0, platform.node()
|
|
|
|
|
db_session.commit()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
print("Testing the database connection...")
|
|
|
|
|
try:
|
|
|
|
|
controller_db = ControllerDB('test')
|
|
|
|
|
Base.metadata.create_all(bind=engine)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
print("Failed to connect to the database!")
|
|
|
|
|
print(e)
|
|
|
|
|
exit(1)
|