import agentpy as ap
import pandas as pd
import numpy as np
import random
import networkx as nx
from firm import FirmAgent
from product import ProductAgent

sample = 0
seed = 0
n_iter = 10
# dct_list_init_remove_firm_prod = {133: ['1.4.4.1'], 2: ['1.1.3']}
# dct_list_init_remove_firm_prod = {
#     135: ['1.3.2.1'],
#     133: ['1.4.4.1'],
#     2: ['1.1.3']
# }
dct_list_init_remove_firm_prod = {
    140: ['1.4.5.1'],
    135: ['1.3.2.1'],
    133: ['1.4.4.1'],
    2: ['1.1.3']
}
n_max_trial = 5
dct_sample_para = {
    'sample': sample,
    'seed': seed,
    'n_iter': n_iter,
    'n_max_trial': n_max_trial,
    'dct_list_init_remove_firm_prod': dct_list_init_remove_firm_prod,
}


class Model(ap.Model):
    def setup(self):
        self.sample = self.p.sample
        self.random = random.Random(self.p.seed)
        self.nprandom = np.random.default_rng(self.p.seed)
        self.int_n_iter = int(self.p.n_iter)
        self.int_n_max_trial = int(self.p.n_max_trial)
        self.dct_list_remove_firm_prod = self.p.dct_list_init_remove_firm_prod

        # init graph bom
        BomNodes = pd.read_csv('BomNodes.csv', index_col=0)
        BomNodes.set_index('Code', inplace=True)
        BomCateNet = pd.read_csv('BomCateNet.csv', index_col=0)
        BomCateNet.fillna(0, inplace=True)

        G_bom = nx.from_pandas_adjacency(BomCateNet.T,
                                         create_using=nx.MultiDiGraph())

        bom_labels_dict = {}
        for code in G_bom.nodes:
            bom_labels_dict[code] = BomNodes.loc[code].to_dict()
        nx.set_node_attributes(G_bom, bom_labels_dict)

        # init graph firm
        Firm = pd.read_csv("Firm_amended.csv")
        Firm.fillna(0, inplace=True)
        Firm_attr = Firm.loc[:, ["Code", "Name", "Type_Region", "Revenue_Log"]]
        firm_product = []
        for _, row in Firm.loc[:, '1':].iterrows():
            firm_product.append(row[row == 1].index.to_list())
        Firm_attr.loc[:, 'Product_Code'] = firm_product
        Firm_attr.set_index('Code')
        G_Firm = nx.MultiDiGraph()
        G_Firm.add_nodes_from(Firm["Code"])

        firm_labels_dict = {}
        for code in G_Firm.nodes:
            firm_labels_dict[code] = Firm_attr.loc[code].to_dict()
        nx.set_node_attributes(G_Firm, firm_labels_dict)

        # add edge to G_firm according to G_bom
        for node in nx.nodes(G_Firm):
            # print(node, '-' * 20)
            for product_code in G_Firm.nodes[node]['Product_Code']:
                # print(product_code)
                for succ_product_code in list(G_bom.successors(product_code)):
                    # print(succ_product_code)
                    list_succ_firms = Firm.index[Firm[succ_product_code] ==
                                                 1].to_list()
                    list_revenue_log = [
                        G_Firm.nodes[succ_firm]['Revenue_Log']
                        for succ_firm in list_succ_firms
                    ]
                    # list_prob = [
                    #     (v - min(list_revenue_log) + 1) /
                    #     (max(list_revenue_log) - min(list_revenue_log) + 1)
                    #     for v in list_revenue_log
                    # ]
                    # list_flag = [
                    #     self.nprandom.choice([1, 0], p=[prob, 1 - prob])
                    #     for prob in list_prob
                    # ]
                    # # print(list(zip(list_succ_firms,list_flag,list_prob)))
                    # list_added_edges = [(node, succ_firm, {
                    #     'Product': product_code
                    # }) for succ_firm, flag in zip(list_succ_firms, list_flag)
                    #                     if flag == 1]
                    list_prob = [
                        size / sum(list_revenue_log)
                        for size in list_revenue_log
                    ]
                    succ_firm = self.nprandom.choice(list_succ_firms,
                                                     p=list_prob)
                    list_added_edges = [(node, succ_firm, {
                        'Product': product_code
                    })]
                    G_Firm.add_edges_from(list_added_edges)
                    # print('-' * 20)

        self.firm_network = ap.Network(self, G_Firm)
        self.product_network = ap.Network(self, G_bom)
        # print([node.label for node in self.firm_network.nodes])
        # print([list(self.firm_network.graph.predecessors(node))
        #        for node in self.firm_network.nodes])
        # print([self.firm_network.graph.nodes[node.label]['Name']
        #        for node in self.firm_network.nodes])
        # print([v for v in self.firm_network.graph.nodes(data=True)])

        # init product
        for ag_node, attr in self.product_network.graph.nodes(data=True):
            product_agent = ProductAgent(self,
                                         code=ag_node.label,
                                         name=attr['Name'])
            self.product_network.add_agents([product_agent], [ag_node])
        self.a_list_total_products = ap.AgentList(self,
                                                  self.product_network.agents)

        # init firm
        for ag_node, attr in self.firm_network.graph.nodes(data=True):
            firm_agent = FirmAgent(
                self,
                code=attr['Code'],
                name=attr['Name'],
                type_region=attr['Type_Region'],
                revenue_log=attr['Revenue_Log'],
                a_list_product=self.a_list_total_products.select([
                    code in attr['Product_Code']
                    for code in self.a_list_total_products.code
                ]))
            # init capacity based on discrete uniform distribution
            # list_out_edges = list(
            #     self.firm_network.graph.out_edges(ag_node,
            #                                       keys=True,
            #                                       data='Product'))
            # for product in firm_agent.a_list_product:
            #     capacity = len([
            #         edge for edge in list_out_edges if edge[-1] ==
            #                 product.code])
            #     firm_agent.dct_prod_capacity[product] = capacity
            for product in firm_agent.a_list_product:
                firm_agent.dct_prod_capacity[product] = self.nprandom.integers(
                    firm_agent.revenue_log / 5, firm_agent.revenue_log / 5 + 2)
            # print(firm_agent.name, firm_agent.dct_prod_capacity)

            self.firm_network.add_agents([firm_agent], [ag_node])
        self.a_list_total_firms = ap.AgentList(self, self.firm_network.agents)
        # print(list(zip(self.a_list_total_firms.code,
        #                self.a_list_total_firms.name,
        #                self.a_list_total_firms.capacity)))

        # init dct_list_remove_firm_prod (from string to agent)
        t_dct = {}
        for firm_code, list_product in self.dct_list_remove_firm_prod.items():
            firm = self.a_list_total_firms.select(
                self.a_list_total_firms.code == firm_code)[0]
            t_dct[firm] = self.a_list_total_products.select([
                code in list_product
                for code in self.a_list_total_products.code
            ])
        self.dct_list_remove_firm_prod = t_dct

        # set the initial firm product that are removed
        for firm, a_list_product in self.dct_list_remove_firm_prod.items():
            for product in a_list_product:
                assert product in firm.a_list_product, \
                    f"product {product.code} not in firm {firm.code}"
                firm.a_list_product_removed.append(product)

        # draw network
        self.draw_network()

    def update(self):
        self.a_list_total_firms.clean_before_time_step()
        # stop simulation if reached terminal number of iteration
        if self.t == self.int_n_iter or len(
                self.dct_list_remove_firm_prod) == 0:
            self.stop()

    def step(self):
        # shuffle self.dct_list_remove_firm_prod
        # dct_key_list = list(self.dct_list_remove_firm_prod.keys())
        # self.nprandom.shuffle(dct_key_list)
        # self.dct_list_remove_firm_prod = {
        #     key: self.dct_list_remove_firm_prod[key].shuffle()
        #     for key in dct_key_list
        # }
        # print(self.dct_list_remove_firm_prod)

        print('\n', '=' * 20, 'step', self.t, '=' * 20)
        print(
            'dct_list_remove_firm_prod', {
                key.name: value.code
                for key, value in self.dct_list_remove_firm_prod.items()
            })

        # remove_edge_to_cus_and_cus_up_prod
        for firm, a_list_product in self.dct_list_remove_firm_prod.items():
            for product in a_list_product:
                firm.remove_edge_to_cus_remove_cus_up_prod(product)

        for n_trial in range(self.int_n_max_trial):
            print('=' * 10, 'trial', n_trial, '=' * 10)
            # seek_alt_supply
            # shuffle self.a_list_total_firms
            self.a_list_total_firms = self.a_list_total_firms.shuffle()
            for firm in self.a_list_total_firms:
                if len(firm.a_list_up_product_removed) > 0:
                    # print(firm.name)
                    # print(firm.a_list_up_product_removed.code)
                    firm.seek_alt_supply()

            # handle_request
            # shuffle self.a_list_total_firms
            self.a_list_total_firms = self.a_list_total_firms.shuffle()
            for firm in self.a_list_total_firms:
                if len(firm.dct_request_prod_from_firm) > 0:
                    firm.handle_request()

            # reset dct_request_prod_from_firm
            self.a_list_total_firms.clean_before_trial()
            # do not use:
            # self.a_list_total_firms.dct_request_prod_from_firm = {} why?

        # based on a_list_up_product_removed,
        # update a_list_product_disrupted / a_list_product_removed / dct_list_remove_firm_prod
        self.dct_list_remove_firm_prod = {}
        for firm in self.a_list_total_firms:
            if len(firm.a_list_up_product_removed) > 0:
                print(firm.name, 'a_list_up_product_removed', [product.code for product in firm.a_list_up_product_removed])
                for product in firm.a_list_product:
                    n_up_product_removed = 0
                    for up_product_removed in firm.a_list_up_product_removed:
                        if product in up_product_removed.a_successors():
                            n_up_product_removed += 1
                    if n_up_product_removed == 0:
                        continue
                    else:
                        # update a_list_product_disrupted
                        if product not in firm.a_list_product_disrupted:
                            firm.a_list_product_disrupted.append(product)
                        # update a_list_product_removed / dct_list_remove_firm_prod
                        lost_percent = n_up_product_removed / len(
                            product.a_predecessors())
                        list_revenue_log = self.a_list_total_firms.revenue_log
                        std_size = (firm.revenue_log - min(list_revenue_log) +
                                    1) / (max(list_revenue_log) -
                                          min(list_revenue_log) + 1)
                        p_remove = 1 - std_size * (1 - lost_percent)
                        # flag = self.nprandom.choice([1, 0],
                        #                             p=[p_remove, 1 - p_remove])
                        flag = 1
                        if flag == 1:
                            firm.a_list_product_removed.append(product)
                            # if firm in
                            # self.dct_list_remove_firm_prod[firm] = firm.a_list_product_removed
                            if firm in self.dct_list_remove_firm_prod.keys():
                                self.dct_list_remove_firm_prod[firm].append(
                                    product)
                            else:
                                self.dct_list_remove_firm_prod[
                                    firm] = ap.AgentList(
                                        self.model, [product])

        # # update the firm that is removed
        # self.dct_list_remove_firm_prod = {}
        # for firm in self.a_list_total_firms:
        #     if len(firm.a_list_product_removed) > 0:
        #         self.dct_list_remove_firm_prod[
        #             firm] = firm.a_list_product_removed
        # print(self.dct_list_remove_firm_prod)
        print(
            'dct_list_remove_firm_prod', {
                key.name: value.code
                for key, value in self.dct_list_remove_firm_prod.items()
            })

    def end(self):
        pass

    def draw_network(self):
        import matplotlib.pyplot as plt
        plt.rcParams['font.sans-serif'] = 'SimHei'
        pos = nx.nx_agraph.graphviz_layout(self.firm_network.graph,
                                           prog="twopi",
                                           args="")
        node_label = nx.get_node_attributes(self.firm_network.graph, 'Name')
        # print(node_label)
        node_degree = dict(self.firm_network.graph.out_degree())
        node_label = {
            key: f"{node_label[key]} {node_degree[key]}"
            for key in node_label.keys()
        }
        node_size = list(
            nx.get_node_attributes(self.firm_network.graph,
                                   'Revenue_Log').values())
        node_size = list(map(lambda x: x**2, node_size))
        edge_label = nx.get_edge_attributes(self.firm_network.graph, "Product")
        # multi(di)graphs, the keys are 3-tuples
        edge_label = {(n1, n2): label
                      for (n1, n2, _), label in edge_label.items()}
        plt.figure(figsize=(12, 12), dpi=300)
        nx.draw(self.firm_network.graph,
                pos,
                node_size=node_size,
                labels=node_label,
                font_size=6)
        nx.draw_networkx_edge_labels(self.firm_network.graph,
                                     pos,
                                     edge_label,
                                     font_size=4)
        plt.savefig("network.png")


model = Model(dct_sample_para)
model.run()