import agentpy as ap
import pandas as pd
import networkx as nx
from firm import FirmAgent
from product import ProductAgent
from orm import db_session, Result
import platform
import json


class Model(ap.Model):
    def setup(self):
        # self para
        self.sample = self.p.sample
        self.int_stop_ts = 0
        self.int_n_iter = int(self.p.n_iter)
        self.product_network = None  # agentpy network
        self.firm_network = None  # agentpy network
        self.firm_prod_network = None  # networkx
        self.dct_lst_init_disrupt_firm_prod = \
            self.p.dct_lst_init_disrupt_firm_prod

        # external variable
        self.int_n_max_trial = int(self.p.n_max_trial)
        self.is_prf_size = bool(self.p.prf_size)
        self.proactive_ratio = float(self.p.proactive_ratio)
        self.remove_t = int(self.p.remove_t)
        self.int_netw_prf_n = int(self.p.netw_prf_n)

        # init graph bom
        G_bom = nx.adjacency_graph(json.loads(self.p.g_bom))
        self.product_network = ap.Network(self, G_bom)

        # init graph firm
        Firm = pd.read_csv("Firm_amended.csv")
        Firm['Code'] = Firm['Code'].astype('string')
        Firm.fillna(0, inplace=True)
        Firm_attr = Firm.loc[:, ["Code", "Name", "Type_Region", "Revenue_Log"]]
        firm_product = []
        for _, row in Firm.loc[:, '1':].iterrows():
            firm_product.append(row[row == 1].index.to_list())
        Firm_attr.loc[:, 'Product_Code'] = firm_product
        Firm_attr.set_index('Code', inplace=True)
        G_Firm = nx.MultiDiGraph()
        G_Firm.add_nodes_from(Firm["Code"])

        firm_labels_dict = {}
        for code in G_Firm.nodes:
            firm_labels_dict[code] = Firm_attr.loc[code].to_dict()
        nx.set_node_attributes(G_Firm, firm_labels_dict)

        # init graph firm prod
        Firm_Prod = pd.read_csv("Firm_amended.csv")
        Firm_Prod.fillna(0, inplace=True)
        firm_prod = pd.DataFrame({'bool': Firm_Prod.loc[:, '1':].stack()})
        firm_prod = firm_prod[firm_prod['bool'] == 1].reset_index()
        firm_prod.drop('bool', axis=1, inplace=True)
        firm_prod.rename({'level_0': 'Firm_Code',
                          'level_1': 'Product_Code'}, axis=1, inplace=True)
        firm_prod['Firm_Code'] = firm_prod['Firm_Code'].astype('string')
        G_FirmProd = nx.MultiDiGraph()
        G_FirmProd.add_nodes_from(firm_prod.index)

        firm_prod_labels_dict = {}
        for code in firm_prod.index:
            firm_prod_labels_dict[code] = firm_prod.loc[code].to_dict()
        nx.set_node_attributes(G_FirmProd, firm_prod_labels_dict)

        # add edge to G_firm according to G_bom
        for node in nx.nodes(G_Firm):
            lst_pred_product_code = []
            for product_code in G_Firm.nodes[node]['Product_Code']:
                lst_pred_product_code += list(G_bom.predecessors(product_code))
            lst_pred_product_code = list(set(lst_pred_product_code))
            # to generate consistant graph
            lst_pred_product_code = list(sorted(lst_pred_product_code))
            for pred_product_code in lst_pred_product_code:
                # for each product predecessor (component) the firm need
                # get a list of firm producing this component
                lst_pred_firm = \
                    Firm['Code'][Firm[pred_product_code] == 1].to_list()
                # select multiple supplier (multi-sourcing)
                n_pred_firm = self.int_netw_prf_n
                if n_pred_firm > len(lst_pred_firm):
                    n_pred_firm = len(lst_pred_firm)
                # based on size or not
                if self.is_prf_size:
                    lst_pred_firm_size = \
                        [G_Firm.nodes[pred_firm]['Revenue_Log']
                            for pred_firm in lst_pred_firm]
                    lst_prob = \
                        [size / sum(lst_pred_firm_size)
                            for size in lst_pred_firm_size]
                    lst_choose_firm = self.nprandom.choice(lst_pred_firm,
                                                           n_pred_firm,
                                                           replace=False,
                                                           p=lst_prob)
                else:
                    lst_choose_firm = self.nprandom.choice(lst_pred_firm,
                                                           n_pred_firm,
                                                           replace=False)
                lst_add_edge = [(pred_firm, node,
                                {'Product': pred_product_code})
                                for pred_firm in lst_choose_firm]
                G_Firm.add_edges_from(lst_add_edge)

                # graph firm prod
                set_node_prod_code = set(G_Firm.nodes[node]['Product_Code'])
                set_pred_succ_code = set(G_bom.successors(pred_product_code))
                lst_use_pred_prod_code = list(
                    set_node_prod_code & set_pred_succ_code)
                for pred_firm in lst_choose_firm:
                    pred_node = [n for n, v in G_FirmProd.nodes(data=True)
                                 if v['Firm_Code'] == pred_firm and
                                 v['Product_Code'] == pred_product_code][0]
                    for use_pred_prod_code in lst_use_pred_prod_code:
                        current_node = \
                            [n for n, v in G_FirmProd.nodes(data=True)
                             if v['Firm_Code'] == node and
                             v['Product_Code'] == use_pred_prod_code][0]
                        G_FirmProd.add_edge(pred_node, current_node)
        # nx.to_pandas_adjacency(G_Firm).to_csv('adj_g_firm.csv')
        # nx.to_pandas_adjacency(G_FirmProd).to_csv('adj_g_firm_prod.csv')

        # unconnected node
        for node in nx.nodes(G_Firm):
            if G_Firm.degree(node) == 0:
                for product_code in G_Firm.nodes[node]['Product_Code']:
                    # unconnect node does not have possible suppliers
                    # current node in graph firm prod
                    current_node = \
                        [n for n, v in G_FirmProd.nodes(data=True)
                            if v['Firm_Code'] == node and
                            v['Product_Code'] == product_code][0]

                    lst_succ_product_code = list(
                        G_bom.successors(product_code))
                    # different from for different types of product,
                    # finding a common supplier (the logic above),
                    # for different types of product,
                    # finding a custormer for each product
                    for succ_product_code in lst_succ_product_code:
                        # for each product successor (finished product)
                        # the firm sells to,
                        # get a list of firm producing this finished product
                        lst_succ_firm = Firm['Code'][
                            Firm[succ_product_code] == 1].to_list()
                        # select multiple customer (multi-selling)
                        n_succ_firm = self.int_netw_prf_n
                        if n_succ_firm > len(lst_succ_firm):
                            n_succ_firm = len(lst_succ_firm)
                        # based on size or not
                        if self.is_prf_size:
                            lst_succ_firm_size = \
                                [G_Firm.nodes[succ_firm]['Revenue_Log']
                                    for succ_firm in lst_succ_firm]
                            lst_prob = \
                                [size / sum(lst_succ_firm_size)
                                    for size in lst_succ_firm_size]
                            lst_choose_firm = \
                                self.nprandom.choice(lst_succ_firm,
                                                     n_succ_firm,
                                                     replace=False,
                                                     p=lst_prob)
                        else:
                            lst_choose_firm = \
                                self.nprandom.choice(lst_succ_firm,
                                                     n_succ_firm,
                                                     replace=False)
                        lst_add_edge = [(node, succ_firm,
                                        {'Product': product_code})
                                        for succ_firm in lst_choose_firm]
                        G_Firm.add_edges_from(lst_add_edge)

                        # graph firm prod
                        for succ_firm in lst_choose_firm:
                            succ_node = \
                                [n for n, v in G_FirmProd.nodes(data=True)
                                 if v['Firm_Code'] == succ_firm and
                                 v['Product_Code'] == succ_product_code][0]
                            G_FirmProd.add_edge(current_node, succ_node)

        self.sample.g_firm = json.dumps(nx.adjacency_data(G_Firm))
        self.firm_network = ap.Network(self, G_Firm)
        self.firm_prod_network = G_FirmProd
        # import matplotlib.pyplot as plt
        # nx.draw(G_FirmProd)
        # plt.show()

        # init product
        for ag_node, attr in self.product_network.graph.nodes(data=True):
            product = ProductAgent(self, code=ag_node.label, name=attr['Name'])
            self.product_network.add_agents([product], [ag_node])
        self.a_lst_total_products = ap.AgentList(self,
                                                 self.product_network.agents)

        # init firm
        for ag_node, attr in self.firm_network.graph.nodes(data=True):
            firm_agent = FirmAgent(
                self,
                code=ag_node.label,
                name=attr['Name'],
                type_region=attr['Type_Region'],
                revenue_log=attr['Revenue_Log'],
                a_lst_product=self.a_lst_total_products.select([
                    code in attr['Product_Code']
                    for code in self.a_lst_total_products.code
                ]))

            self.firm_network.add_agents([firm_agent], [ag_node])
        self.a_lst_total_firms = ap.AgentList(self, self.firm_network.agents)

        # init dct_lst_init_disrupt_firm_prod (from string to agent)
        t_dct = {}
        for firm_code, lst_product in \
                self.dct_lst_init_disrupt_firm_prod.items():
            firm = self.a_lst_total_firms.select(
                self.a_lst_total_firms.code == firm_code)[0]
            t_dct[firm] = self.a_lst_total_products.select([
                code in lst_product for code in self.a_lst_total_products.code
            ])
        self.dct_lst_init_disrupt_firm_prod = t_dct

        # set the initial firm product that are disrupted
        print('\n', '=' * 20, 'step', self.t, '=' * 20)
        for firm, a_lst_product in self.dct_lst_init_disrupt_firm_prod.items():
            for product in a_lst_product:
                assert product in firm.dct_prod_up_prod_stat.keys(), \
                    f"product {product.code} not in firm {firm.code}"
                firm.dct_prod_up_prod_stat[
                    product]['status'].append(('D', self.t))
                print(f"initial disruption {firm.name} {product.code}")

        # proactive strategy
        # get all the firm prod affected
        for firm, a_lst_product in self.dct_lst_init_disrupt_firm_prod.items():
            for product in a_lst_product:
                init_node = \
                    [n for n, v in
                     self.firm_prod_network.nodes(data=True)
                     if v['Firm_Code'] == firm.code and
                     v['Product_Code'] == product.code][0]
                dct_affected = \
                    nx.dfs_successors(self.firm_prod_network,
                                      init_node)
                lst_affected = set()
                for i, (u, vs) in enumerate(dct_affected.items()):
                    # at least 2 hops away
                    if i > 0:
                        pred_node = self.firm_prod_network.nodes[u]
                        for v in vs:
                            succ_node = self.firm_prod_network.nodes[v]
                            lst_affected.add((succ_node['Firm_Code'],
                                              succ_node['Product_Code']))
                lst_affected = list(lst_affected)
                lst_firm_proactive = \
                    [lst_affected[i] for i in
                     self.nprandom.choice(range(len(lst_affected)),
                                          round(len(lst_affected) *
                                                self.proactive_ratio),
                                          replace=False)]

                for firm_code, prod_code in lst_firm_proactive:
                    pro_firm_prod_code = \
                        [n for n, v in
                         self.firm_prod_network.nodes(data=True)
                         if v['Firm_Code'] == firm_code and
                         v['Product_Code'] == prod_code][0]
                    pro_firm_prod_node = \
                        self.firm_prod_network.nodes[pro_firm_prod_code]
                    pro_firm = \
                        self.a_lst_total_firms.select(
                            [firm.code == pro_firm_prod_node['Firm_Code']
                             for firm in self.a_lst_total_firms])[0]
                    lst_shortest_path = \
                        list(nx.all_shortest_paths(self.firm_prod_network,
                                                   source=init_node,
                                                   target=pro_firm_prod_code))

                    dct_drs = {}
                    for di_supp_code in self.firm_prod_network.predecessors(
                            pro_firm_prod_code):
                        di_supp_node = \
                            self.firm_prod_network.nodes[di_supp_code]
                        di_supp_prod = \
                            self.a_lst_total_products.select(
                                [product.code == di_supp_node['Product_Code']
                                 for product in self.a_lst_total_products])[0]
                        di_supp_firm = \
                            self.a_lst_total_firms.select(
                                [firm.code == di_supp_node['Firm_Code']
                                 for firm in self.a_lst_total_firms])[0]
                        lst_cand = self.a_lst_total_firms.select([
                            firm.is_prod_in_current_normal(di_supp_prod)
                            for firm in self.a_lst_total_firms
                        ])
                        n2n_betweenness = \
                            sum([True if di_supp_code in path else False
                                 for path in lst_shortest_path]) \
                            / len(lst_shortest_path)
                        drs = n2n_betweenness / \
                            (len(lst_cand) * di_supp_firm.size_stat[-1][0])
                        dct_drs[di_supp_code] = drs
                    dct_drs = dict(sorted(
                        dct_drs.items(), key=lambda kv: kv[1], reverse=True))
                    for di_supp_code in dct_drs.keys():
                        di_supp_node = \
                            self.firm_prod_network.nodes[di_supp_code]
                        di_supp_prod = \
                            self.a_lst_total_products.select(
                                [product.code == di_supp_node['Product_Code']
                                 for product in self.a_lst_total_products])[0]
                        # find a dfferent firm can produce the same product
                        lst_cand = self.model.a_lst_total_firms.select([
                            firm.is_prod_in_current_normal(di_supp_prod)
                            and firm.code != di_supp_node['Firm_Code']
                            for firm in self.model.a_lst_total_firms
                        ])
                        if len(lst_cand) > 0:
                            select_cand = self.nprandom.choice(lst_cand)
                            self.firm_network.graph.add_edges_from([
                                (self.firm_network.positions[select_cand],
                                 self.firm_network.positions[pro_firm], {
                                    'Product': di_supp_prod.code
                                })
                            ])
                            print(f"proactive add {select_cand.name} to "
                                  f"{pro_firm.name} "
                                  f"for {di_supp_node['Firm_Code']} "
                                  f"{di_supp_node['Product_Code']}")
                            # change capacity
                            select_cand.dct_prod_capacity[di_supp_prod] -= 1
                            break
        # nx.to_pandas_adjacency(G_Firm).to_csv('adj_g_firm_proactive.csv')

        # draw network
        # self.draw_network()

    def update(self):
        self.a_lst_total_firms.clean_before_time_step()

        # reduce the size of disrupted firm
        for firm in self.a_lst_total_firms:
            for prod in firm.dct_prod_up_prod_stat.keys():
                status, ts = firm.dct_prod_up_prod_stat[prod]['status'][-1]
                if status == 'D':
                    size = firm.size_stat[-1][0] - \
                        firm.size_stat[0][0] \
                        / len(firm.dct_prod_up_prod_stat.keys()) \
                        / self.remove_t
                    firm.size_stat.append((size, self.t))
                    print(f'in ts {self.t}, reduce {firm.name} size '
                          f'to {firm.size_stat[-1][0]} due to {prod.code}')
                    lst_is_disrupt = \
                        [stat == 'D' for stat, _ in
                         firm.dct_prod_up_prod_stat[prod]['status']
                         [-1 * self.remove_t:]]
                    if all(lst_is_disrupt):
                        # turn disrupted firm into removed firm
                        # when last self.remove_t times status is all disrupted
                        firm.dct_prod_up_prod_stat[
                            prod]['status'].append(('R', self.t))

        # stop simulation if any firm still in disrupted except inital removal
        if self.t > 0:
            for firm in self.a_lst_total_firms:
                for prod in firm.dct_prod_up_prod_stat.keys():
                    status, _ = firm.dct_prod_up_prod_stat[prod]['status'][-1]
                    is_init = \
                        firm in self.dct_lst_init_disrupt_firm_prod.keys() \
                        and prod in self.dct_lst_init_disrupt_firm_prod[firm]
                    if status == 'D' and not is_init:
                        print("not stop because", firm.name, prod.code)
                        break
                else:
                    continue
                break
            else:
                self.int_stop_ts = self.t
                self.stop()

        if self.t == self.int_n_iter:
            self.stop()

    def step(self):
        print('\n', '=' * 20, 'step', self.t, '=' * 20)

        # remove edge to customer and disrupt customer up product
        for firm in self.a_lst_total_firms:
            for prod in firm.dct_prod_up_prod_stat.keys():
                status, ts = firm.dct_prod_up_prod_stat[prod]['status'][-1]
                if status == 'D' and ts == self.t-1:
                    firm.remove_edge_to_cus_disrupt_cus_up_prod(prod)

        for n_trial in range(self.int_n_max_trial):
            print('=' * 10, 'trial', n_trial, '=' * 10)
            # seek_alt_supply
            # shuffle self.a_lst_total_firms
            self.a_lst_total_firms = self.a_lst_total_firms.shuffle()
            is_stop_trial = True
            for firm in self.a_lst_total_firms:
                lst_seek_prod = []
                for prod in firm.dct_prod_up_prod_stat.keys():
                    status = firm.dct_prod_up_prod_stat[prod]['status'][-1][0]
                    if status == 'D':
                        for supply in firm.dct_prod_up_prod_stat[
                                prod]['supply'].keys():
                            if not firm.dct_prod_up_prod_stat[
                                    prod]['supply'][supply]:
                                lst_seek_prod.append(supply)
                # commmon supply only seek once
                lst_seek_prod = list(set(lst_seek_prod))
                if len(lst_seek_prod) > 0:
                    is_stop_trial = False
                for supply in lst_seek_prod:
                    firm.seek_alt_supply(supply)
            if is_stop_trial:
                break

            # handle_request
            # shuffle self.a_lst_total_firms
            self.a_lst_total_firms = self.a_lst_total_firms.shuffle()
            for firm in self.a_lst_total_firms:
                if len(firm.dct_request_prod_from_firm) > 0:
                    firm.handle_request()

            # reset dct_request_prod_from_firm
            self.a_lst_total_firms.clean_before_trial()
            # do not use:
            # self.a_lst_total_firms.dct_request_prod_from_firm = {} why?

    def end(self):
        print('/' * 20, 'output', '/' * 20)

        qry_result = db_session.query(Result).filter_by(s_id=self.sample.id)
        if qry_result.count() == 0:
            lst_result_info = []
            for firm in self.a_lst_total_firms:
                for prod, dct_status_supply in \
                        firm.dct_prod_up_prod_stat.items():
                    lst_is_normal = [stat == 'N' for stat, _
                                     in dct_status_supply['status']]
                    if not all(lst_is_normal):
                        print(f"{firm.name} {prod.code}:")
                        print(dct_status_supply['status'])
                        for status, ts in dct_status_supply['status']:
                            db_r = Result(s_id=self.sample.id,
                                          id_firm=firm.code,
                                          id_product=prod.code,
                                          ts=ts,
                                          status=status)
                            lst_result_info.append(db_r)
            db_session.bulk_save_objects(lst_result_info)
            db_session.commit()
        self.sample.is_done_flag = 1
        self.sample.computer_name = platform.node()
        self.sample.stop_t = self.int_stop_ts
        db_session.commit()

    def draw_network(self):
        import matplotlib.pyplot as plt
        plt.rcParams['font.sans-serif'] = 'SimHei'
        pos = nx.nx_agraph.graphviz_layout(self.firm_network.graph,
                                           prog="twopi",
                                           args="")
        node_label = nx.get_node_attributes(self.firm_network.graph, 'Name')
        node_degree = dict(self.firm_network.graph.out_degree())
        node_label = {
            key: f"{key} {node_label[key]} {node_degree[key]}"
            for key in node_label.keys()
        }
        node_size = list(
            nx.get_node_attributes(self.firm_network.graph,
                                   'Revenue_Log').values())
        node_size = list(map(lambda x: x**2, node_size))
        edge_label = nx.get_edge_attributes(self.firm_network.graph, "Product")
        # multi(di)graphs, the keys are 3-tuples
        edge_label = {(n1, n2): label
                      for (n1, n2, _), label in edge_label.items()}
        plt.figure(figsize=(12, 12), dpi=300)
        nx.draw(self.firm_network.graph,
                pos,
                node_size=node_size,
                labels=node_label,
                font_size=6)
        nx.draw_networkx_edge_labels(self.firm_network.graph,
                                     pos,
                                     edge_label,
                                     font_size=4)
        plt.savefig("network.png")