import agentpy as ap
import math


class FirmAgent(ap.Agent):
    def setup(self, code, name, type_region, revenue_log, a_lst_product):
        self.firm_network = self.model.firm_network
        self.product_network = self.model.product_network

        self.code = code
        self.name = name
        self.type_region = type_region
        self.revenue_log = revenue_log
        self.a_lst_product = a_lst_product
        self.dct_prod_capacity = dict.fromkeys(self.a_lst_product)

        self.a_lst_up_product_removed = ap.AgentList(self.model, [])
        self.a_lst_product_disrupted = ap.AgentList(self.model, [])
        self.a_lst_product_removed = ap.AgentList(self.model, [])

        self.dct_n_trial_up_product_removed = {}
        self.dct_request_prod_from_firm = {}

    def remove_edge_to_cus_remove_cus_up_prod(self, remove_product):
        lst_out_edge = list(
            self.firm_network.graph.out_edges(
                self.firm_network.positions[self], keys=True, data='Product'))
        for n1, n2, key, product_code in lst_out_edge:
            if product_code == remove_product.code:
                # remove edge
                self.firm_network.graph.remove_edge(n1, n2, key)

                # remove customer up product conditionally
                customer = ap.AgentIter(self.model, n2).to_list()[0]
                lst_in_edge = list(
                    self.firm_network.graph.in_edges(n2,
                                                     keys=True,
                                                     data='Product'))
                lst_select_in_edge = [
                    edge for edge in lst_in_edge
                    if edge[-1] == remove_product.code
                ]
                prod_remove = math.exp(-1 * len(lst_select_in_edge))
                if self.model.nprandom.choice([True, False],
                                              p=[prod_remove,
                                                 1 - prod_remove]):
                    # print(self.name, remove_product.code, 'affect',
                    #       customer.name)
                    if remove_product not in \
                            customer.a_lst_up_product_removed:
                        customer.a_lst_up_product_removed.append(
                            remove_product)
                        customer.dct_n_trial_up_product_removed[
                            remove_product] = 0

    def seek_alt_supply(self):
        for product in self.a_lst_up_product_removed:
            # print(f"{self.name} seek alt supply for {product.code}")
            if self.dct_n_trial_up_product_removed[
                    product] <= self.model.int_n_max_trial:
                # select a list of candidate firm that has the product
                candidate_alt_supply = self.model.a_lst_total_firms.select([
                    product in firm.a_lst_product
                    and product not in firm.a_lst_product_removed
                    for firm in self.model.a_lst_total_firms
                ])
                if not candidate_alt_supply:
                    continue
                # select based on size
                lst_prob = [
                    size / sum(candidate_alt_supply.revenue_log)
                    for size in candidate_alt_supply.revenue_log
                ]
                select_alt_supply = self.model.nprandom.choice(
                    candidate_alt_supply, p=lst_prob)
                # print(
                #     f"{self.name} selct alt supply for {product.code} from {select_alt_supply.name}"
                # )
                assert product in select_alt_supply.a_lst_product, \
                    f"{select_alt_supply} \
                        does not produce requested product {product}"

                if product in select_alt_supply.dct_request_prod_from_firm.\
                        keys():
                    select_alt_supply.dct_request_prod_from_firm[
                        product].append(self)
                else:
                    select_alt_supply.dct_request_prod_from_firm[product] = [
                        self
                    ]
                # print(
                #     select_alt_supply.name, 'dct_request_prod_from_firm', {
                #         key.code: [v.name for v in value]
                #         for key, value in
                #         select_alt_supply.dct_request_prod_from_firm.items()
                #     })

                self.dct_n_trial_up_product_removed[product] += 1

    def handle_request(self):
        # print(self.name, 'handle_request')
        for product, lst_firm in self.dct_request_prod_from_firm.items():
            if self.dct_prod_capacity[product] > 0:
                if len(lst_firm) == 0:
                    continue
                elif len(lst_firm) == 1:
                    self.accept_request(lst_firm[0], product)
                elif len(lst_firm) > 1:
                    # handling based on connection
                    lst_firm_connect = []
                    for firm in lst_firm:
                        out_edges = self.model.firm_network.graph.out_edges(
                            self.model.firm_network.positions[firm], keys=True)
                        in_edges = self.model.firm_network.graph.in_edges(
                            self.model.firm_network.positions[firm], keys=True)
                        lst_adj_firm = []
                        lst_adj_firm += [ap.AgentIter(self.model, edge[1]).to_list()[0].code for edge in out_edges]
                        lst_adj_firm += [ap.AgentIter(self.model, edge[0]).to_list()[0].code for edge in in_edges]
                        if self.code in lst_adj_firm:
                            lst_firm_connect.append(firm)
                    if len(lst_firm_connect) == 0:
                        # handling based on size
                        lst_firm_size = [firm.revenue_log for firm in lst_firm]
                        lst_prob = [
                            size / sum(lst_firm_size) for size in lst_firm_size
                        ]
                        select_customer = self.model.nprandom.choice(lst_firm,
                                                                    p=lst_prob)
                        self.accept_request(select_customer, product)
                    elif len(lst_firm_connect) == 1:
                        self.accept_request(lst_firm_connect[0], product)
                    elif len(lst_firm_connect) > 1:
                        # handling based on size of firm that has connection
                        lst_firm_size = [firm.revenue_log for firm in lst_firm_connect]
                        lst_prob = [
                            size / sum(lst_firm_size) for size in lst_firm_size
                        ]
                        select_customer = self.model.nprandom.choice(lst_firm_connect,
                                                                    p=lst_prob)
                        self.accept_request(select_customer, product)

    def accept_request(self, down_firm, product):
        lst_firm_size = [
            firm.revenue_log for firm in self.model.a_lst_total_firms
            if product in firm.a_lst_product
        ]
        prod_accept = self.revenue_log / sum(lst_firm_size)
        if self.model.nprandom.choice([True, False],
                                      p=[prod_accept, 1 - prod_accept]):
            self.firm_network.graph.add_edges_from([
                (self.firm_network.positions[self],
                 self.firm_network.positions[down_firm], {
                     'Product': product.code
                 })
            ])
            self.dct_prod_capacity[product] -= 1
            self.dct_request_prod_from_firm[product].remove(down_firm)
            down_firm.a_lst_up_product_removed.remove(product)
            # print(
            #     f"{self.name} accept {product.code} request from {down_firm.name}"
            # )

    def clean_before_trial(self):
        self.dct_request_prod_from_firm = {}

    def clean_before_time_step(self):
        self.dct_n_trial_up_product_removed = {}
        self.a_lst_up_product_removed = ap.AgentList(self.model, [])