mesa/model.py

135 lines
4.7 KiB
Python
Raw Normal View History

2024-08-24 11:20:13 +08:00
import json
import networkx as nx
import pandas as pd
from mesa import Model
from mesa.time import RandomActivation
from mesa.space import MultiGrid
from mesa.datacollection import DataCollector
from firm import FirmAgent
from product import ProductAgent
2024-08-24 11:32:09 +08:00
from scheduler import CustomScheduler
2024-08-24 11:20:13 +08:00
class MyModel(Model):
def __init__(self, params):
2024-08-24 11:32:09 +08:00
self.num_agents = params['N']
self.grid = MultiGrid(params['width'], params['height'], True)
self.schedule = CustomScheduler(self)
2024-08-24 11:20:13 +08:00
# Initialize parameters from `params`
self.sample = params['sample']
self.int_stop_ts = 0
self.int_n_iter = int(params['n_iter'])
self.dct_lst_init_disrupt_firm_prod = params['dct_lst_init_disrupt_firm_prod']
# external variable
self.int_n_max_trial = int(params['n_max_trial'])
self.is_prf_size = bool(params['prf_size'])
self.remove_t = int(params['remove_t'])
self.int_netw_prf_n = int(params['netw_prf_n'])
self.product_network = None
self.firm_network = None
self.firm_prod_network = None
# Initialize product network
G_bom = nx.adjacency_graph(json.loads(params['g_bom']))
self.product_network = G_bom
# Initialize firm network
self.initialize_firm_network()
# Initialize firm product network
self.initialize_firm_prod_network()
# Initialize agents
self.initialize_agents()
# Data collector (if needed)
self.datacollector = DataCollector(
agent_reporters={"Product Code": "code"}
)
def initialize_firm_network(self):
# Read firm data and initialize firm network
firm = pd.read_csv("input_data/Firm_amended.csv")
firm['Code'] = firm['Code'].astype('string')
firm.fillna(0, inplace=True)
2024-08-24 11:32:09 +08:00
firm_attr = firm[["Code", "Type_Region", "Revenue_Log"]]
2024-08-24 11:20:13 +08:00
firm_product = []
for _, row in firm.loc[:, '1':].iterrows():
firm_product.append(row[row == 1].index.to_list())
2024-08-24 11:32:09 +08:00
firm_attr['Product_Code'] = firm_product
firm_attr.set_index('Code', inplace=True)
2024-08-24 11:20:13 +08:00
G_Firm = nx.MultiDiGraph()
G_Firm.add_nodes_from(firm["Code"])
# Add node attributes
firm_labels_dict = {}
for code in G_Firm.nodes:
2024-08-24 11:32:09 +08:00
firm_labels_dict[code] = firm_attr.loc[code].to_dict()
2024-08-24 11:20:13 +08:00
nx.set_node_attributes(G_Firm, firm_labels_dict)
# Add edges based on BOM graph
self.add_edges_based_on_bom(G_Firm)
self.firm_network = G_Firm
def initialize_firm_prod_network(self):
# Read firm product data and initialize firm product network
firm_prod = pd.read_csv("input_data/Firm_amended.csv")
firm_prod.fillna(0, inplace=True)
firm_prod = pd.DataFrame({'bool': firm_prod.loc[:, '1':].stack()})
firm_prod = firm_prod[firm_prod['bool'] == 1].reset_index()
firm_prod.drop('bool', axis=1, inplace=True)
firm_prod.rename({'level_0': 'Firm_Code', 'level_1': 'Product_Code'}, axis=1, inplace=True)
firm_prod['Firm_Code'] = firm_prod['Firm_Code'].astype('string')
G_FirmProd = nx.MultiDiGraph()
G_FirmProd.add_nodes_from(firm_prod.index)
# Add node attributes
firm_prod_labels_dict = {}
for code in firm_prod.index:
firm_prod_labels_dict[code] = firm_prod.loc[code].to_dict()
nx.set_node_attributes(G_FirmProd, firm_prod_labels_dict)
self.firm_prod_network = G_FirmProd
def add_edges_based_on_bom(self, G_Firm):
# Logic to add edges to the G_Firm graph based on BOM
pass
def initialize_agents(self):
# Initialize product and firm agents
for node, attr in self.product_network.nodes(data=True):
product = ProductAgent(node, self, code=node, name=attr['Name'])
self.schedule.add(product)
for node, attr in self.firm_network.nodes(data=True):
firm_agent = FirmAgent(
node,
self,
code=node,
type_region=attr['Type_Region'],
revenue_log=attr['Revenue_Log'],
a_lst_product=[] # Populate based on products
)
self.schedule.add(firm_agent)
# Initialize disruptions
self.initialize_disruptions()
def initialize_disruptions(self):
# Set the initial firm product disruptions
for firm, products in self.dct_lst_init_disrupt_firm_prod.items():
for product in products:
if isinstance(firm, FirmAgent):
firm.dct_prod_up_prod_stat[product]['p_stat'].append(('D', self.schedule.steps))
def step(self):
self.schedule.step()
self.datacollector.collect(self)