import numpy as np
import pandas as pd
from orm import engine
from scipy.stats import f

    This file needs to define the info in the *main* block,
    and then run the anova function.

def do_print(lst_value, str_col):
    Just for friendly-looking printing

    :param lst_value:
    :param str_col:
    str_data = '\t'.join(
        [str(round(e, 4 if 'P value' in str_col else 3)) for e in lst_value])

def anova(lst_col_seg, n_level, oa_file, result_file, alpha=0.1):
    Give the files and info, compute the significance of each X for each Y

    :param lst_col_seg: record the number of X, E, and Y.
    :param n_level:
    :param oa_file:
    :param result_file:
    :param alpha: significance level, usually 0.1, 0.05, 0.01
    # read and check the files
    df_oa = pd.read_csv("oa_with_exp.csv", index_col=None)
    df_res = result_file
    assert df_res.shape[1] == sum(lst_col_seg), 'the column number is wrong'
    assert df_oa.shape[1] == lst_col_seg[0] + \
        lst_col_seg[1], 'the column number is wrong'
    lst_head = [f"{idx+1}_{ind_name}" for idx,
                ind_name in enumerate(df_res.columns)]

    # The three lines below define some coefficients for further computation
    n_col_input = lst_col_seg[0] + lst_col_seg[1]
    n_exp_row = df_res.shape[0]
    n_degree_error = n_exp_row - 1 - (n_level - 1) * lst_col_seg[0]

    df_output = df_res.iloc[:, n_col_input:]

    print("Source\tSource\t" + '\t'.join(lst_head[:lst_col_seg[0]]) + "\te")
    print("DOF\tDOF\t" + '\t'.join([str(n_level-1)]
          * lst_col_seg[0]) + f"\t{n_degree_error}")

    lst_report = []

    # start to loop each Y
    for idx_col in range(lst_col_seg[2]):
        str_ind_name = lst_head[idx_col+n_col_input]

        df_y_col = df_output.iloc[:, idx_col]  # the y column
        df_y_col_repeated = np.tile(
            df_y_col, (n_col_input, 1)).T  # repeat the y column
        big_t = df_y_col.sum()  # the big T

        # generate T1, ..., T(n_levels)
        lst_2d_big_t = []  # Table 1, row 10, 11, 12
        for level in range(n_level):
            arr_big_t = np.sum(df_y_col_repeated *
                               np.where(df_oa == level, 1, 0), axis=0)
        arr_big_t_2 = np.power(np.array(lst_2d_big_t), 2)
        arr_s = np.sum(arr_big_t_2, axis=0) / (n_exp_row / n_level) - \
            big_t * big_t / n_exp_row  # Table 1, last row
        assert arr_s.size == n_col_input, 'wrong arr_s size'

        # so far, the first table is computed. Now, compute the second table
        df_s = pd.DataFrame(arr_s.reshape((1, n_col_input)),
        do_print(arr_s.tolist(), f'{str_ind_name}\tS')  # Table 2, col 2

        df_s_non_error = df_s.iloc[:, :lst_col_seg[0]] / (n_level - 1)
        ms_of_error = \
            df_s.iloc[:, lst_col_seg[0]:].sum().sum() / n_degree_error

                 0] + [ms_of_error], f'{str_ind_name}\tMS')  # Table 2, col 4

        arr_f = df_s_non_error / ms_of_error
        # Table 2, col 5
        do_print(arr_f.values.tolist()[0], f'{str_ind_name}\tF ratio')

        # from scipy.stats import f
        arr_p_value = f.sf(arr_f, n_level - 1, n_degree_error)
        # Table 2, col 6
        do_print(arr_p_value.tolist()[0], f'{str_ind_name}\tP value')

        lst_sig = [c for c, p in zip(
            lst_head[:lst_col_seg[0]], arr_p_value[0].tolist()) if p < alpha]

        if len(lst_sig) > 0:
                f"For indicator {str_ind_name}, the sig factors are {lst_sig}")

    for s in lst_report:

if __name__ == '__main__':
    # prep data
    str_sql = """
    select * from
    (select distinct idx_scenario, n_max_trial, crit_supplier,
    firm_pref_request, firm_pref_accept, netw_pref_cust_n,
    netw_pref_cust_size, cap_limit, diff_new_conn, diff_remove
    from iiabmdb.with_exp_experiment) as a
    inner join
    select idx_scenario,
    sum(n_disrupt_s) as n_disrupt_s, sum(n_disrupt_t) as n_disrupt_t from
    iiabmdb.with_exp_experiment as a
    inner join
    select e_id, count(n_s_disrupt_t) as n_disrupt_s,
    sum(n_s_disrupt_t) as n_disrupt_t from
    iiabmdb.with_exp_sample as a
    inner join
    (select a.s_id as s_id, count(id) as n_s_disrupt_t from
    iiabmdb.with_exp_result as a
    inner join
    (select distinct s_id from iiabmdb.with_exp_result where ts > 0) as b
    on a.s_id = b.s_id
    group by s_id
    ) as b
    on = b.s_id
    group by e_id
    ) as b
    on = b.e_id
    group by idx_scenario) as b
    on a.idx_scenario = b.idx_scenario;

    result = pd.read_sql(sql=str_sql,
    result.drop('idx_scenario', 1, inplace=True)
    df_oa = pd.read_csv("oa_with_exp.csv", index_col=None)
    result = pd.concat(
        [result.iloc[:, 0:10],
         df_oa.iloc[:, -4:],
         result.iloc[:, -2:]], axis=1)

    # 9 factors (X), 4 for error (E), and 2 indicators (Y)
    the_lst_col_seg = [10, 3, 2]
    the_n_level = 3
    anova(the_lst_col_seg, the_n_level, "oa25.txt", result, 0.1)