247 lines
10 KiB
Python
247 lines
10 KiB
Python
import agentpy as ap
|
|
import pandas as pd
|
|
import numpy as np
|
|
import networkx as nx
|
|
from firm import FirmAgent
|
|
from product import ProductAgent
|
|
|
|
sample = 0
|
|
seed = 0
|
|
n_iter = 3
|
|
dct_list_init_remove_firm_prod = {0: ['1.4.4'], 2: ['1.1.3']}
|
|
n_max_trial = 2
|
|
dct_sample_para = {
|
|
'sample': sample,
|
|
'seed': seed,
|
|
'n_iter': n_iter,
|
|
'n_max_trial': n_max_trial,
|
|
'dct_list_init_remove_firm_prod': dct_list_init_remove_firm_prod,
|
|
}
|
|
|
|
|
|
class Model(ap.Model):
|
|
def setup(self):
|
|
self.sample = self.p.sample
|
|
self.nprandom = np.random.default_rng(self.p.seed)
|
|
self.int_n_iter = int(self.p.n_iter)
|
|
self.int_n_max_trial = int(self.p.n_max_trial)
|
|
self.dct_list_remove_firm_prod = self.p.dct_list_init_remove_firm_prod
|
|
|
|
# init graph bom
|
|
BomNodes = pd.read_csv('BomNodes.csv', index_col=0)
|
|
BomNodes.set_index('Code', inplace=True)
|
|
BomCateNet = pd.read_csv('BomCateNet.csv', index_col=0)
|
|
BomCateNet.fillna(0, inplace=True)
|
|
|
|
G_bom = nx.from_pandas_adjacency(BomCateNet.T,
|
|
create_using=nx.MultiDiGraph())
|
|
|
|
bom_labels_dict = {}
|
|
for code in G_bom.nodes:
|
|
bom_labels_dict[code] = BomNodes.loc[code].to_dict()
|
|
nx.set_node_attributes(G_bom, bom_labels_dict)
|
|
|
|
# init graph firm
|
|
Firm = pd.read_csv("Firm_amended.csv")
|
|
Firm.fillna(0, inplace=True)
|
|
Firm_attr = Firm.loc[:, ["Code", "Name", "Type_Region", "Revenue_Log"]]
|
|
firm_product = []
|
|
for _, row in Firm.loc[:, '1':].iterrows():
|
|
firm_product.append(row[row == 1].index.to_list())
|
|
Firm_attr.loc[:, 'Product_Code'] = firm_product
|
|
Firm_attr.set_index('Code')
|
|
G_Firm = nx.MultiDiGraph()
|
|
G_Firm.add_nodes_from(Firm["Code"])
|
|
|
|
firm_labels_dict = {}
|
|
for code in G_Firm.nodes:
|
|
firm_labels_dict[code] = Firm_attr.loc[code].to_dict()
|
|
nx.set_node_attributes(G_Firm, firm_labels_dict)
|
|
|
|
# add edge to G_firm according to G_bom
|
|
for node in nx.nodes(G_Firm):
|
|
# print(node, '-' * 20)
|
|
for product_code in G_Firm.nodes[node]['Product_Code']:
|
|
# print(product_code)
|
|
for succ_product_code in list(G_bom.successors(product_code)):
|
|
# print(succ_product_code)
|
|
list_succ_firms = Firm.index[Firm[succ_product_code] ==
|
|
1].to_list()
|
|
list_revenue_log = [
|
|
G_Firm.nodes[succ_firm]['Revenue_Log']
|
|
for succ_firm in list_succ_firms
|
|
]
|
|
list_prob = [
|
|
(v - min(list_revenue_log) + 1) /
|
|
(max(list_revenue_log) - min(list_revenue_log) + 1)
|
|
for v in list_revenue_log
|
|
]
|
|
list_flag = [
|
|
self.nprandom.choice([1, 0], p=[prob, 1 - prob])
|
|
for prob in list_prob
|
|
]
|
|
# print(list(zip(list_succ_firms,list_flag,list_prob)))
|
|
list_added_edges = [(node, succ_firm, {
|
|
'Product': product_code
|
|
}) for succ_firm, flag in zip(list_succ_firms, list_flag)
|
|
if flag == 1]
|
|
G_Firm.add_edges_from(list_added_edges)
|
|
# print('-' * 20)
|
|
|
|
self.firm_network = ap.Network(self, G_Firm)
|
|
self.product_network = ap.Network(self, G_bom)
|
|
# print([node.label for node in self.firm_network.nodes])
|
|
# print([list(self.firm_network.graph.predecessors(node))
|
|
# for node in self.firm_network.nodes])
|
|
# print([self.firm_network.graph.nodes[node.label]['Name']
|
|
# for node in self.firm_network.nodes])
|
|
# print([v for v in self.firm_network.graph.nodes(data=True)])
|
|
|
|
# init product
|
|
for ag_node, attr in self.product_network.graph.nodes(data=True):
|
|
product_agent = ProductAgent(self,
|
|
code=ag_node.label,
|
|
name=attr['Name'])
|
|
self.product_network.add_agents([product_agent], [ag_node])
|
|
self.a_list_total_products = ap.AgentList(self,
|
|
self.product_network.agents)
|
|
|
|
# init firm
|
|
for ag_node, attr in self.firm_network.graph.nodes(data=True):
|
|
firm_agent = FirmAgent(
|
|
self,
|
|
code=attr['Code'],
|
|
name=attr['Name'],
|
|
type_region=attr['Type_Region'],
|
|
revenue_log=attr['Revenue_Log'],
|
|
a_list_product=self.a_list_total_products.select([
|
|
code in attr['Product_Code']
|
|
for code in self.a_list_total_products.code
|
|
]))
|
|
# init capacity as the degree of out edges of a specific product
|
|
list_out_edges = list(
|
|
self.firm_network.graph.out_edges(ag_node,
|
|
keys=True,
|
|
data='Product'))
|
|
for product in firm_agent.a_list_product:
|
|
capacity = len([
|
|
edge for edge in list_out_edges if edge[-1] == product.code
|
|
])
|
|
firm_agent.dct_prod_capacity[product] = capacity
|
|
# print(firm_agent.name, firm_agent.dct_prod_capacity)
|
|
|
|
self.firm_network.add_agents([firm_agent], [ag_node])
|
|
self.a_list_total_firms = ap.AgentList(self, self.firm_network.agents)
|
|
# print(list(zip(self.a_list_total_firms.code,
|
|
# self.a_list_total_firms.name,
|
|
# self.a_list_total_firms.capacity)))
|
|
|
|
# init dct_list_remove_firm_prod (from string to agent)
|
|
t_dct = {}
|
|
for firm_code, list_product in self.dct_list_remove_firm_prod.items():
|
|
firm = self.a_list_total_firms.select(
|
|
self.a_list_total_firms.code == firm_code)[0]
|
|
t_dct[firm] = self.a_list_total_products.select([
|
|
code in list_product
|
|
for code in self.a_list_total_products.code
|
|
])
|
|
self.dct_list_remove_firm_prod = t_dct
|
|
|
|
# set the initial firm product that are removed
|
|
for firm, a_list_product in self.dct_list_remove_firm_prod.items():
|
|
for product in a_list_product:
|
|
assert product in firm.a_list_product, \
|
|
f"product {product.code} not in firm {firm.code}"
|
|
firm.a_list_product_removed.append(product)
|
|
|
|
def update(self):
|
|
# update the firm that is removed
|
|
self.dct_list_remove_firm_prod = {}
|
|
for firm in self.a_list_total_firms:
|
|
if len(firm.a_list_product_removed) > 0:
|
|
self.dct_list_remove_firm_prod[
|
|
firm] = firm.a_list_product_removed
|
|
# print(self.dct_list_remove_firm_prod)
|
|
|
|
# stop simulation if reached terminal number of iteration
|
|
if self.t == self.int_n_iter or len(
|
|
self.dct_list_remove_firm_prod) == 0:
|
|
self.stop()
|
|
|
|
def step(self):
|
|
# shuffle self.dct_list_remove_firm_prod
|
|
dct_key_list = list(self.dct_list_remove_firm_prod.keys())
|
|
self.nprandom.shuffle(dct_key_list)
|
|
self.dct_list_remove_firm_prod = {
|
|
key: self.dct_list_remove_firm_prod[key].shuffle()
|
|
for key in dct_key_list
|
|
}
|
|
# print(self.dct_list_remove_firm_prod)
|
|
|
|
# remove_edge_to_cus_and_cus_up_prod
|
|
for firm, a_list_product in self.dct_list_remove_firm_prod.items():
|
|
for product in a_list_product:
|
|
firm.remove_edge_to_cus_and_cus_up_prod(product)
|
|
|
|
for n_trial in range(self.int_n_max_trial):
|
|
print('='*20, n_trial, '='*20)
|
|
# seek_alt_supply
|
|
for firm in self.a_list_total_firms:
|
|
if len(firm.a_list_up_product_removed) > 0:
|
|
# print(firm.name)
|
|
# print(firm.a_list_up_product_removed.code)
|
|
firm.seek_alt_supply()
|
|
|
|
# handle_request
|
|
for firm in self.a_list_total_firms:
|
|
if len(firm.dct_request_prod_from_firm) > 0:
|
|
firm.handle_request()
|
|
|
|
# reset dct_request_prod_from_firm
|
|
self.a_list_total_firms.clean_before_trial()
|
|
# do not use:
|
|
# self.a_list_total_firms.dct_request_prod_from_firm = {} why?
|
|
|
|
def end(self):
|
|
pass
|
|
|
|
def draw_network(self):
|
|
import matplotlib.pyplot as plt
|
|
plt.rcParams['font.sans-serif'] = 'SimHei'
|
|
pos = nx.nx_agraph.graphviz_layout(self.firm_network.graph,
|
|
prog="twopi",
|
|
args="")
|
|
node_label = nx.get_node_attributes(self.firm_network.graph, 'Name')
|
|
# print(node_label)
|
|
node_degree = dict(self.firm_network.graph.out_degree())
|
|
node_label = {
|
|
key: f"{node_label[key]} {node_degree[key]}"
|
|
for key in node_label.keys()
|
|
}
|
|
node_size = list(
|
|
nx.get_node_attributes(self.firm_network.graph,
|
|
'Revenue_Log').values())
|
|
node_size = list(map(lambda x: x**2, node_size))
|
|
edge_label = nx.get_edge_attributes(self.firm_network.graph, "Product")
|
|
# multi(di)graphs, the keys are 3-tuples
|
|
edge_label = {(n1, n2): label
|
|
for (n1, n2, _), label in edge_label.items()}
|
|
plt.figure(figsize=(12, 12), dpi=300)
|
|
nx.draw(self.firm_network.graph,
|
|
pos,
|
|
node_size=node_size,
|
|
labels=node_label,
|
|
font_size=6)
|
|
nx.draw_networkx_edge_labels(self.firm_network.graph,
|
|
pos,
|
|
edge_label,
|
|
font_size=4)
|
|
plt.savefig("network.png")
|
|
|
|
|
|
model = Model(dct_sample_para)
|
|
model.setup()
|
|
model.update()
|
|
model.step()
|
|
# model.draw_network()
|