import pandas as pd import matplotlib.pyplot as plt import networkx as nx plt.rcParams['font.sans-serif'] = 'SimHei' count_prod = pd.read_csv("output_result/risk/count_prod.csv") print(count_prod) # category print(count_prod.describe()) # prod_networkx # BomNodes = pd.read_csv('input_data/input_product_data/BomNodes.csv', index_col=0) # BomNodes.set_index('Code', inplace=True) # BomCateNet = pd.read_csv('input_data/input_product_data/BomCateNet.csv', index_col=0) # BomCateNet.fillna(0, inplace=True) bom_nodes = pd.read_csv('input_data/input_product_data/BomNodes.csv') bom_nodes['Code'] = bom_nodes['Code'].astype(str) bom_nodes.set_index('Index', inplace=True) bom_cate_net = pd.read_csv('input_data/input_product_data/合成结点.csv') g_bom = nx.from_pandas_edgelist(bom_cate_net, source='UPID', target='ID', create_using=nx.MultiDiGraph()) labels_dict = {} for code in g_bom.nodes: node_attr = bom_nodes.loc[code].to_dict() index_list = count_prod[count_prod['id_product'] == code].index.tolist() index = index_list[0] if len(index_list) == 1 else -1 node_attr['count'] = count_prod['count'].get(index, 0) node_attr['node_size'] = (count_prod['count'].get(index, 0))/10 node_attr['node_color'] = count_prod['count'].get(index, 0) labels_dict[code] = node_attr nx.set_node_attributes(g_bom, labels_dict) # print(labels_dict) pos = nx.nx_agraph.graphviz_layout(g_bom, prog="twopi", args="") dict_node_name = nx.get_node_attributes(g_bom, 'Name') node_labels = {} for node in nx.nodes(g_bom): node_labels[node] = f"{node} {str(dict_node_name[node])}" # node_labels[node] = f"{str(dict_node_name[node])}" colors = list(nx.get_node_attributes(g_bom, 'node_color').values()) vmin = min(colors) vmax = max(colors) cmap = plt.cm.Blues # 创建绘图对象 fig = plt.figure(figsize=(10, 10), dpi=300) ax = fig.add_subplot(111) # 绘制网络图(优化样式参数) nx.draw(g_bom, pos, node_size=list(nx.get_node_attributes(g_bom, 'node_size').values()), labels=node_labels, font_size=3, node_color=colors, cmap=cmap, vmin=vmin, vmax=vmax, edge_color='#808080', # 中性灰 width=0.3, edgecolors='#404040', linewidths=0.2) # 创建颜色条(修正实现方式) sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(vmin=vmin, vmax=vmax)) sm.set_array([]) # 设置颜色条位置和样式 cax = fig.add_axes([0.88, 0.3, 0.015, 0.4]) # 右侧垂直对齐 cb = plt.colorbar(sm, cax=cax) cb.ax.tick_params(labelsize=4, width=0.5, colors='#333333') cb.outline.set_linewidth(0.5) cb.set_label('Risk Level', fontsize=5, labelpad=2) # 添加图元信息 ax.set_title("Production Risk Network", fontsize=6, pad=8, color='#2F2F2F') plt.text(0.5, 0.02, 'Data: USTB Production System | Viz: DeepSeek-R1', ha='center', fontsize=3, color='#666666', transform=fig.transFigure) # 调整边界和保存 plt.subplots_adjust(left=0.05, right=0.85, top=0.95, bottom=0.1) # 适应颜色条 plt.savefig(r"output_result/risk/count_prod_network.png", # 规范路径格式 dpi=600, bbox_inches='tight', pad_inches=0.05, transparent=False) plt.close() # dcp_prod count_dcp = pd.read_csv("output_result/risk/count_dcp.csv", dtype={ 'up_id_firm': str, 'down_id_firm': str }) count_dcp_prod = count_dcp.groupby( ['up_id_product', 'down_id_product'])['count'].sum() count_dcp_prod = count_dcp_prod.reset_index() count_dcp_prod.sort_values('count', inplace=True, ascending=False) count_dcp_prod.to_csv('output_result\\risk\\count_dcp_prod.csv', index=False, encoding='utf-8-sig') count_dcp_prod = count_dcp_prod[count_dcp_prod['count'] > 1000] # print(count_dcp_prod) list_prod = count_dcp_prod['up_id_product'].tolist( ) + count_dcp['down_id_product'].tolist() list_prod = list(set(list_prod)) # init graph bom BomNodes = pd.read_csv('input_data/input_product_data/BomNodes.csv') BomNodes.set_index('Index', inplace=True) g_bom = nx.MultiDiGraph() g_bom.add_nodes_from(list_prod) bom_labels_dict = {} for code in list_prod: dct_attr = BomNodes.loc[code].to_dict() bom_labels_dict[code] = dct_attr nx.set_node_attributes(g_bom, bom_labels_dict) count_max = count_dcp_prod['count'].max() count_min = count_dcp_prod['count'].min() k = 5 / (count_max - count_min) for _, row in count_dcp_prod.iterrows(): # print(row) lst_add_edge = [( row['up_id_product'], row['down_id_product'], { 'count': row['count'] })] g_bom.add_edges_from(lst_add_edge) # dcp_networkx pos = nx.nx_agraph.graphviz_layout(g_bom, prog="twopi", args="") node_labels = nx.get_node_attributes(g_bom, 'Name') temp = {} for key, value in node_labels.items(): temp[key] = str(key) + " " + value node_labels = temp node_labels ={ 38: 'SiC Substrate', 39: 'GaN Substrate', 40: 'Si Substrate', 41: 'AlN Substrate', 42: 'DUV LED Substrate', 43: 'InP Substrate', 44: 'Mono-Si Wafer', 45: 'Poly-Si Wafer', 46: 'InP Cryst./Wafer', 47: 'SiC Cryst./Wafer', 48: 'GaAs Wafer', 49: 'GaN Cryst./Wafer', 50: 'Si Epi Wafer', 51: 'SiC Epi Wafer', 52: 'AlN Epi', 53: 'GaN Epi', 54: 'InP Epi', 55: 'LED Epi Wafer', 90: 'Power Devices', 91: 'Diode', 92: 'Transistor', 93: 'Thyristor', 94: 'Rectifier', 95: 'IC Fab', 99: 'Wafer Test' } colors = nx.get_edge_attributes(g_bom, "count") colors = [w for (n1, n2, _), w in colors.items()] vmin = min(colors) vmax = max(colors) cmap = plt.cm.Blues pos_new = {node: (p[1], p[0]) for node, p in pos.items()} # 字典推导式优化 fig = plt.figure(figsize=(8, 8), dpi=300) plt.subplots_adjust(right=0.85) # 关键调整:右侧保留15%空白 # 使用Axes对象精准控制 main_ax = fig.add_axes([0.1, 0.1, 0.75, 0.8]) # 主图占左75%宽,上下各留10%边距 nx.draw(g_bom, pos_new, ax=main_ax, node_size=50, labels=node_labels, font_size=5, width=1.5, edge_color=colors, edge_cmap=cmap, edge_vmin=vmin, edge_vmax=vmax, ) main_ax.axis('off') # 颜色条定位系统 cbar_ax = fig.add_axes([0.86, 0.15, 0.015, 0.3]) # 右边缘86%位置,底部15%起,占30%高度 sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(vmin=vmin, vmax=vmax)) sm._A = [] # 必需的空数组 # 微调颜色条样式 cbar = fig.colorbar(sm, cax=cbar_ax, orientation='vertical') cbar.ax.tick_params(labelsize=4, width=0.3, # 刻度线粗细 length=1.5, # 刻度线长度 pad=0.8) # 标签与条间距 cbar.outline.set_linewidth(0.5) # 边框线宽 # 输出前验证边界 print(f"Colorbar position: {cbar_ax.get_position().bounds}") # 应输出(0.86,0.15,0.015,0.3) # 专业级保存参数 plt.savefig("output_result/risk/count_dcp_prod_network.png", dpi=900, bbox_inches='tight', # 自动裁剪白边 pad_inches=0.05, # 保留0.05英寸边距 metadata={'CreationDate': None}) # 避免时间戳污染元数据 plt.close()