diff --git a/__pycache__/computation.cpython-38.pyc b/__pycache__/computation.cpython-38.pyc index 2049456..2a37f85 100644 Binary files a/__pycache__/computation.cpython-38.pyc and b/__pycache__/computation.cpython-38.pyc differ diff --git a/__pycache__/controller_db.cpython-38.pyc b/__pycache__/controller_db.cpython-38.pyc index 03cb1ba..8295673 100644 Binary files a/__pycache__/controller_db.cpython-38.pyc and b/__pycache__/controller_db.cpython-38.pyc differ diff --git a/__pycache__/model.cpython-38.pyc b/__pycache__/model.cpython-38.pyc index 19ab1d4..d18096c 100644 Binary files a/__pycache__/model.cpython-38.pyc and b/__pycache__/model.cpython-38.pyc differ diff --git a/__pycache__/orm.cpython-38.pyc b/__pycache__/orm.cpython-38.pyc index ed229af..8e2fae3 100644 Binary files a/__pycache__/orm.cpython-38.pyc and b/__pycache__/orm.cpython-38.pyc differ diff --git a/controller_db.py b/controller_db.py index c726afa..6f648b2 100644 --- a/controller_db.py +++ b/controller_db.py @@ -45,12 +45,15 @@ class ControllerDB: Firm['Code'] = Firm['Code'].astype('string') Firm.fillna(0, inplace=True) list_dct = [] - for _, row in Firm.iterrows(): - code = row['Code'] - row = row['1':] - for product_code in row.index[row == 1].to_list(): - dct = {code: [product_code]} - list_dct.append(dct) + # for _, row in Firm.iterrows(): + # code = row['Code'] + # row = row['1':] + # for product_code in row.index[row == 1].to_list(): + # dct = {code: [product_code]} + # list_dct.append(dct) + # # break + # # break + list_dct = [{'140': ['1.4.5.1']}] for idx_exp, dct in enumerate(list_dct): self.add_experiment_1(idx_exp, self.dct_parameter['n_max_trial'], dct) diff --git a/model.py b/model.py index 0e946b7..93221d1 100644 --- a/model.py +++ b/model.py @@ -5,6 +5,8 @@ import random import networkx as nx from firm import FirmAgent from product import ProductAgent +from orm import db_session, Sample, Result +import platform sample = 0 seed = 0 @@ -34,6 +36,7 @@ dct_sample_para = { class Model(ap.Model): def setup(self): self.sample = self.p.sample + self.int_stop_times, self.int_stop_t = 0, None self.random = random.Random(self.p.seed) self.nprandom = np.random.default_rng(self.p.seed) self.int_n_iter = int(self.p.n_iter) @@ -80,7 +83,7 @@ class Model(ap.Model): for succ_product_code in list(G_bom.successors(product_code)): # print(succ_product_code) list_succ_firms = Firm['Code'][Firm[succ_product_code] == - 1].to_list() + 1].to_list() list_revenue_log = [ G_Firm.nodes[succ_firm]['Revenue_Log'] for succ_firm in list_succ_firms @@ -191,12 +194,16 @@ class Model(ap.Model): def update(self): self.a_list_total_firms.clean_before_time_step() # output - self.list_dct_list_remove_firm_prod.append((self.t, self.dct_list_remove_firm_prod)) - self.list_dct_list_disrupt_firm_prod.append((self.t, self.dct_list_disrupt_firm_prod)) + self.list_dct_list_remove_firm_prod.append( + (self.t, self.dct_list_remove_firm_prod)) + self.list_dct_list_disrupt_firm_prod.append( + (self.t, self.dct_list_disrupt_firm_prod)) # stop simulation if reached terminal number of iteration if self.t == self.int_n_iter or len( self.dct_list_remove_firm_prod) == 0: + self.int_stop_times = self.t + print(self.int_stop_times, self.t) self.stop() def step(self): @@ -245,13 +252,15 @@ class Model(ap.Model): # self.a_list_total_firms.dct_request_prod_from_firm = {} why? # based on a_list_up_product_removed, - # update a_list_product_disrupted / a_list_product_removed + # update a_list_product_disrupted / a_list_product_removed # update dct_list_disrupt_firm_prod / dct_list_remove_firm_prod self.dct_list_remove_firm_prod = {} self.dct_list_disrupt_firm_prod = {} for firm in self.a_list_total_firms: if len(firm.a_list_up_product_removed) > 0: - print(firm.name, 'a_list_up_product_removed', [product.code for product in firm.a_list_up_product_removed]) + print(firm.name, 'a_list_up_product_removed', [ + product.code for product in firm.a_list_up_product_removed + ]) for product in firm.a_list_product: n_up_product_removed = 0 for up_product_removed in firm.a_list_up_product_removed: @@ -307,7 +316,7 @@ class Model(ap.Model): }) def end(self): - print('/'*20, 'output', '/'*20) + print('/' * 20, 'output', '/' * 20) print('dct_list_remove_firm_prod') for t, dct in self.list_dct_list_remove_firm_prod: for firm, a_list_product in dct.items(): @@ -319,6 +328,38 @@ class Model(ap.Model): for product in a_list_product: print(t, firm.name, product.code) + qry_result = db_session.query(Result).filter_by(s_id=self.sample.id) + if qry_result.count() == 0: + lst_result_info = [] + for t, dct in self.list_dct_list_disrupt_firm_prod: + for firm, a_list_product in dct.items(): + for product in a_list_product: + # print(t, firm.name, product.code) + db_r = Result(s_id=self.sample.id, + id_firm=firm.code, + id_product=product.code, + ts=t, + is_disrupted=True) + lst_result_info.append(db_r) + db_session.bulk_save_objects(lst_result_info) + db_session.commit() + for t, dct in self.list_dct_list_remove_firm_prod: + for firm, a_list_product in dct.items(): + for product in a_list_product: + # print(t, firm.name, product.code) + # only firm disrupted can be removed + qry_f_p = db_session.query(Result).filter( + Result.s_id == self.sample.id, + Result.id_firm == firm.code, + Result.id_product == product.code) + if qry_f_p.count() == 1: + qry_f_p.update({"is_removed": True}) + db_session.commit() + self.sample.is_done_flag, self.sample.computer_name = 1, platform.node( + ) + self.sample.stop_t = self.int_stop_times + db_session.commit() + def draw_network(self): import matplotlib.pyplot as plt plt.rcParams['font.sans-serif'] = 'SimHei' diff --git a/orm.py b/orm.py index 94283ff..2bb0594 100644 --- a/orm.py +++ b/orm.py @@ -77,8 +77,8 @@ class Result(Base): id = Column(Integer, primary_key=True, autoincrement=True) s_id = Column(Integer, ForeignKey('{}.id'.format(f"{db_name_prefix}_sample")), nullable=False) - id_firm = Column(Integer, nullable=False) - id_product = Column(Integer, nullable=False) + id_firm = Column(String(10), nullable=False) + id_product = Column(String(10), nullable=False) ts = Column(Integer, nullable=False) is_disrupted = Column(Boolean, nullable=True) is_removed = Column(Boolean, nullable=True)