代理类
This commit is contained in:
		
							parent
							
								
									1f643c64e4
								
							
						
					
					
						commit
						6b87ca7a63
					
				
							
								
								
									
										15
									
								
								model.py
								
								
								
								
							
							
						
						
									
										15
									
								
								model.py
								
								
								
								
							| 
						 | 
				
			
			@ -9,13 +9,14 @@ from mesa.datacollection import DataCollector
 | 
			
		|||
 | 
			
		||||
from firm import FirmAgent
 | 
			
		||||
from product import ProductAgent
 | 
			
		||||
from scheduler import CustomScheduler
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MyModel(Model):
 | 
			
		||||
    def __init__(self, params):
 | 
			
		||||
        # self.num_agents = params['N']
 | 
			
		||||
        # self.grid = MultiGrid(params['width'], params['height'], True)
 | 
			
		||||
        # self.schedule = RandomActivation(self)
 | 
			
		||||
        self.num_agents = params['N']
 | 
			
		||||
        self.grid = MultiGrid(params['width'], params['height'], True)
 | 
			
		||||
        self.schedule = CustomScheduler(self)
 | 
			
		||||
 | 
			
		||||
        # Initialize parameters from `params`
 | 
			
		||||
        self.sample = params['sample']
 | 
			
		||||
| 
						 | 
				
			
			@ -56,19 +57,19 @@ class MyModel(Model):
 | 
			
		|||
        firm = pd.read_csv("input_data/Firm_amended.csv")
 | 
			
		||||
        firm['Code'] = firm['Code'].astype('string')
 | 
			
		||||
        firm.fillna(0, inplace=True)
 | 
			
		||||
        Firm_attr = firm[["Code", "Type_Region", "Revenue_Log"]]
 | 
			
		||||
        firm_attr = firm[["Code", "Type_Region", "Revenue_Log"]]
 | 
			
		||||
        firm_product = []
 | 
			
		||||
        for _, row in firm.loc[:, '1':].iterrows():
 | 
			
		||||
            firm_product.append(row[row == 1].index.to_list())
 | 
			
		||||
        Firm_attr['Product_Code'] = firm_product
 | 
			
		||||
        Firm_attr.set_index('Code', inplace=True)
 | 
			
		||||
        firm_attr['Product_Code'] = firm_product
 | 
			
		||||
        firm_attr.set_index('Code', inplace=True)
 | 
			
		||||
        G_Firm = nx.MultiDiGraph()
 | 
			
		||||
        G_Firm.add_nodes_from(firm["Code"])
 | 
			
		||||
 | 
			
		||||
        # Add node attributes
 | 
			
		||||
        firm_labels_dict = {}
 | 
			
		||||
        for code in G_Firm.nodes:
 | 
			
		||||
            firm_labels_dict[code] = Firm_attr.loc[code].to_dict()
 | 
			
		||||
            firm_labels_dict[code] = firm_attr.loc[code].to_dict()
 | 
			
		||||
        nx.set_node_attributes(G_Firm, firm_labels_dict)
 | 
			
		||||
 | 
			
		||||
        # Add edges based on BOM graph
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,27 @@
 | 
			
		|||
from mesa.time import BaseScheduler
 | 
			
		||||
 | 
			
		||||
from firm import FirmAgent
 | 
			
		||||
from product import ProductAgent
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CustomScheduler(BaseScheduler):
 | 
			
		||||
    def __init__(self, model):
 | 
			
		||||
        super().__init__(model)
 | 
			
		||||
        self.company_agents = []
 | 
			
		||||
        self.product_agents = []
 | 
			
		||||
 | 
			
		||||
    def add_agent(self, agent):
 | 
			
		||||
        if isinstance(agent, FirmAgent):
 | 
			
		||||
            self.company_agents.append(agent)
 | 
			
		||||
        elif isinstance(agent, ProductAgent):
 | 
			
		||||
            self.product_agents.append(agent)
 | 
			
		||||
        super().add_agent(agent)
 | 
			
		||||
 | 
			
		||||
    def step(self):
 | 
			
		||||
        # First, activate all company agents
 | 
			
		||||
        for agent in self.company_agents:
 | 
			
		||||
            agent.step()
 | 
			
		||||
 | 
			
		||||
        # Then, activate all product agents
 | 
			
		||||
        for agent in self.product_agents:
 | 
			
		||||
            agent.step()
 | 
			
		||||
		Loading…
	
		Reference in New Issue