# -*- coding: utf-8 -*- from orm import db_session, engine, Base, ins from orm import Experiment, Sample, Result from sqlalchemy.exc import OperationalError from sqlalchemy import text import yaml import random import pandas as pd import platform import networkx as nx import json import pickle class ControllerDB: dct_parameter = None is_test: bool = None db_name_prefix: str = None reset_flag: int lst_saved_s_id: list def __init__(self, prefix, reset_flag=0): with open('conf_experiment.yaml') as yaml_file: dct_conf_experiment = yaml.full_load(yaml_file) assert prefix in ['test', 'without_exp', 'with_exp'], \ "db name not in test, without_exp, with_exp" self.is_test = prefix == 'test' self.is_with_exp = \ False if prefix == 'test' or prefix == 'without_exp' else True self.db_name_prefix = prefix dct_para_in_test = dct_conf_experiment[ 'test'] if self.is_test else dct_conf_experiment['not_test'] self.dct_parameter = { 'meta_seed': dct_conf_experiment['meta_seed'], **dct_para_in_test } print(self.dct_parameter) # 0, not reset; 1, reset self; 2, reset all self.reset_flag = reset_flag self.lst_saved_s_id = [] def init_tables(self): self.fill_experiment_table() self.fill_sample_table() def fill_experiment_table(self): Firm = pd.read_csv("input_data/Firm_amended.csv") Firm['Code'] = Firm['Code'].astype('string') Firm.fillna(0, inplace=True) # fill dct_lst_init_disrupt_firm_prod list_dct = [] if self.is_with_exp: with open('SQL_export_high_risk_setting.sql', 'r') as f: str_sql = f.read() result = pd.read_sql(sql=str_sql, con=engine) result['dct_lst_init_disrupt_firm_prod'] = \ result['dct_lst_init_disrupt_firm_prod'].apply( lambda x: pickle.loads(x)) list_dct = result['dct_lst_init_disrupt_firm_prod'].to_list() else: for _, row in Firm.iterrows(): code = row['Code'] row = row['1':] for product_code in row.index[row == 1].to_list(): dct = {code: [product_code]} list_dct.append(dct) # fill g_bom BomNodes = pd.read_csv('input_data/BomNodes.csv', index_col=0) BomNodes.set_index('Code', inplace=True) BomCateNet = pd.read_csv('input_data/BomCateNet.csv', index_col=0) BomCateNet.fillna(0, inplace=True) g_bom = nx.from_pandas_adjacency(BomCateNet.T, create_using=nx.MultiDiGraph()) bom_labels_dict = {} for code in g_bom.nodes: bom_labels_dict[code] = BomNodes.loc[code].to_dict() nx.set_node_attributes(g_bom, bom_labels_dict) g_product_js = json.dumps(nx.adjacency_data(g_bom)) # insert exp df_xv = pd.read_csv( "input_data/" f"xv_{'with_exp' if self.is_with_exp else 'without_exp'}.csv", index_col=None) # read the OA table df_oa = pd.read_csv( "input_data/" f"oa_{'with_exp' if self.is_with_exp else 'without_exp'}.csv", index_col=None) df_oa = df_oa.iloc[:, 0:df_xv.shape[1]] for idx_scenario, row in df_oa.iterrows(): dct_exp_para = {} for idx_col, para_level in enumerate(row): dct_exp_para[df_xv.columns[idx_col]] = \ df_xv.iloc[para_level, idx_col] # different initial removal for idx_init_removal, dct_init_removal in enumerate(list_dct): self.add_experiment_1(idx_scenario, idx_init_removal, dct_init_removal, g_product_js, **dct_exp_para) print(f"Inserted experiment for scenario {idx_scenario}, " f"init_removal {idx_init_removal}!") def add_experiment_1(self, idx_scenario, idx_init_removal, dct_lst_init_disrupt_firm_prod, g_bom, n_max_trial, prf_size, prf_conn, cap_limit_prob_type, cap_limit_level, diff_new_conn, remove_t, netw_prf_n): e = Experiment( idx_scenario=idx_scenario, idx_init_removal=idx_init_removal, n_sample=int(self.dct_parameter['n_sample']), n_iter=int(self.dct_parameter['n_iter']), dct_lst_init_disrupt_firm_prod=dct_lst_init_disrupt_firm_prod, g_bom=g_bom, n_max_trial=n_max_trial, prf_size=prf_size, prf_conn=prf_conn, cap_limit_prob_type=cap_limit_prob_type, cap_limit_level=cap_limit_level, diff_new_conn=diff_new_conn, remove_t=remove_t, netw_prf_n=netw_prf_n ) db_session.add(e) db_session.commit() def fill_sample_table(self): rng = random.Random(self.dct_parameter['meta_seed']) lst_seed = [ rng.getrandbits(32) for _ in range(int(self.dct_parameter['n_sample'])) ] lst_exp = db_session.query(Experiment).all() lst_sample = [] for experiment in lst_exp: for idx_sample in range(int(experiment.n_sample)): s = Sample(e_id=experiment.id, idx_sample=idx_sample + 1, seed=lst_seed[idx_sample], is_done_flag=-1) lst_sample.append(s) db_session.bulk_save_objects(lst_sample) db_session.commit() print(f'Inserted {len(lst_sample)} samples!') def reset_db(self, force_drop=False): # first, check if tables exist lst_table_obj = [ Base.metadata.tables[str_table] for str_table in ins.get_table_names() if str_table.startswith(self.db_name_prefix) ] is_exist = len(lst_table_obj) > 0 if force_drop: while is_exist: a_table = random.choice(lst_table_obj) try: Base.metadata.drop_all(bind=engine, tables=[a_table]) except KeyError: pass except OperationalError: pass else: lst_table_obj.remove(a_table) print( f"Table {a_table.name} is dropped " f"for exp: {self.db_name_prefix}!!!" ) finally: is_exist = len(lst_table_obj) > 0 if is_exist: print( f"All tables exist. No need to reset " f"for exp: {self.db_name_prefix}." ) # change the is_done_flag from 0 to -1 # rerun the in-finished tasks if self.reset_flag > 0: if self.reset_flag == 2: sample = db_session.query(Sample).filter( Sample.is_done_flag == 0) elif self.reset_flag == 1: sample = db_session.query(Sample).filter( Sample.is_done_flag == 0, Sample.computer_name == platform.node()) else: raise ValueError('Wrong reset flag') if sample.count() > 0: for s in sample: qry_result = db_session.query(Result).filter_by( s_id=s.id) if qry_result.count() > 0: db_session.query(Result).filter(s_id=s.id).delete() db_session.commit() s.is_done_flag = -1 db_session.commit() print(f"Reset the sample id {s.id} flag from 0 to -1") else: Base.metadata.create_all(bind=engine) self.init_tables() print( f"All tables are just created and initialized " f"for exp: {self.db_name_prefix}." ) def prepare_list_sample(self): res = db_session.execute( text(f"SELECT count(*) FROM {self.db_name_prefix}_sample s, " f"{self.db_name_prefix}_experiment e WHERE s.e_id=e.id" )).scalar() n_sample = 0 if res is None else res print(f'There are a total of {n_sample} samples.') res = db_session.execute( text(f"SELECT id FROM {self.db_name_prefix}_sample " f"WHERE is_done_flag = -1" )) for row in res: s_id = row[0] self.lst_saved_s_id.append(s_id) @staticmethod def select_random_sample(lst_s_id): while 1: if len(lst_s_id) == 0: return None s_id = random.choice(lst_s_id) lst_s_id.remove(s_id) res = db_session.query(Sample).filter(Sample.id == int(s_id), Sample.is_done_flag == -1) if res.count() == 1: return res[0] def fetch_a_sample(self, s_id=None): if s_id is not None: res = db_session.query(Sample).filter(Sample.id == int(s_id)) if res.count() == 0: return None else: return res[0] sample = self.select_random_sample(self.lst_saved_s_id) if sample is not None: return sample return None @staticmethod def lock_the_sample(sample: Sample): sample.is_done_flag, sample.computer_name = 0, platform.node() db_session.commit() if __name__ == '__main__': print("Testing the database connection...") try: controller_db = ControllerDB('test') Base.metadata.create_all(bind=engine) except Exception as e: print("Failed to connect to the database!") print(e) exit(1)