475 lines
		
	
	
		
			23 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			475 lines
		
	
	
		
			23 KiB
		
	
	
	
		
			Python
		
	
	
	
| import agentpy as ap
 | |
| import pandas as pd
 | |
| import numpy as np
 | |
| # import random
 | |
| import networkx as nx
 | |
| from firm import FirmAgent
 | |
| from product import ProductAgent
 | |
| from orm import db_session, Result
 | |
| import platform
 | |
| import json
 | |
| 
 | |
| 
 | |
| class Model(ap.Model):
 | |
|     def setup(self):
 | |
|         self.sample = self.p.sample
 | |
|         self.int_stop_times, self.int_stop_t = 0, None
 | |
|         # self.random = random.Random(self.p.seed)
 | |
|         self.nprandom = np.random.default_rng(self.p.seed)
 | |
|         self.int_n_iter = int(self.p.n_iter)
 | |
| 
 | |
|         self.dct_lst_remove_firm_prod = self.p.dct_lst_init_remove_firm_prod
 | |
| 
 | |
|         self.int_n_max_trial = int(self.p.n_max_trial)
 | |
|         self.flt_netw_pref_supp_n = float(self.p.netw_pref_supp_n)
 | |
|         self.flt_netw_pref_supp_size = float(self.p.netw_pref_supp_size)
 | |
|         self.flt_cap_limit = int(self.p.cap_limit)
 | |
|         self.flt_diff_remove = float(self.p.diff_remove)
 | |
|         self.proactive_ratio = float(self.p.proactive_ratio)
 | |
| 
 | |
|         # init graph bom
 | |
|         G_bom = nx.adjacency_graph(json.loads(self.p.g_bom))
 | |
|         self.product_network = ap.Network(self, G_bom)
 | |
| 
 | |
|         # init graph firm
 | |
|         Firm = pd.read_csv("Firm_amended.csv")
 | |
|         Firm['Code'] = Firm['Code'].astype('string')
 | |
|         Firm.fillna(0, inplace=True)
 | |
|         Firm_attr = Firm.loc[:, ["Code", "Name", "Type_Region", "Revenue_Log"]]
 | |
|         firm_product = []
 | |
|         for _, row in Firm.loc[:, '1':].iterrows():
 | |
|             firm_product.append(row[row == 1].index.to_list())
 | |
|         Firm_attr.loc[:, 'Product_Code'] = firm_product
 | |
|         Firm_attr.set_index('Code', inplace=True)
 | |
|         G_Firm = nx.MultiDiGraph()
 | |
|         G_Firm.add_nodes_from(Firm["Code"])
 | |
| 
 | |
|         firm_labels_dict = {}
 | |
|         for code in G_Firm.nodes:
 | |
|             firm_labels_dict[code] = Firm_attr.loc[code].to_dict()
 | |
|         nx.set_node_attributes(G_Firm, firm_labels_dict)
 | |
| 
 | |
|         # init graph firm prod
 | |
|         Firm_Prod = pd.read_csv("Firm_amended.csv")
 | |
|         Firm_Prod.fillna(0, inplace=True)
 | |
|         firm_prod = pd.DataFrame({'bool': Firm_Prod.loc[:, '1':].stack()})
 | |
|         firm_prod = firm_prod[firm_prod['bool'] == 1].reset_index()
 | |
|         firm_prod.drop('bool', axis=1, inplace=True)
 | |
|         firm_prod.rename({'level_0': 'Firm_Code',
 | |
|                           'level_1': 'Product_Code'}, axis=1, inplace=True)
 | |
|         firm_prod['Firm_Code'] = firm_prod['Firm_Code'].astype('string')
 | |
|         G_FirmProd = nx.MultiDiGraph()
 | |
|         G_FirmProd.add_nodes_from(firm_prod.index)
 | |
| 
 | |
|         firm_prod_labels_dict = {}
 | |
|         for code in firm_prod.index:
 | |
|             firm_prod_labels_dict[code] = firm_prod.loc[code].to_dict()
 | |
|         nx.set_node_attributes(G_FirmProd, firm_prod_labels_dict)
 | |
| 
 | |
|         # add edge to G_firm according to G_bom
 | |
|         for node in nx.nodes(G_Firm):
 | |
|             for product_code in G_Firm.nodes[node]['Product_Code']:
 | |
|                 for pred_product_code in \
 | |
|                         list(G_bom.predecessors(product_code)):
 | |
|                     # for each product of a certain firm
 | |
|                     # get each predecessor (component) of this product
 | |
|                     # get a list of firm producing this component
 | |
|                     lst_pred_firm = Firm['Code'][Firm[pred_product_code] ==
 | |
|                                                  1].to_list()
 | |
|                     lst_pred_firm_size_damp = \
 | |
|                         [G_Firm.nodes[pred_firm]['Revenue_Log'] **
 | |
|                          self.flt_netw_pref_supp_size
 | |
|                          for pred_firm in lst_pred_firm
 | |
|                          ]
 | |
|                     lst_prob = \
 | |
|                         [size_damp / sum(lst_pred_firm_size_damp)
 | |
|                          for size_damp in lst_pred_firm_size_damp
 | |
|                          ]
 | |
|                     # select multiple supplier
 | |
|                     # based on relative size of this firm
 | |
|                     lst_same_prod_firm = Firm['Code'][Firm[product_code] ==
 | |
|                                                       1].to_list()
 | |
|                     lst_same_prod_firm_size = [
 | |
|                         G_Firm.nodes[f]['Revenue_Log']
 | |
|                         for f in lst_same_prod_firm
 | |
|                     ]
 | |
|                     share = G_Firm.nodes[node]['Revenue_Log'] / sum(
 | |
|                         lst_same_prod_firm_size)
 | |
|                     # damp share
 | |
|                     share = share ** self.flt_netw_pref_supp_n
 | |
|                     n_pred_firm = round(share * len(lst_pred_firm)) if round(
 | |
|                         share * len(lst_pred_firm)) > 0 else 1  # at least one
 | |
|                     lst_choose_firm = self.nprandom.choice(lst_pred_firm,
 | |
|                                                            n_pred_firm,
 | |
|                                                            replace=False,
 | |
|                                                            p=lst_prob)
 | |
|                     lst_add_edge = [(node, pred_firm, {
 | |
|                         'Product': product_code
 | |
|                     }) for pred_firm in lst_choose_firm]
 | |
|                     G_Firm.add_edges_from(lst_add_edge)
 | |
| 
 | |
|                     # graph firm prod
 | |
|                     pred_node = [n for n, v in G_FirmProd.nodes(data=True)
 | |
|                                  if v['Firm_Code'] == node and
 | |
|                                  v['Product_Code'] == product_code][0]
 | |
|                     for pred_firm in lst_choose_firm:
 | |
|                         pred_node = [n for n, v in G_FirmProd.nodes(data=True)
 | |
|                                      if v['Firm_Code'] == pred_firm and
 | |
|                                      v['Product_Code'] == pred_product_code][0]
 | |
|                         G_FirmProd.add_edge(pred_node, pred_node)
 | |
| 
 | |
|         # unconnected node
 | |
|         # for node in nx.nodes(G_Firm):
 | |
|         #     if node.
 | |
| 
 | |
|         self.sample.g_firm = json.dumps(nx.adjacency_data(G_Firm))
 | |
|         self.firm_network = ap.Network(self, G_Firm)
 | |
|         self.firm_prod_network = G_FirmProd
 | |
| 
 | |
|         # init product
 | |
|         for ag_node, attr in self.product_network.graph.nodes(data=True):
 | |
|             product = ProductAgent(self, code=ag_node.label, name=attr['Name'])
 | |
|             self.product_network.add_agents([product], [ag_node])
 | |
|         self.a_lst_total_products = ap.AgentList(self,
 | |
|                                                  self.product_network.agents)
 | |
| 
 | |
|         # init firm
 | |
|         for ag_node, attr in self.firm_network.graph.nodes(data=True):
 | |
|             firm_agent = FirmAgent(
 | |
|                 self,
 | |
|                 code=ag_node.label,
 | |
|                 name=attr['Name'],
 | |
|                 type_region=attr['Type_Region'],
 | |
|                 revenue_log=attr['Revenue_Log'],
 | |
|                 a_lst_product=self.a_lst_total_products.select([
 | |
|                     code in attr['Product_Code']
 | |
|                     for code in self.a_lst_total_products.code
 | |
|                 ]))
 | |
|             # init extra capacity based on discrete uniform distribution
 | |
|             for product in firm_agent.a_lst_product:
 | |
|                 firm_agent.dct_prod_capacity[product] = \
 | |
|                     self.nprandom.integers(firm_agent.revenue_log / 5,
 | |
|                                            firm_agent.revenue_log / 5 +
 | |
|                                            self.flt_cap_limit)
 | |
| 
 | |
|             self.firm_network.add_agents([firm_agent], [ag_node])
 | |
|         self.a_lst_total_firms = ap.AgentList(self, self.firm_network.agents)
 | |
| 
 | |
|         # init dct_list_remove_firm_prod (from string to agent)
 | |
|         t_dct = {}
 | |
|         for firm_code, lst_product in self.dct_lst_remove_firm_prod.items():
 | |
|             firm = self.a_lst_total_firms.select(
 | |
|                 self.a_lst_total_firms.code == firm_code)[0]
 | |
|             t_dct[firm] = self.a_lst_total_products.select([
 | |
|                 code in lst_product for code in self.a_lst_total_products.code
 | |
|             ])
 | |
|         self.dct_lst_remove_firm_prod = t_dct
 | |
|         self.dct_lst_disrupt_firm_prod = t_dct
 | |
| 
 | |
|         # init output
 | |
|         self.lst_dct_lst_remove_firm_prod = []
 | |
|         self.lst_dct_lst_disrupt_firm_prod = []
 | |
| 
 | |
|         # set the initial firm product that are removed
 | |
|         for firm, a_lst_product in self.dct_lst_remove_firm_prod.items():
 | |
|             for product in a_lst_product:
 | |
|                 assert product in firm.a_lst_product, \
 | |
|                     f"product {product.code} not in firm {firm.code}"
 | |
|                 firm.a_lst_product_removed.append(product)
 | |
| 
 | |
|         # proactive strategy
 | |
|         # get all the firm prod affected
 | |
|         for firm, a_lst_product in self.dct_lst_remove_firm_prod.items():
 | |
|             for product in a_lst_product:
 | |
|                 init_node = \
 | |
|                     [n for n, v in
 | |
|                      self.firm_prod_network.nodes(data=True)
 | |
|                      if v['Firm_Code'] == firm.code and
 | |
|                      v['Product_Code'] == product.code][0]
 | |
|                 dct_affected = \
 | |
|                     nx.dfs_successors(self.firm_prod_network,
 | |
|                                       init_node)
 | |
|                 lst_affected = set()
 | |
|                 for i, (u, vs) in enumerate(dct_affected.items()):
 | |
|                     # at least 2 hops away
 | |
|                     if i > 0:
 | |
|                         pred_node = self.firm_prod_network.nodes[u]
 | |
|                         for v in vs:
 | |
|                             succ_node = self.firm_prod_network.nodes[v]
 | |
|                             lst_affected.add((succ_node['Firm_Code'],
 | |
|                                               succ_node['Product_Code']))
 | |
|                 lst_affected = list(lst_affected)
 | |
|                 lst_firm_proactive = \
 | |
|                     [lst_affected[i] for i in
 | |
|                      self.nprandom.choice(range(len(lst_affected)),
 | |
|                                           round(len(lst_affected) *
 | |
|                                                 self.proactive_ratio))]
 | |
| 
 | |
|                 for firm_code, prod_code in lst_firm_proactive:
 | |
|                     pro_firm_prod_code = \
 | |
|                         [n for n, v in
 | |
|                          self.firm_prod_network.nodes(data=True)
 | |
|                          if v['Firm_Code'] == firm_code and
 | |
|                          v['Product_Code'] == prod_code][0]
 | |
|                     pro_firm_prod_node = \
 | |
|                         self.firm_prod_network.nodes[pro_firm_prod_code]
 | |
|                     pro_firm = \
 | |
|                         self.a_lst_total_firms.select(
 | |
|                             [firm.code == pro_firm_prod_node['Firm_Code']
 | |
|                              for firm in self.a_lst_total_firms])[0]
 | |
|                     lst_shortest_path = \
 | |
|                         list(nx.all_shortest_paths(self.firm_prod_network,
 | |
|                                                    source=init_node,
 | |
|                                                    target=pro_firm_prod_code))
 | |
| 
 | |
|                     dct_drs = {}
 | |
|                     for di_supp_code in self.firm_prod_network.predecessors(
 | |
|                             pro_firm_prod_code):
 | |
|                         di_supp_node = \
 | |
|                             self.firm_prod_network.nodes[di_supp_code]
 | |
|                         di_supp_prod = \
 | |
|                             self.a_lst_total_products.select(
 | |
|                                 [product.code == di_supp_node['Product_Code']
 | |
|                                  for product in self.a_lst_total_products])[0]
 | |
|                         di_supp_firm = \
 | |
|                             self.a_lst_total_firms.select(
 | |
|                                 [firm.code == di_supp_node['Firm_Code']
 | |
|                                  for firm in self.a_lst_total_firms])[0]
 | |
|                         lst_cand = self.model.a_lst_total_firms.select([
 | |
|                             di_supp_prod in firm.a_lst_product
 | |
|                             and di_supp_prod not in firm.a_lst_product_removed
 | |
|                             for firm in self.model.a_lst_total_firms
 | |
|                         ])
 | |
|                         n2n_betweenness = \
 | |
|                             sum([True if di_supp_code in path else False
 | |
|                                  for path in lst_shortest_path]) \
 | |
|                             / len(lst_shortest_path)
 | |
|                         drs = n2n_betweenness / \
 | |
|                             (len(lst_cand) * di_supp_firm.revenue_log)
 | |
|                         dct_drs[di_supp_code] = drs
 | |
|                     dct_drs = dict(sorted(
 | |
|                         dct_drs.items(), key=lambda kv: kv[1], reverse=True))
 | |
|                     for di_supp_code in dct_drs.keys():
 | |
|                         di_supp_node = \
 | |
|                             self.firm_prod_network.nodes[di_supp_code]
 | |
|                         di_supp_prod = \
 | |
|                             self.a_lst_total_products.select(
 | |
|                                 [product.code == di_supp_node['Product_Code']
 | |
|                                  for product in self.a_lst_total_products])[0]
 | |
|                         # find a dfferent firm can produce the same product
 | |
|                         lst_cand = self.model.a_lst_total_firms.select([
 | |
|                             di_supp_prod in firm.a_lst_product
 | |
|                             and di_supp_prod not in firm.a_lst_product_removed
 | |
|                             and firm.code != di_supp_node['Firm_Code']
 | |
|                             for firm in self.model.a_lst_total_firms
 | |
|                         ])
 | |
|                         if len(lst_cand) > 0:
 | |
|                             select_cand = self.nprandom.choice(lst_cand)
 | |
|                             self.firm_network.graph.add_edges_from([
 | |
|                                 (self.firm_network.positions[select_cand],
 | |
|                                  self.firm_network.positions[pro_firm], {
 | |
|                                     'Product': di_supp_prod.code
 | |
|                                 })
 | |
|                             ])
 | |
|                             # print(f"proactive add {select_cand.code} to "
 | |
|                             #       f"{pro_firm.code} "
 | |
|                             #       f"for {di_supp_code} {di_supp_prod.code}")
 | |
|                             # change capacity
 | |
|                             select_cand.dct_prod_capacity[di_supp_prod] -= 1
 | |
|                             break
 | |
| 
 | |
|         # draw network
 | |
|         # self.draw_network()
 | |
| 
 | |
|     def update(self):
 | |
|         self.a_lst_total_firms.clean_before_time_step()
 | |
|         # output
 | |
|         self.lst_dct_lst_remove_firm_prod.append(
 | |
|             (self.t, self.dct_lst_remove_firm_prod))
 | |
|         self.lst_dct_lst_disrupt_firm_prod.append(
 | |
|             (self.t, self.dct_lst_disrupt_firm_prod))
 | |
| 
 | |
|         # stop simulation if reached terminal number of iteration
 | |
|         if self.t == self.int_n_iter or len(
 | |
|                 self.dct_lst_remove_firm_prod) == 0:
 | |
|             self.int_stop_times = self.t
 | |
|             self.stop()
 | |
| 
 | |
|     def step(self):
 | |
|         # print('\n', '=' * 20, 'step', self.t, '=' * 20)
 | |
|         # print(
 | |
|         #     'dct_list_remove_firm_prod', {
 | |
|         #         key.name: value.code
 | |
|         #         for key, value in self.dct_lst_remove_firm_prod.items()
 | |
|         #     })
 | |
| 
 | |
|         # remove_edge_to_cus_and_cus_up_prod
 | |
|         for firm, a_lst_product in self.dct_lst_remove_firm_prod.items():
 | |
|             for product in a_lst_product:
 | |
|                 firm.remove_edge_to_cus_remove_cus_up_prod(product)
 | |
| 
 | |
|         for n_trial in range(self.int_n_max_trial):
 | |
|             # print('=' * 10, 'trial', n_trial, '=' * 10)
 | |
|             # seek_alt_supply
 | |
|             # shuffle self.a_lst_total_firms
 | |
|             self.a_lst_total_firms = self.a_lst_total_firms.shuffle()
 | |
|             for firm in self.a_lst_total_firms:
 | |
|                 if len(firm.a_lst_up_product_removed) > 0:
 | |
|                     firm.seek_alt_supply()
 | |
| 
 | |
|             # handle_request
 | |
|             # shuffle self.a_lst_total_firms
 | |
|             self.a_lst_total_firms = self.a_lst_total_firms.shuffle()
 | |
|             for firm in self.a_lst_total_firms:
 | |
|                 if len(firm.dct_request_prod_from_firm) > 0:
 | |
|                     firm.handle_request()
 | |
| 
 | |
|             # reset dct_request_prod_from_firm
 | |
|             self.a_lst_total_firms.clean_before_trial()
 | |
|             # do not use:
 | |
|             # self.a_lst_total_firms.dct_request_prod_from_firm = {} why?
 | |
| 
 | |
|         # based on a_lst_up_product_removed
 | |
|         # update a_lst_product_disrupted / a_lst_product_removed
 | |
|         # update dct_lst_disrupt_firm_prod / dct_lst_remove_firm_prod
 | |
|         self.dct_lst_remove_firm_prod = {}
 | |
|         self.dct_lst_disrupt_firm_prod = {}
 | |
|         for firm in self.a_lst_total_firms:
 | |
|             if len(firm.a_lst_up_product_removed) > 0:
 | |
|                 # print(firm.name, 'a_lst_up_product_removed', [
 | |
|                 #     product.code for product in firm.a_lst_up_product_removed
 | |
|                 # ])
 | |
|                 for product in firm.a_lst_product:
 | |
|                     n_up_product_removed = 0
 | |
|                     for up_product_removed in firm.a_lst_up_product_removed:
 | |
|                         if product in up_product_removed.a_successors():
 | |
|                             n_up_product_removed += 1
 | |
|                     if n_up_product_removed == 0:
 | |
|                         continue
 | |
|                     else:
 | |
|                         # update a_lst_product_disrupted
 | |
|                         # update dct_lst_disrupt_firm_prod
 | |
|                         if product not in firm.a_lst_product_disrupted:
 | |
|                             firm.a_lst_product_disrupted.append(product)
 | |
|                             if firm in self.dct_lst_disrupt_firm_prod.keys():
 | |
|                                 self.dct_lst_disrupt_firm_prod[firm].append(
 | |
|                                     product)
 | |
|                             else:
 | |
|                                 self.dct_lst_disrupt_firm_prod[
 | |
|                                     firm] = ap.AgentList(
 | |
|                                         self.model, [product])
 | |
|                         # update a_lst_product_removed
 | |
|                         # update dct_list_remove_firm_prod
 | |
|                         # mark disrupted firm as removed based conditionally
 | |
|                         lost_percent = n_up_product_removed / len(
 | |
|                             product.a_predecessors())
 | |
|                         lst_size = self.a_lst_total_firms.revenue_log
 | |
|                         lst_size = [firm.revenue_log for firm
 | |
|                                     in self.a_lst_total_firms
 | |
|                                     if product in firm.a_lst_product
 | |
|                                     and product
 | |
|                                     not in firm.a_lst_product_removed
 | |
|                                     ]
 | |
|                         std_size = (firm.revenue_log - min(lst_size) +
 | |
|                                     1) / (max(lst_size) - min(lst_size) + 1)
 | |
|                         prob_remove = 1 - std_size * (1 - lost_percent)
 | |
|                         # sample prob
 | |
|                         prob_remove = self.nprandom.uniform(
 | |
|                             prob_remove - 0.1, prob_remove + 0.1)
 | |
|                         prob_remove = 1 if prob_remove > 1 else prob_remove
 | |
|                         prob_remove = 0 if prob_remove < 0 else prob_remove
 | |
|                         # damp prod
 | |
|                         prob_remove = prob_remove ** self.flt_diff_remove
 | |
|                         if self.nprandom.choice([True, False],
 | |
|                                                 p=[prob_remove,
 | |
|                                                    1 - prob_remove]):
 | |
|                             firm.a_lst_product_removed.append(product)
 | |
|                             if firm in self.dct_lst_remove_firm_prod.keys():
 | |
|                                 self.dct_lst_remove_firm_prod[firm].append(
 | |
|                                     product)
 | |
|                             else:
 | |
|                                 self.dct_lst_remove_firm_prod[
 | |
|                                     firm] = ap.AgentList(
 | |
|                                         self.model, [product])
 | |
| 
 | |
|         # print(
 | |
|         #     'dct_list_remove_firm_prod', {
 | |
|         #         key.name: value.code
 | |
|         #         for key, value in self.dct_lst_remove_firm_prod.items()
 | |
|         #     })
 | |
| 
 | |
|     def end(self):
 | |
|         # print('/' * 20, 'output', '/' * 20)
 | |
|         # print('dct_list_remove_firm_prod')
 | |
|         # for t, dct in self.lst_dct_lst_remove_firm_prod:
 | |
|         #     for firm, a_lst_product in dct.items():
 | |
|         #         for product in a_lst_product:
 | |
|         #             print(t, firm.name, product.code)
 | |
|         # print('dct_lst_disrupt_firm_prod')
 | |
|         # for t, dct in self.lst_dct_lst_disrupt_firm_prod:
 | |
|         #     for firm, a_lst_product in dct.items():
 | |
|         #         for product in a_lst_product:
 | |
|         #             print(t, firm.name, product.code)
 | |
| 
 | |
|         qry_result = db_session.query(Result).filter_by(s_id=self.sample.id)
 | |
|         if qry_result.count() == 0:
 | |
|             lst_result_info = []
 | |
|             for t, dct in self.lst_dct_lst_disrupt_firm_prod:
 | |
|                 for firm, a_lst_product in dct.items():
 | |
|                     for product in a_lst_product:
 | |
|                         db_r = Result(s_id=self.sample.id,
 | |
|                                       id_firm=firm.code,
 | |
|                                       id_product=product.code,
 | |
|                                       ts=t,
 | |
|                                       is_disrupted=True)
 | |
|                         lst_result_info.append(db_r)
 | |
|             db_session.bulk_save_objects(lst_result_info)
 | |
|             db_session.commit()
 | |
|             for t, dct in self.lst_dct_lst_remove_firm_prod:
 | |
|                 for firm, a_lst_product in dct.items():
 | |
|                     for product in a_lst_product:
 | |
|                         # only firm disrupted can be removed theoretically
 | |
|                         qry_f_p = db_session.query(Result).filter(
 | |
|                             Result.s_id == self.sample.id,
 | |
|                             Result.id_firm == firm.code,
 | |
|                             Result.id_product == product.code)
 | |
|                         if qry_f_p.count() == 1:
 | |
|                             qry_f_p.update({"is_removed": True})
 | |
|                             db_session.commit()
 | |
|         self.sample.is_done_flag = 1
 | |
|         self.sample.computer_name = platform.node()
 | |
|         self.sample.stop_t = self.int_stop_times
 | |
|         db_session.commit()
 | |
| 
 | |
|     def draw_network(self):
 | |
|         import matplotlib.pyplot as plt
 | |
|         plt.rcParams['font.sans-serif'] = 'SimHei'
 | |
|         pos = nx.nx_agraph.graphviz_layout(self.firm_network.graph,
 | |
|                                            prog="twopi",
 | |
|                                            args="")
 | |
|         node_label = nx.get_node_attributes(self.firm_network.graph, 'Name')
 | |
|         node_degree = dict(self.firm_network.graph.out_degree())
 | |
|         node_label = {
 | |
|             key: f"{key} {node_label[key]} {node_degree[key]}"
 | |
|             for key in node_label.keys()
 | |
|         }
 | |
|         node_size = list(
 | |
|             nx.get_node_attributes(self.firm_network.graph,
 | |
|                                    'Revenue_Log').values())
 | |
|         node_size = list(map(lambda x: x**2, node_size))
 | |
|         edge_label = nx.get_edge_attributes(self.firm_network.graph, "Product")
 | |
|         # multi(di)graphs, the keys are 3-tuples
 | |
|         edge_label = {(n1, n2): label
 | |
|                       for (n1, n2, _), label in edge_label.items()}
 | |
|         plt.figure(figsize=(12, 12), dpi=300)
 | |
|         nx.draw(self.firm_network.graph,
 | |
|                 pos,
 | |
|                 node_size=node_size,
 | |
|                 labels=node_label,
 | |
|                 font_size=6)
 | |
|         nx.draw_networkx_edge_labels(self.firm_network.graph,
 | |
|                                      pos,
 | |
|                                      edge_label,
 | |
|                                      font_size=4)
 | |
|         plt.savefig("network.png")
 |