IIabm/controller_db.py

224 lines
7.9 KiB
Python

# -*- coding: utf-8 -*-
from orm import db_session, engine, Base, ins
from orm import Experiment, Sample, Result
from sqlalchemy.exc import OperationalError
from sqlalchemy import text
import yaml
import random
import pandas as pd
import numpy as np
import platform
class ControllerDB:
dct_parameter = None
is_test: bool = None
db_name_prefix: str = None
reset_flag: int
lst_saved_s_id: list
def __init__(self, prefix, reset_flag=0):
with open('conf_experiment.yaml') as yaml_file:
dct_conf_experiment = yaml.full_load(yaml_file)
self.is_test = prefix == 'test'
self.db_name_prefix = prefix
dct_para_in_test = dct_conf_experiment[
'test'] if self.is_test else dct_conf_experiment['not_test']
self.dct_parameter = {
'meta_seed': dct_conf_experiment['meta_seed'],
**dct_para_in_test
}
print(self.dct_parameter)
self.reset_flag = reset_flag # 0, not reset; 1, reset self; 2, reset all
self.lst_saved_s_id = []
def init_tables(self):
self.fill_experiment_table()
self.fill_sample_table()
def fill_experiment_table(self):
Firm = pd.read_csv("Firm_amended.csv")
Firm['Code'] = Firm['Code'].astype('string')
Firm.fillna(0, inplace=True)
list_dct = []
for _, row in Firm.iterrows():
code = row['Code']
row = row['1':]
for product_code in row.index[row == 1].to_list():
dct = {code: [product_code]}
list_dct.append(dct)
# break
# break
# list_dct = [{'140': ['1.4.5.1']}]
list_dct = [{'133': ['1.4.4.1']}]
# list_dct = [{'2': ['1.1.3']}]
# list_dct = [{'135': ['1.3.2.1']}]
for idx_exp, dct in enumerate(list_dct):
self.add_experiment_1(idx_exp, self.dct_parameter['n_max_trial'],
dct)
print(f'Inserted experiment for exp {idx_exp}!')
def add_experiment_1(self, idx_exp, n_max_trial,
dct_lst_init_remove_firm_prod):
e = Experiment(
idx_exp=idx_exp,
n_sample=int(self.dct_parameter['n_sample']),
n_iter=int(self.dct_parameter['n_iter']),
n_max_trial=n_max_trial,
dct_lst_init_remove_firm_prod=dct_lst_init_remove_firm_prod)
db_session.add(e)
db_session.commit()
def fill_sample_table(self):
rng = random.Random(self.dct_parameter['meta_seed'])
lst_seed = [
rng.getrandbits(32)
for _ in range(int(self.dct_parameter['n_sample']))
]
lst_exp = db_session.query(Experiment).all()
lst_sample = []
for experiment in lst_exp:
for idx_sample in range(int(experiment.n_sample)):
s = Sample(e_id=experiment.id,
idx_sample=idx_sample + 1,
seed=lst_seed[idx_sample],
is_done_flag=-1)
lst_sample.append(s)
db_session.bulk_save_objects(lst_sample)
db_session.commit()
print(f'Inserted {len(lst_sample)} samples!')
def reset_db(self, force_drop=False):
# first, check if tables exist
lst_table_obj = [
Base.metadata.tables[str_table]
for str_table in ins.get_table_names()
if str_table.startswith(self.db_name_prefix)
]
is_exist = len(lst_table_obj) > 0
if force_drop:
while is_exist:
a_table = random.choice(lst_table_obj)
try:
Base.metadata.drop_all(bind=engine, tables=[a_table])
except KeyError:
pass
except OperationalError:
pass
else:
lst_table_obj.remove(a_table)
print(
f"Table {a_table.name} is dropped for exp: {self.db_name_prefix}!!!"
)
finally:
is_exist = len(lst_table_obj) > 0
if is_exist:
print(
f"All tables exist. No need to reset for exp: {self.db_name_prefix}."
)
# change the is_done_flag from 0 to -1, to rerun the in-finished tasks
if self.reset_flag > 0:
if self.reset_flag == 2:
sample = db_session.query(Sample).filter(
Sample.is_done_flag == 0)
elif self.reset_flag == 1:
sample = db_session.query(Sample).filter(
Sample.is_done_flag == 0,
Sample.computer_name == platform.node())
else:
raise ValueError('Wrong reset flag')
if sample.count() > 0:
for s in sample:
qry_result = db_session.query(Result).filter_by(
s_id=s.id)
if qry_result.count() > 0:
db_session.query(Result).filter(s_id = s.id).delete()
db_session.commit()
s.is_done_flag = -1
db_session.commit()
print(f"Reset the sample id {s.id} flag from 0 to -1")
else:
Base.metadata.create_all(bind=engine)
self.init_tables()
print(
f"All tables are just created and initialized for exp: {self.db_name_prefix}."
)
def prepare_list_sample(self):
res = db_session.execute(
text(f'''SELECT count(*) FROM {self.db_name_prefix}_sample s,
{self.db_name_prefix}_experiment e WHERE s.e_id=e.id '''
)).scalar()
n_sample = 0 if res is None else res
print(f'There are a total of {n_sample} samples.')
res = db_session.execute(
text(
f'SELECT id FROM {self.db_name_prefix}_sample WHERE is_done_flag = -1'
))
for row in res:
s_id = row[0]
self.lst_saved_s_id.append(s_id)
@staticmethod
def select_random_sample(lst_s_id):
while 1:
if len(lst_s_id) == 0:
return None
s_id = random.choice(lst_s_id)
lst_s_id.remove(s_id)
res = db_session.query(Sample).filter(Sample.id == int(s_id),
Sample.is_done_flag == -1)
if res.count() == 1:
return res[0]
def fetch_a_sample(self, s_id=None):
if s_id is not None:
res = db_session.query(Sample).filter(Sample.id == int(s_id))
if res.count() == 0:
return None
else:
return res[0]
sample = self.select_random_sample(self.lst_saved_s_id)
if sample is not None:
return sample
return None
@staticmethod
def lock_the_sample(sample: Sample):
sample.is_done_flag, sample.computer_name = 0, platform.node()
db_session.commit()
if __name__ == '__main__':
# pprint.pprint(dct_exp_config)
# pprint.pprint(dct_conf_problem)
db = ControllerDB('first')
ratio = db_session.execute(
'SELECT COUNT(*) / 332750 FROM first_sample s WHERE s.is_done_flag = 1'
).scalar()
print(ratio)
# db.fill_experiment_table()
# print(db.dct_parameter)
# db.init_tables()
# db.fill_sample_table()
# pprint.pprint(dct_conf_exp)
# db.update_bi()
# db.reset_db(force_drop=True)
# db.prepare_list_sample()
#
# for i in range(1000):
# if i % 10 == 0:
# print(i)
# print(len(db.lst_saved_s_id_1_2), len(db.lst_saved_s_id_3))
# r = db.fetch_a_sample()
# if i % 10 == 0:
# print(len(db.lst_saved_s_id_1_2), len(db.lst_saved_s_id_3))
# print(r, r.experiment.idx_exp)
# if i == 400:
# print()
# pass