This commit is contained in:
HaoYizhi 2023-05-15 13:44:21 +08:00
parent 7e8980d0c6
commit 745763ce2e
3 changed files with 34 additions and 21 deletions

View File

@ -6,11 +6,11 @@ from sqlalchemy import text
import yaml
import random
import pandas as pd
import numpy as np
import platform
import networkx as nx
import json
class ControllerDB:
dct_parameter = None
is_test: bool = None
@ -31,7 +31,8 @@ class ControllerDB:
**dct_para_in_test
}
print(self.dct_parameter)
self.reset_flag = reset_flag # 0, not reset; 1, reset self; 2, reset all
# 0, not reset; 1, reset self; 2, reset all
self.reset_flag = reset_flag
self.lst_saved_s_id = []
def init_tables(self):
@ -73,7 +74,7 @@ class ControllerDB:
# insert exp
for idx_exp, dct in enumerate(list_dct):
self.add_experiment_1(idx_exp, self.dct_parameter['n_max_trial'],
dct, g_product_js) # same g_bom for all exp
dct, g_product_js) # same g_bom for all exp
print(f'Inserted experiment for exp {idx_exp}!')
def add_experiment_1(self, idx_exp, n_max_trial,
@ -127,16 +128,19 @@ class ControllerDB:
else:
lst_table_obj.remove(a_table)
print(
f"Table {a_table.name} is dropped for exp: {self.db_name_prefix}!!!"
f"Table {a_table.name} is dropped "
f"for exp: {self.db_name_prefix}!!!"
)
finally:
is_exist = len(lst_table_obj) > 0
if is_exist:
print(
f"All tables exist. No need to reset for exp: {self.db_name_prefix}."
f"All tables exist. No need to reset "
f"for exp: {self.db_name_prefix}."
)
# change the is_done_flag from 0 to -1, to rerun the in-finished tasks
# change the is_done_flag from 0 to -1
# rerun the in-finished tasks
if self.reset_flag > 0:
if self.reset_flag == 2:
sample = db_session.query(Sample).filter(
@ -152,7 +156,7 @@ class ControllerDB:
qry_result = db_session.query(Result).filter_by(
s_id=s.id)
if qry_result.count() > 0:
db_session.query(Result).filter(s_id = s.id).delete()
db_session.query(Result).filter(s_id=s.id).delete()
db_session.commit()
s.is_done_flag = -1
db_session.commit()
@ -161,20 +165,21 @@ class ControllerDB:
Base.metadata.create_all(bind=engine)
self.init_tables()
print(
f"All tables are just created and initialized for exp: {self.db_name_prefix}."
f"All tables are just created and initialized "
f"for exp: {self.db_name_prefix}."
)
def prepare_list_sample(self):
res = db_session.execute(
text(f'''SELECT count(*) FROM {self.db_name_prefix}_sample s,
{self.db_name_prefix}_experiment e WHERE s.e_id=e.id '''
text(f"SELECT count(*) FROM {self.db_name_prefix}_sample s, "
f"{self.db_name_prefix}_experiment e WHERE s.e_id=e.id"
)).scalar()
n_sample = 0 if res is None else res
print(f'There are a total of {n_sample} samples.')
res = db_session.execute(
text(
f'SELECT id FROM {self.db_name_prefix}_sample WHERE is_done_flag = -1'
))
text(f"SELECT id FROM {self.db_name_prefix}_sample "
f"WHERE is_done_flag = -1"
))
for row in res:
s_id = row[0]
self.lst_saved_s_id.append(s_id)

View File

@ -59,7 +59,8 @@ class Model(ap.Model):
size / sum(lst_succ_firm_size)
for size in lst_succ_firm_size
]
# select multiple successors based on relative size of this firm
# select multiple successors
# based on relative size of this firm
lst_same_prod_firm = Firm['Code'][Firm[product_code] ==
1].to_list()
lst_same_prod_firm_size = [
@ -202,7 +203,8 @@ class Model(ap.Model):
if n_up_product_removed == 0:
continue
else:
# update a_lst_product_disrupted / dct_lst_disrupt_firm_prod
# update a_lst_product_disrupted
# update dct_lst_disrupt_firm_prod
if product not in firm.a_lst_product_disrupted:
firm.a_lst_product_disrupted.append(product)
if firm in self.dct_lst_disrupt_firm_prod.keys():
@ -212,7 +214,8 @@ class Model(ap.Model):
self.dct_lst_disrupt_firm_prod[
firm] = ap.AgentList(
self.model, [product])
# update a_lst_product_removed / dct_list_remove_firm_prod
# update a_lst_product_removed
# update dct_list_remove_firm_prod
# mark disrupted firm as removed based conditionally
lost_percent = n_up_product_removed / len(
product.a_predecessors())
@ -220,8 +223,9 @@ class Model(ap.Model):
std_size = (firm.revenue_log - min(lst_size) +
1) / (max(lst_size) - min(lst_size) + 1)
prod_remove = 1 - std_size * (1 - lost_percent)
if self.nprandom.choice(
[True, False], p=[prod_remove, 1 - prod_remove]):
if self.nprandom.choice([True, False],
p=[prod_remove,
1 - prod_remove]):
firm.a_lst_product_removed.append(product)
if firm in self.dct_lst_remove_firm_prod.keys():
self.dct_lst_remove_firm_prod[firm].append(

10
orm.py
View File

@ -1,7 +1,8 @@
# -*- coding: utf-8 -*-
from sqlalchemy import create_engine, inspect
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import Column, Integer, String, ForeignKey, BigInteger, DateTime, PickleType, Boolean, Text
from sqlalchemy import (Column, Integer, String, ForeignKey, BigInteger,
DateTime, PickleType, Boolean, Text)
from sqlalchemy.sql import func
from sqlalchemy.orm import relationship, Session
from sqlalchemy.pool import NullPool
@ -21,8 +22,11 @@ with open('conf_db_prefix.yaml') as file:
db_name_prefix = dct_conf_db_prefix['db_name_prefix']
str_login = 'mysql://{}:{}@{}:{}/{}'.format(dct_conf_db['user_name'], dct_conf_db['password'],
dct_conf_db['address'], dct_conf_db['port'], dct_conf_db['db_name'])
str_login = 'mysql://{}:{}@{}:{}/{}'.format(dct_conf_db['user_name'],
dct_conf_db['password'],
dct_conf_db['address'],
dct_conf_db['port'],
dct_conf_db['db_name'])
print('DB is {}:{}/{}'.format(dct_conf_db['address'],
dct_conf_db['port'], dct_conf_db['db_name']))