IIabm/model.py

325 lines
15 KiB
Python
Raw Normal View History

2023-02-24 15:16:28 +08:00
import agentpy as ap
import pandas as pd
import numpy as np
2023-03-06 22:38:57 +08:00
import random
2023-02-24 15:16:28 +08:00
import networkx as nx
2023-02-24 17:53:55 +08:00
from firm import FirmAgent
2023-02-27 22:02:46 +08:00
from product import ProductAgent
2023-03-14 18:53:00 +08:00
from orm import db_session, Result
2023-03-13 21:53:50 +08:00
import platform
2023-02-24 15:16:28 +08:00
class Model(ap.Model):
def setup(self):
self.sample = self.p.sample
2023-03-13 21:53:50 +08:00
self.int_stop_times, self.int_stop_t = 0, None
2023-03-06 22:38:57 +08:00
self.random = random.Random(self.p.seed)
2023-02-24 15:16:28 +08:00
self.nprandom = np.random.default_rng(self.p.seed)
2023-02-24 17:53:55 +08:00
self.int_n_iter = int(self.p.n_iter)
2023-02-27 22:02:46 +08:00
self.int_n_max_trial = int(self.p.n_max_trial)
2023-03-14 18:53:00 +08:00
self.dct_lst_remove_firm_prod = self.p.dct_lst_init_remove_firm_prod
2023-02-24 15:16:28 +08:00
# init graph bom
BomNodes = pd.read_csv('BomNodes.csv', index_col=0)
BomNodes.set_index('Code', inplace=True)
BomCateNet = pd.read_csv('BomCateNet.csv', index_col=0)
BomCateNet.fillna(0, inplace=True)
2023-02-25 20:14:53 +08:00
G_bom = nx.from_pandas_adjacency(BomCateNet.T,
2023-02-24 15:16:28 +08:00
create_using=nx.MultiDiGraph())
bom_labels_dict = {}
for code in G_bom.nodes:
bom_labels_dict[code] = BomNodes.loc[code].to_dict()
nx.set_node_attributes(G_bom, bom_labels_dict)
# init graph firm
Firm = pd.read_csv("Firm_amended.csv")
2023-03-13 19:47:25 +08:00
Firm['Code'] = Firm['Code'].astype('string')
2023-02-24 15:16:28 +08:00
Firm.fillna(0, inplace=True)
Firm_attr = Firm.loc[:, ["Code", "Name", "Type_Region", "Revenue_Log"]]
firm_product = []
for _, row in Firm.loc[:, '1':].iterrows():
firm_product.append(row[row == 1].index.to_list())
Firm_attr.loc[:, 'Product_Code'] = firm_product
2023-03-13 19:47:25 +08:00
Firm_attr.set_index('Code', inplace=True)
2023-02-24 15:16:28 +08:00
G_Firm = nx.MultiDiGraph()
G_Firm.add_nodes_from(Firm["Code"])
firm_labels_dict = {}
for code in G_Firm.nodes:
firm_labels_dict[code] = Firm_attr.loc[code].to_dict()
nx.set_node_attributes(G_Firm, firm_labels_dict)
# add edge to G_firm according to G_bom
for node in nx.nodes(G_Firm):
for product_code in G_Firm.nodes[node]['Product_Code']:
2023-02-25 20:14:53 +08:00
for succ_product_code in list(G_bom.successors(product_code)):
2023-03-14 18:53:00 +08:00
# for each product of a certain firm
# get each successor (finished product) of this product
# get a list of firm producing this successor
lst_succ_firm = Firm['Code'][Firm[succ_product_code] ==
1].to_list()
lst_succ_firm_size = [
2023-02-25 20:14:53 +08:00
G_Firm.nodes[succ_firm]['Revenue_Log']
2023-03-14 18:53:00 +08:00
for succ_firm in lst_succ_firm
2023-02-25 20:14:53 +08:00
]
2023-03-14 18:53:00 +08:00
lst_prob = [
size / sum(lst_succ_firm_size)
for size in lst_succ_firm_size
2023-02-25 20:14:53 +08:00
]
2023-03-14 18:53:00 +08:00
# select multiple successors based on relative size of this firm
lst_same_prod_firm = Firm['Code'][Firm[product_code] ==
1].to_list()
lst_same_prod_firm_size = [
2023-03-14 17:51:17 +08:00
G_Firm.nodes[f]['Revenue_Log']
2023-03-14 18:53:00 +08:00
for f in lst_same_prod_firm
2023-03-14 17:51:17 +08:00
]
2023-03-14 18:53:00 +08:00
share = G_Firm.nodes[node]['Revenue_Log'] / sum(
lst_same_prod_firm_size)
n_succ_firm = round(share * len(lst_succ_firm)) if round(
share * len(lst_succ_firm)) > 0 else 1 # at least one
lst_choose_firm = self.nprandom.choice(lst_succ_firm,
n_succ_firm,
p=lst_prob)
lst_choose_firm = list(
set(lst_choose_firm
)) # nprandom.choice may have duplicates
lst_add_edge = [(node, succ_firm, {
2023-02-25 20:14:53 +08:00
'Product': product_code
2023-03-14 18:53:00 +08:00
}) for succ_firm in lst_choose_firm]
G_Firm.add_edges_from(lst_add_edge)
2023-02-24 15:16:28 +08:00
self.firm_network = ap.Network(self, G_Firm)
2023-02-27 22:02:46 +08:00
self.product_network = ap.Network(self, G_bom)
2023-02-24 17:53:55 +08:00
2023-02-27 22:02:46 +08:00
# init product
for ag_node, attr in self.product_network.graph.nodes(data=True):
2023-03-14 18:53:00 +08:00
product = ProductAgent(self, code=ag_node.label, name=attr['Name'])
self.product_network.add_agents([product], [ag_node])
self.a_lst_total_products = ap.AgentList(self,
self.product_network.agents)
2023-02-27 22:02:46 +08:00
2023-02-24 17:53:55 +08:00
# init firm
for ag_node, attr in self.firm_network.graph.nodes(data=True):
2023-02-25 20:14:53 +08:00
firm_agent = FirmAgent(
self,
2023-03-13 19:47:25 +08:00
code=ag_node.label,
2023-02-25 20:14:53 +08:00
name=attr['Name'],
type_region=attr['Type_Region'],
revenue_log=attr['Revenue_Log'],
2023-03-14 18:53:00 +08:00
a_lst_product=self.a_lst_total_products.select([
2023-02-27 22:02:46 +08:00
code in attr['Product_Code']
2023-03-14 18:53:00 +08:00
for code in self.a_lst_total_products.code
2023-02-27 22:02:46 +08:00
]))
2023-03-14 18:53:00 +08:00
# init extra capacity based on discrete uniform distribution
for product in firm_agent.a_lst_product:
2023-02-28 16:56:12 +08:00
firm_agent.dct_prod_capacity[product] = self.nprandom.integers(
firm_agent.revenue_log / 5, firm_agent.revenue_log / 5 + 2)
2023-02-27 22:02:46 +08:00
# print(firm_agent.name, firm_agent.dct_prod_capacity)
2023-02-24 17:53:55 +08:00
self.firm_network.add_agents([firm_agent], [ag_node])
2023-03-14 18:53:00 +08:00
self.a_lst_total_firms = ap.AgentList(self, self.firm_network.agents)
2023-02-25 20:14:53 +08:00
2023-02-27 22:02:46 +08:00
# init dct_list_remove_firm_prod (from string to agent)
t_dct = {}
2023-03-14 18:53:00 +08:00
for firm_code, lst_product in self.dct_lst_remove_firm_prod.items():
firm = self.a_lst_total_firms.select(
self.a_lst_total_firms.code == firm_code)[0]
t_dct[firm] = self.a_lst_total_products.select([
code in lst_product for code in self.a_lst_total_products.code
2023-02-27 22:02:46 +08:00
])
2023-03-14 18:53:00 +08:00
self.dct_lst_remove_firm_prod = t_dct
self.dct_lst_disrupt_firm_prod = t_dct
2023-03-12 22:21:39 +08:00
# init output
2023-03-14 18:53:00 +08:00
self.lst_dct_lst_remove_firm_prod = []
self.lst_dct_lst_disrupt_firm_prod = []
2023-02-27 22:02:46 +08:00
# set the initial firm product that are removed
2023-03-14 18:53:00 +08:00
for firm, a_lst_product in self.dct_lst_remove_firm_prod.items():
for product in a_lst_product:
assert product in firm.a_lst_product, \
2023-02-27 22:02:46 +08:00
f"product {product.code} not in firm {firm.code}"
2023-03-14 18:53:00 +08:00
firm.a_lst_product_removed.append(product)
2023-02-24 17:53:55 +08:00
2023-03-06 22:38:57 +08:00
# draw network
2023-03-14 17:51:17 +08:00
# self.draw_network()
2023-02-24 17:53:55 +08:00
2023-03-06 22:38:57 +08:00
def update(self):
2023-03-14 18:53:00 +08:00
self.a_lst_total_firms.clean_before_time_step()
2023-03-12 22:21:39 +08:00
# output
2023-03-14 18:53:00 +08:00
self.lst_dct_lst_remove_firm_prod.append(
(self.t, self.dct_lst_remove_firm_prod))
self.lst_dct_lst_disrupt_firm_prod.append(
(self.t, self.dct_lst_disrupt_firm_prod))
2023-03-12 22:21:39 +08:00
2023-02-27 22:02:46 +08:00
# stop simulation if reached terminal number of iteration
2023-02-26 21:58:05 +08:00
if self.t == self.int_n_iter or len(
2023-03-14 18:53:00 +08:00
self.dct_lst_remove_firm_prod) == 0:
2023-03-13 21:53:50 +08:00
self.int_stop_times = self.t
2023-02-24 17:53:55 +08:00
self.stop()
def step(self):
2023-03-06 22:38:57 +08:00
print('\n', '=' * 20, 'step', self.t, '=' * 20)
print(
'dct_list_remove_firm_prod', {
key.name: value.code
2023-03-14 18:53:00 +08:00
for key, value in self.dct_lst_remove_firm_prod.items()
2023-03-06 22:38:57 +08:00
})
2023-02-27 22:02:46 +08:00
# remove_edge_to_cus_and_cus_up_prod
2023-03-14 18:53:00 +08:00
for firm, a_lst_product in self.dct_lst_remove_firm_prod.items():
for product in a_lst_product:
2023-03-06 22:38:57 +08:00
firm.remove_edge_to_cus_remove_cus_up_prod(product)
2023-02-27 22:02:46 +08:00
for n_trial in range(self.int_n_max_trial):
2023-03-06 22:38:57 +08:00
print('=' * 10, 'trial', n_trial, '=' * 10)
2023-02-27 22:02:46 +08:00
# seek_alt_supply
2023-03-14 18:53:00 +08:00
# shuffle self.a_lst_total_firms
self.a_lst_total_firms = self.a_lst_total_firms.shuffle()
for firm in self.a_lst_total_firms:
if len(firm.a_lst_up_product_removed) > 0:
2023-02-27 22:02:46 +08:00
firm.seek_alt_supply()
# handle_request
2023-03-14 18:53:00 +08:00
# shuffle self.a_lst_total_firms
self.a_lst_total_firms = self.a_lst_total_firms.shuffle()
for firm in self.a_lst_total_firms:
2023-02-27 22:02:46 +08:00
if len(firm.dct_request_prod_from_firm) > 0:
firm.handle_request()
# reset dct_request_prod_from_firm
2023-03-14 18:53:00 +08:00
self.a_lst_total_firms.clean_before_trial()
2023-02-27 22:02:46 +08:00
# do not use:
2023-03-14 18:53:00 +08:00
# self.a_lst_total_firms.dct_request_prod_from_firm = {} why?
# based on a_lst_up_product_removed
# update a_lst_product_disrupted / a_lst_product_removed
# update dct_lst_disrupt_firm_prod / dct_lst_remove_firm_prod
self.dct_lst_remove_firm_prod = {}
self.dct_lst_disrupt_firm_prod = {}
for firm in self.a_lst_total_firms:
if len(firm.a_lst_up_product_removed) > 0:
print(firm.name, 'a_lst_up_product_removed', [
product.code for product in firm.a_lst_up_product_removed
2023-03-13 21:53:50 +08:00
])
2023-03-14 18:53:00 +08:00
for product in firm.a_lst_product:
2023-03-06 22:38:57 +08:00
n_up_product_removed = 0
2023-03-14 18:53:00 +08:00
for up_product_removed in firm.a_lst_up_product_removed:
2023-03-06 22:38:57 +08:00
if product in up_product_removed.a_successors():
n_up_product_removed += 1
if n_up_product_removed == 0:
continue
else:
2023-03-14 18:53:00 +08:00
# update a_lst_product_disrupted / dct_lst_disrupt_firm_prod
if product not in firm.a_lst_product_disrupted:
firm.a_lst_product_disrupted.append(product)
if firm in self.dct_lst_disrupt_firm_prod.keys():
self.dct_lst_disrupt_firm_prod[firm].append(
2023-03-12 22:21:39 +08:00
product)
else:
2023-03-14 18:53:00 +08:00
self.dct_lst_disrupt_firm_prod[
2023-03-12 22:21:39 +08:00
firm] = ap.AgentList(
self.model, [product])
2023-03-14 18:53:00 +08:00
# update a_lst_product_removed / dct_list_remove_firm_prod
# mark disrupted firm as removed based conditionally
2023-03-06 22:38:57 +08:00
lost_percent = n_up_product_removed / len(
product.a_predecessors())
2023-03-14 18:53:00 +08:00
lst_size = self.a_lst_total_firms.revenue_log
std_size = (firm.revenue_log - min(lst_size) +
1) / (max(lst_size) - min(lst_size) + 1)
prod_remove = 1 - std_size * (1 - lost_percent)
if self.nprandom.choice(
[True, False], p=[prod_remove, 1 - prod_remove]):
firm.a_lst_product_removed.append(product)
if firm in self.dct_lst_remove_firm_prod.keys():
self.dct_lst_remove_firm_prod[firm].append(
2023-03-06 22:38:57 +08:00
product)
else:
2023-03-14 18:53:00 +08:00
self.dct_lst_remove_firm_prod[
2023-03-06 22:38:57 +08:00
firm] = ap.AgentList(
self.model, [product])
print(
'dct_list_remove_firm_prod', {
key.name: value.code
2023-03-14 18:53:00 +08:00
for key, value in self.dct_lst_remove_firm_prod.items()
2023-03-06 22:38:57 +08:00
})
2023-02-24 17:53:55 +08:00
def end(self):
2023-03-13 21:53:50 +08:00
print('/' * 20, 'output', '/' * 20)
2023-03-12 22:21:39 +08:00
print('dct_list_remove_firm_prod')
2023-03-14 18:53:00 +08:00
for t, dct in self.lst_dct_lst_remove_firm_prod:
for firm, a_lst_product in dct.items():
for product in a_lst_product:
2023-03-12 22:21:39 +08:00
print(t, firm.name, product.code)
2023-03-14 18:53:00 +08:00
print('dct_lst_disrupt_firm_prod')
for t, dct in self.lst_dct_lst_disrupt_firm_prod:
for firm, a_lst_product in dct.items():
for product in a_lst_product:
2023-03-12 22:21:39 +08:00
print(t, firm.name, product.code)
2023-02-24 15:16:28 +08:00
2023-03-13 21:53:50 +08:00
qry_result = db_session.query(Result).filter_by(s_id=self.sample.id)
if qry_result.count() == 0:
lst_result_info = []
2023-03-14 18:53:00 +08:00
for t, dct in self.lst_dct_lst_disrupt_firm_prod:
for firm, a_lst_product in dct.items():
for product in a_lst_product:
2023-03-13 21:53:50 +08:00
db_r = Result(s_id=self.sample.id,
id_firm=firm.code,
id_product=product.code,
ts=t,
is_disrupted=True)
lst_result_info.append(db_r)
db_session.bulk_save_objects(lst_result_info)
db_session.commit()
2023-03-14 18:53:00 +08:00
for t, dct in self.lst_dct_lst_remove_firm_prod:
for firm, a_lst_product in dct.items():
for product in a_lst_product:
# only firm disrupted can be removed theoretically
2023-03-13 21:53:50 +08:00
qry_f_p = db_session.query(Result).filter(
Result.s_id == self.sample.id,
Result.id_firm == firm.code,
Result.id_product == product.code)
if qry_f_p.count() == 1:
qry_f_p.update({"is_removed": True})
db_session.commit()
2023-03-14 18:53:00 +08:00
self.sample.is_done_flag = 1
self.sample.computer_name = platform.node()
2023-03-13 21:53:50 +08:00
self.sample.stop_t = self.int_stop_times
db_session.commit()
2023-02-24 15:16:28 +08:00
def draw_network(self):
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = 'SimHei'
pos = nx.nx_agraph.graphviz_layout(self.firm_network.graph,
prog="twopi",
args="")
node_label = nx.get_node_attributes(self.firm_network.graph, 'Name')
2023-02-25 20:14:53 +08:00
# print(node_label)
node_degree = dict(self.firm_network.graph.out_degree())
node_label = {
key: f"{node_label[key]} {node_degree[key]}"
for key in node_label.keys()
}
2023-02-24 15:16:28 +08:00
node_size = list(
nx.get_node_attributes(self.firm_network.graph,
'Revenue_Log').values())
node_size = list(map(lambda x: x**2, node_size))
2023-02-25 20:14:53 +08:00
edge_label = nx.get_edge_attributes(self.firm_network.graph, "Product")
# multi(di)graphs, the keys are 3-tuples
edge_label = {(n1, n2): label
for (n1, n2, _), label in edge_label.items()}
2023-02-24 15:16:28 +08:00
plt.figure(figsize=(12, 12), dpi=300)
nx.draw(self.firm_network.graph,
pos,
node_size=node_size,
labels=node_label,
font_size=6)
2023-02-25 20:14:53 +08:00
nx.draw_networkx_edge_labels(self.firm_network.graph,
pos,
edge_label,
font_size=4)
2023-02-24 15:16:28 +08:00
plt.savefig("network.png")