This commit is contained in:
commit
1f643c64e4
|
@ -0,0 +1,3 @@
|
|||
# 默认忽略的文件
|
||||
/shelf/
|
||||
/workspace.xml
|
|
@ -0,0 +1,6 @@
|
|||
<component name="InspectionProjectProfileManager">
|
||||
<settings>
|
||||
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||
<version value="1.0" />
|
||||
</settings>
|
||||
</component>
|
|
@ -0,0 +1,10 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$">
|
||||
<excludeFolder url="file://$MODULE_DIR$/.venv" />
|
||||
</content>
|
||||
<orderEntry type="jdk" jdkName="Python 3.8" jdkType="Python SDK" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
</module>
|
|
@ -0,0 +1,7 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="Black">
|
||||
<option name="sdkName" value="Python 3.8" />
|
||||
</component>
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.8" project-jdk-type="Python SDK" />
|
||||
</project>
|
|
@ -0,0 +1,8 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectModuleManager">
|
||||
<modules>
|
||||
<module fileurl="file://$PROJECT_DIR$/.idea/mesa.iml" filepath="$PROJECT_DIR$/.idea/mesa.iml" />
|
||||
</modules>
|
||||
</component>
|
||||
</project>
|
|
@ -0,0 +1,6 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="VcsDirectoryMappings">
|
||||
<mapping directory="$PROJECT_DIR$" vcs="Git" />
|
||||
</component>
|
||||
</project>
|
|
@ -0,0 +1,85 @@
|
|||
select distinct experiment.idx_scenario,
|
||||
n_max_trial, prf_size, prf_conn, cap_limit_prob_type, cap_limit_level, diff_new_conn, remove_t, netw_prf_n,
|
||||
mean_count_firm_prod, mean_count_firm, mean_count_prod,
|
||||
mean_max_ts_firm_prod, mean_max_ts_firm, mean_max_ts_prod,
|
||||
mean_n_remove_firm_prod, mean_n_all_prod_remove_firm, mean_end_ts
|
||||
from iiabmdb.with_exp_experiment as experiment
|
||||
left join
|
||||
(
|
||||
select
|
||||
idx_scenario,
|
||||
sum(count_firm_prod) / count(*) as mean_count_firm_prod, # Note to use count(*), to include NULL
|
||||
sum(count_firm) / count(*) as mean_count_firm,
|
||||
sum(count_prod) / count(*) as mean_count_prod,
|
||||
sum(max_ts_firm_prod) / count(*) as mean_max_ts_firm_prod,
|
||||
sum(max_ts_firm) / count(*) as mean_max_ts_firm,
|
||||
sum(max_ts_prod) / count(*) as mean_max_ts_prod,
|
||||
sum(n_remove_firm_prod) / count(*) as mean_n_remove_firm_prod,
|
||||
sum(n_all_prod_remove_firm) / count(*) as mean_n_all_prod_remove_firm,
|
||||
sum(end_ts) / count(*) as mean_end_ts
|
||||
from (
|
||||
select sample.id, idx_scenario,
|
||||
count_firm_prod, count_firm, count_prod,
|
||||
max_ts_firm_prod, max_ts_firm, max_ts_prod,
|
||||
n_remove_firm_prod, n_all_prod_remove_firm, end_ts
|
||||
from iiabmdb.with_exp_sample as sample
|
||||
# 1 2 3 + 9
|
||||
left join iiabmdb.with_exp_experiment as experiment
|
||||
on sample.e_id = experiment.id
|
||||
left join (select s_id,
|
||||
count(distinct id_firm, id_product) as count_firm_prod,
|
||||
count(distinct id_firm) as count_firm,
|
||||
count(distinct id_product) as count_prod,
|
||||
max(ts) as end_ts
|
||||
from iiabmdb.with_exp_result group by s_id) as s_count
|
||||
on sample.id = s_count.s_id
|
||||
# 4
|
||||
left join # firm prod
|
||||
(select s_id, max(ts) as max_ts_firm_prod from
|
||||
(select s_id, id_firm, id_product, min(ts) as ts
|
||||
from iiabmdb.with_exp_result
|
||||
where `status` = "D"
|
||||
group by s_id, id_firm, id_product) as ts
|
||||
group by s_id) as s_max_ts_firm_prod
|
||||
on sample.id = s_max_ts_firm_prod.s_id
|
||||
# 5
|
||||
left join # firm
|
||||
(select s_id, max(ts) as max_ts_firm from
|
||||
(select s_id, id_firm, min(ts) as ts
|
||||
from iiabmdb.with_exp_result
|
||||
where `status` = "D"
|
||||
group by s_id, id_firm) as ts
|
||||
group by s_id) as s_max_ts_firm
|
||||
on sample.id = s_max_ts_firm.s_id
|
||||
# 6
|
||||
left join # prod
|
||||
(select s_id, max(ts) as max_ts_prod from
|
||||
(select s_id, id_product, min(ts) as ts
|
||||
from iiabmdb.with_exp_result
|
||||
where `status` = "D"
|
||||
group by s_id, id_product) as ts
|
||||
group by s_id) as s_max_ts_prod
|
||||
on sample.id = s_max_ts_prod.s_id
|
||||
# 7
|
||||
left join
|
||||
(select s_id, count(distinct id_firm, id_product) as n_remove_firm_prod
|
||||
from iiabmdb.with_exp_result
|
||||
where `status` = "R"
|
||||
group by s_id) as s_n_remove_firm_prod
|
||||
on sample.id = s_n_remove_firm_prod.s_id
|
||||
# 8
|
||||
left join
|
||||
(select s_id, count(distinct id_firm) as n_all_prod_remove_firm from
|
||||
(select s_id, id_firm, count(distinct id_product) as n_remove_prod
|
||||
from iiabmdb.with_exp_result
|
||||
where `status` = "R"
|
||||
group by s_id, id_firm) as s_n_remove_prod
|
||||
left join iiabmdb_basic_info.firm_n_prod as firm_n_prod
|
||||
on s_n_remove_prod.id_firm = firm_n_prod.code
|
||||
where n_remove_prod = n_prod
|
||||
group by s_id) as s_n_all_prod_remove_firm
|
||||
on sample.id = s_n_all_prod_remove_firm.s_id
|
||||
) as secnario_count
|
||||
group by idx_scenario
|
||||
) as secnario_mean
|
||||
on experiment.idx_scenario = secnario_mean.idx_scenario;
|
|
@ -0,0 +1,12 @@
|
|||
select * from
|
||||
(select s_id, id_firm, id_product, min(ts) as ts from iiabmdb.without_exp_result
|
||||
where `status` = 'D'
|
||||
group by s_id, id_firm, id_product) as s_disrupt
|
||||
where s_id in
|
||||
(select s_id from
|
||||
(select s_id, id_firm, id_product, min(ts) as ts from iiabmdb.without_exp_result
|
||||
where `status` = 'D'
|
||||
group by s_id, id_firm, id_product) as t
|
||||
group by s_id
|
||||
having count(*) > 1)
|
||||
order by s_id;
|
|
@ -0,0 +1,6 @@
|
|||
CREATE USER 'iiabm_user'@'localhost' IDENTIFIED WITH authentication_plugin BY 'iiabm_pwd';
|
||||
|
||||
-- CREATE USER 'iiabm_user'@'localhost' IDENTIFIED BY 'iiabm_pwd';
|
||||
|
||||
GRANT ALL PRIVILEGES ON iiabmdb.* TO 'iiabm_user'@'localhost';
|
||||
FLUSH PRIVILEGES;
|
|
@ -0,0 +1,15 @@
|
|||
select e_id, n_disrupt_sample, total_n_disrupt_firm_prod_experiment, dct_lst_init_disrupt_firm_prod from iiabmdb.without_exp_experiment as experiment
|
||||
inner join (
|
||||
select e_id, count(id) as n_disrupt_sample, sum(n_disrupt_firm_prod_sample) as total_n_disrupt_firm_prod_experiment from iiabmdb.without_exp_sample as sample
|
||||
inner join (
|
||||
select * from
|
||||
(select s_id, COUNT(DISTINCT id_firm, id_product) as n_disrupt_firm_prod_sample from iiabmdb.without_exp_result group by s_id
|
||||
) as count_disrupt_firm_prod_sample
|
||||
where n_disrupt_firm_prod_sample > 1
|
||||
) as disrupt_sample
|
||||
on sample.id = disrupt_sample.s_id
|
||||
group by e_id
|
||||
) as disrupt_experiment
|
||||
on experiment.id = disrupt_experiment.e_id
|
||||
order by n_disrupt_sample desc, total_n_disrupt_firm_prod_experiment desc
|
||||
limit 0, 95;
|
|
@ -0,0 +1,13 @@
|
|||
CREATE DATABASE iiabmdb20230829;
|
||||
RENAME TABLE iiabmdb.not_test_experiment TO iiabmdb20230829.not_test_experiment,
|
||||
iiabmdb.not_test_result TO iiabmdb20230829.not_test_result,
|
||||
iiabmdb.not_test_sample TO iiabmdb20230829.not_test_sample,
|
||||
iiabmdb.test_experiment TO iiabmdb20230829.test_experiment,
|
||||
iiabmdb.test_result TO iiabmdb20230829.test_result,
|
||||
iiabmdb.test_sample TO iiabmdb20230829.test_sample;
|
||||
RENAME TABLE iiabmdb.with_exp_experiment TO iiabmdb20230829.with_exp_experiment,
|
||||
iiabmdb.with_exp_result TO iiabmdb20230829.with_exp_result,
|
||||
iiabmdb.with_exp_sample TO iiabmdb20230829.with_exp_sample,
|
||||
iiabmdb.without_exp_experiment TO iiabmdb20230829.without_exp_experiment,
|
||||
iiabmdb.without_exp_result TO iiabmdb20230829.without_exp_result,
|
||||
iiabmdb.without_exp_sample TO iiabmdb20230829.without_exp_sample;
|
|
@ -0,0 +1,40 @@
|
|||
import os
|
||||
import datetime
|
||||
|
||||
from model import Model
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from controller_db import ControllerDB
|
||||
|
||||
|
||||
class Computation:
|
||||
|
||||
def __init__(self, c_db: 'ControllerDB'):
|
||||
# 控制不同进程 计算不同的样本 但使用同一个 数据库 c_db
|
||||
self.c_db = c_db
|
||||
self.pid = os.getpid()
|
||||
|
||||
def run(self, str_code='0', s_id=None):
|
||||
sample_random = self.c_db.fetch_a_sample(s_id)
|
||||
if sample_random is None:
|
||||
return True
|
||||
|
||||
# lock this row by update is_done_flag to 0 将运行后的样本设置为 flag 0
|
||||
self.c_db.lock_the_sample(sample_random)
|
||||
print(
|
||||
f"Pid {self.pid} ({str_code}) is running "
|
||||
f"sample {sample_random.id} at {datetime.datetime.now()}")
|
||||
# 将sample 对应的 experiment 的一系列值 和 参数值 传入 模型 中 包括列名 和 值
|
||||
dct_exp = {column: getattr(sample_random.experiment, column)
|
||||
for column in sample_random.experiment.__table__.c.keys()}
|
||||
# 删除不需要的 主键
|
||||
del dct_exp['id']
|
||||
|
||||
dct_sample_para = {'sample': sample_random,
|
||||
'seed': sample_random.seed,
|
||||
**dct_exp}
|
||||
model = Model(dct_sample_para)
|
||||
|
||||
model.run(display=False)
|
||||
return False
|
|
@ -0,0 +1,10 @@
|
|||
# read by orm
|
||||
is_local_db: True
|
||||
|
||||
local:
|
||||
user_name: iiabm_user
|
||||
password: iiabm_pwd
|
||||
db_name: iiabmdb
|
||||
address: 'localhost'
|
||||
port: 3306
|
||||
|
|
@ -0,0 +1 @@
|
|||
db_name_prefix: without_exp
|
|
@ -0,0 +1,12 @@
|
|||
# read by ControllerDB
|
||||
|
||||
# run settings
|
||||
meta_seed: 2
|
||||
|
||||
test: # only for test scenarios
|
||||
n_sample: 1
|
||||
n_iter: 100
|
||||
|
||||
not_test: # normal scenarios
|
||||
n_sample: 50
|
||||
n_iter: 100
|
|
@ -0,0 +1,333 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
from orm import db_session, engine, Base, ins, connection
|
||||
from orm import Experiment, Sample, Result
|
||||
from sqlalchemy.exc import OperationalError
|
||||
from sqlalchemy import text
|
||||
import yaml
|
||||
import random
|
||||
import pandas as pd
|
||||
import platform
|
||||
import networkx as nx
|
||||
import json
|
||||
import pickle
|
||||
|
||||
|
||||
class ControllerDB:
|
||||
is_with_exp: bool
|
||||
dct_parameter = None
|
||||
is_test: bool = None
|
||||
db_name_prefix: str = None
|
||||
reset_flag: int
|
||||
|
||||
lst_saved_s_id: list
|
||||
|
||||
def __init__(self, prefix, reset_flag=0):
|
||||
with open('conf_experiment.yaml') as yaml_file:
|
||||
dct_conf_experiment = yaml.full_load(yaml_file)
|
||||
assert prefix in ['test', 'without_exp', 'with_exp'], "db name not in test, without_exp, with_exp"
|
||||
|
||||
self.is_test = prefix == 'test'
|
||||
self.is_with_exp = False if prefix == 'test' or prefix == 'without_exp' else True
|
||||
self.db_name_prefix = prefix
|
||||
dct_para_in_test = dct_conf_experiment['test'] if self.is_test else dct_conf_experiment['not_test']
|
||||
self.dct_parameter = {'meta_seed': dct_conf_experiment['meta_seed'] , **dct_para_in_test}
|
||||
|
||||
print(self.dct_parameter)
|
||||
# 0, not reset; 1, reset self; 2, reset all
|
||||
self.reset_flag = reset_flag
|
||||
self.is_exist = False
|
||||
self.lst_saved_s_id = []
|
||||
|
||||
|
||||
def init_tables(self):
|
||||
self.fill_experiment_table()
|
||||
self.fill_sample_table()
|
||||
|
||||
def fill_experiment_table(self):
|
||||
Firm = pd.read_csv("input_data/Firm_amended.csv")
|
||||
Firm['Code'] = Firm['Code'].astype('string')
|
||||
Firm.fillna(0, inplace=True)
|
||||
|
||||
# fill dct_lst_init_disrupt_firm_prod
|
||||
# 存储 公司-在供应链结点的位置.. 0 :‘1.1’
|
||||
list_dct = [] #存储 公司编码code 和对应的产业链 结点
|
||||
if self.is_with_exp:
|
||||
#对于方差分析时候使用
|
||||
with open('SQL_export_high_risk_setting.sql', 'r') as f:
|
||||
str_sql = text(f.read())
|
||||
result = pd.read_sql(sql=str_sql, con=connection)
|
||||
result['dct_lst_init_disrupt_firm_prod'] = \
|
||||
result['dct_lst_init_disrupt_firm_prod'].apply(
|
||||
lambda x: pickle.loads(x))
|
||||
list_dct = result['dct_lst_init_disrupt_firm_prod'].to_list()
|
||||
else:
|
||||
# 行索引 (index):这一行在数据帧中的索引值。
|
||||
# 行数据 (row):这一行的数据,是一个 pandas.Series 对象,包含该行的所有列和值。
|
||||
for _, row in Firm.iterrows():
|
||||
code = row['Code']
|
||||
row = row['1':]
|
||||
for product_code in row.index[row == 1].to_list():
|
||||
dct = {code: [product_code]}
|
||||
list_dct.append(dct)
|
||||
|
||||
# fill g_bom
|
||||
# 结点属性值 相当于 图上点的 原始 产品名称
|
||||
BomNodes = pd.read_csv('input_data/BomNodes.csv', index_col=0)
|
||||
BomNodes.set_index('Code', inplace=True)
|
||||
|
||||
BomCateNet = pd.read_csv('input_data/BomCateNet.csv', index_col=0)
|
||||
BomCateNet.fillna(0, inplace=True)
|
||||
# 创建 可以多边的有向图 同时 转置操作 使得 上游指向下游结点 也就是 1.1.1 - 1.1 类似这种
|
||||
g_bom = nx.from_pandas_adjacency(BomCateNet.T,
|
||||
create_using=nx.MultiDiGraph())
|
||||
# 填充每一个结点 的具体内容 通过 相同的 code 并且通过BomNodes.loc[code].to_dict()字典化 格式类似 格式 { code(0) : {level: 0 ,name: 工业互联网 }}
|
||||
bom_labels_dict = {}
|
||||
for code in g_bom.nodes:
|
||||
bom_labels_dict[code] = BomNodes.loc[code].to_dict()
|
||||
# 分配属性 给每一个结点 获得类似 格式:{1: {'label': 'A', 'value': 10},
|
||||
nx.set_node_attributes(g_bom, bom_labels_dict)
|
||||
# 改为json 格式
|
||||
g_product_js = json.dumps(nx.adjacency_data(g_bom))
|
||||
|
||||
# insert exp
|
||||
df_xv = pd.read_csv(
|
||||
"input_data/"
|
||||
f"xv_{'with_exp' if self.is_with_exp else 'without_exp'}.csv",
|
||||
index_col=None)
|
||||
# read the OA table
|
||||
df_oa = pd.read_csv(
|
||||
"input_data/"
|
||||
f"oa_{'with_exp' if self.is_with_exp else 'without_exp'}.csv",
|
||||
index_col=None)
|
||||
# .shape[1] 列数 .iloc 访问特定的值 而不是标签
|
||||
df_oa = df_oa.iloc[:, 0:df_xv.shape[1]]
|
||||
# idx_scenario 是 0 指行 idx_init_removal 指 索引 0.. dct_init_removal 键 code 公司 g_product_js 图的json数据 dct_exp_para 解码 全局参数xv-
|
||||
for idx_scenario, row in df_oa.iterrows():
|
||||
dct_exp_para = {}
|
||||
for idx_col, para_level in enumerate(row):
|
||||
dct_exp_para[df_xv.columns[idx_col]] = \
|
||||
df_xv.iloc[para_level, idx_col]
|
||||
# different initial removal 只会得到 键 和 值
|
||||
for idx_init_removal, dct_init_removal in enumerate(list_dct):
|
||||
self.add_experiment_1(idx_scenario,
|
||||
idx_init_removal,
|
||||
dct_init_removal,
|
||||
g_product_js,
|
||||
**dct_exp_para)
|
||||
print(f"Inserted experiment for scenario {idx_scenario}, "
|
||||
f"init_removal {idx_init_removal}!")
|
||||
|
||||
def add_experiment_1(self, idx_scenario, idx_init_removal,
|
||||
dct_lst_init_disrupt_firm_prod, g_bom,
|
||||
n_max_trial, prf_size, prf_conn,
|
||||
cap_limit_prob_type, cap_limit_level,
|
||||
diff_new_conn, remove_t, netw_prf_n):
|
||||
e = Experiment(
|
||||
idx_scenario=idx_scenario,
|
||||
idx_init_removal=idx_init_removal,
|
||||
n_sample=int(self.dct_parameter['n_sample']),
|
||||
n_iter=int(self.dct_parameter['n_iter']),
|
||||
dct_lst_init_disrupt_firm_prod=dct_lst_init_disrupt_firm_prod,
|
||||
g_bom=g_bom,
|
||||
n_max_trial=n_max_trial,
|
||||
prf_size=prf_size,
|
||||
prf_conn=prf_conn,
|
||||
cap_limit_prob_type=cap_limit_prob_type,
|
||||
cap_limit_level=cap_limit_level,
|
||||
diff_new_conn=diff_new_conn,
|
||||
remove_t=remove_t,
|
||||
netw_prf_n=netw_prf_n
|
||||
)
|
||||
db_session.add(e)
|
||||
db_session.commit()
|
||||
|
||||
def fill_sample_table(self):
|
||||
rng = random.Random(self.dct_parameter['meta_seed'])
|
||||
# 根据样本数目 设置 32 位随机整数
|
||||
lst_seed = [
|
||||
rng.getrandbits(32)
|
||||
for _ in range(int(self.dct_parameter['n_sample']))
|
||||
]
|
||||
lst_exp = db_session.query(Experiment).all()
|
||||
|
||||
lst_sample = []
|
||||
for experiment in lst_exp:
|
||||
# idx_sample: 1-50
|
||||
for idx_sample in range(int(experiment.n_sample)):
|
||||
s = Sample(e_id=experiment.id,
|
||||
idx_sample=idx_sample + 1,
|
||||
seed=lst_seed[idx_sample],
|
||||
is_done_flag=-1)
|
||||
lst_sample.append(s)
|
||||
db_session.bulk_save_objects(lst_sample)
|
||||
db_session.commit()
|
||||
print(f'Inserted {len(lst_sample)} samples!')
|
||||
|
||||
def reset_db(self, force_drop=False):
|
||||
# first, check if tables exist
|
||||
lst_table_obj = [
|
||||
Base.metadata.tables[str_table]
|
||||
for str_table in ins.get_table_names()
|
||||
if str_table.startswith(self.db_name_prefix)
|
||||
]
|
||||
self.is_exist = len(lst_table_obj) > 0
|
||||
if force_drop:
|
||||
self.force_drop_db(lst_table_obj)
|
||||
# while is_exist:
|
||||
# a_table = random.choice(lst_table_obj)
|
||||
# try:
|
||||
# Base.metadata.drop_all(bind=engine, tables=[a_table])
|
||||
# except KeyError:
|
||||
# pass
|
||||
# except OperationalError:
|
||||
# pass
|
||||
# else:
|
||||
# lst_table_obj.remove(a_table)
|
||||
# print(
|
||||
# f"Table {a_table.name} is dropped "
|
||||
# f"for exp: {self.db_name_prefix}!!!"
|
||||
# )
|
||||
# finally:
|
||||
# is_exist = len(lst_table_obj) > 0
|
||||
|
||||
if self.is_exist:
|
||||
print(
|
||||
f"All tables exist. No need to reset "
|
||||
f"for exp: {self.db_name_prefix}."
|
||||
)
|
||||
# change the is_done_flag from 0 to -1
|
||||
# rerun the in-finished tasks
|
||||
self.is_exist_reset_flag_resset_db()
|
||||
# if self.reset_flag > 0:
|
||||
# if self.reset_flag == 2:
|
||||
# sample = db_session.query(Sample).filter(
|
||||
# Sample.is_done_flag == 0)
|
||||
# elif self.reset_flag == 1:
|
||||
# sample = db_session.query(Sample).filter(
|
||||
# Sample.is_done_flag == 0,
|
||||
# Sample.computer_name == platform.node())
|
||||
# else:
|
||||
# raise ValueError('Wrong reset flag')
|
||||
# if sample.count() > 0:
|
||||
# for s in sample:
|
||||
# qry_result = db_session.query(Result).filter_by(
|
||||
# s_id=s.id)
|
||||
# if qry_result.count() > 0:
|
||||
# db_session.query(Result).filter(s_id=s.id).delete()
|
||||
# db_session.commit()
|
||||
# s.is_done_flag = -1
|
||||
# db_session.commit()
|
||||
# print(f"Reset the sample id {s.id} flag from 0 to -1")
|
||||
|
||||
else:
|
||||
# 不存在则重新生成所有的表结构
|
||||
Base.metadata.create_all(bind=engine)
|
||||
self.init_tables()
|
||||
print(
|
||||
f"All tables are just created and initialized "
|
||||
f"for exp: {self.db_name_prefix}."
|
||||
)
|
||||
|
||||
def force_drop_db(self, lst_table_obj):
|
||||
self.is_exist = len(lst_table_obj) > 0
|
||||
while self.is_exist:
|
||||
a_table = random.choice(lst_table_obj)
|
||||
try:
|
||||
Base.metadata.drop_all(bind=engine, tables=[a_table])
|
||||
except KeyError:
|
||||
pass
|
||||
except OperationalError:
|
||||
pass
|
||||
else:
|
||||
lst_table_obj.remove(a_table)
|
||||
print(
|
||||
f"Table {a_table.name} is dropped "
|
||||
f"for exp: {self.db_name_prefix}!!!"
|
||||
)
|
||||
finally:
|
||||
self.is_exist = len(lst_table_obj) > 0
|
||||
|
||||
def is_exist_reset_flag_resset_db(self):
|
||||
if self.reset_flag > 0:
|
||||
if self.reset_flag == 2:
|
||||
sample = db_session.query(Sample).filter(
|
||||
Sample.is_done_flag == 0)
|
||||
elif self.reset_flag == 1:
|
||||
sample = db_session.query(Sample).filter(
|
||||
Sample.is_done_flag == 0,
|
||||
Sample.computer_name == platform.node())
|
||||
else:
|
||||
raise ValueError('Wrong reset flag')
|
||||
if sample.count() > 0:
|
||||
for s in sample:
|
||||
qry_result = db_session.query(Result).filter_by(
|
||||
s_id=s.id)
|
||||
if qry_result.count() > 0:
|
||||
db_session.query(Result).filter(s_id=s.id).delete()
|
||||
db_session.commit()
|
||||
s.is_done_flag = -1
|
||||
db_session.commit()
|
||||
print(f"Reset the sample id {s.id} flag from 0 to -1")
|
||||
|
||||
def prepare_list_sample(self):
|
||||
# 为了符合前面 重置表里面存在 重置本机 或者重置全部 或者不重置的部分 这个部分的 关于样本运行也得重新拿出来
|
||||
# 查找一个风险事件中 50 个样本
|
||||
res = db_session.execute(
|
||||
text(f"SELECT count(*) FROM {self.db_name_prefix}_sample s, "
|
||||
f"{self.db_name_prefix}_experiment e WHERE s.e_id=e.id"
|
||||
)).scalar()
|
||||
# 控制 n_sample数量 作为后面的参数
|
||||
n_sample = 0 if res is None else res
|
||||
print(f'There are a total of {n_sample} samples.')
|
||||
# 查找 is_done_flag = -1 也就是没有运行的 样本 运行后会改为0
|
||||
res = db_session.execute(
|
||||
text(f"SELECT id FROM {self.db_name_prefix}_sample "
|
||||
f"WHERE is_done_flag = -1"
|
||||
))
|
||||
for row in res:
|
||||
s_id = row[0]
|
||||
self.lst_saved_s_id.append(s_id)
|
||||
|
||||
@staticmethod
|
||||
def select_random_sample(lst_s_id):
|
||||
while 1:
|
||||
if len(lst_s_id) == 0:
|
||||
return None
|
||||
s_id = random.choice(lst_s_id)
|
||||
lst_s_id.remove(s_id)
|
||||
res = db_session.query(Sample).filter(Sample.id == int(s_id),
|
||||
Sample.is_done_flag == -1)
|
||||
if res.count() == 1:
|
||||
return res[0]
|
||||
|
||||
def fetch_a_sample(self, s_id=None):
|
||||
# 由Computation 调用 返回 sample对象 同时给出 2中 指定访问模式 抓取特定的 样本 通过s_id
|
||||
# 默认访问 flag为-1的 lst_saved_s_id
|
||||
if s_id is not None:
|
||||
res = db_session.query(Sample).filter(Sample.id == int(s_id))
|
||||
if res.count() == 0:
|
||||
return None
|
||||
else:
|
||||
return res[0]
|
||||
|
||||
sample = self.select_random_sample(self.lst_saved_s_id)
|
||||
if sample is not None:
|
||||
return sample
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def lock_the_sample(sample: Sample):
|
||||
sample.is_done_flag, sample.computer_name = 0, platform.node()
|
||||
db_session.commit()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
print("Testing the database connection...")
|
||||
try:
|
||||
controller_db = ControllerDB('test')
|
||||
Base.metadata.create_all(bind=engine)
|
||||
except Exception as e:
|
||||
print("Failed to connect to the database!")
|
||||
print(e)
|
||||
exit(1)
|
|
@ -0,0 +1,199 @@
|
|||
from mesa import Agent
|
||||
|
||||
|
||||
class FirmAgent(Agent):
|
||||
def __init__(self, unique_id, model, code, type_region, revenue_log, a_lst_product):
|
||||
# 调用超类的 __init__ 方法
|
||||
super().__init__(unique_id, model)
|
||||
|
||||
# 初始化模型中的网络引用
|
||||
self.firm_network = self.model.firm_network
|
||||
self.product_network = self.model.product_network
|
||||
|
||||
# 初始化代理自身的属性
|
||||
self.code = code
|
||||
self.type_region = type_region
|
||||
|
||||
self.size_stat = []
|
||||
self.dct_prod_up_prod_stat = {}
|
||||
self.dct_prod_capacity = {}
|
||||
|
||||
# 试验中的参数
|
||||
self.dct_n_trial_up_prod_disrupted = {}
|
||||
self.dct_cand_alt_supp_up_prod_disrupted = {}
|
||||
self.dct_request_prod_from_firm = {}
|
||||
|
||||
# 外部变量
|
||||
self.is_prf_size = self.model.is_prf_size
|
||||
self.is_prf_conn = bool(self.model.prf_conn)
|
||||
self.str_cap_limit_prob_type = str(self.model.cap_limit_prob_type)
|
||||
self.flt_cap_limit_level = float(self.model.cap_limit_level)
|
||||
self.flt_diff_new_conn = float(self.model.diff_new_conn)
|
||||
|
||||
# 初始化 size_stat
|
||||
self.size_stat.append((revenue_log, 0))
|
||||
|
||||
# 初始化 dct_prod_up_prod_stat
|
||||
for prod in a_lst_product:
|
||||
self.dct_prod_up_prod_stat[prod] = {
|
||||
'p_stat': [('N', 0)],
|
||||
's_stat': {up_prod: {'stat': True, 'set_disrupt_firm': set()}
|
||||
for up_prod in prod.a_predecessors()}
|
||||
}
|
||||
|
||||
# 初始化额外容量 (dct_prod_capacity)
|
||||
for product in a_lst_product:
|
||||
assert self.str_cap_limit_prob_type in ['uniform', 'normal'], \
|
||||
"cap_limit_prob_type must be either 'uniform' or 'normal'"
|
||||
extra_cap_mean = self.size_stat[0][0] / self.flt_cap_limit_level
|
||||
if self.str_cap_limit_prob_type == 'uniform':
|
||||
extra_cap = self.model.random.randint(extra_cap_mean - 2, extra_cap_mean + 2)
|
||||
elif self.str_cap_limit_prob_type == 'normal':
|
||||
extra_cap = self.model.random.normalvariate(extra_cap_mean, 1)
|
||||
extra_cap = max(0, round(extra_cap))
|
||||
self.dct_prod_capacity[product] = extra_cap
|
||||
|
||||
def remove_edge_to_cus(self, disrupted_prod):
|
||||
lst_out_edge = list(
|
||||
self.firm_network.get_neighbors(self.unique_id))
|
||||
for n2 in lst_out_edge:
|
||||
edge_data = self.firm_network.G.edges[self.unique_id, n2]
|
||||
if edge_data.get('Product') == disrupted_prod.code:
|
||||
customer = self.model.schedule.get_agent(n2)
|
||||
for prod in customer.dct_prod_up_prod_stat.keys():
|
||||
if disrupted_prod in customer.dct_prod_up_prod_stat[prod]['s_stat'].keys():
|
||||
customer.dct_prod_up_prod_stat[prod]['s_stat'][disrupted_prod][
|
||||
'set_disrupt_firm'].add(self)
|
||||
self.firm_network.remove_edge(self.unique_id, n2)
|
||||
|
||||
def disrupt_cus_prod(self, prod, disrupted_up_prod):
|
||||
num_lost = len(self.dct_prod_up_prod_stat[prod]['s_stat'][disrupted_up_prod]['set_disrupt_firm'])
|
||||
num_remain = len([
|
||||
u for u in self.firm_network.get_neighbors(self.unique_id)
|
||||
if self.firm_network.G.edges[u, self.unique_id].get('Product') == disrupted_up_prod.code])
|
||||
lost_percent = num_lost / (num_lost + num_remain)
|
||||
lst_size = [firm.size_stat[-1][0] for firm in self.model.schedule.agents]
|
||||
std_size = (self.size_stat[-1][0] - min(lst_size) + 1) / (max(lst_size) - min(lst_size) + 1)
|
||||
|
||||
prob_disrupt = 1 - std_size * (1 - lost_percent)
|
||||
if self.random.choice([True, False], p=[prob_disrupt, 1 - prob_disrupt]):
|
||||
self.dct_n_trial_up_prod_disrupted[disrupted_up_prod] = 0
|
||||
self.dct_prod_up_prod_stat[prod]['s_stat'][disrupted_up_prod]['stat'] = False
|
||||
status, _ = self.dct_prod_up_prod_stat[prod]['p_stat'][-1]
|
||||
if status != 'D':
|
||||
self.dct_prod_up_prod_stat[prod]['p_stat'].append(('D', self.model.schedule.time))
|
||||
|
||||
def seek_alt_supply(self, product):
|
||||
if self.dct_n_trial_up_prod_disrupted[product] <= self.model.int_n_max_trial:
|
||||
if self.dct_n_trial_up_prod_disrupted[product] == 0:
|
||||
self.dct_cand_alt_supp_up_prod_disrupted[product] = [
|
||||
firm for firm in self.model.schedule.agents
|
||||
if firm.is_prod_in_current_normal(product)]
|
||||
if self.dct_cand_alt_supp_up_prod_disrupted[product]:
|
||||
lst_firm_connect = []
|
||||
if self.is_prf_conn:
|
||||
for firm in self.dct_cand_alt_supp_up_prod_disrupted[product]:
|
||||
if self.firm_network.G.has_edge(self.unique_id, firm.unique_id) or \
|
||||
self.firm_network.G.has_edge(firm.unique_id, self.unique_id):
|
||||
lst_firm_connect.append(firm)
|
||||
if len(lst_firm_connect) == 0:
|
||||
if self.is_prf_size:
|
||||
lst_size = [firm.size_stat[-1][0] for firm in self.dct_cand_alt_supp_up_prod_disrupted[product]]
|
||||
lst_prob = [size / sum(lst_size) for size in lst_size]
|
||||
select_alt_supply = \
|
||||
self.random.choices(self.dct_cand_alt_supp_up_prod_disrupted[product], weights=lst_prob)[0]
|
||||
else:
|
||||
select_alt_supply = self.random.choice(self.dct_cand_alt_supp_up_prod_disrupted[product])
|
||||
elif len(lst_firm_connect) > 0:
|
||||
if self.is_prf_size:
|
||||
lst_firm_size = [firm.size_stat[-1][0] for firm in lst_firm_connect]
|
||||
lst_prob = [size / sum(lst_firm_size) for size in lst_firm_size]
|
||||
select_alt_supply = self.random.choices(lst_firm_connect, weights=lst_prob)[0]
|
||||
else:
|
||||
select_alt_supply = self.random.choice(lst_firm_connect)
|
||||
|
||||
assert select_alt_supply.is_prod_in_current_normal(product)
|
||||
|
||||
if product in select_alt_supply.dct_request_prod_from_firm:
|
||||
select_alt_supply.dct_request_prod_from_firm[product].append(self)
|
||||
else:
|
||||
select_alt_supply.dct_request_prod_from_firm[product] = [self]
|
||||
|
||||
self.dct_n_trial_up_prod_disrupted[product] += 1
|
||||
|
||||
def handle_request(self):
|
||||
for product, lst_firm in self.dct_request_prod_from_firm.items():
|
||||
if self.dct_prod_capacity[product] > 0:
|
||||
if len(lst_firm) == 0:
|
||||
continue
|
||||
elif len(lst_firm) == 1:
|
||||
self.accept_request(lst_firm[0], product)
|
||||
elif len(lst_firm) > 1:
|
||||
lst_firm_connect = []
|
||||
if self.is_prf_conn:
|
||||
for firm in lst_firm:
|
||||
if self.firm_network.G.has_edge(self.unique_id, firm.unique_id) or \
|
||||
self.firm_network.G.has_edge(firm.unique_id, self.unique_id):
|
||||
lst_firm_connect.append(firm)
|
||||
if len(lst_firm_connect) == 0:
|
||||
if self.is_prf_size:
|
||||
lst_firm_size = [firm.size_stat[-1][0] for firm in lst_firm]
|
||||
lst_prob = [size / sum(lst_firm_size) for size in lst_firm_size]
|
||||
select_customer = self.random.choices(lst_firm, weights=lst_prob)[0]
|
||||
else:
|
||||
select_customer = self.random.choice(lst_firm)
|
||||
self.accept_request(select_customer, product)
|
||||
elif len(lst_firm_connect) > 0:
|
||||
if self.is_prf_size:
|
||||
lst_firm_size = [firm.size_stat[-1][0] for firm in lst_firm_connect]
|
||||
lst_prob = [size / sum(lst_firm_size) for size in lst_firm_size]
|
||||
select_customer = self.random.choices(lst_firm_connect, weights=lst_prob)[0]
|
||||
else:
|
||||
select_customer = self.random.choice(lst_firm_connect)
|
||||
self.accept_request(select_customer, product)
|
||||
else:
|
||||
for down_firm in lst_firm:
|
||||
down_firm.dct_cand_alt_supp_up_prod_disrupted[product].remove(self)
|
||||
|
||||
def accept_request(self, down_firm, product):
|
||||
if self.firm_network.G.has_edge(self.unique_id, down_firm.unique_id) or \
|
||||
self.firm_network.G.has_edge(down_firm.unique_id, self.unique_id):
|
||||
prod_accept = 1.0
|
||||
else:
|
||||
prod_accept = self.flt_diff_new_conn
|
||||
if self.random.choice([True, False], p=[prod_accept, 1 - prod_accept]):
|
||||
self.firm_network.G.add_edge(self.unique_id, down_firm.unique_id, Product=product.code)
|
||||
self.dct_prod_capacity[product] -= 1
|
||||
self.dct_request_prod_from_firm[product].remove(down_firm)
|
||||
|
||||
for prod in down_firm.dct_prod_up_prod_stat.keys():
|
||||
if product in down_firm.dct_prod_up_prod_stat[prod]['s_stat']:
|
||||
down_firm.dct_prod_up_prod_stat[prod]['s_stat'][product]['stat'] = True
|
||||
down_firm.dct_prod_up_prod_stat[prod]['p_stat'].append(
|
||||
('N', self.model.schedule.time))
|
||||
del down_firm.dct_n_trial_up_prod_disrupted[product]
|
||||
del down_firm.dct_cand_alt_supp_up_prod_disrupted[product]
|
||||
else:
|
||||
down_firm.dct_cand_alt_supp_up_prod_disrupted[product].remove(self)
|
||||
|
||||
def clean_before_trial(self):
|
||||
self.dct_request_prod_from_firm = {}
|
||||
|
||||
def clean_before_time_step(self):
|
||||
# Reset the number of trials and candidate suppliers for disrupted products
|
||||
self.dct_n_trial_up_prod_disrupted = dict.fromkeys(self.dct_n_trial_up_prod_disrupted.keys(), 0)
|
||||
self.dct_cand_alt_supp_up_prod_disrupted = {}
|
||||
|
||||
# Update the status of products and refresh disruption sets
|
||||
for prod in self.dct_prod_up_prod_stat.keys():
|
||||
status, ts = self.dct_prod_up_prod_stat[prod]['p_stat'][-1]
|
||||
if ts != self.model.schedule.time:
|
||||
self.dct_prod_up_prod_stat[prod]['p_stat'].append((status, self.model.schedule.time))
|
||||
|
||||
# Refresh the set of disrupted firms
|
||||
for up_prod in self.dct_prod_up_prod_stat[prod]['s_stat'].keys():
|
||||
self.dct_prod_up_prod_stat[prod]['s_stat'][up_prod]['set_disrupt_firm'] = set()
|
||||
|
||||
def step(self):
|
||||
# 在每个时间步进行的操作
|
||||
pass
|
|
@ -0,0 +1,67 @@
|
|||
import os
|
||||
import random
|
||||
import time
|
||||
from multiprocessing import Process
|
||||
import argparse
|
||||
from computation import Computation
|
||||
from sqlalchemy.orm import close_all_sessions
|
||||
|
||||
import yaml
|
||||
|
||||
from controller_db import ControllerDB
|
||||
|
||||
|
||||
def controll_db_and_process(exp_argument, reset_sample_argument, reset_db_argument):
|
||||
from controller_db import ControllerDB
|
||||
controller_db = ControllerDB(exp_argument, reset_flag=reset_sample_argument)
|
||||
# controller_db.reset_db()
|
||||
# force drop
|
||||
controller_db.reset_db(force_drop=reset_db_argument)
|
||||
# 准备样本表
|
||||
controller_db.prepare_list_sample()
|
||||
|
||||
close_all_sessions()
|
||||
# 调用 do_process 利用计算机进行多核处理 仿真 将数据库中
|
||||
do_process(do_computation, controller_db)
|
||||
|
||||
|
||||
def do_process(target: object, controller_db: ControllerDB, ):
|
||||
process_list = []
|
||||
for i in range(int(args.job)):
|
||||
p = Process(target=do_computation, args=(controller_db,))
|
||||
p.start()
|
||||
process_list.append(p)
|
||||
|
||||
for i in process_list:
|
||||
i.join()
|
||||
|
||||
|
||||
def do_computation(c_db):
|
||||
exp = Computation(c_db)
|
||||
|
||||
while 1:
|
||||
time.sleep(random.uniform(0, 5))
|
||||
is_all_done = exp.run()
|
||||
if is_all_done:
|
||||
break
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 输入参数
|
||||
parser = argparse.ArgumentParser(description='setting')
|
||||
parser.add_argument('--exp', type=str, default='test')
|
||||
parser.add_argument('--job', type=int, default='3')
|
||||
parser.add_argument('--reset_sample', type=int, default='0')
|
||||
parser.add_argument('--reset_db', type=bool, default=False)
|
||||
|
||||
args = parser.parse_args()
|
||||
# 几核参与进程
|
||||
assert args.job >= 1, 'Number of jobs should >= 1'
|
||||
# 控制参数 利用 prefix_file_name 前缀名字 控制 2项不同的实验
|
||||
prefix_file_name = 'conf_db_prefix.yaml'
|
||||
if os.path.exists(prefix_file_name):
|
||||
os.remove(prefix_file_name)
|
||||
with open(prefix_file_name, 'w', encoding='utf-8') as file:
|
||||
yaml.dump({'db_name_prefix': args.exp}, file)
|
||||
# 数据库连接控制 和 进行模型运行
|
||||
controll_db_and_process(args.exp, args.reset_sample, args.reset_db)
|
|
@ -0,0 +1,133 @@
|
|||
import json
|
||||
|
||||
import networkx as nx
|
||||
import pandas as pd
|
||||
from mesa import Model
|
||||
from mesa.time import RandomActivation
|
||||
from mesa.space import MultiGrid
|
||||
from mesa.datacollection import DataCollector
|
||||
|
||||
from firm import FirmAgent
|
||||
from product import ProductAgent
|
||||
|
||||
|
||||
class MyModel(Model):
|
||||
def __init__(self, params):
|
||||
# self.num_agents = params['N']
|
||||
# self.grid = MultiGrid(params['width'], params['height'], True)
|
||||
# self.schedule = RandomActivation(self)
|
||||
|
||||
# Initialize parameters from `params`
|
||||
self.sample = params['sample']
|
||||
self.int_stop_ts = 0
|
||||
self.int_n_iter = int(params['n_iter'])
|
||||
self.dct_lst_init_disrupt_firm_prod = params['dct_lst_init_disrupt_firm_prod']
|
||||
# external variable
|
||||
self.int_n_max_trial = int(params['n_max_trial'])
|
||||
self.is_prf_size = bool(params['prf_size'])
|
||||
|
||||
self.remove_t = int(params['remove_t'])
|
||||
self.int_netw_prf_n = int(params['netw_prf_n'])
|
||||
|
||||
self.product_network = None
|
||||
self.firm_network = None
|
||||
self.firm_prod_network = None
|
||||
|
||||
# Initialize product network
|
||||
G_bom = nx.adjacency_graph(json.loads(params['g_bom']))
|
||||
self.product_network = G_bom
|
||||
|
||||
# Initialize firm network
|
||||
self.initialize_firm_network()
|
||||
|
||||
# Initialize firm product network
|
||||
self.initialize_firm_prod_network()
|
||||
|
||||
# Initialize agents
|
||||
self.initialize_agents()
|
||||
|
||||
# Data collector (if needed)
|
||||
self.datacollector = DataCollector(
|
||||
agent_reporters={"Product Code": "code"}
|
||||
)
|
||||
|
||||
def initialize_firm_network(self):
|
||||
# Read firm data and initialize firm network
|
||||
firm = pd.read_csv("input_data/Firm_amended.csv")
|
||||
firm['Code'] = firm['Code'].astype('string')
|
||||
firm.fillna(0, inplace=True)
|
||||
Firm_attr = firm[["Code", "Type_Region", "Revenue_Log"]]
|
||||
firm_product = []
|
||||
for _, row in firm.loc[:, '1':].iterrows():
|
||||
firm_product.append(row[row == 1].index.to_list())
|
||||
Firm_attr['Product_Code'] = firm_product
|
||||
Firm_attr.set_index('Code', inplace=True)
|
||||
G_Firm = nx.MultiDiGraph()
|
||||
G_Firm.add_nodes_from(firm["Code"])
|
||||
|
||||
# Add node attributes
|
||||
firm_labels_dict = {}
|
||||
for code in G_Firm.nodes:
|
||||
firm_labels_dict[code] = Firm_attr.loc[code].to_dict()
|
||||
nx.set_node_attributes(G_Firm, firm_labels_dict)
|
||||
|
||||
# Add edges based on BOM graph
|
||||
self.add_edges_based_on_bom(G_Firm)
|
||||
|
||||
self.firm_network = G_Firm
|
||||
|
||||
def initialize_firm_prod_network(self):
|
||||
# Read firm product data and initialize firm product network
|
||||
firm_prod = pd.read_csv("input_data/Firm_amended.csv")
|
||||
firm_prod.fillna(0, inplace=True)
|
||||
firm_prod = pd.DataFrame({'bool': firm_prod.loc[:, '1':].stack()})
|
||||
firm_prod = firm_prod[firm_prod['bool'] == 1].reset_index()
|
||||
firm_prod.drop('bool', axis=1, inplace=True)
|
||||
firm_prod.rename({'level_0': 'Firm_Code', 'level_1': 'Product_Code'}, axis=1, inplace=True)
|
||||
firm_prod['Firm_Code'] = firm_prod['Firm_Code'].astype('string')
|
||||
|
||||
G_FirmProd = nx.MultiDiGraph()
|
||||
G_FirmProd.add_nodes_from(firm_prod.index)
|
||||
|
||||
# Add node attributes
|
||||
firm_prod_labels_dict = {}
|
||||
for code in firm_prod.index:
|
||||
firm_prod_labels_dict[code] = firm_prod.loc[code].to_dict()
|
||||
nx.set_node_attributes(G_FirmProd, firm_prod_labels_dict)
|
||||
|
||||
self.firm_prod_network = G_FirmProd
|
||||
|
||||
def add_edges_based_on_bom(self, G_Firm):
|
||||
# Logic to add edges to the G_Firm graph based on BOM
|
||||
pass
|
||||
|
||||
def initialize_agents(self):
|
||||
# Initialize product and firm agents
|
||||
for node, attr in self.product_network.nodes(data=True):
|
||||
product = ProductAgent(node, self, code=node, name=attr['Name'])
|
||||
self.schedule.add(product)
|
||||
|
||||
for node, attr in self.firm_network.nodes(data=True):
|
||||
firm_agent = FirmAgent(
|
||||
node,
|
||||
self,
|
||||
code=node,
|
||||
type_region=attr['Type_Region'],
|
||||
revenue_log=attr['Revenue_Log'],
|
||||
a_lst_product=[] # Populate based on products
|
||||
)
|
||||
self.schedule.add(firm_agent)
|
||||
|
||||
# Initialize disruptions
|
||||
self.initialize_disruptions()
|
||||
|
||||
def initialize_disruptions(self):
|
||||
# Set the initial firm product disruptions
|
||||
for firm, products in self.dct_lst_init_disrupt_firm_prod.items():
|
||||
for product in products:
|
||||
if isinstance(firm, FirmAgent):
|
||||
firm.dct_prod_up_prod_stat[product]['p_stat'].append(('D', self.schedule.steps))
|
||||
|
||||
def step(self):
|
||||
self.schedule.step()
|
||||
self.datacollector.collect(self)
|
|
@ -0,0 +1,114 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
from sqlalchemy import create_engine, inspect, Inspector
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy import (Column, Integer, DECIMAL, String, ForeignKey,
|
||||
BigInteger, DateTime, PickleType, Boolean, Text)
|
||||
from sqlalchemy.sql import func
|
||||
from sqlalchemy.orm import relationship, Session
|
||||
from sqlalchemy.pool import NullPool
|
||||
import yaml
|
||||
|
||||
with open('conf_db.yaml') as file:
|
||||
dct_conf_db_all = yaml.full_load(file)
|
||||
is_local_db = dct_conf_db_all['is_local_db']
|
||||
if is_local_db:
|
||||
dct_conf_db = dct_conf_db_all['local']
|
||||
else:
|
||||
dct_conf_db = dct_conf_db_all['remote']
|
||||
|
||||
with open('conf_db_prefix.yaml') as file:
|
||||
dct_conf_db_prefix = yaml.full_load(file)
|
||||
db_name_prefix = dct_conf_db_prefix['db_name_prefix']
|
||||
|
||||
str_login = 'mysql://{}:{}@{}:{}/{}'.format(dct_conf_db['user_name'],
|
||||
dct_conf_db['password'],
|
||||
dct_conf_db['address'],
|
||||
dct_conf_db['port'],
|
||||
dct_conf_db['db_name'])
|
||||
# print('DB is {}:{}/{}'.format(dct_conf_db['address'], dct_conf_db['port'], dct_conf_db['db_name']))
|
||||
|
||||
# must be null pool to avoid connection lost error
|
||||
engine = create_engine(str_login, poolclass=NullPool)
|
||||
connection = engine.connect()
|
||||
ins: Inspector = inspect(engine)
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
db_session = Session(bind=engine)
|
||||
|
||||
|
||||
class Experiment(Base):
|
||||
__tablename__ = f"{db_name_prefix}_experiment"
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
|
||||
idx_scenario = Column(Integer, nullable=False)
|
||||
idx_init_removal = Column(Integer, nullable=False)
|
||||
|
||||
# fixed parameters
|
||||
n_sample = Column(Integer, nullable=False)
|
||||
n_iter = Column(Integer, nullable=False)
|
||||
|
||||
# variables
|
||||
dct_lst_init_disrupt_firm_prod = Column(PickleType, nullable=False)
|
||||
g_bom = Column(Text(4294000000), nullable=False)
|
||||
|
||||
n_max_trial = Column(Integer, nullable=False)
|
||||
prf_size = Column(Boolean, nullable=False)
|
||||
prf_conn = Column(Boolean, nullable=False)
|
||||
cap_limit_prob_type = Column(String(16), nullable=False)
|
||||
cap_limit_level = Column(DECIMAL(8, 4), nullable=False)
|
||||
diff_new_conn = Column(DECIMAL(8, 4), nullable=False)
|
||||
remove_t = Column(Integer, nullable=False)
|
||||
netw_prf_n = Column(Integer, nullable=False)
|
||||
|
||||
sample = relationship(
|
||||
'Sample', back_populates='experiment', lazy='dynamic')
|
||||
|
||||
def __repr__(self):
|
||||
return f'<Experiment: {self.id}>'
|
||||
|
||||
|
||||
class Sample(Base):
|
||||
__tablename__ = f"{db_name_prefix}_sample"
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
e_id = Column(Integer, ForeignKey('{}.id'.format(
|
||||
f"{db_name_prefix}_experiment")), nullable=False)
|
||||
|
||||
idx_sample = Column(Integer, nullable=False)
|
||||
seed = Column(BigInteger, nullable=False)
|
||||
# -1, waiting; 0, running; 1, done
|
||||
is_done_flag = Column(Integer, nullable=False)
|
||||
computer_name = Column(String(64), nullable=True)
|
||||
ts_done = Column(DateTime(timezone=True), onupdate=func.now())
|
||||
stop_t = Column(Integer, nullable=True)
|
||||
|
||||
g_firm = Column(Text(4294000000), nullable=True)
|
||||
|
||||
experiment = relationship(
|
||||
'Experiment', back_populates='sample', uselist=False)
|
||||
result = relationship('Result', back_populates='sample', lazy='dynamic')
|
||||
|
||||
def __repr__(self):
|
||||
return f'<Sample id: {self.id}>'
|
||||
|
||||
|
||||
class Result(Base):
|
||||
__tablename__ = f"{db_name_prefix}_result"
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
s_id = Column(Integer, ForeignKey('{}.id'.format(
|
||||
f"{db_name_prefix}_sample")), nullable=False)
|
||||
|
||||
id_firm = Column(String(10), nullable=False)
|
||||
id_product = Column(String(10), nullable=False)
|
||||
ts = Column(Integer, nullable=False)
|
||||
status = Column(String(5), nullable=False)
|
||||
|
||||
sample = relationship('Sample', back_populates='result', uselist=False)
|
||||
|
||||
def __repr__(self):
|
||||
return f'<Product id: {self.id}>'
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
Base.metadata.drop_all()
|
||||
Base.metadata.create_all()
|
|
@ -0,0 +1,24 @@
|
|||
from mesa import Agent
|
||||
|
||||
class ProductAgent(Agent):
|
||||
def __init__(self, unique_id, model, code, name):
|
||||
# 调用超类的 __init__ 方法
|
||||
super().__init__(unique_id, model)
|
||||
|
||||
# 初始化代理属性
|
||||
self.code = code
|
||||
self.name = name
|
||||
self.product_network = self.model.product_network
|
||||
|
||||
def a_successors(self):
|
||||
# Find successors of the current product and return them as a list of ProductAgent
|
||||
successors = list(self.product_network.successors(self))
|
||||
return [self.model.schedule.agents[successor] for successor in successors]
|
||||
|
||||
def a_predecessors(self):
|
||||
# Find predecessors of the current product and return them as a list of ProductAgent
|
||||
predecessors = list(self.product_network.predecessors(self))
|
||||
return [self.model.schedule.agents[predecessor] for predecessor in predecessors]
|
||||
def step(self):
|
||||
# 在每个时间步进行的操作
|
||||
pass
|
|
@ -0,0 +1,55 @@
|
|||
agentpy==0.1.5
|
||||
alabaster==0.7.13
|
||||
Babel==2.12.1
|
||||
certifi @ file:///C:/b/abs_85o_6fm0se/croot/certifi_1671487778835/work/certifi
|
||||
charset-normalizer==3.0.1
|
||||
colorama==0.4.6
|
||||
cycler==0.11.0
|
||||
decorator==5.1.1
|
||||
dill==0.3.6
|
||||
docutils==0.19
|
||||
greenlet==2.0.2
|
||||
idna==3.4
|
||||
imagesize==1.4.1
|
||||
importlib-metadata==6.0.0
|
||||
Jinja2==3.1.2
|
||||
joblib==1.2.0
|
||||
kiwisolver==1.4.4
|
||||
MarkupSafe==2.1.2
|
||||
matplotlib==3.3.4
|
||||
matplotlib-inline==0.1.6
|
||||
multiprocess==0.70.14
|
||||
mysqlclient==2.1.1
|
||||
networkx==2.5
|
||||
numpy==1.20.3
|
||||
numpydoc==1.1.0
|
||||
packaging==23.0
|
||||
pandas==1.4.1
|
||||
pandas-stubs==1.2.0.39
|
||||
Pillow==9.4.0
|
||||
Pygments==2.14.0
|
||||
pygraphviz @ file:///C:/Users/ASUS/Downloads/pygraphviz-1.9-cp38-cp38-win_amd64.whl
|
||||
pyparsing==3.0.9
|
||||
python-dateutil==2.8.2
|
||||
pytz==2022.7.1
|
||||
PyYAML==6.0
|
||||
requests==2.28.2
|
||||
SALib==1.4.7
|
||||
scipy==1.10.1
|
||||
six==1.16.0
|
||||
snowballstemmer==2.2.0
|
||||
Sphinx==6.1.3
|
||||
sphinxcontrib-applehelp==1.0.4
|
||||
sphinxcontrib-devhelp==1.0.2
|
||||
sphinxcontrib-htmlhelp==2.0.1
|
||||
sphinxcontrib-jsmath==1.0.1
|
||||
sphinxcontrib-qthelp==1.0.3
|
||||
sphinxcontrib-serializinghtml==1.1.5
|
||||
SQLAlchemy==2.0.5.post1
|
||||
traitlets==5.9.0
|
||||
typing_extensions==4.5.0
|
||||
urllib3==1.26.14
|
||||
wincertstore==0.2
|
||||
yapf @ file:///tmp/build/80754af9/yapf_1615749224965/work
|
||||
zipp==3.15.0
|
||||
mesa==2.1.5
|
|
@ -0,0 +1,9 @@
|
|||
agentpy==0.1.5
|
||||
matplotlib==3.3.4
|
||||
matplotlib-inline==0.1.6
|
||||
networkx==2.5
|
||||
numpy==1.20.3
|
||||
numpydoc==1.1.0
|
||||
pandas==1.4.1
|
||||
pandas-stubs==1.2.0.39
|
||||
pygraphviz==1.9
|
Loading…
Reference in New Issue