import agentpy as ap
import math


class FirmAgent(ap.Agent):
    def setup(self, code, name, type_region, revenue_log, a_lst_product):
        self.firm_network = self.model.firm_network
        self.product_network = self.model.product_network

        self.code = code
        self.name = name
        self.type_region = type_region
        self.revenue_log = revenue_log
        self.a_lst_product = a_lst_product
        self.dct_prod_capacity = dict.fromkeys(self.a_lst_product)

        self.a_lst_up_product_removed = ap.AgentList(self.model, [])
        self.a_lst_product_disrupted = ap.AgentList(self.model, [])
        self.a_lst_product_removed = ap.AgentList(self.model, [])

        self.dct_n_trial_up_prod_removed = {}
        self.dct_cand_alt_supply_up_prod_removed = {}
        self.dct_request_prod_from_firm = {}

        # para
        self.flt_crit_supplier = float(self.p.crit_supplier)
        self.flt_firm_req_prf_size = float(self.p.firm_req_prf_size)
        self.is_firm_req_prf_conn = bool(self.p.firm_req_prf_conn)
        self.flt_firm_acc_prf_size = float(self.p.firm_acc_prf_size)
        self.is_firm_acc_prf_conn = bool(self.p.firm_acc_prf_conn)
        self.flt_diff_new_conn = float(self.p.diff_new_conn)

    def remove_edge_to_cus_remove_cus_up_prod(self, remove_product):
        lst_out_edge = list(
            self.firm_network.graph.out_edges(
                self.firm_network.positions[self], keys=True, data='Product'))
        for n1, n2, key, product_code in lst_out_edge:
            if product_code == remove_product.code:
                # remove edge
                self.firm_network.graph.remove_edge(n1, n2, key)

                # remove customer up product conditionally
                customer = ap.AgentIter(self.model, n2).to_list()[0]
                lst_in_edge = list(
                    self.firm_network.graph.in_edges(n2,
                                                     keys=True,
                                                     data='Product'))
                lst_select_in_edge = [
                    edge for edge in lst_in_edge
                    if edge[-1] == remove_product.code
                ]
                prod_remove = math.exp(-1 * self.flt_crit_supplier *
                                       len(lst_select_in_edge))
                if self.model.nprandom.choice([True, False],
                                              p=[prod_remove,
                                                 1 - prod_remove]):
                    # print(self.name, remove_product.code, 'affect',
                    #       customer.name)
                    if remove_product not in \
                            customer.a_lst_up_product_removed:
                        customer.a_lst_up_product_removed.append(
                            remove_product)
                        customer.dct_n_trial_up_prod_removed[
                            remove_product] = 0

    def seek_alt_supply(self):
        for product in self.a_lst_up_product_removed:
            # print(f"{self.name} seek alt supply for {product.code}")
            if self.dct_n_trial_up_prod_removed[
                    product] <= self.model.int_n_max_trial:
                if self.dct_n_trial_up_prod_removed[product] == 0:
                    # select a list of candidate firm that has the product
                    self.dct_cand_alt_supply_up_prod_removed[product] = \
                        self.model.a_lst_total_firms.select([
                            product in firm.a_lst_product
                            and product not in firm.a_lst_product_removed
                            for firm in self.model.a_lst_total_firms
                        ])
                if not self.dct_cand_alt_supply_up_prod_removed[product]:
                    continue
                # select based on connection
                lst_firm_connect = []
                if self.is_firm_req_prf_conn:
                    for firm in \
                            self.dct_cand_alt_supply_up_prod_removed[product]:
                        out_edges = self.model.firm_network.graph.out_edges(
                            self.model.firm_network.positions[firm], keys=True)
                        in_edges = self.model.firm_network.graph.in_edges(
                            self.model.firm_network.positions[firm], keys=True)
                        lst_adj_firm = []
                        lst_adj_firm += \
                            [ap.AgentIter(self.model, edge[1]).to_list()[
                                0].code for edge in out_edges]
                        lst_adj_firm += \
                            [ap.AgentIter(self.model, edge[0]).to_list()[
                                0].code for edge in in_edges]
                        if self.code in lst_adj_firm:
                            lst_firm_connect.append(firm)
                if len(lst_firm_connect) == 0:
                    # select based on size
                    lst_size_damp = \
                        [size ** self.flt_firm_req_prf_size for size in
                         self.dct_cand_alt_supply_up_prod_removed[
                             product].revenue_log]
                    lst_prob = [size_damp / sum(lst_size_damp)
                                for size_damp in lst_size_damp]
                    select_alt_supply = self.model.nprandom.choice(
                        self.dct_cand_alt_supply_up_prod_removed[product],
                        p=lst_prob)
                elif len(lst_firm_connect) > 0:
                    # select based on size of firm that has connection
                    lst_firm_size_damp = \
                        [firm.revenue_log ** self.flt_firm_acc_prf_size
                            for firm in lst_firm_connect]
                    lst_prob = \
                        [size_damp / sum(lst_firm_size_damp)
                            for size_damp in lst_firm_size_damp]
                    select_alt_supply = \
                        self.model.nprandom.choice(lst_firm_connect,
                                                   p=lst_prob)
                # print(
                #     f"{self.name} selct alt supply for {product.code} "
                #     f"from {select_alt_supply.name}"
                # )
                assert product in select_alt_supply.a_lst_product, \
                    f"{select_alt_supply} \
                        does not produce requested product {product}"

                if product in select_alt_supply.dct_request_prod_from_firm.\
                        keys():
                    select_alt_supply.dct_request_prod_from_firm[
                        product].append(self)
                else:
                    select_alt_supply.dct_request_prod_from_firm[product] = [
                        self
                    ]
                # print(
                #     select_alt_supply.name, 'dct_request_prod_from_firm', {
                #         key.code: [v.name for v in value]
                #         for key, value in
                #         select_alt_supply.dct_request_prod_from_firm.items()
                #     })

                self.dct_n_trial_up_prod_removed[product] += 1

    def handle_request(self):
        # print(self.name, 'handle_request')
        for product, lst_firm in self.dct_request_prod_from_firm.items():
            if self.dct_prod_capacity[product] > 0:
                if len(lst_firm) == 0:
                    continue
                elif len(lst_firm) == 1:
                    self.accept_request(lst_firm[0], product)
                elif len(lst_firm) > 1:
                    # handling based on connection
                    lst_firm_connect = []
                    if self.is_firm_acc_prf_conn:
                        for firm in lst_firm:
                            out_edges = \
                                self.model.firm_network.graph.out_edges(
                                    self.model.firm_network.positions[firm],
                                    keys=True)
                            in_edges = \
                                self.model.firm_network.graph.in_edges(
                                    self.model.firm_network.positions[firm],
                                    keys=True)
                            lst_adj_firm = []
                            lst_adj_firm += \
                                [ap.AgentIter(self.model, edge[1]).to_list()[
                                    0].code for edge in out_edges]
                            lst_adj_firm += \
                                [ap.AgentIter(self.model, edge[0]).to_list()[
                                    0].code for edge in in_edges]
                            if self.code in lst_adj_firm:
                                lst_firm_connect.append(firm)
                    if len(lst_firm_connect) == 0:
                        # handling based on size
                        lst_firm_size_damp = \
                            [firm.revenue_log ** self.flt_firm_acc_prf_size
                             for firm in lst_firm]
                        lst_prob = \
                            [size_damp / sum(lst_firm_size_damp)
                             for size_damp in lst_firm_size_damp]
                        select_customer = \
                            self.model.nprandom.choice(lst_firm, p=lst_prob)
                        self.accept_request(select_customer, product)
                    elif len(lst_firm_connect) > 0:
                        # handling based on size of firm that has connection
                        lst_firm_size_damp = \
                            [firm.revenue_log ** self.flt_firm_acc_prf_size
                             for firm in lst_firm_connect]
                        lst_prob = \
                            [size_damp / sum(lst_firm_size_damp)
                             for size_damp in lst_firm_size_damp]
                        select_customer = \
                            self.model.nprandom.choice(lst_firm_connect,
                                                       p=lst_prob)
                        self.accept_request(select_customer, product)

    def accept_request(self, down_firm, product):
        lst_firm_size = [
            firm.revenue_log for firm in self.model.a_lst_total_firms
            if product in firm.a_lst_product
            and product not in firm.a_lst_product_removed
        ]
        prod_accept = self.revenue_log / sum(lst_firm_size)
        # damp prod
        prod_accept = prod_accept ** self.flt_diff_new_conn
        if self.model.nprandom.choice([True, False],
                                      p=[prod_accept, 1 - prod_accept]):
            self.firm_network.graph.add_edges_from([
                (self.firm_network.positions[self],
                 self.firm_network.positions[down_firm], {
                     'Product': product.code
                })
            ])
            self.dct_prod_capacity[product] -= 1
            self.dct_request_prod_from_firm[product].remove(down_firm)
            down_firm.a_lst_up_product_removed.remove(product)
            # print(
            #     f"{self.name} accept {product.code} request "
            #     f"from {down_firm.name}"
            # )
        else:
            down_firm.dct_cand_alt_supply_up_prod_removed[product].remove(self)

    def clean_before_trial(self):
        self.dct_request_prod_from_firm = {}

    def clean_before_time_step(self):
        self.dct_n_trial_up_prod_removed = {}
        self.a_lst_up_product_removed = ap.AgentList(self.model, [])

    def get_firm_network_node(self):
        return self.firm_network.positions[self]