This commit is contained in:
Cricial 2024-08-24 11:20:13 +08:00
commit 1f643c64e4
23 changed files with 1168 additions and 0 deletions

3
.idea/.gitignore vendored Normal file
View File

@ -0,0 +1,3 @@
# 默认忽略的文件
/shelf/
/workspace.xml

View File

@ -0,0 +1,6 @@
<component name="InspectionProjectProfileManager">
<settings>
<option name="USE_PROJECT_PROFILE" value="false" />
<version value="1.0" />
</settings>
</component>

10
.idea/mesa.iml Normal file
View File

@ -0,0 +1,10 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$">
<excludeFolder url="file://$MODULE_DIR$/.venv" />
</content>
<orderEntry type="jdk" jdkName="Python 3.8" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>

7
.idea/misc.xml Normal file
View File

@ -0,0 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="Black">
<option name="sdkName" value="Python 3.8" />
</component>
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.8" project-jdk-type="Python SDK" />
</project>

8
.idea/modules.xml Normal file
View File

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/mesa.iml" filepath="$PROJECT_DIR$/.idea/mesa.iml" />
</modules>
</component>
</project>

6
.idea/vcs.xml Normal file
View File

@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="$PROJECT_DIR$" vcs="Git" />
</component>
</project>

View File

@ -0,0 +1,85 @@
select distinct experiment.idx_scenario,
n_max_trial, prf_size, prf_conn, cap_limit_prob_type, cap_limit_level, diff_new_conn, remove_t, netw_prf_n,
mean_count_firm_prod, mean_count_firm, mean_count_prod,
mean_max_ts_firm_prod, mean_max_ts_firm, mean_max_ts_prod,
mean_n_remove_firm_prod, mean_n_all_prod_remove_firm, mean_end_ts
from iiabmdb.with_exp_experiment as experiment
left join
(
select
idx_scenario,
sum(count_firm_prod) / count(*) as mean_count_firm_prod, # Note to use count(*), to include NULL
sum(count_firm) / count(*) as mean_count_firm,
sum(count_prod) / count(*) as mean_count_prod,
sum(max_ts_firm_prod) / count(*) as mean_max_ts_firm_prod,
sum(max_ts_firm) / count(*) as mean_max_ts_firm,
sum(max_ts_prod) / count(*) as mean_max_ts_prod,
sum(n_remove_firm_prod) / count(*) as mean_n_remove_firm_prod,
sum(n_all_prod_remove_firm) / count(*) as mean_n_all_prod_remove_firm,
sum(end_ts) / count(*) as mean_end_ts
from (
select sample.id, idx_scenario,
count_firm_prod, count_firm, count_prod,
max_ts_firm_prod, max_ts_firm, max_ts_prod,
n_remove_firm_prod, n_all_prod_remove_firm, end_ts
from iiabmdb.with_exp_sample as sample
# 1 2 3 + 9
left join iiabmdb.with_exp_experiment as experiment
on sample.e_id = experiment.id
left join (select s_id,
count(distinct id_firm, id_product) as count_firm_prod,
count(distinct id_firm) as count_firm,
count(distinct id_product) as count_prod,
max(ts) as end_ts
from iiabmdb.with_exp_result group by s_id) as s_count
on sample.id = s_count.s_id
# 4
left join # firm prod
(select s_id, max(ts) as max_ts_firm_prod from
(select s_id, id_firm, id_product, min(ts) as ts
from iiabmdb.with_exp_result
where `status` = "D"
group by s_id, id_firm, id_product) as ts
group by s_id) as s_max_ts_firm_prod
on sample.id = s_max_ts_firm_prod.s_id
# 5
left join # firm
(select s_id, max(ts) as max_ts_firm from
(select s_id, id_firm, min(ts) as ts
from iiabmdb.with_exp_result
where `status` = "D"
group by s_id, id_firm) as ts
group by s_id) as s_max_ts_firm
on sample.id = s_max_ts_firm.s_id
# 6
left join # prod
(select s_id, max(ts) as max_ts_prod from
(select s_id, id_product, min(ts) as ts
from iiabmdb.with_exp_result
where `status` = "D"
group by s_id, id_product) as ts
group by s_id) as s_max_ts_prod
on sample.id = s_max_ts_prod.s_id
# 7
left join
(select s_id, count(distinct id_firm, id_product) as n_remove_firm_prod
from iiabmdb.with_exp_result
where `status` = "R"
group by s_id) as s_n_remove_firm_prod
on sample.id = s_n_remove_firm_prod.s_id
# 8
left join
(select s_id, count(distinct id_firm) as n_all_prod_remove_firm from
(select s_id, id_firm, count(distinct id_product) as n_remove_prod
from iiabmdb.with_exp_result
where `status` = "R"
group by s_id, id_firm) as s_n_remove_prod
left join iiabmdb_basic_info.firm_n_prod as firm_n_prod
on s_n_remove_prod.id_firm = firm_n_prod.code
where n_remove_prod = n_prod
group by s_id) as s_n_all_prod_remove_firm
on sample.id = s_n_all_prod_remove_firm.s_id
) as secnario_count
group by idx_scenario
) as secnario_mean
on experiment.idx_scenario = secnario_mean.idx_scenario;

12
SQL_analysis_risk.sql Normal file
View File

@ -0,0 +1,12 @@
select * from
(select s_id, id_firm, id_product, min(ts) as ts from iiabmdb.without_exp_result
where `status` = 'D'
group by s_id, id_firm, id_product) as s_disrupt
where s_id in
(select s_id from
(select s_id, id_firm, id_product, min(ts) as ts from iiabmdb.without_exp_result
where `status` = 'D'
group by s_id, id_firm, id_product) as t
group by s_id
having count(*) > 1)
order by s_id;

6
SQL_db_user_create.sql Normal file
View File

@ -0,0 +1,6 @@
CREATE USER 'iiabm_user'@'localhost' IDENTIFIED WITH authentication_plugin BY 'iiabm_pwd';
-- CREATE USER 'iiabm_user'@'localhost' IDENTIFIED BY 'iiabm_pwd';
GRANT ALL PRIVILEGES ON iiabmdb.* TO 'iiabm_user'@'localhost';
FLUSH PRIVILEGES;

View File

@ -0,0 +1,15 @@
select e_id, n_disrupt_sample, total_n_disrupt_firm_prod_experiment, dct_lst_init_disrupt_firm_prod from iiabmdb.without_exp_experiment as experiment
inner join (
select e_id, count(id) as n_disrupt_sample, sum(n_disrupt_firm_prod_sample) as total_n_disrupt_firm_prod_experiment from iiabmdb.without_exp_sample as sample
inner join (
select * from
(select s_id, COUNT(DISTINCT id_firm, id_product) as n_disrupt_firm_prod_sample from iiabmdb.without_exp_result group by s_id
) as count_disrupt_firm_prod_sample
where n_disrupt_firm_prod_sample > 1
) as disrupt_sample
on sample.id = disrupt_sample.s_id
group by e_id
) as disrupt_experiment
on experiment.id = disrupt_experiment.e_id
order by n_disrupt_sample desc, total_n_disrupt_firm_prod_experiment desc
limit 0, 95;

13
SQL_migrate_db.sql Normal file
View File

@ -0,0 +1,13 @@
CREATE DATABASE iiabmdb20230829;
RENAME TABLE iiabmdb.not_test_experiment TO iiabmdb20230829.not_test_experiment,
iiabmdb.not_test_result TO iiabmdb20230829.not_test_result,
iiabmdb.not_test_sample TO iiabmdb20230829.not_test_sample,
iiabmdb.test_experiment TO iiabmdb20230829.test_experiment,
iiabmdb.test_result TO iiabmdb20230829.test_result,
iiabmdb.test_sample TO iiabmdb20230829.test_sample;
RENAME TABLE iiabmdb.with_exp_experiment TO iiabmdb20230829.with_exp_experiment,
iiabmdb.with_exp_result TO iiabmdb20230829.with_exp_result,
iiabmdb.with_exp_sample TO iiabmdb20230829.with_exp_sample,
iiabmdb.without_exp_experiment TO iiabmdb20230829.without_exp_experiment,
iiabmdb.without_exp_result TO iiabmdb20230829.without_exp_result,
iiabmdb.without_exp_sample TO iiabmdb20230829.without_exp_sample;

40
computation.py Normal file
View File

@ -0,0 +1,40 @@
import os
import datetime
from model import Model
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from controller_db import ControllerDB
class Computation:
def __init__(self, c_db: 'ControllerDB'):
# 控制不同进程 计算不同的样本 但使用同一个 数据库 c_db
self.c_db = c_db
self.pid = os.getpid()
def run(self, str_code='0', s_id=None):
sample_random = self.c_db.fetch_a_sample(s_id)
if sample_random is None:
return True
# lock this row by update is_done_flag to 0 将运行后的样本设置为 flag 0
self.c_db.lock_the_sample(sample_random)
print(
f"Pid {self.pid} ({str_code}) is running "
f"sample {sample_random.id} at {datetime.datetime.now()}")
# 将sample 对应的 experiment 的一系列值 和 参数值 传入 模型 中 包括列名 和 值
dct_exp = {column: getattr(sample_random.experiment, column)
for column in sample_random.experiment.__table__.c.keys()}
# 删除不需要的 主键
del dct_exp['id']
dct_sample_para = {'sample': sample_random,
'seed': sample_random.seed,
**dct_exp}
model = Model(dct_sample_para)
model.run(display=False)
return False

10
conf_db.yaml Normal file
View File

@ -0,0 +1,10 @@
# read by orm
is_local_db: True
local:
user_name: iiabm_user
password: iiabm_pwd
db_name: iiabmdb
address: 'localhost'
port: 3306

1
conf_db_prefix.yaml Normal file
View File

@ -0,0 +1 @@
db_name_prefix: without_exp

12
conf_experiment.yaml Normal file
View File

@ -0,0 +1,12 @@
# read by ControllerDB
# run settings
meta_seed: 2
test: # only for test scenarios
n_sample: 1
n_iter: 100
not_test: # normal scenarios
n_sample: 50
n_iter: 100

333
controller_db.py Normal file
View File

@ -0,0 +1,333 @@
# -*- coding: utf-8 -*-
from orm import db_session, engine, Base, ins, connection
from orm import Experiment, Sample, Result
from sqlalchemy.exc import OperationalError
from sqlalchemy import text
import yaml
import random
import pandas as pd
import platform
import networkx as nx
import json
import pickle
class ControllerDB:
is_with_exp: bool
dct_parameter = None
is_test: bool = None
db_name_prefix: str = None
reset_flag: int
lst_saved_s_id: list
def __init__(self, prefix, reset_flag=0):
with open('conf_experiment.yaml') as yaml_file:
dct_conf_experiment = yaml.full_load(yaml_file)
assert prefix in ['test', 'without_exp', 'with_exp'], "db name not in test, without_exp, with_exp"
self.is_test = prefix == 'test'
self.is_with_exp = False if prefix == 'test' or prefix == 'without_exp' else True
self.db_name_prefix = prefix
dct_para_in_test = dct_conf_experiment['test'] if self.is_test else dct_conf_experiment['not_test']
self.dct_parameter = {'meta_seed': dct_conf_experiment['meta_seed'] , **dct_para_in_test}
print(self.dct_parameter)
# 0, not reset; 1, reset self; 2, reset all
self.reset_flag = reset_flag
self.is_exist = False
self.lst_saved_s_id = []
def init_tables(self):
self.fill_experiment_table()
self.fill_sample_table()
def fill_experiment_table(self):
Firm = pd.read_csv("input_data/Firm_amended.csv")
Firm['Code'] = Firm['Code'].astype('string')
Firm.fillna(0, inplace=True)
# fill dct_lst_init_disrupt_firm_prod
# 存储 公司-在供应链结点的位置.. 0 1.1
list_dct = [] #存储 公司编码code 和对应的产业链 结点
if self.is_with_exp:
#对于方差分析时候使用
with open('SQL_export_high_risk_setting.sql', 'r') as f:
str_sql = text(f.read())
result = pd.read_sql(sql=str_sql, con=connection)
result['dct_lst_init_disrupt_firm_prod'] = \
result['dct_lst_init_disrupt_firm_prod'].apply(
lambda x: pickle.loads(x))
list_dct = result['dct_lst_init_disrupt_firm_prod'].to_list()
else:
# 行索引 (index):这一行在数据帧中的索引值。
# 行数据 (row):这一行的数据,是一个 pandas.Series 对象,包含该行的所有列和值。
for _, row in Firm.iterrows():
code = row['Code']
row = row['1':]
for product_code in row.index[row == 1].to_list():
dct = {code: [product_code]}
list_dct.append(dct)
# fill g_bom
# 结点属性值 相当于 图上点的 原始 产品名称
BomNodes = pd.read_csv('input_data/BomNodes.csv', index_col=0)
BomNodes.set_index('Code', inplace=True)
BomCateNet = pd.read_csv('input_data/BomCateNet.csv', index_col=0)
BomCateNet.fillna(0, inplace=True)
# 创建 可以多边的有向图 同时 转置操作 使得 上游指向下游结点 也就是 1.1.1 - 1.1 类似这种
g_bom = nx.from_pandas_adjacency(BomCateNet.T,
create_using=nx.MultiDiGraph())
# 填充每一个结点 的具体内容 通过 相同的 code 并且通过BomNodes.loc[code].to_dict()字典化 格式类似 格式 { code0 : {level: 0 ,name: 工业互联网 }}
bom_labels_dict = {}
for code in g_bom.nodes:
bom_labels_dict[code] = BomNodes.loc[code].to_dict()
# 分配属性 给每一个结点 获得类似 格式:{1: {'label': 'A', 'value': 10},
nx.set_node_attributes(g_bom, bom_labels_dict)
# 改为json 格式
g_product_js = json.dumps(nx.adjacency_data(g_bom))
# insert exp
df_xv = pd.read_csv(
"input_data/"
f"xv_{'with_exp' if self.is_with_exp else 'without_exp'}.csv",
index_col=None)
# read the OA table
df_oa = pd.read_csv(
"input_data/"
f"oa_{'with_exp' if self.is_with_exp else 'without_exp'}.csv",
index_col=None)
# .shape[1] 列数 .iloc 访问特定的值 而不是标签
df_oa = df_oa.iloc[:, 0:df_xv.shape[1]]
# idx_scenario 是 0 指行 idx_init_removal 指 索引 0.. dct_init_removal 键 code 公司 g_product_js 图的json数据 dct_exp_para 解码 全局参数xv-
for idx_scenario, row in df_oa.iterrows():
dct_exp_para = {}
for idx_col, para_level in enumerate(row):
dct_exp_para[df_xv.columns[idx_col]] = \
df_xv.iloc[para_level, idx_col]
# different initial removal 只会得到 键 和 值
for idx_init_removal, dct_init_removal in enumerate(list_dct):
self.add_experiment_1(idx_scenario,
idx_init_removal,
dct_init_removal,
g_product_js,
**dct_exp_para)
print(f"Inserted experiment for scenario {idx_scenario}, "
f"init_removal {idx_init_removal}!")
def add_experiment_1(self, idx_scenario, idx_init_removal,
dct_lst_init_disrupt_firm_prod, g_bom,
n_max_trial, prf_size, prf_conn,
cap_limit_prob_type, cap_limit_level,
diff_new_conn, remove_t, netw_prf_n):
e = Experiment(
idx_scenario=idx_scenario,
idx_init_removal=idx_init_removal,
n_sample=int(self.dct_parameter['n_sample']),
n_iter=int(self.dct_parameter['n_iter']),
dct_lst_init_disrupt_firm_prod=dct_lst_init_disrupt_firm_prod,
g_bom=g_bom,
n_max_trial=n_max_trial,
prf_size=prf_size,
prf_conn=prf_conn,
cap_limit_prob_type=cap_limit_prob_type,
cap_limit_level=cap_limit_level,
diff_new_conn=diff_new_conn,
remove_t=remove_t,
netw_prf_n=netw_prf_n
)
db_session.add(e)
db_session.commit()
def fill_sample_table(self):
rng = random.Random(self.dct_parameter['meta_seed'])
# 根据样本数目 设置 32 位随机整数
lst_seed = [
rng.getrandbits(32)
for _ in range(int(self.dct_parameter['n_sample']))
]
lst_exp = db_session.query(Experiment).all()
lst_sample = []
for experiment in lst_exp:
# idx_sample: 1-50
for idx_sample in range(int(experiment.n_sample)):
s = Sample(e_id=experiment.id,
idx_sample=idx_sample + 1,
seed=lst_seed[idx_sample],
is_done_flag=-1)
lst_sample.append(s)
db_session.bulk_save_objects(lst_sample)
db_session.commit()
print(f'Inserted {len(lst_sample)} samples!')
def reset_db(self, force_drop=False):
# first, check if tables exist
lst_table_obj = [
Base.metadata.tables[str_table]
for str_table in ins.get_table_names()
if str_table.startswith(self.db_name_prefix)
]
self.is_exist = len(lst_table_obj) > 0
if force_drop:
self.force_drop_db(lst_table_obj)
# while is_exist:
# a_table = random.choice(lst_table_obj)
# try:
# Base.metadata.drop_all(bind=engine, tables=[a_table])
# except KeyError:
# pass
# except OperationalError:
# pass
# else:
# lst_table_obj.remove(a_table)
# print(
# f"Table {a_table.name} is dropped "
# f"for exp: {self.db_name_prefix}!!!"
# )
# finally:
# is_exist = len(lst_table_obj) > 0
if self.is_exist:
print(
f"All tables exist. No need to reset "
f"for exp: {self.db_name_prefix}."
)
# change the is_done_flag from 0 to -1
# rerun the in-finished tasks
self.is_exist_reset_flag_resset_db()
# if self.reset_flag > 0:
# if self.reset_flag == 2:
# sample = db_session.query(Sample).filter(
# Sample.is_done_flag == 0)
# elif self.reset_flag == 1:
# sample = db_session.query(Sample).filter(
# Sample.is_done_flag == 0,
# Sample.computer_name == platform.node())
# else:
# raise ValueError('Wrong reset flag')
# if sample.count() > 0:
# for s in sample:
# qry_result = db_session.query(Result).filter_by(
# s_id=s.id)
# if qry_result.count() > 0:
# db_session.query(Result).filter(s_id=s.id).delete()
# db_session.commit()
# s.is_done_flag = -1
# db_session.commit()
# print(f"Reset the sample id {s.id} flag from 0 to -1")
else:
# 不存在则重新生成所有的表结构
Base.metadata.create_all(bind=engine)
self.init_tables()
print(
f"All tables are just created and initialized "
f"for exp: {self.db_name_prefix}."
)
def force_drop_db(self, lst_table_obj):
self.is_exist = len(lst_table_obj) > 0
while self.is_exist:
a_table = random.choice(lst_table_obj)
try:
Base.metadata.drop_all(bind=engine, tables=[a_table])
except KeyError:
pass
except OperationalError:
pass
else:
lst_table_obj.remove(a_table)
print(
f"Table {a_table.name} is dropped "
f"for exp: {self.db_name_prefix}!!!"
)
finally:
self.is_exist = len(lst_table_obj) > 0
def is_exist_reset_flag_resset_db(self):
if self.reset_flag > 0:
if self.reset_flag == 2:
sample = db_session.query(Sample).filter(
Sample.is_done_flag == 0)
elif self.reset_flag == 1:
sample = db_session.query(Sample).filter(
Sample.is_done_flag == 0,
Sample.computer_name == platform.node())
else:
raise ValueError('Wrong reset flag')
if sample.count() > 0:
for s in sample:
qry_result = db_session.query(Result).filter_by(
s_id=s.id)
if qry_result.count() > 0:
db_session.query(Result).filter(s_id=s.id).delete()
db_session.commit()
s.is_done_flag = -1
db_session.commit()
print(f"Reset the sample id {s.id} flag from 0 to -1")
def prepare_list_sample(self):
# 为了符合前面 重置表里面存在 重置本机 或者重置全部 或者不重置的部分 这个部分的 关于样本运行也得重新拿出来
# 查找一个风险事件中 50 个样本
res = db_session.execute(
text(f"SELECT count(*) FROM {self.db_name_prefix}_sample s, "
f"{self.db_name_prefix}_experiment e WHERE s.e_id=e.id"
)).scalar()
# 控制 n_sample数量 作为后面的参数
n_sample = 0 if res is None else res
print(f'There are a total of {n_sample} samples.')
# 查找 is_done_flag = -1 也就是没有运行的 样本 运行后会改为0
res = db_session.execute(
text(f"SELECT id FROM {self.db_name_prefix}_sample "
f"WHERE is_done_flag = -1"
))
for row in res:
s_id = row[0]
self.lst_saved_s_id.append(s_id)
@staticmethod
def select_random_sample(lst_s_id):
while 1:
if len(lst_s_id) == 0:
return None
s_id = random.choice(lst_s_id)
lst_s_id.remove(s_id)
res = db_session.query(Sample).filter(Sample.id == int(s_id),
Sample.is_done_flag == -1)
if res.count() == 1:
return res[0]
def fetch_a_sample(self, s_id=None):
# 由Computation 调用 返回 sample对象 同时给出 2中 指定访问模式 抓取特定的 样本 通过s_id
# 默认访问 flag为-1的 lst_saved_s_id
if s_id is not None:
res = db_session.query(Sample).filter(Sample.id == int(s_id))
if res.count() == 0:
return None
else:
return res[0]
sample = self.select_random_sample(self.lst_saved_s_id)
if sample is not None:
return sample
return None
@staticmethod
def lock_the_sample(sample: Sample):
sample.is_done_flag, sample.computer_name = 0, platform.node()
db_session.commit()
if __name__ == '__main__':
print("Testing the database connection...")
try:
controller_db = ControllerDB('test')
Base.metadata.create_all(bind=engine)
except Exception as e:
print("Failed to connect to the database!")
print(e)
exit(1)

199
firm.py Normal file
View File

@ -0,0 +1,199 @@
from mesa import Agent
class FirmAgent(Agent):
def __init__(self, unique_id, model, code, type_region, revenue_log, a_lst_product):
# 调用超类的 __init__ 方法
super().__init__(unique_id, model)
# 初始化模型中的网络引用
self.firm_network = self.model.firm_network
self.product_network = self.model.product_network
# 初始化代理自身的属性
self.code = code
self.type_region = type_region
self.size_stat = []
self.dct_prod_up_prod_stat = {}
self.dct_prod_capacity = {}
# 试验中的参数
self.dct_n_trial_up_prod_disrupted = {}
self.dct_cand_alt_supp_up_prod_disrupted = {}
self.dct_request_prod_from_firm = {}
# 外部变量
self.is_prf_size = self.model.is_prf_size
self.is_prf_conn = bool(self.model.prf_conn)
self.str_cap_limit_prob_type = str(self.model.cap_limit_prob_type)
self.flt_cap_limit_level = float(self.model.cap_limit_level)
self.flt_diff_new_conn = float(self.model.diff_new_conn)
# 初始化 size_stat
self.size_stat.append((revenue_log, 0))
# 初始化 dct_prod_up_prod_stat
for prod in a_lst_product:
self.dct_prod_up_prod_stat[prod] = {
'p_stat': [('N', 0)],
's_stat': {up_prod: {'stat': True, 'set_disrupt_firm': set()}
for up_prod in prod.a_predecessors()}
}
# 初始化额外容量 (dct_prod_capacity)
for product in a_lst_product:
assert self.str_cap_limit_prob_type in ['uniform', 'normal'], \
"cap_limit_prob_type must be either 'uniform' or 'normal'"
extra_cap_mean = self.size_stat[0][0] / self.flt_cap_limit_level
if self.str_cap_limit_prob_type == 'uniform':
extra_cap = self.model.random.randint(extra_cap_mean - 2, extra_cap_mean + 2)
elif self.str_cap_limit_prob_type == 'normal':
extra_cap = self.model.random.normalvariate(extra_cap_mean, 1)
extra_cap = max(0, round(extra_cap))
self.dct_prod_capacity[product] = extra_cap
def remove_edge_to_cus(self, disrupted_prod):
lst_out_edge = list(
self.firm_network.get_neighbors(self.unique_id))
for n2 in lst_out_edge:
edge_data = self.firm_network.G.edges[self.unique_id, n2]
if edge_data.get('Product') == disrupted_prod.code:
customer = self.model.schedule.get_agent(n2)
for prod in customer.dct_prod_up_prod_stat.keys():
if disrupted_prod in customer.dct_prod_up_prod_stat[prod]['s_stat'].keys():
customer.dct_prod_up_prod_stat[prod]['s_stat'][disrupted_prod][
'set_disrupt_firm'].add(self)
self.firm_network.remove_edge(self.unique_id, n2)
def disrupt_cus_prod(self, prod, disrupted_up_prod):
num_lost = len(self.dct_prod_up_prod_stat[prod]['s_stat'][disrupted_up_prod]['set_disrupt_firm'])
num_remain = len([
u for u in self.firm_network.get_neighbors(self.unique_id)
if self.firm_network.G.edges[u, self.unique_id].get('Product') == disrupted_up_prod.code])
lost_percent = num_lost / (num_lost + num_remain)
lst_size = [firm.size_stat[-1][0] for firm in self.model.schedule.agents]
std_size = (self.size_stat[-1][0] - min(lst_size) + 1) / (max(lst_size) - min(lst_size) + 1)
prob_disrupt = 1 - std_size * (1 - lost_percent)
if self.random.choice([True, False], p=[prob_disrupt, 1 - prob_disrupt]):
self.dct_n_trial_up_prod_disrupted[disrupted_up_prod] = 0
self.dct_prod_up_prod_stat[prod]['s_stat'][disrupted_up_prod]['stat'] = False
status, _ = self.dct_prod_up_prod_stat[prod]['p_stat'][-1]
if status != 'D':
self.dct_prod_up_prod_stat[prod]['p_stat'].append(('D', self.model.schedule.time))
def seek_alt_supply(self, product):
if self.dct_n_trial_up_prod_disrupted[product] <= self.model.int_n_max_trial:
if self.dct_n_trial_up_prod_disrupted[product] == 0:
self.dct_cand_alt_supp_up_prod_disrupted[product] = [
firm for firm in self.model.schedule.agents
if firm.is_prod_in_current_normal(product)]
if self.dct_cand_alt_supp_up_prod_disrupted[product]:
lst_firm_connect = []
if self.is_prf_conn:
for firm in self.dct_cand_alt_supp_up_prod_disrupted[product]:
if self.firm_network.G.has_edge(self.unique_id, firm.unique_id) or \
self.firm_network.G.has_edge(firm.unique_id, self.unique_id):
lst_firm_connect.append(firm)
if len(lst_firm_connect) == 0:
if self.is_prf_size:
lst_size = [firm.size_stat[-1][0] for firm in self.dct_cand_alt_supp_up_prod_disrupted[product]]
lst_prob = [size / sum(lst_size) for size in lst_size]
select_alt_supply = \
self.random.choices(self.dct_cand_alt_supp_up_prod_disrupted[product], weights=lst_prob)[0]
else:
select_alt_supply = self.random.choice(self.dct_cand_alt_supp_up_prod_disrupted[product])
elif len(lst_firm_connect) > 0:
if self.is_prf_size:
lst_firm_size = [firm.size_stat[-1][0] for firm in lst_firm_connect]
lst_prob = [size / sum(lst_firm_size) for size in lst_firm_size]
select_alt_supply = self.random.choices(lst_firm_connect, weights=lst_prob)[0]
else:
select_alt_supply = self.random.choice(lst_firm_connect)
assert select_alt_supply.is_prod_in_current_normal(product)
if product in select_alt_supply.dct_request_prod_from_firm:
select_alt_supply.dct_request_prod_from_firm[product].append(self)
else:
select_alt_supply.dct_request_prod_from_firm[product] = [self]
self.dct_n_trial_up_prod_disrupted[product] += 1
def handle_request(self):
for product, lst_firm in self.dct_request_prod_from_firm.items():
if self.dct_prod_capacity[product] > 0:
if len(lst_firm) == 0:
continue
elif len(lst_firm) == 1:
self.accept_request(lst_firm[0], product)
elif len(lst_firm) > 1:
lst_firm_connect = []
if self.is_prf_conn:
for firm in lst_firm:
if self.firm_network.G.has_edge(self.unique_id, firm.unique_id) or \
self.firm_network.G.has_edge(firm.unique_id, self.unique_id):
lst_firm_connect.append(firm)
if len(lst_firm_connect) == 0:
if self.is_prf_size:
lst_firm_size = [firm.size_stat[-1][0] for firm in lst_firm]
lst_prob = [size / sum(lst_firm_size) for size in lst_firm_size]
select_customer = self.random.choices(lst_firm, weights=lst_prob)[0]
else:
select_customer = self.random.choice(lst_firm)
self.accept_request(select_customer, product)
elif len(lst_firm_connect) > 0:
if self.is_prf_size:
lst_firm_size = [firm.size_stat[-1][0] for firm in lst_firm_connect]
lst_prob = [size / sum(lst_firm_size) for size in lst_firm_size]
select_customer = self.random.choices(lst_firm_connect, weights=lst_prob)[0]
else:
select_customer = self.random.choice(lst_firm_connect)
self.accept_request(select_customer, product)
else:
for down_firm in lst_firm:
down_firm.dct_cand_alt_supp_up_prod_disrupted[product].remove(self)
def accept_request(self, down_firm, product):
if self.firm_network.G.has_edge(self.unique_id, down_firm.unique_id) or \
self.firm_network.G.has_edge(down_firm.unique_id, self.unique_id):
prod_accept = 1.0
else:
prod_accept = self.flt_diff_new_conn
if self.random.choice([True, False], p=[prod_accept, 1 - prod_accept]):
self.firm_network.G.add_edge(self.unique_id, down_firm.unique_id, Product=product.code)
self.dct_prod_capacity[product] -= 1
self.dct_request_prod_from_firm[product].remove(down_firm)
for prod in down_firm.dct_prod_up_prod_stat.keys():
if product in down_firm.dct_prod_up_prod_stat[prod]['s_stat']:
down_firm.dct_prod_up_prod_stat[prod]['s_stat'][product]['stat'] = True
down_firm.dct_prod_up_prod_stat[prod]['p_stat'].append(
('N', self.model.schedule.time))
del down_firm.dct_n_trial_up_prod_disrupted[product]
del down_firm.dct_cand_alt_supp_up_prod_disrupted[product]
else:
down_firm.dct_cand_alt_supp_up_prod_disrupted[product].remove(self)
def clean_before_trial(self):
self.dct_request_prod_from_firm = {}
def clean_before_time_step(self):
# Reset the number of trials and candidate suppliers for disrupted products
self.dct_n_trial_up_prod_disrupted = dict.fromkeys(self.dct_n_trial_up_prod_disrupted.keys(), 0)
self.dct_cand_alt_supp_up_prod_disrupted = {}
# Update the status of products and refresh disruption sets
for prod in self.dct_prod_up_prod_stat.keys():
status, ts = self.dct_prod_up_prod_stat[prod]['p_stat'][-1]
if ts != self.model.schedule.time:
self.dct_prod_up_prod_stat[prod]['p_stat'].append((status, self.model.schedule.time))
# Refresh the set of disrupted firms
for up_prod in self.dct_prod_up_prod_stat[prod]['s_stat'].keys():
self.dct_prod_up_prod_stat[prod]['s_stat'][up_prod]['set_disrupt_firm'] = set()
def step(self):
# 在每个时间步进行的操作
pass

67
main.py Normal file
View File

@ -0,0 +1,67 @@
import os
import random
import time
from multiprocessing import Process
import argparse
from computation import Computation
from sqlalchemy.orm import close_all_sessions
import yaml
from controller_db import ControllerDB
def controll_db_and_process(exp_argument, reset_sample_argument, reset_db_argument):
from controller_db import ControllerDB
controller_db = ControllerDB(exp_argument, reset_flag=reset_sample_argument)
# controller_db.reset_db()
# force drop
controller_db.reset_db(force_drop=reset_db_argument)
# 准备样本表
controller_db.prepare_list_sample()
close_all_sessions()
# 调用 do_process 利用计算机进行多核处理 仿真 将数据库中
do_process(do_computation, controller_db)
def do_process(target: object, controller_db: ControllerDB, ):
process_list = []
for i in range(int(args.job)):
p = Process(target=do_computation, args=(controller_db,))
p.start()
process_list.append(p)
for i in process_list:
i.join()
def do_computation(c_db):
exp = Computation(c_db)
while 1:
time.sleep(random.uniform(0, 5))
is_all_done = exp.run()
if is_all_done:
break
if __name__ == '__main__':
# 输入参数
parser = argparse.ArgumentParser(description='setting')
parser.add_argument('--exp', type=str, default='test')
parser.add_argument('--job', type=int, default='3')
parser.add_argument('--reset_sample', type=int, default='0')
parser.add_argument('--reset_db', type=bool, default=False)
args = parser.parse_args()
# 几核参与进程
assert args.job >= 1, 'Number of jobs should >= 1'
# 控制参数 利用 prefix_file_name 前缀名字 控制 2项不同的实验
prefix_file_name = 'conf_db_prefix.yaml'
if os.path.exists(prefix_file_name):
os.remove(prefix_file_name)
with open(prefix_file_name, 'w', encoding='utf-8') as file:
yaml.dump({'db_name_prefix': args.exp}, file)
# 数据库连接控制 和 进行模型运行
controll_db_and_process(args.exp, args.reset_sample, args.reset_db)

133
model.py Normal file
View File

@ -0,0 +1,133 @@
import json
import networkx as nx
import pandas as pd
from mesa import Model
from mesa.time import RandomActivation
from mesa.space import MultiGrid
from mesa.datacollection import DataCollector
from firm import FirmAgent
from product import ProductAgent
class MyModel(Model):
def __init__(self, params):
# self.num_agents = params['N']
# self.grid = MultiGrid(params['width'], params['height'], True)
# self.schedule = RandomActivation(self)
# Initialize parameters from `params`
self.sample = params['sample']
self.int_stop_ts = 0
self.int_n_iter = int(params['n_iter'])
self.dct_lst_init_disrupt_firm_prod = params['dct_lst_init_disrupt_firm_prod']
# external variable
self.int_n_max_trial = int(params['n_max_trial'])
self.is_prf_size = bool(params['prf_size'])
self.remove_t = int(params['remove_t'])
self.int_netw_prf_n = int(params['netw_prf_n'])
self.product_network = None
self.firm_network = None
self.firm_prod_network = None
# Initialize product network
G_bom = nx.adjacency_graph(json.loads(params['g_bom']))
self.product_network = G_bom
# Initialize firm network
self.initialize_firm_network()
# Initialize firm product network
self.initialize_firm_prod_network()
# Initialize agents
self.initialize_agents()
# Data collector (if needed)
self.datacollector = DataCollector(
agent_reporters={"Product Code": "code"}
)
def initialize_firm_network(self):
# Read firm data and initialize firm network
firm = pd.read_csv("input_data/Firm_amended.csv")
firm['Code'] = firm['Code'].astype('string')
firm.fillna(0, inplace=True)
Firm_attr = firm[["Code", "Type_Region", "Revenue_Log"]]
firm_product = []
for _, row in firm.loc[:, '1':].iterrows():
firm_product.append(row[row == 1].index.to_list())
Firm_attr['Product_Code'] = firm_product
Firm_attr.set_index('Code', inplace=True)
G_Firm = nx.MultiDiGraph()
G_Firm.add_nodes_from(firm["Code"])
# Add node attributes
firm_labels_dict = {}
for code in G_Firm.nodes:
firm_labels_dict[code] = Firm_attr.loc[code].to_dict()
nx.set_node_attributes(G_Firm, firm_labels_dict)
# Add edges based on BOM graph
self.add_edges_based_on_bom(G_Firm)
self.firm_network = G_Firm
def initialize_firm_prod_network(self):
# Read firm product data and initialize firm product network
firm_prod = pd.read_csv("input_data/Firm_amended.csv")
firm_prod.fillna(0, inplace=True)
firm_prod = pd.DataFrame({'bool': firm_prod.loc[:, '1':].stack()})
firm_prod = firm_prod[firm_prod['bool'] == 1].reset_index()
firm_prod.drop('bool', axis=1, inplace=True)
firm_prod.rename({'level_0': 'Firm_Code', 'level_1': 'Product_Code'}, axis=1, inplace=True)
firm_prod['Firm_Code'] = firm_prod['Firm_Code'].astype('string')
G_FirmProd = nx.MultiDiGraph()
G_FirmProd.add_nodes_from(firm_prod.index)
# Add node attributes
firm_prod_labels_dict = {}
for code in firm_prod.index:
firm_prod_labels_dict[code] = firm_prod.loc[code].to_dict()
nx.set_node_attributes(G_FirmProd, firm_prod_labels_dict)
self.firm_prod_network = G_FirmProd
def add_edges_based_on_bom(self, G_Firm):
# Logic to add edges to the G_Firm graph based on BOM
pass
def initialize_agents(self):
# Initialize product and firm agents
for node, attr in self.product_network.nodes(data=True):
product = ProductAgent(node, self, code=node, name=attr['Name'])
self.schedule.add(product)
for node, attr in self.firm_network.nodes(data=True):
firm_agent = FirmAgent(
node,
self,
code=node,
type_region=attr['Type_Region'],
revenue_log=attr['Revenue_Log'],
a_lst_product=[] # Populate based on products
)
self.schedule.add(firm_agent)
# Initialize disruptions
self.initialize_disruptions()
def initialize_disruptions(self):
# Set the initial firm product disruptions
for firm, products in self.dct_lst_init_disrupt_firm_prod.items():
for product in products:
if isinstance(firm, FirmAgent):
firm.dct_prod_up_prod_stat[product]['p_stat'].append(('D', self.schedule.steps))
def step(self):
self.schedule.step()
self.datacollector.collect(self)

114
orm.py Normal file
View File

@ -0,0 +1,114 @@
# -*- coding: utf-8 -*-
from sqlalchemy import create_engine, inspect, Inspector
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import (Column, Integer, DECIMAL, String, ForeignKey,
BigInteger, DateTime, PickleType, Boolean, Text)
from sqlalchemy.sql import func
from sqlalchemy.orm import relationship, Session
from sqlalchemy.pool import NullPool
import yaml
with open('conf_db.yaml') as file:
dct_conf_db_all = yaml.full_load(file)
is_local_db = dct_conf_db_all['is_local_db']
if is_local_db:
dct_conf_db = dct_conf_db_all['local']
else:
dct_conf_db = dct_conf_db_all['remote']
with open('conf_db_prefix.yaml') as file:
dct_conf_db_prefix = yaml.full_load(file)
db_name_prefix = dct_conf_db_prefix['db_name_prefix']
str_login = 'mysql://{}:{}@{}:{}/{}'.format(dct_conf_db['user_name'],
dct_conf_db['password'],
dct_conf_db['address'],
dct_conf_db['port'],
dct_conf_db['db_name'])
# print('DB is {}:{}/{}'.format(dct_conf_db['address'], dct_conf_db['port'], dct_conf_db['db_name']))
# must be null pool to avoid connection lost error
engine = create_engine(str_login, poolclass=NullPool)
connection = engine.connect()
ins: Inspector = inspect(engine)
Base = declarative_base()
db_session = Session(bind=engine)
class Experiment(Base):
__tablename__ = f"{db_name_prefix}_experiment"
id = Column(Integer, primary_key=True, autoincrement=True)
idx_scenario = Column(Integer, nullable=False)
idx_init_removal = Column(Integer, nullable=False)
# fixed parameters
n_sample = Column(Integer, nullable=False)
n_iter = Column(Integer, nullable=False)
# variables
dct_lst_init_disrupt_firm_prod = Column(PickleType, nullable=False)
g_bom = Column(Text(4294000000), nullable=False)
n_max_trial = Column(Integer, nullable=False)
prf_size = Column(Boolean, nullable=False)
prf_conn = Column(Boolean, nullable=False)
cap_limit_prob_type = Column(String(16), nullable=False)
cap_limit_level = Column(DECIMAL(8, 4), nullable=False)
diff_new_conn = Column(DECIMAL(8, 4), nullable=False)
remove_t = Column(Integer, nullable=False)
netw_prf_n = Column(Integer, nullable=False)
sample = relationship(
'Sample', back_populates='experiment', lazy='dynamic')
def __repr__(self):
return f'<Experiment: {self.id}>'
class Sample(Base):
__tablename__ = f"{db_name_prefix}_sample"
id = Column(Integer, primary_key=True, autoincrement=True)
e_id = Column(Integer, ForeignKey('{}.id'.format(
f"{db_name_prefix}_experiment")), nullable=False)
idx_sample = Column(Integer, nullable=False)
seed = Column(BigInteger, nullable=False)
# -1, waiting; 0, running; 1, done
is_done_flag = Column(Integer, nullable=False)
computer_name = Column(String(64), nullable=True)
ts_done = Column(DateTime(timezone=True), onupdate=func.now())
stop_t = Column(Integer, nullable=True)
g_firm = Column(Text(4294000000), nullable=True)
experiment = relationship(
'Experiment', back_populates='sample', uselist=False)
result = relationship('Result', back_populates='sample', lazy='dynamic')
def __repr__(self):
return f'<Sample id: {self.id}>'
class Result(Base):
__tablename__ = f"{db_name_prefix}_result"
id = Column(Integer, primary_key=True, autoincrement=True)
s_id = Column(Integer, ForeignKey('{}.id'.format(
f"{db_name_prefix}_sample")), nullable=False)
id_firm = Column(String(10), nullable=False)
id_product = Column(String(10), nullable=False)
ts = Column(Integer, nullable=False)
status = Column(String(5), nullable=False)
sample = relationship('Sample', back_populates='result', uselist=False)
def __repr__(self):
return f'<Product id: {self.id}>'
if __name__ == '__main__':
Base.metadata.drop_all()
Base.metadata.create_all()

24
product.py Normal file
View File

@ -0,0 +1,24 @@
from mesa import Agent
class ProductAgent(Agent):
def __init__(self, unique_id, model, code, name):
# 调用超类的 __init__ 方法
super().__init__(unique_id, model)
# 初始化代理属性
self.code = code
self.name = name
self.product_network = self.model.product_network
def a_successors(self):
# Find successors of the current product and return them as a list of ProductAgent
successors = list(self.product_network.successors(self))
return [self.model.schedule.agents[successor] for successor in successors]
def a_predecessors(self):
# Find predecessors of the current product and return them as a list of ProductAgent
predecessors = list(self.product_network.predecessors(self))
return [self.model.schedule.agents[predecessor] for predecessor in predecessors]
def step(self):
# 在每个时间步进行的操作
pass

55
requirements.txt Normal file
View File

@ -0,0 +1,55 @@
agentpy==0.1.5
alabaster==0.7.13
Babel==2.12.1
certifi @ file:///C:/b/abs_85o_6fm0se/croot/certifi_1671487778835/work/certifi
charset-normalizer==3.0.1
colorama==0.4.6
cycler==0.11.0
decorator==5.1.1
dill==0.3.6
docutils==0.19
greenlet==2.0.2
idna==3.4
imagesize==1.4.1
importlib-metadata==6.0.0
Jinja2==3.1.2
joblib==1.2.0
kiwisolver==1.4.4
MarkupSafe==2.1.2
matplotlib==3.3.4
matplotlib-inline==0.1.6
multiprocess==0.70.14
mysqlclient==2.1.1
networkx==2.5
numpy==1.20.3
numpydoc==1.1.0
packaging==23.0
pandas==1.4.1
pandas-stubs==1.2.0.39
Pillow==9.4.0
Pygments==2.14.0
pygraphviz @ file:///C:/Users/ASUS/Downloads/pygraphviz-1.9-cp38-cp38-win_amd64.whl
pyparsing==3.0.9
python-dateutil==2.8.2
pytz==2022.7.1
PyYAML==6.0
requests==2.28.2
SALib==1.4.7
scipy==1.10.1
six==1.16.0
snowballstemmer==2.2.0
Sphinx==6.1.3
sphinxcontrib-applehelp==1.0.4
sphinxcontrib-devhelp==1.0.2
sphinxcontrib-htmlhelp==2.0.1
sphinxcontrib-jsmath==1.0.1
sphinxcontrib-qthelp==1.0.3
sphinxcontrib-serializinghtml==1.1.5
SQLAlchemy==2.0.5.post1
traitlets==5.9.0
typing_extensions==4.5.0
urllib3==1.26.14
wincertstore==0.2
yapf @ file:///tmp/build/80754af9/yapf_1615749224965/work
zipp==3.15.0
mesa==2.1.5

View File

@ -0,0 +1,9 @@
agentpy==0.1.5
matplotlib==3.3.4
matplotlib-inline==0.1.6
networkx==2.5
numpy==1.20.3
numpydoc==1.1.0
pandas==1.4.1
pandas-stubs==1.2.0.39
pygraphviz==1.9