This commit is contained in:
2023-03-13 19:47:25 +08:00
parent 2f162b970b
commit 09b59d8778
10 changed files with 113 additions and 210 deletions

81
orm.py
View File

@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
from sqlalchemy import create_engine, inspect
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import Column, Integer, String, ForeignKey, BigInteger, DECIMAL, DateTime, Text
from sqlalchemy import Column, Integer, String, ForeignKey, BigInteger, DateTime, PickleType, Boolean
from sqlalchemy.sql import func
from sqlalchemy.orm import relationship, Session
from sqlalchemy.pool import NullPool
@@ -40,22 +40,12 @@ class Experiment(Base):
idx_exp = Column(Integer, nullable=False)
# fixed parameters
int_n_country = Column(Integer, nullable=False)
max_int_n_supplier = Column(Integer, nullable=False) # uni(1, max), random parameter 1 of firm
int_n_product = Column(Integer, nullable=False)
int_n_firm_per_product_per_country = Column(Integer, nullable=False)
flt_demand_total = Column(DECIMAL(10, 2), nullable=False) # tri(0, total_demand, mean), to compute random para a
flt_bm_price_ratio = Column(DECIMAL(10, 2), nullable=False) # benchmark value of b, same for both countries
flt_beta_developing = Column(DECIMAL(10, 2), nullable=False) # benchmark value of c(beta), for developing countries
n_sample = Column(Integer, nullable=False)
n_iter = Column(Integer, nullable=False)
# variables
is_eliminated = Column(Integer, nullable=False)
flt_beta_developed = Column(DECIMAL(10, 2), nullable=False) # larger, for developed countries
lambda_tier = Column(DECIMAL(10, 2), nullable=False)
tariff_percentage_1 = Column(DECIMAL(10, 2), nullable=False)
tariff_percentage_2 = Column(DECIMAL(10, 2), nullable=False)
n_max_trial = Column(Integer, nullable=False)
dct_list_init_remove_firm_prod = Column(PickleType, nullable=False)
sample = relationship('Sample', back_populates='experiment', lazy='dynamic')
@@ -75,77 +65,30 @@ class Sample(Base):
ts_done = Column(DateTime(timezone=True), onupdate=func.now())
stop_t = Column(Integer, nullable=True)
c1_wealth = Column(DECIMAL(20, 2), nullable=True) # country 1, developing countries
c2_wealth = Column(DECIMAL(20, 2), nullable=True) # country 2, developed countries
c1_wealth_dgt = Column(Integer, nullable=True)
c2_wealth_dgt = Column(Integer, nullable=True)
c1_tariff = Column(DECIMAL(20, 2), nullable=True) # country 1, developing countries
c2_tariff = Column(DECIMAL(20, 2), nullable=True) # country 2, developed countries
c1_tariff_dgt = Column(Integer, nullable=True)
c2_tariff_dgt = Column(Integer, nullable=True)
c1_n_firms = Column(Integer, nullable=True)
c2_n_firms = Column(Integer, nullable=True)
c1_n_positive_firms = Column(Integer, nullable=True)
c2_n_positive_firms = Column(Integer, nullable=True)
network = Column(Text(4294000000), nullable=True)
network_order = Column(Text(4294000000), nullable=True)
network_country = Column(Text(4294000000), nullable=True)
experiment = relationship('Experiment', back_populates='sample', uselist=False)
product = relationship('Product', back_populates='sample', lazy='dynamic')
result = relationship('Result', back_populates='sample', lazy='dynamic')
def __repr__(self):
return f'<Sample id: {self.id}>'
class Product(Base):
__tablename__ = f"{db_name_prefix}_product"
class Result(Base):
__tablename__ = f"{db_name_prefix}_result"
id = Column(Integer, primary_key=True, autoincrement=True)
s_id = Column(Integer, ForeignKey('{}.id'.format(f"{db_name_prefix}_sample")), nullable=False)
int_name = Column(Integer, nullable=False)
int_tier = Column(Integer, nullable=False)
n_up_products = Column(Integer, nullable=False)
n_peer_products = Column(Integer, nullable=False)
n_positive_firms = Column(Integer, nullable=False)
n_all_firms = Column(Integer, nullable=False)
gini_acc_demand_per_age = Column(DECIMAL(10, 2), nullable=False)
gini_acc_wealth_per_age = Column(DECIMAL(10, 2), nullable=False)
gini_acc_demand_per_age_all = Column(DECIMAL(10, 2), nullable=False)
gini_acc_wealth_per_age_all = Column(DECIMAL(10, 2), nullable=False)
# lst_n_positive_firms = Column(Text(4294000000), nullable=False)
# lst_n_all_firms = Column(Text(4294000000), nullable=False)
# lst_gini_acc_demand_per_age = Column(Text(4294000000), nullable=False)
# lst_gini_acc_wealth_per_age = Column(Text(4294000000), nullable=False)
# lst_gini_acc_demand_per_age_all = Column(Text(4294000000), nullable=False)
# lst_gini_acc_wealth_per_age_all = Column(Text(4294000000), nullable=False)
id_firm = Column(Integer, nullable=False)
id_product = Column(Integer, nullable=False)
ts = Column(Integer, nullable=False)
is_disrupted = Column(Boolean, nullable=True)
is_removed = Column(Boolean, nullable=True)
sample = relationship('Sample', back_populates='product', uselist=False)
firm = relationship('Firm', back_populates='product', lazy='dynamic')
sample = relationship('Sample', back_populates='result', uselist=False)
def __repr__(self):
return f'<Product id: {self.id}>'
class Firm(Base):
__tablename__ = f"{db_name_prefix}_firm"
id = Column(Integer, primary_key=True, autoincrement=True)
p_id = Column(Integer, ForeignKey('{}.id'.format(f"{db_name_prefix}_product")), nullable=False)
idx_firm = Column(Integer, nullable=False)
int_n_supplier = Column(Integer, nullable=False)
flt_fix_cost = Column(DECIMAL(20, 2), nullable=False)
flt_q_star = Column(DECIMAL(20, 2), nullable=False)
acc_demand_per_age = Column(DECIMAL(20, 2), nullable=False)
acc_wealth_per_age = Column(DECIMAL(20, 2), nullable=False)
std_demand_per_age = Column(DECIMAL(20, 2), nullable=False)
product = relationship('Product', back_populates='firm', uselist=False)
def __repr__(self):
return f'<Firm id: {self.id}>'
if __name__ == '__main__':
Base.metadata.drop_all()
Base.metadata.create_all()