mesa/risk_analysis_sum_result.py

148 lines
5.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import pickle
from sqlalchemy import text
from orm import engine, connection
import pandas as pd
import networkx as nx
import json
import matplotlib.pyplot as plt
# Prepare data
Firm = pd.read_csv("input_data/input_firm_data/Firm_amended.csv")
Firm['Code'] = Firm['Code'].astype('string')
Firm.fillna(0, inplace=True)
BomNodes = pd.read_csv('input_data/input_product_data/BomNodes.csv', index_col=0)
# SQL query
with open('SQL_analysis_risk.sql', 'r') as f:
str_sql = text(f.read())
result = pd.read_sql(sql=str_sql, con=connection)
result.to_csv('output_result/risk/count.csv', index=False, encoding='utf-8-sig')
print(result)
# G_bom
plt.rcParams['font.sans-serif'] = 'SimHei'
exp_id = 1
G_bom_df = pd.read_sql(
sql=text(f'select g_bom from iiabmdb.without_exp_experiment where id = {exp_id};'),
con=connection
)
if G_bom_df.empty:
raise ValueError(f"No g_bom found for exp_id = {exp_id}")
G_bom_str = G_bom_df['g_bom'].tolist()[0]
if G_bom_str is None:
raise ValueError(f"g_bom data is None for exp_id = {exp_id}")
G_bom = nx.adjacency_graph(json.loads(G_bom_str))
pos = nx.nx_agraph.graphviz_layout(G_bom, prog="twopi", args="")
node_labels = nx.get_node_attributes(G_bom, 'Name')
plt.figure(figsize=(12, 12), dpi=300)
nx.draw_networkx_nodes(G_bom, pos)
nx.draw_networkx_edges(G_bom, pos)
nx.draw_networkx_labels(G_bom, pos, labels=node_labels, font_size=3)
plt.savefig(f"output_result/risk/g_bom_exp_id_{exp_id}.png")
plt.close()
# G_firm
plt.rcParams['font.sans-serif'] = 'SimHei'
sample_id = 1
# G_firm_df = pd.read_sql(
# sql=text(f'select g_firm from iiabmdb.without_exp_sample where id = {sample_id};'),
# con=connection
# )
#
# if G_firm_df.empty:
# raise ValueError(f"No g_firm found for sample_id = {sample_id}")
#
# G_firm_str = G_firm_df['g_firm'].tolist()[0]
# if G_firm_str is None:
# raise ValueError(f"g_firm data is None for sample_id = {sample_id}")
#
# G_firm = nx.adjacency_graph(json.loads(G_firm_str))
with open("firm_network.pkl", 'rb') as f:
G_firm = pickle.load(f)
print(f"Successfully loaded cached data from firm_network.pkl")
# 1. 移除孤立节点
isolated_nodes = list(nx.isolates(G_firm)) # 找出所有没有连接的孤立节点
G_firm.remove_nodes_from(isolated_nodes) # 从图中移除这些节点
# 2. 重新布局和绘图
pos = nx.nx_agraph.graphviz_layout(G_firm, prog="twopi", args="")
node_label = {key: key for key in nx.get_node_attributes(G_firm, 'Revenue_Log').keys()}
node_size = [value * 10 for value in nx.get_node_attributes(G_firm, 'Revenue_Log').values()] # 节点大小扩大10倍
edge_label = {(n1, n2): label for (n1, n2, _), label in nx.get_edge_attributes(G_firm, "Product").items()}
plt.figure(figsize=(12, 12), dpi=500)
nx.draw(G_firm, pos, node_size=node_size, labels=node_label, font_size=5, width=0.5)
nx.draw_networkx_edge_labels(G_firm, pos, edge_label, font_size=2)
plt.axis('equal') # 锁定坐标轴比例,确保图形内容是正方形
plt.savefig(f"output_result/risk/g_firm_sample_id_{sample_id}_de.png", bbox_inches='tight', pad_inches=0.1)
plt.close()
# Count firm product
count_firm_prod = result.value_counts(subset=['id_firm', 'id_product'])
count_firm_prod.name = 'count'
count_firm_prod = count_firm_prod.to_frame().reset_index()
count_firm_prod.to_csv('output_result/risk/count_firm_prod.csv', index=False, encoding='utf-8-sig')
print(count_firm_prod)
# Count firm
count_firm = count_firm_prod.groupby('id_firm')['count'].sum()
count_firm = count_firm.to_frame().reset_index()
count_firm.sort_values('count', inplace=True, ascending=False)
count_firm.to_csv('output_result/risk/count_firm.csv', index=False, encoding='utf-8-sig')
print(count_firm)
# Count product
count_prod = count_firm_prod.groupby('id_product')['count'].sum()
count_prod = count_prod.to_frame().reset_index()
count_prod.sort_values('count', inplace=True, ascending=False)
count_prod.to_csv('output_result/risk/count_prod.csv', index=False, encoding='utf-8-sig')
print(count_prod)
# DCP disruption causing probability
result_disrupt_ts_above_0 = result[result['ts'] > 0]
print(result_disrupt_ts_above_0)
result_dcp = pd.DataFrame(columns=[
's_id', 'up_id_firm', 'up_id_product', 'down_id_firm', 'down_id_product'
])
result_dcp_list = [] # 用列表收集数据避免DataFrame逐行增长的问题
for sid, group in result.groupby('s_id'):
ts_start = max(group['ts'])
while ts_start >= 1:
ts_end = ts_start - 1
while ts_end >= 0:
up = group.loc[group['ts'] == ts_end, ['id_firm', 'id_product']]
down = group.loc[group['ts'] == ts_start, ['id_firm', 'id_product']]
for _, up_row in up.iterrows():
for _, down_row in down.iterrows():
result_dcp_list.append([sid] + up_row.tolist() + down_row.tolist())
ts_end -= 1
ts_start -= 1
# 转换为DataFrame
result_dcp = pd.DataFrame(result_dcp_list, columns=[
's_id', 'up_id_firm', 'up_id_product', 'down_id_firm', 'down_id_product'
])
# 统计
count_dcp = result_dcp.value_counts(
subset=['up_id_firm', 'up_id_product', 'down_id_firm', 'down_id_product']
).reset_index(name='count')
# 保存文件
count_dcp.to_csv('output_result/risk/count_dcp.csv', index=False, encoding='utf-8-sig')
# 输出结果
print(count_dcp)