148 lines
5.2 KiB
Python
148 lines
5.2 KiB
Python
import pickle
|
||
|
||
from sqlalchemy import text
|
||
from orm import engine, connection
|
||
import pandas as pd
|
||
import networkx as nx
|
||
import json
|
||
import matplotlib.pyplot as plt
|
||
|
||
# Prepare data
|
||
Firm = pd.read_csv("input_data/input_firm_data/Firm_amended.csv")
|
||
Firm['Code'] = Firm['Code'].astype('string')
|
||
Firm.fillna(0, inplace=True)
|
||
BomNodes = pd.read_csv('input_data/input_product_data/BomNodes.csv', index_col=0)
|
||
|
||
# SQL query
|
||
with open('SQL_analysis_risk.sql', 'r') as f:
|
||
str_sql = text(f.read())
|
||
|
||
result = pd.read_sql(sql=str_sql, con=connection)
|
||
result.to_csv('output_result/risk/count.csv', index=False, encoding='utf-8-sig')
|
||
print(result)
|
||
|
||
# G_bom
|
||
plt.rcParams['font.sans-serif'] = 'SimHei'
|
||
|
||
exp_id = 1
|
||
G_bom_df = pd.read_sql(
|
||
sql=text(f'select g_bom from iiabmdb.without_exp_experiment where id = {exp_id};'),
|
||
con=connection
|
||
)
|
||
|
||
if G_bom_df.empty:
|
||
raise ValueError(f"No g_bom found for exp_id = {exp_id}")
|
||
|
||
G_bom_str = G_bom_df['g_bom'].tolist()[0]
|
||
if G_bom_str is None:
|
||
raise ValueError(f"g_bom data is None for exp_id = {exp_id}")
|
||
|
||
G_bom = nx.adjacency_graph(json.loads(G_bom_str))
|
||
pos = nx.nx_agraph.graphviz_layout(G_bom, prog="twopi", args="")
|
||
node_labels = nx.get_node_attributes(G_bom, 'Name')
|
||
|
||
plt.figure(figsize=(12, 12), dpi=300)
|
||
nx.draw_networkx_nodes(G_bom, pos)
|
||
nx.draw_networkx_edges(G_bom, pos)
|
||
nx.draw_networkx_labels(G_bom, pos, labels=node_labels, font_size=3)
|
||
plt.savefig(f"output_result/risk/g_bom_exp_id_{exp_id}.png")
|
||
plt.close()
|
||
|
||
# G_firm
|
||
plt.rcParams['font.sans-serif'] = 'SimHei'
|
||
|
||
sample_id = 1
|
||
# G_firm_df = pd.read_sql(
|
||
# sql=text(f'select g_firm from iiabmdb.without_exp_sample where id = {sample_id};'),
|
||
# con=connection
|
||
# )
|
||
#
|
||
# if G_firm_df.empty:
|
||
# raise ValueError(f"No g_firm found for sample_id = {sample_id}")
|
||
#
|
||
# G_firm_str = G_firm_df['g_firm'].tolist()[0]
|
||
# if G_firm_str is None:
|
||
# raise ValueError(f"g_firm data is None for sample_id = {sample_id}")
|
||
#
|
||
# G_firm = nx.adjacency_graph(json.loads(G_firm_str))
|
||
|
||
with open("firm_network.pkl", 'rb') as f:
|
||
G_firm = pickle.load(f)
|
||
print(f"Successfully loaded cached data from firm_network.pkl")
|
||
|
||
# 1. 移除孤立节点
|
||
isolated_nodes = list(nx.isolates(G_firm)) # 找出所有没有连接的孤立节点
|
||
G_firm.remove_nodes_from(isolated_nodes) # 从图中移除这些节点
|
||
|
||
# 2. 重新布局和绘图
|
||
pos = nx.nx_agraph.graphviz_layout(G_firm, prog="twopi", args="")
|
||
node_label = {key: key for key in nx.get_node_attributes(G_firm, 'Revenue_Log').keys()}
|
||
node_size = [value * 10 for value in nx.get_node_attributes(G_firm, 'Revenue_Log').values()] # 节点大小扩大10倍
|
||
edge_label = {(n1, n2): label for (n1, n2, _), label in nx.get_edge_attributes(G_firm, "Product").items()}
|
||
|
||
plt.figure(figsize=(12, 12), dpi=500)
|
||
nx.draw(G_firm, pos, node_size=node_size, labels=node_label, font_size=5, width=0.5)
|
||
nx.draw_networkx_edge_labels(G_firm, pos, edge_label, font_size=2)
|
||
plt.axis('equal') # 锁定坐标轴比例,确保图形内容是正方形
|
||
plt.savefig(f"output_result/risk/g_firm_sample_id_{sample_id}_de.png", bbox_inches='tight', pad_inches=0.1)
|
||
plt.close()
|
||
|
||
|
||
# Count firm product
|
||
count_firm_prod = result.value_counts(subset=['id_firm', 'id_product'])
|
||
count_firm_prod.name = 'count'
|
||
count_firm_prod = count_firm_prod.to_frame().reset_index()
|
||
count_firm_prod.to_csv('output_result/risk/count_firm_prod.csv', index=False, encoding='utf-8-sig')
|
||
print(count_firm_prod)
|
||
|
||
# Count firm
|
||
count_firm = count_firm_prod.groupby('id_firm')['count'].sum()
|
||
count_firm = count_firm.to_frame().reset_index()
|
||
count_firm.sort_values('count', inplace=True, ascending=False)
|
||
count_firm.to_csv('output_result/risk/count_firm.csv', index=False, encoding='utf-8-sig')
|
||
print(count_firm)
|
||
|
||
# Count product
|
||
count_prod = count_firm_prod.groupby('id_product')['count'].sum()
|
||
count_prod = count_prod.to_frame().reset_index()
|
||
count_prod.sort_values('count', inplace=True, ascending=False)
|
||
count_prod.to_csv('output_result/risk/count_prod.csv', index=False, encoding='utf-8-sig')
|
||
print(count_prod)
|
||
|
||
# DCP disruption causing probability
|
||
result_disrupt_ts_above_0 = result[result['ts'] > 0]
|
||
print(result_disrupt_ts_above_0)
|
||
result_dcp = pd.DataFrame(columns=[
|
||
's_id', 'up_id_firm', 'up_id_product', 'down_id_firm', 'down_id_product'
|
||
])
|
||
|
||
result_dcp_list = [] # 用列表收集数据,避免DataFrame逐行增长的问题
|
||
for sid, group in result.groupby('s_id'):
|
||
ts_start = max(group['ts'])
|
||
while ts_start >= 1:
|
||
ts_end = ts_start - 1
|
||
while ts_end >= 0:
|
||
up = group.loc[group['ts'] == ts_end, ['id_firm', 'id_product']]
|
||
down = group.loc[group['ts'] == ts_start, ['id_firm', 'id_product']]
|
||
for _, up_row in up.iterrows():
|
||
for _, down_row in down.iterrows():
|
||
result_dcp_list.append([sid] + up_row.tolist() + down_row.tolist())
|
||
ts_end -= 1
|
||
ts_start -= 1
|
||
|
||
# 转换为DataFrame
|
||
result_dcp = pd.DataFrame(result_dcp_list, columns=[
|
||
's_id', 'up_id_firm', 'up_id_product', 'down_id_firm', 'down_id_product'
|
||
])
|
||
|
||
# 统计
|
||
count_dcp = result_dcp.value_counts(
|
||
subset=['up_id_firm', 'up_id_product', 'down_id_firm', 'down_id_product']
|
||
).reset_index(name='count')
|
||
|
||
# 保存文件
|
||
count_dcp.to_csv('output_result/risk/count_dcp.csv', index=False, encoding='utf-8-sig')
|
||
|
||
# 输出结果
|
||
print(count_dcp)
|