mesa-GA/GA_Agent_0925/risk_ay/risk_sum.py

81 lines
2.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import pickle
from sqlalchemy import text
from orm import engine, connection
import pandas as pd
import networkx as nx
import json
import matplotlib.pyplot as plt
# Prepare data
Firm = pd.read_csv("../../input_data/input_firm_data/firm_amended.csv")
Firm['Code'] = Firm['Code'].astype('string')
Firm.fillna(0, inplace=True)
BomNodes = pd.read_csv('../../input_data/input_product_data/BomNodes.csv', index_col=0)
# SQL query
with open('../../SQL_analysis_risk.sql', 'r') as f:
str_sql = text(f.read())
result = pd.read_sql(sql=str_sql, con=connection)
result.to_csv('count.csv', index=False, encoding='utf-8-sig')
print(result)
# Count firm product
count_firm_prod = result.value_counts(subset=['id_firm', 'id_product'])
count_firm_prod.name = 'count'
count_firm_prod = count_firm_prod.to_frame().reset_index()
count_firm_prod.to_csv('count_firm_prod.csv', index=False, encoding='utf-8-sig')
print(count_firm_prod)
# Count firm
count_firm = count_firm_prod.groupby('id_firm')['count'].sum()
count_firm = count_firm.to_frame().reset_index()
count_firm.sort_values('count', inplace=True, ascending=False)
count_firm.to_csv('count_firm.csv', index=False, encoding='utf-8-sig')
print(count_firm)
# Count product
count_prod = count_firm_prod.groupby('id_product')['count'].sum()
count_prod = count_prod.to_frame().reset_index()
count_prod.sort_values('count', inplace=True, ascending=False)
count_prod.to_csv('count_prod.csv', index=False, encoding='utf-8-sig')
print(count_prod)
# DCP disruption causing probability
result_disrupt_ts_above_0 = result[result['ts'] > 0]
print(result_disrupt_ts_above_0)
result_dcp = pd.DataFrame(columns=[
's_id', 'up_id_firm', 'up_id_product', 'down_id_firm', 'down_id_product'
])
result_dcp_list = [] # 用列表收集数据避免DataFrame逐行增长的问题
for sid, group in result.groupby('s_id'):
ts_start = max(group['ts'])
while ts_start >= 1:
ts_end = ts_start - 1
while ts_end >= 0:
up = group.loc[group['ts'] == ts_end, ['id_firm', 'id_product']]
down = group.loc[group['ts'] == ts_start, ['id_firm', 'id_product']]
for _, up_row in up.iterrows():
for _, down_row in down.iterrows():
result_dcp_list.append([sid] + up_row.tolist() + down_row.tolist())
ts_end -= 1
ts_start -= 1
# 转换为DataFrame
result_dcp = pd.DataFrame(result_dcp_list, columns=[
's_id', 'up_id_firm', 'up_id_product', 'down_id_firm', 'down_id_product'
])
# 统计
count_dcp = result_dcp.value_counts(
subset=['up_id_firm', 'up_id_product', 'down_id_firm', 'down_id_product']
).reset_index(name='count')
# 保存文件
count_dcp.to_csv('count_dcp.csv', index=False, encoding='utf-8-sig')
# 输出结果
print(count_dcp)