IIabm/orm.py

152 lines
6.6 KiB
Python
Raw Normal View History

2023-03-12 12:02:01 +08:00
# -*- coding: utf-8 -*-
from sqlalchemy import create_engine, inspect
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import Column, Integer, String, ForeignKey, BigInteger, DECIMAL, DateTime, Text
from sqlalchemy.sql import func
from sqlalchemy.orm import relationship, Session
from sqlalchemy.pool import NullPool
import yaml
with open('conf_db.yaml') as file:
dct_conf_db_all = yaml.full_load(file)
is_local_db = dct_conf_db_all['is_local_db']
if is_local_db:
dct_conf_db = dct_conf_db_all['local']
else:
dct_conf_db = dct_conf_db_all['remote']
with open('conf_db_prefix.yaml') as file:
dct_conf_db_prefix = yaml.full_load(file)
db_name_prefix = dct_conf_db_prefix['db_name_prefix']
str_login = 'mysql://{}:{}@{}:{}/{}'.format(dct_conf_db['user_name'], dct_conf_db['password'],
dct_conf_db['address'], dct_conf_db['port'], dct_conf_db['db_name'])
print('DB is {}:{}/{}'.format(dct_conf_db['address'], dct_conf_db['port'], dct_conf_db['db_name']))
engine = create_engine(str_login, poolclass=NullPool) # must be null pool to avoid connection lost error
ins = inspect(engine)
2023-03-12 22:21:39 +08:00
Base = declarative_base()
2023-03-12 12:02:01 +08:00
db_session = Session(bind=engine)
class Experiment(Base):
__tablename__ = f"{db_name_prefix}_experiment"
id = Column(Integer, primary_key=True, autoincrement=True)
idx_exp = Column(Integer, nullable=False)
# fixed parameters
int_n_country = Column(Integer, nullable=False)
max_int_n_supplier = Column(Integer, nullable=False) # uni(1, max), random parameter 1 of firm
int_n_product = Column(Integer, nullable=False)
int_n_firm_per_product_per_country = Column(Integer, nullable=False)
flt_demand_total = Column(DECIMAL(10, 2), nullable=False) # tri(0, total_demand, mean), to compute random para a
flt_bm_price_ratio = Column(DECIMAL(10, 2), nullable=False) # benchmark value of b, same for both countries
flt_beta_developing = Column(DECIMAL(10, 2), nullable=False) # benchmark value of c(beta), for developing countries
n_sample = Column(Integer, nullable=False)
n_iter = Column(Integer, nullable=False)
# variables
is_eliminated = Column(Integer, nullable=False)
flt_beta_developed = Column(DECIMAL(10, 2), nullable=False) # larger, for developed countries
lambda_tier = Column(DECIMAL(10, 2), nullable=False)
tariff_percentage_1 = Column(DECIMAL(10, 2), nullable=False)
tariff_percentage_2 = Column(DECIMAL(10, 2), nullable=False)
sample = relationship('Sample', back_populates='experiment', lazy='dynamic')
def __repr__(self):
return f'<Experiment: {self.id}>'
class Sample(Base):
__tablename__ = f"{db_name_prefix}_sample"
id = Column(Integer, primary_key=True, autoincrement=True)
e_id = Column(Integer, ForeignKey('{}.id'.format(f"{db_name_prefix}_experiment")), nullable=False)
idx_sample = Column(Integer, nullable=False)
seed = Column(BigInteger, nullable=False)
is_done_flag = Column(Integer, nullable=False) # -1, waiting; 0, running; 1, done
computer_name = Column(String(64), nullable=True)
ts_done = Column(DateTime(timezone=True), onupdate=func.now())
stop_t = Column(Integer, nullable=True)
c1_wealth = Column(DECIMAL(20, 2), nullable=True) # country 1, developing countries
c2_wealth = Column(DECIMAL(20, 2), nullable=True) # country 2, developed countries
c1_wealth_dgt = Column(Integer, nullable=True)
c2_wealth_dgt = Column(Integer, nullable=True)
c1_tariff = Column(DECIMAL(20, 2), nullable=True) # country 1, developing countries
c2_tariff = Column(DECIMAL(20, 2), nullable=True) # country 2, developed countries
c1_tariff_dgt = Column(Integer, nullable=True)
c2_tariff_dgt = Column(Integer, nullable=True)
c1_n_firms = Column(Integer, nullable=True)
c2_n_firms = Column(Integer, nullable=True)
c1_n_positive_firms = Column(Integer, nullable=True)
c2_n_positive_firms = Column(Integer, nullable=True)
network = Column(Text(4294000000), nullable=True)
network_order = Column(Text(4294000000), nullable=True)
network_country = Column(Text(4294000000), nullable=True)
experiment = relationship('Experiment', back_populates='sample', uselist=False)
product = relationship('Product', back_populates='sample', lazy='dynamic')
def __repr__(self):
return f'<Sample id: {self.id}>'
class Product(Base):
__tablename__ = f"{db_name_prefix}_product"
id = Column(Integer, primary_key=True, autoincrement=True)
s_id = Column(Integer, ForeignKey('{}.id'.format(f"{db_name_prefix}_sample")), nullable=False)
int_name = Column(Integer, nullable=False)
int_tier = Column(Integer, nullable=False)
n_up_products = Column(Integer, nullable=False)
n_peer_products = Column(Integer, nullable=False)
n_positive_firms = Column(Integer, nullable=False)
n_all_firms = Column(Integer, nullable=False)
gini_acc_demand_per_age = Column(DECIMAL(10, 2), nullable=False)
gini_acc_wealth_per_age = Column(DECIMAL(10, 2), nullable=False)
gini_acc_demand_per_age_all = Column(DECIMAL(10, 2), nullable=False)
gini_acc_wealth_per_age_all = Column(DECIMAL(10, 2), nullable=False)
# lst_n_positive_firms = Column(Text(4294000000), nullable=False)
# lst_n_all_firms = Column(Text(4294000000), nullable=False)
# lst_gini_acc_demand_per_age = Column(Text(4294000000), nullable=False)
# lst_gini_acc_wealth_per_age = Column(Text(4294000000), nullable=False)
# lst_gini_acc_demand_per_age_all = Column(Text(4294000000), nullable=False)
# lst_gini_acc_wealth_per_age_all = Column(Text(4294000000), nullable=False)
sample = relationship('Sample', back_populates='product', uselist=False)
firm = relationship('Firm', back_populates='product', lazy='dynamic')
def __repr__(self):
return f'<Product id: {self.id}>'
class Firm(Base):
__tablename__ = f"{db_name_prefix}_firm"
id = Column(Integer, primary_key=True, autoincrement=True)
p_id = Column(Integer, ForeignKey('{}.id'.format(f"{db_name_prefix}_product")), nullable=False)
idx_firm = Column(Integer, nullable=False)
int_n_supplier = Column(Integer, nullable=False)
flt_fix_cost = Column(DECIMAL(20, 2), nullable=False)
flt_q_star = Column(DECIMAL(20, 2), nullable=False)
acc_demand_per_age = Column(DECIMAL(20, 2), nullable=False)
acc_wealth_per_age = Column(DECIMAL(20, 2), nullable=False)
std_demand_per_age = Column(DECIMAL(20, 2), nullable=False)
product = relationship('Product', back_populates='firm', uselist=False)
def __repr__(self):
return f'<Firm id: {self.id}>'
if __name__ == '__main__':
Base.metadata.drop_all()
Base.metadata.create_all()