mesa-GA/GA_Agent_0925/orm.py

118 lines
4.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
from datetime import datetime
from sqlalchemy import create_engine, inspect, Inspector, Float
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import (Column, Integer, DECIMAL, String, ForeignKey,
BigInteger, DateTime, PickleType, Boolean, Text)
from sqlalchemy.sql import func
from sqlalchemy.orm import relationship, Session
from sqlalchemy.pool import NullPool
import yaml
with open('../conf_db.yaml') as file:
dct_conf_db_all = yaml.full_load(file)
is_local_db = dct_conf_db_all['is_local_db']
if is_local_db:
dct_conf_db = dct_conf_db_all['local']
else:
dct_conf_db = dct_conf_db_all['remote']
with open('../conf_db_prefix.yaml') as file:
dct_conf_db_prefix = yaml.full_load(file)
db_name_prefix = dct_conf_db_prefix['db_name_prefix']
str_login = 'mysql://{}:{}@{}:{}/{}'.format(dct_conf_db['user_name'],
dct_conf_db['password'],
dct_conf_db['address'],
dct_conf_db['port'],
dct_conf_db['db_name'])
# print('DB is {}:{}/{}'.format(dct_conf_db['address'], dct_conf_db['port'], dct_conf_db['db_name']))
# must be null pool to avoid connection lost error
engine = create_engine(str_login, poolclass=NullPool)
connection = engine.connect()
ins: Inspector = inspect(engine)
Base = declarative_base()
db_session = Session(bind=engine)
class Experiment(Base):
__tablename__ = f"{db_name_prefix}_experiment"
id = Column(Integer, primary_key=True, autoincrement=True)
idx_scenario = Column(Integer, nullable=False)
idx_init_removal = Column(Integer, nullable=False)
# fixed parameters
n_sample = Column(Integer, nullable=False)
n_iter = Column(Integer, nullable=False)
# variables
dct_lst_init_disrupt_firm_prod = Column(PickleType, nullable=False)
g_bom = Column(Text(4294000000), nullable=False)
n_max_trial = Column(Integer, nullable=False)
prf_size = Column(Boolean, nullable=False)
prf_conn = Column(Boolean, nullable=False)
cap_limit_prob_type = Column(String(16), nullable=False)
cap_limit_level = Column(DECIMAL(8, 4), nullable=False)
diff_new_conn = Column(DECIMAL(8, 4), nullable=False)
remove_t = Column(Integer, nullable=False)
netw_prf_n = Column(Integer, nullable=False)
sample = relationship(
'Sample', back_populates='experiment', lazy='dynamic')
def __repr__(self):
return f'<Experiment: {self.id}>'
class Sample(Base):
__tablename__ = f"{db_name_prefix}_sample"
id = Column(Integer, primary_key=True, autoincrement=True)
e_id = Column(Integer, ForeignKey('{}.id'.format(
f"{db_name_prefix}_experiment")), nullable=False)
idx_sample = Column(Integer, nullable=False)
seed = Column(BigInteger, nullable=False)
# -1, waiting; 0, running; 1, done
is_done_flag = Column(Integer, nullable=False)
computer_name = Column(String(64), nullable=True)
ts_done = Column(DateTime(timezone=True), onupdate=func.now())
stop_t = Column(Integer, nullable=True)
g_firm = Column(Text(4294000000), nullable=True)
experiment = relationship(
'Experiment', back_populates='sample', uselist=False)
result = relationship('Result', back_populates='sample', lazy='dynamic')
def __repr__(self):
return f'<Sample id: {self.id}>'
class Result(Base):
__tablename__ = f"{db_name_prefix}_result"
id = Column(Integer, primary_key=True, autoincrement=True)
s_id = Column(Integer, ForeignKey('{}.id'.format(
f"{db_name_prefix}_sample")), nullable=False)
id_firm = Column(String(20), nullable=False)
id_product = Column(String(20), nullable=False)
ts = Column(Integer, nullable=False)
status = Column(String(5), nullable=False)
sample = relationship('Sample', back_populates='result', uselist=False)
# 💥 新增 GA 调用 ID用于标记属于哪一次遗传算法运行
ga_id = Column(String(50), nullable=True)
def __repr__(self):
return f'<Product id: {self.id}>'
if __name__ == '__main__':
Base.metadata.drop_all()
Base.metadata.create_all()