log status of every firm product in every time step + mysql

This commit is contained in:
2023-07-02 20:44:06 +08:00
parent 0965a5daa4
commit 84828272e2
8 changed files with 92 additions and 53 deletions

View File

@@ -17,7 +17,8 @@ class Model(ap.Model):
self.product_network = None # agentpy network
self.firm_network = None # agentpy network
self.firm_prod_network = None # networkx
self.dct_lst_disrupt_firm_prod = self.p.dct_lst_init_disrupt_firm_prod
self.dct_lst_init_disrupt_firm_prod = \
self.p.dct_lst_init_disrupt_firm_prod
# external variable
self.int_n_max_trial = int(self.p.n_max_trial)
@@ -209,19 +210,20 @@ class Model(ap.Model):
self.firm_network.add_agents([firm_agent], [ag_node])
self.a_lst_total_firms = ap.AgentList(self, self.firm_network.agents)
# init dct_lst_disrupt_firm_prod (from string to agent)
# init dct_lst_init_disrupt_firm_prod (from string to agent)
t_dct = {}
for firm_code, lst_product in self.dct_lst_disrupt_firm_prod.items():
for firm_code, lst_product in \
self.dct_lst_init_disrupt_firm_prod.items():
firm = self.a_lst_total_firms.select(
self.a_lst_total_firms.code == firm_code)[0]
t_dct[firm] = self.a_lst_total_products.select([
code in lst_product for code in self.a_lst_total_products.code
])
self.dct_lst_disrupt_firm_prod = t_dct
self.dct_lst_init_disrupt_firm_prod = t_dct
# set the initial firm product that are disrupted
print('\n', '=' * 20, 'step', self.t, '=' * 20)
for firm, a_lst_product in self.dct_lst_disrupt_firm_prod.items():
for firm, a_lst_product in self.dct_lst_init_disrupt_firm_prod.items():
for product in a_lst_product:
assert product in firm.dct_prod_up_prod_stat.keys(), \
f"product {product.code} not in firm {firm.code}"
@@ -231,7 +233,7 @@ class Model(ap.Model):
# proactive strategy
# get all the firm prod affected
for firm, a_lst_product in self.dct_lst_disrupt_firm_prod.items():
for firm, a_lst_product in self.dct_lst_init_disrupt_firm_prod.items():
for product in a_lst_product:
init_node = \
[n for n, v in
@@ -349,8 +351,13 @@ class Model(ap.Model):
firm.size_stat.append((size, self.t))
print(f'in ts {self.t}, reduce {firm.name} size '
f'to {firm.size_stat[-1][0]} due to {prod.code}')
if self.t - ts + 1 == self.remove_t:
lst_is_disrupt = \
[stat == 'D' for stat, _ in
firm.dct_prod_up_prod_stat[prod]['status']
[-1 * self.remove_t:]]
if all(lst_is_disrupt):
# turn disrupted firm into removed firm
# when last self.remove_t times status is all disrupted
firm.dct_prod_up_prod_stat[
prod]['status'].append(('R', self.t))
@@ -358,8 +365,11 @@ class Model(ap.Model):
if self.t > 0:
for firm in self.a_lst_total_firms:
for prod in firm.dct_prod_up_prod_stat.keys():
status, ts = firm.dct_prod_up_prod_stat[prod]['status'][-1]
if status == 'D' and ts != 0:
status, _ = firm.dct_prod_up_prod_stat[prod]['status'][-1]
is_init = \
firm in self.dct_lst_init_disrupt_firm_prod.keys() \
and prod in self.dct_lst_init_disrupt_firm_prod[firm]
if status == 'D' and not is_init:
print("not stop because", firm.name, prod.code)
break
else:
@@ -422,45 +432,30 @@ class Model(ap.Model):
def end(self):
print('/' * 20, 'output', '/' * 20)
for firm in self.a_lst_total_firms:
is_size = False
for prod, dct_status_supply in firm.dct_prod_up_prod_stat.items():
if len(dct_status_supply['status']) > 1:
is_size = True
print(f"{firm.name} {prod.code}:")
print(dct_status_supply['status'])
if is_size:
print(firm.size_stat)
# qry_result = db_session.query(Result).filter_by(s_id=self.sample.id)
# if qry_result.count() == 0:
# lst_result_info = []
# for t, dct in self.lst_dct_lst_disrupt_firm_prod:
# for firm, a_lst_product in dct.items():
# for product in a_lst_product:
# db_r = Result(s_id=self.sample.id,
# id_firm=firm.code,
# id_product=product.code,
# ts=t,
# is_disrupted=True)
# lst_result_info.append(db_r)
# db_session.bulk_save_objects(lst_result_info)
# db_session.commit()
# for t, dct in self.lst_dct_lst_disrupt_firm_prod:
# for firm, a_lst_product in dct.items():
# for product in a_lst_product:
# # only firm disrupted can be removed theoretically
# qry_f_p = db_session.query(Result).filter(
# Result.s_id == self.sample.id,
# Result.id_firm == firm.code,
# Result.id_product == product.code)
# if qry_f_p.count() == 1:
# qry_f_p.update({"is_removed": True})
# db_session.commit()
# self.sample.is_done_flag = 1
# self.sample.computer_name = platform.node()
# self.sample.stop_t = self.int_stop_ts
# db_session.commit()
qry_result = db_session.query(Result).filter_by(s_id=self.sample.id)
if qry_result.count() == 0:
lst_result_info = []
for firm in self.a_lst_total_firms:
for prod, dct_status_supply in \
firm.dct_prod_up_prod_stat.items():
lst_is_normal = [stat == 'N' for stat, _
in dct_status_supply['status']]
if not all(lst_is_normal):
print(f"{firm.name} {prod.code}:")
print(dct_status_supply['status'])
for status, ts in dct_status_supply['status']:
db_r = Result(s_id=self.sample.id,
id_firm=firm.code,
id_product=prod.code,
ts=ts,
status=status)
lst_result_info.append(db_r)
db_session.bulk_save_objects(lst_result_info)
db_session.commit()
self.sample.is_done_flag = 1
self.sample.computer_name = platform.node()
self.sample.stop_t = self.int_stop_ts
db_session.commit()
def draw_network(self):
import matplotlib.pyplot as plt