From 745763ce2e69c6788bcff10e491b1b1cdc96caca Mon Sep 17 00:00:00 2001 From: HaoYizhi Date: Mon, 15 May 2023 13:44:21 +0800 Subject: [PATCH] format --- controller_db.py | 31 ++++++++++++++++++------------- model.py | 14 +++++++++----- orm.py | 10 +++++++--- 3 files changed, 34 insertions(+), 21 deletions(-) diff --git a/controller_db.py b/controller_db.py index 6c1bec4..0f4052a 100644 --- a/controller_db.py +++ b/controller_db.py @@ -6,11 +6,11 @@ from sqlalchemy import text import yaml import random import pandas as pd -import numpy as np import platform import networkx as nx import json + class ControllerDB: dct_parameter = None is_test: bool = None @@ -31,7 +31,8 @@ class ControllerDB: **dct_para_in_test } print(self.dct_parameter) - self.reset_flag = reset_flag # 0, not reset; 1, reset self; 2, reset all + # 0, not reset; 1, reset self; 2, reset all + self.reset_flag = reset_flag self.lst_saved_s_id = [] def init_tables(self): @@ -73,7 +74,7 @@ class ControllerDB: # insert exp for idx_exp, dct in enumerate(list_dct): self.add_experiment_1(idx_exp, self.dct_parameter['n_max_trial'], - dct, g_product_js) # same g_bom for all exp + dct, g_product_js) # same g_bom for all exp print(f'Inserted experiment for exp {idx_exp}!') def add_experiment_1(self, idx_exp, n_max_trial, @@ -127,16 +128,19 @@ class ControllerDB: else: lst_table_obj.remove(a_table) print( - f"Table {a_table.name} is dropped for exp: {self.db_name_prefix}!!!" + f"Table {a_table.name} is dropped " + f"for exp: {self.db_name_prefix}!!!" ) finally: is_exist = len(lst_table_obj) > 0 if is_exist: print( - f"All tables exist. No need to reset for exp: {self.db_name_prefix}." + f"All tables exist. No need to reset " + f"for exp: {self.db_name_prefix}." ) - # change the is_done_flag from 0 to -1, to rerun the in-finished tasks + # change the is_done_flag from 0 to -1 + # rerun the in-finished tasks if self.reset_flag > 0: if self.reset_flag == 2: sample = db_session.query(Sample).filter( @@ -152,7 +156,7 @@ class ControllerDB: qry_result = db_session.query(Result).filter_by( s_id=s.id) if qry_result.count() > 0: - db_session.query(Result).filter(s_id = s.id).delete() + db_session.query(Result).filter(s_id=s.id).delete() db_session.commit() s.is_done_flag = -1 db_session.commit() @@ -161,20 +165,21 @@ class ControllerDB: Base.metadata.create_all(bind=engine) self.init_tables() print( - f"All tables are just created and initialized for exp: {self.db_name_prefix}." + f"All tables are just created and initialized " + f"for exp: {self.db_name_prefix}." ) def prepare_list_sample(self): res = db_session.execute( - text(f'''SELECT count(*) FROM {self.db_name_prefix}_sample s, - {self.db_name_prefix}_experiment e WHERE s.e_id=e.id ''' + text(f"SELECT count(*) FROM {self.db_name_prefix}_sample s, " + f"{self.db_name_prefix}_experiment e WHERE s.e_id=e.id" )).scalar() n_sample = 0 if res is None else res print(f'There are a total of {n_sample} samples.') res = db_session.execute( - text( - f'SELECT id FROM {self.db_name_prefix}_sample WHERE is_done_flag = -1' - )) + text(f"SELECT id FROM {self.db_name_prefix}_sample " + f"WHERE is_done_flag = -1" + )) for row in res: s_id = row[0] self.lst_saved_s_id.append(s_id) diff --git a/model.py b/model.py index e76e44f..fc6754f 100644 --- a/model.py +++ b/model.py @@ -59,7 +59,8 @@ class Model(ap.Model): size / sum(lst_succ_firm_size) for size in lst_succ_firm_size ] - # select multiple successors based on relative size of this firm + # select multiple successors + # based on relative size of this firm lst_same_prod_firm = Firm['Code'][Firm[product_code] == 1].to_list() lst_same_prod_firm_size = [ @@ -202,7 +203,8 @@ class Model(ap.Model): if n_up_product_removed == 0: continue else: - # update a_lst_product_disrupted / dct_lst_disrupt_firm_prod + # update a_lst_product_disrupted + # update dct_lst_disrupt_firm_prod if product not in firm.a_lst_product_disrupted: firm.a_lst_product_disrupted.append(product) if firm in self.dct_lst_disrupt_firm_prod.keys(): @@ -212,7 +214,8 @@ class Model(ap.Model): self.dct_lst_disrupt_firm_prod[ firm] = ap.AgentList( self.model, [product]) - # update a_lst_product_removed / dct_list_remove_firm_prod + # update a_lst_product_removed + # update dct_list_remove_firm_prod # mark disrupted firm as removed based conditionally lost_percent = n_up_product_removed / len( product.a_predecessors()) @@ -220,8 +223,9 @@ class Model(ap.Model): std_size = (firm.revenue_log - min(lst_size) + 1) / (max(lst_size) - min(lst_size) + 1) prod_remove = 1 - std_size * (1 - lost_percent) - if self.nprandom.choice( - [True, False], p=[prod_remove, 1 - prod_remove]): + if self.nprandom.choice([True, False], + p=[prod_remove, + 1 - prod_remove]): firm.a_lst_product_removed.append(product) if firm in self.dct_lst_remove_firm_prod.keys(): self.dct_lst_remove_firm_prod[firm].append( diff --git a/orm.py b/orm.py index 5e12c73..9779f80 100644 --- a/orm.py +++ b/orm.py @@ -1,7 +1,8 @@ # -*- coding: utf-8 -*- from sqlalchemy import create_engine, inspect from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy import Column, Integer, String, ForeignKey, BigInteger, DateTime, PickleType, Boolean, Text +from sqlalchemy import (Column, Integer, String, ForeignKey, BigInteger, + DateTime, PickleType, Boolean, Text) from sqlalchemy.sql import func from sqlalchemy.orm import relationship, Session from sqlalchemy.pool import NullPool @@ -21,8 +22,11 @@ with open('conf_db_prefix.yaml') as file: db_name_prefix = dct_conf_db_prefix['db_name_prefix'] -str_login = 'mysql://{}:{}@{}:{}/{}'.format(dct_conf_db['user_name'], dct_conf_db['password'], - dct_conf_db['address'], dct_conf_db['port'], dct_conf_db['db_name']) +str_login = 'mysql://{}:{}@{}:{}/{}'.format(dct_conf_db['user_name'], + dct_conf_db['password'], + dct_conf_db['address'], + dct_conf_db['port'], + dct_conf_db['db_name']) print('DB is {}:{}/{}'.format(dct_conf_db['address'], dct_conf_db['port'], dct_conf_db['db_name']))