优化代码中的网络结构构建过程,未完成缓存设置 和 多进程共享内容

This commit is contained in:
Cricial 2025-01-26 23:56:37 +08:00
parent 09622cf33d
commit 13521ff752
22 changed files with 850 additions and 51364 deletions

View File

@ -3,13 +3,6 @@
<component name="CsvFileAttributes">
<option name="attributeMap">
<map>
<entry key="C:\Users\www\Desktop\半导体数据——暂时的数据\我做的数据\汇总数据\BomCateNet.csv">
<value>
<Attribute>
<option name="separator" value="," />
</Attribute>
</value>
</entry>
<entry key="\input_data\device_salvage_values.csv">
<value>
<Attribute>

24
11.py Normal file
View File

@ -0,0 +1,24 @@
import pickle
import os
def load_cached_data(file_path):
"""
从指定的缓存文件加载数据
如果文件不存在或加载失败则返回空字典
"""
if not os.path.exists(file_path):
print(f"Warning: Cache file '{file_path}' does not exist.")
return {}
try:
with open(file_path, 'rb') as f:
data = pickle.load(f)
print(f"Successfully loaded cache from '{file_path}'.")
return data
except (pickle.UnpicklingError, FileNotFoundError, EOFError) as e:
print(f"Error loading cache from '{file_path}': {e}")
return {}
# 示例用法
data_dct = load_cached_data("G_Firm_add_edges.pkl")

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -41,8 +41,6 @@ class Computation:
'seed': sample_random.seed,
**dct_exp}
product_network_test = nx.adjacency_graph(json.loads(dct_sample_para['g_bom']))
model = MyModel(dct_sample_para)
model.step() # 运行仿真一步

View File

@ -8,5 +8,5 @@ test: # only for test scenarios
n_iter: 100
not_test: # normal scenarios
n_sample: 50
n_iter: 100
n_sample: 5
n_iter: 10

Binary file not shown.

Binary file not shown.

Binary file not shown.

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -44,7 +44,7 @@ if __name__ == '__main__':
# 输入参数
parser = argparse.ArgumentParser(description='setting')
parser.add_argument('--exp', type=str, default='without_exp')
parser.add_argument('--job', type=int, default='4')
parser.add_argument('--job', type=int, default='1')
parser.add_argument('--reset_sample', type=int, default='0')
parser.add_argument('--reset_db', type=bool, default=False)

View File

@ -59,11 +59,9 @@ class MyModel(Model):
self.firm_network = nx.MultiDiGraph() # 企业之间的有向多重图。
self.firm_prod_network = nx.MultiDiGraph() # 企业与产品关系的有向多重图。
self.product_network = nx.MultiDiGraph() # 产品之间的有向多重图。
self.G_FirmProd = nx.MultiDiGraph() # 初始化 企业-产品 关系的图
self.G_Firm = nx.MultiDiGraph() # 使用 NetworkX 的有向多重图
# BOM物料清单
self.G_bom = nx.adjacency_graph(json.loads(params['g_bom'])) # 表示 BOM 结构的图。
self.g_bom = nx.adjacency_graph(json.loads(params['g_bom'])) # 表示 BOM 结构的图。
# 随机数生成器
self.nprandom = np.random.default_rng(params['seed']) # 基于固定种子的随机数生成器。
@ -141,88 +139,65 @@ class MyModel(Model):
初始化企业网络处理一个 Code 映射到多个 Index 的情况并缓存所有相关属性
"""
cache_file = "firm_network_cache.pkl"
# 加载企业数据
firm_data = pd.read_csv("input_data/input_firm_data/firm_amended.csv", dtype={'Code': str})
firm_data['Code'] = firm_data['Code'].str.replace('.0', '', regex=False)
# 检查是否存在缓存
if os.path.exists(cache_file):
with open(cache_file, 'rb') as f:
cached_data = pickle.load(f)
# 加载企业与产品关系数据
firm_industry_relation = pd.read_csv("input_data/firm_industry_relation.csv", dtype={'Firm_Code': str})
bom_nodes = pd.read_csv("input_data/input_product_data/BomNodes.csv")
# 加载缓存的属性
self.G_Firm = cached_data['G_Firm']
self.firm_product_cache = cached_data['firm_product_cache']
self.firm_relationship_cache = cached_data['firm_relationship_cache']
print("Loaded firm network and related data from cache.")
return
# 构建 Code -> [Index] 的多值映射
code_to_indices = bom_nodes.groupby('Code')['Index'].apply(list).to_dict()
# 如果没有缓存,则从头初始化网络
try:
# 加载企业数据
firm_data = pd.read_csv("input_data/input_firm_data/firm_amended.csv", dtype={'Code': str})
firm_data['Code'] = firm_data['Code'].str.replace('.0', '', regex=False)
# 将 Product_Code 转换为 Product_Indices
firm_industry_relation['Product_Indices'] = firm_industry_relation['Product_Code'].map(code_to_indices)
# 创建企业网络图
self.G_Firm = nx.Graph()
self.G_Firm.add_nodes_from(firm_data['Code'])
# 检查并处理未映射的 Product_Code
unmapped_products = firm_industry_relation[firm_industry_relation['Product_Indices'].isna()]
if not unmapped_products.empty:
print("Warning: The following Product_Code values could not be mapped to Index:")
print(unmapped_products[['Firm_Code', 'Product_Code']])
# 设置节点属性
firm_attributes = firm_data.set_index('Code').to_dict('index')
nx.set_node_attributes(self.G_Firm, firm_attributes)
firm_industry_relation['Product_Indices'] = firm_industry_relation['Product_Indices'].apply(
lambda x: x if isinstance(x, list) else []
)
print(f"Initialized G_Firm with {len(self.G_Firm.nodes)} nodes.")
# 按 Firm_Code 分组生成企业的 Product_Code 和 Product_Indices 映射
firm_product = (
firm_industry_relation.groupby('Firm_Code')['Product_Code'].apply(list)
)
firm_product_indices = (
firm_industry_relation.groupby('Firm_Code')['Product_Indices']
.apply(lambda indices: [idx for sublist in indices for idx in sublist])
)
# 加载企业与产品关系数据
firm_industry_relation = pd.read_csv("input_data/firm_industry_relation.csv", dtype={'Firm_Code': str})
bom_nodes = pd.read_csv("input_data/input_product_data/BomNodes.csv")
# 设置企业属性并添加到网络中
firm_attributes = firm_data.copy()
firm_attributes['Product_Indices'] = firm_attributes['Code'].map(firm_product)
firm_attributes['Product_Code'] = firm_attributes['Code'].map(firm_product_indices)
firm_attributes.set_index('Code', inplace=True)
# 构建 Code -> [Index] 的多值映射
code_to_indices = bom_nodes.groupby('Code')['Index'].apply(list).to_dict()
self.firm_network.add_nodes_from(firm_data['Code'])
# 将 Product_Code 转换为 Product_Indices
firm_industry_relation['Product_Indices'] = firm_industry_relation['Product_Code'].map(code_to_indices)
# 为企业节点分配属性
firm_labels_dict = {code: firm_attributes.loc[code].to_dict() for code in self.firm_network.nodes}
nx.set_node_attributes(self.firm_network, firm_labels_dict)
# 检查并处理未映射的 Product_Code
unmapped_products = firm_industry_relation[firm_industry_relation['Product_Indices'].isna()]
if not unmapped_products.empty:
print("Warning: The following Product_Code values could not be mapped to Index:")
print(unmapped_products[['Firm_Code', 'Product_Code']])
# 构建企业-产品映射缓存
self.firm_product_cache = firm_product_indices.to_dict()
firm_industry_relation['Product_Indices'] = firm_industry_relation['Product_Indices'].apply(
lambda x: x if isinstance(x, list) else []
)
# 构建企业-产品映射缓存
self.firm_product_cache = (
firm_industry_relation.groupby('Firm_Code')['Product_Indices']
.apply(lambda indices: [idx for sublist in indices for idx in sublist]) # 展平嵌套列表
.to_dict()
)
print(f"Built firm_product_cache with {len(self.firm_product_cache)} entries.")
# 构建企业关系缓存
self.firm_relationship_cache = {
firm: self.compute_firm_relationship(firm, self.firm_product_cache)
for firm in self.firm_product_cache
}
print(f"Built firm_relationship_cache with {len(self.firm_relationship_cache)} entries.")
# 保存所有关键属性到缓存
cached_data = {
'G_Firm': self.G_Firm,
'firm_product_cache': self.firm_product_cache,
'firm_relationship_cache': self.firm_relationship_cache
}
with open(cache_file, 'wb') as f:
pickle.dump(cached_data, f)
print("Saved firm network and related data to cache.")
except Exception as e:
print(f"Error during network initialization: {e}")
# 构建企业关系缓存
self.firm_relationship_cache = {
firm: self.compute_firm_relationship(firm, self.firm_product_cache)
for firm in self.firm_product_cache
}
def compute_firm_relationship(self, firm, firm_product_cache):
"""计算单个企业的供应链关系"""
lst_pred_product_code = []
for product_code in firm_product_cache[firm]:
lst_pred_product_code += list(self.G_bom.predecessors(product_code))
lst_pred_product_code += list(self.g_bom.predecessors(product_code))
return list(set(lst_pred_product_code)) # 返回唯一值列表
def build_firm_prod_labels_dict(self):
@ -246,54 +221,22 @@ class MyModel(Model):
3. 将产品代码与索引进行映射并为网络节点分配属性
4. 缓存网络和相关数据以加速后续运行
"""
# 加载企业-行业关系数据
firm_industry_relation = pd.read_csv("input_data/firm_industry_relation.csv")
firm_industry_relation['Firm_Code'] = firm_industry_relation['Firm_Code'].astype(str)
firm_industry_relation['Product_Code'] = firm_industry_relation['Product_Code'].apply(lambda x: [x])
cache_file = "firm_product_network_cache.pkl"
# 映射产品代码到索引
firm_industry_relation['Product_Code'] = firm_industry_relation['Product_Code'].apply(
lambda codes: [idx for code in codes for idx in self.id_code.get(str(code), [])]
)
# 检查缓存文件是否存在
if os.path.exists(cache_file):
try:
with open(cache_file, 'rb') as f:
cached_data = pickle.load(f)
self.G_FirmProd = cached_data['G_FirmProd']
print("Loaded firm-product network from cache.")
return
except Exception as e:
print(f"Error loading cache: {e}. Reinitializing network.")
try:
# 加载企业-行业关系数据
firm_industry_relation = pd.read_csv("input_data/firm_industry_relation.csv")
firm_industry_relation['Firm_Code'] = firm_industry_relation['Firm_Code'].astype(str)
firm_industry_relation['Product_Code'] = firm_industry_relation['Product_Code'].apply(lambda x: [x])
# 创建企业-产品网络图
self.G_FirmProd.add_nodes_from(firm_industry_relation.index)
# 遍历数据行,将产品代码映射到索引,并更新关系表
for index, row in firm_industry_relation.iterrows():
id_index_list = []
for product_code in row['Product_Code']:
if str(product_code) in self.id_code:
id_index_list.extend(self.id_code[str(product_code)])
firm_industry_relation.at[index, 'Product_Code'] = id_index_list
# 为每个节点分配属性
firm_prod_labels_dict = {code: firm_industry_relation.loc[code].to_dict() for code in
firm_industry_relation.index}
nx.set_node_attributes(self.G_FirmProd, firm_prod_labels_dict)
print(f"Initialized G_FirmProd with {len(self.G_FirmProd.nodes)} nodes.")
# 缓存网络和数据
cached_data = {'G_FirmProd': self.G_FirmProd}
with open(cache_file, 'wb') as f:
pickle.dump(cached_data, f)
print("Saved firm-product network to cache.")
except FileNotFoundError as e:
print(f"Error: {e}. File not found.")
except Exception as e:
print(f"Error initializing firm-product network: {e}")
# 创建企业-产品网络图,同时附带属性
nodes_with_attributes = [
(index, firm_industry_relation.loc[index].to_dict())
for index in firm_industry_relation.index
]
self.firm_prod_network.add_nodes_from(nodes_with_attributes)
def compute_firm_supply_chain(self, firm_industry_relation, g_bom):
"""
@ -311,18 +254,6 @@ class MyModel(Model):
return supply_chain_cache
def add_edges_to_firm_network(self):
"""利用缓存加速企业网络边的添加"""
cache_file = "G_Firm_add_edges.pkl"
# 检查是否存在缓存
if os.path.exists(cache_file):
with open(cache_file, 'rb') as f:
cached_data = pickle.load(f)
# 加载缓存的属性
self.G_Firm = cached_data
print("Loaded G_Firm_add_edges cache.")
return
for firm in self.firm_relationship_cache:
lst_pred_product_code = self.firm_relationship_cache[firm]
for pred_product_code in lst_pred_product_code:
@ -330,15 +261,10 @@ class MyModel(Model):
f for f, products in self.firm_product_cache.items()
if pred_product_code in products
]
# 使用缓存,避免重复查询
lst_choose_firm = self.select_firms(lst_pred_firm)
# 添加边
edges = [(pred_firm, firm, {'Product': pred_product_code}) for pred_firm in lst_choose_firm]
self.G_Firm.add_edges_from(edges)
cached_data = self.G_Firm
with open(cache_file, 'wb') as f:
pickle.dump(cached_data, f)
print("Saved G_Firm_add_edges to cache.")
self.firm_network.add_edges_from(edges)
def select_firms(self, lst_pred_firm):
"""
@ -353,9 +279,9 @@ class MyModel(Model):
valid_firms = []
lst_pred_firm_size = []
for pred_firm in lst_pred_firm:
if pred_firm in self.G_Firm.nodes and 'Revenue_Log' in self.G_Firm.nodes[pred_firm]:
if pred_firm in self.firm_network.nodes and 'Revenue_Log' in self.firm_network.nodes[pred_firm]:
valid_firms.append(pred_firm)
lst_pred_firm_size.append(self.G_Firm.nodes[pred_firm]['Revenue_Log'])
lst_pred_firm_size.append(self.firm_network.nodes[pred_firm]['Revenue_Log'])
# 如果未启用企业规模加权,随机选择
if not self.is_prf_size:
@ -379,20 +305,8 @@ class MyModel(Model):
def add_edges_to_firm_product_network(self, node, pred_product_code, lst_choose_firm):
""" Helper function to add edges to the firm-product network """
cache_file = "G_FirmProd_cache.pkl"
# 检查是否存在缓存
if os.path.exists(cache_file):
with open(cache_file, 'rb') as f:
cached_data = pickle.load(f)
# 加载缓存的属性
self.G_FirmProd = cached_data
print("Loaded add_edges_to_firm from cache.")
return
set_node_prod_code = set(self.G_Firm.nodes[node]['Product_Code'])
set_pred_succ_code = set(self.G_bom.successors(pred_product_code))
set_node_prod_code = set(self.firm_network.nodes[node]['Product_Code'])
set_pred_succ_code = set(self.g_bom.successors(pred_product_code))
lst_use_pred_prod_code = list(set_node_prod_code & set_pred_succ_code)
if len(lst_use_pred_prod_code) == 0:
@ -400,7 +314,7 @@ class MyModel(Model):
pred_node_list = []
for pred_firm in lst_choose_firm:
for n, v in self.G_FirmProd.nodes(data=True):
for n, v in self.firm_prod_network.nodes(data=True):
for v1 in v['Product_Code']:
if v1 == pred_product_code and v['Firm_Code'] == pred_firm:
pred_node_list.append(n)
@ -410,7 +324,7 @@ class MyModel(Model):
pred_node = -1
current_node_list = []
for use_pred_prod_code in lst_use_pred_prod_code:
for n, v in self.G_FirmProd.nodes(data=True):
for n, v in self.firm_prod_network.nodes(data=True):
for v1 in v['Product_Code']:
if v1 == use_pred_prod_code and v['Firm_Code'] == node:
current_node_list.append(n)
@ -419,13 +333,7 @@ class MyModel(Model):
else:
current_node = -1
if current_node != -1 and pred_node != -1:
self.G_FirmProd.add_edge(pred_node, current_node)
# 保存所有关键属性到缓存
cached_data = self.G_FirmProd
with open(cache_file, 'wb') as f:
pickle.dump(cached_data, f)
print("Saved G_FirmProd to cache.")
self.firm_prod_network.add_edge(pred_node, current_node)
def connect_unconnected_nodes(self):
"""
@ -435,34 +343,22 @@ class MyModel(Model):
- 为未连接节点添加边连接到可能的下游企业
- 同时更新 G_FirmProd 网络反映企业与产品的关系
"""
cache_file = "connect_unconnected_nodes_cache.pkl"
# 检查是否存在缓存
if os.path.exists(cache_file):
with open(cache_file, 'rb') as f:
cached_data = pickle.load(f)
# 加载缓存的属性
self.G_Firm = cached_data['firm_network']
self.firm_product_cache = cached_data['firm_prod_network']
print("Loaded G_Firm and firm_product_cache from cache.")
return
for node in nx.nodes(self.G_Firm):
for node in nx.nodes(self.firm_network):
# 如果节点没有任何连接,则处理该节点
if self.G_Firm.degree(node) == 0:
if self.firm_network.degree(node) == 0:
# 获取当前节点的产品列表
product_codes = self.G_Firm.nodes[node].get('Product_Code', [])
product_codes = self.firm_network.nodes[node].get('Product_Code', [])
for product_code in product_codes:
# 查找与当前产品相关的 FirmProd 节点
current_node_list = [
n for n, v in self.G_FirmProd.nodes(data=True)
n for n, v in self.firm_prod_network.nodes(data=True)
if v['Firm_Code'] == node and product_code in v['Product_Code']
]
current_node = current_node_list[0] if current_node_list else -1
# 查找当前产品的所有下游产品代码
succ_product_codes = list(self.G_bom.successors(product_code))
succ_product_codes = list(self.g_bom.successors(product_code))
for succ_product_code in succ_product_codes:
# 查找生产下游产品的企业
succ_firms = [
@ -479,7 +375,7 @@ class MyModel(Model):
if self.is_prf_size:
# 基于企业规模选择供应商
succ_firm_sizes = [
self.G_Firm.nodes[succ_firm].get('Revenue_Log', 0)
self.firm_network.nodes[succ_firm].get('Revenue_Log', 0)
for succ_firm in succ_firms
]
if sum(succ_firm_sizes) > 0:
@ -494,29 +390,20 @@ class MyModel(Model):
# 添加边到 G_Firm 图
edges = [(node, firm, {'Product': product_code}) for firm in selected_firms]
self.G_Firm.add_edges_from(edges)
self.firm_network.add_edges_from(edges)
# 更新 G_FirmProd 网络
for succ_firm in selected_firms:
succ_node_list = [
n for n, v in self.G_FirmProd.nodes(data=True)
n for n, v in self.firm_prod_network.nodes(data=True)
if v['Firm_Code'] == succ_firm and succ_product_code in v['Product_Code']
]
succ_node = succ_node_list[0] if succ_node_list else -1
if current_node != -1 and succ_node != -1:
self.G_FirmProd.add_edge(current_node, succ_node)
self.firm_prod_network.add_edge(current_node, succ_node)
# 保存网络数据到样本
self.firm_network = self.G_Firm # 使用 networkx 图对象表示的企业网络
self.firm_prod_network = self.G_FirmProd # 使用 networkx 图对象表示的企业与产品关系网络
cached_data = {
'firm_network': self.firm_network,
'firm_prod_network': self.firm_prod_network,
}
with open(cache_file, 'wb') as f:
pickle.dump(cached_data, f)
print("Saved firm network and related data to cache.")
self.firm_prod_network = self.firm_prod_network # 使用 networkx 图对象表示的企业与产品关系网络
def initialize_agents(self):
"""
@ -545,17 +432,16 @@ class MyModel(Model):
]
# 获取企业的需求数量和生产输出
demand_quantity = self.data_materials.loc[self.data_materials['Firm_Code'] == ag_node]
production_output = self.data_produced.loc[self.data_produced['Firm_Code'] == ag_node]
demand_quantity = self.data_materials.loc[self.data_materials['Firm_Code'] == int(ag_node)]
production_output = self.data_produced.loc[self.data_produced['Firm_Code'] == int(ag_node)]
# 获取企业的资源信息
# 获取企业的资源信息,同时处理 R、P、C 的情况
try:
R = self.firm_resource_R.loc[int(ag_node)]
P = self.firm_resource_P.loc[int(ag_node)]
P = self.firm_resource_P.get(int(ag_node))
C = self.firm_resource_C.loc[int(ag_node)]
except KeyError:
# 如果资源数据缺失,提供默认值
R, P, C = [], [], []
R, P, C = [], {}, [] # 如果任何资源不存在,返回空列表
# 创建企业代理
firm_agent = FirmAgent(
@ -618,50 +504,38 @@ class MyModel(Model):
- 合并设备数据与设备残值数据
- 按企业分组生成资源列表
"""
try:
# 加载企业的材料、设备和产品数据
data_R = pd.read_csv("input_data/input_firm_data/firms_materials.csv")
data_C = pd.read_csv("input_data/input_firm_data/firms_devices.csv")
data_P = pd.read_csv("input_data/input_firm_data/firms_products.csv")
# 加载企业的材料、设备和产品数据
data_R = pd.read_csv("input_data/input_firm_data/firms_materials.csv")
data_C = pd.read_csv("input_data/input_firm_data/firms_devices.csv")
data_P = pd.read_csv("input_data/input_firm_data/firms_products.csv")
# 加载设备残值数据,并合并到设备数据中
device_salvage_values = pd.read_csv('input_data/device_salvage_values.csv')
self.device_salvage_values = device_salvage_values
# 加载设备残值数据,并合并到设备数据中
device_salvage_values = pd.read_csv('input_data/device_salvage_values.csv')
self.device_salvage_values = device_salvage_values
# 合并设备数据和设备残值
data_merged_C = pd.merge(data_C, device_salvage_values, on='设备id', how='left')
# 合并设备数据和设备残值
data_merged_C = pd.merge(data_C, device_salvage_values, on='设备id', how='left')
# 按企业分组并生成资源列表
firm_resource_R = (
data_R.groupby('Firm_Code')[['材料id', '材料数量']]
.apply(lambda x: x.values.tolist())
.to_dict() # 转换为字典格式,便于快速查询
)
# 按企业分组并生成资源列表
firm_resource_R = (
data_R.groupby('Firm_Code')[['材料id', '材料数量']]
.apply(lambda x: x.values.tolist())
)
firm_resource_C = (
data_merged_C.groupby('Firm_Code')[['设备id', '设备数量', '设备残值']]
.apply(lambda x: x.values.tolist())
.to_dict()
)
firm_resource_C = (
data_merged_C.groupby('Firm_Code')[['设备id', '设备数量', '设备残值']]
.apply(lambda x: x.values.tolist())
)
firm_resource_P = (
data_P.groupby('Firm_Code')[['产品id', '产品数量']]
.apply(lambda x: x.values.tolist())
.to_dict()
)
firm_resource_P = (
data_P.groupby('Firm_Code')[['产品id', '产品数量']]
.apply(lambda x: x.values.tolist())
)
# 将结果存储到模型中
self.firm_resource_R = firm_resource_R
self.firm_resource_C = firm_resource_C
self.firm_resource_P = firm_resource_P
except FileNotFoundError as e:
print(f"Error: Missing input file - {e.filename}")
self.firm_resource_R, self.firm_resource_C, self.firm_resource_P = {}, {}, {}
except Exception as e:
print(f"Error during resource integration: {e}")
self.firm_resource_R, self.firm_resource_C, self.firm_resource_P = {}, {}, {}
# 将结果存储到模型中
self.firm_resource_R = firm_resource_R
self.firm_resource_C = firm_resource_C
self.firm_resource_P = firm_resource_P
def j_comp_consumed_produced(self):
"""
@ -728,7 +602,8 @@ class MyModel(Model):
self._trial_process()
# 3. 判断是否需要采购资源和设备
self._handle_resource_and_machinery_purchase()
self._handle_material_purchase()
self._handle_machinery_purchase()
# 4. 资源消耗和产品生产
self._consume_resources_and_produce()
@ -784,42 +659,107 @@ class MyModel(Model):
for firm in self.company_agents:
firm.clean_before_trial()
def _handle_resource_and_machinery_purchase(self):
"""处理资源和设备的采购。"""
for firm in self.company_agents:
# 判断资源需求
for material_id, material_quantity in firm.R:
if material_quantity <= firm.s_r:
required_quantity = firm.S_r - material_quantity
firm.request_material_purchase(material_id, required_quantity)
def _handle_material_purchase(self):
"""
判断并处理资源的采购
"""
purchase_material_firms = {}
material_list = []
list_seek_material_firm = [] # 每一个收到请求的企业
# 判断设备需求
for device_id, device_quantity, device_salvage in firm.C:
device_salvage -= firm.x
if device_salvage <= 0: # 如果设备残值小于等于 0
device_quantity -= 1
firm.request_device_purchase(device_id, 1)
for firm in self.company_agents:
# 如果 firm.R 为 None跳过处理
if not firm.R:
# print(f"Firm {firm.unique_id} has no resource data.")
continue
# 处理资源需求
for sub_list in firm.R:
if sub_list[1] <= firm.s_r:
required_material_quantity = firm.S_r - sub_list[1]
material_list.append([sub_list[0], required_material_quantity])
purchase_material_firms[firm] = material_list
# 寻源并发送资源请求
for material_firm_key, sub_list_values in purchase_material_firms.items():
for mater_list in sub_list_values:
result = material_firm_key.seek_material_supply(mater_list[0])
if result != -1:
list_seek_material_firm.append(result)
if list_seek_material_firm:
for seek_material_firm in list_seek_material_firm:
seek_material_firm.handle_material_request(mater_list)
for R_list in firm.R:
R_list[1] = firm.S_r
def _handle_machinery_purchase(self):
"""
判断并处理设备的采购
"""
purchase_machinery_firms = {}
machinery_list = []
list_seek_machinery_firm = [] # 每一个收到请求的企业
for firm in self.company_agents:
# 处理设备需求
# 如果 firm.C 为 None跳过处理
if not firm.C:
# print(f"Firm {firm.unique_id} has no equipment to maintain or purchase.")
continue
for sub_list in firm.C:
sub_list[2] -= firm.x # 减少设备残值
if sub_list[2] <= 0: # 残值小于等于 0
sub_list[1] -= 1
required_machinery_quantity = 1
machinery_list.append([sub_list[0], required_machinery_quantity])
purchase_machinery_firms[firm] = machinery_list
# 寻源并发送设备请求
for machinery_firm, sub_list in purchase_machinery_firms.items():
for machi_list in sub_list:
result = machinery_firm.seek_machinery_supply(machi_list[0])
if result != -1:
list_seek_machinery_firm.append(result)
if list_seek_machinery_firm:
for seek_machinery_firm in list_seek_machinery_firm:
seek_machinery_firm.handle_machinery_request(machi_list)
for C_list, C0_list in zip(firm.C, firm.C0):
C_list[1] = C0_list[1] # 恢复初始值
C_list[2] = C0_list[2]
def _consume_resources_and_produce(self):
"""消耗资源并生产产品。"""
"""
消耗资源并生产产品
"""
k = 0.6 # 资源消耗比例
production_increase_ratio = 1.6 # 产品生产比例
for firm in self.company_agents:
# 计算消耗量
# 计算行业的资源消耗
consumed_resources = {}
for industry in firm.indus_i:
for product_id, product_quantity in firm.P.items():
if product_id == industry.unique_id:
consumed_resources[industry] = product_quantity * k
# 计算当前行业的资源消耗量
consumed_quantity = sum(
product[1] * k # product[1] 是产品的数量
for product in firm.P
if product[0] == industry.unique_id # product[0] 是产品的 ID
)
# 将计算的消耗量记录到字典中
consumed_resources[industry] = consumed_quantity
# 消耗资源
for resource_id, resource_quantity in firm.R.items():
for resource in firm.R:
resource_id, resource_quantity = resource[0], resource[1]
for industry, consumed_quantity in consumed_resources.items():
if resource_id in industry.resource_ids:
firm.R[resource_id] -= consumed_quantity
if resource_id == industry.unique_id: # 判断资源是否属于该行业
resource[1] = max(0, resource_quantity - consumed_quantity)
# 生产产品
for product_id, product_quantity in firm.P.items():
firm.P[product_id] = product_quantity * 1.6
for product in firm.P:
product_id, product_quantity = product[0], product[1]
product[1] = product_quantity * production_increase_ratio
# 刷新资源和设备状态
firm.refresh_R()

49
查看进度.py Normal file
View File

@ -0,0 +1,49 @@
from matplotlib import rcParams, pyplot as plt
from sqlalchemy import func
from orm import db_session, Sample
def visualize_progress():
"""
可视化 `is_done_flag` 的分布
"""
# 设置全局字体
rcParams['font.family'] = 'SimHei' # 黑体,适用于中文
rcParams['font.size'] = 12
# 查询数据库中各 is_done_flag 的数量
result = db_session.query(
Sample.is_done_flag, func.count(Sample.id)
).group_by(Sample.is_done_flag).all()
# 转换为字典
data = {flag: count for flag, count in result}
# 填充缺失的标志为 0
for flag in [-1, 0, 1]:
data.setdefault(flag, 0)
# 准备数据
labels = ['未完成 (-1)', '完成(0)', '等待 (1)']
values = [data[-1], data[0], data[1]]
# 创建柱状图
plt.figure(figsize=(8, 5))
plt.bar(labels, values, color=['red', 'orange', 'green'])
plt.title('任务进度分布', fontsize=16)
plt.xlabel('任务状态', fontsize=14)
plt.ylabel('数量', fontsize=14)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
# 显示具体数量
for i, v in enumerate(values):
plt.text(i, v + 0.5, str(v), ha='center', fontsize=12)
# 显示图表
plt.tight_layout()
plt.show()
visualize_progress()