mesa/main.py

62 lines
2.1 KiB
Python
Raw Normal View History

2024-08-24 11:20:13 +08:00
import os
import random
import time
from multiprocessing import Process
import argparse
from computation import Computation
from sqlalchemy.orm import close_all_sessions
import yaml
from controller_db import ControllerDB
def controll_db_and_process(exp_argument, reset_sample_argument, reset_db_argument):
from controller_db import ControllerDB
controller_db = ControllerDB(exp_argument, reset_flag=reset_sample_argument)
# controller_db.reset_db()
# force drop
controller_db.reset_db(force_drop=reset_db_argument)
# 准备样本表
controller_db.prepare_list_sample()
close_all_sessions()
# 调用 do_process 利用计算机进行多核处理 仿真 将数据库中
do_process(do_computation, controller_db)
def do_process(target: object, controller_db: ControllerDB, ):
process_list = []
for i in range(int(args.job)):
p = Process(target=do_computation, args=(controller_db,))
p.start()
process_list.append(p)
for i in process_list:
i.join()
def do_computation(c_db):
exp = Computation(c_db)
while 1:
2024-09-24 19:21:59 +08:00
time.sleep(random.uniform(0, 1))
2024-08-24 11:20:13 +08:00
is_all_done = exp.run()
if is_all_done:
break
if __name__ == '__main__':
# 输入参数
parser = argparse.ArgumentParser(description='setting')
parser.add_argument('--exp', type=str, default='without_exp')
parser.add_argument('--job', type=int, default='4')
2024-08-24 11:20:13 +08:00
parser.add_argument('--reset_sample', type=int, default='0')
2024-09-21 22:39:09 +08:00
parser.add_argument('--reset_db', type=bool, default=False)
2024-08-24 11:20:13 +08:00
args = parser.parse_args()
# 几核参与进程
assert args.job >= 1, 'Number of jobs should >= 1'
# 控制参数 利用 prefix_file_name 前缀名字 控制 2项不同的实验
prefix_file_name = 'conf_db_prefix.yaml'
if os.path.exists(prefix_file_name):
os.remove(prefix_file_name)
with open(prefix_file_name, 'w', encoding='utf-8') as file:
yaml.dump({'db_name_prefix': args.exp}, file)
# 数据库连接控制 和 进行模型运行
controll_db_and_process(args.exp, args.reset_sample, args.reset_db)