暂时通过测试

This commit is contained in:
Cricial 2025-01-27 15:18:51 +08:00
parent 0f7f9c1a4b
commit 30e7e56c11
18 changed files with 106 additions and 153 deletions

9
11.py
View File

@ -1,6 +1,9 @@
import pickle import pickle
import os import os
from 查看进度 import visualize_progress
def load_cached_data(file_path): def load_cached_data(file_path):
""" """
从指定的缓存文件加载数据 从指定的缓存文件加载数据
@ -19,6 +22,8 @@ def load_cached_data(file_path):
print(f"Error loading cache from '{file_path}': {e}") print(f"Error loading cache from '{file_path}': {e}")
return {} return {}
# 示例用法
data_dct = load_cached_data("G_Firm_add_edges.pkl")
# 示例用法
# data_dct = load_cached_data("G_Firm_add_edges.pkl")
visualize_progress()

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -44,6 +44,6 @@ class Computation:
model = MyModel(dct_sample_para) model = MyModel(dct_sample_para)
model.step() # 运行仿真一步 model.step() # 运行仿真一步
model.end() # 汇总结果
return False return False

View File

@ -1 +1 @@
db_name_prefix: without_exp db_name_prefix: with_exp

View File

@ -8,5 +8,5 @@ test: # only for test scenarios
n_iter: 100 n_iter: 100
not_test: # normal scenarios not_test: # normal scenarios
n_sample: 5 n_sample: 10
n_iter: 10 n_iter: 50

View File

@ -53,7 +53,6 @@ class ControllerDB:
# fill dct_lst_init_disrupt_firm_prod # fill dct_lst_init_disrupt_firm_prod
# 存储 公司-在供应链结点的位置.. 0 1.1 # 存储 公司-在供应链结点的位置.. 0 1.1
list_dct = [] # 存储 公司编码code 和对应的产业链 结点
if self.is_with_exp: if self.is_with_exp:
# 对于方差分析时候使用 # 对于方差分析时候使用
with open('SQL_export_high_risk_setting.sql', 'r') as f: with open('SQL_export_high_risk_setting.sql', 'r') as f:

85
firm.py
View File

@ -143,42 +143,69 @@ class FirmAgent(Agent):
# f"disrupted supplier of {disrupted_up_prod.code}") # f"disrupted supplier of {disrupted_up_prod.code}")
def seek_alt_supply(self, product): def seek_alt_supply(self, product):
# 检查当前产品的尝试次数是否达到最大值
if self.dct_n_trial_up_prod_disrupted[product] <= self.model.int_n_max_trial: if self.dct_n_trial_up_prod_disrupted[product] <= self.model.int_n_max_trial:
# 初始化候选供应商列表
if self.dct_n_trial_up_prod_disrupted[product] == 0: if self.dct_n_trial_up_prod_disrupted[product] == 0:
self.dct_cand_alt_supp_up_prod_disrupted[product] = [ self.dct_cand_alt_supp_up_prod_disrupted[product] = [
firm for firm in self.model.company_agents firm for firm in self.model.company_agents if firm.is_prod_in_current_normal(product)
if firm.is_prod_in_current_normal(product)] ]
if self.dct_cand_alt_supp_up_prod_disrupted[product]:
lst_firm_connect = []
if self.is_prf_conn:
for firm in self.dct_cand_alt_supp_up_prod_disrupted[product]:
if self.firm_network.has_edge(self.unique_id, firm.unique_id) or \
self.firm_network.has_edge(firm.unique_id, self.unique_id):
lst_firm_connect.append(firm)
if len(lst_firm_connect) == 0:
if self.is_prf_size:
lst_size = [firm.size_stat[-1][0] for firm in self.dct_cand_alt_supp_up_prod_disrupted[product]]
lst_prob = [size / sum(lst_size) for size in lst_size]
select_alt_supply = \
self.random.choices(self.dct_cand_alt_supp_up_prod_disrupted[product], weights=lst_prob)[0]
else:
select_alt_supply = self.random.choice(self.dct_cand_alt_supp_up_prod_disrupted[product])
elif len(lst_firm_connect) > 0:
if self.is_prf_size:
lst_firm_size = [firm.size_stat[-1][0] for firm in lst_firm_connect]
lst_prob = [size / sum(lst_firm_size) for size in lst_firm_size]
select_alt_supply = self.random.choices(lst_firm_connect, weights=lst_prob)[0]
else:
select_alt_supply = self.random.choice(lst_firm_connect)
assert select_alt_supply.is_prod_in_current_normal(product) # 如果没有候选供应商,直接退出
if not self.dct_cand_alt_supp_up_prod_disrupted[product]:
# print(f"No valid candidates found for product {product.unique_id}")
return
if product in select_alt_supply.dct_request_prod_from_firm: # 查找与当前企业已连接的候选供应商
select_alt_supply.dct_request_prod_from_firm[product].append(self) lst_firm_connect = []
if self.is_prf_conn:
lst_firm_connect = [
firm for firm in self.dct_cand_alt_supp_up_prod_disrupted[product]
if self.firm_network.has_edge(self.unique_id, firm.unique_id) or
self.firm_network.has_edge(firm.unique_id, self.unique_id)
]
# 如果没有连接的供应商
if not lst_firm_connect:
if self.is_prf_size: # 根据规模加权选择
lst_size = [firm.size_stat[-1][0] for firm in self.dct_cand_alt_supp_up_prod_disrupted[product]]
lst_prob = [size / sum(lst_size) for size in lst_size]
select_alt_supply = self.random.choices(
self.dct_cand_alt_supp_up_prod_disrupted[product], weights=lst_prob
)[0]
else: # 随机选择
select_alt_supply = self.random.choice(self.dct_cand_alt_supp_up_prod_disrupted[product])
else: # 如果存在连接的供应商
if self.is_prf_size: # 根据规模加权选择
lst_firm_size = [firm.size_stat[-1][0] for firm in lst_firm_connect]
lst_prob = [size / sum(lst_firm_size) for size in lst_firm_size]
select_alt_supply = self.random.choices(lst_firm_connect, weights=lst_prob)[0]
else: # 随机选择
select_alt_supply = self.random.choice(lst_firm_connect)
# 检查选中的供应商是否能够生产产品
if not select_alt_supply.is_prod_in_current_normal(product):
# print(f"Selected supplier {select_alt_supply.unique_id} cannot produce product {product.unique_id}")
# 打印供应商的生产状态字典
#print(f"Supplier production state: {select_alt_supply.dct_prod_up_prod_stat}")
# 检查产品是否存在于生产状态字典中
if product in select_alt_supply.dct_prod_up_prod_stat:
print(
f"Product {product.unique_id} production state: {select_alt_supply.dct_prod_up_prod_stat[product]['p_stat']}")
else: else:
select_alt_supply.dct_request_prod_from_firm[product] = [self] print(f"Product {product.unique_id} not found in supplier production state.")
return
self.dct_n_trial_up_prod_disrupted[product] += 1 # 添加到供应商的请求字典
if product in select_alt_supply.dct_request_prod_from_firm:
select_alt_supply.dct_request_prod_from_firm[product].append(self)
else:
select_alt_supply.dct_request_prod_from_firm[product] = [self]
# 更新尝试次数
self.dct_n_trial_up_prod_disrupted[product] += 1
def handle_request(self): def handle_request(self):
for product, lst_firm in self.dct_request_prod_from_firm.items(): for product, lst_firm in self.dct_request_prod_from_firm.items():

View File

@ -45,7 +45,7 @@ def do_computation(c_db):
exp = Computation(c_db) exp = Computation(c_db)
while 1: while 1:
time.sleep(random.uniform(0, 2)) time.sleep(random.uniform(0, 1))
is_all_done = exp.run() is_all_done = exp.run()
if is_all_done: if is_all_done:
break break
@ -54,7 +54,7 @@ def do_computation(c_db):
if __name__ == '__main__': if __name__ == '__main__':
# 输入参数 # 输入参数
parser = argparse.ArgumentParser(description='setting') parser = argparse.ArgumentParser(description='setting')
parser.add_argument('--exp', type=str, default='without_exp') parser.add_argument('--exp', type=str, default='with_exp')
parser.add_argument('--job', type=int, default='4') parser.add_argument('--job', type=int, default='4')
parser.add_argument('--reset_sample', type=int, default='0') parser.add_argument('--reset_sample', type=int, default='0')
parser.add_argument('--reset_db', type=bool, default=False) parser.add_argument('--reset_db', type=bool, default=False)

View File

@ -54,7 +54,7 @@ class MyModel(Model):
self.cap_limit_level = params['cap_limit_level'] # 产能限制的水平。 self.cap_limit_level = params['cap_limit_level'] # 产能限制的水平。
self.diff_new_conn = params['diff_new_conn'] # 是否允许差异化的新连接。 self.diff_new_conn = params['diff_new_conn'] # 是否允许差异化的新连接。
# 初始化停止时间步,可能是用户通过参数传入 # 初始化停止时间步,可能是用户通过参数传入
self.int_stop_ts = params.get('stop_t', 3) # 默认停止时间为 100 self.int_stop_ts = params.get('n_iter', 3) # 默认停止时间为 100
# 网络初始化 # 网络初始化
self.firm_network = nx.MultiDiGraph() # 企业之间的有向多重图。 self.firm_network = nx.MultiDiGraph() # 企业之间的有向多重图。
@ -624,22 +624,23 @@ class MyModel(Model):
3. 判断资源和设备是否需要采购并处理采购 3. 判断资源和设备是否需要采购并处理采购
4. 资源消耗和产品生产 4. 资源消耗和产品生产
""" """
# 1. 移除客户边并中断客户上游产品 while self.t < self.int_stop_ts: # 使用循环控制时间步
self._remove_disrupted_edges() # 1. 移除客户边并中断客户上游产品
self._disrupt_upstream_products() self._remove_disrupted_edges()
self._disrupt_upstream_products()
# 2. 尝试寻找替代供应链 # 2. 尝试寻找替代供应链
self._trial_process() self._trial_process()
# 3. 判断是否需要采购资源和设备 # 3. 判断是否需要采购资源和设备
self._handle_material_purchase() self._handle_material_purchase()
self._handle_machinery_purchase() self._handle_machinery_purchase()
# 4. 资源消耗和产品生产 # 4. 资源消耗和产品生产
self._consume_resources_and_produce() self._consume_resources_and_produce()
# 增加时间步 # 增加时间步
self.t += 1 self.t += 1
# 子方法定义 # 子方法定义
def _remove_disrupted_edges(self): def _remove_disrupted_edges(self):
@ -815,31 +816,27 @@ class MyModel(Model):
def end(self): def end(self):
""" """
结束模型运行并保存结果 结束模型运行并保存结果
功能: - 如果当前样本的结果未保存则保存所有生产状态为非正常状态的结果
- 检查结果是否已存在避免重复写入 - 更新样本状态为完成并记录相关信息
- 保存企业和产品的生产状态到数据库
- 更新样本的状态为完成并记录停止时间和计算机名称
""" """
# 检查当前样本结果是否已存在 # 检查当前样本结果是否已存在
qry_result = db_session.query(Result).filter_by(s_id=self.sample.id) if not db_session.query(Result).filter_by(s_id=self.sample.id).first():
if qry_result.count() == 0: # 生成需要保存的结果列表
# 收集所有结果 lst_result_info = [
lst_result_info = [] Result(
for firm in self.company_agents: s_id=self.sample.id,
for prod, dct_status_supply in firm.dct_prod_up_prod_stat.items(): id_firm=firm.unique_id,
# 检查产品状态是否都为正常 id_product=prod.unique_id,
if not all(status == 'N' for status, _ in dct_status_supply['p_stat']): ts=ts,
for status, ts in dct_status_supply['p_stat']: status=status
# 创建结果对象 )
lst_result_info.append(Result( for firm in self.company_agents
s_id=self.sample.id, for prod, dct_status_supply in firm.dct_prod_up_prod_stat.items()
id_firm=firm.unique_id, if not all(stat == 'N' for stat, _ in dct_status_supply['p_stat'])
id_product=prod.unique_id, for status, ts in dct_status_supply['p_stat']
ts=ts, ]
status=status
))
# 批量保存结果 # 批量保存结果到数据库
if lst_result_info: if lst_result_info:
db_session.bulk_save_objects(lst_result_info) db_session.bulk_save_objects(lst_result_info)
db_session.commit() db_session.commit()

4
orm.py
View File

@ -98,8 +98,8 @@ class Result(Base):
s_id = Column(Integer, ForeignKey('{}.id'.format( s_id = Column(Integer, ForeignKey('{}.id'.format(
f"{db_name_prefix}_sample")), nullable=False) f"{db_name_prefix}_sample")), nullable=False)
id_firm = Column(String(10), nullable=False) id_firm = Column(String(20), nullable=False)
id_product = Column(String(10), nullable=False) id_product = Column(String(20), nullable=False)
ts = Column(Integer, nullable=False) ts = Column(Integer, nullable=False)
status = Column(String(5), nullable=False) status = Column(String(5), nullable=False)

View File

@ -1,76 +1 @@
s_id,id_firm,id_product,ts s_id,id_firm,id_product,ts
1441,13,2.1.3.4,0
1441,126,2.1.3,1
1441,97,2.1.3,1
1441,106,2.1.3,1
1566,13,2.1.3.7,0
1566,126,2.1.3,1
2073,14,1.3.3.4,0
2073,75,1.3.3,1
2621,85,1.3.1,1
2621,21,1.3.1.3,0
2621,100,1.3.1,1
3386,22,2.1.3.7,0
3386,108,2.1.3,1
4249,23,2.3.1,0
4249,84,2.3,1
4440,25,1.3.1.7,0
4440,100,1.3.1,1
4624,74,2.1.3,1
4624,26,2.1.3.4,0
5015,31,1.3.3.3,0
5015,75,1.3.3,1
5015,97,1.3.3,1
5720,94,1.1,1
5720,36,1.1.1,0
5720,126,1.1,1
7349,80,1.3.4,1
7349,45,1.3.4.2,0
7349,77,1.3.4,1
7399,79,2.1.4.1,1
7399,45,2.1.4.1.1,0
8285,99,1.3.1,1
8285,49,1.3.1.4,0
8601,93,1.3.1,1
8601,50,1.3.1.5,0
8601,85,1.3.1,1
9072,53,1.4.3.4,0
9072,142,1.4.3,1
9382,41,1.4.5,1
9382,53,1.4.5.6,0
10098,99,1.3.3,1
10098,57,1.3.3.3,0
10121,57,2.3.1,0
10121,124,2.3,1
10521,81,1.3.4,1
10521,58,1.3.4.3,0
10521,80,1.3.4,1
11675,93,1.3.1,1
11675,68,1.3.1.3,0
11678,99,1.3.1,1
11678,85,1.3.1,1
11678,68,1.3.1.3,0
12837,126,2.1.3,1
12837,74,2.1.3,1
12837,73,2.1.3,1
12837,79,2.1.3.2,0
13084,108,2.1.3,1
13084,73,2.1.3,1
13084,106,2.1.3,1
13084,79,2.1.3.7,0
13084,148,2.1.3,1
13084,126,2.1.3,1
16647,115,1.1.3,0
16647,94,1.1,1
16903,85,2.1.1,1
16903,117,2.1.1.4,0
17379,119,1.3.1.1,0
17379,100,1.3.1,1
17922,126,2.1.1.5,0
17922,80,2.1.1,1
18824,85,2.1.1,1
18824,131,2.1.1.5,0
19562,135,2.2,0
19562,98,2,1
21447,159,2.1.2,1
21447,149,2.1.2.2,0

1 s_id id_firm id_product ts
1441 13 2.1.3.4 0
1441 126 2.1.3 1
1441 97 2.1.3 1
1441 106 2.1.3 1
1566 13 2.1.3.7 0
1566 126 2.1.3 1
2073 14 1.3.3.4 0
2073 75 1.3.3 1
2621 85 1.3.1 1
2621 21 1.3.1.3 0
2621 100 1.3.1 1
3386 22 2.1.3.7 0
3386 108 2.1.3 1
4249 23 2.3.1 0
4249 84 2.3 1
4440 25 1.3.1.7 0
4440 100 1.3.1 1
4624 74 2.1.3 1
4624 26 2.1.3.4 0
5015 31 1.3.3.3 0
5015 75 1.3.3 1
5015 97 1.3.3 1
5720 94 1.1 1
5720 36 1.1.1 0
5720 126 1.1 1
7349 80 1.3.4 1
7349 45 1.3.4.2 0
7349 77 1.3.4 1
7399 79 2.1.4.1 1
7399 45 2.1.4.1.1 0
8285 99 1.3.1 1
8285 49 1.3.1.4 0
8601 93 1.3.1 1
8601 50 1.3.1.5 0
8601 85 1.3.1 1
9072 53 1.4.3.4 0
9072 142 1.4.3 1
9382 41 1.4.5 1
9382 53 1.4.5.6 0
10098 99 1.3.3 1
10098 57 1.3.3.3 0
10121 57 2.3.1 0
10121 124 2.3 1
10521 81 1.3.4 1
10521 58 1.3.4.3 0
10521 80 1.3.4 1
11675 93 1.3.1 1
11675 68 1.3.1.3 0
11678 99 1.3.1 1
11678 85 1.3.1 1
11678 68 1.3.1.3 0
12837 126 2.1.3 1
12837 74 2.1.3 1
12837 73 2.1.3 1
12837 79 2.1.3.2 0
13084 108 2.1.3 1
13084 73 2.1.3 1
13084 106 2.1.3 1
13084 79 2.1.3.7 0
13084 148 2.1.3 1
13084 126 2.1.3 1
16647 115 1.1.3 0
16647 94 1.1 1
16903 85 2.1.1 1
16903 117 2.1.1.4 0
17379 119 1.3.1.1 0
17379 100 1.3.1 1
17922 126 2.1.1.5 0
17922 80 2.1.1 1
18824 85 2.1.1 1
18824 131 2.1.1.5 0
19562 135 2.2 0
19562 98 2 1
21447 159 2.1.2 1
21447 149 2.1.2.2 0

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.1 MiB

After

Width:  |  Height:  |  Size: 2.0 MiB

View File

@ -28,7 +28,7 @@ def visualize_progress():
data.setdefault(flag, 0) data.setdefault(flag, 0)
# 准备数据 # 准备数据
labels = ['未完成 (-1)', '完成(0)', '等待 (1)'] labels = ['未完成 (-1)', '计算中(0)', '完成 (1)']
values = [data[-1], data[0], data[1]] values = [data[-1], data[0], data[1]]
# 清空之前的绘图内容 # 清空之前的绘图内容
@ -46,7 +46,7 @@ def visualize_progress():
ax.text(i, v + 0.5, str(v), ha='center', fontsize=12) ax.text(i, v + 0.5, str(v), ha='center', fontsize=12)
# 刷新绘图 # 刷新绘图
plt.pause(0.1) # 暂停一段时间以更新图表 plt.pause(1) # 暂停一段时间以更新图表
# 关闭窗口时,停止交互模式 # 关闭窗口时,停止交互模式
plt.ioff() # plt.ioff()