代理类

This commit is contained in:
Cricial 2024-08-24 11:32:09 +08:00
parent 1f643c64e4
commit 6b87ca7a63
2 changed files with 35 additions and 7 deletions

View File

@ -9,13 +9,14 @@ from mesa.datacollection import DataCollector
from firm import FirmAgent from firm import FirmAgent
from product import ProductAgent from product import ProductAgent
from scheduler import CustomScheduler
class MyModel(Model): class MyModel(Model):
def __init__(self, params): def __init__(self, params):
# self.num_agents = params['N'] self.num_agents = params['N']
# self.grid = MultiGrid(params['width'], params['height'], True) self.grid = MultiGrid(params['width'], params['height'], True)
# self.schedule = RandomActivation(self) self.schedule = CustomScheduler(self)
# Initialize parameters from `params` # Initialize parameters from `params`
self.sample = params['sample'] self.sample = params['sample']
@ -56,19 +57,19 @@ class MyModel(Model):
firm = pd.read_csv("input_data/Firm_amended.csv") firm = pd.read_csv("input_data/Firm_amended.csv")
firm['Code'] = firm['Code'].astype('string') firm['Code'] = firm['Code'].astype('string')
firm.fillna(0, inplace=True) firm.fillna(0, inplace=True)
Firm_attr = firm[["Code", "Type_Region", "Revenue_Log"]] firm_attr = firm[["Code", "Type_Region", "Revenue_Log"]]
firm_product = [] firm_product = []
for _, row in firm.loc[:, '1':].iterrows(): for _, row in firm.loc[:, '1':].iterrows():
firm_product.append(row[row == 1].index.to_list()) firm_product.append(row[row == 1].index.to_list())
Firm_attr['Product_Code'] = firm_product firm_attr['Product_Code'] = firm_product
Firm_attr.set_index('Code', inplace=True) firm_attr.set_index('Code', inplace=True)
G_Firm = nx.MultiDiGraph() G_Firm = nx.MultiDiGraph()
G_Firm.add_nodes_from(firm["Code"]) G_Firm.add_nodes_from(firm["Code"])
# Add node attributes # Add node attributes
firm_labels_dict = {} firm_labels_dict = {}
for code in G_Firm.nodes: for code in G_Firm.nodes:
firm_labels_dict[code] = Firm_attr.loc[code].to_dict() firm_labels_dict[code] = firm_attr.loc[code].to_dict()
nx.set_node_attributes(G_Firm, firm_labels_dict) nx.set_node_attributes(G_Firm, firm_labels_dict)
# Add edges based on BOM graph # Add edges based on BOM graph

27
scheduler.py Normal file
View File

@ -0,0 +1,27 @@
from mesa.time import BaseScheduler
from firm import FirmAgent
from product import ProductAgent
class CustomScheduler(BaseScheduler):
def __init__(self, model):
super().__init__(model)
self.company_agents = []
self.product_agents = []
def add_agent(self, agent):
if isinstance(agent, FirmAgent):
self.company_agents.append(agent)
elif isinstance(agent, ProductAgent):
self.product_agents.append(agent)
super().add_agent(agent)
def step(self):
# First, activate all company agents
for agent in self.company_agents:
agent.step()
# Then, activate all product agents
for agent in self.product_agents:
agent.step()