This commit is contained in:
HaoYizhi 2023-03-13 21:53:50 +08:00
parent 09b59d8778
commit b935f0ce79
7 changed files with 58 additions and 14 deletions

Binary file not shown.

Binary file not shown.

View File

@ -45,12 +45,15 @@ class ControllerDB:
Firm['Code'] = Firm['Code'].astype('string')
Firm.fillna(0, inplace=True)
list_dct = []
for _, row in Firm.iterrows():
code = row['Code']
row = row['1':]
for product_code in row.index[row == 1].to_list():
dct = {code: [product_code]}
list_dct.append(dct)
# for _, row in Firm.iterrows():
# code = row['Code']
# row = row['1':]
# for product_code in row.index[row == 1].to_list():
# dct = {code: [product_code]}
# list_dct.append(dct)
# # break
# # break
list_dct = [{'140': ['1.4.5.1']}]
for idx_exp, dct in enumerate(list_dct):
self.add_experiment_1(idx_exp, self.dct_parameter['n_max_trial'],
dct)

View File

@ -5,6 +5,8 @@ import random
import networkx as nx
from firm import FirmAgent
from product import ProductAgent
from orm import db_session, Sample, Result
import platform
sample = 0
seed = 0
@ -34,6 +36,7 @@ dct_sample_para = {
class Model(ap.Model):
def setup(self):
self.sample = self.p.sample
self.int_stop_times, self.int_stop_t = 0, None
self.random = random.Random(self.p.seed)
self.nprandom = np.random.default_rng(self.p.seed)
self.int_n_iter = int(self.p.n_iter)
@ -80,7 +83,7 @@ class Model(ap.Model):
for succ_product_code in list(G_bom.successors(product_code)):
# print(succ_product_code)
list_succ_firms = Firm['Code'][Firm[succ_product_code] ==
1].to_list()
1].to_list()
list_revenue_log = [
G_Firm.nodes[succ_firm]['Revenue_Log']
for succ_firm in list_succ_firms
@ -191,12 +194,16 @@ class Model(ap.Model):
def update(self):
self.a_list_total_firms.clean_before_time_step()
# output
self.list_dct_list_remove_firm_prod.append((self.t, self.dct_list_remove_firm_prod))
self.list_dct_list_disrupt_firm_prod.append((self.t, self.dct_list_disrupt_firm_prod))
self.list_dct_list_remove_firm_prod.append(
(self.t, self.dct_list_remove_firm_prod))
self.list_dct_list_disrupt_firm_prod.append(
(self.t, self.dct_list_disrupt_firm_prod))
# stop simulation if reached terminal number of iteration
if self.t == self.int_n_iter or len(
self.dct_list_remove_firm_prod) == 0:
self.int_stop_times = self.t
print(self.int_stop_times, self.t)
self.stop()
def step(self):
@ -245,13 +252,15 @@ class Model(ap.Model):
# self.a_list_total_firms.dct_request_prod_from_firm = {} why?
# based on a_list_up_product_removed,
# update a_list_product_disrupted / a_list_product_removed
# update a_list_product_disrupted / a_list_product_removed
# update dct_list_disrupt_firm_prod / dct_list_remove_firm_prod
self.dct_list_remove_firm_prod = {}
self.dct_list_disrupt_firm_prod = {}
for firm in self.a_list_total_firms:
if len(firm.a_list_up_product_removed) > 0:
print(firm.name, 'a_list_up_product_removed', [product.code for product in firm.a_list_up_product_removed])
print(firm.name, 'a_list_up_product_removed', [
product.code for product in firm.a_list_up_product_removed
])
for product in firm.a_list_product:
n_up_product_removed = 0
for up_product_removed in firm.a_list_up_product_removed:
@ -307,7 +316,7 @@ class Model(ap.Model):
})
def end(self):
print('/'*20, 'output', '/'*20)
print('/' * 20, 'output', '/' * 20)
print('dct_list_remove_firm_prod')
for t, dct in self.list_dct_list_remove_firm_prod:
for firm, a_list_product in dct.items():
@ -319,6 +328,38 @@ class Model(ap.Model):
for product in a_list_product:
print(t, firm.name, product.code)
qry_result = db_session.query(Result).filter_by(s_id=self.sample.id)
if qry_result.count() == 0:
lst_result_info = []
for t, dct in self.list_dct_list_disrupt_firm_prod:
for firm, a_list_product in dct.items():
for product in a_list_product:
# print(t, firm.name, product.code)
db_r = Result(s_id=self.sample.id,
id_firm=firm.code,
id_product=product.code,
ts=t,
is_disrupted=True)
lst_result_info.append(db_r)
db_session.bulk_save_objects(lst_result_info)
db_session.commit()
for t, dct in self.list_dct_list_remove_firm_prod:
for firm, a_list_product in dct.items():
for product in a_list_product:
# print(t, firm.name, product.code)
# only firm disrupted can be removed
qry_f_p = db_session.query(Result).filter(
Result.s_id == self.sample.id,
Result.id_firm == firm.code,
Result.id_product == product.code)
if qry_f_p.count() == 1:
qry_f_p.update({"is_removed": True})
db_session.commit()
self.sample.is_done_flag, self.sample.computer_name = 1, platform.node(
)
self.sample.stop_t = self.int_stop_times
db_session.commit()
def draw_network(self):
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = 'SimHei'

4
orm.py
View File

@ -77,8 +77,8 @@ class Result(Base):
id = Column(Integer, primary_key=True, autoincrement=True)
s_id = Column(Integer, ForeignKey('{}.id'.format(f"{db_name_prefix}_sample")), nullable=False)
id_firm = Column(Integer, nullable=False)
id_product = Column(Integer, nullable=False)
id_firm = Column(String(10), nullable=False)
id_product = Column(String(10), nullable=False)
ts = Column(Integer, nullable=False)
is_disrupted = Column(Boolean, nullable=True)
is_removed = Column(Boolean, nullable=True)