This commit is contained in:
HaoYizhi 2023-05-15 13:44:21 +08:00
parent 7e8980d0c6
commit 745763ce2e
3 changed files with 34 additions and 21 deletions

View File

@ -6,11 +6,11 @@ from sqlalchemy import text
import yaml import yaml
import random import random
import pandas as pd import pandas as pd
import numpy as np
import platform import platform
import networkx as nx import networkx as nx
import json import json
class ControllerDB: class ControllerDB:
dct_parameter = None dct_parameter = None
is_test: bool = None is_test: bool = None
@ -31,7 +31,8 @@ class ControllerDB:
**dct_para_in_test **dct_para_in_test
} }
print(self.dct_parameter) print(self.dct_parameter)
self.reset_flag = reset_flag # 0, not reset; 1, reset self; 2, reset all # 0, not reset; 1, reset self; 2, reset all
self.reset_flag = reset_flag
self.lst_saved_s_id = [] self.lst_saved_s_id = []
def init_tables(self): def init_tables(self):
@ -73,7 +74,7 @@ class ControllerDB:
# insert exp # insert exp
for idx_exp, dct in enumerate(list_dct): for idx_exp, dct in enumerate(list_dct):
self.add_experiment_1(idx_exp, self.dct_parameter['n_max_trial'], self.add_experiment_1(idx_exp, self.dct_parameter['n_max_trial'],
dct, g_product_js) # same g_bom for all exp dct, g_product_js) # same g_bom for all exp
print(f'Inserted experiment for exp {idx_exp}!') print(f'Inserted experiment for exp {idx_exp}!')
def add_experiment_1(self, idx_exp, n_max_trial, def add_experiment_1(self, idx_exp, n_max_trial,
@ -127,16 +128,19 @@ class ControllerDB:
else: else:
lst_table_obj.remove(a_table) lst_table_obj.remove(a_table)
print( print(
f"Table {a_table.name} is dropped for exp: {self.db_name_prefix}!!!" f"Table {a_table.name} is dropped "
f"for exp: {self.db_name_prefix}!!!"
) )
finally: finally:
is_exist = len(lst_table_obj) > 0 is_exist = len(lst_table_obj) > 0
if is_exist: if is_exist:
print( print(
f"All tables exist. No need to reset for exp: {self.db_name_prefix}." f"All tables exist. No need to reset "
f"for exp: {self.db_name_prefix}."
) )
# change the is_done_flag from 0 to -1, to rerun the in-finished tasks # change the is_done_flag from 0 to -1
# rerun the in-finished tasks
if self.reset_flag > 0: if self.reset_flag > 0:
if self.reset_flag == 2: if self.reset_flag == 2:
sample = db_session.query(Sample).filter( sample = db_session.query(Sample).filter(
@ -152,7 +156,7 @@ class ControllerDB:
qry_result = db_session.query(Result).filter_by( qry_result = db_session.query(Result).filter_by(
s_id=s.id) s_id=s.id)
if qry_result.count() > 0: if qry_result.count() > 0:
db_session.query(Result).filter(s_id = s.id).delete() db_session.query(Result).filter(s_id=s.id).delete()
db_session.commit() db_session.commit()
s.is_done_flag = -1 s.is_done_flag = -1
db_session.commit() db_session.commit()
@ -161,20 +165,21 @@ class ControllerDB:
Base.metadata.create_all(bind=engine) Base.metadata.create_all(bind=engine)
self.init_tables() self.init_tables()
print( print(
f"All tables are just created and initialized for exp: {self.db_name_prefix}." f"All tables are just created and initialized "
f"for exp: {self.db_name_prefix}."
) )
def prepare_list_sample(self): def prepare_list_sample(self):
res = db_session.execute( res = db_session.execute(
text(f'''SELECT count(*) FROM {self.db_name_prefix}_sample s, text(f"SELECT count(*) FROM {self.db_name_prefix}_sample s, "
{self.db_name_prefix}_experiment e WHERE s.e_id=e.id ''' f"{self.db_name_prefix}_experiment e WHERE s.e_id=e.id"
)).scalar() )).scalar()
n_sample = 0 if res is None else res n_sample = 0 if res is None else res
print(f'There are a total of {n_sample} samples.') print(f'There are a total of {n_sample} samples.')
res = db_session.execute( res = db_session.execute(
text( text(f"SELECT id FROM {self.db_name_prefix}_sample "
f'SELECT id FROM {self.db_name_prefix}_sample WHERE is_done_flag = -1' f"WHERE is_done_flag = -1"
)) ))
for row in res: for row in res:
s_id = row[0] s_id = row[0]
self.lst_saved_s_id.append(s_id) self.lst_saved_s_id.append(s_id)

View File

@ -59,7 +59,8 @@ class Model(ap.Model):
size / sum(lst_succ_firm_size) size / sum(lst_succ_firm_size)
for size in lst_succ_firm_size for size in lst_succ_firm_size
] ]
# select multiple successors based on relative size of this firm # select multiple successors
# based on relative size of this firm
lst_same_prod_firm = Firm['Code'][Firm[product_code] == lst_same_prod_firm = Firm['Code'][Firm[product_code] ==
1].to_list() 1].to_list()
lst_same_prod_firm_size = [ lst_same_prod_firm_size = [
@ -202,7 +203,8 @@ class Model(ap.Model):
if n_up_product_removed == 0: if n_up_product_removed == 0:
continue continue
else: else:
# update a_lst_product_disrupted / dct_lst_disrupt_firm_prod # update a_lst_product_disrupted
# update dct_lst_disrupt_firm_prod
if product not in firm.a_lst_product_disrupted: if product not in firm.a_lst_product_disrupted:
firm.a_lst_product_disrupted.append(product) firm.a_lst_product_disrupted.append(product)
if firm in self.dct_lst_disrupt_firm_prod.keys(): if firm in self.dct_lst_disrupt_firm_prod.keys():
@ -212,7 +214,8 @@ class Model(ap.Model):
self.dct_lst_disrupt_firm_prod[ self.dct_lst_disrupt_firm_prod[
firm] = ap.AgentList( firm] = ap.AgentList(
self.model, [product]) self.model, [product])
# update a_lst_product_removed / dct_list_remove_firm_prod # update a_lst_product_removed
# update dct_list_remove_firm_prod
# mark disrupted firm as removed based conditionally # mark disrupted firm as removed based conditionally
lost_percent = n_up_product_removed / len( lost_percent = n_up_product_removed / len(
product.a_predecessors()) product.a_predecessors())
@ -220,8 +223,9 @@ class Model(ap.Model):
std_size = (firm.revenue_log - min(lst_size) + std_size = (firm.revenue_log - min(lst_size) +
1) / (max(lst_size) - min(lst_size) + 1) 1) / (max(lst_size) - min(lst_size) + 1)
prod_remove = 1 - std_size * (1 - lost_percent) prod_remove = 1 - std_size * (1 - lost_percent)
if self.nprandom.choice( if self.nprandom.choice([True, False],
[True, False], p=[prod_remove, 1 - prod_remove]): p=[prod_remove,
1 - prod_remove]):
firm.a_lst_product_removed.append(product) firm.a_lst_product_removed.append(product)
if firm in self.dct_lst_remove_firm_prod.keys(): if firm in self.dct_lst_remove_firm_prod.keys():
self.dct_lst_remove_firm_prod[firm].append( self.dct_lst_remove_firm_prod[firm].append(

10
orm.py
View File

@ -1,7 +1,8 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from sqlalchemy import create_engine, inspect from sqlalchemy import create_engine, inspect
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import Column, Integer, String, ForeignKey, BigInteger, DateTime, PickleType, Boolean, Text from sqlalchemy import (Column, Integer, String, ForeignKey, BigInteger,
DateTime, PickleType, Boolean, Text)
from sqlalchemy.sql import func from sqlalchemy.sql import func
from sqlalchemy.orm import relationship, Session from sqlalchemy.orm import relationship, Session
from sqlalchemy.pool import NullPool from sqlalchemy.pool import NullPool
@ -21,8 +22,11 @@ with open('conf_db_prefix.yaml') as file:
db_name_prefix = dct_conf_db_prefix['db_name_prefix'] db_name_prefix = dct_conf_db_prefix['db_name_prefix']
str_login = 'mysql://{}:{}@{}:{}/{}'.format(dct_conf_db['user_name'], dct_conf_db['password'], str_login = 'mysql://{}:{}@{}:{}/{}'.format(dct_conf_db['user_name'],
dct_conf_db['address'], dct_conf_db['port'], dct_conf_db['db_name']) dct_conf_db['password'],
dct_conf_db['address'],
dct_conf_db['port'],
dct_conf_db['db_name'])
print('DB is {}:{}/{}'.format(dct_conf_db['address'], print('DB is {}:{}/{}'.format(dct_conf_db['address'],
dct_conf_db['port'], dct_conf_db['db_name'])) dct_conf_db['port'], dct_conf_db['db_name']))