# -*- coding: utf-8 -*-
from sqlalchemy import create_engine, inspect
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import (Column, Integer, DECIMAL, String, ForeignKey,
                        BigInteger, DateTime, PickleType, Boolean, Text)
from sqlalchemy.sql import func
from sqlalchemy.orm import relationship, Session
from sqlalchemy.pool import NullPool
import yaml


with open('conf_db.yaml') as file:
    dct_conf_db_all = yaml.full_load(file)
    is_local_db = dct_conf_db_all['is_local_db']
    if is_local_db:
        dct_conf_db = dct_conf_db_all['local']
    else:
        dct_conf_db = dct_conf_db_all['remote']

with open('conf_db_prefix.yaml') as file:
    dct_conf_db_prefix = yaml.full_load(file)
    db_name_prefix = dct_conf_db_prefix['db_name_prefix']


str_login = 'mysql://{}:{}@{}:{}/{}'.format(dct_conf_db['user_name'],
                                            dct_conf_db['password'],
                                            dct_conf_db['address'],
                                            dct_conf_db['port'],
                                            dct_conf_db['db_name'])
print('DB is {}:{}/{}'.format(dct_conf_db['address'],
      dct_conf_db['port'], dct_conf_db['db_name']))

# must be null pool to avoid connection lost error
engine = create_engine(str_login, poolclass=NullPool)
ins = inspect(engine)

Base = declarative_base()

db_session = Session(bind=engine)


class Experiment(Base):
    __tablename__ = f"{db_name_prefix}_experiment"
    id = Column(Integer, primary_key=True, autoincrement=True)

    idx_scenario = Column(Integer, nullable=False)
    idx_init_removal = Column(Integer, nullable=False)

    # fixed parameters
    n_sample = Column(Integer, nullable=False)
    n_iter = Column(Integer, nullable=False)

    # variables
    dct_lst_init_remove_firm_prod = Column(PickleType, nullable=False)
    g_bom = Column(Text(4294000000), nullable=False)

    n_max_trial = Column(Integer, nullable=False)
    crit_supplier = Column(DECIMAL(8, 4), nullable=False)
    firm_req_prf_size = Column(DECIMAL(8, 4), nullable=False)
    firm_req_prf_conn = Column(Boolean, nullable=False)
    firm_acc_prf_size = Column(DECIMAL(8, 4), nullable=False)
    firm_acc_prf_conn = Column(Boolean, nullable=False)
    netw_sply_prf_n = Column(Integer, nullable=False)
    netw_sply_prf_size = Column(DECIMAL(8, 4), nullable=False)
    cap_limit_prob_type = Column(String(16), nullable=False)
    cap_limit_level = Column(DECIMAL(8, 4), nullable=False)
    diff_new_conn = Column(DECIMAL(8, 4), nullable=False)
    diff_remove = Column(DECIMAL(8, 4), nullable=False)
    proactive_ratio = Column(DECIMAL(8, 4), nullable=False)

    sample = relationship(
        'Sample', back_populates='experiment', lazy='dynamic')

    def __repr__(self):
        return f'<Experiment: {self.id}>'


class Sample(Base):
    __tablename__ = f"{db_name_prefix}_sample"
    id = Column(Integer, primary_key=True, autoincrement=True)
    e_id = Column(Integer, ForeignKey('{}.id'.format(
        f"{db_name_prefix}_experiment")), nullable=False)

    idx_sample = Column(Integer, nullable=False)
    seed = Column(BigInteger, nullable=False)
    # -1, waiting; 0, running; 1, done
    is_done_flag = Column(Integer, nullable=False)
    computer_name = Column(String(64), nullable=True)
    ts_done = Column(DateTime(timezone=True), onupdate=func.now())
    stop_t = Column(Integer, nullable=True)

    g_firm = Column(Text(4294000000), nullable=True)

    experiment = relationship(
        'Experiment', back_populates='sample', uselist=False)
    result = relationship('Result', back_populates='sample', lazy='dynamic')

    def __repr__(self):
        return f'<Sample id: {self.id}>'


class Result(Base):
    __tablename__ = f"{db_name_prefix}_result"
    id = Column(Integer, primary_key=True, autoincrement=True)
    s_id = Column(Integer, ForeignKey('{}.id'.format(
        f"{db_name_prefix}_sample")), nullable=False)

    id_firm = Column(String(10), nullable=False)
    id_product = Column(String(10), nullable=False)
    ts = Column(Integer, nullable=False)
    is_disrupted = Column(Boolean, nullable=True)
    is_removed = Column(Boolean, nullable=True)

    sample = relationship('Sample', back_populates='result', uselist=False)

    def __repr__(self):
        return f'<Product id: {self.id}>'


if __name__ == '__main__':
    Base.metadata.drop_all()
    Base.metadata.create_all()