# -*- coding: utf-8 -*-
from orm import db_session, engine, Base, ins
from orm import Experiment, Sample, Result
from sqlalchemy.exc import OperationalError
from sqlalchemy import text
import yaml
import random
import pandas as pd
import platform
import networkx as nx
import json
import pickle


class ControllerDB:
    dct_parameter = None
    is_test: bool = None
    db_name_prefix: str = None
    reset_flag: int

    lst_saved_s_id: list

    def __init__(self, prefix, reset_flag=0):
        with open('conf_experiment.yaml') as yaml_file:
            dct_conf_experiment = yaml.full_load(yaml_file)
        self.is_test = prefix == 'test'
        self.db_name_prefix = prefix
        dct_para_in_test = dct_conf_experiment[
            'test'] if self.is_test else dct_conf_experiment['not_test']
        self.dct_parameter = {
            'meta_seed': dct_conf_experiment['meta_seed'],
            **dct_para_in_test
        }
        print(self.dct_parameter)
        # 0, not reset; 1, reset self; 2, reset all
        self.reset_flag = reset_flag
        self.lst_saved_s_id = []

    def init_tables(self):
        self.fill_experiment_table()
        self.fill_sample_table()

    def fill_experiment_table(self):
        Firm = pd.read_csv("Firm_amended.csv")
        Firm['Code'] = Firm['Code'].astype('string')
        Firm.fillna(0, inplace=True)

        # fill dct_lst_init_remove_firm_prod
        # list_dct = []
        # for _, row in Firm.iterrows():
        #     code = row['Code']
        #     row = row['1':]
        #     for product_code in row.index[row == 1].to_list():
        #         dct = {code: [product_code]}
        #         list_dct.append(dct)
        str_sql = "select e_id, count, max_max_ts, " \
            "dct_lst_init_remove_firm_prod from " \
            "iiabmdb.without_exp_experiment as a " \
            "inner join " \
            "(select e_id, count(id) as count, max(max_ts) as max_max_ts " \
            "from iiabmdb.without_exp_sample as a " \
            "inner join (select s_id, max(ts) as max_ts from " \
            "iiabmdb.without_exp_result where ts > 0 group by s_id) as b " \
            "on a.id = b.s_id " \
            "group by e_id) as b " \
            "on a.id = b.e_id " \
            "order by count desc;"
        result = pd.read_sql(sql=str_sql, con=engine)
        result['dct_lst_init_remove_firm_prod'] = \
            result['dct_lst_init_remove_firm_prod'].apply(
            lambda x: pickle.loads(x))
        list_dct = result.loc[result['count'] >= 9,
                              'dct_lst_init_remove_firm_prod'].to_list()
        # list_dct = [{'140': ['1.4.5.1']}]
        # list_dct = [{'133': ['1.4.4.1']}]
        # list_dct = [{'2': ['1.1.3']}]
        # list_dct = [{'135': ['1.3.2.1']}]
        # list_dct = [{'79': ['2.1.3.4']}]

        # fill g_bom
        BomNodes = pd.read_csv('BomNodes.csv', index_col=0)
        BomNodes.set_index('Code', inplace=True)
        BomCateNet = pd.read_csv('BomCateNet.csv', index_col=0)
        BomCateNet.fillna(0, inplace=True)
        g_bom = nx.from_pandas_adjacency(BomCateNet.T,
                                         create_using=nx.MultiDiGraph())
        bom_labels_dict = {}
        for code in g_bom.nodes:
            bom_labels_dict[code] = BomNodes.loc[code].to_dict()
        nx.set_node_attributes(g_bom, bom_labels_dict)
        g_product_js = json.dumps(nx.adjacency_data(g_bom))

        # insert exp
        df_xv = pd.read_csv("xv.csv", index_col=None)
        # read the OA table
        df_oa = pd.read_csv("oa_with_exp.csv", index_col=None)
        for idx_scenario, row in df_oa.iterrows():
            dct_exp_para = {}
            for idx_col, para_level in enumerate(row):
                dct_exp_para[df_xv.columns[idx_col]] = \
                    df_xv.iloc[para_level, idx_col]
            # different initial removal
            for idx_init_removal, dct_init_removal in enumerate(list_dct):
                self.add_experiment_1(idx_scenario,
                                      idx_init_removal,
                                      dct_init_removal,
                                      g_product_js,
                                      **dct_exp_para)
                print(f"Inserted experiment for scenario {idx_scenario}, "
                      f"init_removal {idx_init_removal}!")

    def add_experiment_1(self, idx_scenario, idx_init_removal,
                         dct_lst_init_remove_firm_prod, g_bom,
                         n_max_trial, crit_supplier, firm_pref_request,
                         firm_pref_accept, netw_pref_cust_n,
                         netw_pref_cust_size, cap_limit, diff_new_conn,
                         diff_remove,):
        e = Experiment(
            idx_scenario=idx_scenario,
            idx_init_removal=idx_init_removal,
            n_sample=int(self.dct_parameter['n_sample']),
            n_iter=int(self.dct_parameter['n_iter']),
            dct_lst_init_remove_firm_prod=dct_lst_init_remove_firm_prod,
            g_bom=g_bom,
            n_max_trial=n_max_trial,
            crit_supplier=crit_supplier,
            firm_pref_request=firm_pref_request,
            firm_pref_accept=firm_pref_accept,
            netw_pref_cust_n=netw_pref_cust_n,
            netw_pref_cust_size=netw_pref_cust_size,
            cap_limit=cap_limit,
            diff_new_conn=diff_new_conn,
            diff_remove=diff_remove,
        )
        db_session.add(e)
        db_session.commit()

    def fill_sample_table(self):
        rng = random.Random(self.dct_parameter['meta_seed'])
        lst_seed = [
            rng.getrandbits(32)
            for _ in range(int(self.dct_parameter['n_sample']))
        ]
        lst_exp = db_session.query(Experiment).all()
        lst_sample = []
        for experiment in lst_exp:
            for idx_sample in range(int(experiment.n_sample)):
                s = Sample(e_id=experiment.id,
                           idx_sample=idx_sample + 1,
                           seed=lst_seed[idx_sample],
                           is_done_flag=-1)
                lst_sample.append(s)
        db_session.bulk_save_objects(lst_sample)
        db_session.commit()
        print(f'Inserted {len(lst_sample)} samples!')

    def reset_db(self, force_drop=False):
        # first, check if tables exist
        lst_table_obj = [
            Base.metadata.tables[str_table]
            for str_table in ins.get_table_names()
            if str_table.startswith(self.db_name_prefix)
        ]
        is_exist = len(lst_table_obj) > 0
        if force_drop:
            while is_exist:
                a_table = random.choice(lst_table_obj)
                try:
                    Base.metadata.drop_all(bind=engine, tables=[a_table])
                except KeyError:
                    pass
                except OperationalError:
                    pass
                else:
                    lst_table_obj.remove(a_table)
                    print(
                        f"Table {a_table.name} is dropped "
                        f"for exp: {self.db_name_prefix}!!!"
                    )
                finally:
                    is_exist = len(lst_table_obj) > 0

        if is_exist:
            print(
                f"All tables exist. No need to reset "
                f"for exp: {self.db_name_prefix}."
            )
            # change the is_done_flag from 0 to -1
            # rerun the in-finished tasks
            if self.reset_flag > 0:
                if self.reset_flag == 2:
                    sample = db_session.query(Sample).filter(
                        Sample.is_done_flag == 0)
                elif self.reset_flag == 1:
                    sample = db_session.query(Sample).filter(
                        Sample.is_done_flag == 0,
                        Sample.computer_name == platform.node())
                else:
                    raise ValueError('Wrong reset flag')
                if sample.count() > 0:
                    for s in sample:
                        qry_result = db_session.query(Result).filter_by(
                            s_id=s.id)
                        if qry_result.count() > 0:
                            db_session.query(Result).filter(s_id=s.id).delete()
                            db_session.commit()
                        s.is_done_flag = -1
                        db_session.commit()
                        print(f"Reset the sample id {s.id} flag from 0 to -1")
        else:
            Base.metadata.create_all(bind=engine)
            self.init_tables()
            print(
                f"All tables are just created and initialized "
                f"for exp: {self.db_name_prefix}."
            )

    def prepare_list_sample(self):
        res = db_session.execute(
            text(f"SELECT count(*) FROM {self.db_name_prefix}_sample s, "
                 f"{self.db_name_prefix}_experiment e WHERE s.e_id=e.id"
                 )).scalar()
        n_sample = 0 if res is None else res
        print(f'There are a total of {n_sample} samples.')
        res = db_session.execute(
            text(f"SELECT id FROM {self.db_name_prefix}_sample "
                 f"WHERE is_done_flag = -1"
                 ))
        for row in res:
            s_id = row[0]
            self.lst_saved_s_id.append(s_id)

    @staticmethod
    def select_random_sample(lst_s_id):
        while 1:
            if len(lst_s_id) == 0:
                return None
            s_id = random.choice(lst_s_id)
            lst_s_id.remove(s_id)
            res = db_session.query(Sample).filter(Sample.id == int(s_id),
                                                  Sample.is_done_flag == -1)
            if res.count() == 1:
                return res[0]

    def fetch_a_sample(self, s_id=None):
        if s_id is not None:
            res = db_session.query(Sample).filter(Sample.id == int(s_id))
            if res.count() == 0:
                return None
            else:
                return res[0]

        sample = self.select_random_sample(self.lst_saved_s_id)
        if sample is not None:
            return sample

        return None

    @staticmethod
    def lock_the_sample(sample: Sample):
        sample.is_done_flag, sample.computer_name = 0, platform.node()
        db_session.commit()


if __name__ == '__main__':
    # pprint.pprint(dct_exp_config)
    # pprint.pprint(dct_conf_problem)
    db = ControllerDB('first')
    ratio = db_session.execute(
        'SELECT COUNT(*) / 332750 FROM first_sample s WHERE s.is_done_flag = 1'
    ).scalar()
    print(ratio)
    # db.fill_experiment_table()
    # print(db.dct_parameter)
    # db.init_tables()
    # db.fill_sample_table()
    # pprint.pprint(dct_conf_exp)
    # db.update_bi()
    # db.reset_db(force_drop=True)
    # db.prepare_list_sample()
    #
    # for i in range(1000):
    #     if i % 10 == 0:
    #         print(i)
    #         print(len(db.lst_saved_s_id_1_2), len(db.lst_saved_s_id_3))
    #     r = db.fetch_a_sample()
    #     if i % 10 == 0:
    #         print(len(db.lst_saved_s_id_1_2), len(db.lst_saved_s_id_3))
    #         print(r, r.experiment.idx_exp)
    #     if i == 400:
    #         print()
    #         pass