commit 1f643c64e4dc5f8c189134ae0491e40c359357c7
Author: Cricial <2911646453@qq.com>
Date: Sat Aug 24 11:20:13 2024 +0800
1
diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..359bb53
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,3 @@
+# 默认忽略的文件
+/shelf/
+/workspace.xml
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000..105ce2d
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/mesa.iml b/.idea/mesa.iml
new file mode 100644
index 0000000..b6973c7
--- /dev/null
+++ b/.idea/mesa.iml
@@ -0,0 +1,10 @@
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000..e557d17
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,7 @@
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..e08ffd7
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000..94a25f7
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/SQL_analysis_experiment.sql b/SQL_analysis_experiment.sql
new file mode 100644
index 0000000..7894992
--- /dev/null
+++ b/SQL_analysis_experiment.sql
@@ -0,0 +1,85 @@
+select distinct experiment.idx_scenario,
+n_max_trial, prf_size, prf_conn, cap_limit_prob_type, cap_limit_level, diff_new_conn, remove_t, netw_prf_n,
+mean_count_firm_prod, mean_count_firm, mean_count_prod,
+mean_max_ts_firm_prod, mean_max_ts_firm, mean_max_ts_prod,
+mean_n_remove_firm_prod, mean_n_all_prod_remove_firm, mean_end_ts
+from iiabmdb.with_exp_experiment as experiment
+left join
+(
+select
+idx_scenario,
+sum(count_firm_prod) / count(*) as mean_count_firm_prod, # Note to use count(*), to include NULL
+sum(count_firm) / count(*) as mean_count_firm,
+sum(count_prod) / count(*) as mean_count_prod,
+sum(max_ts_firm_prod) / count(*) as mean_max_ts_firm_prod,
+sum(max_ts_firm) / count(*) as mean_max_ts_firm,
+sum(max_ts_prod) / count(*) as mean_max_ts_prod,
+sum(n_remove_firm_prod) / count(*) as mean_n_remove_firm_prod,
+sum(n_all_prod_remove_firm) / count(*) as mean_n_all_prod_remove_firm,
+sum(end_ts) / count(*) as mean_end_ts
+from (
+select sample.id, idx_scenario,
+count_firm_prod, count_firm, count_prod,
+max_ts_firm_prod, max_ts_firm, max_ts_prod,
+n_remove_firm_prod, n_all_prod_remove_firm, end_ts
+from iiabmdb.with_exp_sample as sample
+# 1 2 3 + 9
+left join iiabmdb.with_exp_experiment as experiment
+on sample.e_id = experiment.id
+left join (select s_id,
+count(distinct id_firm, id_product) as count_firm_prod,
+count(distinct id_firm) as count_firm,
+count(distinct id_product) as count_prod,
+max(ts) as end_ts
+from iiabmdb.with_exp_result group by s_id) as s_count
+on sample.id = s_count.s_id
+# 4
+left join # firm prod
+(select s_id, max(ts) as max_ts_firm_prod from
+(select s_id, id_firm, id_product, min(ts) as ts
+from iiabmdb.with_exp_result
+where `status` = "D"
+group by s_id, id_firm, id_product) as ts
+group by s_id) as s_max_ts_firm_prod
+on sample.id = s_max_ts_firm_prod.s_id
+# 5
+left join # firm
+(select s_id, max(ts) as max_ts_firm from
+(select s_id, id_firm, min(ts) as ts
+from iiabmdb.with_exp_result
+where `status` = "D"
+group by s_id, id_firm) as ts
+group by s_id) as s_max_ts_firm
+on sample.id = s_max_ts_firm.s_id
+# 6
+left join # prod
+(select s_id, max(ts) as max_ts_prod from
+(select s_id, id_product, min(ts) as ts
+from iiabmdb.with_exp_result
+where `status` = "D"
+group by s_id, id_product) as ts
+group by s_id) as s_max_ts_prod
+on sample.id = s_max_ts_prod.s_id
+# 7
+left join
+(select s_id, count(distinct id_firm, id_product) as n_remove_firm_prod
+from iiabmdb.with_exp_result
+where `status` = "R"
+group by s_id) as s_n_remove_firm_prod
+on sample.id = s_n_remove_firm_prod.s_id
+# 8
+left join
+(select s_id, count(distinct id_firm) as n_all_prod_remove_firm from
+(select s_id, id_firm, count(distinct id_product) as n_remove_prod
+from iiabmdb.with_exp_result
+where `status` = "R"
+group by s_id, id_firm) as s_n_remove_prod
+left join iiabmdb_basic_info.firm_n_prod as firm_n_prod
+on s_n_remove_prod.id_firm = firm_n_prod.code
+where n_remove_prod = n_prod
+group by s_id) as s_n_all_prod_remove_firm
+on sample.id = s_n_all_prod_remove_firm.s_id
+) as secnario_count
+group by idx_scenario
+) as secnario_mean
+on experiment.idx_scenario = secnario_mean.idx_scenario;
\ No newline at end of file
diff --git a/SQL_analysis_risk.sql b/SQL_analysis_risk.sql
new file mode 100644
index 0000000..621186e
--- /dev/null
+++ b/SQL_analysis_risk.sql
@@ -0,0 +1,12 @@
+select * from
+(select s_id, id_firm, id_product, min(ts) as ts from iiabmdb.without_exp_result
+where `status` = 'D'
+group by s_id, id_firm, id_product) as s_disrupt
+where s_id in
+(select s_id from
+(select s_id, id_firm, id_product, min(ts) as ts from iiabmdb.without_exp_result
+where `status` = 'D'
+group by s_id, id_firm, id_product) as t
+group by s_id
+having count(*) > 1)
+order by s_id;
\ No newline at end of file
diff --git a/SQL_db_user_create.sql b/SQL_db_user_create.sql
new file mode 100644
index 0000000..eaef3bd
--- /dev/null
+++ b/SQL_db_user_create.sql
@@ -0,0 +1,6 @@
+CREATE USER 'iiabm_user'@'localhost' IDENTIFIED WITH authentication_plugin BY 'iiabm_pwd';
+
+-- CREATE USER 'iiabm_user'@'localhost' IDENTIFIED BY 'iiabm_pwd';
+
+GRANT ALL PRIVILEGES ON iiabmdb.* TO 'iiabm_user'@'localhost';
+FLUSH PRIVILEGES;
\ No newline at end of file
diff --git a/SQL_export_high_risk_setting.sql b/SQL_export_high_risk_setting.sql
new file mode 100644
index 0000000..cfa3a17
--- /dev/null
+++ b/SQL_export_high_risk_setting.sql
@@ -0,0 +1,15 @@
+select e_id, n_disrupt_sample, total_n_disrupt_firm_prod_experiment, dct_lst_init_disrupt_firm_prod from iiabmdb.without_exp_experiment as experiment
+inner join (
+select e_id, count(id) as n_disrupt_sample, sum(n_disrupt_firm_prod_sample) as total_n_disrupt_firm_prod_experiment from iiabmdb.without_exp_sample as sample
+inner join (
+select * from
+(select s_id, COUNT(DISTINCT id_firm, id_product) as n_disrupt_firm_prod_sample from iiabmdb.without_exp_result group by s_id
+) as count_disrupt_firm_prod_sample
+where n_disrupt_firm_prod_sample > 1
+) as disrupt_sample
+on sample.id = disrupt_sample.s_id
+group by e_id
+) as disrupt_experiment
+on experiment.id = disrupt_experiment.e_id
+order by n_disrupt_sample desc, total_n_disrupt_firm_prod_experiment desc
+limit 0, 95;
\ No newline at end of file
diff --git a/SQL_migrate_db.sql b/SQL_migrate_db.sql
new file mode 100644
index 0000000..e5c0555
--- /dev/null
+++ b/SQL_migrate_db.sql
@@ -0,0 +1,13 @@
+CREATE DATABASE iiabmdb20230829;
+RENAME TABLE iiabmdb.not_test_experiment TO iiabmdb20230829.not_test_experiment,
+iiabmdb.not_test_result TO iiabmdb20230829.not_test_result,
+iiabmdb.not_test_sample TO iiabmdb20230829.not_test_sample,
+iiabmdb.test_experiment TO iiabmdb20230829.test_experiment,
+iiabmdb.test_result TO iiabmdb20230829.test_result,
+iiabmdb.test_sample TO iiabmdb20230829.test_sample;
+RENAME TABLE iiabmdb.with_exp_experiment TO iiabmdb20230829.with_exp_experiment,
+iiabmdb.with_exp_result TO iiabmdb20230829.with_exp_result,
+iiabmdb.with_exp_sample TO iiabmdb20230829.with_exp_sample,
+iiabmdb.without_exp_experiment TO iiabmdb20230829.without_exp_experiment,
+iiabmdb.without_exp_result TO iiabmdb20230829.without_exp_result,
+iiabmdb.without_exp_sample TO iiabmdb20230829.without_exp_sample;
\ No newline at end of file
diff --git a/computation.py b/computation.py
new file mode 100644
index 0000000..8712c49
--- /dev/null
+++ b/computation.py
@@ -0,0 +1,40 @@
+import os
+import datetime
+
+from model import Model
+
+from typing import TYPE_CHECKING
+if TYPE_CHECKING:
+ from controller_db import ControllerDB
+
+
+class Computation:
+
+ def __init__(self, c_db: 'ControllerDB'):
+ # 控制不同进程 计算不同的样本 但使用同一个 数据库 c_db
+ self.c_db = c_db
+ self.pid = os.getpid()
+
+ def run(self, str_code='0', s_id=None):
+ sample_random = self.c_db.fetch_a_sample(s_id)
+ if sample_random is None:
+ return True
+
+ # lock this row by update is_done_flag to 0 将运行后的样本设置为 flag 0
+ self.c_db.lock_the_sample(sample_random)
+ print(
+ f"Pid {self.pid} ({str_code}) is running "
+ f"sample {sample_random.id} at {datetime.datetime.now()}")
+ # 将sample 对应的 experiment 的一系列值 和 参数值 传入 模型 中 包括列名 和 值
+ dct_exp = {column: getattr(sample_random.experiment, column)
+ for column in sample_random.experiment.__table__.c.keys()}
+ # 删除不需要的 主键
+ del dct_exp['id']
+
+ dct_sample_para = {'sample': sample_random,
+ 'seed': sample_random.seed,
+ **dct_exp}
+ model = Model(dct_sample_para)
+
+ model.run(display=False)
+ return False
diff --git a/conf_db.yaml b/conf_db.yaml
new file mode 100644
index 0000000..74e37e2
--- /dev/null
+++ b/conf_db.yaml
@@ -0,0 +1,10 @@
+# read by orm
+is_local_db: True
+
+local:
+ user_name: iiabm_user
+ password: iiabm_pwd
+ db_name: iiabmdb
+ address: 'localhost'
+ port: 3306
+
diff --git a/conf_db_prefix.yaml b/conf_db_prefix.yaml
new file mode 100644
index 0000000..0e48fd8
--- /dev/null
+++ b/conf_db_prefix.yaml
@@ -0,0 +1 @@
+db_name_prefix: without_exp
diff --git a/conf_experiment.yaml b/conf_experiment.yaml
new file mode 100644
index 0000000..77309ea
--- /dev/null
+++ b/conf_experiment.yaml
@@ -0,0 +1,12 @@
+# read by ControllerDB
+
+# run settings
+meta_seed: 2
+
+test: # only for test scenarios
+ n_sample: 1
+ n_iter: 100
+
+not_test: # normal scenarios
+ n_sample: 50
+ n_iter: 100
diff --git a/controller_db.py b/controller_db.py
new file mode 100644
index 0000000..3314ab7
--- /dev/null
+++ b/controller_db.py
@@ -0,0 +1,333 @@
+# -*- coding: utf-8 -*-
+from orm import db_session, engine, Base, ins, connection
+from orm import Experiment, Sample, Result
+from sqlalchemy.exc import OperationalError
+from sqlalchemy import text
+import yaml
+import random
+import pandas as pd
+import platform
+import networkx as nx
+import json
+import pickle
+
+
+class ControllerDB:
+ is_with_exp: bool
+ dct_parameter = None
+ is_test: bool = None
+ db_name_prefix: str = None
+ reset_flag: int
+
+ lst_saved_s_id: list
+
+ def __init__(self, prefix, reset_flag=0):
+ with open('conf_experiment.yaml') as yaml_file:
+ dct_conf_experiment = yaml.full_load(yaml_file)
+ assert prefix in ['test', 'without_exp', 'with_exp'], "db name not in test, without_exp, with_exp"
+
+ self.is_test = prefix == 'test'
+ self.is_with_exp = False if prefix == 'test' or prefix == 'without_exp' else True
+ self.db_name_prefix = prefix
+ dct_para_in_test = dct_conf_experiment['test'] if self.is_test else dct_conf_experiment['not_test']
+ self.dct_parameter = {'meta_seed': dct_conf_experiment['meta_seed'] , **dct_para_in_test}
+
+ print(self.dct_parameter)
+ # 0, not reset; 1, reset self; 2, reset all
+ self.reset_flag = reset_flag
+ self.is_exist = False
+ self.lst_saved_s_id = []
+
+
+ def init_tables(self):
+ self.fill_experiment_table()
+ self.fill_sample_table()
+
+ def fill_experiment_table(self):
+ Firm = pd.read_csv("input_data/Firm_amended.csv")
+ Firm['Code'] = Firm['Code'].astype('string')
+ Firm.fillna(0, inplace=True)
+
+ # fill dct_lst_init_disrupt_firm_prod
+ # 存储 公司-在供应链结点的位置.. 0 :‘1.1’
+ list_dct = [] #存储 公司编码code 和对应的产业链 结点
+ if self.is_with_exp:
+ #对于方差分析时候使用
+ with open('SQL_export_high_risk_setting.sql', 'r') as f:
+ str_sql = text(f.read())
+ result = pd.read_sql(sql=str_sql, con=connection)
+ result['dct_lst_init_disrupt_firm_prod'] = \
+ result['dct_lst_init_disrupt_firm_prod'].apply(
+ lambda x: pickle.loads(x))
+ list_dct = result['dct_lst_init_disrupt_firm_prod'].to_list()
+ else:
+ # 行索引 (index):这一行在数据帧中的索引值。
+ # 行数据 (row):这一行的数据,是一个 pandas.Series 对象,包含该行的所有列和值。
+ for _, row in Firm.iterrows():
+ code = row['Code']
+ row = row['1':]
+ for product_code in row.index[row == 1].to_list():
+ dct = {code: [product_code]}
+ list_dct.append(dct)
+
+ # fill g_bom
+ # 结点属性值 相当于 图上点的 原始 产品名称
+ BomNodes = pd.read_csv('input_data/BomNodes.csv', index_col=0)
+ BomNodes.set_index('Code', inplace=True)
+
+ BomCateNet = pd.read_csv('input_data/BomCateNet.csv', index_col=0)
+ BomCateNet.fillna(0, inplace=True)
+ # 创建 可以多边的有向图 同时 转置操作 使得 上游指向下游结点 也就是 1.1.1 - 1.1 类似这种
+ g_bom = nx.from_pandas_adjacency(BomCateNet.T,
+ create_using=nx.MultiDiGraph())
+ # 填充每一个结点 的具体内容 通过 相同的 code 并且通过BomNodes.loc[code].to_dict()字典化 格式类似 格式 { code(0) : {level: 0 ,name: 工业互联网 }}
+ bom_labels_dict = {}
+ for code in g_bom.nodes:
+ bom_labels_dict[code] = BomNodes.loc[code].to_dict()
+ # 分配属性 给每一个结点 获得类似 格式:{1: {'label': 'A', 'value': 10},
+ nx.set_node_attributes(g_bom, bom_labels_dict)
+ # 改为json 格式
+ g_product_js = json.dumps(nx.adjacency_data(g_bom))
+
+ # insert exp
+ df_xv = pd.read_csv(
+ "input_data/"
+ f"xv_{'with_exp' if self.is_with_exp else 'without_exp'}.csv",
+ index_col=None)
+ # read the OA table
+ df_oa = pd.read_csv(
+ "input_data/"
+ f"oa_{'with_exp' if self.is_with_exp else 'without_exp'}.csv",
+ index_col=None)
+ # .shape[1] 列数 .iloc 访问特定的值 而不是标签
+ df_oa = df_oa.iloc[:, 0:df_xv.shape[1]]
+ # idx_scenario 是 0 指行 idx_init_removal 指 索引 0.. dct_init_removal 键 code 公司 g_product_js 图的json数据 dct_exp_para 解码 全局参数xv-
+ for idx_scenario, row in df_oa.iterrows():
+ dct_exp_para = {}
+ for idx_col, para_level in enumerate(row):
+ dct_exp_para[df_xv.columns[idx_col]] = \
+ df_xv.iloc[para_level, idx_col]
+ # different initial removal 只会得到 键 和 值
+ for idx_init_removal, dct_init_removal in enumerate(list_dct):
+ self.add_experiment_1(idx_scenario,
+ idx_init_removal,
+ dct_init_removal,
+ g_product_js,
+ **dct_exp_para)
+ print(f"Inserted experiment for scenario {idx_scenario}, "
+ f"init_removal {idx_init_removal}!")
+
+ def add_experiment_1(self, idx_scenario, idx_init_removal,
+ dct_lst_init_disrupt_firm_prod, g_bom,
+ n_max_trial, prf_size, prf_conn,
+ cap_limit_prob_type, cap_limit_level,
+ diff_new_conn, remove_t, netw_prf_n):
+ e = Experiment(
+ idx_scenario=idx_scenario,
+ idx_init_removal=idx_init_removal,
+ n_sample=int(self.dct_parameter['n_sample']),
+ n_iter=int(self.dct_parameter['n_iter']),
+ dct_lst_init_disrupt_firm_prod=dct_lst_init_disrupt_firm_prod,
+ g_bom=g_bom,
+ n_max_trial=n_max_trial,
+ prf_size=prf_size,
+ prf_conn=prf_conn,
+ cap_limit_prob_type=cap_limit_prob_type,
+ cap_limit_level=cap_limit_level,
+ diff_new_conn=diff_new_conn,
+ remove_t=remove_t,
+ netw_prf_n=netw_prf_n
+ )
+ db_session.add(e)
+ db_session.commit()
+
+ def fill_sample_table(self):
+ rng = random.Random(self.dct_parameter['meta_seed'])
+ # 根据样本数目 设置 32 位随机整数
+ lst_seed = [
+ rng.getrandbits(32)
+ for _ in range(int(self.dct_parameter['n_sample']))
+ ]
+ lst_exp = db_session.query(Experiment).all()
+
+ lst_sample = []
+ for experiment in lst_exp:
+ # idx_sample: 1-50
+ for idx_sample in range(int(experiment.n_sample)):
+ s = Sample(e_id=experiment.id,
+ idx_sample=idx_sample + 1,
+ seed=lst_seed[idx_sample],
+ is_done_flag=-1)
+ lst_sample.append(s)
+ db_session.bulk_save_objects(lst_sample)
+ db_session.commit()
+ print(f'Inserted {len(lst_sample)} samples!')
+
+ def reset_db(self, force_drop=False):
+ # first, check if tables exist
+ lst_table_obj = [
+ Base.metadata.tables[str_table]
+ for str_table in ins.get_table_names()
+ if str_table.startswith(self.db_name_prefix)
+ ]
+ self.is_exist = len(lst_table_obj) > 0
+ if force_drop:
+ self.force_drop_db(lst_table_obj)
+ # while is_exist:
+ # a_table = random.choice(lst_table_obj)
+ # try:
+ # Base.metadata.drop_all(bind=engine, tables=[a_table])
+ # except KeyError:
+ # pass
+ # except OperationalError:
+ # pass
+ # else:
+ # lst_table_obj.remove(a_table)
+ # print(
+ # f"Table {a_table.name} is dropped "
+ # f"for exp: {self.db_name_prefix}!!!"
+ # )
+ # finally:
+ # is_exist = len(lst_table_obj) > 0
+
+ if self.is_exist:
+ print(
+ f"All tables exist. No need to reset "
+ f"for exp: {self.db_name_prefix}."
+ )
+ # change the is_done_flag from 0 to -1
+ # rerun the in-finished tasks
+ self.is_exist_reset_flag_resset_db()
+ # if self.reset_flag > 0:
+ # if self.reset_flag == 2:
+ # sample = db_session.query(Sample).filter(
+ # Sample.is_done_flag == 0)
+ # elif self.reset_flag == 1:
+ # sample = db_session.query(Sample).filter(
+ # Sample.is_done_flag == 0,
+ # Sample.computer_name == platform.node())
+ # else:
+ # raise ValueError('Wrong reset flag')
+ # if sample.count() > 0:
+ # for s in sample:
+ # qry_result = db_session.query(Result).filter_by(
+ # s_id=s.id)
+ # if qry_result.count() > 0:
+ # db_session.query(Result).filter(s_id=s.id).delete()
+ # db_session.commit()
+ # s.is_done_flag = -1
+ # db_session.commit()
+ # print(f"Reset the sample id {s.id} flag from 0 to -1")
+
+ else:
+ # 不存在则重新生成所有的表结构
+ Base.metadata.create_all(bind=engine)
+ self.init_tables()
+ print(
+ f"All tables are just created and initialized "
+ f"for exp: {self.db_name_prefix}."
+ )
+
+ def force_drop_db(self, lst_table_obj):
+ self.is_exist = len(lst_table_obj) > 0
+ while self.is_exist:
+ a_table = random.choice(lst_table_obj)
+ try:
+ Base.metadata.drop_all(bind=engine, tables=[a_table])
+ except KeyError:
+ pass
+ except OperationalError:
+ pass
+ else:
+ lst_table_obj.remove(a_table)
+ print(
+ f"Table {a_table.name} is dropped "
+ f"for exp: {self.db_name_prefix}!!!"
+ )
+ finally:
+ self.is_exist = len(lst_table_obj) > 0
+
+ def is_exist_reset_flag_resset_db(self):
+ if self.reset_flag > 0:
+ if self.reset_flag == 2:
+ sample = db_session.query(Sample).filter(
+ Sample.is_done_flag == 0)
+ elif self.reset_flag == 1:
+ sample = db_session.query(Sample).filter(
+ Sample.is_done_flag == 0,
+ Sample.computer_name == platform.node())
+ else:
+ raise ValueError('Wrong reset flag')
+ if sample.count() > 0:
+ for s in sample:
+ qry_result = db_session.query(Result).filter_by(
+ s_id=s.id)
+ if qry_result.count() > 0:
+ db_session.query(Result).filter(s_id=s.id).delete()
+ db_session.commit()
+ s.is_done_flag = -1
+ db_session.commit()
+ print(f"Reset the sample id {s.id} flag from 0 to -1")
+
+ def prepare_list_sample(self):
+ # 为了符合前面 重置表里面存在 重置本机 或者重置全部 或者不重置的部分 这个部分的 关于样本运行也得重新拿出来
+ # 查找一个风险事件中 50 个样本
+ res = db_session.execute(
+ text(f"SELECT count(*) FROM {self.db_name_prefix}_sample s, "
+ f"{self.db_name_prefix}_experiment e WHERE s.e_id=e.id"
+ )).scalar()
+ # 控制 n_sample数量 作为后面的参数
+ n_sample = 0 if res is None else res
+ print(f'There are a total of {n_sample} samples.')
+ # 查找 is_done_flag = -1 也就是没有运行的 样本 运行后会改为0
+ res = db_session.execute(
+ text(f"SELECT id FROM {self.db_name_prefix}_sample "
+ f"WHERE is_done_flag = -1"
+ ))
+ for row in res:
+ s_id = row[0]
+ self.lst_saved_s_id.append(s_id)
+
+ @staticmethod
+ def select_random_sample(lst_s_id):
+ while 1:
+ if len(lst_s_id) == 0:
+ return None
+ s_id = random.choice(lst_s_id)
+ lst_s_id.remove(s_id)
+ res = db_session.query(Sample).filter(Sample.id == int(s_id),
+ Sample.is_done_flag == -1)
+ if res.count() == 1:
+ return res[0]
+
+ def fetch_a_sample(self, s_id=None):
+ # 由Computation 调用 返回 sample对象 同时给出 2中 指定访问模式 抓取特定的 样本 通过s_id
+ # 默认访问 flag为-1的 lst_saved_s_id
+ if s_id is not None:
+ res = db_session.query(Sample).filter(Sample.id == int(s_id))
+ if res.count() == 0:
+ return None
+ else:
+ return res[0]
+
+ sample = self.select_random_sample(self.lst_saved_s_id)
+ if sample is not None:
+ return sample
+
+ return None
+
+ @staticmethod
+ def lock_the_sample(sample: Sample):
+ sample.is_done_flag, sample.computer_name = 0, platform.node()
+ db_session.commit()
+
+
+if __name__ == '__main__':
+ print("Testing the database connection...")
+ try:
+ controller_db = ControllerDB('test')
+ Base.metadata.create_all(bind=engine)
+ except Exception as e:
+ print("Failed to connect to the database!")
+ print(e)
+ exit(1)
diff --git a/firm.py b/firm.py
new file mode 100644
index 0000000..241aa3c
--- /dev/null
+++ b/firm.py
@@ -0,0 +1,199 @@
+from mesa import Agent
+
+
+class FirmAgent(Agent):
+ def __init__(self, unique_id, model, code, type_region, revenue_log, a_lst_product):
+ # 调用超类的 __init__ 方法
+ super().__init__(unique_id, model)
+
+ # 初始化模型中的网络引用
+ self.firm_network = self.model.firm_network
+ self.product_network = self.model.product_network
+
+ # 初始化代理自身的属性
+ self.code = code
+ self.type_region = type_region
+
+ self.size_stat = []
+ self.dct_prod_up_prod_stat = {}
+ self.dct_prod_capacity = {}
+
+ # 试验中的参数
+ self.dct_n_trial_up_prod_disrupted = {}
+ self.dct_cand_alt_supp_up_prod_disrupted = {}
+ self.dct_request_prod_from_firm = {}
+
+ # 外部变量
+ self.is_prf_size = self.model.is_prf_size
+ self.is_prf_conn = bool(self.model.prf_conn)
+ self.str_cap_limit_prob_type = str(self.model.cap_limit_prob_type)
+ self.flt_cap_limit_level = float(self.model.cap_limit_level)
+ self.flt_diff_new_conn = float(self.model.diff_new_conn)
+
+ # 初始化 size_stat
+ self.size_stat.append((revenue_log, 0))
+
+ # 初始化 dct_prod_up_prod_stat
+ for prod in a_lst_product:
+ self.dct_prod_up_prod_stat[prod] = {
+ 'p_stat': [('N', 0)],
+ 's_stat': {up_prod: {'stat': True, 'set_disrupt_firm': set()}
+ for up_prod in prod.a_predecessors()}
+ }
+
+ # 初始化额外容量 (dct_prod_capacity)
+ for product in a_lst_product:
+ assert self.str_cap_limit_prob_type in ['uniform', 'normal'], \
+ "cap_limit_prob_type must be either 'uniform' or 'normal'"
+ extra_cap_mean = self.size_stat[0][0] / self.flt_cap_limit_level
+ if self.str_cap_limit_prob_type == 'uniform':
+ extra_cap = self.model.random.randint(extra_cap_mean - 2, extra_cap_mean + 2)
+ elif self.str_cap_limit_prob_type == 'normal':
+ extra_cap = self.model.random.normalvariate(extra_cap_mean, 1)
+ extra_cap = max(0, round(extra_cap))
+ self.dct_prod_capacity[product] = extra_cap
+
+ def remove_edge_to_cus(self, disrupted_prod):
+ lst_out_edge = list(
+ self.firm_network.get_neighbors(self.unique_id))
+ for n2 in lst_out_edge:
+ edge_data = self.firm_network.G.edges[self.unique_id, n2]
+ if edge_data.get('Product') == disrupted_prod.code:
+ customer = self.model.schedule.get_agent(n2)
+ for prod in customer.dct_prod_up_prod_stat.keys():
+ if disrupted_prod in customer.dct_prod_up_prod_stat[prod]['s_stat'].keys():
+ customer.dct_prod_up_prod_stat[prod]['s_stat'][disrupted_prod][
+ 'set_disrupt_firm'].add(self)
+ self.firm_network.remove_edge(self.unique_id, n2)
+
+ def disrupt_cus_prod(self, prod, disrupted_up_prod):
+ num_lost = len(self.dct_prod_up_prod_stat[prod]['s_stat'][disrupted_up_prod]['set_disrupt_firm'])
+ num_remain = len([
+ u for u in self.firm_network.get_neighbors(self.unique_id)
+ if self.firm_network.G.edges[u, self.unique_id].get('Product') == disrupted_up_prod.code])
+ lost_percent = num_lost / (num_lost + num_remain)
+ lst_size = [firm.size_stat[-1][0] for firm in self.model.schedule.agents]
+ std_size = (self.size_stat[-1][0] - min(lst_size) + 1) / (max(lst_size) - min(lst_size) + 1)
+
+ prob_disrupt = 1 - std_size * (1 - lost_percent)
+ if self.random.choice([True, False], p=[prob_disrupt, 1 - prob_disrupt]):
+ self.dct_n_trial_up_prod_disrupted[disrupted_up_prod] = 0
+ self.dct_prod_up_prod_stat[prod]['s_stat'][disrupted_up_prod]['stat'] = False
+ status, _ = self.dct_prod_up_prod_stat[prod]['p_stat'][-1]
+ if status != 'D':
+ self.dct_prod_up_prod_stat[prod]['p_stat'].append(('D', self.model.schedule.time))
+
+ def seek_alt_supply(self, product):
+ if self.dct_n_trial_up_prod_disrupted[product] <= self.model.int_n_max_trial:
+ if self.dct_n_trial_up_prod_disrupted[product] == 0:
+ self.dct_cand_alt_supp_up_prod_disrupted[product] = [
+ firm for firm in self.model.schedule.agents
+ if firm.is_prod_in_current_normal(product)]
+ if self.dct_cand_alt_supp_up_prod_disrupted[product]:
+ lst_firm_connect = []
+ if self.is_prf_conn:
+ for firm in self.dct_cand_alt_supp_up_prod_disrupted[product]:
+ if self.firm_network.G.has_edge(self.unique_id, firm.unique_id) or \
+ self.firm_network.G.has_edge(firm.unique_id, self.unique_id):
+ lst_firm_connect.append(firm)
+ if len(lst_firm_connect) == 0:
+ if self.is_prf_size:
+ lst_size = [firm.size_stat[-1][0] for firm in self.dct_cand_alt_supp_up_prod_disrupted[product]]
+ lst_prob = [size / sum(lst_size) for size in lst_size]
+ select_alt_supply = \
+ self.random.choices(self.dct_cand_alt_supp_up_prod_disrupted[product], weights=lst_prob)[0]
+ else:
+ select_alt_supply = self.random.choice(self.dct_cand_alt_supp_up_prod_disrupted[product])
+ elif len(lst_firm_connect) > 0:
+ if self.is_prf_size:
+ lst_firm_size = [firm.size_stat[-1][0] for firm in lst_firm_connect]
+ lst_prob = [size / sum(lst_firm_size) for size in lst_firm_size]
+ select_alt_supply = self.random.choices(lst_firm_connect, weights=lst_prob)[0]
+ else:
+ select_alt_supply = self.random.choice(lst_firm_connect)
+
+ assert select_alt_supply.is_prod_in_current_normal(product)
+
+ if product in select_alt_supply.dct_request_prod_from_firm:
+ select_alt_supply.dct_request_prod_from_firm[product].append(self)
+ else:
+ select_alt_supply.dct_request_prod_from_firm[product] = [self]
+
+ self.dct_n_trial_up_prod_disrupted[product] += 1
+
+ def handle_request(self):
+ for product, lst_firm in self.dct_request_prod_from_firm.items():
+ if self.dct_prod_capacity[product] > 0:
+ if len(lst_firm) == 0:
+ continue
+ elif len(lst_firm) == 1:
+ self.accept_request(lst_firm[0], product)
+ elif len(lst_firm) > 1:
+ lst_firm_connect = []
+ if self.is_prf_conn:
+ for firm in lst_firm:
+ if self.firm_network.G.has_edge(self.unique_id, firm.unique_id) or \
+ self.firm_network.G.has_edge(firm.unique_id, self.unique_id):
+ lst_firm_connect.append(firm)
+ if len(lst_firm_connect) == 0:
+ if self.is_prf_size:
+ lst_firm_size = [firm.size_stat[-1][0] for firm in lst_firm]
+ lst_prob = [size / sum(lst_firm_size) for size in lst_firm_size]
+ select_customer = self.random.choices(lst_firm, weights=lst_prob)[0]
+ else:
+ select_customer = self.random.choice(lst_firm)
+ self.accept_request(select_customer, product)
+ elif len(lst_firm_connect) > 0:
+ if self.is_prf_size:
+ lst_firm_size = [firm.size_stat[-1][0] for firm in lst_firm_connect]
+ lst_prob = [size / sum(lst_firm_size) for size in lst_firm_size]
+ select_customer = self.random.choices(lst_firm_connect, weights=lst_prob)[0]
+ else:
+ select_customer = self.random.choice(lst_firm_connect)
+ self.accept_request(select_customer, product)
+ else:
+ for down_firm in lst_firm:
+ down_firm.dct_cand_alt_supp_up_prod_disrupted[product].remove(self)
+
+ def accept_request(self, down_firm, product):
+ if self.firm_network.G.has_edge(self.unique_id, down_firm.unique_id) or \
+ self.firm_network.G.has_edge(down_firm.unique_id, self.unique_id):
+ prod_accept = 1.0
+ else:
+ prod_accept = self.flt_diff_new_conn
+ if self.random.choice([True, False], p=[prod_accept, 1 - prod_accept]):
+ self.firm_network.G.add_edge(self.unique_id, down_firm.unique_id, Product=product.code)
+ self.dct_prod_capacity[product] -= 1
+ self.dct_request_prod_from_firm[product].remove(down_firm)
+
+ for prod in down_firm.dct_prod_up_prod_stat.keys():
+ if product in down_firm.dct_prod_up_prod_stat[prod]['s_stat']:
+ down_firm.dct_prod_up_prod_stat[prod]['s_stat'][product]['stat'] = True
+ down_firm.dct_prod_up_prod_stat[prod]['p_stat'].append(
+ ('N', self.model.schedule.time))
+ del down_firm.dct_n_trial_up_prod_disrupted[product]
+ del down_firm.dct_cand_alt_supp_up_prod_disrupted[product]
+ else:
+ down_firm.dct_cand_alt_supp_up_prod_disrupted[product].remove(self)
+
+ def clean_before_trial(self):
+ self.dct_request_prod_from_firm = {}
+
+ def clean_before_time_step(self):
+ # Reset the number of trials and candidate suppliers for disrupted products
+ self.dct_n_trial_up_prod_disrupted = dict.fromkeys(self.dct_n_trial_up_prod_disrupted.keys(), 0)
+ self.dct_cand_alt_supp_up_prod_disrupted = {}
+
+ # Update the status of products and refresh disruption sets
+ for prod in self.dct_prod_up_prod_stat.keys():
+ status, ts = self.dct_prod_up_prod_stat[prod]['p_stat'][-1]
+ if ts != self.model.schedule.time:
+ self.dct_prod_up_prod_stat[prod]['p_stat'].append((status, self.model.schedule.time))
+
+ # Refresh the set of disrupted firms
+ for up_prod in self.dct_prod_up_prod_stat[prod]['s_stat'].keys():
+ self.dct_prod_up_prod_stat[prod]['s_stat'][up_prod]['set_disrupt_firm'] = set()
+
+ def step(self):
+ # 在每个时间步进行的操作
+ pass
diff --git a/main.py b/main.py
new file mode 100644
index 0000000..70c84d1
--- /dev/null
+++ b/main.py
@@ -0,0 +1,67 @@
+import os
+import random
+import time
+from multiprocessing import Process
+import argparse
+from computation import Computation
+from sqlalchemy.orm import close_all_sessions
+
+import yaml
+
+from controller_db import ControllerDB
+
+
+def controll_db_and_process(exp_argument, reset_sample_argument, reset_db_argument):
+ from controller_db import ControllerDB
+ controller_db = ControllerDB(exp_argument, reset_flag=reset_sample_argument)
+ # controller_db.reset_db()
+ # force drop
+ controller_db.reset_db(force_drop=reset_db_argument)
+ # 准备样本表
+ controller_db.prepare_list_sample()
+
+ close_all_sessions()
+ # 调用 do_process 利用计算机进行多核处理 仿真 将数据库中
+ do_process(do_computation, controller_db)
+
+
+def do_process(target: object, controller_db: ControllerDB, ):
+ process_list = []
+ for i in range(int(args.job)):
+ p = Process(target=do_computation, args=(controller_db,))
+ p.start()
+ process_list.append(p)
+
+ for i in process_list:
+ i.join()
+
+
+def do_computation(c_db):
+ exp = Computation(c_db)
+
+ while 1:
+ time.sleep(random.uniform(0, 5))
+ is_all_done = exp.run()
+ if is_all_done:
+ break
+
+
+if __name__ == '__main__':
+ # 输入参数
+ parser = argparse.ArgumentParser(description='setting')
+ parser.add_argument('--exp', type=str, default='test')
+ parser.add_argument('--job', type=int, default='3')
+ parser.add_argument('--reset_sample', type=int, default='0')
+ parser.add_argument('--reset_db', type=bool, default=False)
+
+ args = parser.parse_args()
+ # 几核参与进程
+ assert args.job >= 1, 'Number of jobs should >= 1'
+ # 控制参数 利用 prefix_file_name 前缀名字 控制 2项不同的实验
+ prefix_file_name = 'conf_db_prefix.yaml'
+ if os.path.exists(prefix_file_name):
+ os.remove(prefix_file_name)
+ with open(prefix_file_name, 'w', encoding='utf-8') as file:
+ yaml.dump({'db_name_prefix': args.exp}, file)
+ # 数据库连接控制 和 进行模型运行
+ controll_db_and_process(args.exp, args.reset_sample, args.reset_db)
diff --git a/model.py b/model.py
new file mode 100644
index 0000000..9756931
--- /dev/null
+++ b/model.py
@@ -0,0 +1,133 @@
+import json
+
+import networkx as nx
+import pandas as pd
+from mesa import Model
+from mesa.time import RandomActivation
+from mesa.space import MultiGrid
+from mesa.datacollection import DataCollector
+
+from firm import FirmAgent
+from product import ProductAgent
+
+
+class MyModel(Model):
+ def __init__(self, params):
+ # self.num_agents = params['N']
+ # self.grid = MultiGrid(params['width'], params['height'], True)
+ # self.schedule = RandomActivation(self)
+
+ # Initialize parameters from `params`
+ self.sample = params['sample']
+ self.int_stop_ts = 0
+ self.int_n_iter = int(params['n_iter'])
+ self.dct_lst_init_disrupt_firm_prod = params['dct_lst_init_disrupt_firm_prod']
+ # external variable
+ self.int_n_max_trial = int(params['n_max_trial'])
+ self.is_prf_size = bool(params['prf_size'])
+
+ self.remove_t = int(params['remove_t'])
+ self.int_netw_prf_n = int(params['netw_prf_n'])
+
+ self.product_network = None
+ self.firm_network = None
+ self.firm_prod_network = None
+
+ # Initialize product network
+ G_bom = nx.adjacency_graph(json.loads(params['g_bom']))
+ self.product_network = G_bom
+
+ # Initialize firm network
+ self.initialize_firm_network()
+
+ # Initialize firm product network
+ self.initialize_firm_prod_network()
+
+ # Initialize agents
+ self.initialize_agents()
+
+ # Data collector (if needed)
+ self.datacollector = DataCollector(
+ agent_reporters={"Product Code": "code"}
+ )
+
+ def initialize_firm_network(self):
+ # Read firm data and initialize firm network
+ firm = pd.read_csv("input_data/Firm_amended.csv")
+ firm['Code'] = firm['Code'].astype('string')
+ firm.fillna(0, inplace=True)
+ Firm_attr = firm[["Code", "Type_Region", "Revenue_Log"]]
+ firm_product = []
+ for _, row in firm.loc[:, '1':].iterrows():
+ firm_product.append(row[row == 1].index.to_list())
+ Firm_attr['Product_Code'] = firm_product
+ Firm_attr.set_index('Code', inplace=True)
+ G_Firm = nx.MultiDiGraph()
+ G_Firm.add_nodes_from(firm["Code"])
+
+ # Add node attributes
+ firm_labels_dict = {}
+ for code in G_Firm.nodes:
+ firm_labels_dict[code] = Firm_attr.loc[code].to_dict()
+ nx.set_node_attributes(G_Firm, firm_labels_dict)
+
+ # Add edges based on BOM graph
+ self.add_edges_based_on_bom(G_Firm)
+
+ self.firm_network = G_Firm
+
+ def initialize_firm_prod_network(self):
+ # Read firm product data and initialize firm product network
+ firm_prod = pd.read_csv("input_data/Firm_amended.csv")
+ firm_prod.fillna(0, inplace=True)
+ firm_prod = pd.DataFrame({'bool': firm_prod.loc[:, '1':].stack()})
+ firm_prod = firm_prod[firm_prod['bool'] == 1].reset_index()
+ firm_prod.drop('bool', axis=1, inplace=True)
+ firm_prod.rename({'level_0': 'Firm_Code', 'level_1': 'Product_Code'}, axis=1, inplace=True)
+ firm_prod['Firm_Code'] = firm_prod['Firm_Code'].astype('string')
+
+ G_FirmProd = nx.MultiDiGraph()
+ G_FirmProd.add_nodes_from(firm_prod.index)
+
+ # Add node attributes
+ firm_prod_labels_dict = {}
+ for code in firm_prod.index:
+ firm_prod_labels_dict[code] = firm_prod.loc[code].to_dict()
+ nx.set_node_attributes(G_FirmProd, firm_prod_labels_dict)
+
+ self.firm_prod_network = G_FirmProd
+
+ def add_edges_based_on_bom(self, G_Firm):
+ # Logic to add edges to the G_Firm graph based on BOM
+ pass
+
+ def initialize_agents(self):
+ # Initialize product and firm agents
+ for node, attr in self.product_network.nodes(data=True):
+ product = ProductAgent(node, self, code=node, name=attr['Name'])
+ self.schedule.add(product)
+
+ for node, attr in self.firm_network.nodes(data=True):
+ firm_agent = FirmAgent(
+ node,
+ self,
+ code=node,
+ type_region=attr['Type_Region'],
+ revenue_log=attr['Revenue_Log'],
+ a_lst_product=[] # Populate based on products
+ )
+ self.schedule.add(firm_agent)
+
+ # Initialize disruptions
+ self.initialize_disruptions()
+
+ def initialize_disruptions(self):
+ # Set the initial firm product disruptions
+ for firm, products in self.dct_lst_init_disrupt_firm_prod.items():
+ for product in products:
+ if isinstance(firm, FirmAgent):
+ firm.dct_prod_up_prod_stat[product]['p_stat'].append(('D', self.schedule.steps))
+
+ def step(self):
+ self.schedule.step()
+ self.datacollector.collect(self)
diff --git a/orm.py b/orm.py
new file mode 100644
index 0000000..b96d88f
--- /dev/null
+++ b/orm.py
@@ -0,0 +1,114 @@
+# -*- coding: utf-8 -*-
+from sqlalchemy import create_engine, inspect, Inspector
+from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy import (Column, Integer, DECIMAL, String, ForeignKey,
+ BigInteger, DateTime, PickleType, Boolean, Text)
+from sqlalchemy.sql import func
+from sqlalchemy.orm import relationship, Session
+from sqlalchemy.pool import NullPool
+import yaml
+
+with open('conf_db.yaml') as file:
+ dct_conf_db_all = yaml.full_load(file)
+ is_local_db = dct_conf_db_all['is_local_db']
+ if is_local_db:
+ dct_conf_db = dct_conf_db_all['local']
+ else:
+ dct_conf_db = dct_conf_db_all['remote']
+
+with open('conf_db_prefix.yaml') as file:
+ dct_conf_db_prefix = yaml.full_load(file)
+ db_name_prefix = dct_conf_db_prefix['db_name_prefix']
+
+str_login = 'mysql://{}:{}@{}:{}/{}'.format(dct_conf_db['user_name'],
+ dct_conf_db['password'],
+ dct_conf_db['address'],
+ dct_conf_db['port'],
+ dct_conf_db['db_name'])
+# print('DB is {}:{}/{}'.format(dct_conf_db['address'], dct_conf_db['port'], dct_conf_db['db_name']))
+
+# must be null pool to avoid connection lost error
+engine = create_engine(str_login, poolclass=NullPool)
+connection = engine.connect()
+ins: Inspector = inspect(engine)
+
+Base = declarative_base()
+
+db_session = Session(bind=engine)
+
+
+class Experiment(Base):
+ __tablename__ = f"{db_name_prefix}_experiment"
+ id = Column(Integer, primary_key=True, autoincrement=True)
+
+ idx_scenario = Column(Integer, nullable=False)
+ idx_init_removal = Column(Integer, nullable=False)
+
+ # fixed parameters
+ n_sample = Column(Integer, nullable=False)
+ n_iter = Column(Integer, nullable=False)
+
+ # variables
+ dct_lst_init_disrupt_firm_prod = Column(PickleType, nullable=False)
+ g_bom = Column(Text(4294000000), nullable=False)
+
+ n_max_trial = Column(Integer, nullable=False)
+ prf_size = Column(Boolean, nullable=False)
+ prf_conn = Column(Boolean, nullable=False)
+ cap_limit_prob_type = Column(String(16), nullable=False)
+ cap_limit_level = Column(DECIMAL(8, 4), nullable=False)
+ diff_new_conn = Column(DECIMAL(8, 4), nullable=False)
+ remove_t = Column(Integer, nullable=False)
+ netw_prf_n = Column(Integer, nullable=False)
+
+ sample = relationship(
+ 'Sample', back_populates='experiment', lazy='dynamic')
+
+ def __repr__(self):
+ return f''
+
+
+class Sample(Base):
+ __tablename__ = f"{db_name_prefix}_sample"
+ id = Column(Integer, primary_key=True, autoincrement=True)
+ e_id = Column(Integer, ForeignKey('{}.id'.format(
+ f"{db_name_prefix}_experiment")), nullable=False)
+
+ idx_sample = Column(Integer, nullable=False)
+ seed = Column(BigInteger, nullable=False)
+ # -1, waiting; 0, running; 1, done
+ is_done_flag = Column(Integer, nullable=False)
+ computer_name = Column(String(64), nullable=True)
+ ts_done = Column(DateTime(timezone=True), onupdate=func.now())
+ stop_t = Column(Integer, nullable=True)
+
+ g_firm = Column(Text(4294000000), nullable=True)
+
+ experiment = relationship(
+ 'Experiment', back_populates='sample', uselist=False)
+ result = relationship('Result', back_populates='sample', lazy='dynamic')
+
+ def __repr__(self):
+ return f''
+
+
+class Result(Base):
+ __tablename__ = f"{db_name_prefix}_result"
+ id = Column(Integer, primary_key=True, autoincrement=True)
+ s_id = Column(Integer, ForeignKey('{}.id'.format(
+ f"{db_name_prefix}_sample")), nullable=False)
+
+ id_firm = Column(String(10), nullable=False)
+ id_product = Column(String(10), nullable=False)
+ ts = Column(Integer, nullable=False)
+ status = Column(String(5), nullable=False)
+
+ sample = relationship('Sample', back_populates='result', uselist=False)
+
+ def __repr__(self):
+ return f''
+
+
+if __name__ == '__main__':
+ Base.metadata.drop_all()
+ Base.metadata.create_all()
diff --git a/product.py b/product.py
new file mode 100644
index 0000000..950137d
--- /dev/null
+++ b/product.py
@@ -0,0 +1,24 @@
+from mesa import Agent
+
+class ProductAgent(Agent):
+ def __init__(self, unique_id, model, code, name):
+ # 调用超类的 __init__ 方法
+ super().__init__(unique_id, model)
+
+ # 初始化代理属性
+ self.code = code
+ self.name = name
+ self.product_network = self.model.product_network
+
+ def a_successors(self):
+ # Find successors of the current product and return them as a list of ProductAgent
+ successors = list(self.product_network.successors(self))
+ return [self.model.schedule.agents[successor] for successor in successors]
+
+ def a_predecessors(self):
+ # Find predecessors of the current product and return them as a list of ProductAgent
+ predecessors = list(self.product_network.predecessors(self))
+ return [self.model.schedule.agents[predecessor] for predecessor in predecessors]
+ def step(self):
+ # 在每个时间步进行的操作
+ pass
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..69c77f3
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,55 @@
+agentpy==0.1.5
+alabaster==0.7.13
+Babel==2.12.1
+certifi @ file:///C:/b/abs_85o_6fm0se/croot/certifi_1671487778835/work/certifi
+charset-normalizer==3.0.1
+colorama==0.4.6
+cycler==0.11.0
+decorator==5.1.1
+dill==0.3.6
+docutils==0.19
+greenlet==2.0.2
+idna==3.4
+imagesize==1.4.1
+importlib-metadata==6.0.0
+Jinja2==3.1.2
+joblib==1.2.0
+kiwisolver==1.4.4
+MarkupSafe==2.1.2
+matplotlib==3.3.4
+matplotlib-inline==0.1.6
+multiprocess==0.70.14
+mysqlclient==2.1.1
+networkx==2.5
+numpy==1.20.3
+numpydoc==1.1.0
+packaging==23.0
+pandas==1.4.1
+pandas-stubs==1.2.0.39
+Pillow==9.4.0
+Pygments==2.14.0
+pygraphviz @ file:///C:/Users/ASUS/Downloads/pygraphviz-1.9-cp38-cp38-win_amd64.whl
+pyparsing==3.0.9
+python-dateutil==2.8.2
+pytz==2022.7.1
+PyYAML==6.0
+requests==2.28.2
+SALib==1.4.7
+scipy==1.10.1
+six==1.16.0
+snowballstemmer==2.2.0
+Sphinx==6.1.3
+sphinxcontrib-applehelp==1.0.4
+sphinxcontrib-devhelp==1.0.2
+sphinxcontrib-htmlhelp==2.0.1
+sphinxcontrib-jsmath==1.0.1
+sphinxcontrib-qthelp==1.0.3
+sphinxcontrib-serializinghtml==1.1.5
+SQLAlchemy==2.0.5.post1
+traitlets==5.9.0
+typing_extensions==4.5.0
+urllib3==1.26.14
+wincertstore==0.2
+yapf @ file:///tmp/build/80754af9/yapf_1615749224965/work
+zipp==3.15.0
+mesa==2.1.5
diff --git a/requirements_manual_selected_20230304.txt b/requirements_manual_selected_20230304.txt
new file mode 100644
index 0000000..91f2ba8
--- /dev/null
+++ b/requirements_manual_selected_20230304.txt
@@ -0,0 +1,9 @@
+agentpy==0.1.5
+matplotlib==3.3.4
+matplotlib-inline==0.1.6
+networkx==2.5
+numpy==1.20.3
+numpydoc==1.1.0
+pandas==1.4.1
+pandas-stubs==1.2.0.39
+pygraphviz==1.9