IIabm/controller_db.py

267 lines
10 KiB
Python
Raw Normal View History

2023-03-12 12:02:01 +08:00
# -*- coding: utf-8 -*-
from orm import db_session, engine, Base, ins
2023-03-13 19:47:25 +08:00
from orm import Experiment, Sample, Result
2023-03-12 12:02:01 +08:00
from sqlalchemy.exc import OperationalError
2023-03-12 22:21:39 +08:00
from sqlalchemy import text
2023-03-12 12:02:01 +08:00
import yaml
import random
2023-03-13 19:47:25 +08:00
import pandas as pd
2023-03-12 12:02:01 +08:00
import platform
2023-03-16 15:07:43 +08:00
import networkx as nx
import json
2023-05-21 17:05:06 +08:00
import pickle
2023-03-12 12:02:01 +08:00
2023-05-15 13:44:21 +08:00
2023-03-12 12:02:01 +08:00
class ControllerDB:
dct_parameter = None
is_test: bool = None
db_name_prefix: str = None
reset_flag: int
2023-03-14 17:51:17 +08:00
lst_saved_s_id: list
2023-03-12 12:02:01 +08:00
def __init__(self, prefix, reset_flag=0):
with open('conf_experiment.yaml') as yaml_file:
dct_conf_experiment = yaml.full_load(yaml_file)
2023-06-14 18:00:08 +08:00
assert prefix in ['test', 'without_exp', 'with_exp'], \
"db name not in test, without_exp, with_exp"
2023-03-12 12:02:01 +08:00
self.is_test = prefix == 'test'
2023-06-14 18:00:08 +08:00
self.is_with_exp = \
False if prefix == 'test' or prefix == 'without_exp' else True
2023-03-12 12:02:01 +08:00
self.db_name_prefix = prefix
2023-03-13 19:47:25 +08:00
dct_para_in_test = dct_conf_experiment[
'test'] if self.is_test else dct_conf_experiment['not_test']
self.dct_parameter = {
'meta_seed': dct_conf_experiment['meta_seed'],
**dct_para_in_test
}
print(self.dct_parameter)
2023-05-15 13:44:21 +08:00
# 0, not reset; 1, reset self; 2, reset all
self.reset_flag = reset_flag
2023-03-13 19:47:25 +08:00
self.lst_saved_s_id = []
2023-03-12 12:02:01 +08:00
def init_tables(self):
self.fill_experiment_table()
self.fill_sample_table()
def fill_experiment_table(self):
2023-03-13 19:47:25 +08:00
Firm = pd.read_csv("Firm_amended.csv")
Firm['Code'] = Firm['Code'].astype('string')
Firm.fillna(0, inplace=True)
2023-03-16 15:07:43 +08:00
# fill dct_lst_init_disrupt_firm_prod
2023-06-11 20:28:51 +08:00
list_dct = []
2023-06-14 18:00:08 +08:00
if self.is_with_exp:
2023-07-03 12:40:10 +08:00
with open('SQL_export_high_risk_setting.sql', 'r') as f:
str_sql = f.read()
2023-06-11 20:28:51 +08:00
result = pd.read_sql(sql=str_sql, con=engine)
result['dct_lst_init_disrupt_firm_prod'] = \
result['dct_lst_init_disrupt_firm_prod'].apply(
2023-06-11 20:28:51 +08:00
lambda x: pickle.loads(x))
2023-07-03 12:42:20 +08:00
list_dct = result['dct_lst_init_disrupt_firm_prod'].to_list()
2023-06-14 18:00:08 +08:00
else:
for _, row in Firm.iterrows():
code = row['Code']
row = row['1':]
for product_code in row.index[row == 1].to_list():
dct = {code: [product_code]}
list_dct.append(dct)
# list_dct = [{'140': ['1.4.5.1']}]
# list_dct = [{'133': ['1.4.4.1']}]
2023-07-02 15:51:01 +08:00
# list_dct = [{'2': ['1.1.3']}]
2023-07-02 23:30:42 +08:00
# list_dct = [{'135': ['1.3.2.1']}]
2023-07-27 12:50:52 +08:00
list_dct = [{'79': ['2.1.3.4']}]
2023-06-11 11:17:41 +08:00
# list_dct = [{'99': ['1.3.3']}]
# list_dct = [{'41': ['1.4.5']}]
2023-03-16 15:07:43 +08:00
# fill g_bom
BomNodes = pd.read_csv('BomNodes.csv', index_col=0)
BomNodes.set_index('Code', inplace=True)
BomCateNet = pd.read_csv('BomCateNet.csv', index_col=0)
BomCateNet.fillna(0, inplace=True)
g_bom = nx.from_pandas_adjacency(BomCateNet.T,
create_using=nx.MultiDiGraph())
bom_labels_dict = {}
for code in g_bom.nodes:
bom_labels_dict[code] = BomNodes.loc[code].to_dict()
nx.set_node_attributes(g_bom, bom_labels_dict)
g_product_js = json.dumps(nx.adjacency_data(g_bom))
# insert exp
2023-06-14 18:00:08 +08:00
df_xv = pd.read_csv(
f"xv_{'with_exp' if self.is_with_exp else 'without_exp'}.csv",
index_col=None)
2023-05-15 16:19:05 +08:00
# read the OA table
2023-06-14 18:00:08 +08:00
df_oa = pd.read_csv(
f"oa_{'with_exp' if self.is_with_exp else 'without_exp'}.csv",
index_col=None)
2023-06-10 20:56:34 +08:00
df_oa = df_oa.iloc[:, 0:df_xv.shape[1]]
2023-05-15 16:19:05 +08:00
for idx_scenario, row in df_oa.iterrows():
dct_exp_para = {}
for idx_col, para_level in enumerate(row):
dct_exp_para[df_xv.columns[idx_col]] = \
df_xv.iloc[para_level, idx_col]
# different initial removal
for idx_init_removal, dct_init_removal in enumerate(list_dct):
self.add_experiment_1(idx_scenario,
idx_init_removal,
dct_init_removal,
g_product_js,
**dct_exp_para)
print(f"Inserted experiment for scenario {idx_scenario}, "
f"init_removal {idx_init_removal}!")
2023-03-13 19:47:25 +08:00
2023-05-15 16:19:05 +08:00
def add_experiment_1(self, idx_scenario, idx_init_removal,
dct_lst_init_disrupt_firm_prod, g_bom,
2023-06-14 18:00:08 +08:00
n_max_trial, prf_size, prf_conn,
2023-06-10 21:26:35 +08:00
cap_limit_prob_type, cap_limit_level,
2023-07-28 15:14:59 +08:00
diff_new_conn,
proactive_ratio, remove_t, netw_prf_n):
2023-03-13 19:47:25 +08:00
e = Experiment(
2023-05-15 16:19:05 +08:00
idx_scenario=idx_scenario,
idx_init_removal=idx_init_removal,
2023-03-13 19:47:25 +08:00
n_sample=int(self.dct_parameter['n_sample']),
n_iter=int(self.dct_parameter['n_iter']),
dct_lst_init_disrupt_firm_prod=dct_lst_init_disrupt_firm_prod,
2023-05-15 16:19:05 +08:00
g_bom=g_bom,
2023-05-15 18:18:41 +08:00
n_max_trial=n_max_trial,
2023-06-14 18:00:08 +08:00
prf_size=prf_size,
prf_conn=prf_conn,
2023-06-10 21:26:35 +08:00
cap_limit_prob_type=cap_limit_prob_type,
cap_limit_level=cap_limit_level,
2023-05-15 18:18:41 +08:00
diff_new_conn=diff_new_conn,
2023-06-14 18:00:08 +08:00
proactive_ratio=proactive_ratio,
remove_t=remove_t,
2023-06-14 18:00:08 +08:00
netw_prf_n=netw_prf_n
2023-05-15 18:18:41 +08:00
)
2023-03-13 19:47:25 +08:00
db_session.add(e)
2023-03-12 12:02:01 +08:00
db_session.commit()
def fill_sample_table(self):
rng = random.Random(self.dct_parameter['meta_seed'])
2023-03-13 19:47:25 +08:00
lst_seed = [
rng.getrandbits(32)
for _ in range(int(self.dct_parameter['n_sample']))
]
2023-03-12 12:02:01 +08:00
lst_exp = db_session.query(Experiment).all()
lst_sample = []
for experiment in lst_exp:
for idx_sample in range(int(experiment.n_sample)):
s = Sample(e_id=experiment.id,
2023-03-13 19:47:25 +08:00
idx_sample=idx_sample + 1,
2023-03-12 12:02:01 +08:00
seed=lst_seed[idx_sample],
is_done_flag=-1)
lst_sample.append(s)
db_session.bulk_save_objects(lst_sample)
db_session.commit()
print(f'Inserted {len(lst_sample)} samples!')
def reset_db(self, force_drop=False):
# first, check if tables exist
2023-03-13 19:47:25 +08:00
lst_table_obj = [
Base.metadata.tables[str_table]
for str_table in ins.get_table_names()
if str_table.startswith(self.db_name_prefix)
]
2023-03-12 12:02:01 +08:00
is_exist = len(lst_table_obj) > 0
if force_drop:
while is_exist:
a_table = random.choice(lst_table_obj)
try:
Base.metadata.drop_all(bind=engine, tables=[a_table])
except KeyError:
pass
except OperationalError:
pass
else:
lst_table_obj.remove(a_table)
2023-03-13 19:47:25 +08:00
print(
2023-05-15 13:44:21 +08:00
f"Table {a_table.name} is dropped "
f"for exp: {self.db_name_prefix}!!!"
2023-03-13 19:47:25 +08:00
)
2023-03-12 12:02:01 +08:00
finally:
is_exist = len(lst_table_obj) > 0
if is_exist:
2023-03-13 19:47:25 +08:00
print(
2023-05-15 13:44:21 +08:00
f"All tables exist. No need to reset "
f"for exp: {self.db_name_prefix}."
2023-03-13 19:47:25 +08:00
)
2023-05-15 13:44:21 +08:00
# change the is_done_flag from 0 to -1
# rerun the in-finished tasks
2023-03-12 12:02:01 +08:00
if self.reset_flag > 0:
if self.reset_flag == 2:
2023-03-13 19:47:25 +08:00
sample = db_session.query(Sample).filter(
Sample.is_done_flag == 0)
2023-03-12 12:02:01 +08:00
elif self.reset_flag == 1:
2023-03-13 19:47:25 +08:00
sample = db_session.query(Sample).filter(
Sample.is_done_flag == 0,
Sample.computer_name == platform.node())
2023-03-12 12:02:01 +08:00
else:
raise ValueError('Wrong reset flag')
2023-03-13 19:47:25 +08:00
if sample.count() > 0:
for s in sample:
qry_result = db_session.query(Result).filter_by(
s_id=s.id)
if qry_result.count() > 0:
2023-05-15 13:44:21 +08:00
db_session.query(Result).filter(s_id=s.id).delete()
2023-03-13 19:47:25 +08:00
db_session.commit()
s.is_done_flag = -1
2023-03-12 12:02:01 +08:00
db_session.commit()
2023-03-13 19:47:25 +08:00
print(f"Reset the sample id {s.id} flag from 0 to -1")
2023-03-12 12:02:01 +08:00
else:
2023-03-12 22:21:39 +08:00
Base.metadata.create_all(bind=engine)
2023-03-12 12:02:01 +08:00
self.init_tables()
2023-03-13 19:47:25 +08:00
print(
2023-05-15 13:44:21 +08:00
f"All tables are just created and initialized "
f"for exp: {self.db_name_prefix}."
2023-03-13 19:47:25 +08:00
)
2023-03-12 12:02:01 +08:00
def prepare_list_sample(self):
2023-03-13 19:47:25 +08:00
res = db_session.execute(
2023-05-15 13:44:21 +08:00
text(f"SELECT count(*) FROM {self.db_name_prefix}_sample s, "
f"{self.db_name_prefix}_experiment e WHERE s.e_id=e.id"
2023-03-13 19:47:25 +08:00
)).scalar()
n_sample = 0 if res is None else res
print(f'There are a total of {n_sample} samples.')
res = db_session.execute(
2023-05-15 13:44:21 +08:00
text(f"SELECT id FROM {self.db_name_prefix}_sample "
f"WHERE is_done_flag = -1"
))
2023-03-12 12:02:01 +08:00
for row in res:
s_id = row[0]
2023-03-13 19:47:25 +08:00
self.lst_saved_s_id.append(s_id)
2023-03-12 12:02:01 +08:00
@staticmethod
def select_random_sample(lst_s_id):
while 1:
if len(lst_s_id) == 0:
return None
s_id = random.choice(lst_s_id)
lst_s_id.remove(s_id)
2023-03-13 19:47:25 +08:00
res = db_session.query(Sample).filter(Sample.id == int(s_id),
Sample.is_done_flag == -1)
2023-03-12 12:02:01 +08:00
if res.count() == 1:
return res[0]
def fetch_a_sample(self, s_id=None):
if s_id is not None:
res = db_session.query(Sample).filter(Sample.id == int(s_id))
if res.count() == 0:
return None
else:
return res[0]
2023-03-13 19:47:25 +08:00
sample = self.select_random_sample(self.lst_saved_s_id)
2023-03-12 12:02:01 +08:00
if sample is not None:
return sample
return None
@staticmethod
def lock_the_sample(sample: Sample):
sample.is_done_flag, sample.computer_name = 0, platform.node()
db_session.commit()