diff --git a/11.py b/11.py index 1288772..1d6ca55 100644 --- a/11.py +++ b/11.py @@ -1,6 +1,9 @@ import pickle import os +from 查看进度 import visualize_progress + + def load_cached_data(file_path): """ 从指定的缓存文件加载数据。 @@ -19,6 +22,8 @@ def load_cached_data(file_path): print(f"Error loading cache from '{file_path}': {e}") return {} -# 示例用法 -data_dct = load_cached_data("G_Firm_add_edges.pkl") +# 示例用法 +# data_dct = load_cached_data("G_Firm_add_edges.pkl") + +visualize_progress() diff --git a/__pycache__/computation.cpython-38.pyc b/__pycache__/computation.cpython-38.pyc index 2255337..07c59ca 100644 Binary files a/__pycache__/computation.cpython-38.pyc and b/__pycache__/computation.cpython-38.pyc differ diff --git a/__pycache__/controller_db.cpython-38.pyc b/__pycache__/controller_db.cpython-38.pyc index 9ca2a33..774db44 100644 Binary files a/__pycache__/controller_db.cpython-38.pyc and b/__pycache__/controller_db.cpython-38.pyc differ diff --git a/__pycache__/firm.cpython-38.pyc b/__pycache__/firm.cpython-38.pyc index d913b3f..082f5bf 100644 Binary files a/__pycache__/firm.cpython-38.pyc and b/__pycache__/firm.cpython-38.pyc differ diff --git a/__pycache__/my_model.cpython-38.pyc b/__pycache__/my_model.cpython-38.pyc index 2048625..c98e8ea 100644 Binary files a/__pycache__/my_model.cpython-38.pyc and b/__pycache__/my_model.cpython-38.pyc differ diff --git a/__pycache__/orm.cpython-38.pyc b/__pycache__/orm.cpython-38.pyc index 9a0e013..1cff520 100644 Binary files a/__pycache__/orm.cpython-38.pyc and b/__pycache__/orm.cpython-38.pyc differ diff --git a/__pycache__/查看进度.cpython-38.pyc b/__pycache__/查看进度.cpython-38.pyc index 7de7e92..a1855b8 100644 Binary files a/__pycache__/查看进度.cpython-38.pyc and b/__pycache__/查看进度.cpython-38.pyc differ diff --git a/computation.py b/computation.py index fc8ce99..58d5e7d 100644 --- a/computation.py +++ b/computation.py @@ -44,6 +44,6 @@ class Computation: model = MyModel(dct_sample_para) model.step() # 运行仿真一步 + model.end() # 汇总结果 return False - diff --git a/conf_db_prefix.yaml b/conf_db_prefix.yaml index 0e48fd8..1391188 100644 --- a/conf_db_prefix.yaml +++ b/conf_db_prefix.yaml @@ -1 +1 @@ -db_name_prefix: without_exp +db_name_prefix: with_exp diff --git a/conf_experiment.yaml b/conf_experiment.yaml index 173d772..ab75a91 100644 --- a/conf_experiment.yaml +++ b/conf_experiment.yaml @@ -8,5 +8,5 @@ test: # only for test scenarios n_iter: 100 not_test: # normal scenarios - n_sample: 5 - n_iter: 10 + n_sample: 10 + n_iter: 50 diff --git a/controller_db.py b/controller_db.py index f733fd7..54e03a8 100644 --- a/controller_db.py +++ b/controller_db.py @@ -53,7 +53,6 @@ class ControllerDB: # fill dct_lst_init_disrupt_firm_prod # 存储 公司-在供应链结点的位置.. 0 :‘1.1’ - list_dct = [] # 存储 公司编码code 和对应的产业链 结点 if self.is_with_exp: # 对于方差分析时候使用 with open('SQL_export_high_risk_setting.sql', 'r') as f: diff --git a/firm.py b/firm.py index 72e8f9c..740f2f8 100644 --- a/firm.py +++ b/firm.py @@ -143,42 +143,69 @@ class FirmAgent(Agent): # f"disrupted supplier of {disrupted_up_prod.code}") def seek_alt_supply(self, product): + # 检查当前产品的尝试次数是否达到最大值 if self.dct_n_trial_up_prod_disrupted[product] <= self.model.int_n_max_trial: + # 初始化候选供应商列表 if self.dct_n_trial_up_prod_disrupted[product] == 0: self.dct_cand_alt_supp_up_prod_disrupted[product] = [ - firm for firm in self.model.company_agents - if firm.is_prod_in_current_normal(product)] - if self.dct_cand_alt_supp_up_prod_disrupted[product]: - lst_firm_connect = [] - if self.is_prf_conn: - for firm in self.dct_cand_alt_supp_up_prod_disrupted[product]: - if self.firm_network.has_edge(self.unique_id, firm.unique_id) or \ - self.firm_network.has_edge(firm.unique_id, self.unique_id): - lst_firm_connect.append(firm) - if len(lst_firm_connect) == 0: - if self.is_prf_size: - lst_size = [firm.size_stat[-1][0] for firm in self.dct_cand_alt_supp_up_prod_disrupted[product]] - lst_prob = [size / sum(lst_size) for size in lst_size] - select_alt_supply = \ - self.random.choices(self.dct_cand_alt_supp_up_prod_disrupted[product], weights=lst_prob)[0] - else: - select_alt_supply = self.random.choice(self.dct_cand_alt_supp_up_prod_disrupted[product]) - elif len(lst_firm_connect) > 0: - if self.is_prf_size: - lst_firm_size = [firm.size_stat[-1][0] for firm in lst_firm_connect] - lst_prob = [size / sum(lst_firm_size) for size in lst_firm_size] - select_alt_supply = self.random.choices(lst_firm_connect, weights=lst_prob)[0] - else: - select_alt_supply = self.random.choice(lst_firm_connect) + firm for firm in self.model.company_agents if firm.is_prod_in_current_normal(product) + ] - assert select_alt_supply.is_prod_in_current_normal(product) + # 如果没有候选供应商,直接退出 + if not self.dct_cand_alt_supp_up_prod_disrupted[product]: + # print(f"No valid candidates found for product {product.unique_id}") + return - if product in select_alt_supply.dct_request_prod_from_firm: - select_alt_supply.dct_request_prod_from_firm[product].append(self) + # 查找与当前企业已连接的候选供应商 + lst_firm_connect = [] + if self.is_prf_conn: + lst_firm_connect = [ + firm for firm in self.dct_cand_alt_supp_up_prod_disrupted[product] + if self.firm_network.has_edge(self.unique_id, firm.unique_id) or + self.firm_network.has_edge(firm.unique_id, self.unique_id) + ] + + # 如果没有连接的供应商 + if not lst_firm_connect: + if self.is_prf_size: # 根据规模加权选择 + lst_size = [firm.size_stat[-1][0] for firm in self.dct_cand_alt_supp_up_prod_disrupted[product]] + lst_prob = [size / sum(lst_size) for size in lst_size] + select_alt_supply = self.random.choices( + self.dct_cand_alt_supp_up_prod_disrupted[product], weights=lst_prob + )[0] + else: # 随机选择 + select_alt_supply = self.random.choice(self.dct_cand_alt_supp_up_prod_disrupted[product]) + else: # 如果存在连接的供应商 + if self.is_prf_size: # 根据规模加权选择 + lst_firm_size = [firm.size_stat[-1][0] for firm in lst_firm_connect] + lst_prob = [size / sum(lst_firm_size) for size in lst_firm_size] + select_alt_supply = self.random.choices(lst_firm_connect, weights=lst_prob)[0] + else: # 随机选择 + select_alt_supply = self.random.choice(lst_firm_connect) + + # 检查选中的供应商是否能够生产产品 + if not select_alt_supply.is_prod_in_current_normal(product): + # print(f"Selected supplier {select_alt_supply.unique_id} cannot produce product {product.unique_id}") + + # 打印供应商的生产状态字典 + #print(f"Supplier production state: {select_alt_supply.dct_prod_up_prod_stat}") + + # 检查产品是否存在于生产状态字典中 + if product in select_alt_supply.dct_prod_up_prod_stat: + print( + f"Product {product.unique_id} production state: {select_alt_supply.dct_prod_up_prod_stat[product]['p_stat']}") else: - select_alt_supply.dct_request_prod_from_firm[product] = [self] + print(f"Product {product.unique_id} not found in supplier production state.") + return - self.dct_n_trial_up_prod_disrupted[product] += 1 + # 添加到供应商的请求字典 + if product in select_alt_supply.dct_request_prod_from_firm: + select_alt_supply.dct_request_prod_from_firm[product].append(self) + else: + select_alt_supply.dct_request_prod_from_firm[product] = [self] + + # 更新尝试次数 + self.dct_n_trial_up_prod_disrupted[product] += 1 def handle_request(self): for product, lst_firm in self.dct_request_prod_from_firm.items(): diff --git a/main.py b/main.py index fb74878..2266452 100644 --- a/main.py +++ b/main.py @@ -45,7 +45,7 @@ def do_computation(c_db): exp = Computation(c_db) while 1: - time.sleep(random.uniform(0, 2)) + time.sleep(random.uniform(0, 1)) is_all_done = exp.run() if is_all_done: break @@ -54,7 +54,7 @@ def do_computation(c_db): if __name__ == '__main__': # 输入参数 parser = argparse.ArgumentParser(description='setting') - parser.add_argument('--exp', type=str, default='without_exp') + parser.add_argument('--exp', type=str, default='with_exp') parser.add_argument('--job', type=int, default='4') parser.add_argument('--reset_sample', type=int, default='0') parser.add_argument('--reset_db', type=bool, default=False) diff --git a/my_model.py b/my_model.py index 6f8a075..e27dd0f 100644 --- a/my_model.py +++ b/my_model.py @@ -54,7 +54,7 @@ class MyModel(Model): self.cap_limit_level = params['cap_limit_level'] # 产能限制的水平。 self.diff_new_conn = params['diff_new_conn'] # 是否允许差异化的新连接。 # 初始化停止时间步,可能是用户通过参数传入 - self.int_stop_ts = params.get('stop_t', 3) # 默认停止时间为 100 + self.int_stop_ts = params.get('n_iter', 3) # 默认停止时间为 100 # 网络初始化 self.firm_network = nx.MultiDiGraph() # 企业之间的有向多重图。 @@ -624,22 +624,23 @@ class MyModel(Model): 3. 判断资源和设备是否需要采购,并处理采购。 4. 资源消耗和产品生产。 """ - # 1. 移除客户边并中断客户上游产品 - self._remove_disrupted_edges() - self._disrupt_upstream_products() + while self.t < self.int_stop_ts: # 使用循环控制时间步 + # 1. 移除客户边并中断客户上游产品 + self._remove_disrupted_edges() + self._disrupt_upstream_products() - # 2. 尝试寻找替代供应链 - self._trial_process() + # 2. 尝试寻找替代供应链 + self._trial_process() - # 3. 判断是否需要采购资源和设备 - self._handle_material_purchase() - self._handle_machinery_purchase() + # 3. 判断是否需要采购资源和设备 + self._handle_material_purchase() + self._handle_machinery_purchase() - # 4. 资源消耗和产品生产 - self._consume_resources_and_produce() + # 4. 资源消耗和产品生产 + self._consume_resources_and_produce() - # 增加时间步 - self.t += 1 + # 增加时间步 + self.t += 1 # 子方法定义 def _remove_disrupted_edges(self): @@ -815,31 +816,27 @@ class MyModel(Model): def end(self): """ 结束模型运行并保存结果。 - 功能: - - 检查结果是否已存在,避免重复写入。 - - 保存企业和产品的生产状态到数据库。 - - 更新样本的状态为完成,并记录停止时间和计算机名称。 + - 如果当前样本的结果未保存,则保存所有生产状态为非正常状态的结果。 + - 更新样本状态为完成,并记录相关信息。 """ # 检查当前样本结果是否已存在 - qry_result = db_session.query(Result).filter_by(s_id=self.sample.id) - if qry_result.count() == 0: - # 收集所有结果 - lst_result_info = [] - for firm in self.company_agents: - for prod, dct_status_supply in firm.dct_prod_up_prod_stat.items(): - # 检查产品状态是否都为正常 - if not all(status == 'N' for status, _ in dct_status_supply['p_stat']): - for status, ts in dct_status_supply['p_stat']: - # 创建结果对象 - lst_result_info.append(Result( - s_id=self.sample.id, - id_firm=firm.unique_id, - id_product=prod.unique_id, - ts=ts, - status=status - )) + if not db_session.query(Result).filter_by(s_id=self.sample.id).first(): + # 生成需要保存的结果列表 + lst_result_info = [ + Result( + s_id=self.sample.id, + id_firm=firm.unique_id, + id_product=prod.unique_id, + ts=ts, + status=status + ) + for firm in self.company_agents + for prod, dct_status_supply in firm.dct_prod_up_prod_stat.items() + if not all(stat == 'N' for stat, _ in dct_status_supply['p_stat']) + for status, ts in dct_status_supply['p_stat'] + ] - # 批量保存结果 + # 批量保存结果到数据库 if lst_result_info: db_session.bulk_save_objects(lst_result_info) db_session.commit() diff --git a/orm.py b/orm.py index b96d88f..ffe995f 100644 --- a/orm.py +++ b/orm.py @@ -98,8 +98,8 @@ class Result(Base): s_id = Column(Integer, ForeignKey('{}.id'.format( f"{db_name_prefix}_sample")), nullable=False) - id_firm = Column(String(10), nullable=False) - id_product = Column(String(10), nullable=False) + id_firm = Column(String(20), nullable=False) + id_product = Column(String(20), nullable=False) ts = Column(Integer, nullable=False) status = Column(String(5), nullable=False) diff --git a/output_result/risk/count.csv b/output_result/risk/count.csv index c650bce..e877bba 100644 --- a/output_result/risk/count.csv +++ b/output_result/risk/count.csv @@ -1,76 +1 @@ s_id,id_firm,id_product,ts -1441,13,2.1.3.4,0 -1441,126,2.1.3,1 -1441,97,2.1.3,1 -1441,106,2.1.3,1 -1566,13,2.1.3.7,0 -1566,126,2.1.3,1 -2073,14,1.3.3.4,0 -2073,75,1.3.3,1 -2621,85,1.3.1,1 -2621,21,1.3.1.3,0 -2621,100,1.3.1,1 -3386,22,2.1.3.7,0 -3386,108,2.1.3,1 -4249,23,2.3.1,0 -4249,84,2.3,1 -4440,25,1.3.1.7,0 -4440,100,1.3.1,1 -4624,74,2.1.3,1 -4624,26,2.1.3.4,0 -5015,31,1.3.3.3,0 -5015,75,1.3.3,1 -5015,97,1.3.3,1 -5720,94,1.1,1 -5720,36,1.1.1,0 -5720,126,1.1,1 -7349,80,1.3.4,1 -7349,45,1.3.4.2,0 -7349,77,1.3.4,1 -7399,79,2.1.4.1,1 -7399,45,2.1.4.1.1,0 -8285,99,1.3.1,1 -8285,49,1.3.1.4,0 -8601,93,1.3.1,1 -8601,50,1.3.1.5,0 -8601,85,1.3.1,1 -9072,53,1.4.3.4,0 -9072,142,1.4.3,1 -9382,41,1.4.5,1 -9382,53,1.4.5.6,0 -10098,99,1.3.3,1 -10098,57,1.3.3.3,0 -10121,57,2.3.1,0 -10121,124,2.3,1 -10521,81,1.3.4,1 -10521,58,1.3.4.3,0 -10521,80,1.3.4,1 -11675,93,1.3.1,1 -11675,68,1.3.1.3,0 -11678,99,1.3.1,1 -11678,85,1.3.1,1 -11678,68,1.3.1.3,0 -12837,126,2.1.3,1 -12837,74,2.1.3,1 -12837,73,2.1.3,1 -12837,79,2.1.3.2,0 -13084,108,2.1.3,1 -13084,73,2.1.3,1 -13084,106,2.1.3,1 -13084,79,2.1.3.7,0 -13084,148,2.1.3,1 -13084,126,2.1.3,1 -16647,115,1.1.3,0 -16647,94,1.1,1 -16903,85,2.1.1,1 -16903,117,2.1.1.4,0 -17379,119,1.3.1.1,0 -17379,100,1.3.1,1 -17922,126,2.1.1.5,0 -17922,80,2.1.1,1 -18824,85,2.1.1,1 -18824,131,2.1.1.5,0 -19562,135,2.2,0 -19562,98,2,1 -21447,159,2.1.2,1 -21447,149,2.1.2.2,0 diff --git a/output_result/risk/g_bom_exp_id_1.png b/output_result/risk/g_bom_exp_id_1.png index a7131b7..7ea3357 100644 Binary files a/output_result/risk/g_bom_exp_id_1.png and b/output_result/risk/g_bom_exp_id_1.png differ diff --git a/查看进度.py b/查看进度.py index 72b67f1..c493475 100644 --- a/查看进度.py +++ b/查看进度.py @@ -28,7 +28,7 @@ def visualize_progress(): data.setdefault(flag, 0) # 准备数据 - labels = ['未完成 (-1)', '完成(0)', '等待 (1)'] + labels = ['未完成 (-1)', '计算中(0)', '完成 (1)'] values = [data[-1], data[0], data[1]] # 清空之前的绘图内容 @@ -46,7 +46,7 @@ def visualize_progress(): ax.text(i, v + 0.5, str(v), ha='center', fontsize=12) # 刷新绘图 - plt.pause(0.1) # 暂停一段时间以更新图表 + plt.pause(1) # 暂停一段时间以更新图表 # 关闭窗口时,停止交互模式 -plt.ioff() +# plt.ioff()