10.09添加git
This commit is contained in:
128
risk_analysis_firm_network.py
Normal file
128
risk_analysis_firm_network.py
Normal file
@@ -0,0 +1,128 @@
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
import networkx as nx
|
||||
|
||||
plt.rcParams['font.sans-serif'] = 'SimHei'
|
||||
|
||||
# count firm category
|
||||
count_firm = pd.read_csv("output_result/risk/count_firm.csv")
|
||||
print(count_firm.describe())
|
||||
|
||||
count_dcp = pd.read_csv("output_result/risk/count_dcp.csv",
|
||||
dtype={
|
||||
'up_id_firm': str,
|
||||
'down_id_firm': str
|
||||
})
|
||||
count_dcp = count_dcp[count_dcp['count'] > 130]
|
||||
|
||||
list_firm = count_dcp['up_id_firm'].tolist(
|
||||
) + count_dcp['down_id_firm'].tolist()
|
||||
list_firm = list(set(list_firm))
|
||||
|
||||
# init graph firm
|
||||
Firm = pd.read_csv("input_data/input_firm_data/Firm_amended.csv")
|
||||
Firm['Code'] = Firm['Code'].astype('string')
|
||||
Firm.fillna(0, inplace=True)
|
||||
Firm_attr = Firm.loc[:, ["Code", "企业名称", "Type_Region", "Revenue_Log"]]
|
||||
firm_industry_relation = pd.read_csv("input_data/firm_industry_relation.csv")
|
||||
firm_industry_relation['Firm_Code'] = firm_industry_relation['Firm_Code'].astype('string')
|
||||
firm_product = []
|
||||
grouped = firm_industry_relation.groupby('Firm_Code')['Product_Code'].apply(list)
|
||||
firm_product.append(grouped)
|
||||
Firm_attr['Product_Code'] = Firm_attr['Code'].map(grouped)
|
||||
Firm_attr.set_index('Code', inplace=True)
|
||||
|
||||
G_firm = nx.MultiDiGraph()
|
||||
G_firm.add_nodes_from(list_firm)
|
||||
|
||||
firm_labels_dict = {}
|
||||
for code in G_firm.nodes:
|
||||
firm_labels_dict[code] = Firm_attr.loc[code].to_dict()
|
||||
nx.set_node_attributes(G_firm, firm_labels_dict)
|
||||
|
||||
count_max = count_dcp['count'].max()
|
||||
count_min = count_dcp['count'].min()
|
||||
k = 15 / (count_max - count_min)
|
||||
for _, row in count_dcp.iterrows():
|
||||
# print(row)
|
||||
lst_add_edge = [(
|
||||
row['up_id_firm'],
|
||||
row['down_id_firm'],
|
||||
{
|
||||
'up_id_product': row['up_id_product'],
|
||||
'down_id_product': row['down_id_product'],
|
||||
'edge_label': f"{row['up_id_product']} - {row['down_id_product']}",
|
||||
'edge_width': k * (row['count'] - count_min),
|
||||
'count': (row['count'])*18
|
||||
})]
|
||||
G_firm.add_edges_from(lst_add_edge)
|
||||
|
||||
# dcp_networkx
|
||||
pos = nx.nx_agraph.graphviz_layout(G_firm, prog="twopi", args="")
|
||||
node_label = nx.get_node_attributes(G_firm, '企业名称')
|
||||
# desensitize
|
||||
node_label = {key: f"{key} " for key, value in node_label.items()}
|
||||
node_label = {
|
||||
'343012684': '59',
|
||||
'2944892892': '165',
|
||||
'3269039233': '194',
|
||||
'503176785': '73',
|
||||
'3111033905': '178',
|
||||
'3215814536': '190',
|
||||
'413274977': '64',
|
||||
'2317841563': '131',
|
||||
'2354145351': '157',
|
||||
'653528340': '88',
|
||||
'888395016': '104',
|
||||
'3069206426': '174',
|
||||
'3299144127': '197',
|
||||
'2624175': '8',
|
||||
'25685135': '24',
|
||||
'2348941764': '151',
|
||||
'750610681': '95',
|
||||
'2320475044': '133',
|
||||
'571058167': '78',
|
||||
'152008168': '44',
|
||||
'448033045': '66',
|
||||
'2321109759': '134',
|
||||
'3445928818': '213'
|
||||
}
|
||||
|
||||
node_size = list(nx.get_node_attributes(G_firm, 'Revenue_Log').values())
|
||||
node_size = list(map(lambda x: x * 10, node_size))
|
||||
edge_label = nx.get_edge_attributes(G_firm, "edge_label")
|
||||
edge_label = {(n1, n2): label for (n1, n2, _), label in edge_label.items()}
|
||||
edge_width = nx.get_edge_attributes(G_firm, "edge_width")
|
||||
edge_width = [w for (n1, n2, _), w in edge_width.items()]
|
||||
colors = nx.get_edge_attributes(G_firm, "count")
|
||||
colors = [w for (n1, n2, _), w in colors.items()]
|
||||
vmin = min(colors)
|
||||
vmax = max(colors)
|
||||
cmap = plt.cm.Blues
|
||||
fig = plt.figure(figsize=(10, 8), dpi=500)
|
||||
nx.draw(G_firm,
|
||||
pos,
|
||||
node_size=node_size,
|
||||
labels=node_label,
|
||||
font_size=8,
|
||||
width=2,
|
||||
edge_color=colors,
|
||||
edge_cmap=cmap,
|
||||
edge_vmin=vmin,
|
||||
edge_vmax=vmax)
|
||||
# nx.draw_networkx_edge_labels(G_firm, pos, font_size=6)
|
||||
nx.draw_networkx_edge_labels(
|
||||
G_firm,
|
||||
pos,
|
||||
edge_labels=edge_label,
|
||||
font_size=5
|
||||
)
|
||||
|
||||
sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(vmin=vmin, vmax=vmax))
|
||||
sm._A = []
|
||||
position = fig.add_axes([0.95, 0.05, 0.01, 0.3])
|
||||
cb = plt.colorbar(sm, fraction=0.01, cax=position)
|
||||
cb.ax.tick_params(labelsize=4)
|
||||
cb.outline.set_visible(False)
|
||||
plt.savefig("output_result\\risk\\count_dcp_network")
|
||||
plt.close()
|
||||
Reference in New Issue
Block a user