mesa/controller_db.py

334 lines
14 KiB
Python
Raw Normal View History

2024-08-24 11:20:13 +08:00
# -*- coding: utf-8 -*-
from orm import db_session, engine, Base, ins, connection
from orm import Experiment, Sample, Result
from sqlalchemy.exc import OperationalError
from sqlalchemy import text
import yaml
import random
import pandas as pd
import platform
import networkx as nx
import json
import pickle
class ControllerDB:
is_with_exp: bool
dct_parameter = None
is_test: bool = None
db_name_prefix: str = None
reset_flag: int
lst_saved_s_id: list
def __init__(self, prefix, reset_flag=0):
with open('conf_experiment.yaml') as yaml_file:
dct_conf_experiment = yaml.full_load(yaml_file)
assert prefix in ['test', 'without_exp', 'with_exp'], "db name not in test, without_exp, with_exp"
self.is_test = prefix == 'test'
self.is_with_exp = False if prefix == 'test' or prefix == 'without_exp' else True
self.db_name_prefix = prefix
dct_para_in_test = dct_conf_experiment['test'] if self.is_test else dct_conf_experiment['not_test']
self.dct_parameter = {'meta_seed': dct_conf_experiment['meta_seed'] , **dct_para_in_test}
print(self.dct_parameter)
# 0, not reset; 1, reset self; 2, reset all
self.reset_flag = reset_flag
self.is_exist = False
self.lst_saved_s_id = []
def init_tables(self):
self.fill_experiment_table()
self.fill_sample_table()
def fill_experiment_table(self):
Firm = pd.read_csv("input_data/Firm_amended.csv")
Firm['Code'] = Firm['Code'].astype('string')
Firm.fillna(0, inplace=True)
# fill dct_lst_init_disrupt_firm_prod
# 存储 公司-在供应链结点的位置.. 0 1.1
list_dct = [] #存储 公司编码code 和对应的产业链 结点
if self.is_with_exp:
#对于方差分析时候使用
with open('SQL_export_high_risk_setting.sql', 'r') as f:
str_sql = text(f.read())
result = pd.read_sql(sql=str_sql, con=connection)
result['dct_lst_init_disrupt_firm_prod'] = \
result['dct_lst_init_disrupt_firm_prod'].apply(
lambda x: pickle.loads(x))
list_dct = result['dct_lst_init_disrupt_firm_prod'].to_list()
else:
# 行索引 (index):这一行在数据帧中的索引值。
# 行数据 (row):这一行的数据,是一个 pandas.Series 对象,包含该行的所有列和值。
for _, row in Firm.iterrows():
code = row['Code']
row = row['1':]
for product_code in row.index[row == 1].to_list():
dct = {code: [product_code]}
list_dct.append(dct)
# fill g_bom
# 结点属性值 相当于 图上点的 原始 产品名称
BomNodes = pd.read_csv('input_data/BomNodes.csv', index_col=0)
BomNodes.set_index('Code', inplace=True)
BomCateNet = pd.read_csv('input_data/BomCateNet.csv', index_col=0)
BomCateNet.fillna(0, inplace=True)
# 创建 可以多边的有向图 同时 转置操作 使得 上游指向下游结点 也就是 1.1.1 - 1.1 类似这种
g_bom = nx.from_pandas_adjacency(BomCateNet.T,
create_using=nx.MultiDiGraph())
# 填充每一个结点 的具体内容 通过 相同的 code 并且通过BomNodes.loc[code].to_dict()字典化 格式类似 格式 { code0 : {level: 0 ,name: 工业互联网 }}
bom_labels_dict = {}
for code in g_bom.nodes:
bom_labels_dict[code] = BomNodes.loc[code].to_dict()
# 分配属性 给每一个结点 获得类似 格式:{1: {'label': 'A', 'value': 10},
nx.set_node_attributes(g_bom, bom_labels_dict)
# 改为json 格式
g_product_js = json.dumps(nx.adjacency_data(g_bom))
# insert exp
df_xv = pd.read_csv(
"input_data/"
f"xv_{'with_exp' if self.is_with_exp else 'without_exp'}.csv",
index_col=None)
# read the OA table
df_oa = pd.read_csv(
"input_data/"
f"oa_{'with_exp' if self.is_with_exp else 'without_exp'}.csv",
index_col=None)
# .shape[1] 列数 .iloc 访问特定的值 而不是标签
df_oa = df_oa.iloc[:, 0:df_xv.shape[1]]
# idx_scenario 是 0 指行 idx_init_removal 指 索引 0.. dct_init_removal 键 code 公司 g_product_js 图的json数据 dct_exp_para 解码 全局参数xv-
for idx_scenario, row in df_oa.iterrows():
dct_exp_para = {}
for idx_col, para_level in enumerate(row):
dct_exp_para[df_xv.columns[idx_col]] = \
df_xv.iloc[para_level, idx_col]
# different initial removal 只会得到 键 和 值
for idx_init_removal, dct_init_removal in enumerate(list_dct):
self.add_experiment_1(idx_scenario,
idx_init_removal,
dct_init_removal,
g_product_js,
**dct_exp_para)
print(f"Inserted experiment for scenario {idx_scenario}, "
f"init_removal {idx_init_removal}!")
def add_experiment_1(self, idx_scenario, idx_init_removal,
dct_lst_init_disrupt_firm_prod, g_bom,
n_max_trial, prf_size, prf_conn,
cap_limit_prob_type, cap_limit_level,
diff_new_conn, remove_t, netw_prf_n):
e = Experiment(
idx_scenario=idx_scenario,
idx_init_removal=idx_init_removal,
n_sample=int(self.dct_parameter['n_sample']),
n_iter=int(self.dct_parameter['n_iter']),
dct_lst_init_disrupt_firm_prod=dct_lst_init_disrupt_firm_prod,
g_bom=g_bom,
n_max_trial=n_max_trial,
prf_size=prf_size,
prf_conn=prf_conn,
cap_limit_prob_type=cap_limit_prob_type,
cap_limit_level=cap_limit_level,
diff_new_conn=diff_new_conn,
remove_t=remove_t,
netw_prf_n=netw_prf_n
)
db_session.add(e)
db_session.commit()
def fill_sample_table(self):
rng = random.Random(self.dct_parameter['meta_seed'])
# 根据样本数目 设置 32 位随机整数
lst_seed = [
rng.getrandbits(32)
for _ in range(int(self.dct_parameter['n_sample']))
]
lst_exp = db_session.query(Experiment).all()
lst_sample = []
for experiment in lst_exp:
# idx_sample: 1-50
for idx_sample in range(int(experiment.n_sample)):
s = Sample(e_id=experiment.id,
idx_sample=idx_sample + 1,
seed=lst_seed[idx_sample],
is_done_flag=-1)
lst_sample.append(s)
db_session.bulk_save_objects(lst_sample)
db_session.commit()
print(f'Inserted {len(lst_sample)} samples!')
def reset_db(self, force_drop=False):
# first, check if tables exist
lst_table_obj = [
Base.metadata.tables[str_table]
for str_table in ins.get_table_names()
if str_table.startswith(self.db_name_prefix)
]
self.is_exist = len(lst_table_obj) > 0
if force_drop:
self.force_drop_db(lst_table_obj)
# while is_exist:
# a_table = random.choice(lst_table_obj)
# try:
# Base.metadata.drop_all(bind=engine, tables=[a_table])
# except KeyError:
# pass
# except OperationalError:
# pass
# else:
# lst_table_obj.remove(a_table)
# print(
# f"Table {a_table.name} is dropped "
# f"for exp: {self.db_name_prefix}!!!"
# )
# finally:
# is_exist = len(lst_table_obj) > 0
if self.is_exist:
print(
f"All tables exist. No need to reset "
f"for exp: {self.db_name_prefix}."
)
# change the is_done_flag from 0 to -1
# rerun the in-finished tasks
self.is_exist_reset_flag_resset_db()
# if self.reset_flag > 0:
# if self.reset_flag == 2:
# sample = db_session.query(Sample).filter(
# Sample.is_done_flag == 0)
# elif self.reset_flag == 1:
# sample = db_session.query(Sample).filter(
# Sample.is_done_flag == 0,
# Sample.computer_name == platform.node())
# else:
# raise ValueError('Wrong reset flag')
# if sample.count() > 0:
# for s in sample:
# qry_result = db_session.query(Result).filter_by(
# s_id=s.id)
# if qry_result.count() > 0:
# db_session.query(Result).filter(s_id=s.id).delete()
# db_session.commit()
# s.is_done_flag = -1
# db_session.commit()
# print(f"Reset the sample id {s.id} flag from 0 to -1")
else:
# 不存在则重新生成所有的表结构
Base.metadata.create_all(bind=engine)
self.init_tables()
print(
f"All tables are just created and initialized "
f"for exp: {self.db_name_prefix}."
)
def force_drop_db(self, lst_table_obj):
self.is_exist = len(lst_table_obj) > 0
while self.is_exist:
a_table = random.choice(lst_table_obj)
try:
Base.metadata.drop_all(bind=engine, tables=[a_table])
except KeyError:
pass
except OperationalError:
pass
else:
lst_table_obj.remove(a_table)
print(
f"Table {a_table.name} is dropped "
f"for exp: {self.db_name_prefix}!!!"
)
finally:
self.is_exist = len(lst_table_obj) > 0
def is_exist_reset_flag_resset_db(self):
if self.reset_flag > 0:
if self.reset_flag == 2:
sample = db_session.query(Sample).filter(
Sample.is_done_flag == 0)
elif self.reset_flag == 1:
sample = db_session.query(Sample).filter(
Sample.is_done_flag == 0,
Sample.computer_name == platform.node())
else:
raise ValueError('Wrong reset flag')
if sample.count() > 0:
for s in sample:
qry_result = db_session.query(Result).filter_by(
s_id=s.id)
if qry_result.count() > 0:
db_session.query(Result).filter(s_id=s.id).delete()
db_session.commit()
s.is_done_flag = -1
db_session.commit()
print(f"Reset the sample id {s.id} flag from 0 to -1")
def prepare_list_sample(self):
# 为了符合前面 重置表里面存在 重置本机 或者重置全部 或者不重置的部分 这个部分的 关于样本运行也得重新拿出来
# 查找一个风险事件中 50 个样本
res = db_session.execute(
text(f"SELECT count(*) FROM {self.db_name_prefix}_sample s, "
f"{self.db_name_prefix}_experiment e WHERE s.e_id=e.id"
)).scalar()
# 控制 n_sample数量 作为后面的参数
n_sample = 0 if res is None else res
print(f'There are a total of {n_sample} samples.')
# 查找 is_done_flag = -1 也就是没有运行的 样本 运行后会改为0
res = db_session.execute(
text(f"SELECT id FROM {self.db_name_prefix}_sample "
f"WHERE is_done_flag = -1"
))
for row in res:
s_id = row[0]
self.lst_saved_s_id.append(s_id)
@staticmethod
def select_random_sample(lst_s_id):
while 1:
if len(lst_s_id) == 0:
return None
s_id = random.choice(lst_s_id)
lst_s_id.remove(s_id)
res = db_session.query(Sample).filter(Sample.id == int(s_id),
Sample.is_done_flag == -1)
if res.count() == 1:
return res[0]
def fetch_a_sample(self, s_id=None):
# 由Computation 调用 返回 sample对象 同时给出 2中 指定访问模式 抓取特定的 样本 通过s_id
# 默认访问 flag为-1的 lst_saved_s_id
if s_id is not None:
res = db_session.query(Sample).filter(Sample.id == int(s_id))
if res.count() == 0:
return None
else:
return res[0]
sample = self.select_random_sample(self.lst_saved_s_id)
if sample is not None:
return sample
return None
@staticmethod
def lock_the_sample(sample: Sample):
sample.is_done_flag, sample.computer_name = 0, platform.node()
db_session.commit()
if __name__ == '__main__':
print("Testing the database connection...")
try:
controller_db = ControllerDB('test')
Base.metadata.create_all(bind=engine)
except Exception as e:
print("Failed to connect to the database!")
print(e)
exit(1)