This commit is contained in:
2023-03-13 19:47:25 +08:00
parent 2f162b970b
commit 09b59d8778
10 changed files with 113 additions and 210 deletions

View File

@@ -1,10 +1,11 @@
# -*- coding: utf-8 -*-
from orm import db_session, engine, Base, ins
from orm import Experiment, Sample, Product, Firm
from orm import Experiment, Sample, Result
from sqlalchemy.exc import OperationalError
from sqlalchemy import text
import yaml
import random
import pandas as pd
import numpy as np
import platform
@@ -17,6 +18,7 @@ class ControllerDB:
lst_saved_s_id_3: list
lst_saved_s_id_1_2: list
# n_sample_1_2: int
def __init__(self, prefix, reset_flag=0):
@@ -24,88 +26,59 @@ class ControllerDB:
dct_conf_experiment = yaml.full_load(yaml_file)
self.is_test = prefix == 'test'
self.db_name_prefix = prefix
dct_para_in_test = dct_conf_experiment['test'] if self.is_test else dct_conf_experiment['not_test']
self.dct_parameter = {'meta_seed': dct_conf_experiment['meta_seed'],
'experiment': dct_conf_experiment['experiment'],
**dct_conf_experiment['fixed'], **dct_conf_experiment['default'], **dct_para_in_test}
dct_para_in_test = dct_conf_experiment[
'test'] if self.is_test else dct_conf_experiment['not_test']
self.dct_parameter = {
'meta_seed': dct_conf_experiment['meta_seed'],
**dct_para_in_test
}
print(self.dct_parameter)
self.reset_flag = reset_flag # 0, not reset; 1, reset self; 2, reset all
self.lst_saved_s_id_1_2, self.lst_saved_s_id_3 = [], []
self.lst_saved_s_id = []
def init_tables(self):
self.fill_experiment_table()
self.fill_sample_table()
@staticmethod
def get_lst_of_range(str_range: str):
s1, s2, s3 = tuple(str_range.split(','))
return list(np.linspace(float(s1), float(s2), num=int((float(s2) - float(s1)) / float(s3)) + 1))
def fill_experiment_table(self):
# prepare the list of lambda tier
lst_lambda = self.get_lst_of_range(self.dct_parameter['experiment'][1]['range_lambda_tier'])
# prepare the list of alpha_2nd_country
lst_beta_developed = self.get_lst_of_range(self.dct_parameter['experiment'][3]['range_flt_beta_developed'])
# prepare the list of tariff_percentage
lst_tariff = self.get_lst_of_range(self.dct_parameter['experiment'][3]['range_tariff_percentage'])
# prepare the default values
is_eliminated = int(self.dct_parameter['is_eliminated'])
beta_developed = float(self.dct_parameter['flt_beta_developed'])
tariff_percentage_1 = tariff_percentage_2 = float(self.dct_parameter['tariff_percentage'])
for idx_scenario in self.dct_parameter['experiment'].keys():
n_exp = 0
# if idx_scenario == 1: # add S1 experiments
# n_exp = self.add_experiment_1(idx_scenario, lst_lambda, is_eliminated,
# beta_developed, tariff_percentage_1, tariff_percentage_2)
# if idx_scenario == 2: # add S2 experiments
# n_exp = self.add_experiment_1(idx_scenario, lst_lambda,
# int(self.dct_parameter['experiment'][idx_scenario]['is_eliminated']),
# beta_developed, tariff_percentage_1, tariff_percentage_2)
if idx_scenario == 3:
# int_eliminated = int(self.dct_parameter['experiment'][idx_scenario-1]['is_eliminated'])
int_eliminated = is_eliminated # is_eliminated is 0 at default, so stop eliminating under S3
for beta_developed in lst_beta_developed:
# for beta_developed in [0.5]: # fix beta as 0.5
for tariff_percentage_1 in lst_tariff:
for tariff_percentage_2 in lst_tariff:
# fix lambda as 0.5
n_exp += self.add_experiment_1(idx_scenario, [0.5], int_eliminated,
beta_developed, tariff_percentage_1, tariff_percentage_2)
print(f'Inserted {n_exp} experiments for exp {idx_scenario}!')
Firm = pd.read_csv("Firm_amended.csv")
Firm['Code'] = Firm['Code'].astype('string')
Firm.fillna(0, inplace=True)
list_dct = []
for _, row in Firm.iterrows():
code = row['Code']
row = row['1':]
for product_code in row.index[row == 1].to_list():
dct = {code: [product_code]}
list_dct.append(dct)
for idx_exp, dct in enumerate(list_dct):
self.add_experiment_1(idx_exp, self.dct_parameter['n_max_trial'],
dct)
print(f'Inserted experiment for exp {idx_exp}!')
def add_experiment_1(self, idx_exp, lst_lambda, is_eliminated, flt_beta_developed,
tariff_percentage_1, tariff_percentage_2):
lst_exp = []
for lambda_tier in lst_lambda:
e = Experiment(idx_exp=idx_exp,
int_n_country=int(self.dct_parameter['int_n_country']),
max_int_n_supplier=int(self.dct_parameter['max_int_n_supplier']),
int_n_product=int(self.dct_parameter['int_n_product']),
int_n_firm_per_product_per_country=int(
self.dct_parameter['int_n_firm_per_product_per_country']),
flt_demand_total=float(self.dct_parameter['flt_demand_total']),
flt_bm_price_ratio=float(self.dct_parameter['flt_bm_price_ratio']),
flt_beta_developing=float(self.dct_parameter['flt_beta_developing']),
n_sample=int(self.dct_parameter['n_sample']),
n_iter=int(self.dct_parameter['n_iter']),
is_eliminated=is_eliminated,
flt_beta_developed=flt_beta_developed,
tariff_percentage_1=tariff_percentage_1,
tariff_percentage_2=tariff_percentage_2,
lambda_tier=float(lambda_tier))
lst_exp.append(e)
db_session.bulk_save_objects(lst_exp)
def add_experiment_1(self, idx_exp, n_max_trial,
dct_list_init_remove_firm_prod):
e = Experiment(
idx_exp=idx_exp,
n_sample=int(self.dct_parameter['n_sample']),
n_iter=int(self.dct_parameter['n_iter']),
n_max_trial=n_max_trial,
dct_list_init_remove_firm_prod=dct_list_init_remove_firm_prod)
db_session.add(e)
db_session.commit()
return len(lst_exp)
def fill_sample_table(self):
rng = random.Random(self.dct_parameter['meta_seed'])
lst_seed = [rng.getrandbits(32) for _ in range(int(self.dct_parameter['n_sample']))]
lst_seed = [
rng.getrandbits(32)
for _ in range(int(self.dct_parameter['n_sample']))
]
lst_exp = db_session.query(Experiment).all()
lst_sample = []
for experiment in lst_exp:
for idx_sample in range(int(experiment.n_sample)):
s = Sample(e_id=experiment.id,
idx_sample=idx_sample+1,
idx_sample=idx_sample + 1,
seed=lst_seed[idx_sample],
is_done_flag=-1)
lst_sample.append(s)
@@ -115,8 +88,11 @@ class ControllerDB:
def reset_db(self, force_drop=False):
# first, check if tables exist
lst_table_obj = [Base.metadata.tables[str_table] for str_table in ins.get_table_names()
if str_table.startswith(self.db_name_prefix)]
lst_table_obj = [
Base.metadata.tables[str_table]
for str_table in ins.get_table_names()
if str_table.startswith(self.db_name_prefix)
]
is_exist = len(lst_table_obj) > 0
if force_drop:
while is_exist:
@@ -129,51 +105,58 @@ class ControllerDB:
pass
else:
lst_table_obj.remove(a_table)
print(f"Table {a_table.name} is dropped for exp: {self.db_name_prefix}!!!")
print(
f"Table {a_table.name} is dropped for exp: {self.db_name_prefix}!!!"
)
finally:
is_exist = len(lst_table_obj) > 0
if is_exist:
print(f"All tables exist. No need to reset for exp: {self.db_name_prefix}.")
print(
f"All tables exist. No need to reset for exp: {self.db_name_prefix}."
)
# change the is_done_flag from 0 to -1, to rerun the in-finished tasks
if self.reset_flag > 0:
if self.reset_flag == 2:
result = db_session.query(Sample).filter(Sample.is_done_flag == 0)
sample = db_session.query(Sample).filter(
Sample.is_done_flag == 0)
elif self.reset_flag == 1:
result = db_session.query(Sample).filter(Sample.is_done_flag == 0,
Sample.computer_name == platform.node())
sample = db_session.query(Sample).filter(
Sample.is_done_flag == 0,
Sample.computer_name == platform.node())
else:
raise ValueError('Wrong reset flag')
if result.count() > 0:
for res in result:
qry_product = db_session.query(Product).filter_by(s_id=res.id)
if qry_product.count() > 0:
for p in qry_product:
db_session.query(Firm).filter(Firm.p_id == p.id).delete()
db_session.commit()
db_session.query(Product).filter(Product.id == p.id).delete()
db_session.commit()
res.is_done_flag = -1
if sample.count() > 0:
for s in sample:
qry_result = db_session.query(Result).filter_by(
s_id=s.id)
if qry_result.count() > 0:
db_session.query(Result).filter(s_id = s.id).delete()
db_session.commit()
s.is_done_flag = -1
db_session.commit()
print(f"Reset the task id {res.id} flag from 0 to -1")
print(f"Reset the sample id {s.id} flag from 0 to -1")
else:
Base.metadata.create_all(bind=engine)
self.init_tables()
print(f"All tables are just created and initialized for exp: {self.db_name_prefix}.")
print(
f"All tables are just created and initialized for exp: {self.db_name_prefix}."
)
def prepare_list_sample(self):
res = db_session.execute(text(f'''SELECT count(*) FROM {self.db_name_prefix}_sample s,
{self.db_name_prefix}_experiment e WHERE s.e_id=e.id and e.idx_exp < 3''')).scalar()
n_sample_1_2 = 0 if res is None else res
print(f'There are {n_sample_1_2} sample for exp 1 and 2.')
res = db_session.execute(text(f'SELECT id FROM {self.db_name_prefix}_sample WHERE is_done_flag = -1'))
res = db_session.execute(
text(f'''SELECT count(*) FROM {self.db_name_prefix}_sample s,
{self.db_name_prefix}_experiment e WHERE s.e_id=e.id '''
)).scalar()
n_sample = 0 if res is None else res
print(f'There are a total of {n_sample} samples.')
res = db_session.execute(
text(
f'SELECT id FROM {self.db_name_prefix}_sample WHERE is_done_flag = -1'
))
for row in res:
s_id = row[0]
if s_id <= n_sample_1_2:
self.lst_saved_s_id_1_2.append(s_id)
else:
self.lst_saved_s_id_3.append(s_id)
print(f'Left: {len(self.lst_saved_s_id_1_2)} for exp 1 and 2; {len(self.lst_saved_s_id_3)} for exp 3')
self.lst_saved_s_id.append(s_id)
@staticmethod
def select_random_sample(lst_s_id):
@@ -182,7 +165,8 @@ class ControllerDB:
return None
s_id = random.choice(lst_s_id)
lst_s_id.remove(s_id)
res = db_session.query(Sample).filter(Sample.id == int(s_id), Sample.is_done_flag == -1)
res = db_session.query(Sample).filter(Sample.id == int(s_id),
Sample.is_done_flag == -1)
if res.count() == 1:
return res[0]
@@ -194,11 +178,7 @@ class ControllerDB:
else:
return res[0]
sample = self.select_random_sample(self.lst_saved_s_id_1_2)
if sample is not None:
return sample
sample = self.select_random_sample(self.lst_saved_s_id_3)
sample = self.select_random_sample(self.lst_saved_s_id)
if sample is not None:
return sample
@@ -214,7 +194,9 @@ if __name__ == '__main__':
# pprint.pprint(dct_exp_config)
# pprint.pprint(dct_conf_problem)
db = ControllerDB('first')
ratio = db_session.execute('SELECT COUNT(*) / 332750 FROM first_sample s WHERE s.is_done_flag = 1').scalar()
ratio = db_session.execute(
'SELECT COUNT(*) / 332750 FROM first_sample s WHERE s.is_done_flag = 1'
).scalar()
print(ratio)
# db.fill_experiment_table()
# print(db.dct_parameter)