import pandas as pd import matplotlib.pyplot as plt import networkx as nx plt.rcParams['font.sans-serif'] = 'SimHei' # count firm category count_firm = pd.read_csv("output_result/risk/count_firm.csv") print(count_firm.describe()) count_dcp = pd.read_csv("output_result/risk/count_dcp.csv", dtype={ 'up_id_firm': str, 'down_id_firm': str }) count_dcp = count_dcp[count_dcp['count'] > 130] list_firm = count_dcp['up_id_firm'].tolist( ) + count_dcp['down_id_firm'].tolist() list_firm = list(set(list_firm)) # init graph firm Firm = pd.read_csv("input_data/input_firm_data/Firm_amended.csv") Firm['Code'] = Firm['Code'].astype('string') Firm.fillna(0, inplace=True) Firm_attr = Firm.loc[:, ["Code", "企业名称", "Type_Region", "Revenue_Log"]] firm_industry_relation = pd.read_csv("input_data/firm_industry_relation.csv") firm_industry_relation['Firm_Code'] = firm_industry_relation['Firm_Code'].astype('string') firm_product = [] grouped = firm_industry_relation.groupby('Firm_Code')['Product_Code'].apply(list) firm_product.append(grouped) Firm_attr['Product_Code'] = Firm_attr['Code'].map(grouped) Firm_attr.set_index('Code', inplace=True) G_firm = nx.MultiDiGraph() G_firm.add_nodes_from(list_firm) firm_labels_dict = {} for code in G_firm.nodes: firm_labels_dict[code] = Firm_attr.loc[code].to_dict() nx.set_node_attributes(G_firm, firm_labels_dict) count_max = count_dcp['count'].max() count_min = count_dcp['count'].min() k = 15 / (count_max - count_min) for _, row in count_dcp.iterrows(): # print(row) lst_add_edge = [( row['up_id_firm'], row['down_id_firm'], { 'up_id_product': row['up_id_product'], 'down_id_product': row['down_id_product'], 'edge_label': f"{row['up_id_product']} - {row['down_id_product']}", 'edge_width': k * (row['count'] - count_min), 'count': (row['count'])*18 })] G_firm.add_edges_from(lst_add_edge) # dcp_networkx pos = nx.nx_agraph.graphviz_layout(G_firm, prog="twopi", args="") node_label = nx.get_node_attributes(G_firm, '企业名称') # desensitize node_label = {key: f"{key} " for key, value in node_label.items()} node_label = { '343012684': '59', '2944892892': '165', '3269039233': '194', '503176785': '73', '3111033905': '178', '3215814536': '190', '413274977': '64', '2317841563': '131', '2354145351': '157', '653528340': '88', '888395016': '104', '3069206426': '174', '3299144127': '197', '2624175': '8', '25685135': '24', '2348941764': '151', '750610681': '95', '2320475044': '133', '571058167': '78', '152008168': '44', '448033045': '66', '2321109759': '134', '3445928818': '213' } node_size = list(nx.get_node_attributes(G_firm, 'Revenue_Log').values()) node_size = list(map(lambda x: x * 10, node_size)) edge_label = nx.get_edge_attributes(G_firm, "edge_label") edge_label = {(n1, n2): label for (n1, n2, _), label in edge_label.items()} edge_width = nx.get_edge_attributes(G_firm, "edge_width") edge_width = [w for (n1, n2, _), w in edge_width.items()] colors = nx.get_edge_attributes(G_firm, "count") colors = [w for (n1, n2, _), w in colors.items()] vmin = min(colors) vmax = max(colors) cmap = plt.cm.Blues fig = plt.figure(figsize=(10, 8), dpi=500) nx.draw(G_firm, pos, node_size=node_size, labels=node_label, font_size=8, width=2, edge_color=colors, edge_cmap=cmap, edge_vmin=vmin, edge_vmax=vmax) # nx.draw_networkx_edge_labels(G_firm, pos, font_size=6) nx.draw_networkx_edge_labels( G_firm, pos, edge_labels=edge_label, font_size=5 ) sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(vmin=vmin, vmax=vmax)) sm._A = [] position = fig.add_axes([0.95, 0.05, 0.01, 0.3]) cb = plt.colorbar(sm, fraction=0.01, cax=position) cb.ax.tick_params(labelsize=4) cb.outline.set_visible(False) plt.savefig("output_result\\risk\\count_dcp_network") plt.close()