from mesa import Agent


class FirmAgent(Agent):
    def __init__(self, unique_id, model, type_region, revenue_log, a_lst_product,
                 production_output, demand_quantity, R, P, C):
        # 调用超类的 __init__ 方法
        super().__init__(unique_id, model)

        # 初始化模型中的网络引用
        self.firm_network = self.model.firm_network
        self.product_network = self.model.product_network
        # 初始化代理自身的属性
        self.type_region = type_region
        self.size_stat = []
        self.dct_prod_up_prod_stat = {}
        self.dct_prod_capacity = {}
        # 企业涉及的产业
        self.indus_i = a_lst_product
        # 各资源库存信息,库存资源,库存量
        self.R = R
        # 包括库存时间的值 方便后面统计
        self.R1 = {0: R}
        # 设备资产信息,持有设备,设备数量, 增加 设备残值 [[1,2,3],[] ]
        self.C = C
        # 包括设备时间步的值
        self.C1 = {0: C}
        # 复制一份
        self.C0 = C
        # 产品库存信息 库存产品,库存量 ID 数量
        self.P = P
        # 包括 产品时间
        self.P1 = {0: P}
        # 企业i的供应商
        self.upper_i = [self.model.agent_map[u] for u, v in self.firm_network.in_edges(self.unique_id)
                        if u in self.model.agent_map]
        # 企业i的客户
        self.downer_i = [self.model.agent_map[v] for u, v in self.firm_network.out_edges(self.unique_id)
                         if v in self.model.agent_map]
        # 设备c的数量 (总量) 使用这个来判断设备数量
        # self.n_equip_c = n_equip_c
        # 设备c产量 根据设备量进行估算
        self.c_yield = production_output
        # 消耗材料量 根据设备量进行估算   {           }
        self.c_consumption = demand_quantity
        # 设备c购买价格(初始值)
        # self.c_price = c_price
        # 资源r补货库存阈值 很重要设置
        self.s_r = 40
        self.S_r = 120
        # 设备补货阙值 可选
        # self.ss_r = 70
        # 每一个周期步减少残值:x
        self.x = 20
        # 试验中的参数
        self.dct_n_trial_up_prod_disrupted = {}
        self.dct_cand_alt_supp_up_prod_disrupted = {}
        self.dct_request_prod_from_firm = {}

        # 外部变量
        self.is_prf_size = self.model.is_prf_size
        self.is_prf_conn = bool(self.model.prf_conn)
        self.str_cap_limit_prob_type = str(self.model.cap_limit_prob_type)
        self.flt_cap_limit_level = float(self.model.cap_limit_level)
        self.flt_diff_new_conn = float(self.model.diff_new_conn)

        # 初始化 size_stat
        self.size_stat.append((revenue_log, 0))

        # 初始化 dct_prod_up_prod_stat
        for prod in a_lst_product:
            self.dct_prod_up_prod_stat[prod] = {
                'p_stat': [('N', 0)],
                's_stat': {up_prod: {'stat': True, 'set_disrupt_firm': set()}
                           for up_prod in prod.a_predecessors()}
            }

        # 初始化额外容量 (dct_prod_capacity)
        for product in a_lst_product:
            assert self.str_cap_limit_prob_type in ['uniform', 'normal'], \
                "cap_limit_prob_type must be either 'uniform' or 'normal'"
            extra_cap_mean = self.size_stat[0][0] / self.flt_cap_limit_level
            if self.str_cap_limit_prob_type == 'uniform':
                extra_cap = self.model.random.uniform(extra_cap_mean - 2, extra_cap_mean + 2)
                extra_cap = 0 if round(extra_cap) < 0 else round(extra_cap)
            elif self.str_cap_limit_prob_type == 'normal':
                extra_cap = self.model.random.normalvariate(extra_cap_mean, 1)
                extra_cap = 0 if round(extra_cap) < 0 else round(extra_cap)
            self.dct_prod_capacity[product] = extra_cap

    def remove_edge_to_cus(self, disrupted_prod):
        # parameter disrupted_prod is the product that self got disrupted
        lst_out_edge = list(
            self.firm_network.out_edges(
                self.unique_id, keys=True, data='Product'))
        for n1, n2, key, product_code in lst_out_edge:
            if product_code == disrupted_prod.unique_id:
                # update customer up product supplier status
                customer = next(agent for agent in self.model.company_agents if agent.unique_id == n2)

                for prod in customer.dct_prod_up_prod_stat.keys():
                    if disrupted_prod in customer.dct_prod_up_prod_stat[prod]['s_stat'].keys():
                        customer.dct_prod_up_prod_stat[prod]['s_stat'][disrupted_prod][
                            'set_disrupt_firm'].add(self)
                        # print(f"{self.name} disrupt {customer.name}'s "
                        #       f"{prod.code} due to {disrupted_prod.code}")
                # remove edge to customer
                self.firm_network.remove_edge(n1, n2, key)

    def disrupt_cus_prod(self, prod, disrupted_up_prod):
        # parameter prod is the product that has disrupted_up_prod
        # parameter disrupted_up_prod is the product that
        # self's component exists disrupted supplier
        num_lost = \
            len(self.dct_prod_up_prod_stat[prod]['s_stat']
                [disrupted_up_prod]['set_disrupt_firm'])
        num_remain = \
            len([u for u, _, _, d in
                 self.firm_network.in_edges(self.get_firm_network_unique_id(),
                                            keys=True,
                                            data='Product')
                 if d == disrupted_up_prod.unique_id])
        lost_percent = num_lost / (num_lost + num_remain)
        lst_size = \
            [firm.size_stat[-1][0] for firm in self.model.company_agents]
        std_size = (self.size_stat[-1][0] - min(lst_size) + 1) \
                   / (max(lst_size) - min(lst_size) + 1)

        # calculate probability of disruption
        prob_disrupt = 1 - std_size * (1 - lost_percent)
        if self.model.nprandom.choice([True, False],
                                      p=[prob_disrupt,
                                         1 - prob_disrupt]):
            self.dct_n_trial_up_prod_disrupted[disrupted_up_prod] = 0
            self.dct_prod_up_prod_stat[
                prod]['s_stat'][disrupted_up_prod]['stat'] = False
            status, _ = self.dct_prod_up_prod_stat[
                prod]['p_stat'][-1]
            if status != 'D':
                self.dct_prod_up_prod_stat[
                    prod]['p_stat'].append(('D', self.model.t))
                # print(f"{self.name}'s {prod.code} turn to D status due to "
                #       f"disrupted supplier of {disrupted_up_prod.code}")

    def seek_alt_supply(self, product):
        # 检查当前产品的尝试次数是否达到最大值
        if self.dct_n_trial_up_prod_disrupted[product] <= self.model.int_n_max_trial:
            # 初始化候选供应商列表
            if self.dct_n_trial_up_prod_disrupted[product] == 0:
                self.dct_cand_alt_supp_up_prod_disrupted[product] = [
                    firm for firm in self.model.company_agents if firm.is_prod_in_current_normal(product)
                ]

            # 如果没有候选供应商,直接退出
            if not self.dct_cand_alt_supp_up_prod_disrupted[product]:
                # print(f"No valid candidates found for product {product.unique_id}")
                return

            # 查找与当前企业已连接的候选供应商
            lst_firm_connect = []
            if self.is_prf_conn:
                lst_firm_connect = [
                    firm for firm in self.dct_cand_alt_supp_up_prod_disrupted[product]
                    if self.firm_network.has_edge(self.unique_id, firm.unique_id) or
                       self.firm_network.has_edge(firm.unique_id, self.unique_id)
                ]

            # 如果没有连接的供应商
            if not lst_firm_connect:
                if self.is_prf_size:  # 根据规模加权选择
                    lst_size = [firm.size_stat[-1][0] for firm in self.dct_cand_alt_supp_up_prod_disrupted[product]]
                    lst_prob = [size / sum(lst_size) for size in lst_size]
                    select_alt_supply = self.random.choices(
                        self.dct_cand_alt_supp_up_prod_disrupted[product], weights=lst_prob
                    )[0]
                else:  # 随机选择
                    select_alt_supply = self.random.choice(self.dct_cand_alt_supp_up_prod_disrupted[product])
            else:  # 如果存在连接的供应商
                if self.is_prf_size:  # 根据规模加权选择
                    lst_firm_size = [firm.size_stat[-1][0] for firm in lst_firm_connect]
                    lst_prob = [size / sum(lst_firm_size) for size in lst_firm_size]
                    select_alt_supply = self.random.choices(lst_firm_connect, weights=lst_prob)[0]
                else:  # 随机选择
                    select_alt_supply = self.random.choice(lst_firm_connect)

            # 检查选中的供应商是否能够生产产品
            if not select_alt_supply.is_prod_in_current_normal(product):
               # print(f"Selected supplier {select_alt_supply.unique_id} cannot produce product {product.unique_id}")

                # 打印供应商的生产状态字典
                #print(f"Supplier production state: {select_alt_supply.dct_prod_up_prod_stat}")

                # 检查产品是否存在于生产状态字典中
                if product in select_alt_supply.dct_prod_up_prod_stat:
                    print(
                        f"Product {product.unique_id} production state: {select_alt_supply.dct_prod_up_prod_stat[product]['p_stat']}")
                else:
                    print(f"Product {product.unique_id} not found in supplier production state.")
                return

            # 添加到供应商的请求字典
            if product in select_alt_supply.dct_request_prod_from_firm:
                select_alt_supply.dct_request_prod_from_firm[product].append(self)
            else:
                select_alt_supply.dct_request_prod_from_firm[product] = [self]

            # 更新尝试次数
            self.dct_n_trial_up_prod_disrupted[product] += 1

    def handle_request(self):
        for product, lst_firm in self.dct_request_prod_from_firm.items():
            if self.dct_prod_capacity[product] > 0:
                if len(lst_firm) == 0:
                    continue
                elif len(lst_firm) == 1:
                    self.accept_request(lst_firm[0], product)
                elif len(lst_firm) > 1:
                    lst_firm_connect = []
                    if self.is_prf_conn:
                        for firm in lst_firm:
                            if self.firm_network.has_edge(self.unique_id, firm.unique_id) or \
                                    self.firm_network.has_edge(firm.unique_id, self.unique_id):
                                lst_firm_connect.append(firm)
                    if len(lst_firm_connect) == 0:
                        if self.is_prf_size:
                            lst_firm_size = [firm.size_stat[-1][0] for firm in lst_firm]
                            lst_prob = [size / sum(lst_firm_size) for size in lst_firm_size]
                            select_customer = self.random.choices(lst_firm, weights=lst_prob)[0]
                        else:
                            select_customer = self.random.choice(lst_firm)
                        self.accept_request(select_customer, product)
                    elif len(lst_firm_connect) > 0:
                        if self.is_prf_size:
                            lst_firm_size = [firm.size_stat[-1][0] for firm in lst_firm_connect]
                            lst_prob = [size / sum(lst_firm_size) for size in lst_firm_size]
                            select_customer = self.random.choices(lst_firm_connect, weights=lst_prob)[0]
                        else:
                            select_customer = self.random.choice(lst_firm_connect)
                        self.accept_request(select_customer, product)
            else:
                for down_firm in lst_firm:
                    down_firm.dct_cand_alt_supp_up_prod_disrupted[product].remove(self)

    def accept_request(self, down_firm, product):
        if self.firm_network.has_edge(self.unique_id, down_firm.unique_id) or \
                self.firm_network.has_edge(down_firm.unique_id, self.unique_id):
            prod_accept = 1.0
        else:
            prod_accept = self.flt_diff_new_conn
        if self.model.nprandom.choice([True, False], p=[prod_accept, 1 - prod_accept]):
            self.firm_network.add_edge(self.unique_id, down_firm.unique_id, Product=product.unique_id)
            self.dct_prod_capacity[product] -= 1
            self.dct_request_prod_from_firm[product].remove(down_firm)

            for prod in down_firm.dct_prod_up_prod_stat.keys():
                if product in down_firm.dct_prod_up_prod_stat[prod]['s_stat']:
                    down_firm.dct_prod_up_prod_stat[prod]['s_stat'][product]['stat'] = True
                    down_firm.dct_prod_up_prod_stat[prod]['p_stat'].append(
                        ('N', self.model.t))
            del down_firm.dct_n_trial_up_prod_disrupted[product]
            del down_firm.dct_cand_alt_supp_up_prod_disrupted[product]
        else:
            down_firm.dct_cand_alt_supp_up_prod_disrupted[product].remove(self)

    def seek_material_supply(self, material_type):
        lst_firm_material_connect = []  # 符合条件 可选择的上游
        upper_i_material = []  # 特定 资源的上游 企业集合
        for firm in self.upper_i:
            for sub_list in firm.R:
                if sub_list[0] == material_type:
                    upper_i_material.append(firm)
        # 没有 上游 没有 材料的情况,也就是紊乱的情况
        if len(upper_i_material) == 0:
            return -1
        if self.is_prf_conn:
            for firm in upper_i_material:
                if self.firm_network.has_edge(self.unique_id, firm.unique_id) or self.firm_network.has_edge(
                        firm.unique_id, self.unique_id):
                    lst_firm_material_connect.append(firm)
        if len(lst_firm_material_connect) == 0:
            if self.is_prf_size:
                lst_size = [firm.size_stat[-1][0] for firm in upper_i_material]
                lst_prob = [size / sum(lst_size) for size in lst_size]
                select_alt_supply = \
                    self.random.choices(upper_i_material, weights=lst_prob)[0]
            else:
                select_alt_supply = self.random.choice(upper_i_material)
        elif len(lst_firm_material_connect) > 0:
            if self.is_prf_size:
                lst_firm_size = [firm.size_stat[-1][0] for firm in lst_firm_material_connect]
                lst_prob = [size / sum(lst_firm_size) for size in lst_firm_size]
                select_alt_supply = self.random.choices(lst_firm_material_connect, weights=lst_prob)[0]
            else:
                select_alt_supply = self.random.choice(lst_firm_material_connect)
        return select_alt_supply

    def seek_machinery_supply(self, machinery_type):
        lst_firm_machinery_connect = []  # 符合条件 可选择的上游
        upper_i_machinery = []  # 特定 资源的上游 企业集合
        for firm in self.upper_i:
            for sub_list in firm.R:
                if sub_list[0] == machinery_type:
                    upper_i_machinery.append(firm)
        # 没有 上游 没有 材料的情况,也就是紊乱的情况
        if len(upper_i_machinery) == 0:
            return -1
        if self.is_prf_conn:
            for firm in upper_i_machinery:
                if self.firm_network.has_edge(self.unique_id, firm.unique_id) or self.firm_network.has_edge(
                        firm.unique_id, self.unique_id):
                    lst_firm_machinery_connect.append(firm)
        if len(lst_firm_machinery_connect) == 0:
            if self.is_prf_size:
                lst_size = [firm.size_stat[-1][0] for firm in upper_i_machinery]
                lst_prob = [size / sum(lst_size) for size in lst_size]
                select_alt_supply = \
                    self.random.choices(upper_i_machinery, weights=lst_prob)[0]
            else:
                select_alt_supply = self.random.choice(upper_i_machinery)
        elif len(lst_firm_machinery_connect) > 0:
            if self.is_prf_size:
                lst_firm_size = [firm.size_stat[-1][0] for firm in lst_firm_machinery_connect]
                lst_prob = [size / sum(lst_firm_size) for size in lst_firm_size]
                select_alt_supply = self.random.choices(lst_firm_machinery_connect, weights=lst_prob)[0]
            else:
                select_alt_supply = self.random.choice(lst_firm_machinery_connect)
        return select_alt_supply

    def handle_material_request(self, mater_list):
        for list_P in self.P:
            if list_P[0] == mater_list[0]:
                list_P[1] -= mater_list[1]

    def handle_machinery_request(self, machi_list):
        for list_C in self.C:
            if list_C[0] == machi_list[0]:
                list_C[1] -= machi_list[1]

    def refresh_R(self):
        self.R1[self.model.t] = self.R

    def refresh_C(self):
        self.C1[self.model.t] = self.C

    def refresh_P(self):
        self.P1[self.model.t] = self.P

    def clean_before_trial(self):
        self.dct_request_prod_from_firm = {}

    def clean_before_time_step(self):
        # Reset the number of trials and candidate suppliers for disrupted products
        self.dct_n_trial_up_prod_disrupted = dict.fromkeys(self.dct_n_trial_up_prod_disrupted.keys(), 0)
        self.dct_cand_alt_supp_up_prod_disrupted = {}

        # Update the status of products and refresh disruption sets
        for prod in self.dct_prod_up_prod_stat.keys():
            status, ts = self.dct_prod_up_prod_stat[prod]['p_stat'][-1]
            if ts != self.model.t:
                self.dct_prod_up_prod_stat[prod]['p_stat'].append((status, self.model.t))

            # Refresh the set of disrupted firms
            for up_prod in self.dct_prod_up_prod_stat[prod]['s_stat'].keys():
                self.dct_prod_up_prod_stat[prod]['s_stat'][up_prod]['set_disrupt_firm'] = set()

    def get_firm_network_unique_id(self):
        return self.unique_id

    def is_prod_in_current_normal(self, prod):
        if prod in self.dct_prod_up_prod_stat.keys():
            if self.dct_prod_up_prod_stat[prod]['p_stat'][-1][0] == 'N':
                return True
            else:
                return False
        else:
            return False