IIabm/model.py

195 lines
7.9 KiB
Python
Raw Normal View History

2023-02-24 15:16:28 +08:00
import agentpy as ap
import pandas as pd
import numpy as np
import networkx as nx
2023-02-24 17:53:55 +08:00
from firm import FirmAgent
2023-02-24 15:16:28 +08:00
sample = 0
2023-02-24 17:53:55 +08:00
seed = 0
n_iter = 3
2023-02-25 20:14:53 +08:00
dct_list_init_remove_firm_prod = {0: ['1.4.4'], 2: ['1.1.3']}
2023-02-24 17:53:55 +08:00
dct_sample_para = {
'sample': sample,
'seed': seed,
'n_iter': n_iter,
2023-02-25 20:14:53 +08:00
'dct_list_init_remove_firm_prod': dct_list_init_remove_firm_prod
2023-02-24 17:53:55 +08:00
}
2023-02-24 15:16:28 +08:00
class Model(ap.Model):
def setup(self):
self.sample = self.p.sample
self.nprandom = np.random.default_rng(self.p.seed)
2023-02-25 20:14:53 +08:00
self.dct_list_remove_firm_prod = self.p.dct_list_init_remove_firm_prod
2023-02-24 17:53:55 +08:00
self.int_n_iter = int(self.p.n_iter)
2023-02-24 15:16:28 +08:00
# init graph bom
BomNodes = pd.read_csv('BomNodes.csv', index_col=0)
BomNodes.set_index('Code', inplace=True)
BomCateNet = pd.read_csv('BomCateNet.csv', index_col=0)
BomCateNet.fillna(0, inplace=True)
2023-02-25 20:14:53 +08:00
G_bom = nx.from_pandas_adjacency(BomCateNet.T,
2023-02-24 15:16:28 +08:00
create_using=nx.MultiDiGraph())
bom_labels_dict = {}
for code in G_bom.nodes:
bom_labels_dict[code] = BomNodes.loc[code].to_dict()
nx.set_node_attributes(G_bom, bom_labels_dict)
# init graph firm
Firm = pd.read_csv("Firm_amended.csv")
Firm.fillna(0, inplace=True)
Firm_attr = Firm.loc[:, ["Code", "Name", "Type_Region", "Revenue_Log"]]
firm_product = []
for _, row in Firm.loc[:, '1':].iterrows():
firm_product.append(row[row == 1].index.to_list())
Firm_attr.loc[:, 'Product_Code'] = firm_product
Firm_attr.set_index('Code')
G_Firm = nx.MultiDiGraph()
G_Firm.add_nodes_from(Firm["Code"])
firm_labels_dict = {}
for code in G_Firm.nodes:
firm_labels_dict[code] = Firm_attr.loc[code].to_dict()
nx.set_node_attributes(G_Firm, firm_labels_dict)
# add edge to G_firm according to G_bom
for node in nx.nodes(G_Firm):
2023-02-25 20:14:53 +08:00
# print(node, '-' * 20)
2023-02-24 15:16:28 +08:00
for product_code in G_Firm.nodes[node]['Product_Code']:
2023-02-25 20:14:53 +08:00
# print(product_code)
for succ_product_code in list(G_bom.successors(product_code)):
# print(succ_product_code)
list_succ_firms = Firm.index[Firm[succ_product_code] ==
1].to_list()
list_revenue_log = [
G_Firm.nodes[succ_firm]['Revenue_Log']
for succ_firm in list_succ_firms
]
list_prob = [
(v - min(list_revenue_log) + 1) /
(max(list_revenue_log) - min(list_revenue_log) + 1)
for v in list_revenue_log
]
list_flag = [
self.nprandom.choice([1, 0], p=[prob, 1 - prob])
for prob in list_prob
]
# print(list(zip(list_succ_firms,list_flag,list_prob)))
list_added_edges = [(node, succ_firm, {
'Product': product_code
}) for succ_firm, flag in zip(list_succ_firms, list_flag)
if flag == 1]
G_Firm.add_edges_from(list_added_edges)
# print('-' * 20)
2023-02-24 15:16:28 +08:00
self.firm_network = ap.Network(self, G_Firm)
2023-02-24 17:53:55 +08:00
# print([node.label for node in self.firm_network.nodes])
2023-02-26 21:58:05 +08:00
# print([list(self.firm_network.graph.predecessors(node))
# for node in self.firm_network.nodes])
# print([self.firm_network.graph.nodes[node.label]['Name']
# for node in self.firm_network.nodes])
2023-02-24 17:53:55 +08:00
# print([v for v in self.firm_network.graph.nodes(data=True)])
# init firm
for ag_node, attr in self.firm_network.graph.nodes(data=True):
2023-02-25 20:14:53 +08:00
firm_agent = FirmAgent(
self,
code=attr['Code'],
name=attr['Name'],
type_region=attr['Type_Region'],
revenue_log=attr['Revenue_Log'],
list_product=attr['Product_Code'],
# init capacity as the degree of out edges
capacity=self.firm_network.graph.out_degree(ag_node))
2023-02-24 17:53:55 +08:00
self.firm_network.add_agents([firm_agent], [ag_node])
2023-02-25 20:14:53 +08:00
self.a_list_total_firms = ap.AgentList(self, self.firm_network.agents)
2023-02-26 21:58:05 +08:00
# print(list(zip(self.a_list_total_firms.code,
# self.a_list_total_firms.name,
# self.a_list_total_firms.capacity)))
2023-02-25 20:14:53 +08:00
# set the initial firm product that are removed
for firm_code, list_product in self.dct_list_remove_firm_prod.items():
firm = self.a_list_total_firms.select(
self.a_list_total_firms.code == firm_code)[0]
for product in list_product:
assert product in firm.list_product, \
f"product {product} not in firm {firm_code}"
firm.dct_product_is_removed[product] = True
2023-02-24 17:53:55 +08:00
def update(self):
2023-02-26 21:58:05 +08:00
# Update the firm that is removed
self.dct_list_remove_firm_prod = {}
for firm in self.a_list_total_firms:
for product, flag in firm.dct_product_is_removed.items():
if flag is True:
if firm.code in self.dct_list_remove_firm_prod.keys():
self.dct_list_remove_firm_prod[firm.code].append(
product)
else:
self.dct_list_remove_firm_prod[firm.code] = [product]
# print(self.dct_list_remove_firm_prod)
2023-02-24 17:53:55 +08:00
# Stop simulation if reached terminal number of iteration
2023-02-26 21:58:05 +08:00
if self.t == self.int_n_iter or len(
self.dct_list_remove_firm_prod) == 0:
2023-02-24 17:53:55 +08:00
self.stop()
def step(self):
2023-02-26 21:58:05 +08:00
# shuffle self.dct_list_remove_firm_prod
dct_key_list = list(self.dct_list_remove_firm_prod.keys())
self.nprandom.shuffle(dct_key_list)
self.dct_list_remove_firm_prod = {
key: self.dct_list_remove_firm_prod[key]
for key in dct_key_list
}
for firm_code, list_product in self.dct_list_remove_firm_prod.items():
firm = self.a_list_total_firms.select(
self.a_list_total_firms.code == firm_code)[0]
for product in list_product:
firm.remove_edge_to_customer_if_removed(product)
2023-02-24 17:53:55 +08:00
def end(self):
pass
2023-02-24 15:16:28 +08:00
def draw_network(self):
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = 'SimHei'
pos = nx.nx_agraph.graphviz_layout(self.firm_network.graph,
prog="twopi",
args="")
node_label = nx.get_node_attributes(self.firm_network.graph, 'Name')
2023-02-25 20:14:53 +08:00
# print(node_label)
node_degree = dict(self.firm_network.graph.out_degree())
node_label = {
key: f"{node_label[key]} {node_degree[key]}"
for key in node_label.keys()
}
2023-02-24 15:16:28 +08:00
node_size = list(
nx.get_node_attributes(self.firm_network.graph,
'Revenue_Log').values())
node_size = list(map(lambda x: x**2, node_size))
2023-02-25 20:14:53 +08:00
edge_label = nx.get_edge_attributes(self.firm_network.graph, "Product")
# multi(di)graphs, the keys are 3-tuples
edge_label = {(n1, n2): label
for (n1, n2, _), label in edge_label.items()}
2023-02-24 15:16:28 +08:00
plt.figure(figsize=(12, 12), dpi=300)
nx.draw(self.firm_network.graph,
pos,
node_size=node_size,
labels=node_label,
font_size=6)
2023-02-25 20:14:53 +08:00
nx.draw_networkx_edge_labels(self.firm_network.graph,
pos,
edge_label,
font_size=4)
2023-02-24 15:16:28 +08:00
plt.savefig("network.png")
model = Model(dct_sample_para)
model.setup()
2023-02-26 21:58:05 +08:00
model.update()
model.step()
# model.draw_network()