最新版本,增加修改很多地方,但是还是有报错
This commit is contained in:
		
							parent
							
								
									0285494a43
								
							
						
					
					
						commit
						be701c2fe9
					
				
										
											Binary file not shown.
										
									
								
							
										
											Binary file not shown.
										
									
								
							
										
											Binary file not shown.
										
									
								
							
										
											Binary file not shown.
										
									
								
							| 
						 | 
					@ -4,6 +4,9 @@ from mesa import Model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from typing import TYPE_CHECKING
 | 
					from typing import TYPE_CHECKING
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from model import MyModel
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if TYPE_CHECKING:
 | 
					if TYPE_CHECKING:
 | 
				
			||||||
    from controller_db import ControllerDB
 | 
					    from controller_db import ControllerDB
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -34,7 +37,8 @@ class Computation:
 | 
				
			||||||
        dct_sample_para = {'sample': sample_random,
 | 
					        dct_sample_para = {'sample': sample_random,
 | 
				
			||||||
                           'seed': sample_random.seed,
 | 
					                           'seed': sample_random.seed,
 | 
				
			||||||
                           **dct_exp}
 | 
					                           **dct_exp}
 | 
				
			||||||
        model = Model(dct_sample_para)
 | 
					
 | 
				
			||||||
 | 
					        model = MyModel(dct_sample_para)
 | 
				
			||||||
        for i in range(100):
 | 
					        for i in range(100):
 | 
				
			||||||
            model.step()
 | 
					            model.step()
 | 
				
			||||||
        return False
 | 
					        return False
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										9
									
								
								firm.py
								
								
								
								
							
							
						
						
									
										9
									
								
								firm.py
								
								
								
								
							| 
						 | 
					@ -1,7 +1,9 @@
 | 
				
			||||||
from mesa import Agent
 | 
					from mesa import Agent
 | 
				
			||||||
 | 
					from model import MyModel
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class FirmAgent(Agent):
 | 
					class FirmAgent(Agent):
 | 
				
			||||||
    def __init__(self, unique_id, model, code, type_region, revenue_log, a_lst_product):
 | 
					    def __init__(self, unique_id, model, type_region, revenue_log, a_lst_product):
 | 
				
			||||||
        # 调用超类的 __init__ 方法
 | 
					        # 调用超类的 __init__ 方法
 | 
				
			||||||
        super().__init__(unique_id, model)
 | 
					        super().__init__(unique_id, model)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -10,7 +12,6 @@ class FirmAgent(Agent):
 | 
				
			||||||
        self.product_network = self.model.product_network
 | 
					        self.product_network = self.model.product_network
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # 初始化代理自身的属性
 | 
					        # 初始化代理自身的属性
 | 
				
			||||||
        self.code = code
 | 
					 | 
				
			||||||
        self.type_region = type_region
 | 
					        self.type_region = type_region
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.size_stat = []
 | 
					        self.size_stat = []
 | 
				
			||||||
| 
						 | 
					@ -100,7 +101,7 @@ class FirmAgent(Agent):
 | 
				
			||||||
                        lst_size = [firm.size_stat[-1][0] for firm in self.dct_cand_alt_supp_up_prod_disrupted[product]]
 | 
					                        lst_size = [firm.size_stat[-1][0] for firm in self.dct_cand_alt_supp_up_prod_disrupted[product]]
 | 
				
			||||||
                        lst_prob = [size / sum(lst_size) for size in lst_size]
 | 
					                        lst_prob = [size / sum(lst_size) for size in lst_size]
 | 
				
			||||||
                        select_alt_supply = \
 | 
					                        select_alt_supply = \
 | 
				
			||||||
                        self.random.choices(self.dct_cand_alt_supp_up_prod_disrupted[product], weights=lst_prob)[0]
 | 
					                            self.random.choices(self.dct_cand_alt_supp_up_prod_disrupted[product], weights=lst_prob)[0]
 | 
				
			||||||
                    else:
 | 
					                    else:
 | 
				
			||||||
                        select_alt_supply = self.random.choice(self.dct_cand_alt_supp_up_prod_disrupted[product])
 | 
					                        select_alt_supply = self.random.choice(self.dct_cand_alt_supp_up_prod_disrupted[product])
 | 
				
			||||||
                elif len(lst_firm_connect) > 0:
 | 
					                elif len(lst_firm_connect) > 0:
 | 
				
			||||||
| 
						 | 
					@ -176,7 +177,7 @@ class FirmAgent(Agent):
 | 
				
			||||||
                            down_firm.dct_cand_alt_supp_up_prod_disrupted[product].remove(self)
 | 
					                            down_firm.dct_cand_alt_supp_up_prod_disrupted[product].remove(self)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def clean_before_trial(self):
 | 
					    def clean_before_trial(self):
 | 
				
			||||||
         self.dct_request_prod_from_firm = {}
 | 
					        self.dct_request_prod_from_firm = {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def clean_before_time_step(self):
 | 
					    def clean_before_time_step(self):
 | 
				
			||||||
        # Reset the number of trials and candidate suppliers for disrupted products
 | 
					        # Reset the number of trials and candidate suppliers for disrupted products
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										4
									
								
								main.py
								
								
								
								
							
							
						
						
									
										4
									
								
								main.py
								
								
								
								
							| 
						 | 
					@ -49,10 +49,10 @@ def do_computation(c_db):
 | 
				
			||||||
if __name__ == '__main__':
 | 
					if __name__ == '__main__':
 | 
				
			||||||
    # 输入参数
 | 
					    # 输入参数
 | 
				
			||||||
    parser = argparse.ArgumentParser(description='setting')
 | 
					    parser = argparse.ArgumentParser(description='setting')
 | 
				
			||||||
    parser.add_argument('--exp', type=str, default='test')
 | 
					    parser.add_argument('--exp', type=str, default='without_exp')
 | 
				
			||||||
    parser.add_argument('--job', type=int, default='3')
 | 
					    parser.add_argument('--job', type=int, default='3')
 | 
				
			||||||
    parser.add_argument('--reset_sample', type=int, default='0')
 | 
					    parser.add_argument('--reset_sample', type=int, default='0')
 | 
				
			||||||
    parser.add_argument('--reset_db', type=bool, default=False)
 | 
					    parser.add_argument('--reset_db', type=bool, default=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    args = parser.parse_args()
 | 
					    args = parser.parse_args()
 | 
				
			||||||
    # 几核参与进程
 | 
					    # 几核参与进程
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										27
									
								
								model.py
								
								
								
								
							
							
						
						
									
										27
									
								
								model.py
								
								
								
								
							| 
						 | 
					@ -14,11 +14,17 @@ from product import ProductAgent
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class MyModel(Model):
 | 
					class MyModel(Model):
 | 
				
			||||||
    def __init__(self, params):
 | 
					    def __init__(self, params):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # self.num_agents = N
 | 
					        # self.num_agents = N
 | 
				
			||||||
 | 
					        self.is_prf_size = params['is_prf_size']
 | 
				
			||||||
 | 
					        self.prf_conn = params['prf_conn']
 | 
				
			||||||
 | 
					        self.cap_limit_prob_type = params['cap_limit_prob_type']
 | 
				
			||||||
 | 
					        self.cap_limit_level = params['cap_limit_level']
 | 
				
			||||||
 | 
					        self.diff_new_conn = params['diff_new_conn']
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # NetworkX 图对象
 | 
					        # NetworkX 图对象
 | 
				
			||||||
        self.t = 0
 | 
					        self.t = 0
 | 
				
			||||||
        self.network_graph = nx.DiGraph()
 | 
					        self.network_graph = nx.MultiDiGraph()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # NetworkGrid 用于管理网格
 | 
					        # NetworkGrid 用于管理网格
 | 
				
			||||||
        self.grid = NetworkGrid(self.network_graph)
 | 
					        self.grid = NetworkGrid(self.network_graph)
 | 
				
			||||||
| 
						 | 
					@ -26,7 +32,7 @@ class MyModel(Model):
 | 
				
			||||||
        self.data_collector = DataCollector(
 | 
					        self.data_collector = DataCollector(
 | 
				
			||||||
            agent_reporters={"Product": "name"}
 | 
					            agent_reporters={"Product": "name"}
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        self.schedule = RandomActivation(self)
 | 
					
 | 
				
			||||||
        self.company_agents = []
 | 
					        self.company_agents = []
 | 
				
			||||||
        self.product_agents = []
 | 
					        self.product_agents = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -166,20 +172,20 @@ class MyModel(Model):
 | 
				
			||||||
    def initialize_agents(self):
 | 
					    def initialize_agents(self):
 | 
				
			||||||
        """ Initialize agents and add them to the model. """
 | 
					        """ Initialize agents and add them to the model. """
 | 
				
			||||||
        for ag_node, attr in self.product_network.nodes(data=True):
 | 
					        for ag_node, attr in self.product_network.nodes(data=True):
 | 
				
			||||||
            product = ProductAgent(ag_node, self,code=attr['code'], name=attr['Name'])
 | 
					            product = ProductAgent(ag_node, self, name=attr['Name'])
 | 
				
			||||||
            self.schedule.add(product)
 | 
					            self.add_agent(product)
 | 
				
			||||||
            self.grid.place_agent(product, ag_node)
 | 
					            # self.grid.place_agent(product, ag_node)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        for ag_node, attr in self.firm_network.nodes(data=True):
 | 
					        for ag_node, attr in self.firm_network.nodes(data=True):
 | 
				
			||||||
 | 
					            a_lst_product = [agent for agent in self.product_agents if agent.unique_id in attr['Product_Code']]
 | 
				
			||||||
            firm_agent = FirmAgent(
 | 
					            firm_agent = FirmAgent(
 | 
				
			||||||
                ag_node, self,
 | 
					                ag_node, self,
 | 
				
			||||||
                code=attr['Code'],
 | 
					 | 
				
			||||||
                type_region=attr['Type_Region'],
 | 
					                type_region=attr['Type_Region'],
 | 
				
			||||||
                revenue_log=attr['Revenue_Log'],
 | 
					                revenue_log=attr['Revenue_Log'],
 | 
				
			||||||
                a_lst_product=[self.schedule.agents[code] for code in attr['Product_Code']]
 | 
					                a_lst_product=a_lst_product,
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
            self.schedule.add(firm_agent)
 | 
					            self.add_agent(firm_agent)
 | 
				
			||||||
            self.grid.place_agent(firm_agent, ag_node)
 | 
					            # self.grid.place_agent(firm_agent, ag_node)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def initialize_disruptions(self):
 | 
					    def initialize_disruptions(self):
 | 
				
			||||||
        """ Initialize disruptions in the network. """
 | 
					        """ Initialize disruptions in the network. """
 | 
				
			||||||
| 
						 | 
					@ -195,9 +201,9 @@ class MyModel(Model):
 | 
				
			||||||
            self.company_agents.append(agent)
 | 
					            self.company_agents.append(agent)
 | 
				
			||||||
        elif isinstance(agent, ProductAgent):
 | 
					        elif isinstance(agent, ProductAgent):
 | 
				
			||||||
            self.product_agents.append(agent)
 | 
					            self.product_agents.append(agent)
 | 
				
			||||||
        self.schedule.add(agent)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def step(self):
 | 
					    def step(self):
 | 
				
			||||||
 | 
					        print(f"Running step {self.t}")
 | 
				
			||||||
        # 1. Remove edge to customer and disrupt customer up product
 | 
					        # 1. Remove edge to customer and disrupt customer up product
 | 
				
			||||||
        for firm in self.company_agents:
 | 
					        for firm in self.company_agents:
 | 
				
			||||||
            for prod in firm.dct_prod_up_prod_stat.keys():
 | 
					            for prod in firm.dct_prod_up_prod_stat.keys():
 | 
				
			||||||
| 
						 | 
					@ -244,4 +250,3 @@ class MyModel(Model):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Increment the time step
 | 
					        # Increment the time step
 | 
				
			||||||
        self.t += 1
 | 
					        self.t += 1
 | 
				
			||||||
        self.schedule.step()  # Activate all agents in the scheduler
 | 
					 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,12 +1,11 @@
 | 
				
			||||||
from mesa import Agent
 | 
					from mesa import Agent
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class ProductAgent(Agent):
 | 
					class ProductAgent(Agent):
 | 
				
			||||||
    def __init__(self, unique_id, model, code, name):
 | 
					    def __init__(self, unique_id, model, name):
 | 
				
			||||||
        # 调用超类的 __init__ 方法
 | 
					        # 调用超类的 __init__ 方法
 | 
				
			||||||
        super().__init__(unique_id, model)
 | 
					        super().__init__(unique_id, model)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # 初始化代理属性
 | 
					        # 初始化代理属性
 | 
				
			||||||
        self.code = code
 | 
					 | 
				
			||||||
        self.name = name
 | 
					        self.name = name
 | 
				
			||||||
        self.product_network = self.model.product_network
 | 
					        self.product_network = self.model.product_network
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue