database
This commit is contained in:
parent
09b59d8778
commit
b935f0ce79
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -45,12 +45,15 @@ class ControllerDB:
|
|||
Firm['Code'] = Firm['Code'].astype('string')
|
||||
Firm.fillna(0, inplace=True)
|
||||
list_dct = []
|
||||
for _, row in Firm.iterrows():
|
||||
code = row['Code']
|
||||
row = row['1':]
|
||||
for product_code in row.index[row == 1].to_list():
|
||||
dct = {code: [product_code]}
|
||||
list_dct.append(dct)
|
||||
# for _, row in Firm.iterrows():
|
||||
# code = row['Code']
|
||||
# row = row['1':]
|
||||
# for product_code in row.index[row == 1].to_list():
|
||||
# dct = {code: [product_code]}
|
||||
# list_dct.append(dct)
|
||||
# # break
|
||||
# # break
|
||||
list_dct = [{'140': ['1.4.5.1']}]
|
||||
for idx_exp, dct in enumerate(list_dct):
|
||||
self.add_experiment_1(idx_exp, self.dct_parameter['n_max_trial'],
|
||||
dct)
|
||||
|
|
51
model.py
51
model.py
|
@ -5,6 +5,8 @@ import random
|
|||
import networkx as nx
|
||||
from firm import FirmAgent
|
||||
from product import ProductAgent
|
||||
from orm import db_session, Sample, Result
|
||||
import platform
|
||||
|
||||
sample = 0
|
||||
seed = 0
|
||||
|
@ -34,6 +36,7 @@ dct_sample_para = {
|
|||
class Model(ap.Model):
|
||||
def setup(self):
|
||||
self.sample = self.p.sample
|
||||
self.int_stop_times, self.int_stop_t = 0, None
|
||||
self.random = random.Random(self.p.seed)
|
||||
self.nprandom = np.random.default_rng(self.p.seed)
|
||||
self.int_n_iter = int(self.p.n_iter)
|
||||
|
@ -80,7 +83,7 @@ class Model(ap.Model):
|
|||
for succ_product_code in list(G_bom.successors(product_code)):
|
||||
# print(succ_product_code)
|
||||
list_succ_firms = Firm['Code'][Firm[succ_product_code] ==
|
||||
1].to_list()
|
||||
1].to_list()
|
||||
list_revenue_log = [
|
||||
G_Firm.nodes[succ_firm]['Revenue_Log']
|
||||
for succ_firm in list_succ_firms
|
||||
|
@ -191,12 +194,16 @@ class Model(ap.Model):
|
|||
def update(self):
|
||||
self.a_list_total_firms.clean_before_time_step()
|
||||
# output
|
||||
self.list_dct_list_remove_firm_prod.append((self.t, self.dct_list_remove_firm_prod))
|
||||
self.list_dct_list_disrupt_firm_prod.append((self.t, self.dct_list_disrupt_firm_prod))
|
||||
self.list_dct_list_remove_firm_prod.append(
|
||||
(self.t, self.dct_list_remove_firm_prod))
|
||||
self.list_dct_list_disrupt_firm_prod.append(
|
||||
(self.t, self.dct_list_disrupt_firm_prod))
|
||||
|
||||
# stop simulation if reached terminal number of iteration
|
||||
if self.t == self.int_n_iter or len(
|
||||
self.dct_list_remove_firm_prod) == 0:
|
||||
self.int_stop_times = self.t
|
||||
print(self.int_stop_times, self.t)
|
||||
self.stop()
|
||||
|
||||
def step(self):
|
||||
|
@ -251,7 +258,9 @@ class Model(ap.Model):
|
|||
self.dct_list_disrupt_firm_prod = {}
|
||||
for firm in self.a_list_total_firms:
|
||||
if len(firm.a_list_up_product_removed) > 0:
|
||||
print(firm.name, 'a_list_up_product_removed', [product.code for product in firm.a_list_up_product_removed])
|
||||
print(firm.name, 'a_list_up_product_removed', [
|
||||
product.code for product in firm.a_list_up_product_removed
|
||||
])
|
||||
for product in firm.a_list_product:
|
||||
n_up_product_removed = 0
|
||||
for up_product_removed in firm.a_list_up_product_removed:
|
||||
|
@ -307,7 +316,7 @@ class Model(ap.Model):
|
|||
})
|
||||
|
||||
def end(self):
|
||||
print('/'*20, 'output', '/'*20)
|
||||
print('/' * 20, 'output', '/' * 20)
|
||||
print('dct_list_remove_firm_prod')
|
||||
for t, dct in self.list_dct_list_remove_firm_prod:
|
||||
for firm, a_list_product in dct.items():
|
||||
|
@ -319,6 +328,38 @@ class Model(ap.Model):
|
|||
for product in a_list_product:
|
||||
print(t, firm.name, product.code)
|
||||
|
||||
qry_result = db_session.query(Result).filter_by(s_id=self.sample.id)
|
||||
if qry_result.count() == 0:
|
||||
lst_result_info = []
|
||||
for t, dct in self.list_dct_list_disrupt_firm_prod:
|
||||
for firm, a_list_product in dct.items():
|
||||
for product in a_list_product:
|
||||
# print(t, firm.name, product.code)
|
||||
db_r = Result(s_id=self.sample.id,
|
||||
id_firm=firm.code,
|
||||
id_product=product.code,
|
||||
ts=t,
|
||||
is_disrupted=True)
|
||||
lst_result_info.append(db_r)
|
||||
db_session.bulk_save_objects(lst_result_info)
|
||||
db_session.commit()
|
||||
for t, dct in self.list_dct_list_remove_firm_prod:
|
||||
for firm, a_list_product in dct.items():
|
||||
for product in a_list_product:
|
||||
# print(t, firm.name, product.code)
|
||||
# only firm disrupted can be removed
|
||||
qry_f_p = db_session.query(Result).filter(
|
||||
Result.s_id == self.sample.id,
|
||||
Result.id_firm == firm.code,
|
||||
Result.id_product == product.code)
|
||||
if qry_f_p.count() == 1:
|
||||
qry_f_p.update({"is_removed": True})
|
||||
db_session.commit()
|
||||
self.sample.is_done_flag, self.sample.computer_name = 1, platform.node(
|
||||
)
|
||||
self.sample.stop_t = self.int_stop_times
|
||||
db_session.commit()
|
||||
|
||||
def draw_network(self):
|
||||
import matplotlib.pyplot as plt
|
||||
plt.rcParams['font.sans-serif'] = 'SimHei'
|
||||
|
|
4
orm.py
4
orm.py
|
@ -77,8 +77,8 @@ class Result(Base):
|
|||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
s_id = Column(Integer, ForeignKey('{}.id'.format(f"{db_name_prefix}_sample")), nullable=False)
|
||||
|
||||
id_firm = Column(Integer, nullable=False)
|
||||
id_product = Column(Integer, nullable=False)
|
||||
id_firm = Column(String(10), nullable=False)
|
||||
id_product = Column(String(10), nullable=False)
|
||||
ts = Column(Integer, nullable=False)
|
||||
is_disrupted = Column(Boolean, nullable=True)
|
||||
is_removed = Column(Boolean, nullable=True)
|
||||
|
|
Loading…
Reference in New Issue