Compare commits
4 Commits
09622cf33d
...
30e7e56c11
Author | SHA1 | Date |
---|---|---|
|
30e7e56c11 | |
|
0f7f9c1a4b | |
|
3a46d09b8e | |
|
13521ff752 |
|
@ -3,7 +3,7 @@
|
|||
<component name="CsvFileAttributes">
|
||||
<option name="attributeMap">
|
||||
<map>
|
||||
<entry key="C:\Users\www\Desktop\半导体数据——暂时的数据\我做的数据\汇总数据\BomCateNet.csv">
|
||||
<entry key="C:\Users\www\Desktop\python项目\数据\抽样第3次数据\firm_amended.csv">
|
||||
<value>
|
||||
<Attribute>
|
||||
<option name="separator" value="," />
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
import pickle
|
||||
import os
|
||||
|
||||
from 查看进度 import visualize_progress
|
||||
|
||||
|
||||
def load_cached_data(file_path):
|
||||
"""
|
||||
从指定的缓存文件加载数据。
|
||||
如果文件不存在或加载失败,则返回空字典。
|
||||
"""
|
||||
if not os.path.exists(file_path):
|
||||
print(f"Warning: Cache file '{file_path}' does not exist.")
|
||||
return {}
|
||||
|
||||
try:
|
||||
with open(file_path, 'rb') as f:
|
||||
data = pickle.load(f)
|
||||
print(f"Successfully loaded cache from '{file_path}'.")
|
||||
return data
|
||||
except (pickle.UnpicklingError, FileNotFoundError, EOFError) as e:
|
||||
print(f"Error loading cache from '{file_path}': {e}")
|
||||
return {}
|
||||
|
||||
|
||||
# 示例用法
|
||||
# data_dct = load_cached_data("G_Firm_add_edges.pkl")
|
||||
|
||||
visualize_progress()
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -41,11 +41,9 @@ class Computation:
|
|||
'seed': sample_random.seed,
|
||||
**dct_exp}
|
||||
|
||||
product_network_test = nx.adjacency_graph(json.loads(dct_sample_para['g_bom']))
|
||||
|
||||
model = MyModel(dct_sample_para)
|
||||
|
||||
model.step() # 运行仿真一步
|
||||
model.end() # 汇总结果
|
||||
|
||||
return False
|
||||
|
||||
|
|
|
@ -1 +1 @@
|
|||
db_name_prefix: without_exp
|
||||
db_name_prefix: with_exp
|
||||
|
|
|
@ -8,5 +8,5 @@ test: # only for test scenarios
|
|||
n_iter: 100
|
||||
|
||||
not_test: # normal scenarios
|
||||
n_sample: 50
|
||||
n_iter: 100
|
||||
n_sample: 10
|
||||
n_iter: 50
|
||||
|
|
Binary file not shown.
|
@ -53,7 +53,6 @@ class ControllerDB:
|
|||
|
||||
# fill dct_lst_init_disrupt_firm_prod
|
||||
# 存储 公司-在供应链结点的位置.. 0 :‘1.1’
|
||||
list_dct = [] # 存储 公司编码code 和对应的产业链 结点
|
||||
if self.is_with_exp:
|
||||
# 对于方差分析时候使用
|
||||
with open('SQL_export_high_risk_setting.sql', 'r') as f:
|
||||
|
@ -67,12 +66,27 @@ class ControllerDB:
|
|||
# 行索引 (index):这一行在数据帧中的索引值。
|
||||
# 行数据 (row):这一行的数据,是一个 pandas.Series 对象,包含该行的所有列和值。
|
||||
|
||||
# 读取企业与产品关系数据
|
||||
firm_industry = pd.read_csv("input_data/firm_industry_relation.csv")
|
||||
firm_industry['Firm_Code'] = firm_industry['Firm_Code'].astype('string')
|
||||
|
||||
# 假设已从 BOM 数据构建了 code_to_indices
|
||||
bom_nodes = pd.read_csv("input_data/input_product_data/BomNodes.csv")
|
||||
code_to_indices = bom_nodes.groupby('Code')['Index'].apply(list).to_dict()
|
||||
|
||||
# 初始化存储映射结果的列表
|
||||
list_dct = []
|
||||
|
||||
# 遍历 firm_industry 数据
|
||||
for _, row in firm_industry.iterrows():
|
||||
code = row['Firm_Code']
|
||||
row = row['Product_Code']
|
||||
dct = {code: [row]}
|
||||
firm_code = row['Firm_Code'] # 企业代码
|
||||
product_code = row['Product_Code'] # 原始产品代码
|
||||
|
||||
# 使用 code_to_indices 映射 Product_Code 到 Product_Indices
|
||||
mapped_indices = code_to_indices.get(product_code, []) # 如果找不到则返回空列表
|
||||
|
||||
# 构建企业到产品索引的映射
|
||||
dct = {firm_code: mapped_indices}
|
||||
list_dct.append(dct)
|
||||
|
||||
# fill g_bom
|
||||
|
|
95
firm.py
95
firm.py
|
@ -32,13 +32,13 @@ class FirmAgent(Agent):
|
|||
# 包括 产品时间
|
||||
self.P1 = {0: P}
|
||||
# 企业i的供应商
|
||||
self.upper_i = [agent for u, v in self.firm_network.in_edges(self.unique_id)
|
||||
for agent in self.model.company_agents if agent.unique_id == u]
|
||||
self.upper_i = [self.model.agent_map[u] for u, v in self.firm_network.in_edges(self.unique_id)
|
||||
if u in self.model.agent_map]
|
||||
# 企业i的客户
|
||||
self.downer_i = [agent for u, v in self.firm_network.out_edges(self.unique_id)
|
||||
for agent in self.model.company_agents if agent.unique_id == u]
|
||||
self.downer_i = [self.model.agent_map[v] for u, v in self.firm_network.out_edges(self.unique_id)
|
||||
if v in self.model.agent_map]
|
||||
# 设备c的数量 (总量) 使用这个来判断设备数量
|
||||
#self.n_equip_c = n_equip_c
|
||||
# self.n_equip_c = n_equip_c
|
||||
# 设备c产量 根据设备量进行估算
|
||||
self.c_yield = production_output
|
||||
# 消耗材料量 根据设备量进行估算 { }
|
||||
|
@ -143,42 +143,69 @@ class FirmAgent(Agent):
|
|||
# f"disrupted supplier of {disrupted_up_prod.code}")
|
||||
|
||||
def seek_alt_supply(self, product):
|
||||
# 检查当前产品的尝试次数是否达到最大值
|
||||
if self.dct_n_trial_up_prod_disrupted[product] <= self.model.int_n_max_trial:
|
||||
# 初始化候选供应商列表
|
||||
if self.dct_n_trial_up_prod_disrupted[product] == 0:
|
||||
self.dct_cand_alt_supp_up_prod_disrupted[product] = [
|
||||
firm for firm in self.model.company_agents
|
||||
if firm.is_prod_in_current_normal(product)]
|
||||
if self.dct_cand_alt_supp_up_prod_disrupted[product]:
|
||||
lst_firm_connect = []
|
||||
if self.is_prf_conn:
|
||||
for firm in self.dct_cand_alt_supp_up_prod_disrupted[product]:
|
||||
if self.firm_network.has_edge(self.unique_id, firm.unique_id) or \
|
||||
self.firm_network.has_edge(firm.unique_id, self.unique_id):
|
||||
lst_firm_connect.append(firm)
|
||||
if len(lst_firm_connect) == 0:
|
||||
if self.is_prf_size:
|
||||
lst_size = [firm.size_stat[-1][0] for firm in self.dct_cand_alt_supp_up_prod_disrupted[product]]
|
||||
lst_prob = [size / sum(lst_size) for size in lst_size]
|
||||
select_alt_supply = \
|
||||
self.random.choices(self.dct_cand_alt_supp_up_prod_disrupted[product], weights=lst_prob)[0]
|
||||
else:
|
||||
select_alt_supply = self.random.choice(self.dct_cand_alt_supp_up_prod_disrupted[product])
|
||||
elif len(lst_firm_connect) > 0:
|
||||
if self.is_prf_size:
|
||||
lst_firm_size = [firm.size_stat[-1][0] for firm in lst_firm_connect]
|
||||
lst_prob = [size / sum(lst_firm_size) for size in lst_firm_size]
|
||||
select_alt_supply = self.random.choices(lst_firm_connect, weights=lst_prob)[0]
|
||||
else:
|
||||
select_alt_supply = self.random.choice(lst_firm_connect)
|
||||
firm for firm in self.model.company_agents if firm.is_prod_in_current_normal(product)
|
||||
]
|
||||
|
||||
assert select_alt_supply.is_prod_in_current_normal(product)
|
||||
# 如果没有候选供应商,直接退出
|
||||
if not self.dct_cand_alt_supp_up_prod_disrupted[product]:
|
||||
# print(f"No valid candidates found for product {product.unique_id}")
|
||||
return
|
||||
|
||||
if product in select_alt_supply.dct_request_prod_from_firm:
|
||||
select_alt_supply.dct_request_prod_from_firm[product].append(self)
|
||||
# 查找与当前企业已连接的候选供应商
|
||||
lst_firm_connect = []
|
||||
if self.is_prf_conn:
|
||||
lst_firm_connect = [
|
||||
firm for firm in self.dct_cand_alt_supp_up_prod_disrupted[product]
|
||||
if self.firm_network.has_edge(self.unique_id, firm.unique_id) or
|
||||
self.firm_network.has_edge(firm.unique_id, self.unique_id)
|
||||
]
|
||||
|
||||
# 如果没有连接的供应商
|
||||
if not lst_firm_connect:
|
||||
if self.is_prf_size: # 根据规模加权选择
|
||||
lst_size = [firm.size_stat[-1][0] for firm in self.dct_cand_alt_supp_up_prod_disrupted[product]]
|
||||
lst_prob = [size / sum(lst_size) for size in lst_size]
|
||||
select_alt_supply = self.random.choices(
|
||||
self.dct_cand_alt_supp_up_prod_disrupted[product], weights=lst_prob
|
||||
)[0]
|
||||
else: # 随机选择
|
||||
select_alt_supply = self.random.choice(self.dct_cand_alt_supp_up_prod_disrupted[product])
|
||||
else: # 如果存在连接的供应商
|
||||
if self.is_prf_size: # 根据规模加权选择
|
||||
lst_firm_size = [firm.size_stat[-1][0] for firm in lst_firm_connect]
|
||||
lst_prob = [size / sum(lst_firm_size) for size in lst_firm_size]
|
||||
select_alt_supply = self.random.choices(lst_firm_connect, weights=lst_prob)[0]
|
||||
else: # 随机选择
|
||||
select_alt_supply = self.random.choice(lst_firm_connect)
|
||||
|
||||
# 检查选中的供应商是否能够生产产品
|
||||
if not select_alt_supply.is_prod_in_current_normal(product):
|
||||
# print(f"Selected supplier {select_alt_supply.unique_id} cannot produce product {product.unique_id}")
|
||||
|
||||
# 打印供应商的生产状态字典
|
||||
#print(f"Supplier production state: {select_alt_supply.dct_prod_up_prod_stat}")
|
||||
|
||||
# 检查产品是否存在于生产状态字典中
|
||||
if product in select_alt_supply.dct_prod_up_prod_stat:
|
||||
print(
|
||||
f"Product {product.unique_id} production state: {select_alt_supply.dct_prod_up_prod_stat[product]['p_stat']}")
|
||||
else:
|
||||
select_alt_supply.dct_request_prod_from_firm[product] = [self]
|
||||
print(f"Product {product.unique_id} not found in supplier production state.")
|
||||
return
|
||||
|
||||
self.dct_n_trial_up_prod_disrupted[product] += 1
|
||||
# 添加到供应商的请求字典
|
||||
if product in select_alt_supply.dct_request_prod_from_firm:
|
||||
select_alt_supply.dct_request_prod_from_firm[product].append(self)
|
||||
else:
|
||||
select_alt_supply.dct_request_prod_from_firm[product] = [self]
|
||||
|
||||
# 更新尝试次数
|
||||
self.dct_n_trial_up_prod_disrupted[product] += 1
|
||||
|
||||
def handle_request(self):
|
||||
for product, lst_firm in self.dct_request_prod_from_firm.items():
|
||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
13
main.py
13
main.py
|
@ -3,10 +3,14 @@ import random
|
|||
import time
|
||||
from multiprocessing import Process
|
||||
import argparse
|
||||
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
from computation import Computation
|
||||
from sqlalchemy.orm import close_all_sessions
|
||||
import yaml
|
||||
from controller_db import ControllerDB
|
||||
from 查看进度 import visualize_progress
|
||||
|
||||
|
||||
def controll_db_and_process(exp_argument, reset_sample_argument, reset_db_argument):
|
||||
|
@ -30,6 +34,13 @@ def do_process(target: object, controller_db: ControllerDB, ):
|
|||
|
||||
for i in process_list:
|
||||
i.join()
|
||||
|
||||
# 所有子进程完成后刷新最终进度
|
||||
visualize_progress()
|
||||
|
||||
# 显示最终进度后关闭图表
|
||||
plt.show()
|
||||
|
||||
def do_computation(c_db):
|
||||
exp = Computation(c_db)
|
||||
|
||||
|
@ -43,7 +54,7 @@ def do_computation(c_db):
|
|||
if __name__ == '__main__':
|
||||
# 输入参数
|
||||
parser = argparse.ArgumentParser(description='setting')
|
||||
parser.add_argument('--exp', type=str, default='without_exp')
|
||||
parser.add_argument('--exp', type=str, default='with_exp')
|
||||
parser.add_argument('--job', type=int, default='4')
|
||||
parser.add_argument('--reset_sample', type=int, default='0')
|
||||
parser.add_argument('--reset_db', type=bool, default=False)
|
||||
|
|
603
my_model.py
603
my_model.py
|
@ -42,6 +42,7 @@ class MyModel(Model):
|
|||
- seed (int): 随机种子的值,用于确保实验的可重复性。
|
||||
"""
|
||||
# 仿真参数
|
||||
self.agent_map = None
|
||||
self.firm_prod_labels_dict = None
|
||||
self.firm_relationship_cache = None
|
||||
self.firm_product_cache = None
|
||||
|
@ -53,17 +54,15 @@ class MyModel(Model):
|
|||
self.cap_limit_level = params['cap_limit_level'] # 产能限制的水平。
|
||||
self.diff_new_conn = params['diff_new_conn'] # 是否允许差异化的新连接。
|
||||
# 初始化停止时间步,可能是用户通过参数传入
|
||||
self.int_stop_ts = params.get('stop_t', 3) # 默认停止时间为 100
|
||||
self.int_stop_ts = params.get('n_iter', 3) # 默认停止时间为 100
|
||||
|
||||
# 网络初始化
|
||||
self.firm_network = nx.MultiDiGraph() # 企业之间的有向多重图。
|
||||
self.firm_prod_network = nx.MultiDiGraph() # 企业与产品关系的有向多重图。
|
||||
self.product_network = nx.MultiDiGraph() # 产品之间的有向多重图。
|
||||
self.G_FirmProd = nx.MultiDiGraph() # 初始化 企业-产品 关系的图
|
||||
self.G_Firm = nx.MultiDiGraph() # 使用 NetworkX 的有向多重图
|
||||
|
||||
# BOM(物料清单)图
|
||||
self.G_bom = nx.adjacency_graph(json.loads(params['g_bom'])) # 表示 BOM 结构的图。
|
||||
self.g_bom = nx.adjacency_graph(json.loads(params['g_bom'])) # 表示 BOM 结构的图。
|
||||
|
||||
# 随机数生成器
|
||||
self.nprandom = np.random.default_rng(params['seed']) # 基于固定种子的随机数生成器。
|
||||
|
@ -87,12 +86,22 @@ class MyModel(Model):
|
|||
self.company_agents = [] # 初始化公司代理列表
|
||||
|
||||
# 初始化模型的网络和代理
|
||||
# 检查缓存是否存在
|
||||
cache_file = "firm_network.pkl"
|
||||
if os.path.exists(cache_file):
|
||||
# 从缓存加载 firm_network
|
||||
with open(cache_file, 'rb') as f:
|
||||
self.firm_network = pickle.load(f)
|
||||
print("Loaded firm network from cache.")
|
||||
else:
|
||||
# 执行完整的初始化流程
|
||||
self.initialize_product_network(params)
|
||||
self.initialize_firm_network()
|
||||
self.build_firm_prod_labels_dict()
|
||||
self.initialize_firm_product_network()
|
||||
self.add_edges_to_firm_network()
|
||||
self.connect_unconnected_nodes()
|
||||
self.initialize_product_network(params) # 初始化产品网络。
|
||||
self.initialize_firm_network() # 初始化企业网络。
|
||||
self.build_firm_prod_labels_dict() # 构建企业与产品的映射关系字典
|
||||
self.initialize_firm_product_network() # 初始化企业与产品的网络。
|
||||
self.add_edges_to_firm_network() # 添加企业之间的边。
|
||||
self.connect_unconnected_nodes() # 连接未连接的节点。
|
||||
self.resource_integration()
|
||||
self.j_comp_consumed_produced()
|
||||
self.initialize_agents() # 初始化代理。
|
||||
|
@ -141,88 +150,65 @@ class MyModel(Model):
|
|||
初始化企业网络,处理一个 Code 映射到多个 Index 的情况,并缓存所有相关属性。
|
||||
"""
|
||||
|
||||
cache_file = "firm_network_cache.pkl"
|
||||
# 加载企业数据
|
||||
firm_data = pd.read_csv("input_data/input_firm_data/firm_amended.csv", dtype={'Code': str})
|
||||
firm_data['Code'] = firm_data['Code'].str.replace('.0', '', regex=False)
|
||||
|
||||
# 检查是否存在缓存
|
||||
if os.path.exists(cache_file):
|
||||
with open(cache_file, 'rb') as f:
|
||||
cached_data = pickle.load(f)
|
||||
# 加载企业与产品关系数据
|
||||
firm_industry_relation = pd.read_csv("input_data/firm_industry_relation.csv", dtype={'Firm_Code': str})
|
||||
bom_nodes = pd.read_csv("input_data/input_product_data/BomNodes.csv")
|
||||
|
||||
# 加载缓存的属性
|
||||
self.G_Firm = cached_data['G_Firm']
|
||||
self.firm_product_cache = cached_data['firm_product_cache']
|
||||
self.firm_relationship_cache = cached_data['firm_relationship_cache']
|
||||
print("Loaded firm network and related data from cache.")
|
||||
return
|
||||
# 构建 Code -> [Index] 的多值映射
|
||||
code_to_indices = bom_nodes.groupby('Code')['Index'].apply(list).to_dict()
|
||||
|
||||
# 如果没有缓存,则从头初始化网络
|
||||
try:
|
||||
# 加载企业数据
|
||||
firm_data = pd.read_csv("input_data/input_firm_data/firm_amended.csv", dtype={'Code': str})
|
||||
firm_data['Code'] = firm_data['Code'].str.replace('.0', '', regex=False)
|
||||
# 将 Product_Code 转换为 Product_Indices
|
||||
firm_industry_relation['Product_Indices'] = firm_industry_relation['Product_Code'].map(code_to_indices)
|
||||
|
||||
# 创建企业网络图
|
||||
self.G_Firm = nx.Graph()
|
||||
self.G_Firm.add_nodes_from(firm_data['Code'])
|
||||
# 检查并处理未映射的 Product_Code
|
||||
unmapped_products = firm_industry_relation[firm_industry_relation['Product_Indices'].isna()]
|
||||
if not unmapped_products.empty:
|
||||
print("Warning: The following Product_Code values could not be mapped to Index:")
|
||||
print(unmapped_products[['Firm_Code', 'Product_Code']])
|
||||
|
||||
# 设置节点属性
|
||||
firm_attributes = firm_data.set_index('Code').to_dict('index')
|
||||
nx.set_node_attributes(self.G_Firm, firm_attributes)
|
||||
firm_industry_relation['Product_Indices'] = firm_industry_relation['Product_Indices'].apply(
|
||||
lambda x: x if isinstance(x, list) else []
|
||||
)
|
||||
|
||||
print(f"Initialized G_Firm with {len(self.G_Firm.nodes)} nodes.")
|
||||
# 按 Firm_Code 分组生成企业的 Product_Code 和 Product_Indices 映射
|
||||
firm_product = (
|
||||
firm_industry_relation.groupby('Firm_Code')['Product_Code'].apply(list)
|
||||
)
|
||||
firm_product_indices = (
|
||||
firm_industry_relation.groupby('Firm_Code')['Product_Indices']
|
||||
.apply(lambda indices: [idx for sublist in indices for idx in sublist])
|
||||
)
|
||||
|
||||
# 加载企业与产品关系数据
|
||||
firm_industry_relation = pd.read_csv("input_data/firm_industry_relation.csv", dtype={'Firm_Code': str})
|
||||
bom_nodes = pd.read_csv("input_data/input_product_data/BomNodes.csv")
|
||||
# 设置企业属性并添加到网络中
|
||||
firm_attributes = firm_data.copy()
|
||||
firm_attributes['Product_Indices'] = firm_attributes['Code'].map(firm_product)
|
||||
firm_attributes['Product_Code'] = firm_attributes['Code'].map(firm_product_indices)
|
||||
firm_attributes.set_index('Code', inplace=True)
|
||||
|
||||
# 构建 Code -> [Index] 的多值映射
|
||||
code_to_indices = bom_nodes.groupby('Code')['Index'].apply(list).to_dict()
|
||||
self.firm_network.add_nodes_from(firm_data['Code'])
|
||||
|
||||
# 将 Product_Code 转换为 Product_Indices
|
||||
firm_industry_relation['Product_Indices'] = firm_industry_relation['Product_Code'].map(code_to_indices)
|
||||
# 为企业节点分配属性
|
||||
firm_labels_dict = {code: firm_attributes.loc[code].to_dict() for code in self.firm_network.nodes}
|
||||
nx.set_node_attributes(self.firm_network, firm_labels_dict)
|
||||
|
||||
# 检查并处理未映射的 Product_Code
|
||||
unmapped_products = firm_industry_relation[firm_industry_relation['Product_Indices'].isna()]
|
||||
if not unmapped_products.empty:
|
||||
print("Warning: The following Product_Code values could not be mapped to Index:")
|
||||
print(unmapped_products[['Firm_Code', 'Product_Code']])
|
||||
# 构建企业-产品映射缓存
|
||||
self.firm_product_cache = firm_product_indices.to_dict()
|
||||
|
||||
firm_industry_relation['Product_Indices'] = firm_industry_relation['Product_Indices'].apply(
|
||||
lambda x: x if isinstance(x, list) else []
|
||||
)
|
||||
|
||||
# 构建企业-产品映射缓存
|
||||
self.firm_product_cache = (
|
||||
firm_industry_relation.groupby('Firm_Code')['Product_Indices']
|
||||
.apply(lambda indices: [idx for sublist in indices for idx in sublist]) # 展平嵌套列表
|
||||
.to_dict()
|
||||
)
|
||||
print(f"Built firm_product_cache with {len(self.firm_product_cache)} entries.")
|
||||
|
||||
# 构建企业关系缓存
|
||||
self.firm_relationship_cache = {
|
||||
firm: self.compute_firm_relationship(firm, self.firm_product_cache)
|
||||
for firm in self.firm_product_cache
|
||||
}
|
||||
print(f"Built firm_relationship_cache with {len(self.firm_relationship_cache)} entries.")
|
||||
|
||||
# 保存所有关键属性到缓存
|
||||
cached_data = {
|
||||
'G_Firm': self.G_Firm,
|
||||
'firm_product_cache': self.firm_product_cache,
|
||||
'firm_relationship_cache': self.firm_relationship_cache
|
||||
}
|
||||
with open(cache_file, 'wb') as f:
|
||||
pickle.dump(cached_data, f)
|
||||
print("Saved firm network and related data to cache.")
|
||||
except Exception as e:
|
||||
print(f"Error during network initialization: {e}")
|
||||
# 构建企业关系缓存
|
||||
self.firm_relationship_cache = {
|
||||
firm: self.compute_firm_relationship(firm, self.firm_product_cache)
|
||||
for firm in self.firm_product_cache
|
||||
}
|
||||
|
||||
def compute_firm_relationship(self, firm, firm_product_cache):
|
||||
"""计算单个企业的供应链关系"""
|
||||
lst_pred_product_code = []
|
||||
for product_code in firm_product_cache[firm]:
|
||||
lst_pred_product_code += list(self.G_bom.predecessors(product_code))
|
||||
lst_pred_product_code += list(self.g_bom.predecessors(product_code))
|
||||
return list(set(lst_pred_product_code)) # 返回唯一值列表
|
||||
|
||||
def build_firm_prod_labels_dict(self):
|
||||
|
@ -246,54 +232,22 @@ class MyModel(Model):
|
|||
3. 将产品代码与索引进行映射,并为网络节点分配属性。
|
||||
4. 缓存网络和相关数据以加速后续运行。
|
||||
"""
|
||||
# 加载企业-行业关系数据
|
||||
firm_industry_relation = pd.read_csv("input_data/firm_industry_relation.csv")
|
||||
firm_industry_relation['Firm_Code'] = firm_industry_relation['Firm_Code'].astype(str)
|
||||
firm_industry_relation['Product_Code'] = firm_industry_relation['Product_Code'].apply(lambda x: [x])
|
||||
|
||||
cache_file = "firm_product_network_cache.pkl"
|
||||
# 映射产品代码到索引
|
||||
firm_industry_relation['Product_Code'] = firm_industry_relation['Product_Code'].apply(
|
||||
lambda codes: [idx for code in codes for idx in self.id_code.get(str(code), [])]
|
||||
)
|
||||
|
||||
# 检查缓存文件是否存在
|
||||
if os.path.exists(cache_file):
|
||||
try:
|
||||
with open(cache_file, 'rb') as f:
|
||||
cached_data = pickle.load(f)
|
||||
self.G_FirmProd = cached_data['G_FirmProd']
|
||||
print("Loaded firm-product network from cache.")
|
||||
return
|
||||
except Exception as e:
|
||||
print(f"Error loading cache: {e}. Reinitializing network.")
|
||||
|
||||
try:
|
||||
# 加载企业-行业关系数据
|
||||
firm_industry_relation = pd.read_csv("input_data/firm_industry_relation.csv")
|
||||
firm_industry_relation['Firm_Code'] = firm_industry_relation['Firm_Code'].astype(str)
|
||||
firm_industry_relation['Product_Code'] = firm_industry_relation['Product_Code'].apply(lambda x: [x])
|
||||
|
||||
# 创建企业-产品网络图
|
||||
self.G_FirmProd.add_nodes_from(firm_industry_relation.index)
|
||||
|
||||
# 遍历数据行,将产品代码映射到索引,并更新关系表
|
||||
for index, row in firm_industry_relation.iterrows():
|
||||
id_index_list = []
|
||||
for product_code in row['Product_Code']:
|
||||
if str(product_code) in self.id_code:
|
||||
id_index_list.extend(self.id_code[str(product_code)])
|
||||
firm_industry_relation.at[index, 'Product_Code'] = id_index_list
|
||||
|
||||
# 为每个节点分配属性
|
||||
firm_prod_labels_dict = {code: firm_industry_relation.loc[code].to_dict() for code in
|
||||
firm_industry_relation.index}
|
||||
nx.set_node_attributes(self.G_FirmProd, firm_prod_labels_dict)
|
||||
|
||||
print(f"Initialized G_FirmProd with {len(self.G_FirmProd.nodes)} nodes.")
|
||||
|
||||
# 缓存网络和数据
|
||||
cached_data = {'G_FirmProd': self.G_FirmProd}
|
||||
with open(cache_file, 'wb') as f:
|
||||
pickle.dump(cached_data, f)
|
||||
print("Saved firm-product network to cache.")
|
||||
|
||||
except FileNotFoundError as e:
|
||||
print(f"Error: {e}. File not found.")
|
||||
except Exception as e:
|
||||
print(f"Error initializing firm-product network: {e}")
|
||||
# 创建企业-产品网络图,同时附带属性
|
||||
nodes_with_attributes = [
|
||||
(index, firm_industry_relation.loc[index].to_dict())
|
||||
for index in firm_industry_relation.index
|
||||
]
|
||||
self.firm_prod_network.add_nodes_from(nodes_with_attributes)
|
||||
|
||||
def compute_firm_supply_chain(self, firm_industry_relation, g_bom):
|
||||
"""
|
||||
|
@ -311,18 +265,6 @@ class MyModel(Model):
|
|||
return supply_chain_cache
|
||||
|
||||
def add_edges_to_firm_network(self):
|
||||
"""利用缓存加速企业网络边的添加"""
|
||||
cache_file = "G_Firm_add_edges.pkl"
|
||||
# 检查是否存在缓存
|
||||
if os.path.exists(cache_file):
|
||||
with open(cache_file, 'rb') as f:
|
||||
cached_data = pickle.load(f)
|
||||
|
||||
# 加载缓存的属性
|
||||
self.G_Firm = cached_data
|
||||
print("Loaded G_Firm_add_edges cache.")
|
||||
return
|
||||
|
||||
for firm in self.firm_relationship_cache:
|
||||
lst_pred_product_code = self.firm_relationship_cache[firm]
|
||||
for pred_product_code in lst_pred_product_code:
|
||||
|
@ -330,15 +272,10 @@ class MyModel(Model):
|
|||
f for f, products in self.firm_product_cache.items()
|
||||
if pred_product_code in products
|
||||
]
|
||||
# 使用缓存,避免重复查询
|
||||
lst_choose_firm = self.select_firms(lst_pred_firm)
|
||||
# 添加边
|
||||
edges = [(pred_firm, firm, {'Product': pred_product_code}) for pred_firm in lst_choose_firm]
|
||||
self.G_Firm.add_edges_from(edges)
|
||||
cached_data = self.G_Firm
|
||||
with open(cache_file, 'wb') as f:
|
||||
pickle.dump(cached_data, f)
|
||||
print("Saved G_Firm_add_edges to cache.")
|
||||
self.firm_network.add_edges_from(edges)
|
||||
|
||||
def select_firms(self, lst_pred_firm):
|
||||
"""
|
||||
|
@ -353,9 +290,9 @@ class MyModel(Model):
|
|||
valid_firms = []
|
||||
lst_pred_firm_size = []
|
||||
for pred_firm in lst_pred_firm:
|
||||
if pred_firm in self.G_Firm.nodes and 'Revenue_Log' in self.G_Firm.nodes[pred_firm]:
|
||||
if pred_firm in self.firm_network.nodes and 'Revenue_Log' in self.firm_network.nodes[pred_firm]:
|
||||
valid_firms.append(pred_firm)
|
||||
lst_pred_firm_size.append(self.G_Firm.nodes[pred_firm]['Revenue_Log'])
|
||||
lst_pred_firm_size.append(self.firm_network.nodes[pred_firm]['Revenue_Log'])
|
||||
|
||||
# 如果未启用企业规模加权,随机选择
|
||||
if not self.is_prf_size:
|
||||
|
@ -379,20 +316,8 @@ class MyModel(Model):
|
|||
|
||||
def add_edges_to_firm_product_network(self, node, pred_product_code, lst_choose_firm):
|
||||
""" Helper function to add edges to the firm-product network """
|
||||
|
||||
cache_file = "G_FirmProd_cache.pkl"
|
||||
# 检查是否存在缓存
|
||||
if os.path.exists(cache_file):
|
||||
with open(cache_file, 'rb') as f:
|
||||
cached_data = pickle.load(f)
|
||||
|
||||
# 加载缓存的属性
|
||||
self.G_FirmProd = cached_data
|
||||
print("Loaded add_edges_to_firm from cache.")
|
||||
return
|
||||
|
||||
set_node_prod_code = set(self.G_Firm.nodes[node]['Product_Code'])
|
||||
set_pred_succ_code = set(self.G_bom.successors(pred_product_code))
|
||||
set_node_prod_code = set(self.firm_network.nodes[node]['Product_Code'])
|
||||
set_pred_succ_code = set(self.g_bom.successors(pred_product_code))
|
||||
lst_use_pred_prod_code = list(set_node_prod_code & set_pred_succ_code)
|
||||
|
||||
if len(lst_use_pred_prod_code) == 0:
|
||||
|
@ -400,7 +325,7 @@ class MyModel(Model):
|
|||
|
||||
pred_node_list = []
|
||||
for pred_firm in lst_choose_firm:
|
||||
for n, v in self.G_FirmProd.nodes(data=True):
|
||||
for n, v in self.firm_prod_network.nodes(data=True):
|
||||
for v1 in v['Product_Code']:
|
||||
if v1 == pred_product_code and v['Firm_Code'] == pred_firm:
|
||||
pred_node_list.append(n)
|
||||
|
@ -410,7 +335,7 @@ class MyModel(Model):
|
|||
pred_node = -1
|
||||
current_node_list = []
|
||||
for use_pred_prod_code in lst_use_pred_prod_code:
|
||||
for n, v in self.G_FirmProd.nodes(data=True):
|
||||
for n, v in self.firm_prod_network.nodes(data=True):
|
||||
for v1 in v['Product_Code']:
|
||||
if v1 == use_pred_prod_code and v['Firm_Code'] == node:
|
||||
current_node_list.append(n)
|
||||
|
@ -419,13 +344,7 @@ class MyModel(Model):
|
|||
else:
|
||||
current_node = -1
|
||||
if current_node != -1 and pred_node != -1:
|
||||
self.G_FirmProd.add_edge(pred_node, current_node)
|
||||
# 保存所有关键属性到缓存
|
||||
|
||||
cached_data = self.G_FirmProd
|
||||
with open(cache_file, 'wb') as f:
|
||||
pickle.dump(cached_data, f)
|
||||
print("Saved G_FirmProd to cache.")
|
||||
self.firm_prod_network.add_edge(pred_node, current_node)
|
||||
|
||||
def connect_unconnected_nodes(self):
|
||||
"""
|
||||
|
@ -435,34 +354,22 @@ class MyModel(Model):
|
|||
- 为未连接节点添加边,连接到可能的下游企业。
|
||||
- 同时更新 G_FirmProd 网络,反映企业与产品的关系。
|
||||
"""
|
||||
cache_file = "connect_unconnected_nodes_cache.pkl"
|
||||
|
||||
# 检查是否存在缓存
|
||||
if os.path.exists(cache_file):
|
||||
with open(cache_file, 'rb') as f:
|
||||
cached_data = pickle.load(f)
|
||||
|
||||
# 加载缓存的属性
|
||||
self.G_Firm = cached_data['firm_network']
|
||||
self.firm_product_cache = cached_data['firm_prod_network']
|
||||
print("Loaded G_Firm and firm_product_cache from cache.")
|
||||
return
|
||||
|
||||
for node in nx.nodes(self.G_Firm):
|
||||
for node in nx.nodes(self.firm_network):
|
||||
# 如果节点没有任何连接,则处理该节点
|
||||
if self.G_Firm.degree(node) == 0:
|
||||
if self.firm_network.degree(node) == 0:
|
||||
# 获取当前节点的产品列表
|
||||
product_codes = self.G_Firm.nodes[node].get('Product_Code', [])
|
||||
product_codes = self.firm_network.nodes[node].get('Product_Code', [])
|
||||
for product_code in product_codes:
|
||||
# 查找与当前产品相关的 FirmProd 节点
|
||||
current_node_list = [
|
||||
n for n, v in self.G_FirmProd.nodes(data=True)
|
||||
n for n, v in self.firm_prod_network.nodes(data=True)
|
||||
if v['Firm_Code'] == node and product_code in v['Product_Code']
|
||||
]
|
||||
current_node = current_node_list[0] if current_node_list else -1
|
||||
|
||||
# 查找当前产品的所有下游产品代码
|
||||
succ_product_codes = list(self.G_bom.successors(product_code))
|
||||
succ_product_codes = list(self.g_bom.successors(product_code))
|
||||
for succ_product_code in succ_product_codes:
|
||||
# 查找生产下游产品的企业
|
||||
succ_firms = [
|
||||
|
@ -479,7 +386,7 @@ class MyModel(Model):
|
|||
if self.is_prf_size:
|
||||
# 基于企业规模选择供应商
|
||||
succ_firm_sizes = [
|
||||
self.G_Firm.nodes[succ_firm].get('Revenue_Log', 0)
|
||||
self.firm_network.nodes[succ_firm].get('Revenue_Log', 0)
|
||||
for succ_firm in succ_firms
|
||||
]
|
||||
if sum(succ_firm_sizes) > 0:
|
||||
|
@ -494,29 +401,24 @@ class MyModel(Model):
|
|||
|
||||
# 添加边到 G_Firm 图
|
||||
edges = [(node, firm, {'Product': product_code}) for firm in selected_firms]
|
||||
self.G_Firm.add_edges_from(edges)
|
||||
self.firm_network.add_edges_from(edges)
|
||||
|
||||
# 更新 G_FirmProd 网络
|
||||
for succ_firm in selected_firms:
|
||||
succ_node_list = [
|
||||
n for n, v in self.G_FirmProd.nodes(data=True)
|
||||
n for n, v in self.firm_prod_network.nodes(data=True)
|
||||
if v['Firm_Code'] == succ_firm and succ_product_code in v['Product_Code']
|
||||
]
|
||||
succ_node = succ_node_list[0] if succ_node_list else -1
|
||||
if current_node != -1 and succ_node != -1:
|
||||
self.G_FirmProd.add_edge(current_node, succ_node)
|
||||
self.firm_prod_network.add_edge(current_node, succ_node)
|
||||
|
||||
# 保存网络数据到样本
|
||||
self.firm_network = self.G_Firm # 使用 networkx 图对象表示的企业网络
|
||||
self.firm_prod_network = self.G_FirmProd # 使用 networkx 图对象表示的企业与产品关系网络
|
||||
|
||||
cached_data = {
|
||||
'firm_network': self.firm_network,
|
||||
'firm_prod_network': self.firm_prod_network,
|
||||
}
|
||||
# 保存构建完成的 firm_network 到缓存
|
||||
cache_file = "firm_network.pkl"
|
||||
os.makedirs("cache", exist_ok=True)
|
||||
with open(cache_file, 'wb') as f:
|
||||
pickle.dump(cached_data, f)
|
||||
print("Saved firm network and related data to cache.")
|
||||
pickle.dump(self.firm_network, f)
|
||||
# print("Firm network has been saved to cache.")
|
||||
|
||||
def initialize_agents(self):
|
||||
"""
|
||||
|
@ -545,18 +447,18 @@ class MyModel(Model):
|
|||
]
|
||||
|
||||
# 获取企业的需求数量和生产输出
|
||||
demand_quantity = self.data_materials.loc[self.data_materials['Firm_Code'] == ag_node]
|
||||
production_output = self.data_produced.loc[self.data_produced['Firm_Code'] == ag_node]
|
||||
demand_quantity = self.data_materials.loc[self.data_materials['Firm_Code'] == int(ag_node)]
|
||||
production_output = self.data_produced.loc[self.data_produced['Firm_Code'] == int(ag_node)]
|
||||
|
||||
# 获取企业的资源信息
|
||||
# 获取企业的资源信息,同时处理 R、P、C 的情况
|
||||
try:
|
||||
R = self.firm_resource_R.loc[int(ag_node)]
|
||||
P = self.firm_resource_P.loc[int(ag_node)]
|
||||
P = self.firm_resource_P.get(int(ag_node))
|
||||
C = self.firm_resource_C.loc[int(ag_node)]
|
||||
except KeyError:
|
||||
# 如果资源数据缺失,提供默认值
|
||||
R, P, C = [], [], []
|
||||
|
||||
R, P, C = [], {}, [] # 如果任何资源不存在,返回空列表
|
||||
# 在模型初始化时,构建 unique_id -> agent 的快速映射字典
|
||||
self.agent_map = {agent.unique_id: agent for agent in self.company_agents}
|
||||
# 创建企业代理
|
||||
firm_agent = FirmAgent(
|
||||
unique_id=ag_node,
|
||||
|
@ -580,14 +482,25 @@ class MyModel(Model):
|
|||
- 更新公司与产品的生产状态为干扰状态。
|
||||
"""
|
||||
# 构建公司与受干扰产品的映射字典
|
||||
disruption_mapping = {
|
||||
firm: [
|
||||
product for product in self.product_agents
|
||||
if product.unique_id in lst_product
|
||||
disruption_mapping = {}
|
||||
|
||||
for firm_code, lst_product_indices in self.dct_lst_init_disrupt_firm_prod.items():
|
||||
# 查找企业对象
|
||||
firm = next((f for f in self.company_agents if f.unique_id == firm_code), None)
|
||||
if not firm:
|
||||
print(f"Warning: Firm {firm_code} not found. Skipping.")
|
||||
continue
|
||||
|
||||
# 查找有效的产品代理
|
||||
valid_products = [
|
||||
product for product in self.product_agents if product.unique_id in lst_product_indices
|
||||
]
|
||||
for firm_code, lst_product in self.dct_lst_init_disrupt_firm_prod.items()
|
||||
if (firm := next((f for f in self.company_agents if f.unique_id == firm_code), None))
|
||||
}
|
||||
if not valid_products:
|
||||
print(f"Warning: No valid products found for Firm {firm_code}. Skipping.")
|
||||
continue
|
||||
|
||||
# 更新映射
|
||||
disruption_mapping[firm] = valid_products
|
||||
|
||||
# 更新干扰字典
|
||||
self.dct_lst_init_disrupt_firm_prod = disruption_mapping
|
||||
|
@ -595,10 +508,13 @@ class MyModel(Model):
|
|||
# 设置初始干扰状态
|
||||
for firm, disrupted_products in disruption_mapping.items():
|
||||
for product in disrupted_products:
|
||||
# 确保产品在公司的生产状态中
|
||||
# 检查产品是否在企业的生产状态中
|
||||
if product not in firm.dct_prod_up_prod_stat:
|
||||
raise ValueError(
|
||||
f"Product {product.unique_id} not found in firm {firm.unique_id}'s production status.")
|
||||
print(
|
||||
f"Warning: Product {product.unique_id} not found in firm "
|
||||
f"{firm.unique_id}'s production status. Skipping."
|
||||
)
|
||||
continue
|
||||
|
||||
# 更新产品状态为干扰状态,并记录干扰时间
|
||||
firm.dct_prod_up_prod_stat[product]['p_stat'].append(('D', self.t))
|
||||
|
@ -618,50 +534,38 @@ class MyModel(Model):
|
|||
- 合并设备数据与设备残值数据。
|
||||
- 按企业分组生成资源列表。
|
||||
"""
|
||||
try:
|
||||
# 加载企业的材料、设备和产品数据
|
||||
data_R = pd.read_csv("input_data/input_firm_data/firms_materials.csv")
|
||||
data_C = pd.read_csv("input_data/input_firm_data/firms_devices.csv")
|
||||
data_P = pd.read_csv("input_data/input_firm_data/firms_products.csv")
|
||||
# 加载企业的材料、设备和产品数据
|
||||
data_R = pd.read_csv("input_data/input_firm_data/firms_materials.csv")
|
||||
data_C = pd.read_csv("input_data/input_firm_data/firms_devices.csv")
|
||||
data_P = pd.read_csv("input_data/input_firm_data/firms_products.csv")
|
||||
|
||||
# 加载设备残值数据,并合并到设备数据中
|
||||
device_salvage_values = pd.read_csv('input_data/device_salvage_values.csv')
|
||||
self.device_salvage_values = device_salvage_values
|
||||
# 加载设备残值数据,并合并到设备数据中
|
||||
device_salvage_values = pd.read_csv('input_data/device_salvage_values.csv')
|
||||
self.device_salvage_values = device_salvage_values
|
||||
|
||||
# 合并设备数据和设备残值
|
||||
data_merged_C = pd.merge(data_C, device_salvage_values, on='设备id', how='left')
|
||||
# 合并设备数据和设备残值
|
||||
data_merged_C = pd.merge(data_C, device_salvage_values, on='设备id', how='left')
|
||||
|
||||
# 按企业分组并生成资源列表
|
||||
firm_resource_R = (
|
||||
data_R.groupby('Firm_Code')[['材料id', '材料数量']]
|
||||
.apply(lambda x: x.values.tolist())
|
||||
.to_dict() # 转换为字典格式,便于快速查询
|
||||
)
|
||||
# 按企业分组并生成资源列表
|
||||
firm_resource_R = (
|
||||
data_R.groupby('Firm_Code')[['材料id', '材料数量']]
|
||||
.apply(lambda x: x.values.tolist())
|
||||
)
|
||||
|
||||
firm_resource_C = (
|
||||
data_merged_C.groupby('Firm_Code')[['设备id', '设备数量', '设备残值']]
|
||||
.apply(lambda x: x.values.tolist())
|
||||
.to_dict()
|
||||
)
|
||||
firm_resource_C = (
|
||||
data_merged_C.groupby('Firm_Code')[['设备id', '设备数量', '设备残值']]
|
||||
.apply(lambda x: x.values.tolist())
|
||||
)
|
||||
|
||||
firm_resource_P = (
|
||||
data_P.groupby('Firm_Code')[['产品id', '产品数量']]
|
||||
.apply(lambda x: x.values.tolist())
|
||||
.to_dict()
|
||||
)
|
||||
firm_resource_P = (
|
||||
data_P.groupby('Firm_Code')[['产品id', '产品数量']]
|
||||
.apply(lambda x: x.values.tolist())
|
||||
)
|
||||
|
||||
# 将结果存储到模型中
|
||||
self.firm_resource_R = firm_resource_R
|
||||
self.firm_resource_C = firm_resource_C
|
||||
self.firm_resource_P = firm_resource_P
|
||||
|
||||
except FileNotFoundError as e:
|
||||
print(f"Error: Missing input file - {e.filename}")
|
||||
self.firm_resource_R, self.firm_resource_C, self.firm_resource_P = {}, {}, {}
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error during resource integration: {e}")
|
||||
self.firm_resource_R, self.firm_resource_C, self.firm_resource_P = {}, {}, {}
|
||||
# 将结果存储到模型中
|
||||
self.firm_resource_R = firm_resource_R
|
||||
self.firm_resource_C = firm_resource_C
|
||||
self.firm_resource_P = firm_resource_P
|
||||
|
||||
def j_comp_consumed_produced(self):
|
||||
"""
|
||||
|
@ -720,21 +624,23 @@ class MyModel(Model):
|
|||
3. 判断资源和设备是否需要采购,并处理采购。
|
||||
4. 资源消耗和产品生产。
|
||||
"""
|
||||
# 1. 移除客户边并中断客户上游产品
|
||||
self._remove_disrupted_edges()
|
||||
self._disrupt_upstream_products()
|
||||
while self.t < self.int_stop_ts: # 使用循环控制时间步
|
||||
# 1. 移除客户边并中断客户上游产品
|
||||
self._remove_disrupted_edges()
|
||||
self._disrupt_upstream_products()
|
||||
|
||||
# 2. 尝试寻找替代供应链
|
||||
self._trial_process()
|
||||
# 2. 尝试寻找替代供应链
|
||||
self._trial_process()
|
||||
|
||||
# 3. 判断是否需要采购资源和设备
|
||||
self._handle_resource_and_machinery_purchase()
|
||||
# 3. 判断是否需要采购资源和设备
|
||||
self._handle_material_purchase()
|
||||
self._handle_machinery_purchase()
|
||||
|
||||
# 4. 资源消耗和产品生产
|
||||
self._consume_resources_and_produce()
|
||||
# 4. 资源消耗和产品生产
|
||||
self._consume_resources_and_produce()
|
||||
|
||||
# 增加时间步
|
||||
self.t += 1
|
||||
# 增加时间步
|
||||
self.t += 1
|
||||
|
||||
# 子方法定义
|
||||
def _remove_disrupted_edges(self):
|
||||
|
@ -784,76 +690,153 @@ class MyModel(Model):
|
|||
for firm in self.company_agents:
|
||||
firm.clean_before_trial()
|
||||
|
||||
def _handle_resource_and_machinery_purchase(self):
|
||||
"""处理资源和设备的采购。"""
|
||||
for firm in self.company_agents:
|
||||
# 判断资源需求
|
||||
for material_id, material_quantity in firm.R:
|
||||
if material_quantity <= firm.s_r:
|
||||
required_quantity = firm.S_r - material_quantity
|
||||
firm.request_material_purchase(material_id, required_quantity)
|
||||
def _handle_material_purchase(self):
|
||||
"""
|
||||
判断并处理资源的采购。
|
||||
"""
|
||||
# 存储需要采购资源的企业及其需求
|
||||
purchase_material_firms = {}
|
||||
|
||||
# 判断设备需求
|
||||
for device_id, device_quantity, device_salvage in firm.C:
|
||||
device_salvage -= firm.x
|
||||
if device_salvage <= 0: # 如果设备残值小于等于 0
|
||||
device_quantity -= 1
|
||||
firm.request_device_purchase(device_id, 1)
|
||||
# 遍历所有企业,检查资源需求
|
||||
for firm in self.company_agents:
|
||||
if not firm.R: # 跳过没有资源的企业
|
||||
continue
|
||||
|
||||
# 遍历资源列表,检查哪些资源需要补货
|
||||
for resource_id, resource_quantity in firm.R:
|
||||
if resource_quantity <= firm.s_r: # 如果资源低于阈值,记录需求
|
||||
required_quantity = firm.S_r - resource_quantity
|
||||
if firm not in purchase_material_firms:
|
||||
purchase_material_firms[firm] = []
|
||||
purchase_material_firms[firm].append((resource_id, required_quantity))
|
||||
|
||||
# 寻找供应商并处理补货
|
||||
for firm, material_requests in purchase_material_firms.items():
|
||||
for resource_id, required_quantity in material_requests:
|
||||
# 寻找供应商
|
||||
supplier = firm.seek_material_supply(resource_id)
|
||||
if supplier != -1: # 如果找到供应商
|
||||
# 供应商处理资源请求
|
||||
supplier.handle_material_request([resource_id, required_quantity])
|
||||
# 更新当前企业的资源数量
|
||||
for resource in firm.R:
|
||||
if resource[0] == resource_id:
|
||||
resource[1] = firm.S_r
|
||||
|
||||
def _handle_machinery_purchase(self):
|
||||
"""
|
||||
判断并处理设备的采购。
|
||||
"""
|
||||
# 存储需要采购设备的企业及其需求
|
||||
purchase_machinery_firms = {}
|
||||
|
||||
# 遍历所有企业,检查设备需求
|
||||
for firm in self.company_agents:
|
||||
if not firm.C: # 跳过没有设备的企业
|
||||
continue
|
||||
|
||||
# 检查设备残值,记录需要补充的设备
|
||||
for equipment in firm.C:
|
||||
equipment_id, equipment_quantity, equipment_salvage = equipment
|
||||
equipment_salvage -= firm.x # 减少设备残值
|
||||
if equipment_salvage <= 0: # 如果残值小于等于 0
|
||||
equipment_quantity -= 1
|
||||
required_quantity = 1 # 需要补充的设备数量
|
||||
if firm not in purchase_machinery_firms:
|
||||
purchase_machinery_firms[firm] = []
|
||||
purchase_machinery_firms[firm].append((equipment_id, required_quantity))
|
||||
|
||||
# 寻找供应商并处理设备补充
|
||||
for firm, machinery_requests in purchase_machinery_firms.items():
|
||||
for equipment_id, required_quantity in machinery_requests:
|
||||
# 寻找供应商
|
||||
supplier = firm.seek_machinery_supply(equipment_id)
|
||||
if supplier != -1: # 如果找到供应商
|
||||
# 供应商处理设备请求
|
||||
supplier.handle_machinery_request([equipment_id, required_quantity])
|
||||
# 恢复企业的设备数量和残值
|
||||
for equipment, initial_equipment in zip(firm.C, firm.C0):
|
||||
if equipment[0] == equipment_id:
|
||||
equipment[1] = initial_equipment[1] # 恢复数量
|
||||
equipment[2] = initial_equipment[2] # 恢复残值
|
||||
|
||||
def _consume_resources_and_produce(self):
|
||||
"""消耗资源并生产产品。"""
|
||||
"""
|
||||
消耗资源并生产产品。
|
||||
"""
|
||||
k = 0.6 # 资源消耗比例
|
||||
production_increase_ratio = 1.6 # 产品生产比例
|
||||
|
||||
# 遍历每个企业
|
||||
for firm in self.company_agents:
|
||||
# 计算消耗量
|
||||
consumed_resources = {}
|
||||
for industry in firm.indus_i:
|
||||
for product_id, product_quantity in firm.P.items():
|
||||
if product_id == industry.unique_id:
|
||||
consumed_resources[industry] = product_quantity * k
|
||||
# 计算资源消耗
|
||||
consumed_resources = self._calculate_consumed_resources(firm, k)
|
||||
|
||||
# 消耗资源
|
||||
for resource_id, resource_quantity in firm.R.items():
|
||||
for industry, consumed_quantity in consumed_resources.items():
|
||||
if resource_id in industry.resource_ids:
|
||||
firm.R[resource_id] -= consumed_quantity
|
||||
self._consume_resources(firm, consumed_resources)
|
||||
|
||||
# 生产产品
|
||||
for product_id, product_quantity in firm.P.items():
|
||||
firm.P[product_id] = product_quantity * 1.6
|
||||
self._produce_products(firm, production_increase_ratio)
|
||||
|
||||
# 刷新资源和设备状态
|
||||
firm.refresh_R()
|
||||
firm.refresh_C()
|
||||
firm.refresh_P()
|
||||
|
||||
def _calculate_consumed_resources(self, firm, k):
|
||||
"""
|
||||
计算企业的资源消耗量。
|
||||
"""
|
||||
consumed_resources = {}
|
||||
for industry in firm.indus_i:
|
||||
consumed_quantity = sum(
|
||||
product[1] * k
|
||||
for product in firm.P
|
||||
if product[0] == industry.unique_id
|
||||
)
|
||||
consumed_resources[industry.unique_id] = consumed_quantity
|
||||
return consumed_resources
|
||||
|
||||
def _consume_resources(self, firm, consumed_resources):
|
||||
"""
|
||||
消耗企业的资源。
|
||||
"""
|
||||
for resource in firm.R:
|
||||
resource_id, resource_quantity = resource[0], resource[1]
|
||||
if resource_id in consumed_resources:
|
||||
resource[1] = max(0, resource_quantity - consumed_resources[resource_id])
|
||||
|
||||
def _produce_products(self, firm, production_increase_ratio):
|
||||
"""
|
||||
生产企业的产品。
|
||||
"""
|
||||
for product in firm.P:
|
||||
product[1] *= production_increase_ratio
|
||||
|
||||
def end(self):
|
||||
"""
|
||||
结束模型运行并保存结果。
|
||||
功能:
|
||||
- 检查结果是否已存在,避免重复写入。
|
||||
- 保存企业和产品的生产状态到数据库。
|
||||
- 更新样本的状态为完成,并记录停止时间和计算机名称。
|
||||
- 如果当前样本的结果未保存,则保存所有生产状态为非正常状态的结果。
|
||||
- 更新样本状态为完成,并记录相关信息。
|
||||
"""
|
||||
# 检查当前样本结果是否已存在
|
||||
qry_result = db_session.query(Result).filter_by(s_id=self.sample.id)
|
||||
if qry_result.count() == 0:
|
||||
# 收集所有结果
|
||||
lst_result_info = []
|
||||
for firm in self.company_agents:
|
||||
for prod, dct_status_supply in firm.dct_prod_up_prod_stat.items():
|
||||
# 检查产品状态是否都为正常
|
||||
if not all(status == 'N' for status, _ in dct_status_supply['p_stat']):
|
||||
for status, ts in dct_status_supply['p_stat']:
|
||||
# 创建结果对象
|
||||
lst_result_info.append(Result(
|
||||
s_id=self.sample.id,
|
||||
id_firm=firm.unique_id,
|
||||
id_product=prod.unique_id,
|
||||
ts=ts,
|
||||
status=status
|
||||
))
|
||||
if not db_session.query(Result).filter_by(s_id=self.sample.id).first():
|
||||
# 生成需要保存的结果列表
|
||||
lst_result_info = [
|
||||
Result(
|
||||
s_id=self.sample.id,
|
||||
id_firm=firm.unique_id,
|
||||
id_product=prod.unique_id,
|
||||
ts=ts,
|
||||
status=status
|
||||
)
|
||||
for firm in self.company_agents
|
||||
for prod, dct_status_supply in firm.dct_prod_up_prod_stat.items()
|
||||
if not all(stat == 'N' for stat, _ in dct_status_supply['p_stat'])
|
||||
for status, ts in dct_status_supply['p_stat']
|
||||
]
|
||||
|
||||
# 批量保存结果
|
||||
# 批量保存结果到数据库
|
||||
if lst_result_info:
|
||||
db_session.bulk_save_objects(lst_result_info)
|
||||
db_session.commit()
|
||||
|
|
4
orm.py
4
orm.py
|
@ -98,8 +98,8 @@ class Result(Base):
|
|||
s_id = Column(Integer, ForeignKey('{}.id'.format(
|
||||
f"{db_name_prefix}_sample")), nullable=False)
|
||||
|
||||
id_firm = Column(String(10), nullable=False)
|
||||
id_product = Column(String(10), nullable=False)
|
||||
id_firm = Column(String(20), nullable=False)
|
||||
id_product = Column(String(20), nullable=False)
|
||||
ts = Column(Integer, nullable=False)
|
||||
status = Column(String(5), nullable=False)
|
||||
|
||||
|
|
|
@ -1,76 +1 @@
|
|||
s_id,id_firm,id_product,ts
|
||||
1441,13,2.1.3.4,0
|
||||
1441,126,2.1.3,1
|
||||
1441,97,2.1.3,1
|
||||
1441,106,2.1.3,1
|
||||
1566,13,2.1.3.7,0
|
||||
1566,126,2.1.3,1
|
||||
2073,14,1.3.3.4,0
|
||||
2073,75,1.3.3,1
|
||||
2621,85,1.3.1,1
|
||||
2621,21,1.3.1.3,0
|
||||
2621,100,1.3.1,1
|
||||
3386,22,2.1.3.7,0
|
||||
3386,108,2.1.3,1
|
||||
4249,23,2.3.1,0
|
||||
4249,84,2.3,1
|
||||
4440,25,1.3.1.7,0
|
||||
4440,100,1.3.1,1
|
||||
4624,74,2.1.3,1
|
||||
4624,26,2.1.3.4,0
|
||||
5015,31,1.3.3.3,0
|
||||
5015,75,1.3.3,1
|
||||
5015,97,1.3.3,1
|
||||
5720,94,1.1,1
|
||||
5720,36,1.1.1,0
|
||||
5720,126,1.1,1
|
||||
7349,80,1.3.4,1
|
||||
7349,45,1.3.4.2,0
|
||||
7349,77,1.3.4,1
|
||||
7399,79,2.1.4.1,1
|
||||
7399,45,2.1.4.1.1,0
|
||||
8285,99,1.3.1,1
|
||||
8285,49,1.3.1.4,0
|
||||
8601,93,1.3.1,1
|
||||
8601,50,1.3.1.5,0
|
||||
8601,85,1.3.1,1
|
||||
9072,53,1.4.3.4,0
|
||||
9072,142,1.4.3,1
|
||||
9382,41,1.4.5,1
|
||||
9382,53,1.4.5.6,0
|
||||
10098,99,1.3.3,1
|
||||
10098,57,1.3.3.3,0
|
||||
10121,57,2.3.1,0
|
||||
10121,124,2.3,1
|
||||
10521,81,1.3.4,1
|
||||
10521,58,1.3.4.3,0
|
||||
10521,80,1.3.4,1
|
||||
11675,93,1.3.1,1
|
||||
11675,68,1.3.1.3,0
|
||||
11678,99,1.3.1,1
|
||||
11678,85,1.3.1,1
|
||||
11678,68,1.3.1.3,0
|
||||
12837,126,2.1.3,1
|
||||
12837,74,2.1.3,1
|
||||
12837,73,2.1.3,1
|
||||
12837,79,2.1.3.2,0
|
||||
13084,108,2.1.3,1
|
||||
13084,73,2.1.3,1
|
||||
13084,106,2.1.3,1
|
||||
13084,79,2.1.3.7,0
|
||||
13084,148,2.1.3,1
|
||||
13084,126,2.1.3,1
|
||||
16647,115,1.1.3,0
|
||||
16647,94,1.1,1
|
||||
16903,85,2.1.1,1
|
||||
16903,117,2.1.1.4,0
|
||||
17379,119,1.3.1.1,0
|
||||
17379,100,1.3.1,1
|
||||
17922,126,2.1.1.5,0
|
||||
17922,80,2.1.1,1
|
||||
18824,85,2.1.1,1
|
||||
18824,131,2.1.1.5,0
|
||||
19562,135,2.2,0
|
||||
19562,98,2,1
|
||||
21447,159,2.1.2,1
|
||||
21447,149,2.1.2.2,0
|
||||
|
|
|
Binary file not shown.
Before Width: | Height: | Size: 1.1 MiB After Width: | Height: | Size: 2.0 MiB |
|
@ -0,0 +1,52 @@
|
|||
from matplotlib import rcParams, pyplot as plt
|
||||
from sqlalchemy import func
|
||||
from orm import db_session, Sample
|
||||
|
||||
# 创建全局绘图对象和轴
|
||||
fig, ax = plt.subplots(figsize=(8, 5))
|
||||
plt.ion() # 启用交互模式
|
||||
|
||||
def visualize_progress():
|
||||
"""
|
||||
可视化 `is_done_flag` 的分布,动态更新进度条。
|
||||
"""
|
||||
|
||||
# 设置全局字体
|
||||
rcParams['font.family'] = 'SimHei' # 黑体,适用于中文
|
||||
rcParams['font.size'] = 12
|
||||
|
||||
# 查询数据库中各 is_done_flag 的数量
|
||||
result = db_session.query(
|
||||
Sample.is_done_flag, func.count(Sample.id)
|
||||
).group_by(Sample.is_done_flag).all()
|
||||
|
||||
# 转换为字典
|
||||
data = {flag: count for flag, count in result}
|
||||
|
||||
# 填充缺失的标志为 0
|
||||
for flag in [-1, 0, 1]:
|
||||
data.setdefault(flag, 0)
|
||||
|
||||
# 准备数据
|
||||
labels = ['未完成 (-1)', '计算中(0)', '完成 (1)']
|
||||
values = [data[-1], data[0], data[1]]
|
||||
|
||||
# 清空之前的绘图内容
|
||||
ax.clear()
|
||||
|
||||
# 创建柱状图
|
||||
ax.bar(labels, values, color=['red', 'orange', 'green'])
|
||||
ax.set_title('任务进度分布', fontsize=16)
|
||||
ax.set_xlabel('任务状态', fontsize=14)
|
||||
ax.set_ylabel('数量', fontsize=14)
|
||||
ax.tick_params(axis='both', labelsize=12)
|
||||
|
||||
# 显示具体数量
|
||||
for i, v in enumerate(values):
|
||||
ax.text(i, v + 0.5, str(v), ha='center', fontsize=12)
|
||||
|
||||
# 刷新绘图
|
||||
plt.pause(1) # 暂停一段时间以更新图表
|
||||
|
||||
# 关闭窗口时,停止交互模式
|
||||
# plt.ioff()
|
Loading…
Reference in New Issue