最新版本,增加修改很多地方,但是还是有报错
This commit is contained in:
		
							parent
							
								
									0285494a43
								
							
						
					
					
						commit
						be701c2fe9
					
				
										
											Binary file not shown.
										
									
								
							
										
											Binary file not shown.
										
									
								
							
										
											Binary file not shown.
										
									
								
							
										
											Binary file not shown.
										
									
								
							| 
						 | 
				
			
			@ -4,6 +4,9 @@ from mesa import Model
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
from typing import TYPE_CHECKING
 | 
			
		||||
 | 
			
		||||
from model import MyModel
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from controller_db import ControllerDB
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -34,7 +37,8 @@ class Computation:
 | 
			
		|||
        dct_sample_para = {'sample': sample_random,
 | 
			
		||||
                           'seed': sample_random.seed,
 | 
			
		||||
                           **dct_exp}
 | 
			
		||||
        model = Model(dct_sample_para)
 | 
			
		||||
 | 
			
		||||
        model = MyModel(dct_sample_para)
 | 
			
		||||
        for i in range(100):
 | 
			
		||||
            model.step()
 | 
			
		||||
        return False
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										9
									
								
								firm.py
								
								
								
								
							
							
						
						
									
										9
									
								
								firm.py
								
								
								
								
							| 
						 | 
				
			
			@ -1,7 +1,9 @@
 | 
			
		|||
from mesa import Agent
 | 
			
		||||
from model import MyModel
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FirmAgent(Agent):
 | 
			
		||||
    def __init__(self, unique_id, model, code, type_region, revenue_log, a_lst_product):
 | 
			
		||||
    def __init__(self, unique_id, model, type_region, revenue_log, a_lst_product):
 | 
			
		||||
        # 调用超类的 __init__ 方法
 | 
			
		||||
        super().__init__(unique_id, model)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -10,7 +12,6 @@ class FirmAgent(Agent):
 | 
			
		|||
        self.product_network = self.model.product_network
 | 
			
		||||
 | 
			
		||||
        # 初始化代理自身的属性
 | 
			
		||||
        self.code = code
 | 
			
		||||
        self.type_region = type_region
 | 
			
		||||
 | 
			
		||||
        self.size_stat = []
 | 
			
		||||
| 
						 | 
				
			
			@ -100,7 +101,7 @@ class FirmAgent(Agent):
 | 
			
		|||
                        lst_size = [firm.size_stat[-1][0] for firm in self.dct_cand_alt_supp_up_prod_disrupted[product]]
 | 
			
		||||
                        lst_prob = [size / sum(lst_size) for size in lst_size]
 | 
			
		||||
                        select_alt_supply = \
 | 
			
		||||
                        self.random.choices(self.dct_cand_alt_supp_up_prod_disrupted[product], weights=lst_prob)[0]
 | 
			
		||||
                            self.random.choices(self.dct_cand_alt_supp_up_prod_disrupted[product], weights=lst_prob)[0]
 | 
			
		||||
                    else:
 | 
			
		||||
                        select_alt_supply = self.random.choice(self.dct_cand_alt_supp_up_prod_disrupted[product])
 | 
			
		||||
                elif len(lst_firm_connect) > 0:
 | 
			
		||||
| 
						 | 
				
			
			@ -176,7 +177,7 @@ class FirmAgent(Agent):
 | 
			
		|||
                            down_firm.dct_cand_alt_supp_up_prod_disrupted[product].remove(self)
 | 
			
		||||
 | 
			
		||||
    def clean_before_trial(self):
 | 
			
		||||
         self.dct_request_prod_from_firm = {}
 | 
			
		||||
        self.dct_request_prod_from_firm = {}
 | 
			
		||||
 | 
			
		||||
    def clean_before_time_step(self):
 | 
			
		||||
        # Reset the number of trials and candidate suppliers for disrupted products
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										4
									
								
								main.py
								
								
								
								
							
							
						
						
									
										4
									
								
								main.py
								
								
								
								
							| 
						 | 
				
			
			@ -49,10 +49,10 @@ def do_computation(c_db):
 | 
			
		|||
if __name__ == '__main__':
 | 
			
		||||
    # 输入参数
 | 
			
		||||
    parser = argparse.ArgumentParser(description='setting')
 | 
			
		||||
    parser.add_argument('--exp', type=str, default='test')
 | 
			
		||||
    parser.add_argument('--exp', type=str, default='without_exp')
 | 
			
		||||
    parser.add_argument('--job', type=int, default='3')
 | 
			
		||||
    parser.add_argument('--reset_sample', type=int, default='0')
 | 
			
		||||
    parser.add_argument('--reset_db', type=bool, default=False)
 | 
			
		||||
    parser.add_argument('--reset_db', type=bool, default=True)
 | 
			
		||||
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    # 几核参与进程
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										27
									
								
								model.py
								
								
								
								
							
							
						
						
									
										27
									
								
								model.py
								
								
								
								
							| 
						 | 
				
			
			@ -14,11 +14,17 @@ from product import ProductAgent
 | 
			
		|||
 | 
			
		||||
class MyModel(Model):
 | 
			
		||||
    def __init__(self, params):
 | 
			
		||||
 | 
			
		||||
        # self.num_agents = N
 | 
			
		||||
        self.is_prf_size = params['is_prf_size']
 | 
			
		||||
        self.prf_conn = params['prf_conn']
 | 
			
		||||
        self.cap_limit_prob_type = params['cap_limit_prob_type']
 | 
			
		||||
        self.cap_limit_level = params['cap_limit_level']
 | 
			
		||||
        self.diff_new_conn = params['diff_new_conn']
 | 
			
		||||
 | 
			
		||||
        # NetworkX 图对象
 | 
			
		||||
        self.t = 0
 | 
			
		||||
        self.network_graph = nx.DiGraph()
 | 
			
		||||
        self.network_graph = nx.MultiDiGraph()
 | 
			
		||||
 | 
			
		||||
        # NetworkGrid 用于管理网格
 | 
			
		||||
        self.grid = NetworkGrid(self.network_graph)
 | 
			
		||||
| 
						 | 
				
			
			@ -26,7 +32,7 @@ class MyModel(Model):
 | 
			
		|||
        self.data_collector = DataCollector(
 | 
			
		||||
            agent_reporters={"Product": "name"}
 | 
			
		||||
        )
 | 
			
		||||
        self.schedule = RandomActivation(self)
 | 
			
		||||
 | 
			
		||||
        self.company_agents = []
 | 
			
		||||
        self.product_agents = []
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -166,20 +172,20 @@ class MyModel(Model):
 | 
			
		|||
    def initialize_agents(self):
 | 
			
		||||
        """ Initialize agents and add them to the model. """
 | 
			
		||||
        for ag_node, attr in self.product_network.nodes(data=True):
 | 
			
		||||
            product = ProductAgent(ag_node, self,code=attr['code'], name=attr['Name'])
 | 
			
		||||
            self.schedule.add(product)
 | 
			
		||||
            self.grid.place_agent(product, ag_node)
 | 
			
		||||
            product = ProductAgent(ag_node, self, name=attr['Name'])
 | 
			
		||||
            self.add_agent(product)
 | 
			
		||||
            # self.grid.place_agent(product, ag_node)
 | 
			
		||||
 | 
			
		||||
        for ag_node, attr in self.firm_network.nodes(data=True):
 | 
			
		||||
            a_lst_product = [agent for agent in self.product_agents if agent.unique_id in attr['Product_Code']]
 | 
			
		||||
            firm_agent = FirmAgent(
 | 
			
		||||
                ag_node, self,
 | 
			
		||||
                code=attr['Code'],
 | 
			
		||||
                type_region=attr['Type_Region'],
 | 
			
		||||
                revenue_log=attr['Revenue_Log'],
 | 
			
		||||
                a_lst_product=[self.schedule.agents[code] for code in attr['Product_Code']]
 | 
			
		||||
                a_lst_product=a_lst_product,
 | 
			
		||||
            )
 | 
			
		||||
            self.schedule.add(firm_agent)
 | 
			
		||||
            self.grid.place_agent(firm_agent, ag_node)
 | 
			
		||||
            self.add_agent(firm_agent)
 | 
			
		||||
            # self.grid.place_agent(firm_agent, ag_node)
 | 
			
		||||
 | 
			
		||||
    def initialize_disruptions(self):
 | 
			
		||||
        """ Initialize disruptions in the network. """
 | 
			
		||||
| 
						 | 
				
			
			@ -195,9 +201,9 @@ class MyModel(Model):
 | 
			
		|||
            self.company_agents.append(agent)
 | 
			
		||||
        elif isinstance(agent, ProductAgent):
 | 
			
		||||
            self.product_agents.append(agent)
 | 
			
		||||
        self.schedule.add(agent)
 | 
			
		||||
 | 
			
		||||
    def step(self):
 | 
			
		||||
        print(f"Running step {self.t}")
 | 
			
		||||
        # 1. Remove edge to customer and disrupt customer up product
 | 
			
		||||
        for firm in self.company_agents:
 | 
			
		||||
            for prod in firm.dct_prod_up_prod_stat.keys():
 | 
			
		||||
| 
						 | 
				
			
			@ -244,4 +250,3 @@ class MyModel(Model):
 | 
			
		|||
 | 
			
		||||
        # Increment the time step
 | 
			
		||||
        self.t += 1
 | 
			
		||||
        self.schedule.step()  # Activate all agents in the scheduler
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,12 +1,11 @@
 | 
			
		|||
from mesa import Agent
 | 
			
		||||
 | 
			
		||||
class ProductAgent(Agent):
 | 
			
		||||
    def __init__(self, unique_id, model, code, name):
 | 
			
		||||
    def __init__(self, unique_id, model, name):
 | 
			
		||||
        # 调用超类的 __init__ 方法
 | 
			
		||||
        super().__init__(unique_id, model)
 | 
			
		||||
 | 
			
		||||
        # 初始化代理属性
 | 
			
		||||
        self.code = code
 | 
			
		||||
        self.name = name
 | 
			
		||||
        self.product_network = self.model.product_network
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue