# -*- coding: utf-8 -*- from orm import db_session, engine, Base, ins from orm import Experiment, Sample, Product, Firm from sqlalchemy.exc import OperationalError import yaml import random import numpy as np import platform class ControllerDB: dct_parameter = None is_test: bool = None db_name_prefix: str = None reset_flag: int lst_saved_s_id_3: list lst_saved_s_id_1_2: list # n_sample_1_2: int def __init__(self, prefix, reset_flag=0): with open('conf_experiment.yaml') as yaml_file: dct_conf_experiment = yaml.full_load(yaml_file) self.is_test = prefix == 'test' self.db_name_prefix = prefix dct_para_in_test = dct_conf_experiment['test'] if self.is_test else dct_conf_experiment['not_test'] self.dct_parameter = {'meta_seed': dct_conf_experiment['meta_seed'], 'experiment': dct_conf_experiment['experiment'], **dct_conf_experiment['fixed'], **dct_conf_experiment['default'], **dct_para_in_test} self.reset_flag = reset_flag # 0, not reset; 1, reset self; 2, reset all self.lst_saved_s_id_1_2, self.lst_saved_s_id_3 = [], [] def init_tables(self): self.fill_experiment_table() self.fill_sample_table() @staticmethod def get_lst_of_range(str_range: str): s1, s2, s3 = tuple(str_range.split(',')) return list(np.linspace(float(s1), float(s2), num=int((float(s2) - float(s1)) / float(s3)) + 1)) def fill_experiment_table(self): # prepare the list of lambda tier lst_lambda = self.get_lst_of_range(self.dct_parameter['experiment'][1]['range_lambda_tier']) # prepare the list of alpha_2nd_country lst_beta_developed = self.get_lst_of_range(self.dct_parameter['experiment'][3]['range_flt_beta_developed']) # prepare the list of tariff_percentage lst_tariff = self.get_lst_of_range(self.dct_parameter['experiment'][3]['range_tariff_percentage']) # prepare the default values is_eliminated = int(self.dct_parameter['is_eliminated']) beta_developed = float(self.dct_parameter['flt_beta_developed']) tariff_percentage_1 = tariff_percentage_2 = float(self.dct_parameter['tariff_percentage']) for idx_scenario in self.dct_parameter['experiment'].keys(): n_exp = 0 # if idx_scenario == 1: # add S1 experiments # n_exp = self.add_experiment_1(idx_scenario, lst_lambda, is_eliminated, # beta_developed, tariff_percentage_1, tariff_percentage_2) # if idx_scenario == 2: # add S2 experiments # n_exp = self.add_experiment_1(idx_scenario, lst_lambda, # int(self.dct_parameter['experiment'][idx_scenario]['is_eliminated']), # beta_developed, tariff_percentage_1, tariff_percentage_2) if idx_scenario == 3: # int_eliminated = int(self.dct_parameter['experiment'][idx_scenario-1]['is_eliminated']) int_eliminated = is_eliminated # is_eliminated is 0 at default, so stop eliminating under S3 for beta_developed in lst_beta_developed: # for beta_developed in [0.5]: # fix beta as 0.5 for tariff_percentage_1 in lst_tariff: for tariff_percentage_2 in lst_tariff: # fix lambda as 0.5 n_exp += self.add_experiment_1(idx_scenario, [0.5], int_eliminated, beta_developed, tariff_percentage_1, tariff_percentage_2) print(f'Inserted {n_exp} experiments for exp {idx_scenario}!') def add_experiment_1(self, idx_exp, lst_lambda, is_eliminated, flt_beta_developed, tariff_percentage_1: float, tariff_percentage_2: float): lst_exp = [] for lambda_tier in lst_lambda: e = Experiment(idx_exp=idx_exp, int_n_country=int(self.dct_parameter['int_n_country']), max_int_n_supplier=int(self.dct_parameter['max_int_n_supplier']), int_n_product=int(self.dct_parameter['int_n_product']), int_n_firm_per_product_per_country=int( self.dct_parameter['int_n_firm_per_product_per_country']), flt_demand_total=float(self.dct_parameter['flt_demand_total']), flt_bm_price_ratio=float(self.dct_parameter['flt_bm_price_ratio']), flt_beta_developing=float(self.dct_parameter['flt_beta_developing']), n_sample=int(self.dct_parameter['n_sample']), n_iter=int(self.dct_parameter['n_iter']), is_eliminated=is_eliminated, flt_beta_developed=flt_beta_developed, tariff_percentage_1=tariff_percentage_1, tariff_percentage_2=tariff_percentage_2, lambda_tier=float(lambda_tier)) lst_exp.append(e) db_session.bulk_save_objects(lst_exp) db_session.commit() return len(lst_exp) def fill_sample_table(self): rng = random.Random(self.dct_parameter['meta_seed']) lst_seed = [rng.getrandbits(32) for _ in range(int(self.dct_parameter['n_sample']))] lst_exp = db_session.query(Experiment).all() lst_sample = [] for experiment in lst_exp: for idx_sample in range(int(experiment.n_sample)): s = Sample(e_id=experiment.id, idx_sample=idx_sample+1, seed=lst_seed[idx_sample], is_done_flag=-1) lst_sample.append(s) db_session.bulk_save_objects(lst_sample) db_session.commit() print(f'Inserted {len(lst_sample)} samples!') def reset_db(self, force_drop=False): # first, check if tables exist lst_table_obj = [Base.metadata.tables[str_table] for str_table in ins.get_table_names() if str_table.startswith(self.db_name_prefix)] is_exist = len(lst_table_obj) > 0 if force_drop: while is_exist: a_table = random.choice(lst_table_obj) try: Base.metadata.drop_all(bind=engine, tables=[a_table]) except KeyError: pass except OperationalError: pass else: lst_table_obj.remove(a_table) print(f"Table {a_table.name} is dropped for exp: {self.db_name_prefix}!!!") finally: is_exist = len(lst_table_obj) > 0 if is_exist: print(f"All tables exist. No need to reset for exp: {self.db_name_prefix}.") # change the is_done_flag from 0 to -1, to rerun the in-finished tasks if self.reset_flag > 0: if self.reset_flag == 2: result = db_session.query(Sample).filter(Sample.is_done_flag == 0) elif self.reset_flag == 1: result = db_session.query(Sample).filter(Sample.is_done_flag == 0, Sample.computer_name == platform.node()) else: raise ValueError('Wrong reset flag') if result.count() > 0: for res in result: qry_product = db_session.query(Product).filter_by(s_id=res.id) if qry_product.count() > 0: for p in qry_product: db_session.query(Firm).filter(Firm.p_id == p.id).delete() db_session.commit() db_session.query(Product).filter(Product.id == p.id).delete() db_session.commit() res.is_done_flag = -1 db_session.commit() print(f"Reset the task id {res.id} flag from 0 to -1") else: Base.metadata.create_all() self.init_tables() print(f"All tables are just created and initialized for exp: {self.db_name_prefix}.") def prepare_list_sample(self): res = db_session.execute(f'''SELECT count(*) FROM {self.db_name_prefix}_sample s, {self.db_name_prefix}_experiment e WHERE s.e_id=e.id and e.idx_exp < 3''').scalar() n_sample_1_2 = 0 if res is None else res print(f'There are {n_sample_1_2} sample for exp 1 and 2.') res = db_session.execute(f'SELECT id FROM {self.db_name_prefix}_sample WHERE is_done_flag = -1') for row in res: s_id = row[0] if s_id <= n_sample_1_2: self.lst_saved_s_id_1_2.append(s_id) else: self.lst_saved_s_id_3.append(s_id) print(f'Left: {len(self.lst_saved_s_id_1_2)} for exp 1 and 2; {len(self.lst_saved_s_id_3)} for exp 3') @staticmethod def select_random_sample(lst_s_id): while 1: if len(lst_s_id) == 0: return None s_id = random.choice(lst_s_id) lst_s_id.remove(s_id) res = db_session.query(Sample).filter(Sample.id == int(s_id), Sample.is_done_flag == -1) if res.count() == 1: return res[0] def fetch_a_sample(self, s_id=None): if s_id is not None: res = db_session.query(Sample).filter(Sample.id == int(s_id)) if res.count() == 0: return None else: return res[0] sample = self.select_random_sample(self.lst_saved_s_id_1_2) if sample is not None: return sample sample = self.select_random_sample(self.lst_saved_s_id_3) if sample is not None: return sample return None @staticmethod def lock_the_sample(sample: Sample): sample.is_done_flag, sample.computer_name = 0, platform.node() db_session.commit() if __name__ == '__main__': # pprint.pprint(dct_exp_config) # pprint.pprint(dct_conf_problem) db = ControllerDB('first') ratio = db_session.execute('SELECT COUNT(*) / 332750 FROM first_sample s WHERE s.is_done_flag = 1').scalar() print(ratio) # db.fill_experiment_table() # print(db.dct_parameter) # db.init_tables() # db.fill_sample_table() # pprint.pprint(dct_conf_exp) # db.update_bi() # db.reset_db(force_drop=True) # db.prepare_list_sample() # # for i in range(1000): # if i % 10 == 0: # print(i) # print(len(db.lst_saved_s_id_1_2), len(db.lst_saved_s_id_3)) # r = db.fetch_a_sample() # if i % 10 == 0: # print(len(db.lst_saved_s_id_1_2), len(db.lst_saved_s_id_3)) # print(r, r.experiment.idx_exp) # if i == 400: # print() # pass