diff --git a/.vscode/launch.json b/.vscode/launch.json index 3f1ca37..92afbfc 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -8,7 +8,7 @@ "name": "Python: Current File", "type": "python", "request": "launch", - "program": "C:\\Users\\ASUS\\OneDrive\\Project\\ScrAbm\\Dissertation\\IIabm\\main.py", + "program": "C:\\Users\\ASUS\\OneDrive\\Project\\ScrAbm\\Dissertation\\IIabm\\model.py", "console": "integratedTerminal", "justMyCode": true } diff --git a/__pycache__/computation.cpython-38.pyc b/__pycache__/computation.cpython-38.pyc index e6acd9c..2049456 100644 Binary files a/__pycache__/computation.cpython-38.pyc and b/__pycache__/computation.cpython-38.pyc differ diff --git a/__pycache__/controller_db.cpython-38.pyc b/__pycache__/controller_db.cpython-38.pyc index fb53429..5947bc5 100644 Binary files a/__pycache__/controller_db.cpython-38.pyc and b/__pycache__/controller_db.cpython-38.pyc differ diff --git a/__pycache__/firm.cpython-38.pyc b/__pycache__/firm.cpython-38.pyc index 3b52083..033af6f 100644 Binary files a/__pycache__/firm.cpython-38.pyc and b/__pycache__/firm.cpython-38.pyc differ diff --git a/__pycache__/orm.cpython-38.pyc b/__pycache__/orm.cpython-38.pyc index 84a240d..ff67101 100644 Binary files a/__pycache__/orm.cpython-38.pyc and b/__pycache__/orm.cpython-38.pyc differ diff --git a/controller_db.py b/controller_db.py index a9c02c7..fab7273 100644 --- a/controller_db.py +++ b/controller_db.py @@ -2,6 +2,7 @@ from orm import db_session, engine, Base, ins from orm import Experiment, Sample, Product, Firm from sqlalchemy.exc import OperationalError +from sqlalchemy import text import yaml import random import numpy as np @@ -72,7 +73,7 @@ class ControllerDB: print(f'Inserted {n_exp} experiments for exp {idx_scenario}!') def add_experiment_1(self, idx_exp, lst_lambda, is_eliminated, flt_beta_developed, - tariff_percentage_1: float, tariff_percentage_2: float): + tariff_percentage_1, tariff_percentage_2): lst_exp = [] for lambda_tier in lst_lambda: e = Experiment(idx_exp=idx_exp, @@ -156,16 +157,16 @@ class ControllerDB: db_session.commit() print(f"Reset the task id {res.id} flag from 0 to -1") else: - Base.metadata.create_all() + Base.metadata.create_all(bind=engine) self.init_tables() print(f"All tables are just created and initialized for exp: {self.db_name_prefix}.") def prepare_list_sample(self): - res = db_session.execute(f'''SELECT count(*) FROM {self.db_name_prefix}_sample s, - {self.db_name_prefix}_experiment e WHERE s.e_id=e.id and e.idx_exp < 3''').scalar() + res = db_session.execute(text(f'''SELECT count(*) FROM {self.db_name_prefix}_sample s, + {self.db_name_prefix}_experiment e WHERE s.e_id=e.id and e.idx_exp < 3''')).scalar() n_sample_1_2 = 0 if res is None else res print(f'There are {n_sample_1_2} sample for exp 1 and 2.') - res = db_session.execute(f'SELECT id FROM {self.db_name_prefix}_sample WHERE is_done_flag = -1') + res = db_session.execute(text(f'SELECT id FROM {self.db_name_prefix}_sample WHERE is_done_flag = -1')) for row in res: s_id = row[0] if s_id <= n_sample_1_2: diff --git a/model.py b/model.py index 4a50423..3a34490 100644 --- a/model.py +++ b/model.py @@ -171,6 +171,11 @@ class Model(ap.Model): for code in self.a_list_total_products.code ]) self.dct_list_remove_firm_prod = t_dct + self.dct_list_disrupt_firm_prod = t_dct + + # init output + self.list_dct_list_remove_firm_prod = [] + self.list_dct_list_disrupt_firm_prod = [] # set the initial firm product that are removed for firm, a_list_product in self.dct_list_remove_firm_prod.items(): @@ -184,6 +189,10 @@ class Model(ap.Model): def update(self): self.a_list_total_firms.clean_before_time_step() + # output + self.list_dct_list_remove_firm_prod.append((self.t, self.dct_list_remove_firm_prod)) + self.list_dct_list_disrupt_firm_prod.append((self.t, self.dct_list_disrupt_firm_prod)) + # stop simulation if reached terminal number of iteration if self.t == self.int_n_iter or len( self.dct_list_remove_firm_prod) == 0: @@ -235,8 +244,10 @@ class Model(ap.Model): # self.a_list_total_firms.dct_request_prod_from_firm = {} why? # based on a_list_up_product_removed, - # update a_list_product_disrupted / a_list_product_removed / dct_list_remove_firm_prod + # update a_list_product_disrupted / a_list_product_removed + # update dct_list_disrupt_firm_prod / dct_list_remove_firm_prod self.dct_list_remove_firm_prod = {} + self.dct_list_disrupt_firm_prod = {} for firm in self.a_list_total_firms: if len(firm.a_list_up_product_removed) > 0: print(firm.name, 'a_list_up_product_removed', [product.code for product in firm.a_list_up_product_removed]) @@ -248,9 +259,16 @@ class Model(ap.Model): if n_up_product_removed == 0: continue else: - # update a_list_product_disrupted + # update a_list_product_disrupted / dct_list_disrupt_firm_prod if product not in firm.a_list_product_disrupted: firm.a_list_product_disrupted.append(product) + if firm in self.dct_list_disrupt_firm_prod.keys(): + self.dct_list_disrupt_firm_prod[firm].append( + product) + else: + self.dct_list_disrupt_firm_prod[ + firm] = ap.AgentList( + self.model, [product]) # update a_list_product_removed / dct_list_remove_firm_prod lost_percent = n_up_product_removed / len( product.a_predecessors()) @@ -259,9 +277,9 @@ class Model(ap.Model): 1) / (max(list_revenue_log) - min(list_revenue_log) + 1) p_remove = 1 - std_size * (1 - lost_percent) - # flag = self.nprandom.choice([1, 0], - # p=[p_remove, 1 - p_remove]) - flag = 1 + flag = self.nprandom.choice([1, 0], + p=[p_remove, 1 - p_remove]) + # flag = 1 if flag == 1: firm.a_list_product_removed.append(product) # if firm in @@ -288,7 +306,17 @@ class Model(ap.Model): }) def end(self): - pass + print('/'*20, 'output', '/'*20) + print('dct_list_remove_firm_prod') + for t, dct in self.list_dct_list_remove_firm_prod: + for firm, a_list_product in dct.items(): + for product in a_list_product: + print(t, firm.name, product.code) + print('dct_list_disrupt_firm_prod') + for t, dct in self.list_dct_list_disrupt_firm_prod: + for firm, a_list_product in dct.items(): + for product in a_list_product: + print(t, firm.name, product.code) def draw_network(self): import matplotlib.pyplot as plt @@ -324,5 +352,5 @@ class Model(ap.Model): plt.savefig("network.png") -# model = Model(dct_sample_para) -# model.run() +model = Model(dct_sample_para) +model.run() diff --git a/orm.py b/orm.py index dd683c0..a3926ae 100644 --- a/orm.py +++ b/orm.py @@ -28,7 +28,7 @@ print('DB is {}:{}/{}'.format(dct_conf_db['address'], dct_conf_db['port'], dct_c engine = create_engine(str_login, poolclass=NullPool) # must be null pool to avoid connection lost error ins = inspect(engine) -Base = declarative_base(constructor=engine) +Base = declarative_base() db_session = Session(bind=engine)