# -*- coding: utf-8 -*- from orm import db_session, engine, Base, ins, connection from orm import Experiment, Sample, Result from sqlalchemy.exc import OperationalError from sqlalchemy import text import yaml import random import pandas as pd import platform import networkx as nx import json import pickle class ControllerDB: is_with_exp: bool dct_parameter = None is_test: bool = None db_name_prefix: str = None reset_flag: int lst_saved_s_id: list def __init__(self, prefix, reset_flag=0): with open('conf_experiment.yaml') as yaml_file: dct_conf_experiment = yaml.full_load(yaml_file) assert prefix in ['test', 'without_exp', 'with_exp'], "db name not in test, without_exp, with_exp" self.is_test = prefix == 'test' self.is_with_exp = False if prefix == 'test' or prefix == 'without_exp' else True self.db_name_prefix = prefix dct_para_in_test = dct_conf_experiment['test'] if self.is_test else dct_conf_experiment['not_test'] self.dct_parameter = {'meta_seed': dct_conf_experiment['meta_seed'], **dct_para_in_test} print(self.dct_parameter) # 0, not reset; 1, reset self; 2, reset all self.reset_flag = reset_flag self.is_exist = False self.lst_saved_s_id = [] self.experiment_data = [] self.batch_size = 2000 # 根据需求设置每批次的大小 def init_tables(self): self.fill_experiment_table() self.fill_sample_table() def fill_experiment_table(self): firm = pd.read_csv("input_data/input_firm_data/Firm_amended.csv") firm['Code'] = firm['Code'].astype('string') firm.fillna(0, inplace=True) # fill dct_lst_init_disrupt_firm_prod # 存储 公司-在供应链结点的位置.. 0 :‘1.1’ list_dct = [] # 存储 公司编码code 和对应的产业链 结点 if self.is_with_exp: # 对于方差分析时候使用 with open('SQL_export_high_risk_setting.sql', 'r') as f: str_sql = text(f.read()) result = pd.read_sql(sql=str_sql, con=connection) result['dct_lst_init_disrupt_firm_prod'] = \ result['dct_lst_init_disrupt_firm_prod'].apply( lambda x: pickle.loads(x)) list_dct = result['dct_lst_init_disrupt_firm_prod'].to_list() else: # 行索引 (index):这一行在数据帧中的索引值。 # 行数据 (row):这一行的数据,是一个 pandas.Series 对象,包含该行的所有列和值。 firm_industry = pd.read_csv("input_data/firm_industry_relation.csv") firm_industry['Firm_Code'] = firm_industry['Firm_Code'].astype('string') for _, row in firm_industry.iterrows(): code = row['Firm_Code'] row = row['Product_Code'] dct = {code: [row]} list_dct.append(dct) # fill g_bom # 结点属性值 相当于 图上点的 原始 产品名称 bom_nodes = pd.read_csv('input_data/input_product_data/BomNodes.csv') bom_nodes['Code'] = bom_nodes['Code'].astype(str) bom_nodes.set_index('Code', inplace=True) # bom_cate_net = pd.read_csv('input_data/input_product_data/BomCateNet.csv', index_col=0) # bom_cate_net.fillna(0, inplace=True) # # 创建 可以多边的有向图 同时 转置操作 使得 上游指向下游结点 也就是 1.1.1 - 1.1 类似这种 # # 将第一列转换为字符串类型 # print("sssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss") # print(bom_cate_net.columns) # print(bom_cate_net.index) # 打印行标题(索引) # print(bom_cate_net.iloc[:, 0]) # 打印第一列的内容 # # g_bom = nx.from_pandas_adjacency(bom_cate_net.T, # create_using=nx.MultiDiGraph()) bom_cate_net = pd.read_csv('input_data/input_product_data/合成结点.csv') g_bom = nx.from_pandas_edgelist(bom_cate_net, source='UPID', target='ID', create_using=nx.MultiDiGraph()) # 填充每一个结点 的具体内容 通过 相同的 code 并且通过BomNodes.loc[code].to_dict()字典化 格式类似 格式 { code(0) : {level: 0 ,name: 工业互联网 }} bom_labels_dict = {} for code in g_bom.nodes: try: int_code = int(code) bom_labels_dict[code] = bom_nodes.loc[int_code].to_dict() except KeyError: print(f"节点 {code} 不存在于 bom_nodes 中") # 分配属性 给每一个结点 获得类似 格式:{1: {'label': 'A', 'value': 10}, nx.set_node_attributes(g_bom, bom_labels_dict) # 改为json 格式 g_product_js = json.dumps(nx.adjacency_data(g_bom)) # insert exp df_xv = pd.read_csv( "input_data/" f"xv_{'with_exp' if self.is_with_exp else 'without_exp'}.csv", index_col=None) # read the OA table df_oa = pd.read_csv( "input_data/" f"oa_{'with_exp' if self.is_with_exp else 'without_exp'}.csv", index_col=None) # .shape[1] 列数 .iloc 访问特定的值 而不是标签 df_oa = df_oa.iloc[:, 0:df_xv.shape[1]] # idx_scenario 是 0 指行 idx_init_removal 指 索引 0.. dct_init_removal 键 code 公司 g_product_js 图的json数据 dct_exp_para 解码 全局参数xv- for idx_scenario, row in df_oa.iterrows(): dct_exp_para = {} for idx_col, para_level in enumerate(row): # 处理 NaN 值,替换为默认值(如 0 或其他合适的值) para_level = para_level if not pd.isna(para_level) else 0 # 转换为整数 para_level = int(para_level) dct_exp_para[df_xv.columns[idx_col]] = \ df_xv.iloc[para_level, idx_col] # different initial removal 只会得到 键 和 值 for idx_init_removal, dct_init_removal in enumerate(list_dct): self.add_experiment_1(idx_scenario, idx_init_removal, dct_init_removal, g_product_js, **dct_exp_para) print(f"Inserted experiment for scenario {idx_scenario}, " f"init_removal {idx_init_removal}!") self.finalize_insertion() def add_experiment_1(self, idx_scenario, idx_init_removal, dct_lst_init_disrupt_firm_prod, g_bom, n_max_trial, prf_size, prf_conn, cap_limit_prob_type, cap_limit_level, diff_new_conn, remove_t, netw_prf_n): e = Experiment( idx_scenario=idx_scenario, idx_init_removal=idx_init_removal, n_sample=int(self.dct_parameter['n_sample']), n_iter=int(self.dct_parameter['n_iter']), dct_lst_init_disrupt_firm_prod=dct_lst_init_disrupt_firm_prod, g_bom=g_bom, n_max_trial=n_max_trial, prf_size=prf_size, prf_conn=prf_conn, cap_limit_prob_type=cap_limit_prob_type, cap_limit_level=cap_limit_level, diff_new_conn=diff_new_conn, remove_t=remove_t, netw_prf_n=netw_prf_n ) # 这里我们不立即提交,而是先添加到批量保存的队列中 self.experiment_data.append(e) # 当批量数据达到一定数量时再提交 if len(self.experiment_data) >= self.batch_size: self._commit_batch() # 辅助方法:批量提交 def _commit_batch(self): db_session.bulk_save_objects(self.experiment_data) db_session.commit() self.experiment_data.clear() # 清空队列 def finalize_insertion(self): if self.experiment_data: self._commit_batch() # 提交剩余的数据 def fill_sample_table(self): rng = random.Random(self.dct_parameter['meta_seed']) # 根据样本数目 设置 32 位随机整数 lst_seed = [ rng.getrandbits(32) for _ in range(int(self.dct_parameter['n_sample'])) ] lst_exp = db_session.query(Experiment).all() lst_sample = [] for experiment in lst_exp: # idx_sample: 1-50 for idx_sample in range(int(experiment.n_sample)): s = Sample(e_id=experiment.id, idx_sample=idx_sample + 1, seed=lst_seed[idx_sample], is_done_flag=-1) lst_sample.append(s) # 每当达到批量大小时提交一次 if len(lst_sample) >= self.batch_size: db_session.bulk_save_objects(lst_sample) db_session.commit() print(f'Inserted {len(lst_sample)} samples!') lst_sample.clear() # 清空已提交的样本列表 # 提交剩余的样本 if lst_sample: db_session.bulk_save_objects(lst_sample) db_session.commit() print(f'Inserted {len(lst_sample)} samples!') def reset_db(self, force_drop=False): # first, check if tables exist lst_table_obj = [ Base.metadata.tables[str_table] for str_table in ins.get_table_names() if str_table.startswith(self.db_name_prefix) ] self.is_exist = len(lst_table_obj) > 0 if force_drop: self.force_drop_db(lst_table_obj) # while is_exist: # a_table = random.choice(lst_table_obj) # try: # Base.metadata.drop_all(bind=engine, tables=[a_table]) # except KeyError: # pass # except OperationalError: # pass # else: # lst_table_obj.remove(a_table) # print( # f"Table {a_table.name} is dropped " # f"for exp: {self.db_name_prefix}!!!" # ) # finally: # is_exist = len(lst_table_obj) > 0 if self.is_exist: print( f"All tables exist. No need to reset " f"for exp: {self.db_name_prefix}." ) # change the is_done_flag from 0 to -1 # rerun the in-finished tasks self.is_exist_reset_flag_resset_db() # if self.reset_flag > 0: # if self.reset_flag == 2: # sample = db_session.query(Sample).filter( # Sample.is_done_flag == 0) # elif self.reset_flag == 1: # sample = db_session.query(Sample).filter( # Sample.is_done_flag == 0, # Sample.computer_name == platform.node()) # else: # raise ValueError('Wrong reset flag') # if sample.count() > 0: # for s in sample: # qry_result = db_session.query(Result).filter_by( # s_id=s.id) # if qry_result.count() > 0: # db_session.query(Result).filter(s_id=s.id).delete() # db_session.commit() # s.is_done_flag = -1 # db_session.commit() # print(f"Reset the sample id {s.id} flag from 0 to -1") else: # 不存在则重新生成所有的表结构 Base.metadata.create_all(bind=engine) self.init_tables() print( f"All tables are just created and initialized " f"for exp: {self.db_name_prefix}." ) def force_drop_db(self, lst_table_obj): self.is_exist = len(lst_table_obj) > 0 while self.is_exist: a_table = random.choice(lst_table_obj) try: Base.metadata.drop_all(bind=engine, tables=[a_table]) except KeyError: pass except OperationalError: pass else: lst_table_obj.remove(a_table) print( f"Table {a_table.name} is dropped " f"for exp: {self.db_name_prefix}!!!" ) finally: self.is_exist = len(lst_table_obj) > 0 def is_exist_reset_flag_resset_db(self): if self.reset_flag > 0: if self.reset_flag == 2: sample = db_session.query(Sample).filter( Sample.is_done_flag == 0) elif self.reset_flag == 1: sample = db_session.query(Sample).filter( Sample.is_done_flag == 0, Sample.computer_name == platform.node()) else: raise ValueError('Wrong reset flag') if sample.count() > 0: for s in sample: qry_result = db_session.query(Result).filter_by( s_id=s.id) if qry_result.count() > 0: db_session.query(Result).filter(s_id=s.id).delete() db_session.commit() s.is_done_flag = -1 db_session.commit() print(f"Reset the sample id {s.id} flag from 0 to -1") def prepare_list_sample(self): # 为了符合前面 重置表里面存在 重置本机 或者重置全部 或者不重置的部分 这个部分的 关于样本运行也得重新拿出来 # 查找一个风险事件中 50 个样本 res = db_session.execute( text(f"SELECT count(*) FROM {self.db_name_prefix}_sample s, " f"{self.db_name_prefix}_experiment e WHERE s.e_id=e.id" )).scalar() # 控制 n_sample数量 作为后面的参数 n_sample = 0 if res is None else res print(f'There are a total of {n_sample} samples.') # 查找 is_done_flag = -1 也就是没有运行的 样本 运行后会改为0 res = db_session.execute( text(f"SELECT id FROM {self.db_name_prefix}_sample " f"WHERE is_done_flag = -1" )) for row in res: s_id = row[0] self.lst_saved_s_id.append(s_id) @staticmethod def select_random_sample(lst_s_id): while 1: if len(lst_s_id) == 0: return None s_id = random.choice(lst_s_id) lst_s_id.remove(s_id) res = db_session.query(Sample).filter(Sample.id == int(s_id), Sample.is_done_flag == -1) if res.count() == 1: return res[0] def fetch_a_sample(self, s_id=None): # 由Computation 调用 返回 sample对象 同时给出 2中 指定访问模式 抓取特定的 样本 通过s_id # 默认访问 flag为-1的 lst_saved_s_id if s_id is not None: res = db_session.query(Sample).filter(Sample.id == int(s_id)) if res.count() == 0: return None else: return res[0] sample = self.select_random_sample(self.lst_saved_s_id) if sample is not None: return sample return None @staticmethod def lock_the_sample(sample: Sample): sample.is_done_flag, sample.computer_name = 0, platform.node() db_session.commit() if __name__ == '__main__': print("Testing the database connection...") try: controller_db = ControllerDB('test') Base.metadata.create_all(bind=engine) except Exception as e: print("Failed to connect to the database!") print(e) exit(1)