Compare commits

...

4 Commits

Author SHA1 Message Date
Cricial 30e7e56c11 暂时通过测试 2025-01-27 15:18:51 +08:00
Cricial 0f7f9c1a4b 优化查看进度 2025-01-27 01:15:15 +08:00
Cricial 3a46d09b8e 优化 干扰列表 2025-01-27 01:04:21 +08:00
Cricial 13521ff752 优化代码中的网络结构构建过程,未完成缓存设置 和 多进程共享内容 2025-01-26 23:56:37 +08:00
31 changed files with 1067 additions and 51527 deletions

View File

@ -3,7 +3,7 @@
<component name="CsvFileAttributes">
<option name="attributeMap">
<map>
<entry key="C:\Users\www\Desktop\半导体数据——暂时的数据\我做的数据\汇总数据\BomCateNet.csv">
<entry key="C:\Users\www\Desktop\python项目\数据\抽样第3次数据\firm_amended.csv">
<value>
<Attribute>
<option name="separator" value="," />

29
11.py Normal file
View File

@ -0,0 +1,29 @@
import pickle
import os
from 查看进度 import visualize_progress
def load_cached_data(file_path):
"""
从指定的缓存文件加载数据
如果文件不存在或加载失败则返回空字典
"""
if not os.path.exists(file_path):
print(f"Warning: Cache file '{file_path}' does not exist.")
return {}
try:
with open(file_path, 'rb') as f:
data = pickle.load(f)
print(f"Successfully loaded cache from '{file_path}'.")
return data
except (pickle.UnpicklingError, FileNotFoundError, EOFError) as e:
print(f"Error loading cache from '{file_path}': {e}")
return {}
# 示例用法
# data_dct = load_cached_data("G_Firm_add_edges.pkl")
visualize_progress()

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -41,11 +41,9 @@ class Computation:
'seed': sample_random.seed,
**dct_exp}
product_network_test = nx.adjacency_graph(json.loads(dct_sample_para['g_bom']))
model = MyModel(dct_sample_para)
model.step() # 运行仿真一步
model.end() # 汇总结果
return False

View File

@ -1 +1 @@
db_name_prefix: without_exp
db_name_prefix: with_exp

View File

@ -8,5 +8,5 @@ test: # only for test scenarios
n_iter: 100
not_test: # normal scenarios
n_sample: 50
n_iter: 100
n_sample: 10
n_iter: 50

Binary file not shown.

View File

@ -53,7 +53,6 @@ class ControllerDB:
# fill dct_lst_init_disrupt_firm_prod
# 存储 公司-在供应链结点的位置.. 0 1.1
list_dct = [] # 存储 公司编码code 和对应的产业链 结点
if self.is_with_exp:
# 对于方差分析时候使用
with open('SQL_export_high_risk_setting.sql', 'r') as f:
@ -67,12 +66,27 @@ class ControllerDB:
# 行索引 (index):这一行在数据帧中的索引值。
# 行数据 (row):这一行的数据,是一个 pandas.Series 对象,包含该行的所有列和值。
# 读取企业与产品关系数据
firm_industry = pd.read_csv("input_data/firm_industry_relation.csv")
firm_industry['Firm_Code'] = firm_industry['Firm_Code'].astype('string')
# 假设已从 BOM 数据构建了 code_to_indices
bom_nodes = pd.read_csv("input_data/input_product_data/BomNodes.csv")
code_to_indices = bom_nodes.groupby('Code')['Index'].apply(list).to_dict()
# 初始化存储映射结果的列表
list_dct = []
# 遍历 firm_industry 数据
for _, row in firm_industry.iterrows():
code = row['Firm_Code']
row = row['Product_Code']
dct = {code: [row]}
firm_code = row['Firm_Code'] # 企业代码
product_code = row['Product_Code'] # 原始产品代码
# 使用 code_to_indices 映射 Product_Code 到 Product_Indices
mapped_indices = code_to_indices.get(product_code, []) # 如果找不到则返回空列表
# 构建企业到产品索引的映射
dct = {firm_code: mapped_indices}
list_dct.append(dct)
# fill g_bom

95
firm.py
View File

@ -32,13 +32,13 @@ class FirmAgent(Agent):
# 包括 产品时间
self.P1 = {0: P}
# 企业i的供应商
self.upper_i = [agent for u, v in self.firm_network.in_edges(self.unique_id)
for agent in self.model.company_agents if agent.unique_id == u]
self.upper_i = [self.model.agent_map[u] for u, v in self.firm_network.in_edges(self.unique_id)
if u in self.model.agent_map]
# 企业i的客户
self.downer_i = [agent for u, v in self.firm_network.out_edges(self.unique_id)
for agent in self.model.company_agents if agent.unique_id == u]
self.downer_i = [self.model.agent_map[v] for u, v in self.firm_network.out_edges(self.unique_id)
if v in self.model.agent_map]
# 设备c的数量 (总量) 使用这个来判断设备数量
#self.n_equip_c = n_equip_c
# self.n_equip_c = n_equip_c
# 设备c产量 根据设备量进行估算
self.c_yield = production_output
# 消耗材料量 根据设备量进行估算 { }
@ -143,42 +143,69 @@ class FirmAgent(Agent):
# f"disrupted supplier of {disrupted_up_prod.code}")
def seek_alt_supply(self, product):
# 检查当前产品的尝试次数是否达到最大值
if self.dct_n_trial_up_prod_disrupted[product] <= self.model.int_n_max_trial:
# 初始化候选供应商列表
if self.dct_n_trial_up_prod_disrupted[product] == 0:
self.dct_cand_alt_supp_up_prod_disrupted[product] = [
firm for firm in self.model.company_agents
if firm.is_prod_in_current_normal(product)]
if self.dct_cand_alt_supp_up_prod_disrupted[product]:
lst_firm_connect = []
if self.is_prf_conn:
for firm in self.dct_cand_alt_supp_up_prod_disrupted[product]:
if self.firm_network.has_edge(self.unique_id, firm.unique_id) or \
self.firm_network.has_edge(firm.unique_id, self.unique_id):
lst_firm_connect.append(firm)
if len(lst_firm_connect) == 0:
if self.is_prf_size:
lst_size = [firm.size_stat[-1][0] for firm in self.dct_cand_alt_supp_up_prod_disrupted[product]]
lst_prob = [size / sum(lst_size) for size in lst_size]
select_alt_supply = \
self.random.choices(self.dct_cand_alt_supp_up_prod_disrupted[product], weights=lst_prob)[0]
else:
select_alt_supply = self.random.choice(self.dct_cand_alt_supp_up_prod_disrupted[product])
elif len(lst_firm_connect) > 0:
if self.is_prf_size:
lst_firm_size = [firm.size_stat[-1][0] for firm in lst_firm_connect]
lst_prob = [size / sum(lst_firm_size) for size in lst_firm_size]
select_alt_supply = self.random.choices(lst_firm_connect, weights=lst_prob)[0]
else:
select_alt_supply = self.random.choice(lst_firm_connect)
firm for firm in self.model.company_agents if firm.is_prod_in_current_normal(product)
]
assert select_alt_supply.is_prod_in_current_normal(product)
# 如果没有候选供应商,直接退出
if not self.dct_cand_alt_supp_up_prod_disrupted[product]:
# print(f"No valid candidates found for product {product.unique_id}")
return
if product in select_alt_supply.dct_request_prod_from_firm:
select_alt_supply.dct_request_prod_from_firm[product].append(self)
# 查找与当前企业已连接的候选供应商
lst_firm_connect = []
if self.is_prf_conn:
lst_firm_connect = [
firm for firm in self.dct_cand_alt_supp_up_prod_disrupted[product]
if self.firm_network.has_edge(self.unique_id, firm.unique_id) or
self.firm_network.has_edge(firm.unique_id, self.unique_id)
]
# 如果没有连接的供应商
if not lst_firm_connect:
if self.is_prf_size: # 根据规模加权选择
lst_size = [firm.size_stat[-1][0] for firm in self.dct_cand_alt_supp_up_prod_disrupted[product]]
lst_prob = [size / sum(lst_size) for size in lst_size]
select_alt_supply = self.random.choices(
self.dct_cand_alt_supp_up_prod_disrupted[product], weights=lst_prob
)[0]
else: # 随机选择
select_alt_supply = self.random.choice(self.dct_cand_alt_supp_up_prod_disrupted[product])
else: # 如果存在连接的供应商
if self.is_prf_size: # 根据规模加权选择
lst_firm_size = [firm.size_stat[-1][0] for firm in lst_firm_connect]
lst_prob = [size / sum(lst_firm_size) for size in lst_firm_size]
select_alt_supply = self.random.choices(lst_firm_connect, weights=lst_prob)[0]
else: # 随机选择
select_alt_supply = self.random.choice(lst_firm_connect)
# 检查选中的供应商是否能够生产产品
if not select_alt_supply.is_prod_in_current_normal(product):
# print(f"Selected supplier {select_alt_supply.unique_id} cannot produce product {product.unique_id}")
# 打印供应商的生产状态字典
#print(f"Supplier production state: {select_alt_supply.dct_prod_up_prod_stat}")
# 检查产品是否存在于生产状态字典中
if product in select_alt_supply.dct_prod_up_prod_stat:
print(
f"Product {product.unique_id} production state: {select_alt_supply.dct_prod_up_prod_stat[product]['p_stat']}")
else:
select_alt_supply.dct_request_prod_from_firm[product] = [self]
print(f"Product {product.unique_id} not found in supplier production state.")
return
self.dct_n_trial_up_prod_disrupted[product] += 1
# 添加到供应商的请求字典
if product in select_alt_supply.dct_request_prod_from_firm:
select_alt_supply.dct_request_prod_from_firm[product].append(self)
else:
select_alt_supply.dct_request_prod_from_firm[product] = [self]
# 更新尝试次数
self.dct_n_trial_up_prod_disrupted[product] += 1
def handle_request(self):
for product, lst_firm in self.dct_request_prod_from_firm.items():

BIN
firm_network.pkl Normal file

Binary file not shown.

Binary file not shown.

Binary file not shown.

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

13
main.py
View File

@ -3,10 +3,14 @@ import random
import time
from multiprocessing import Process
import argparse
from matplotlib import pyplot as plt
from computation import Computation
from sqlalchemy.orm import close_all_sessions
import yaml
from controller_db import ControllerDB
from 查看进度 import visualize_progress
def controll_db_and_process(exp_argument, reset_sample_argument, reset_db_argument):
@ -30,6 +34,13 @@ def do_process(target: object, controller_db: ControllerDB, ):
for i in process_list:
i.join()
# 所有子进程完成后刷新最终进度
visualize_progress()
# 显示最终进度后关闭图表
plt.show()
def do_computation(c_db):
exp = Computation(c_db)
@ -43,7 +54,7 @@ def do_computation(c_db):
if __name__ == '__main__':
# 输入参数
parser = argparse.ArgumentParser(description='setting')
parser.add_argument('--exp', type=str, default='without_exp')
parser.add_argument('--exp', type=str, default='with_exp')
parser.add_argument('--job', type=int, default='4')
parser.add_argument('--reset_sample', type=int, default='0')
parser.add_argument('--reset_db', type=bool, default=False)

View File

@ -42,6 +42,7 @@ class MyModel(Model):
- seed (int): 随机种子的值用于确保实验的可重复性
"""
# 仿真参数
self.agent_map = None
self.firm_prod_labels_dict = None
self.firm_relationship_cache = None
self.firm_product_cache = None
@ -53,17 +54,15 @@ class MyModel(Model):
self.cap_limit_level = params['cap_limit_level'] # 产能限制的水平。
self.diff_new_conn = params['diff_new_conn'] # 是否允许差异化的新连接。
# 初始化停止时间步,可能是用户通过参数传入
self.int_stop_ts = params.get('stop_t', 3) # 默认停止时间为 100
self.int_stop_ts = params.get('n_iter', 3) # 默认停止时间为 100
# 网络初始化
self.firm_network = nx.MultiDiGraph() # 企业之间的有向多重图。
self.firm_prod_network = nx.MultiDiGraph() # 企业与产品关系的有向多重图。
self.product_network = nx.MultiDiGraph() # 产品之间的有向多重图。
self.G_FirmProd = nx.MultiDiGraph() # 初始化 企业-产品 关系的图
self.G_Firm = nx.MultiDiGraph() # 使用 NetworkX 的有向多重图
# BOM物料清单
self.G_bom = nx.adjacency_graph(json.loads(params['g_bom'])) # 表示 BOM 结构的图。
self.g_bom = nx.adjacency_graph(json.loads(params['g_bom'])) # 表示 BOM 结构的图。
# 随机数生成器
self.nprandom = np.random.default_rng(params['seed']) # 基于固定种子的随机数生成器。
@ -87,12 +86,22 @@ class MyModel(Model):
self.company_agents = [] # 初始化公司代理列表
# 初始化模型的网络和代理
# 检查缓存是否存在
cache_file = "firm_network.pkl"
if os.path.exists(cache_file):
# 从缓存加载 firm_network
with open(cache_file, 'rb') as f:
self.firm_network = pickle.load(f)
print("Loaded firm network from cache.")
else:
# 执行完整的初始化流程
self.initialize_product_network(params)
self.initialize_firm_network()
self.build_firm_prod_labels_dict()
self.initialize_firm_product_network()
self.add_edges_to_firm_network()
self.connect_unconnected_nodes()
self.initialize_product_network(params) # 初始化产品网络。
self.initialize_firm_network() # 初始化企业网络。
self.build_firm_prod_labels_dict() # 构建企业与产品的映射关系字典
self.initialize_firm_product_network() # 初始化企业与产品的网络。
self.add_edges_to_firm_network() # 添加企业之间的边。
self.connect_unconnected_nodes() # 连接未连接的节点。
self.resource_integration()
self.j_comp_consumed_produced()
self.initialize_agents() # 初始化代理。
@ -141,88 +150,65 @@ class MyModel(Model):
初始化企业网络处理一个 Code 映射到多个 Index 的情况并缓存所有相关属性
"""
cache_file = "firm_network_cache.pkl"
# 加载企业数据
firm_data = pd.read_csv("input_data/input_firm_data/firm_amended.csv", dtype={'Code': str})
firm_data['Code'] = firm_data['Code'].str.replace('.0', '', regex=False)
# 检查是否存在缓存
if os.path.exists(cache_file):
with open(cache_file, 'rb') as f:
cached_data = pickle.load(f)
# 加载企业与产品关系数据
firm_industry_relation = pd.read_csv("input_data/firm_industry_relation.csv", dtype={'Firm_Code': str})
bom_nodes = pd.read_csv("input_data/input_product_data/BomNodes.csv")
# 加载缓存的属性
self.G_Firm = cached_data['G_Firm']
self.firm_product_cache = cached_data['firm_product_cache']
self.firm_relationship_cache = cached_data['firm_relationship_cache']
print("Loaded firm network and related data from cache.")
return
# 构建 Code -> [Index] 的多值映射
code_to_indices = bom_nodes.groupby('Code')['Index'].apply(list).to_dict()
# 如果没有缓存,则从头初始化网络
try:
# 加载企业数据
firm_data = pd.read_csv("input_data/input_firm_data/firm_amended.csv", dtype={'Code': str})
firm_data['Code'] = firm_data['Code'].str.replace('.0', '', regex=False)
# 将 Product_Code 转换为 Product_Indices
firm_industry_relation['Product_Indices'] = firm_industry_relation['Product_Code'].map(code_to_indices)
# 创建企业网络图
self.G_Firm = nx.Graph()
self.G_Firm.add_nodes_from(firm_data['Code'])
# 检查并处理未映射的 Product_Code
unmapped_products = firm_industry_relation[firm_industry_relation['Product_Indices'].isna()]
if not unmapped_products.empty:
print("Warning: The following Product_Code values could not be mapped to Index:")
print(unmapped_products[['Firm_Code', 'Product_Code']])
# 设置节点属性
firm_attributes = firm_data.set_index('Code').to_dict('index')
nx.set_node_attributes(self.G_Firm, firm_attributes)
firm_industry_relation['Product_Indices'] = firm_industry_relation['Product_Indices'].apply(
lambda x: x if isinstance(x, list) else []
)
print(f"Initialized G_Firm with {len(self.G_Firm.nodes)} nodes.")
# 按 Firm_Code 分组生成企业的 Product_Code 和 Product_Indices 映射
firm_product = (
firm_industry_relation.groupby('Firm_Code')['Product_Code'].apply(list)
)
firm_product_indices = (
firm_industry_relation.groupby('Firm_Code')['Product_Indices']
.apply(lambda indices: [idx for sublist in indices for idx in sublist])
)
# 加载企业与产品关系数据
firm_industry_relation = pd.read_csv("input_data/firm_industry_relation.csv", dtype={'Firm_Code': str})
bom_nodes = pd.read_csv("input_data/input_product_data/BomNodes.csv")
# 设置企业属性并添加到网络中
firm_attributes = firm_data.copy()
firm_attributes['Product_Indices'] = firm_attributes['Code'].map(firm_product)
firm_attributes['Product_Code'] = firm_attributes['Code'].map(firm_product_indices)
firm_attributes.set_index('Code', inplace=True)
# 构建 Code -> [Index] 的多值映射
code_to_indices = bom_nodes.groupby('Code')['Index'].apply(list).to_dict()
self.firm_network.add_nodes_from(firm_data['Code'])
# 将 Product_Code 转换为 Product_Indices
firm_industry_relation['Product_Indices'] = firm_industry_relation['Product_Code'].map(code_to_indices)
# 为企业节点分配属性
firm_labels_dict = {code: firm_attributes.loc[code].to_dict() for code in self.firm_network.nodes}
nx.set_node_attributes(self.firm_network, firm_labels_dict)
# 检查并处理未映射的 Product_Code
unmapped_products = firm_industry_relation[firm_industry_relation['Product_Indices'].isna()]
if not unmapped_products.empty:
print("Warning: The following Product_Code values could not be mapped to Index:")
print(unmapped_products[['Firm_Code', 'Product_Code']])
# 构建企业-产品映射缓存
self.firm_product_cache = firm_product_indices.to_dict()
firm_industry_relation['Product_Indices'] = firm_industry_relation['Product_Indices'].apply(
lambda x: x if isinstance(x, list) else []
)
# 构建企业-产品映射缓存
self.firm_product_cache = (
firm_industry_relation.groupby('Firm_Code')['Product_Indices']
.apply(lambda indices: [idx for sublist in indices for idx in sublist]) # 展平嵌套列表
.to_dict()
)
print(f"Built firm_product_cache with {len(self.firm_product_cache)} entries.")
# 构建企业关系缓存
self.firm_relationship_cache = {
firm: self.compute_firm_relationship(firm, self.firm_product_cache)
for firm in self.firm_product_cache
}
print(f"Built firm_relationship_cache with {len(self.firm_relationship_cache)} entries.")
# 保存所有关键属性到缓存
cached_data = {
'G_Firm': self.G_Firm,
'firm_product_cache': self.firm_product_cache,
'firm_relationship_cache': self.firm_relationship_cache
}
with open(cache_file, 'wb') as f:
pickle.dump(cached_data, f)
print("Saved firm network and related data to cache.")
except Exception as e:
print(f"Error during network initialization: {e}")
# 构建企业关系缓存
self.firm_relationship_cache = {
firm: self.compute_firm_relationship(firm, self.firm_product_cache)
for firm in self.firm_product_cache
}
def compute_firm_relationship(self, firm, firm_product_cache):
"""计算单个企业的供应链关系"""
lst_pred_product_code = []
for product_code in firm_product_cache[firm]:
lst_pred_product_code += list(self.G_bom.predecessors(product_code))
lst_pred_product_code += list(self.g_bom.predecessors(product_code))
return list(set(lst_pred_product_code)) # 返回唯一值列表
def build_firm_prod_labels_dict(self):
@ -246,54 +232,22 @@ class MyModel(Model):
3. 将产品代码与索引进行映射并为网络节点分配属性
4. 缓存网络和相关数据以加速后续运行
"""
# 加载企业-行业关系数据
firm_industry_relation = pd.read_csv("input_data/firm_industry_relation.csv")
firm_industry_relation['Firm_Code'] = firm_industry_relation['Firm_Code'].astype(str)
firm_industry_relation['Product_Code'] = firm_industry_relation['Product_Code'].apply(lambda x: [x])
cache_file = "firm_product_network_cache.pkl"
# 映射产品代码到索引
firm_industry_relation['Product_Code'] = firm_industry_relation['Product_Code'].apply(
lambda codes: [idx for code in codes for idx in self.id_code.get(str(code), [])]
)
# 检查缓存文件是否存在
if os.path.exists(cache_file):
try:
with open(cache_file, 'rb') as f:
cached_data = pickle.load(f)
self.G_FirmProd = cached_data['G_FirmProd']
print("Loaded firm-product network from cache.")
return
except Exception as e:
print(f"Error loading cache: {e}. Reinitializing network.")
try:
# 加载企业-行业关系数据
firm_industry_relation = pd.read_csv("input_data/firm_industry_relation.csv")
firm_industry_relation['Firm_Code'] = firm_industry_relation['Firm_Code'].astype(str)
firm_industry_relation['Product_Code'] = firm_industry_relation['Product_Code'].apply(lambda x: [x])
# 创建企业-产品网络图
self.G_FirmProd.add_nodes_from(firm_industry_relation.index)
# 遍历数据行,将产品代码映射到索引,并更新关系表
for index, row in firm_industry_relation.iterrows():
id_index_list = []
for product_code in row['Product_Code']:
if str(product_code) in self.id_code:
id_index_list.extend(self.id_code[str(product_code)])
firm_industry_relation.at[index, 'Product_Code'] = id_index_list
# 为每个节点分配属性
firm_prod_labels_dict = {code: firm_industry_relation.loc[code].to_dict() for code in
firm_industry_relation.index}
nx.set_node_attributes(self.G_FirmProd, firm_prod_labels_dict)
print(f"Initialized G_FirmProd with {len(self.G_FirmProd.nodes)} nodes.")
# 缓存网络和数据
cached_data = {'G_FirmProd': self.G_FirmProd}
with open(cache_file, 'wb') as f:
pickle.dump(cached_data, f)
print("Saved firm-product network to cache.")
except FileNotFoundError as e:
print(f"Error: {e}. File not found.")
except Exception as e:
print(f"Error initializing firm-product network: {e}")
# 创建企业-产品网络图,同时附带属性
nodes_with_attributes = [
(index, firm_industry_relation.loc[index].to_dict())
for index in firm_industry_relation.index
]
self.firm_prod_network.add_nodes_from(nodes_with_attributes)
def compute_firm_supply_chain(self, firm_industry_relation, g_bom):
"""
@ -311,18 +265,6 @@ class MyModel(Model):
return supply_chain_cache
def add_edges_to_firm_network(self):
"""利用缓存加速企业网络边的添加"""
cache_file = "G_Firm_add_edges.pkl"
# 检查是否存在缓存
if os.path.exists(cache_file):
with open(cache_file, 'rb') as f:
cached_data = pickle.load(f)
# 加载缓存的属性
self.G_Firm = cached_data
print("Loaded G_Firm_add_edges cache.")
return
for firm in self.firm_relationship_cache:
lst_pred_product_code = self.firm_relationship_cache[firm]
for pred_product_code in lst_pred_product_code:
@ -330,15 +272,10 @@ class MyModel(Model):
f for f, products in self.firm_product_cache.items()
if pred_product_code in products
]
# 使用缓存,避免重复查询
lst_choose_firm = self.select_firms(lst_pred_firm)
# 添加边
edges = [(pred_firm, firm, {'Product': pred_product_code}) for pred_firm in lst_choose_firm]
self.G_Firm.add_edges_from(edges)
cached_data = self.G_Firm
with open(cache_file, 'wb') as f:
pickle.dump(cached_data, f)
print("Saved G_Firm_add_edges to cache.")
self.firm_network.add_edges_from(edges)
def select_firms(self, lst_pred_firm):
"""
@ -353,9 +290,9 @@ class MyModel(Model):
valid_firms = []
lst_pred_firm_size = []
for pred_firm in lst_pred_firm:
if pred_firm in self.G_Firm.nodes and 'Revenue_Log' in self.G_Firm.nodes[pred_firm]:
if pred_firm in self.firm_network.nodes and 'Revenue_Log' in self.firm_network.nodes[pred_firm]:
valid_firms.append(pred_firm)
lst_pred_firm_size.append(self.G_Firm.nodes[pred_firm]['Revenue_Log'])
lst_pred_firm_size.append(self.firm_network.nodes[pred_firm]['Revenue_Log'])
# 如果未启用企业规模加权,随机选择
if not self.is_prf_size:
@ -379,20 +316,8 @@ class MyModel(Model):
def add_edges_to_firm_product_network(self, node, pred_product_code, lst_choose_firm):
""" Helper function to add edges to the firm-product network """
cache_file = "G_FirmProd_cache.pkl"
# 检查是否存在缓存
if os.path.exists(cache_file):
with open(cache_file, 'rb') as f:
cached_data = pickle.load(f)
# 加载缓存的属性
self.G_FirmProd = cached_data
print("Loaded add_edges_to_firm from cache.")
return
set_node_prod_code = set(self.G_Firm.nodes[node]['Product_Code'])
set_pred_succ_code = set(self.G_bom.successors(pred_product_code))
set_node_prod_code = set(self.firm_network.nodes[node]['Product_Code'])
set_pred_succ_code = set(self.g_bom.successors(pred_product_code))
lst_use_pred_prod_code = list(set_node_prod_code & set_pred_succ_code)
if len(lst_use_pred_prod_code) == 0:
@ -400,7 +325,7 @@ class MyModel(Model):
pred_node_list = []
for pred_firm in lst_choose_firm:
for n, v in self.G_FirmProd.nodes(data=True):
for n, v in self.firm_prod_network.nodes(data=True):
for v1 in v['Product_Code']:
if v1 == pred_product_code and v['Firm_Code'] == pred_firm:
pred_node_list.append(n)
@ -410,7 +335,7 @@ class MyModel(Model):
pred_node = -1
current_node_list = []
for use_pred_prod_code in lst_use_pred_prod_code:
for n, v in self.G_FirmProd.nodes(data=True):
for n, v in self.firm_prod_network.nodes(data=True):
for v1 in v['Product_Code']:
if v1 == use_pred_prod_code and v['Firm_Code'] == node:
current_node_list.append(n)
@ -419,13 +344,7 @@ class MyModel(Model):
else:
current_node = -1
if current_node != -1 and pred_node != -1:
self.G_FirmProd.add_edge(pred_node, current_node)
# 保存所有关键属性到缓存
cached_data = self.G_FirmProd
with open(cache_file, 'wb') as f:
pickle.dump(cached_data, f)
print("Saved G_FirmProd to cache.")
self.firm_prod_network.add_edge(pred_node, current_node)
def connect_unconnected_nodes(self):
"""
@ -435,34 +354,22 @@ class MyModel(Model):
- 为未连接节点添加边连接到可能的下游企业
- 同时更新 G_FirmProd 网络反映企业与产品的关系
"""
cache_file = "connect_unconnected_nodes_cache.pkl"
# 检查是否存在缓存
if os.path.exists(cache_file):
with open(cache_file, 'rb') as f:
cached_data = pickle.load(f)
# 加载缓存的属性
self.G_Firm = cached_data['firm_network']
self.firm_product_cache = cached_data['firm_prod_network']
print("Loaded G_Firm and firm_product_cache from cache.")
return
for node in nx.nodes(self.G_Firm):
for node in nx.nodes(self.firm_network):
# 如果节点没有任何连接,则处理该节点
if self.G_Firm.degree(node) == 0:
if self.firm_network.degree(node) == 0:
# 获取当前节点的产品列表
product_codes = self.G_Firm.nodes[node].get('Product_Code', [])
product_codes = self.firm_network.nodes[node].get('Product_Code', [])
for product_code in product_codes:
# 查找与当前产品相关的 FirmProd 节点
current_node_list = [
n for n, v in self.G_FirmProd.nodes(data=True)
n for n, v in self.firm_prod_network.nodes(data=True)
if v['Firm_Code'] == node and product_code in v['Product_Code']
]
current_node = current_node_list[0] if current_node_list else -1
# 查找当前产品的所有下游产品代码
succ_product_codes = list(self.G_bom.successors(product_code))
succ_product_codes = list(self.g_bom.successors(product_code))
for succ_product_code in succ_product_codes:
# 查找生产下游产品的企业
succ_firms = [
@ -479,7 +386,7 @@ class MyModel(Model):
if self.is_prf_size:
# 基于企业规模选择供应商
succ_firm_sizes = [
self.G_Firm.nodes[succ_firm].get('Revenue_Log', 0)
self.firm_network.nodes[succ_firm].get('Revenue_Log', 0)
for succ_firm in succ_firms
]
if sum(succ_firm_sizes) > 0:
@ -494,29 +401,24 @@ class MyModel(Model):
# 添加边到 G_Firm 图
edges = [(node, firm, {'Product': product_code}) for firm in selected_firms]
self.G_Firm.add_edges_from(edges)
self.firm_network.add_edges_from(edges)
# 更新 G_FirmProd 网络
for succ_firm in selected_firms:
succ_node_list = [
n for n, v in self.G_FirmProd.nodes(data=True)
n for n, v in self.firm_prod_network.nodes(data=True)
if v['Firm_Code'] == succ_firm and succ_product_code in v['Product_Code']
]
succ_node = succ_node_list[0] if succ_node_list else -1
if current_node != -1 and succ_node != -1:
self.G_FirmProd.add_edge(current_node, succ_node)
self.firm_prod_network.add_edge(current_node, succ_node)
# 保存网络数据到样本
self.firm_network = self.G_Firm # 使用 networkx 图对象表示的企业网络
self.firm_prod_network = self.G_FirmProd # 使用 networkx 图对象表示的企业与产品关系网络
cached_data = {
'firm_network': self.firm_network,
'firm_prod_network': self.firm_prod_network,
}
# 保存构建完成的 firm_network 到缓存
cache_file = "firm_network.pkl"
os.makedirs("cache", exist_ok=True)
with open(cache_file, 'wb') as f:
pickle.dump(cached_data, f)
print("Saved firm network and related data to cache.")
pickle.dump(self.firm_network, f)
# print("Firm network has been saved to cache.")
def initialize_agents(self):
"""
@ -545,18 +447,18 @@ class MyModel(Model):
]
# 获取企业的需求数量和生产输出
demand_quantity = self.data_materials.loc[self.data_materials['Firm_Code'] == ag_node]
production_output = self.data_produced.loc[self.data_produced['Firm_Code'] == ag_node]
demand_quantity = self.data_materials.loc[self.data_materials['Firm_Code'] == int(ag_node)]
production_output = self.data_produced.loc[self.data_produced['Firm_Code'] == int(ag_node)]
# 获取企业的资源信息
# 获取企业的资源信息,同时处理 R、P、C 的情况
try:
R = self.firm_resource_R.loc[int(ag_node)]
P = self.firm_resource_P.loc[int(ag_node)]
P = self.firm_resource_P.get(int(ag_node))
C = self.firm_resource_C.loc[int(ag_node)]
except KeyError:
# 如果资源数据缺失,提供默认值
R, P, C = [], [], []
R, P, C = [], {}, [] # 如果任何资源不存在,返回空列表
# 在模型初始化时,构建 unique_id -> agent 的快速映射字典
self.agent_map = {agent.unique_id: agent for agent in self.company_agents}
# 创建企业代理
firm_agent = FirmAgent(
unique_id=ag_node,
@ -580,14 +482,25 @@ class MyModel(Model):
- 更新公司与产品的生产状态为干扰状态
"""
# 构建公司与受干扰产品的映射字典
disruption_mapping = {
firm: [
product for product in self.product_agents
if product.unique_id in lst_product
disruption_mapping = {}
for firm_code, lst_product_indices in self.dct_lst_init_disrupt_firm_prod.items():
# 查找企业对象
firm = next((f for f in self.company_agents if f.unique_id == firm_code), None)
if not firm:
print(f"Warning: Firm {firm_code} not found. Skipping.")
continue
# 查找有效的产品代理
valid_products = [
product for product in self.product_agents if product.unique_id in lst_product_indices
]
for firm_code, lst_product in self.dct_lst_init_disrupt_firm_prod.items()
if (firm := next((f for f in self.company_agents if f.unique_id == firm_code), None))
}
if not valid_products:
print(f"Warning: No valid products found for Firm {firm_code}. Skipping.")
continue
# 更新映射
disruption_mapping[firm] = valid_products
# 更新干扰字典
self.dct_lst_init_disrupt_firm_prod = disruption_mapping
@ -595,10 +508,13 @@ class MyModel(Model):
# 设置初始干扰状态
for firm, disrupted_products in disruption_mapping.items():
for product in disrupted_products:
# 确保产品在公司的生产状态中
# 检查产品是否在企业的生产状态中
if product not in firm.dct_prod_up_prod_stat:
raise ValueError(
f"Product {product.unique_id} not found in firm {firm.unique_id}'s production status.")
print(
f"Warning: Product {product.unique_id} not found in firm "
f"{firm.unique_id}'s production status. Skipping."
)
continue
# 更新产品状态为干扰状态,并记录干扰时间
firm.dct_prod_up_prod_stat[product]['p_stat'].append(('D', self.t))
@ -618,50 +534,38 @@ class MyModel(Model):
- 合并设备数据与设备残值数据
- 按企业分组生成资源列表
"""
try:
# 加载企业的材料、设备和产品数据
data_R = pd.read_csv("input_data/input_firm_data/firms_materials.csv")
data_C = pd.read_csv("input_data/input_firm_data/firms_devices.csv")
data_P = pd.read_csv("input_data/input_firm_data/firms_products.csv")
# 加载企业的材料、设备和产品数据
data_R = pd.read_csv("input_data/input_firm_data/firms_materials.csv")
data_C = pd.read_csv("input_data/input_firm_data/firms_devices.csv")
data_P = pd.read_csv("input_data/input_firm_data/firms_products.csv")
# 加载设备残值数据,并合并到设备数据中
device_salvage_values = pd.read_csv('input_data/device_salvage_values.csv')
self.device_salvage_values = device_salvage_values
# 加载设备残值数据,并合并到设备数据中
device_salvage_values = pd.read_csv('input_data/device_salvage_values.csv')
self.device_salvage_values = device_salvage_values
# 合并设备数据和设备残值
data_merged_C = pd.merge(data_C, device_salvage_values, on='设备id', how='left')
# 合并设备数据和设备残值
data_merged_C = pd.merge(data_C, device_salvage_values, on='设备id', how='left')
# 按企业分组并生成资源列表
firm_resource_R = (
data_R.groupby('Firm_Code')[['材料id', '材料数量']]
.apply(lambda x: x.values.tolist())
.to_dict() # 转换为字典格式,便于快速查询
)
# 按企业分组并生成资源列表
firm_resource_R = (
data_R.groupby('Firm_Code')[['材料id', '材料数量']]
.apply(lambda x: x.values.tolist())
)
firm_resource_C = (
data_merged_C.groupby('Firm_Code')[['设备id', '设备数量', '设备残值']]
.apply(lambda x: x.values.tolist())
.to_dict()
)
firm_resource_C = (
data_merged_C.groupby('Firm_Code')[['设备id', '设备数量', '设备残值']]
.apply(lambda x: x.values.tolist())
)
firm_resource_P = (
data_P.groupby('Firm_Code')[['产品id', '产品数量']]
.apply(lambda x: x.values.tolist())
.to_dict()
)
firm_resource_P = (
data_P.groupby('Firm_Code')[['产品id', '产品数量']]
.apply(lambda x: x.values.tolist())
)
# 将结果存储到模型中
self.firm_resource_R = firm_resource_R
self.firm_resource_C = firm_resource_C
self.firm_resource_P = firm_resource_P
except FileNotFoundError as e:
print(f"Error: Missing input file - {e.filename}")
self.firm_resource_R, self.firm_resource_C, self.firm_resource_P = {}, {}, {}
except Exception as e:
print(f"Error during resource integration: {e}")
self.firm_resource_R, self.firm_resource_C, self.firm_resource_P = {}, {}, {}
# 将结果存储到模型中
self.firm_resource_R = firm_resource_R
self.firm_resource_C = firm_resource_C
self.firm_resource_P = firm_resource_P
def j_comp_consumed_produced(self):
"""
@ -720,21 +624,23 @@ class MyModel(Model):
3. 判断资源和设备是否需要采购并处理采购
4. 资源消耗和产品生产
"""
# 1. 移除客户边并中断客户上游产品
self._remove_disrupted_edges()
self._disrupt_upstream_products()
while self.t < self.int_stop_ts: # 使用循环控制时间步
# 1. 移除客户边并中断客户上游产品
self._remove_disrupted_edges()
self._disrupt_upstream_products()
# 2. 尝试寻找替代供应链
self._trial_process()
# 2. 尝试寻找替代供应链
self._trial_process()
# 3. 判断是否需要采购资源和设备
self._handle_resource_and_machinery_purchase()
# 3. 判断是否需要采购资源和设备
self._handle_material_purchase()
self._handle_machinery_purchase()
# 4. 资源消耗和产品生产
self._consume_resources_and_produce()
# 4. 资源消耗和产品生产
self._consume_resources_and_produce()
# 增加时间步
self.t += 1
# 增加时间步
self.t += 1
# 子方法定义
def _remove_disrupted_edges(self):
@ -784,76 +690,153 @@ class MyModel(Model):
for firm in self.company_agents:
firm.clean_before_trial()
def _handle_resource_and_machinery_purchase(self):
"""处理资源和设备的采购。"""
for firm in self.company_agents:
# 判断资源需求
for material_id, material_quantity in firm.R:
if material_quantity <= firm.s_r:
required_quantity = firm.S_r - material_quantity
firm.request_material_purchase(material_id, required_quantity)
def _handle_material_purchase(self):
"""
判断并处理资源的采购
"""
# 存储需要采购资源的企业及其需求
purchase_material_firms = {}
# 判断设备需求
for device_id, device_quantity, device_salvage in firm.C:
device_salvage -= firm.x
if device_salvage <= 0: # 如果设备残值小于等于 0
device_quantity -= 1
firm.request_device_purchase(device_id, 1)
# 遍历所有企业,检查资源需求
for firm in self.company_agents:
if not firm.R: # 跳过没有资源的企业
continue
# 遍历资源列表,检查哪些资源需要补货
for resource_id, resource_quantity in firm.R:
if resource_quantity <= firm.s_r: # 如果资源低于阈值,记录需求
required_quantity = firm.S_r - resource_quantity
if firm not in purchase_material_firms:
purchase_material_firms[firm] = []
purchase_material_firms[firm].append((resource_id, required_quantity))
# 寻找供应商并处理补货
for firm, material_requests in purchase_material_firms.items():
for resource_id, required_quantity in material_requests:
# 寻找供应商
supplier = firm.seek_material_supply(resource_id)
if supplier != -1: # 如果找到供应商
# 供应商处理资源请求
supplier.handle_material_request([resource_id, required_quantity])
# 更新当前企业的资源数量
for resource in firm.R:
if resource[0] == resource_id:
resource[1] = firm.S_r
def _handle_machinery_purchase(self):
"""
判断并处理设备的采购
"""
# 存储需要采购设备的企业及其需求
purchase_machinery_firms = {}
# 遍历所有企业,检查设备需求
for firm in self.company_agents:
if not firm.C: # 跳过没有设备的企业
continue
# 检查设备残值,记录需要补充的设备
for equipment in firm.C:
equipment_id, equipment_quantity, equipment_salvage = equipment
equipment_salvage -= firm.x # 减少设备残值
if equipment_salvage <= 0: # 如果残值小于等于 0
equipment_quantity -= 1
required_quantity = 1 # 需要补充的设备数量
if firm not in purchase_machinery_firms:
purchase_machinery_firms[firm] = []
purchase_machinery_firms[firm].append((equipment_id, required_quantity))
# 寻找供应商并处理设备补充
for firm, machinery_requests in purchase_machinery_firms.items():
for equipment_id, required_quantity in machinery_requests:
# 寻找供应商
supplier = firm.seek_machinery_supply(equipment_id)
if supplier != -1: # 如果找到供应商
# 供应商处理设备请求
supplier.handle_machinery_request([equipment_id, required_quantity])
# 恢复企业的设备数量和残值
for equipment, initial_equipment in zip(firm.C, firm.C0):
if equipment[0] == equipment_id:
equipment[1] = initial_equipment[1] # 恢复数量
equipment[2] = initial_equipment[2] # 恢复残值
def _consume_resources_and_produce(self):
"""消耗资源并生产产品。"""
"""
消耗资源并生产产品
"""
k = 0.6 # 资源消耗比例
production_increase_ratio = 1.6 # 产品生产比例
# 遍历每个企业
for firm in self.company_agents:
# 计算消耗量
consumed_resources = {}
for industry in firm.indus_i:
for product_id, product_quantity in firm.P.items():
if product_id == industry.unique_id:
consumed_resources[industry] = product_quantity * k
# 计算资源消耗
consumed_resources = self._calculate_consumed_resources(firm, k)
# 消耗资源
for resource_id, resource_quantity in firm.R.items():
for industry, consumed_quantity in consumed_resources.items():
if resource_id in industry.resource_ids:
firm.R[resource_id] -= consumed_quantity
self._consume_resources(firm, consumed_resources)
# 生产产品
for product_id, product_quantity in firm.P.items():
firm.P[product_id] = product_quantity * 1.6
self._produce_products(firm, production_increase_ratio)
# 刷新资源和设备状态
firm.refresh_R()
firm.refresh_C()
firm.refresh_P()
def _calculate_consumed_resources(self, firm, k):
"""
计算企业的资源消耗量
"""
consumed_resources = {}
for industry in firm.indus_i:
consumed_quantity = sum(
product[1] * k
for product in firm.P
if product[0] == industry.unique_id
)
consumed_resources[industry.unique_id] = consumed_quantity
return consumed_resources
def _consume_resources(self, firm, consumed_resources):
"""
消耗企业的资源
"""
for resource in firm.R:
resource_id, resource_quantity = resource[0], resource[1]
if resource_id in consumed_resources:
resource[1] = max(0, resource_quantity - consumed_resources[resource_id])
def _produce_products(self, firm, production_increase_ratio):
"""
生产企业的产品
"""
for product in firm.P:
product[1] *= production_increase_ratio
def end(self):
"""
结束模型运行并保存结果
功能:
- 检查结果是否已存在避免重复写入
- 保存企业和产品的生产状态到数据库
- 更新样本的状态为完成并记录停止时间和计算机名称
- 如果当前样本的结果未保存则保存所有生产状态为非正常状态的结果
- 更新样本状态为完成并记录相关信息
"""
# 检查当前样本结果是否已存在
qry_result = db_session.query(Result).filter_by(s_id=self.sample.id)
if qry_result.count() == 0:
# 收集所有结果
lst_result_info = []
for firm in self.company_agents:
for prod, dct_status_supply in firm.dct_prod_up_prod_stat.items():
# 检查产品状态是否都为正常
if not all(status == 'N' for status, _ in dct_status_supply['p_stat']):
for status, ts in dct_status_supply['p_stat']:
# 创建结果对象
lst_result_info.append(Result(
s_id=self.sample.id,
id_firm=firm.unique_id,
id_product=prod.unique_id,
ts=ts,
status=status
))
if not db_session.query(Result).filter_by(s_id=self.sample.id).first():
# 生成需要保存的结果列表
lst_result_info = [
Result(
s_id=self.sample.id,
id_firm=firm.unique_id,
id_product=prod.unique_id,
ts=ts,
status=status
)
for firm in self.company_agents
for prod, dct_status_supply in firm.dct_prod_up_prod_stat.items()
if not all(stat == 'N' for stat, _ in dct_status_supply['p_stat'])
for status, ts in dct_status_supply['p_stat']
]
# 批量保存结果
# 批量保存结果到数据库
if lst_result_info:
db_session.bulk_save_objects(lst_result_info)
db_session.commit()

4
orm.py
View File

@ -98,8 +98,8 @@ class Result(Base):
s_id = Column(Integer, ForeignKey('{}.id'.format(
f"{db_name_prefix}_sample")), nullable=False)
id_firm = Column(String(10), nullable=False)
id_product = Column(String(10), nullable=False)
id_firm = Column(String(20), nullable=False)
id_product = Column(String(20), nullable=False)
ts = Column(Integer, nullable=False)
status = Column(String(5), nullable=False)

View File

@ -1,76 +1 @@
s_id,id_firm,id_product,ts
1441,13,2.1.3.4,0
1441,126,2.1.3,1
1441,97,2.1.3,1
1441,106,2.1.3,1
1566,13,2.1.3.7,0
1566,126,2.1.3,1
2073,14,1.3.3.4,0
2073,75,1.3.3,1
2621,85,1.3.1,1
2621,21,1.3.1.3,0
2621,100,1.3.1,1
3386,22,2.1.3.7,0
3386,108,2.1.3,1
4249,23,2.3.1,0
4249,84,2.3,1
4440,25,1.3.1.7,0
4440,100,1.3.1,1
4624,74,2.1.3,1
4624,26,2.1.3.4,0
5015,31,1.3.3.3,0
5015,75,1.3.3,1
5015,97,1.3.3,1
5720,94,1.1,1
5720,36,1.1.1,0
5720,126,1.1,1
7349,80,1.3.4,1
7349,45,1.3.4.2,0
7349,77,1.3.4,1
7399,79,2.1.4.1,1
7399,45,2.1.4.1.1,0
8285,99,1.3.1,1
8285,49,1.3.1.4,0
8601,93,1.3.1,1
8601,50,1.3.1.5,0
8601,85,1.3.1,1
9072,53,1.4.3.4,0
9072,142,1.4.3,1
9382,41,1.4.5,1
9382,53,1.4.5.6,0
10098,99,1.3.3,1
10098,57,1.3.3.3,0
10121,57,2.3.1,0
10121,124,2.3,1
10521,81,1.3.4,1
10521,58,1.3.4.3,0
10521,80,1.3.4,1
11675,93,1.3.1,1
11675,68,1.3.1.3,0
11678,99,1.3.1,1
11678,85,1.3.1,1
11678,68,1.3.1.3,0
12837,126,2.1.3,1
12837,74,2.1.3,1
12837,73,2.1.3,1
12837,79,2.1.3.2,0
13084,108,2.1.3,1
13084,73,2.1.3,1
13084,106,2.1.3,1
13084,79,2.1.3.7,0
13084,148,2.1.3,1
13084,126,2.1.3,1
16647,115,1.1.3,0
16647,94,1.1,1
16903,85,2.1.1,1
16903,117,2.1.1.4,0
17379,119,1.3.1.1,0
17379,100,1.3.1,1
17922,126,2.1.1.5,0
17922,80,2.1.1,1
18824,85,2.1.1,1
18824,131,2.1.1.5,0
19562,135,2.2,0
19562,98,2,1
21447,159,2.1.2,1
21447,149,2.1.2.2,0

1 s_id id_firm id_product ts
1441 13 2.1.3.4 0
1441 126 2.1.3 1
1441 97 2.1.3 1
1441 106 2.1.3 1
1566 13 2.1.3.7 0
1566 126 2.1.3 1
2073 14 1.3.3.4 0
2073 75 1.3.3 1
2621 85 1.3.1 1
2621 21 1.3.1.3 0
2621 100 1.3.1 1
3386 22 2.1.3.7 0
3386 108 2.1.3 1
4249 23 2.3.1 0
4249 84 2.3 1
4440 25 1.3.1.7 0
4440 100 1.3.1 1
4624 74 2.1.3 1
4624 26 2.1.3.4 0
5015 31 1.3.3.3 0
5015 75 1.3.3 1
5015 97 1.3.3 1
5720 94 1.1 1
5720 36 1.1.1 0
5720 126 1.1 1
7349 80 1.3.4 1
7349 45 1.3.4.2 0
7349 77 1.3.4 1
7399 79 2.1.4.1 1
7399 45 2.1.4.1.1 0
8285 99 1.3.1 1
8285 49 1.3.1.4 0
8601 93 1.3.1 1
8601 50 1.3.1.5 0
8601 85 1.3.1 1
9072 53 1.4.3.4 0
9072 142 1.4.3 1
9382 41 1.4.5 1
9382 53 1.4.5.6 0
10098 99 1.3.3 1
10098 57 1.3.3.3 0
10121 57 2.3.1 0
10121 124 2.3 1
10521 81 1.3.4 1
10521 58 1.3.4.3 0
10521 80 1.3.4 1
11675 93 1.3.1 1
11675 68 1.3.1.3 0
11678 99 1.3.1 1
11678 85 1.3.1 1
11678 68 1.3.1.3 0
12837 126 2.1.3 1
12837 74 2.1.3 1
12837 73 2.1.3 1
12837 79 2.1.3.2 0
13084 108 2.1.3 1
13084 73 2.1.3 1
13084 106 2.1.3 1
13084 79 2.1.3.7 0
13084 148 2.1.3 1
13084 126 2.1.3 1
16647 115 1.1.3 0
16647 94 1.1 1
16903 85 2.1.1 1
16903 117 2.1.1.4 0
17379 119 1.3.1.1 0
17379 100 1.3.1 1
17922 126 2.1.1.5 0
17922 80 2.1.1 1
18824 85 2.1.1 1
18824 131 2.1.1.5 0
19562 135 2.2 0
19562 98 2 1
21447 159 2.1.2 1
21447 149 2.1.2.2 0

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.1 MiB

After

Width:  |  Height:  |  Size: 2.0 MiB

52
查看进度.py Normal file
View File

@ -0,0 +1,52 @@
from matplotlib import rcParams, pyplot as plt
from sqlalchemy import func
from orm import db_session, Sample
# 创建全局绘图对象和轴
fig, ax = plt.subplots(figsize=(8, 5))
plt.ion() # 启用交互模式
def visualize_progress():
"""
可视化 `is_done_flag` 的分布动态更新进度条
"""
# 设置全局字体
rcParams['font.family'] = 'SimHei' # 黑体,适用于中文
rcParams['font.size'] = 12
# 查询数据库中各 is_done_flag 的数量
result = db_session.query(
Sample.is_done_flag, func.count(Sample.id)
).group_by(Sample.is_done_flag).all()
# 转换为字典
data = {flag: count for flag, count in result}
# 填充缺失的标志为 0
for flag in [-1, 0, 1]:
data.setdefault(flag, 0)
# 准备数据
labels = ['未完成 (-1)', '计算中(0)', '完成 (1)']
values = [data[-1], data[0], data[1]]
# 清空之前的绘图内容
ax.clear()
# 创建柱状图
ax.bar(labels, values, color=['red', 'orange', 'green'])
ax.set_title('任务进度分布', fontsize=16)
ax.set_xlabel('任务状态', fontsize=14)
ax.set_ylabel('数量', fontsize=14)
ax.tick_params(axis='both', labelsize=12)
# 显示具体数量
for i, v in enumerate(values):
ax.text(i, v + 0.5, str(v), ha='center', fontsize=12)
# 刷新绘图
plt.pause(1) # 暂停一段时间以更新图表
# 关闭窗口时,停止交互模式
# plt.ioff()