mesa/risk_analysis_sum_result.py

144 lines
4.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import pickle
from sqlalchemy import text
from orm import engine, connection
import pandas as pd
import networkx as nx
import json
import matplotlib.pyplot as plt
# Prepare data
Firm = pd.read_csv("input_data/input_firm_data/Firm_amended.csv")
Firm['Code'] = Firm['Code'].astype('string')
Firm.fillna(0, inplace=True)
BomNodes = pd.read_csv('input_data/input_product_data/BomNodes.csv', index_col=0)
# SQL query
with open('SQL_analysis_risk.sql', 'r') as f:
str_sql = text(f.read())
result = pd.read_sql(sql=str_sql, con=connection)
result.to_csv('output_result/risk/count.csv', index=False, encoding='utf-8-sig')
print(result)
# G_bom
plt.rcParams['font.sans-serif'] = 'SimHei'
exp_id = 1
G_bom_df = pd.read_sql(
sql=text(f'select g_bom from iiabmdb.without_exp_experiment where id = {exp_id};'),
con=connection
)
if G_bom_df.empty:
raise ValueError(f"No g_bom found for exp_id = {exp_id}")
G_bom_str = G_bom_df['g_bom'].tolist()[0]
if G_bom_str is None:
raise ValueError(f"g_bom data is None for exp_id = {exp_id}")
G_bom = nx.adjacency_graph(json.loads(G_bom_str))
pos = nx.nx_agraph.graphviz_layout(G_bom, prog="twopi", args="")
node_labels = nx.get_node_attributes(G_bom, 'Name')
plt.figure(figsize=(12, 12), dpi=300)
nx.draw_networkx_nodes(G_bom, pos)
nx.draw_networkx_edges(G_bom, pos)
nx.draw_networkx_labels(G_bom, pos, labels=node_labels, font_size=3)
plt.savefig(f"output_result/risk/g_bom_exp_id_{exp_id}.png")
plt.close()
# G_firm
plt.rcParams['font.sans-serif'] = 'SimHei'
sample_id = 1
# G_firm_df = pd.read_sql(
# sql=text(f'select g_firm from iiabmdb.without_exp_sample where id = {sample_id};'),
# con=connection
# )
#
# if G_firm_df.empty:
# raise ValueError(f"No g_firm found for sample_id = {sample_id}")
#
# G_firm_str = G_firm_df['g_firm'].tolist()[0]
# if G_firm_str is None:
# raise ValueError(f"g_firm data is None for sample_id = {sample_id}")
#
# G_firm = nx.adjacency_graph(json.loads(G_firm_str))
with open("firm_network.pkl", 'rb') as f:
G_firm = pickle.load(f)
print(f"Successfully loaded cached data from firm_network.pkl")
pos = nx.nx_agraph.graphviz_layout(G_firm, prog="twopi", args="")
node_label = nx.get_node_attributes(G_firm, 'Revenue_Log')
node_label = {key: key for key in node_label.keys()}
node_size = list(nx.get_node_attributes(G_firm, 'Revenue_Log').values())
edge_label = nx.get_edge_attributes(G_firm, "Product")
edge_label = {(n1, n2): label for (n1, n2, _), label in edge_label.items()}
plt.figure(figsize=(12, 12), dpi=300)
nx.draw(G_firm, pos, node_size=node_size, labels=node_label, font_size=5)
nx.draw_networkx_edge_labels(G_firm, pos, edge_label, font_size=4)
plt.savefig(f"output_result/risk/g_firm_sample_id_{sample_id}_de.png")
plt.close()
# Count firm product
count_firm_prod = result.value_counts(subset=['id_firm', 'id_product'])
count_firm_prod.name = 'count'
count_firm_prod = count_firm_prod.to_frame().reset_index()
count_firm_prod.to_csv('output_result/risk/count_firm_prod.csv', index=False, encoding='utf-8-sig')
print(count_firm_prod)
# Count firm
count_firm = count_firm_prod.groupby('id_firm')['count'].sum()
count_firm = count_firm.to_frame().reset_index()
count_firm.sort_values('count', inplace=True, ascending=False)
count_firm.to_csv('output_result/risk/count_firm.csv', index=False, encoding='utf-8-sig')
print(count_firm)
# Count product
count_prod = count_firm_prod.groupby('id_product')['count'].sum()
count_prod = count_prod.to_frame().reset_index()
count_prod.sort_values('count', inplace=True, ascending=False)
count_prod.to_csv('output_result/risk/count_prod.csv', index=False, encoding='utf-8-sig')
print(count_prod)
# DCP disruption causing probability
result_disrupt_ts_above_0 = result[result['ts'] > 0]
print(result_disrupt_ts_above_0)
result_dcp = pd.DataFrame(columns=[
's_id', 'up_id_firm', 'up_id_product', 'down_id_firm', 'down_id_product'
])
result_dcp_list = [] # 用列表收集数据避免DataFrame逐行增长的问题
for sid, group in result.groupby('s_id'):
ts_start = max(group['ts'])
while ts_start >= 1:
ts_end = ts_start - 1
while ts_end >= 0:
up = group.loc[group['ts'] == ts_end, ['id_firm', 'id_product']]
down = group.loc[group['ts'] == ts_start, ['id_firm', 'id_product']]
for _, up_row in up.iterrows():
for _, down_row in down.iterrows():
result_dcp_list.append([sid] + up_row.tolist() + down_row.tolist())
ts_end -= 1
ts_start -= 1
# 转换为DataFrame
result_dcp = pd.DataFrame(result_dcp_list, columns=[
's_id', 'up_id_firm', 'up_id_product', 'down_id_firm', 'down_id_product'
])
# 统计
count_dcp = result_dcp.value_counts(
subset=['up_id_firm', 'up_id_product', 'down_id_firm', 'down_id_product']
).reset_index(name='count')
# 保存文件
count_dcp.to_csv('output_result/risk/count_dcp.csv', index=False, encoding='utf-8-sig')
# 输出结果
print(count_dcp)