This commit is contained in:
HaoYizhi 2023-03-13 21:53:50 +08:00
parent 09b59d8778
commit b935f0ce79
7 changed files with 58 additions and 14 deletions

Binary file not shown.

Binary file not shown.

View File

@ -45,12 +45,15 @@ class ControllerDB:
Firm['Code'] = Firm['Code'].astype('string') Firm['Code'] = Firm['Code'].astype('string')
Firm.fillna(0, inplace=True) Firm.fillna(0, inplace=True)
list_dct = [] list_dct = []
for _, row in Firm.iterrows(): # for _, row in Firm.iterrows():
code = row['Code'] # code = row['Code']
row = row['1':] # row = row['1':]
for product_code in row.index[row == 1].to_list(): # for product_code in row.index[row == 1].to_list():
dct = {code: [product_code]} # dct = {code: [product_code]}
list_dct.append(dct) # list_dct.append(dct)
# # break
# # break
list_dct = [{'140': ['1.4.5.1']}]
for idx_exp, dct in enumerate(list_dct): for idx_exp, dct in enumerate(list_dct):
self.add_experiment_1(idx_exp, self.dct_parameter['n_max_trial'], self.add_experiment_1(idx_exp, self.dct_parameter['n_max_trial'],
dct) dct)

View File

@ -5,6 +5,8 @@ import random
import networkx as nx import networkx as nx
from firm import FirmAgent from firm import FirmAgent
from product import ProductAgent from product import ProductAgent
from orm import db_session, Sample, Result
import platform
sample = 0 sample = 0
seed = 0 seed = 0
@ -34,6 +36,7 @@ dct_sample_para = {
class Model(ap.Model): class Model(ap.Model):
def setup(self): def setup(self):
self.sample = self.p.sample self.sample = self.p.sample
self.int_stop_times, self.int_stop_t = 0, None
self.random = random.Random(self.p.seed) self.random = random.Random(self.p.seed)
self.nprandom = np.random.default_rng(self.p.seed) self.nprandom = np.random.default_rng(self.p.seed)
self.int_n_iter = int(self.p.n_iter) self.int_n_iter = int(self.p.n_iter)
@ -191,12 +194,16 @@ class Model(ap.Model):
def update(self): def update(self):
self.a_list_total_firms.clean_before_time_step() self.a_list_total_firms.clean_before_time_step()
# output # output
self.list_dct_list_remove_firm_prod.append((self.t, self.dct_list_remove_firm_prod)) self.list_dct_list_remove_firm_prod.append(
self.list_dct_list_disrupt_firm_prod.append((self.t, self.dct_list_disrupt_firm_prod)) (self.t, self.dct_list_remove_firm_prod))
self.list_dct_list_disrupt_firm_prod.append(
(self.t, self.dct_list_disrupt_firm_prod))
# stop simulation if reached terminal number of iteration # stop simulation if reached terminal number of iteration
if self.t == self.int_n_iter or len( if self.t == self.int_n_iter or len(
self.dct_list_remove_firm_prod) == 0: self.dct_list_remove_firm_prod) == 0:
self.int_stop_times = self.t
print(self.int_stop_times, self.t)
self.stop() self.stop()
def step(self): def step(self):
@ -251,7 +258,9 @@ class Model(ap.Model):
self.dct_list_disrupt_firm_prod = {} self.dct_list_disrupt_firm_prod = {}
for firm in self.a_list_total_firms: for firm in self.a_list_total_firms:
if len(firm.a_list_up_product_removed) > 0: if len(firm.a_list_up_product_removed) > 0:
print(firm.name, 'a_list_up_product_removed', [product.code for product in firm.a_list_up_product_removed]) print(firm.name, 'a_list_up_product_removed', [
product.code for product in firm.a_list_up_product_removed
])
for product in firm.a_list_product: for product in firm.a_list_product:
n_up_product_removed = 0 n_up_product_removed = 0
for up_product_removed in firm.a_list_up_product_removed: for up_product_removed in firm.a_list_up_product_removed:
@ -319,6 +328,38 @@ class Model(ap.Model):
for product in a_list_product: for product in a_list_product:
print(t, firm.name, product.code) print(t, firm.name, product.code)
qry_result = db_session.query(Result).filter_by(s_id=self.sample.id)
if qry_result.count() == 0:
lst_result_info = []
for t, dct in self.list_dct_list_disrupt_firm_prod:
for firm, a_list_product in dct.items():
for product in a_list_product:
# print(t, firm.name, product.code)
db_r = Result(s_id=self.sample.id,
id_firm=firm.code,
id_product=product.code,
ts=t,
is_disrupted=True)
lst_result_info.append(db_r)
db_session.bulk_save_objects(lst_result_info)
db_session.commit()
for t, dct in self.list_dct_list_remove_firm_prod:
for firm, a_list_product in dct.items():
for product in a_list_product:
# print(t, firm.name, product.code)
# only firm disrupted can be removed
qry_f_p = db_session.query(Result).filter(
Result.s_id == self.sample.id,
Result.id_firm == firm.code,
Result.id_product == product.code)
if qry_f_p.count() == 1:
qry_f_p.update({"is_removed": True})
db_session.commit()
self.sample.is_done_flag, self.sample.computer_name = 1, platform.node(
)
self.sample.stop_t = self.int_stop_times
db_session.commit()
def draw_network(self): def draw_network(self):
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = 'SimHei' plt.rcParams['font.sans-serif'] = 'SimHei'

4
orm.py
View File

@ -77,8 +77,8 @@ class Result(Base):
id = Column(Integer, primary_key=True, autoincrement=True) id = Column(Integer, primary_key=True, autoincrement=True)
s_id = Column(Integer, ForeignKey('{}.id'.format(f"{db_name_prefix}_sample")), nullable=False) s_id = Column(Integer, ForeignKey('{}.id'.format(f"{db_name_prefix}_sample")), nullable=False)
id_firm = Column(Integer, nullable=False) id_firm = Column(String(10), nullable=False)
id_product = Column(Integer, nullable=False) id_product = Column(String(10), nullable=False)
ts = Column(Integer, nullable=False) ts = Column(Integer, nullable=False)
is_disrupted = Column(Boolean, nullable=True) is_disrupted = Column(Boolean, nullable=True)
is_removed = Column(Boolean, nullable=True) is_removed = Column(Boolean, nullable=True)