238 lines
11 KiB
Python
238 lines
11 KiB
Python
|
# -*- coding: utf-8 -*-
|
||
|
from orm import db_session, engine, Base, ins
|
||
|
from orm import Experiment, Sample, Product, Firm
|
||
|
from sqlalchemy.exc import OperationalError
|
||
|
import yaml
|
||
|
import random
|
||
|
import numpy as np
|
||
|
import platform
|
||
|
|
||
|
|
||
|
class ControllerDB:
|
||
|
dct_parameter = None
|
||
|
is_test: bool = None
|
||
|
db_name_prefix: str = None
|
||
|
reset_flag: int
|
||
|
|
||
|
lst_saved_s_id_3: list
|
||
|
lst_saved_s_id_1_2: list
|
||
|
# n_sample_1_2: int
|
||
|
|
||
|
def __init__(self, prefix, reset_flag=0):
|
||
|
with open('conf_experiment.yaml') as yaml_file:
|
||
|
dct_conf_experiment = yaml.full_load(yaml_file)
|
||
|
self.is_test = prefix == 'test'
|
||
|
self.db_name_prefix = prefix
|
||
|
dct_para_in_test = dct_conf_experiment['test'] if self.is_test else dct_conf_experiment['not_test']
|
||
|
self.dct_parameter = {'meta_seed': dct_conf_experiment['meta_seed'],
|
||
|
'experiment': dct_conf_experiment['experiment'],
|
||
|
**dct_conf_experiment['fixed'], **dct_conf_experiment['default'], **dct_para_in_test}
|
||
|
self.reset_flag = reset_flag # 0, not reset; 1, reset self; 2, reset all
|
||
|
self.lst_saved_s_id_1_2, self.lst_saved_s_id_3 = [], []
|
||
|
|
||
|
def init_tables(self):
|
||
|
self.fill_experiment_table()
|
||
|
self.fill_sample_table()
|
||
|
|
||
|
@staticmethod
|
||
|
def get_lst_of_range(str_range: str):
|
||
|
s1, s2, s3 = tuple(str_range.split(','))
|
||
|
return list(np.linspace(float(s1), float(s2), num=int((float(s2) - float(s1)) / float(s3)) + 1))
|
||
|
|
||
|
def fill_experiment_table(self):
|
||
|
# prepare the list of lambda tier
|
||
|
lst_lambda = self.get_lst_of_range(self.dct_parameter['experiment'][1]['range_lambda_tier'])
|
||
|
# prepare the list of alpha_2nd_country
|
||
|
lst_beta_developed = self.get_lst_of_range(self.dct_parameter['experiment'][3]['range_flt_beta_developed'])
|
||
|
# prepare the list of tariff_percentage
|
||
|
lst_tariff = self.get_lst_of_range(self.dct_parameter['experiment'][3]['range_tariff_percentage'])
|
||
|
# prepare the default values
|
||
|
is_eliminated = int(self.dct_parameter['is_eliminated'])
|
||
|
beta_developed = float(self.dct_parameter['flt_beta_developed'])
|
||
|
tariff_percentage_1 = tariff_percentage_2 = float(self.dct_parameter['tariff_percentage'])
|
||
|
for idx_scenario in self.dct_parameter['experiment'].keys():
|
||
|
n_exp = 0
|
||
|
# if idx_scenario == 1: # add S1 experiments
|
||
|
# n_exp = self.add_experiment_1(idx_scenario, lst_lambda, is_eliminated,
|
||
|
# beta_developed, tariff_percentage_1, tariff_percentage_2)
|
||
|
# if idx_scenario == 2: # add S2 experiments
|
||
|
# n_exp = self.add_experiment_1(idx_scenario, lst_lambda,
|
||
|
# int(self.dct_parameter['experiment'][idx_scenario]['is_eliminated']),
|
||
|
# beta_developed, tariff_percentage_1, tariff_percentage_2)
|
||
|
if idx_scenario == 3:
|
||
|
# int_eliminated = int(self.dct_parameter['experiment'][idx_scenario-1]['is_eliminated'])
|
||
|
int_eliminated = is_eliminated # is_eliminated is 0 at default, so stop eliminating under S3
|
||
|
for beta_developed in lst_beta_developed:
|
||
|
# for beta_developed in [0.5]: # fix beta as 0.5
|
||
|
for tariff_percentage_1 in lst_tariff:
|
||
|
for tariff_percentage_2 in lst_tariff:
|
||
|
# fix lambda as 0.5
|
||
|
n_exp += self.add_experiment_1(idx_scenario, [0.5], int_eliminated,
|
||
|
beta_developed, tariff_percentage_1, tariff_percentage_2)
|
||
|
print(f'Inserted {n_exp} experiments for exp {idx_scenario}!')
|
||
|
|
||
|
def add_experiment_1(self, idx_exp, lst_lambda, is_eliminated, flt_beta_developed,
|
||
|
tariff_percentage_1: float, tariff_percentage_2: float):
|
||
|
lst_exp = []
|
||
|
for lambda_tier in lst_lambda:
|
||
|
e = Experiment(idx_exp=idx_exp,
|
||
|
int_n_country=int(self.dct_parameter['int_n_country']),
|
||
|
max_int_n_supplier=int(self.dct_parameter['max_int_n_supplier']),
|
||
|
int_n_product=int(self.dct_parameter['int_n_product']),
|
||
|
int_n_firm_per_product_per_country=int(
|
||
|
self.dct_parameter['int_n_firm_per_product_per_country']),
|
||
|
flt_demand_total=float(self.dct_parameter['flt_demand_total']),
|
||
|
flt_bm_price_ratio=float(self.dct_parameter['flt_bm_price_ratio']),
|
||
|
flt_beta_developing=float(self.dct_parameter['flt_beta_developing']),
|
||
|
n_sample=int(self.dct_parameter['n_sample']),
|
||
|
n_iter=int(self.dct_parameter['n_iter']),
|
||
|
is_eliminated=is_eliminated,
|
||
|
flt_beta_developed=flt_beta_developed,
|
||
|
tariff_percentage_1=tariff_percentage_1,
|
||
|
tariff_percentage_2=tariff_percentage_2,
|
||
|
lambda_tier=float(lambda_tier))
|
||
|
lst_exp.append(e)
|
||
|
db_session.bulk_save_objects(lst_exp)
|
||
|
db_session.commit()
|
||
|
return len(lst_exp)
|
||
|
|
||
|
def fill_sample_table(self):
|
||
|
rng = random.Random(self.dct_parameter['meta_seed'])
|
||
|
lst_seed = [rng.getrandbits(32) for _ in range(int(self.dct_parameter['n_sample']))]
|
||
|
lst_exp = db_session.query(Experiment).all()
|
||
|
lst_sample = []
|
||
|
for experiment in lst_exp:
|
||
|
for idx_sample in range(int(experiment.n_sample)):
|
||
|
s = Sample(e_id=experiment.id,
|
||
|
idx_sample=idx_sample+1,
|
||
|
seed=lst_seed[idx_sample],
|
||
|
is_done_flag=-1)
|
||
|
lst_sample.append(s)
|
||
|
db_session.bulk_save_objects(lst_sample)
|
||
|
db_session.commit()
|
||
|
print(f'Inserted {len(lst_sample)} samples!')
|
||
|
|
||
|
def reset_db(self, force_drop=False):
|
||
|
# first, check if tables exist
|
||
|
lst_table_obj = [Base.metadata.tables[str_table] for str_table in ins.get_table_names()
|
||
|
if str_table.startswith(self.db_name_prefix)]
|
||
|
is_exist = len(lst_table_obj) > 0
|
||
|
if force_drop:
|
||
|
while is_exist:
|
||
|
a_table = random.choice(lst_table_obj)
|
||
|
try:
|
||
|
Base.metadata.drop_all(bind=engine, tables=[a_table])
|
||
|
except KeyError:
|
||
|
pass
|
||
|
except OperationalError:
|
||
|
pass
|
||
|
else:
|
||
|
lst_table_obj.remove(a_table)
|
||
|
print(f"Table {a_table.name} is dropped for exp: {self.db_name_prefix}!!!")
|
||
|
finally:
|
||
|
is_exist = len(lst_table_obj) > 0
|
||
|
|
||
|
if is_exist:
|
||
|
print(f"All tables exist. No need to reset for exp: {self.db_name_prefix}.")
|
||
|
# change the is_done_flag from 0 to -1, to rerun the in-finished tasks
|
||
|
if self.reset_flag > 0:
|
||
|
if self.reset_flag == 2:
|
||
|
result = db_session.query(Sample).filter(Sample.is_done_flag == 0)
|
||
|
elif self.reset_flag == 1:
|
||
|
result = db_session.query(Sample).filter(Sample.is_done_flag == 0,
|
||
|
Sample.computer_name == platform.node())
|
||
|
else:
|
||
|
raise ValueError('Wrong reset flag')
|
||
|
if result.count() > 0:
|
||
|
for res in result:
|
||
|
qry_product = db_session.query(Product).filter_by(s_id=res.id)
|
||
|
if qry_product.count() > 0:
|
||
|
for p in qry_product:
|
||
|
db_session.query(Firm).filter(Firm.p_id == p.id).delete()
|
||
|
db_session.commit()
|
||
|
db_session.query(Product).filter(Product.id == p.id).delete()
|
||
|
db_session.commit()
|
||
|
res.is_done_flag = -1
|
||
|
db_session.commit()
|
||
|
print(f"Reset the task id {res.id} flag from 0 to -1")
|
||
|
else:
|
||
|
Base.metadata.create_all()
|
||
|
self.init_tables()
|
||
|
print(f"All tables are just created and initialized for exp: {self.db_name_prefix}.")
|
||
|
|
||
|
def prepare_list_sample(self):
|
||
|
res = db_session.execute(f'''SELECT count(*) FROM {self.db_name_prefix}_sample s,
|
||
|
{self.db_name_prefix}_experiment e WHERE s.e_id=e.id and e.idx_exp < 3''').scalar()
|
||
|
n_sample_1_2 = 0 if res is None else res
|
||
|
print(f'There are {n_sample_1_2} sample for exp 1 and 2.')
|
||
|
res = db_session.execute(f'SELECT id FROM {self.db_name_prefix}_sample WHERE is_done_flag = -1')
|
||
|
for row in res:
|
||
|
s_id = row[0]
|
||
|
if s_id <= n_sample_1_2:
|
||
|
self.lst_saved_s_id_1_2.append(s_id)
|
||
|
else:
|
||
|
self.lst_saved_s_id_3.append(s_id)
|
||
|
print(f'Left: {len(self.lst_saved_s_id_1_2)} for exp 1 and 2; {len(self.lst_saved_s_id_3)} for exp 3')
|
||
|
|
||
|
@staticmethod
|
||
|
def select_random_sample(lst_s_id):
|
||
|
while 1:
|
||
|
if len(lst_s_id) == 0:
|
||
|
return None
|
||
|
s_id = random.choice(lst_s_id)
|
||
|
lst_s_id.remove(s_id)
|
||
|
res = db_session.query(Sample).filter(Sample.id == int(s_id), Sample.is_done_flag == -1)
|
||
|
if res.count() == 1:
|
||
|
return res[0]
|
||
|
|
||
|
def fetch_a_sample(self, s_id=None):
|
||
|
if s_id is not None:
|
||
|
res = db_session.query(Sample).filter(Sample.id == int(s_id))
|
||
|
if res.count() == 0:
|
||
|
return None
|
||
|
else:
|
||
|
return res[0]
|
||
|
|
||
|
sample = self.select_random_sample(self.lst_saved_s_id_1_2)
|
||
|
if sample is not None:
|
||
|
return sample
|
||
|
|
||
|
sample = self.select_random_sample(self.lst_saved_s_id_3)
|
||
|
if sample is not None:
|
||
|
return sample
|
||
|
|
||
|
return None
|
||
|
|
||
|
@staticmethod
|
||
|
def lock_the_sample(sample: Sample):
|
||
|
sample.is_done_flag, sample.computer_name = 0, platform.node()
|
||
|
db_session.commit()
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
# pprint.pprint(dct_exp_config)
|
||
|
# pprint.pprint(dct_conf_problem)
|
||
|
db = ControllerDB('first')
|
||
|
ratio = db_session.execute('SELECT COUNT(*) / 332750 FROM first_sample s WHERE s.is_done_flag = 1').scalar()
|
||
|
print(ratio)
|
||
|
# db.fill_experiment_table()
|
||
|
# print(db.dct_parameter)
|
||
|
# db.init_tables()
|
||
|
# db.fill_sample_table()
|
||
|
# pprint.pprint(dct_conf_exp)
|
||
|
# db.update_bi()
|
||
|
# db.reset_db(force_drop=True)
|
||
|
# db.prepare_list_sample()
|
||
|
#
|
||
|
# for i in range(1000):
|
||
|
# if i % 10 == 0:
|
||
|
# print(i)
|
||
|
# print(len(db.lst_saved_s_id_1_2), len(db.lst_saved_s_id_3))
|
||
|
# r = db.fetch_a_sample()
|
||
|
# if i % 10 == 0:
|
||
|
# print(len(db.lst_saved_s_id_1_2), len(db.lst_saved_s_id_3))
|
||
|
# print(r, r.experiment.idx_exp)
|
||
|
# if i == 400:
|
||
|
# print()
|
||
|
# pass
|