commit 1f643c64e4dc5f8c189134ae0491e40c359357c7 Author: Cricial <2911646453@qq.com> Date: Sat Aug 24 11:20:13 2024 +0800 1 diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..359bb53 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,3 @@ +# 默认忽略的文件 +/shelf/ +/workspace.xml diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/mesa.iml b/.idea/mesa.iml new file mode 100644 index 0000000..b6973c7 --- /dev/null +++ b/.idea/mesa.iml @@ -0,0 +1,10 @@ + + + + + + + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..e557d17 --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,7 @@ + + + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..e08ffd7 --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..94a25f7 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/SQL_analysis_experiment.sql b/SQL_analysis_experiment.sql new file mode 100644 index 0000000..7894992 --- /dev/null +++ b/SQL_analysis_experiment.sql @@ -0,0 +1,85 @@ +select distinct experiment.idx_scenario, +n_max_trial, prf_size, prf_conn, cap_limit_prob_type, cap_limit_level, diff_new_conn, remove_t, netw_prf_n, +mean_count_firm_prod, mean_count_firm, mean_count_prod, +mean_max_ts_firm_prod, mean_max_ts_firm, mean_max_ts_prod, +mean_n_remove_firm_prod, mean_n_all_prod_remove_firm, mean_end_ts +from iiabmdb.with_exp_experiment as experiment +left join +( +select +idx_scenario, +sum(count_firm_prod) / count(*) as mean_count_firm_prod, # Note to use count(*), to include NULL +sum(count_firm) / count(*) as mean_count_firm, +sum(count_prod) / count(*) as mean_count_prod, +sum(max_ts_firm_prod) / count(*) as mean_max_ts_firm_prod, +sum(max_ts_firm) / count(*) as mean_max_ts_firm, +sum(max_ts_prod) / count(*) as mean_max_ts_prod, +sum(n_remove_firm_prod) / count(*) as mean_n_remove_firm_prod, +sum(n_all_prod_remove_firm) / count(*) as mean_n_all_prod_remove_firm, +sum(end_ts) / count(*) as mean_end_ts +from ( +select sample.id, idx_scenario, +count_firm_prod, count_firm, count_prod, +max_ts_firm_prod, max_ts_firm, max_ts_prod, +n_remove_firm_prod, n_all_prod_remove_firm, end_ts +from iiabmdb.with_exp_sample as sample +# 1 2 3 + 9 +left join iiabmdb.with_exp_experiment as experiment +on sample.e_id = experiment.id +left join (select s_id, +count(distinct id_firm, id_product) as count_firm_prod, +count(distinct id_firm) as count_firm, +count(distinct id_product) as count_prod, +max(ts) as end_ts +from iiabmdb.with_exp_result group by s_id) as s_count +on sample.id = s_count.s_id +# 4 +left join # firm prod +(select s_id, max(ts) as max_ts_firm_prod from +(select s_id, id_firm, id_product, min(ts) as ts +from iiabmdb.with_exp_result +where `status` = "D" +group by s_id, id_firm, id_product) as ts +group by s_id) as s_max_ts_firm_prod +on sample.id = s_max_ts_firm_prod.s_id +# 5 +left join # firm +(select s_id, max(ts) as max_ts_firm from +(select s_id, id_firm, min(ts) as ts +from iiabmdb.with_exp_result +where `status` = "D" +group by s_id, id_firm) as ts +group by s_id) as s_max_ts_firm +on sample.id = s_max_ts_firm.s_id +# 6 +left join # prod +(select s_id, max(ts) as max_ts_prod from +(select s_id, id_product, min(ts) as ts +from iiabmdb.with_exp_result +where `status` = "D" +group by s_id, id_product) as ts +group by s_id) as s_max_ts_prod +on sample.id = s_max_ts_prod.s_id +# 7 +left join +(select s_id, count(distinct id_firm, id_product) as n_remove_firm_prod +from iiabmdb.with_exp_result +where `status` = "R" +group by s_id) as s_n_remove_firm_prod +on sample.id = s_n_remove_firm_prod.s_id +# 8 +left join +(select s_id, count(distinct id_firm) as n_all_prod_remove_firm from +(select s_id, id_firm, count(distinct id_product) as n_remove_prod +from iiabmdb.with_exp_result +where `status` = "R" +group by s_id, id_firm) as s_n_remove_prod +left join iiabmdb_basic_info.firm_n_prod as firm_n_prod +on s_n_remove_prod.id_firm = firm_n_prod.code +where n_remove_prod = n_prod +group by s_id) as s_n_all_prod_remove_firm +on sample.id = s_n_all_prod_remove_firm.s_id +) as secnario_count +group by idx_scenario +) as secnario_mean +on experiment.idx_scenario = secnario_mean.idx_scenario; \ No newline at end of file diff --git a/SQL_analysis_risk.sql b/SQL_analysis_risk.sql new file mode 100644 index 0000000..621186e --- /dev/null +++ b/SQL_analysis_risk.sql @@ -0,0 +1,12 @@ +select * from +(select s_id, id_firm, id_product, min(ts) as ts from iiabmdb.without_exp_result +where `status` = 'D' +group by s_id, id_firm, id_product) as s_disrupt +where s_id in +(select s_id from +(select s_id, id_firm, id_product, min(ts) as ts from iiabmdb.without_exp_result +where `status` = 'D' +group by s_id, id_firm, id_product) as t +group by s_id +having count(*) > 1) +order by s_id; \ No newline at end of file diff --git a/SQL_db_user_create.sql b/SQL_db_user_create.sql new file mode 100644 index 0000000..eaef3bd --- /dev/null +++ b/SQL_db_user_create.sql @@ -0,0 +1,6 @@ +CREATE USER 'iiabm_user'@'localhost' IDENTIFIED WITH authentication_plugin BY 'iiabm_pwd'; + +-- CREATE USER 'iiabm_user'@'localhost' IDENTIFIED BY 'iiabm_pwd'; + +GRANT ALL PRIVILEGES ON iiabmdb.* TO 'iiabm_user'@'localhost'; +FLUSH PRIVILEGES; \ No newline at end of file diff --git a/SQL_export_high_risk_setting.sql b/SQL_export_high_risk_setting.sql new file mode 100644 index 0000000..cfa3a17 --- /dev/null +++ b/SQL_export_high_risk_setting.sql @@ -0,0 +1,15 @@ +select e_id, n_disrupt_sample, total_n_disrupt_firm_prod_experiment, dct_lst_init_disrupt_firm_prod from iiabmdb.without_exp_experiment as experiment +inner join ( +select e_id, count(id) as n_disrupt_sample, sum(n_disrupt_firm_prod_sample) as total_n_disrupt_firm_prod_experiment from iiabmdb.without_exp_sample as sample +inner join ( +select * from +(select s_id, COUNT(DISTINCT id_firm, id_product) as n_disrupt_firm_prod_sample from iiabmdb.without_exp_result group by s_id +) as count_disrupt_firm_prod_sample +where n_disrupt_firm_prod_sample > 1 +) as disrupt_sample +on sample.id = disrupt_sample.s_id +group by e_id +) as disrupt_experiment +on experiment.id = disrupt_experiment.e_id +order by n_disrupt_sample desc, total_n_disrupt_firm_prod_experiment desc +limit 0, 95; \ No newline at end of file diff --git a/SQL_migrate_db.sql b/SQL_migrate_db.sql new file mode 100644 index 0000000..e5c0555 --- /dev/null +++ b/SQL_migrate_db.sql @@ -0,0 +1,13 @@ +CREATE DATABASE iiabmdb20230829; +RENAME TABLE iiabmdb.not_test_experiment TO iiabmdb20230829.not_test_experiment, +iiabmdb.not_test_result TO iiabmdb20230829.not_test_result, +iiabmdb.not_test_sample TO iiabmdb20230829.not_test_sample, +iiabmdb.test_experiment TO iiabmdb20230829.test_experiment, +iiabmdb.test_result TO iiabmdb20230829.test_result, +iiabmdb.test_sample TO iiabmdb20230829.test_sample; +RENAME TABLE iiabmdb.with_exp_experiment TO iiabmdb20230829.with_exp_experiment, +iiabmdb.with_exp_result TO iiabmdb20230829.with_exp_result, +iiabmdb.with_exp_sample TO iiabmdb20230829.with_exp_sample, +iiabmdb.without_exp_experiment TO iiabmdb20230829.without_exp_experiment, +iiabmdb.without_exp_result TO iiabmdb20230829.without_exp_result, +iiabmdb.without_exp_sample TO iiabmdb20230829.without_exp_sample; \ No newline at end of file diff --git a/computation.py b/computation.py new file mode 100644 index 0000000..8712c49 --- /dev/null +++ b/computation.py @@ -0,0 +1,40 @@ +import os +import datetime + +from model import Model + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from controller_db import ControllerDB + + +class Computation: + + def __init__(self, c_db: 'ControllerDB'): + # 控制不同进程 计算不同的样本 但使用同一个 数据库 c_db + self.c_db = c_db + self.pid = os.getpid() + + def run(self, str_code='0', s_id=None): + sample_random = self.c_db.fetch_a_sample(s_id) + if sample_random is None: + return True + + # lock this row by update is_done_flag to 0 将运行后的样本设置为 flag 0 + self.c_db.lock_the_sample(sample_random) + print( + f"Pid {self.pid} ({str_code}) is running " + f"sample {sample_random.id} at {datetime.datetime.now()}") + # 将sample 对应的 experiment 的一系列值 和 参数值 传入 模型 中 包括列名 和 值 + dct_exp = {column: getattr(sample_random.experiment, column) + for column in sample_random.experiment.__table__.c.keys()} + # 删除不需要的 主键 + del dct_exp['id'] + + dct_sample_para = {'sample': sample_random, + 'seed': sample_random.seed, + **dct_exp} + model = Model(dct_sample_para) + + model.run(display=False) + return False diff --git a/conf_db.yaml b/conf_db.yaml new file mode 100644 index 0000000..74e37e2 --- /dev/null +++ b/conf_db.yaml @@ -0,0 +1,10 @@ +# read by orm +is_local_db: True + +local: + user_name: iiabm_user + password: iiabm_pwd + db_name: iiabmdb + address: 'localhost' + port: 3306 + diff --git a/conf_db_prefix.yaml b/conf_db_prefix.yaml new file mode 100644 index 0000000..0e48fd8 --- /dev/null +++ b/conf_db_prefix.yaml @@ -0,0 +1 @@ +db_name_prefix: without_exp diff --git a/conf_experiment.yaml b/conf_experiment.yaml new file mode 100644 index 0000000..77309ea --- /dev/null +++ b/conf_experiment.yaml @@ -0,0 +1,12 @@ +# read by ControllerDB + +# run settings +meta_seed: 2 + +test: # only for test scenarios + n_sample: 1 + n_iter: 100 + +not_test: # normal scenarios + n_sample: 50 + n_iter: 100 diff --git a/controller_db.py b/controller_db.py new file mode 100644 index 0000000..3314ab7 --- /dev/null +++ b/controller_db.py @@ -0,0 +1,333 @@ +# -*- coding: utf-8 -*- +from orm import db_session, engine, Base, ins, connection +from orm import Experiment, Sample, Result +from sqlalchemy.exc import OperationalError +from sqlalchemy import text +import yaml +import random +import pandas as pd +import platform +import networkx as nx +import json +import pickle + + +class ControllerDB: + is_with_exp: bool + dct_parameter = None + is_test: bool = None + db_name_prefix: str = None + reset_flag: int + + lst_saved_s_id: list + + def __init__(self, prefix, reset_flag=0): + with open('conf_experiment.yaml') as yaml_file: + dct_conf_experiment = yaml.full_load(yaml_file) + assert prefix in ['test', 'without_exp', 'with_exp'], "db name not in test, without_exp, with_exp" + + self.is_test = prefix == 'test' + self.is_with_exp = False if prefix == 'test' or prefix == 'without_exp' else True + self.db_name_prefix = prefix + dct_para_in_test = dct_conf_experiment['test'] if self.is_test else dct_conf_experiment['not_test'] + self.dct_parameter = {'meta_seed': dct_conf_experiment['meta_seed'] , **dct_para_in_test} + + print(self.dct_parameter) + # 0, not reset; 1, reset self; 2, reset all + self.reset_flag = reset_flag + self.is_exist = False + self.lst_saved_s_id = [] + + + def init_tables(self): + self.fill_experiment_table() + self.fill_sample_table() + + def fill_experiment_table(self): + Firm = pd.read_csv("input_data/Firm_amended.csv") + Firm['Code'] = Firm['Code'].astype('string') + Firm.fillna(0, inplace=True) + + # fill dct_lst_init_disrupt_firm_prod + # 存储 公司-在供应链结点的位置.. 0 :‘1.1’ + list_dct = [] #存储 公司编码code 和对应的产业链 结点 + if self.is_with_exp: + #对于方差分析时候使用 + with open('SQL_export_high_risk_setting.sql', 'r') as f: + str_sql = text(f.read()) + result = pd.read_sql(sql=str_sql, con=connection) + result['dct_lst_init_disrupt_firm_prod'] = \ + result['dct_lst_init_disrupt_firm_prod'].apply( + lambda x: pickle.loads(x)) + list_dct = result['dct_lst_init_disrupt_firm_prod'].to_list() + else: + # 行索引 (index):这一行在数据帧中的索引值。 + # 行数据 (row):这一行的数据,是一个 pandas.Series 对象,包含该行的所有列和值。 + for _, row in Firm.iterrows(): + code = row['Code'] + row = row['1':] + for product_code in row.index[row == 1].to_list(): + dct = {code: [product_code]} + list_dct.append(dct) + + # fill g_bom + # 结点属性值 相当于 图上点的 原始 产品名称 + BomNodes = pd.read_csv('input_data/BomNodes.csv', index_col=0) + BomNodes.set_index('Code', inplace=True) + + BomCateNet = pd.read_csv('input_data/BomCateNet.csv', index_col=0) + BomCateNet.fillna(0, inplace=True) + # 创建 可以多边的有向图 同时 转置操作 使得 上游指向下游结点 也就是 1.1.1 - 1.1 类似这种 + g_bom = nx.from_pandas_adjacency(BomCateNet.T, + create_using=nx.MultiDiGraph()) + # 填充每一个结点 的具体内容 通过 相同的 code 并且通过BomNodes.loc[code].to_dict()字典化 格式类似 格式 { code(0) : {level: 0 ,name: 工业互联网 }} + bom_labels_dict = {} + for code in g_bom.nodes: + bom_labels_dict[code] = BomNodes.loc[code].to_dict() + # 分配属性 给每一个结点 获得类似 格式:{1: {'label': 'A', 'value': 10}, + nx.set_node_attributes(g_bom, bom_labels_dict) + # 改为json 格式 + g_product_js = json.dumps(nx.adjacency_data(g_bom)) + + # insert exp + df_xv = pd.read_csv( + "input_data/" + f"xv_{'with_exp' if self.is_with_exp else 'without_exp'}.csv", + index_col=None) + # read the OA table + df_oa = pd.read_csv( + "input_data/" + f"oa_{'with_exp' if self.is_with_exp else 'without_exp'}.csv", + index_col=None) + # .shape[1] 列数 .iloc 访问特定的值 而不是标签 + df_oa = df_oa.iloc[:, 0:df_xv.shape[1]] + # idx_scenario 是 0 指行 idx_init_removal 指 索引 0.. dct_init_removal 键 code 公司 g_product_js 图的json数据 dct_exp_para 解码 全局参数xv- + for idx_scenario, row in df_oa.iterrows(): + dct_exp_para = {} + for idx_col, para_level in enumerate(row): + dct_exp_para[df_xv.columns[idx_col]] = \ + df_xv.iloc[para_level, idx_col] + # different initial removal 只会得到 键 和 值 + for idx_init_removal, dct_init_removal in enumerate(list_dct): + self.add_experiment_1(idx_scenario, + idx_init_removal, + dct_init_removal, + g_product_js, + **dct_exp_para) + print(f"Inserted experiment for scenario {idx_scenario}, " + f"init_removal {idx_init_removal}!") + + def add_experiment_1(self, idx_scenario, idx_init_removal, + dct_lst_init_disrupt_firm_prod, g_bom, + n_max_trial, prf_size, prf_conn, + cap_limit_prob_type, cap_limit_level, + diff_new_conn, remove_t, netw_prf_n): + e = Experiment( + idx_scenario=idx_scenario, + idx_init_removal=idx_init_removal, + n_sample=int(self.dct_parameter['n_sample']), + n_iter=int(self.dct_parameter['n_iter']), + dct_lst_init_disrupt_firm_prod=dct_lst_init_disrupt_firm_prod, + g_bom=g_bom, + n_max_trial=n_max_trial, + prf_size=prf_size, + prf_conn=prf_conn, + cap_limit_prob_type=cap_limit_prob_type, + cap_limit_level=cap_limit_level, + diff_new_conn=diff_new_conn, + remove_t=remove_t, + netw_prf_n=netw_prf_n + ) + db_session.add(e) + db_session.commit() + + def fill_sample_table(self): + rng = random.Random(self.dct_parameter['meta_seed']) + # 根据样本数目 设置 32 位随机整数 + lst_seed = [ + rng.getrandbits(32) + for _ in range(int(self.dct_parameter['n_sample'])) + ] + lst_exp = db_session.query(Experiment).all() + + lst_sample = [] + for experiment in lst_exp: + # idx_sample: 1-50 + for idx_sample in range(int(experiment.n_sample)): + s = Sample(e_id=experiment.id, + idx_sample=idx_sample + 1, + seed=lst_seed[idx_sample], + is_done_flag=-1) + lst_sample.append(s) + db_session.bulk_save_objects(lst_sample) + db_session.commit() + print(f'Inserted {len(lst_sample)} samples!') + + def reset_db(self, force_drop=False): + # first, check if tables exist + lst_table_obj = [ + Base.metadata.tables[str_table] + for str_table in ins.get_table_names() + if str_table.startswith(self.db_name_prefix) + ] + self.is_exist = len(lst_table_obj) > 0 + if force_drop: + self.force_drop_db(lst_table_obj) + # while is_exist: + # a_table = random.choice(lst_table_obj) + # try: + # Base.metadata.drop_all(bind=engine, tables=[a_table]) + # except KeyError: + # pass + # except OperationalError: + # pass + # else: + # lst_table_obj.remove(a_table) + # print( + # f"Table {a_table.name} is dropped " + # f"for exp: {self.db_name_prefix}!!!" + # ) + # finally: + # is_exist = len(lst_table_obj) > 0 + + if self.is_exist: + print( + f"All tables exist. No need to reset " + f"for exp: {self.db_name_prefix}." + ) + # change the is_done_flag from 0 to -1 + # rerun the in-finished tasks + self.is_exist_reset_flag_resset_db() + # if self.reset_flag > 0: + # if self.reset_flag == 2: + # sample = db_session.query(Sample).filter( + # Sample.is_done_flag == 0) + # elif self.reset_flag == 1: + # sample = db_session.query(Sample).filter( + # Sample.is_done_flag == 0, + # Sample.computer_name == platform.node()) + # else: + # raise ValueError('Wrong reset flag') + # if sample.count() > 0: + # for s in sample: + # qry_result = db_session.query(Result).filter_by( + # s_id=s.id) + # if qry_result.count() > 0: + # db_session.query(Result).filter(s_id=s.id).delete() + # db_session.commit() + # s.is_done_flag = -1 + # db_session.commit() + # print(f"Reset the sample id {s.id} flag from 0 to -1") + + else: + # 不存在则重新生成所有的表结构 + Base.metadata.create_all(bind=engine) + self.init_tables() + print( + f"All tables are just created and initialized " + f"for exp: {self.db_name_prefix}." + ) + + def force_drop_db(self, lst_table_obj): + self.is_exist = len(lst_table_obj) > 0 + while self.is_exist: + a_table = random.choice(lst_table_obj) + try: + Base.metadata.drop_all(bind=engine, tables=[a_table]) + except KeyError: + pass + except OperationalError: + pass + else: + lst_table_obj.remove(a_table) + print( + f"Table {a_table.name} is dropped " + f"for exp: {self.db_name_prefix}!!!" + ) + finally: + self.is_exist = len(lst_table_obj) > 0 + + def is_exist_reset_flag_resset_db(self): + if self.reset_flag > 0: + if self.reset_flag == 2: + sample = db_session.query(Sample).filter( + Sample.is_done_flag == 0) + elif self.reset_flag == 1: + sample = db_session.query(Sample).filter( + Sample.is_done_flag == 0, + Sample.computer_name == platform.node()) + else: + raise ValueError('Wrong reset flag') + if sample.count() > 0: + for s in sample: + qry_result = db_session.query(Result).filter_by( + s_id=s.id) + if qry_result.count() > 0: + db_session.query(Result).filter(s_id=s.id).delete() + db_session.commit() + s.is_done_flag = -1 + db_session.commit() + print(f"Reset the sample id {s.id} flag from 0 to -1") + + def prepare_list_sample(self): + # 为了符合前面 重置表里面存在 重置本机 或者重置全部 或者不重置的部分 这个部分的 关于样本运行也得重新拿出来 + # 查找一个风险事件中 50 个样本 + res = db_session.execute( + text(f"SELECT count(*) FROM {self.db_name_prefix}_sample s, " + f"{self.db_name_prefix}_experiment e WHERE s.e_id=e.id" + )).scalar() + # 控制 n_sample数量 作为后面的参数 + n_sample = 0 if res is None else res + print(f'There are a total of {n_sample} samples.') + # 查找 is_done_flag = -1 也就是没有运行的 样本 运行后会改为0 + res = db_session.execute( + text(f"SELECT id FROM {self.db_name_prefix}_sample " + f"WHERE is_done_flag = -1" + )) + for row in res: + s_id = row[0] + self.lst_saved_s_id.append(s_id) + + @staticmethod + def select_random_sample(lst_s_id): + while 1: + if len(lst_s_id) == 0: + return None + s_id = random.choice(lst_s_id) + lst_s_id.remove(s_id) + res = db_session.query(Sample).filter(Sample.id == int(s_id), + Sample.is_done_flag == -1) + if res.count() == 1: + return res[0] + + def fetch_a_sample(self, s_id=None): + # 由Computation 调用 返回 sample对象 同时给出 2中 指定访问模式 抓取特定的 样本 通过s_id + # 默认访问 flag为-1的 lst_saved_s_id + if s_id is not None: + res = db_session.query(Sample).filter(Sample.id == int(s_id)) + if res.count() == 0: + return None + else: + return res[0] + + sample = self.select_random_sample(self.lst_saved_s_id) + if sample is not None: + return sample + + return None + + @staticmethod + def lock_the_sample(sample: Sample): + sample.is_done_flag, sample.computer_name = 0, platform.node() + db_session.commit() + + +if __name__ == '__main__': + print("Testing the database connection...") + try: + controller_db = ControllerDB('test') + Base.metadata.create_all(bind=engine) + except Exception as e: + print("Failed to connect to the database!") + print(e) + exit(1) diff --git a/firm.py b/firm.py new file mode 100644 index 0000000..241aa3c --- /dev/null +++ b/firm.py @@ -0,0 +1,199 @@ +from mesa import Agent + + +class FirmAgent(Agent): + def __init__(self, unique_id, model, code, type_region, revenue_log, a_lst_product): + # 调用超类的 __init__ 方法 + super().__init__(unique_id, model) + + # 初始化模型中的网络引用 + self.firm_network = self.model.firm_network + self.product_network = self.model.product_network + + # 初始化代理自身的属性 + self.code = code + self.type_region = type_region + + self.size_stat = [] + self.dct_prod_up_prod_stat = {} + self.dct_prod_capacity = {} + + # 试验中的参数 + self.dct_n_trial_up_prod_disrupted = {} + self.dct_cand_alt_supp_up_prod_disrupted = {} + self.dct_request_prod_from_firm = {} + + # 外部变量 + self.is_prf_size = self.model.is_prf_size + self.is_prf_conn = bool(self.model.prf_conn) + self.str_cap_limit_prob_type = str(self.model.cap_limit_prob_type) + self.flt_cap_limit_level = float(self.model.cap_limit_level) + self.flt_diff_new_conn = float(self.model.diff_new_conn) + + # 初始化 size_stat + self.size_stat.append((revenue_log, 0)) + + # 初始化 dct_prod_up_prod_stat + for prod in a_lst_product: + self.dct_prod_up_prod_stat[prod] = { + 'p_stat': [('N', 0)], + 's_stat': {up_prod: {'stat': True, 'set_disrupt_firm': set()} + for up_prod in prod.a_predecessors()} + } + + # 初始化额外容量 (dct_prod_capacity) + for product in a_lst_product: + assert self.str_cap_limit_prob_type in ['uniform', 'normal'], \ + "cap_limit_prob_type must be either 'uniform' or 'normal'" + extra_cap_mean = self.size_stat[0][0] / self.flt_cap_limit_level + if self.str_cap_limit_prob_type == 'uniform': + extra_cap = self.model.random.randint(extra_cap_mean - 2, extra_cap_mean + 2) + elif self.str_cap_limit_prob_type == 'normal': + extra_cap = self.model.random.normalvariate(extra_cap_mean, 1) + extra_cap = max(0, round(extra_cap)) + self.dct_prod_capacity[product] = extra_cap + + def remove_edge_to_cus(self, disrupted_prod): + lst_out_edge = list( + self.firm_network.get_neighbors(self.unique_id)) + for n2 in lst_out_edge: + edge_data = self.firm_network.G.edges[self.unique_id, n2] + if edge_data.get('Product') == disrupted_prod.code: + customer = self.model.schedule.get_agent(n2) + for prod in customer.dct_prod_up_prod_stat.keys(): + if disrupted_prod in customer.dct_prod_up_prod_stat[prod]['s_stat'].keys(): + customer.dct_prod_up_prod_stat[prod]['s_stat'][disrupted_prod][ + 'set_disrupt_firm'].add(self) + self.firm_network.remove_edge(self.unique_id, n2) + + def disrupt_cus_prod(self, prod, disrupted_up_prod): + num_lost = len(self.dct_prod_up_prod_stat[prod]['s_stat'][disrupted_up_prod]['set_disrupt_firm']) + num_remain = len([ + u for u in self.firm_network.get_neighbors(self.unique_id) + if self.firm_network.G.edges[u, self.unique_id].get('Product') == disrupted_up_prod.code]) + lost_percent = num_lost / (num_lost + num_remain) + lst_size = [firm.size_stat[-1][0] for firm in self.model.schedule.agents] + std_size = (self.size_stat[-1][0] - min(lst_size) + 1) / (max(lst_size) - min(lst_size) + 1) + + prob_disrupt = 1 - std_size * (1 - lost_percent) + if self.random.choice([True, False], p=[prob_disrupt, 1 - prob_disrupt]): + self.dct_n_trial_up_prod_disrupted[disrupted_up_prod] = 0 + self.dct_prod_up_prod_stat[prod]['s_stat'][disrupted_up_prod]['stat'] = False + status, _ = self.dct_prod_up_prod_stat[prod]['p_stat'][-1] + if status != 'D': + self.dct_prod_up_prod_stat[prod]['p_stat'].append(('D', self.model.schedule.time)) + + def seek_alt_supply(self, product): + if self.dct_n_trial_up_prod_disrupted[product] <= self.model.int_n_max_trial: + if self.dct_n_trial_up_prod_disrupted[product] == 0: + self.dct_cand_alt_supp_up_prod_disrupted[product] = [ + firm for firm in self.model.schedule.agents + if firm.is_prod_in_current_normal(product)] + if self.dct_cand_alt_supp_up_prod_disrupted[product]: + lst_firm_connect = [] + if self.is_prf_conn: + for firm in self.dct_cand_alt_supp_up_prod_disrupted[product]: + if self.firm_network.G.has_edge(self.unique_id, firm.unique_id) or \ + self.firm_network.G.has_edge(firm.unique_id, self.unique_id): + lst_firm_connect.append(firm) + if len(lst_firm_connect) == 0: + if self.is_prf_size: + lst_size = [firm.size_stat[-1][0] for firm in self.dct_cand_alt_supp_up_prod_disrupted[product]] + lst_prob = [size / sum(lst_size) for size in lst_size] + select_alt_supply = \ + self.random.choices(self.dct_cand_alt_supp_up_prod_disrupted[product], weights=lst_prob)[0] + else: + select_alt_supply = self.random.choice(self.dct_cand_alt_supp_up_prod_disrupted[product]) + elif len(lst_firm_connect) > 0: + if self.is_prf_size: + lst_firm_size = [firm.size_stat[-1][0] for firm in lst_firm_connect] + lst_prob = [size / sum(lst_firm_size) for size in lst_firm_size] + select_alt_supply = self.random.choices(lst_firm_connect, weights=lst_prob)[0] + else: + select_alt_supply = self.random.choice(lst_firm_connect) + + assert select_alt_supply.is_prod_in_current_normal(product) + + if product in select_alt_supply.dct_request_prod_from_firm: + select_alt_supply.dct_request_prod_from_firm[product].append(self) + else: + select_alt_supply.dct_request_prod_from_firm[product] = [self] + + self.dct_n_trial_up_prod_disrupted[product] += 1 + + def handle_request(self): + for product, lst_firm in self.dct_request_prod_from_firm.items(): + if self.dct_prod_capacity[product] > 0: + if len(lst_firm) == 0: + continue + elif len(lst_firm) == 1: + self.accept_request(lst_firm[0], product) + elif len(lst_firm) > 1: + lst_firm_connect = [] + if self.is_prf_conn: + for firm in lst_firm: + if self.firm_network.G.has_edge(self.unique_id, firm.unique_id) or \ + self.firm_network.G.has_edge(firm.unique_id, self.unique_id): + lst_firm_connect.append(firm) + if len(lst_firm_connect) == 0: + if self.is_prf_size: + lst_firm_size = [firm.size_stat[-1][0] for firm in lst_firm] + lst_prob = [size / sum(lst_firm_size) for size in lst_firm_size] + select_customer = self.random.choices(lst_firm, weights=lst_prob)[0] + else: + select_customer = self.random.choice(lst_firm) + self.accept_request(select_customer, product) + elif len(lst_firm_connect) > 0: + if self.is_prf_size: + lst_firm_size = [firm.size_stat[-1][0] for firm in lst_firm_connect] + lst_prob = [size / sum(lst_firm_size) for size in lst_firm_size] + select_customer = self.random.choices(lst_firm_connect, weights=lst_prob)[0] + else: + select_customer = self.random.choice(lst_firm_connect) + self.accept_request(select_customer, product) + else: + for down_firm in lst_firm: + down_firm.dct_cand_alt_supp_up_prod_disrupted[product].remove(self) + + def accept_request(self, down_firm, product): + if self.firm_network.G.has_edge(self.unique_id, down_firm.unique_id) or \ + self.firm_network.G.has_edge(down_firm.unique_id, self.unique_id): + prod_accept = 1.0 + else: + prod_accept = self.flt_diff_new_conn + if self.random.choice([True, False], p=[prod_accept, 1 - prod_accept]): + self.firm_network.G.add_edge(self.unique_id, down_firm.unique_id, Product=product.code) + self.dct_prod_capacity[product] -= 1 + self.dct_request_prod_from_firm[product].remove(down_firm) + + for prod in down_firm.dct_prod_up_prod_stat.keys(): + if product in down_firm.dct_prod_up_prod_stat[prod]['s_stat']: + down_firm.dct_prod_up_prod_stat[prod]['s_stat'][product]['stat'] = True + down_firm.dct_prod_up_prod_stat[prod]['p_stat'].append( + ('N', self.model.schedule.time)) + del down_firm.dct_n_trial_up_prod_disrupted[product] + del down_firm.dct_cand_alt_supp_up_prod_disrupted[product] + else: + down_firm.dct_cand_alt_supp_up_prod_disrupted[product].remove(self) + + def clean_before_trial(self): + self.dct_request_prod_from_firm = {} + + def clean_before_time_step(self): + # Reset the number of trials and candidate suppliers for disrupted products + self.dct_n_trial_up_prod_disrupted = dict.fromkeys(self.dct_n_trial_up_prod_disrupted.keys(), 0) + self.dct_cand_alt_supp_up_prod_disrupted = {} + + # Update the status of products and refresh disruption sets + for prod in self.dct_prod_up_prod_stat.keys(): + status, ts = self.dct_prod_up_prod_stat[prod]['p_stat'][-1] + if ts != self.model.schedule.time: + self.dct_prod_up_prod_stat[prod]['p_stat'].append((status, self.model.schedule.time)) + + # Refresh the set of disrupted firms + for up_prod in self.dct_prod_up_prod_stat[prod]['s_stat'].keys(): + self.dct_prod_up_prod_stat[prod]['s_stat'][up_prod]['set_disrupt_firm'] = set() + + def step(self): + # 在每个时间步进行的操作 + pass diff --git a/main.py b/main.py new file mode 100644 index 0000000..70c84d1 --- /dev/null +++ b/main.py @@ -0,0 +1,67 @@ +import os +import random +import time +from multiprocessing import Process +import argparse +from computation import Computation +from sqlalchemy.orm import close_all_sessions + +import yaml + +from controller_db import ControllerDB + + +def controll_db_and_process(exp_argument, reset_sample_argument, reset_db_argument): + from controller_db import ControllerDB + controller_db = ControllerDB(exp_argument, reset_flag=reset_sample_argument) + # controller_db.reset_db() + # force drop + controller_db.reset_db(force_drop=reset_db_argument) + # 准备样本表 + controller_db.prepare_list_sample() + + close_all_sessions() + # 调用 do_process 利用计算机进行多核处理 仿真 将数据库中 + do_process(do_computation, controller_db) + + +def do_process(target: object, controller_db: ControllerDB, ): + process_list = [] + for i in range(int(args.job)): + p = Process(target=do_computation, args=(controller_db,)) + p.start() + process_list.append(p) + + for i in process_list: + i.join() + + +def do_computation(c_db): + exp = Computation(c_db) + + while 1: + time.sleep(random.uniform(0, 5)) + is_all_done = exp.run() + if is_all_done: + break + + +if __name__ == '__main__': + # 输入参数 + parser = argparse.ArgumentParser(description='setting') + parser.add_argument('--exp', type=str, default='test') + parser.add_argument('--job', type=int, default='3') + parser.add_argument('--reset_sample', type=int, default='0') + parser.add_argument('--reset_db', type=bool, default=False) + + args = parser.parse_args() + # 几核参与进程 + assert args.job >= 1, 'Number of jobs should >= 1' + # 控制参数 利用 prefix_file_name 前缀名字 控制 2项不同的实验 + prefix_file_name = 'conf_db_prefix.yaml' + if os.path.exists(prefix_file_name): + os.remove(prefix_file_name) + with open(prefix_file_name, 'w', encoding='utf-8') as file: + yaml.dump({'db_name_prefix': args.exp}, file) + # 数据库连接控制 和 进行模型运行 + controll_db_and_process(args.exp, args.reset_sample, args.reset_db) diff --git a/model.py b/model.py new file mode 100644 index 0000000..9756931 --- /dev/null +++ b/model.py @@ -0,0 +1,133 @@ +import json + +import networkx as nx +import pandas as pd +from mesa import Model +from mesa.time import RandomActivation +from mesa.space import MultiGrid +from mesa.datacollection import DataCollector + +from firm import FirmAgent +from product import ProductAgent + + +class MyModel(Model): + def __init__(self, params): + # self.num_agents = params['N'] + # self.grid = MultiGrid(params['width'], params['height'], True) + # self.schedule = RandomActivation(self) + + # Initialize parameters from `params` + self.sample = params['sample'] + self.int_stop_ts = 0 + self.int_n_iter = int(params['n_iter']) + self.dct_lst_init_disrupt_firm_prod = params['dct_lst_init_disrupt_firm_prod'] + # external variable + self.int_n_max_trial = int(params['n_max_trial']) + self.is_prf_size = bool(params['prf_size']) + + self.remove_t = int(params['remove_t']) + self.int_netw_prf_n = int(params['netw_prf_n']) + + self.product_network = None + self.firm_network = None + self.firm_prod_network = None + + # Initialize product network + G_bom = nx.adjacency_graph(json.loads(params['g_bom'])) + self.product_network = G_bom + + # Initialize firm network + self.initialize_firm_network() + + # Initialize firm product network + self.initialize_firm_prod_network() + + # Initialize agents + self.initialize_agents() + + # Data collector (if needed) + self.datacollector = DataCollector( + agent_reporters={"Product Code": "code"} + ) + + def initialize_firm_network(self): + # Read firm data and initialize firm network + firm = pd.read_csv("input_data/Firm_amended.csv") + firm['Code'] = firm['Code'].astype('string') + firm.fillna(0, inplace=True) + Firm_attr = firm[["Code", "Type_Region", "Revenue_Log"]] + firm_product = [] + for _, row in firm.loc[:, '1':].iterrows(): + firm_product.append(row[row == 1].index.to_list()) + Firm_attr['Product_Code'] = firm_product + Firm_attr.set_index('Code', inplace=True) + G_Firm = nx.MultiDiGraph() + G_Firm.add_nodes_from(firm["Code"]) + + # Add node attributes + firm_labels_dict = {} + for code in G_Firm.nodes: + firm_labels_dict[code] = Firm_attr.loc[code].to_dict() + nx.set_node_attributes(G_Firm, firm_labels_dict) + + # Add edges based on BOM graph + self.add_edges_based_on_bom(G_Firm) + + self.firm_network = G_Firm + + def initialize_firm_prod_network(self): + # Read firm product data and initialize firm product network + firm_prod = pd.read_csv("input_data/Firm_amended.csv") + firm_prod.fillna(0, inplace=True) + firm_prod = pd.DataFrame({'bool': firm_prod.loc[:, '1':].stack()}) + firm_prod = firm_prod[firm_prod['bool'] == 1].reset_index() + firm_prod.drop('bool', axis=1, inplace=True) + firm_prod.rename({'level_0': 'Firm_Code', 'level_1': 'Product_Code'}, axis=1, inplace=True) + firm_prod['Firm_Code'] = firm_prod['Firm_Code'].astype('string') + + G_FirmProd = nx.MultiDiGraph() + G_FirmProd.add_nodes_from(firm_prod.index) + + # Add node attributes + firm_prod_labels_dict = {} + for code in firm_prod.index: + firm_prod_labels_dict[code] = firm_prod.loc[code].to_dict() + nx.set_node_attributes(G_FirmProd, firm_prod_labels_dict) + + self.firm_prod_network = G_FirmProd + + def add_edges_based_on_bom(self, G_Firm): + # Logic to add edges to the G_Firm graph based on BOM + pass + + def initialize_agents(self): + # Initialize product and firm agents + for node, attr in self.product_network.nodes(data=True): + product = ProductAgent(node, self, code=node, name=attr['Name']) + self.schedule.add(product) + + for node, attr in self.firm_network.nodes(data=True): + firm_agent = FirmAgent( + node, + self, + code=node, + type_region=attr['Type_Region'], + revenue_log=attr['Revenue_Log'], + a_lst_product=[] # Populate based on products + ) + self.schedule.add(firm_agent) + + # Initialize disruptions + self.initialize_disruptions() + + def initialize_disruptions(self): + # Set the initial firm product disruptions + for firm, products in self.dct_lst_init_disrupt_firm_prod.items(): + for product in products: + if isinstance(firm, FirmAgent): + firm.dct_prod_up_prod_stat[product]['p_stat'].append(('D', self.schedule.steps)) + + def step(self): + self.schedule.step() + self.datacollector.collect(self) diff --git a/orm.py b/orm.py new file mode 100644 index 0000000..b96d88f --- /dev/null +++ b/orm.py @@ -0,0 +1,114 @@ +# -*- coding: utf-8 -*- +from sqlalchemy import create_engine, inspect, Inspector +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy import (Column, Integer, DECIMAL, String, ForeignKey, + BigInteger, DateTime, PickleType, Boolean, Text) +from sqlalchemy.sql import func +from sqlalchemy.orm import relationship, Session +from sqlalchemy.pool import NullPool +import yaml + +with open('conf_db.yaml') as file: + dct_conf_db_all = yaml.full_load(file) + is_local_db = dct_conf_db_all['is_local_db'] + if is_local_db: + dct_conf_db = dct_conf_db_all['local'] + else: + dct_conf_db = dct_conf_db_all['remote'] + +with open('conf_db_prefix.yaml') as file: + dct_conf_db_prefix = yaml.full_load(file) + db_name_prefix = dct_conf_db_prefix['db_name_prefix'] + +str_login = 'mysql://{}:{}@{}:{}/{}'.format(dct_conf_db['user_name'], + dct_conf_db['password'], + dct_conf_db['address'], + dct_conf_db['port'], + dct_conf_db['db_name']) +# print('DB is {}:{}/{}'.format(dct_conf_db['address'], dct_conf_db['port'], dct_conf_db['db_name'])) + +# must be null pool to avoid connection lost error +engine = create_engine(str_login, poolclass=NullPool) +connection = engine.connect() +ins: Inspector = inspect(engine) + +Base = declarative_base() + +db_session = Session(bind=engine) + + +class Experiment(Base): + __tablename__ = f"{db_name_prefix}_experiment" + id = Column(Integer, primary_key=True, autoincrement=True) + + idx_scenario = Column(Integer, nullable=False) + idx_init_removal = Column(Integer, nullable=False) + + # fixed parameters + n_sample = Column(Integer, nullable=False) + n_iter = Column(Integer, nullable=False) + + # variables + dct_lst_init_disrupt_firm_prod = Column(PickleType, nullable=False) + g_bom = Column(Text(4294000000), nullable=False) + + n_max_trial = Column(Integer, nullable=False) + prf_size = Column(Boolean, nullable=False) + prf_conn = Column(Boolean, nullable=False) + cap_limit_prob_type = Column(String(16), nullable=False) + cap_limit_level = Column(DECIMAL(8, 4), nullable=False) + diff_new_conn = Column(DECIMAL(8, 4), nullable=False) + remove_t = Column(Integer, nullable=False) + netw_prf_n = Column(Integer, nullable=False) + + sample = relationship( + 'Sample', back_populates='experiment', lazy='dynamic') + + def __repr__(self): + return f'' + + +class Sample(Base): + __tablename__ = f"{db_name_prefix}_sample" + id = Column(Integer, primary_key=True, autoincrement=True) + e_id = Column(Integer, ForeignKey('{}.id'.format( + f"{db_name_prefix}_experiment")), nullable=False) + + idx_sample = Column(Integer, nullable=False) + seed = Column(BigInteger, nullable=False) + # -1, waiting; 0, running; 1, done + is_done_flag = Column(Integer, nullable=False) + computer_name = Column(String(64), nullable=True) + ts_done = Column(DateTime(timezone=True), onupdate=func.now()) + stop_t = Column(Integer, nullable=True) + + g_firm = Column(Text(4294000000), nullable=True) + + experiment = relationship( + 'Experiment', back_populates='sample', uselist=False) + result = relationship('Result', back_populates='sample', lazy='dynamic') + + def __repr__(self): + return f'' + + +class Result(Base): + __tablename__ = f"{db_name_prefix}_result" + id = Column(Integer, primary_key=True, autoincrement=True) + s_id = Column(Integer, ForeignKey('{}.id'.format( + f"{db_name_prefix}_sample")), nullable=False) + + id_firm = Column(String(10), nullable=False) + id_product = Column(String(10), nullable=False) + ts = Column(Integer, nullable=False) + status = Column(String(5), nullable=False) + + sample = relationship('Sample', back_populates='result', uselist=False) + + def __repr__(self): + return f'' + + +if __name__ == '__main__': + Base.metadata.drop_all() + Base.metadata.create_all() diff --git a/product.py b/product.py new file mode 100644 index 0000000..950137d --- /dev/null +++ b/product.py @@ -0,0 +1,24 @@ +from mesa import Agent + +class ProductAgent(Agent): + def __init__(self, unique_id, model, code, name): + # 调用超类的 __init__ 方法 + super().__init__(unique_id, model) + + # 初始化代理属性 + self.code = code + self.name = name + self.product_network = self.model.product_network + + def a_successors(self): + # Find successors of the current product and return them as a list of ProductAgent + successors = list(self.product_network.successors(self)) + return [self.model.schedule.agents[successor] for successor in successors] + + def a_predecessors(self): + # Find predecessors of the current product and return them as a list of ProductAgent + predecessors = list(self.product_network.predecessors(self)) + return [self.model.schedule.agents[predecessor] for predecessor in predecessors] + def step(self): + # 在每个时间步进行的操作 + pass diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..69c77f3 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,55 @@ +agentpy==0.1.5 +alabaster==0.7.13 +Babel==2.12.1 +certifi @ file:///C:/b/abs_85o_6fm0se/croot/certifi_1671487778835/work/certifi +charset-normalizer==3.0.1 +colorama==0.4.6 +cycler==0.11.0 +decorator==5.1.1 +dill==0.3.6 +docutils==0.19 +greenlet==2.0.2 +idna==3.4 +imagesize==1.4.1 +importlib-metadata==6.0.0 +Jinja2==3.1.2 +joblib==1.2.0 +kiwisolver==1.4.4 +MarkupSafe==2.1.2 +matplotlib==3.3.4 +matplotlib-inline==0.1.6 +multiprocess==0.70.14 +mysqlclient==2.1.1 +networkx==2.5 +numpy==1.20.3 +numpydoc==1.1.0 +packaging==23.0 +pandas==1.4.1 +pandas-stubs==1.2.0.39 +Pillow==9.4.0 +Pygments==2.14.0 +pygraphviz @ file:///C:/Users/ASUS/Downloads/pygraphviz-1.9-cp38-cp38-win_amd64.whl +pyparsing==3.0.9 +python-dateutil==2.8.2 +pytz==2022.7.1 +PyYAML==6.0 +requests==2.28.2 +SALib==1.4.7 +scipy==1.10.1 +six==1.16.0 +snowballstemmer==2.2.0 +Sphinx==6.1.3 +sphinxcontrib-applehelp==1.0.4 +sphinxcontrib-devhelp==1.0.2 +sphinxcontrib-htmlhelp==2.0.1 +sphinxcontrib-jsmath==1.0.1 +sphinxcontrib-qthelp==1.0.3 +sphinxcontrib-serializinghtml==1.1.5 +SQLAlchemy==2.0.5.post1 +traitlets==5.9.0 +typing_extensions==4.5.0 +urllib3==1.26.14 +wincertstore==0.2 +yapf @ file:///tmp/build/80754af9/yapf_1615749224965/work +zipp==3.15.0 +mesa==2.1.5 diff --git a/requirements_manual_selected_20230304.txt b/requirements_manual_selected_20230304.txt new file mode 100644 index 0000000..91f2ba8 --- /dev/null +++ b/requirements_manual_selected_20230304.txt @@ -0,0 +1,9 @@ +agentpy==0.1.5 +matplotlib==3.3.4 +matplotlib-inline==0.1.6 +networkx==2.5 +numpy==1.20.3 +numpydoc==1.1.0 +pandas==1.4.1 +pandas-stubs==1.2.0.39 +pygraphviz==1.9