最新版本,增加修改很多地方,但是还是有报错

This commit is contained in:
Cricial 2024-08-24 19:30:16 +08:00
parent 0285494a43
commit be701c2fe9
9 changed files with 29 additions and 20 deletions

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -4,6 +4,9 @@ from mesa import Model
from typing import TYPE_CHECKING
from model import MyModel
if TYPE_CHECKING:
from controller_db import ControllerDB
@ -34,7 +37,8 @@ class Computation:
dct_sample_para = {'sample': sample_random,
'seed': sample_random.seed,
**dct_exp}
model = Model(dct_sample_para)
model = MyModel(dct_sample_para)
for i in range(100):
model.step()
return False

View File

@ -1,7 +1,9 @@
from mesa import Agent
from model import MyModel
class FirmAgent(Agent):
def __init__(self, unique_id, model, code, type_region, revenue_log, a_lst_product):
def __init__(self, unique_id, model, type_region, revenue_log, a_lst_product):
# 调用超类的 __init__ 方法
super().__init__(unique_id, model)
@ -10,7 +12,6 @@ class FirmAgent(Agent):
self.product_network = self.model.product_network
# 初始化代理自身的属性
self.code = code
self.type_region = type_region
self.size_stat = []
@ -100,7 +101,7 @@ class FirmAgent(Agent):
lst_size = [firm.size_stat[-1][0] for firm in self.dct_cand_alt_supp_up_prod_disrupted[product]]
lst_prob = [size / sum(lst_size) for size in lst_size]
select_alt_supply = \
self.random.choices(self.dct_cand_alt_supp_up_prod_disrupted[product], weights=lst_prob)[0]
self.random.choices(self.dct_cand_alt_supp_up_prod_disrupted[product], weights=lst_prob)[0]
else:
select_alt_supply = self.random.choice(self.dct_cand_alt_supp_up_prod_disrupted[product])
elif len(lst_firm_connect) > 0:
@ -176,7 +177,7 @@ class FirmAgent(Agent):
down_firm.dct_cand_alt_supp_up_prod_disrupted[product].remove(self)
def clean_before_trial(self):
self.dct_request_prod_from_firm = {}
self.dct_request_prod_from_firm = {}
def clean_before_time_step(self):
# Reset the number of trials and candidate suppliers for disrupted products

View File

@ -49,10 +49,10 @@ def do_computation(c_db):
if __name__ == '__main__':
# 输入参数
parser = argparse.ArgumentParser(description='setting')
parser.add_argument('--exp', type=str, default='test')
parser.add_argument('--exp', type=str, default='without_exp')
parser.add_argument('--job', type=int, default='3')
parser.add_argument('--reset_sample', type=int, default='0')
parser.add_argument('--reset_db', type=bool, default=False)
parser.add_argument('--reset_db', type=bool, default=True)
args = parser.parse_args()
# 几核参与进程

View File

@ -14,11 +14,17 @@ from product import ProductAgent
class MyModel(Model):
def __init__(self, params):
# self.num_agents = N
self.is_prf_size = params['is_prf_size']
self.prf_conn = params['prf_conn']
self.cap_limit_prob_type = params['cap_limit_prob_type']
self.cap_limit_level = params['cap_limit_level']
self.diff_new_conn = params['diff_new_conn']
# NetworkX 图对象
self.t = 0
self.network_graph = nx.DiGraph()
self.network_graph = nx.MultiDiGraph()
# NetworkGrid 用于管理网格
self.grid = NetworkGrid(self.network_graph)
@ -26,7 +32,7 @@ class MyModel(Model):
self.data_collector = DataCollector(
agent_reporters={"Product": "name"}
)
self.schedule = RandomActivation(self)
self.company_agents = []
self.product_agents = []
@ -166,20 +172,20 @@ class MyModel(Model):
def initialize_agents(self):
""" Initialize agents and add them to the model. """
for ag_node, attr in self.product_network.nodes(data=True):
product = ProductAgent(ag_node, self,code=attr['code'], name=attr['Name'])
self.schedule.add(product)
self.grid.place_agent(product, ag_node)
product = ProductAgent(ag_node, self, name=attr['Name'])
self.add_agent(product)
# self.grid.place_agent(product, ag_node)
for ag_node, attr in self.firm_network.nodes(data=True):
a_lst_product = [agent for agent in self.product_agents if agent.unique_id in attr['Product_Code']]
firm_agent = FirmAgent(
ag_node, self,
code=attr['Code'],
type_region=attr['Type_Region'],
revenue_log=attr['Revenue_Log'],
a_lst_product=[self.schedule.agents[code] for code in attr['Product_Code']]
a_lst_product=a_lst_product,
)
self.schedule.add(firm_agent)
self.grid.place_agent(firm_agent, ag_node)
self.add_agent(firm_agent)
# self.grid.place_agent(firm_agent, ag_node)
def initialize_disruptions(self):
""" Initialize disruptions in the network. """
@ -195,9 +201,9 @@ class MyModel(Model):
self.company_agents.append(agent)
elif isinstance(agent, ProductAgent):
self.product_agents.append(agent)
self.schedule.add(agent)
def step(self):
print(f"Running step {self.t}")
# 1. Remove edge to customer and disrupt customer up product
for firm in self.company_agents:
for prod in firm.dct_prod_up_prod_stat.keys():
@ -244,4 +250,3 @@ class MyModel(Model):
# Increment the time step
self.t += 1
self.schedule.step() # Activate all agents in the scheduler

View File

@ -1,12 +1,11 @@
from mesa import Agent
class ProductAgent(Agent):
def __init__(self, unique_id, model, code, name):
def __init__(self, unique_id, model, name):
# 调用超类的 __init__ 方法
super().__init__(unique_id, model)
# 初始化代理属性
self.code = code
self.name = name
self.product_network = self.model.product_network