This commit is contained in:
HaoYizhi 2023-03-13 19:47:25 +08:00
parent 2f162b970b
commit 09b59d8778
10 changed files with 113 additions and 210 deletions

2
.vscode/launch.json vendored
View File

@ -8,7 +8,7 @@
"name": "Python: Current File",
"type": "python",
"request": "launch",
"program": "C:\\Users\\ASUS\\OneDrive\\Project\\ScrAbm\\Dissertation\\IIabm\\model.py",
"program": "C:\\Users\\ASUS\\OneDrive\\Project\\ScrAbm\\Dissertation\\IIabm\\main.py",
"console": "integratedTerminal",
"justMyCode": true
}

Binary file not shown.

Binary file not shown.

View File

@ -3,36 +3,12 @@
# run settings
meta_seed: 0
fixed: # unchanged all the time
int_n_country: 2
max_int_n_supplier: 3 # make firms heterogeneous
flt_bm_price_ratio: 20.0
flt_beta_developing: 0.5
test: # only for test scenarios
int_n_product: 12
int_n_firm_per_product_per_country: 2
flt_demand_total: 1000.0
n_sample: 5
n_sample: 1
n_iter: 100
n_max_trial: 3
not_test: # normal scenarios
int_n_product: 50
int_n_firm_per_product_per_country: 10
flt_demand_total: 10000.0
n_sample: 50
n_iter: 10000
default:
is_eliminated: 0 # add when all positive profits and keep max n supplier; remove the worst when all negative wealth
flt_beta_developed: 0.5 # benchmarks flt_beta_developed
tariff_percentage: 0
experiment:
1:
range_lambda_tier: 0, 1, 0.1 # describe the network. 0: chain 1: one iter
2:
is_eliminated: 1
3:
range_flt_beta_developed: 0.5, 0.9, 0.1
range_tariff_percentage: 0, 1, 0.1
n_iter: 100
n_max_trial: 3

View File

@ -1,10 +1,11 @@
# -*- coding: utf-8 -*-
from orm import db_session, engine, Base, ins
from orm import Experiment, Sample, Product, Firm
from orm import Experiment, Sample, Result
from sqlalchemy.exc import OperationalError
from sqlalchemy import text
import yaml
import random
import pandas as pd
import numpy as np
import platform
@ -17,6 +18,7 @@ class ControllerDB:
lst_saved_s_id_3: list
lst_saved_s_id_1_2: list
# n_sample_1_2: int
def __init__(self, prefix, reset_flag=0):
@ -24,88 +26,59 @@ class ControllerDB:
dct_conf_experiment = yaml.full_load(yaml_file)
self.is_test = prefix == 'test'
self.db_name_prefix = prefix
dct_para_in_test = dct_conf_experiment['test'] if self.is_test else dct_conf_experiment['not_test']
self.dct_parameter = {'meta_seed': dct_conf_experiment['meta_seed'],
'experiment': dct_conf_experiment['experiment'],
**dct_conf_experiment['fixed'], **dct_conf_experiment['default'], **dct_para_in_test}
dct_para_in_test = dct_conf_experiment[
'test'] if self.is_test else dct_conf_experiment['not_test']
self.dct_parameter = {
'meta_seed': dct_conf_experiment['meta_seed'],
**dct_para_in_test
}
print(self.dct_parameter)
self.reset_flag = reset_flag # 0, not reset; 1, reset self; 2, reset all
self.lst_saved_s_id_1_2, self.lst_saved_s_id_3 = [], []
self.lst_saved_s_id = []
def init_tables(self):
self.fill_experiment_table()
self.fill_sample_table()
@staticmethod
def get_lst_of_range(str_range: str):
s1, s2, s3 = tuple(str_range.split(','))
return list(np.linspace(float(s1), float(s2), num=int((float(s2) - float(s1)) / float(s3)) + 1))
def fill_experiment_table(self):
# prepare the list of lambda tier
lst_lambda = self.get_lst_of_range(self.dct_parameter['experiment'][1]['range_lambda_tier'])
# prepare the list of alpha_2nd_country
lst_beta_developed = self.get_lst_of_range(self.dct_parameter['experiment'][3]['range_flt_beta_developed'])
# prepare the list of tariff_percentage
lst_tariff = self.get_lst_of_range(self.dct_parameter['experiment'][3]['range_tariff_percentage'])
# prepare the default values
is_eliminated = int(self.dct_parameter['is_eliminated'])
beta_developed = float(self.dct_parameter['flt_beta_developed'])
tariff_percentage_1 = tariff_percentage_2 = float(self.dct_parameter['tariff_percentage'])
for idx_scenario in self.dct_parameter['experiment'].keys():
n_exp = 0
# if idx_scenario == 1: # add S1 experiments
# n_exp = self.add_experiment_1(idx_scenario, lst_lambda, is_eliminated,
# beta_developed, tariff_percentage_1, tariff_percentage_2)
# if idx_scenario == 2: # add S2 experiments
# n_exp = self.add_experiment_1(idx_scenario, lst_lambda,
# int(self.dct_parameter['experiment'][idx_scenario]['is_eliminated']),
# beta_developed, tariff_percentage_1, tariff_percentage_2)
if idx_scenario == 3:
# int_eliminated = int(self.dct_parameter['experiment'][idx_scenario-1]['is_eliminated'])
int_eliminated = is_eliminated # is_eliminated is 0 at default, so stop eliminating under S3
for beta_developed in lst_beta_developed:
# for beta_developed in [0.5]: # fix beta as 0.5
for tariff_percentage_1 in lst_tariff:
for tariff_percentage_2 in lst_tariff:
# fix lambda as 0.5
n_exp += self.add_experiment_1(idx_scenario, [0.5], int_eliminated,
beta_developed, tariff_percentage_1, tariff_percentage_2)
print(f'Inserted {n_exp} experiments for exp {idx_scenario}!')
Firm = pd.read_csv("Firm_amended.csv")
Firm['Code'] = Firm['Code'].astype('string')
Firm.fillna(0, inplace=True)
list_dct = []
for _, row in Firm.iterrows():
code = row['Code']
row = row['1':]
for product_code in row.index[row == 1].to_list():
dct = {code: [product_code]}
list_dct.append(dct)
for idx_exp, dct in enumerate(list_dct):
self.add_experiment_1(idx_exp, self.dct_parameter['n_max_trial'],
dct)
print(f'Inserted experiment for exp {idx_exp}!')
def add_experiment_1(self, idx_exp, lst_lambda, is_eliminated, flt_beta_developed,
tariff_percentage_1, tariff_percentage_2):
lst_exp = []
for lambda_tier in lst_lambda:
e = Experiment(idx_exp=idx_exp,
int_n_country=int(self.dct_parameter['int_n_country']),
max_int_n_supplier=int(self.dct_parameter['max_int_n_supplier']),
int_n_product=int(self.dct_parameter['int_n_product']),
int_n_firm_per_product_per_country=int(
self.dct_parameter['int_n_firm_per_product_per_country']),
flt_demand_total=float(self.dct_parameter['flt_demand_total']),
flt_bm_price_ratio=float(self.dct_parameter['flt_bm_price_ratio']),
flt_beta_developing=float(self.dct_parameter['flt_beta_developing']),
n_sample=int(self.dct_parameter['n_sample']),
n_iter=int(self.dct_parameter['n_iter']),
is_eliminated=is_eliminated,
flt_beta_developed=flt_beta_developed,
tariff_percentage_1=tariff_percentage_1,
tariff_percentage_2=tariff_percentage_2,
lambda_tier=float(lambda_tier))
lst_exp.append(e)
db_session.bulk_save_objects(lst_exp)
def add_experiment_1(self, idx_exp, n_max_trial,
dct_list_init_remove_firm_prod):
e = Experiment(
idx_exp=idx_exp,
n_sample=int(self.dct_parameter['n_sample']),
n_iter=int(self.dct_parameter['n_iter']),
n_max_trial=n_max_trial,
dct_list_init_remove_firm_prod=dct_list_init_remove_firm_prod)
db_session.add(e)
db_session.commit()
return len(lst_exp)
def fill_sample_table(self):
rng = random.Random(self.dct_parameter['meta_seed'])
lst_seed = [rng.getrandbits(32) for _ in range(int(self.dct_parameter['n_sample']))]
lst_seed = [
rng.getrandbits(32)
for _ in range(int(self.dct_parameter['n_sample']))
]
lst_exp = db_session.query(Experiment).all()
lst_sample = []
for experiment in lst_exp:
for idx_sample in range(int(experiment.n_sample)):
s = Sample(e_id=experiment.id,
idx_sample=idx_sample+1,
idx_sample=idx_sample + 1,
seed=lst_seed[idx_sample],
is_done_flag=-1)
lst_sample.append(s)
@ -115,8 +88,11 @@ class ControllerDB:
def reset_db(self, force_drop=False):
# first, check if tables exist
lst_table_obj = [Base.metadata.tables[str_table] for str_table in ins.get_table_names()
if str_table.startswith(self.db_name_prefix)]
lst_table_obj = [
Base.metadata.tables[str_table]
for str_table in ins.get_table_names()
if str_table.startswith(self.db_name_prefix)
]
is_exist = len(lst_table_obj) > 0
if force_drop:
while is_exist:
@ -129,51 +105,58 @@ class ControllerDB:
pass
else:
lst_table_obj.remove(a_table)
print(f"Table {a_table.name} is dropped for exp: {self.db_name_prefix}!!!")
print(
f"Table {a_table.name} is dropped for exp: {self.db_name_prefix}!!!"
)
finally:
is_exist = len(lst_table_obj) > 0
if is_exist:
print(f"All tables exist. No need to reset for exp: {self.db_name_prefix}.")
print(
f"All tables exist. No need to reset for exp: {self.db_name_prefix}."
)
# change the is_done_flag from 0 to -1, to rerun the in-finished tasks
if self.reset_flag > 0:
if self.reset_flag == 2:
result = db_session.query(Sample).filter(Sample.is_done_flag == 0)
sample = db_session.query(Sample).filter(
Sample.is_done_flag == 0)
elif self.reset_flag == 1:
result = db_session.query(Sample).filter(Sample.is_done_flag == 0,
Sample.computer_name == platform.node())
sample = db_session.query(Sample).filter(
Sample.is_done_flag == 0,
Sample.computer_name == platform.node())
else:
raise ValueError('Wrong reset flag')
if result.count() > 0:
for res in result:
qry_product = db_session.query(Product).filter_by(s_id=res.id)
if qry_product.count() > 0:
for p in qry_product:
db_session.query(Firm).filter(Firm.p_id == p.id).delete()
db_session.commit()
db_session.query(Product).filter(Product.id == p.id).delete()
db_session.commit()
res.is_done_flag = -1
if sample.count() > 0:
for s in sample:
qry_result = db_session.query(Result).filter_by(
s_id=s.id)
if qry_result.count() > 0:
db_session.query(Result).filter(s_id = s.id).delete()
db_session.commit()
s.is_done_flag = -1
db_session.commit()
print(f"Reset the task id {res.id} flag from 0 to -1")
print(f"Reset the sample id {s.id} flag from 0 to -1")
else:
Base.metadata.create_all(bind=engine)
self.init_tables()
print(f"All tables are just created and initialized for exp: {self.db_name_prefix}.")
print(
f"All tables are just created and initialized for exp: {self.db_name_prefix}."
)
def prepare_list_sample(self):
res = db_session.execute(text(f'''SELECT count(*) FROM {self.db_name_prefix}_sample s,
{self.db_name_prefix}_experiment e WHERE s.e_id=e.id and e.idx_exp < 3''')).scalar()
n_sample_1_2 = 0 if res is None else res
print(f'There are {n_sample_1_2} sample for exp 1 and 2.')
res = db_session.execute(text(f'SELECT id FROM {self.db_name_prefix}_sample WHERE is_done_flag = -1'))
res = db_session.execute(
text(f'''SELECT count(*) FROM {self.db_name_prefix}_sample s,
{self.db_name_prefix}_experiment e WHERE s.e_id=e.id '''
)).scalar()
n_sample = 0 if res is None else res
print(f'There are a total of {n_sample} samples.')
res = db_session.execute(
text(
f'SELECT id FROM {self.db_name_prefix}_sample WHERE is_done_flag = -1'
))
for row in res:
s_id = row[0]
if s_id <= n_sample_1_2:
self.lst_saved_s_id_1_2.append(s_id)
else:
self.lst_saved_s_id_3.append(s_id)
print(f'Left: {len(self.lst_saved_s_id_1_2)} for exp 1 and 2; {len(self.lst_saved_s_id_3)} for exp 3')
self.lst_saved_s_id.append(s_id)
@staticmethod
def select_random_sample(lst_s_id):
@ -182,7 +165,8 @@ class ControllerDB:
return None
s_id = random.choice(lst_s_id)
lst_s_id.remove(s_id)
res = db_session.query(Sample).filter(Sample.id == int(s_id), Sample.is_done_flag == -1)
res = db_session.query(Sample).filter(Sample.id == int(s_id),
Sample.is_done_flag == -1)
if res.count() == 1:
return res[0]
@ -194,11 +178,7 @@ class ControllerDB:
else:
return res[0]
sample = self.select_random_sample(self.lst_saved_s_id_1_2)
if sample is not None:
return sample
sample = self.select_random_sample(self.lst_saved_s_id_3)
sample = self.select_random_sample(self.lst_saved_s_id)
if sample is not None:
return sample
@ -214,7 +194,9 @@ if __name__ == '__main__':
# pprint.pprint(dct_exp_config)
# pprint.pprint(dct_conf_problem)
db = ControllerDB('first')
ratio = db_session.execute('SELECT COUNT(*) / 332750 FROM first_sample s WHERE s.is_done_flag = 1').scalar()
ratio = db_session.execute(
'SELECT COUNT(*) / 332750 FROM first_sample s WHERE s.is_done_flag = 1'
).scalar()
print(ratio)
# db.fill_experiment_table()
# print(db.dct_parameter)

View File

@ -36,7 +36,8 @@ if __name__ == '__main__':
from controller_db import ControllerDB
controller_db = ControllerDB(args.exp, reset_flag=args.reset)
controller_db.reset_db()
# controller_db.reset_db()
controller_db.reset_db(force_drop=True)
controller_db.prepare_list_sample()

View File

@ -16,10 +16,10 @@ n_iter = 10
# 2: ['1.1.3']
# }
dct_list_init_remove_firm_prod = {
140: ['1.4.5.1'],
135: ['1.3.2.1'],
133: ['1.4.4.1'],
2: ['1.1.3']
'140': ['1.4.5.1'],
'135': ['1.3.2.1'],
'133': ['1.4.4.1'],
'2': ['1.1.3']
}
n_max_trial = 5
dct_sample_para = {
@ -56,13 +56,14 @@ class Model(ap.Model):
# init graph firm
Firm = pd.read_csv("Firm_amended.csv")
Firm['Code'] = Firm['Code'].astype('string')
Firm.fillna(0, inplace=True)
Firm_attr = Firm.loc[:, ["Code", "Name", "Type_Region", "Revenue_Log"]]
firm_product = []
for _, row in Firm.loc[:, '1':].iterrows():
firm_product.append(row[row == 1].index.to_list())
Firm_attr.loc[:, 'Product_Code'] = firm_product
Firm_attr.set_index('Code')
Firm_attr.set_index('Code', inplace=True)
G_Firm = nx.MultiDiGraph()
G_Firm.add_nodes_from(Firm["Code"])
@ -78,7 +79,7 @@ class Model(ap.Model):
# print(product_code)
for succ_product_code in list(G_bom.successors(product_code)):
# print(succ_product_code)
list_succ_firms = Firm.index[Firm[succ_product_code] ==
list_succ_firms = Firm['Code'][Firm[succ_product_code] ==
1].to_list()
list_revenue_log = [
G_Firm.nodes[succ_firm]['Revenue_Log']
@ -132,7 +133,7 @@ class Model(ap.Model):
for ag_node, attr in self.firm_network.graph.nodes(data=True):
firm_agent = FirmAgent(
self,
code=attr['Code'],
code=ag_node.label,
name=attr['Name'],
type_region=attr['Type_Region'],
revenue_log=attr['Revenue_Log'],
@ -352,5 +353,5 @@ class Model(ap.Model):
plt.savefig("network.png")
model = Model(dct_sample_para)
model.run()
# model = Model(dct_sample_para)
# model.run()

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.7 MiB

After

Width:  |  Height:  |  Size: 2.8 MiB

81
orm.py
View File

@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
from sqlalchemy import create_engine, inspect
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import Column, Integer, String, ForeignKey, BigInteger, DECIMAL, DateTime, Text
from sqlalchemy import Column, Integer, String, ForeignKey, BigInteger, DateTime, PickleType, Boolean
from sqlalchemy.sql import func
from sqlalchemy.orm import relationship, Session
from sqlalchemy.pool import NullPool
@ -40,22 +40,12 @@ class Experiment(Base):
idx_exp = Column(Integer, nullable=False)
# fixed parameters
int_n_country = Column(Integer, nullable=False)
max_int_n_supplier = Column(Integer, nullable=False) # uni(1, max), random parameter 1 of firm
int_n_product = Column(Integer, nullable=False)
int_n_firm_per_product_per_country = Column(Integer, nullable=False)
flt_demand_total = Column(DECIMAL(10, 2), nullable=False) # tri(0, total_demand, mean), to compute random para a
flt_bm_price_ratio = Column(DECIMAL(10, 2), nullable=False) # benchmark value of b, same for both countries
flt_beta_developing = Column(DECIMAL(10, 2), nullable=False) # benchmark value of c(beta), for developing countries
n_sample = Column(Integer, nullable=False)
n_iter = Column(Integer, nullable=False)
# variables
is_eliminated = Column(Integer, nullable=False)
flt_beta_developed = Column(DECIMAL(10, 2), nullable=False) # larger, for developed countries
lambda_tier = Column(DECIMAL(10, 2), nullable=False)
tariff_percentage_1 = Column(DECIMAL(10, 2), nullable=False)
tariff_percentage_2 = Column(DECIMAL(10, 2), nullable=False)
n_max_trial = Column(Integer, nullable=False)
dct_list_init_remove_firm_prod = Column(PickleType, nullable=False)
sample = relationship('Sample', back_populates='experiment', lazy='dynamic')
@ -75,77 +65,30 @@ class Sample(Base):
ts_done = Column(DateTime(timezone=True), onupdate=func.now())
stop_t = Column(Integer, nullable=True)
c1_wealth = Column(DECIMAL(20, 2), nullable=True) # country 1, developing countries
c2_wealth = Column(DECIMAL(20, 2), nullable=True) # country 2, developed countries
c1_wealth_dgt = Column(Integer, nullable=True)
c2_wealth_dgt = Column(Integer, nullable=True)
c1_tariff = Column(DECIMAL(20, 2), nullable=True) # country 1, developing countries
c2_tariff = Column(DECIMAL(20, 2), nullable=True) # country 2, developed countries
c1_tariff_dgt = Column(Integer, nullable=True)
c2_tariff_dgt = Column(Integer, nullable=True)
c1_n_firms = Column(Integer, nullable=True)
c2_n_firms = Column(Integer, nullable=True)
c1_n_positive_firms = Column(Integer, nullable=True)
c2_n_positive_firms = Column(Integer, nullable=True)
network = Column(Text(4294000000), nullable=True)
network_order = Column(Text(4294000000), nullable=True)
network_country = Column(Text(4294000000), nullable=True)
experiment = relationship('Experiment', back_populates='sample', uselist=False)
product = relationship('Product', back_populates='sample', lazy='dynamic')
result = relationship('Result', back_populates='sample', lazy='dynamic')
def __repr__(self):
return f'<Sample id: {self.id}>'
class Product(Base):
__tablename__ = f"{db_name_prefix}_product"
class Result(Base):
__tablename__ = f"{db_name_prefix}_result"
id = Column(Integer, primary_key=True, autoincrement=True)
s_id = Column(Integer, ForeignKey('{}.id'.format(f"{db_name_prefix}_sample")), nullable=False)
int_name = Column(Integer, nullable=False)
int_tier = Column(Integer, nullable=False)
n_up_products = Column(Integer, nullable=False)
n_peer_products = Column(Integer, nullable=False)
n_positive_firms = Column(Integer, nullable=False)
n_all_firms = Column(Integer, nullable=False)
gini_acc_demand_per_age = Column(DECIMAL(10, 2), nullable=False)
gini_acc_wealth_per_age = Column(DECIMAL(10, 2), nullable=False)
gini_acc_demand_per_age_all = Column(DECIMAL(10, 2), nullable=False)
gini_acc_wealth_per_age_all = Column(DECIMAL(10, 2), nullable=False)
# lst_n_positive_firms = Column(Text(4294000000), nullable=False)
# lst_n_all_firms = Column(Text(4294000000), nullable=False)
# lst_gini_acc_demand_per_age = Column(Text(4294000000), nullable=False)
# lst_gini_acc_wealth_per_age = Column(Text(4294000000), nullable=False)
# lst_gini_acc_demand_per_age_all = Column(Text(4294000000), nullable=False)
# lst_gini_acc_wealth_per_age_all = Column(Text(4294000000), nullable=False)
id_firm = Column(Integer, nullable=False)
id_product = Column(Integer, nullable=False)
ts = Column(Integer, nullable=False)
is_disrupted = Column(Boolean, nullable=True)
is_removed = Column(Boolean, nullable=True)
sample = relationship('Sample', back_populates='product', uselist=False)
firm = relationship('Firm', back_populates='product', lazy='dynamic')
sample = relationship('Sample', back_populates='result', uselist=False)
def __repr__(self):
return f'<Product id: {self.id}>'
class Firm(Base):
__tablename__ = f"{db_name_prefix}_firm"
id = Column(Integer, primary_key=True, autoincrement=True)
p_id = Column(Integer, ForeignKey('{}.id'.format(f"{db_name_prefix}_product")), nullable=False)
idx_firm = Column(Integer, nullable=False)
int_n_supplier = Column(Integer, nullable=False)
flt_fix_cost = Column(DECIMAL(20, 2), nullable=False)
flt_q_star = Column(DECIMAL(20, 2), nullable=False)
acc_demand_per_age = Column(DECIMAL(20, 2), nullable=False)
acc_wealth_per_age = Column(DECIMAL(20, 2), nullable=False)
std_demand_per_age = Column(DECIMAL(20, 2), nullable=False)
product = relationship('Product', back_populates='firm', uselist=False)
def __repr__(self):
return f'<Firm id: {self.id}>'
if __name__ == '__main__':
Base.metadata.drop_all()
Base.metadata.create_all()