mesa/risk_analysis_prod_network.py

224 lines
7.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import pandas as pd
import matplotlib.pyplot as plt
import networkx as nx
plt.rcParams['font.sans-serif'] = 'SimHei'
count_prod = pd.read_csv("output_result/risk/count_prod.csv")
print(count_prod)
# category
print(count_prod.describe())
# prod_networkx
# BomNodes = pd.read_csv('input_data/input_product_data/BomNodes.csv', index_col=0)
# BomNodes.set_index('Code', inplace=True)
# BomCateNet = pd.read_csv('input_data/input_product_data/BomCateNet.csv', index_col=0)
# BomCateNet.fillna(0, inplace=True)
bom_nodes = pd.read_csv('input_data/input_product_data/BomNodes.csv')
bom_nodes['Code'] = bom_nodes['Code'].astype(str)
bom_nodes.set_index('Index', inplace=True)
bom_cate_net = pd.read_csv('input_data/input_product_data/合成结点.csv')
g_bom = nx.from_pandas_edgelist(bom_cate_net, source='UPID', target='ID', create_using=nx.MultiDiGraph())
labels_dict = {}
for code in g_bom.nodes:
node_attr = bom_nodes.loc[code].to_dict()
index_list = count_prod[count_prod['id_product'] == code].index.tolist()
index = index_list[0] if len(index_list) == 1 else -1
node_attr['count'] = count_prod['count'].get(index, 0)
node_attr['node_size'] = (count_prod['count'].get(index, 0))/10
node_attr['node_color'] = count_prod['count'].get(index, 0)
labels_dict[code] = node_attr
nx.set_node_attributes(g_bom, labels_dict)
# print(labels_dict)
pos = nx.nx_agraph.graphviz_layout(g_bom, prog="twopi", args="")
dict_node_name = nx.get_node_attributes(g_bom, 'Name')
node_labels = {}
for node in nx.nodes(g_bom):
node_labels[node] = f"{node} {str(dict_node_name[node])}"
# node_labels[node] = f"{str(dict_node_name[node])}"
colors = list(nx.get_node_attributes(g_bom, 'node_color').values())
vmin = min(colors)
vmax = max(colors)
cmap = plt.cm.Blues
# 创建绘图对象
fig = plt.figure(figsize=(10, 10), dpi=300)
ax = fig.add_subplot(111)
# 绘制网络图(优化样式参数)
nx.draw(g_bom, pos,
node_size=list(nx.get_node_attributes(g_bom, 'node_size').values()),
labels=node_labels,
font_size=3,
node_color=colors,
cmap=cmap,
vmin=vmin,
vmax=vmax,
edge_color='#808080', # 中性灰
width=0.3,
edgecolors='#404040',
linewidths=0.2)
# 创建颜色条(修正实现方式)
sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(vmin=vmin, vmax=vmax))
sm.set_array([])
# 设置颜色条位置和样式
cax = fig.add_axes([0.88, 0.3, 0.015, 0.4]) # 右侧垂直对齐
cb = plt.colorbar(sm, cax=cax)
cb.ax.tick_params(labelsize=4, width=0.5, colors='#333333')
cb.outline.set_linewidth(0.5)
cb.set_label('Risk Level', fontsize=5, labelpad=2)
# 添加图元信息
ax.set_title("Production Risk Network", fontsize=6, pad=8, color='#2F2F2F')
plt.text(0.5, 0.02, 'Data: USTB Production System | Viz: DeepSeek-R1',
ha='center', fontsize=3, color='#666666',
transform=fig.transFigure)
# 调整边界和保存
plt.subplots_adjust(left=0.05, right=0.85, top=0.95, bottom=0.1) # 适应颜色条
plt.savefig(r"output_result/risk/count_prod_network.png", # 规范路径格式
dpi=600,
bbox_inches='tight',
pad_inches=0.05,
transparent=False)
plt.close()
# dcp_prod
count_dcp = pd.read_csv("output_result/risk/count_dcp.csv",
dtype={
'up_id_firm': str,
'down_id_firm': str
})
count_dcp_prod = count_dcp.groupby(
['up_id_product',
'down_id_product'])['count'].sum()
count_dcp_prod = count_dcp_prod.reset_index()
count_dcp_prod.sort_values('count', inplace=True, ascending=False)
count_dcp_prod.to_csv('output_result\\risk\\count_dcp_prod.csv',
index=False,
encoding='utf-8-sig')
count_dcp_prod = count_dcp_prod[count_dcp_prod['count'] > 1000]
# print(count_dcp_prod)
list_prod = count_dcp_prod['up_id_product'].tolist(
) + count_dcp['down_id_product'].tolist()
list_prod = list(set(list_prod))
# init graph bom
BomNodes = pd.read_csv('input_data/input_product_data/BomNodes.csv')
BomNodes.set_index('Index', inplace=True)
g_bom = nx.MultiDiGraph()
g_bom.add_nodes_from(list_prod)
bom_labels_dict = {}
for code in list_prod:
dct_attr = BomNodes.loc[code].to_dict()
bom_labels_dict[code] = dct_attr
nx.set_node_attributes(g_bom, bom_labels_dict)
count_max = count_dcp_prod['count'].max()
count_min = count_dcp_prod['count'].min()
k = 5 / (count_max - count_min)
for _, row in count_dcp_prod.iterrows():
# print(row)
lst_add_edge = [(
row['up_id_product'],
row['down_id_product'],
{
'count': row['count']
})]
g_bom.add_edges_from(lst_add_edge)
# dcp_networkx
pos = nx.nx_agraph.graphviz_layout(g_bom, prog="twopi", args="")
node_labels = nx.get_node_attributes(g_bom, 'Name')
temp = {}
for key, value in node_labels.items():
temp[key] = str(key) + " " + value
node_labels = temp
node_labels ={
38: 'SiC Substrate',
39: 'GaN Substrate',
40: 'Si Substrate',
41: 'AlN Substrate',
42: 'DUV LED Substrate',
43: 'InP Substrate',
44: 'Mono-Si Wafer',
45: 'Poly-Si Wafer',
46: 'InP Cryst./Wafer',
47: 'SiC Cryst./Wafer',
48: 'GaAs Wafer',
49: 'GaN Cryst./Wafer',
50: 'Si Epi Wafer',
51: 'SiC Epi Wafer',
52: 'AlN Epi',
53: 'GaN Epi',
54: 'InP Epi',
55: 'LED Epi Wafer',
90: 'Power Devices',
91: 'Diode',
92: 'Transistor',
93: 'Thyristor',
94: 'Rectifier',
95: 'IC Fab',
99: 'Wafer Test'
}
colors = nx.get_edge_attributes(g_bom, "count")
colors = [w for (n1, n2, _), w in colors.items()]
vmin = min(colors)
vmax = max(colors)
cmap = plt.cm.Blues
pos_new = {node: (p[1], p[0]) for node, p in pos.items()} # 字典推导式优化
fig = plt.figure(figsize=(8, 8), dpi=300)
plt.subplots_adjust(right=0.85) # 关键调整右侧保留15%空白
# 使用Axes对象精准控制
main_ax = fig.add_axes([0.1, 0.1, 0.75, 0.8]) # 主图占左75%宽上下各留10%边距
nx.draw(g_bom, pos_new,
ax=main_ax,
node_size=50,
labels=node_labels,
font_size=5,
width=1.5,
edge_color=colors,
edge_cmap=cmap,
edge_vmin=vmin,
edge_vmax=vmax,
)
main_ax.axis('off')
# 颜色条定位系统
cbar_ax = fig.add_axes([0.86, 0.15, 0.015, 0.3]) # 右边缘86%位置底部15%起占30%高度
sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(vmin=vmin, vmax=vmax))
sm._A = [] # 必需的空数组
# 微调颜色条样式
cbar = fig.colorbar(sm, cax=cbar_ax, orientation='vertical')
cbar.ax.tick_params(labelsize=4,
width=0.3, # 刻度线粗细
length=1.5, # 刻度线长度
pad=0.8) # 标签与条间距
cbar.outline.set_linewidth(0.5) # 边框线宽
# 输出前验证边界
print(f"Colorbar position: {cbar_ax.get_position().bounds}") # 应输出(0.86,0.15,0.015,0.3)
# 专业级保存参数
plt.savefig("output_result/risk/count_dcp_prod_network.png",
dpi=900,
bbox_inches='tight', # 自动裁剪白边
pad_inches=0.05, # 保留0.05英寸边距
metadata={'CreationDate': None}) # 避免时间戳污染元数据
plt.close()