2023-03-07 12:29:27 +08:00
|
|
|
import os
|
|
|
|
import random
|
|
|
|
import time
|
|
|
|
from multiprocessing import Process
|
|
|
|
import argparse
|
|
|
|
from computation import Computation
|
|
|
|
from sqlalchemy.orm import close_all_sessions
|
|
|
|
|
|
|
|
import yaml
|
|
|
|
|
|
|
|
|
|
|
|
def do_computation(c_db):
|
|
|
|
exp = Computation(c_db)
|
|
|
|
|
|
|
|
while 1:
|
|
|
|
time.sleep(random.uniform(0, 10))
|
|
|
|
is_all_done = exp.run()
|
|
|
|
if is_all_done:
|
|
|
|
break
|
|
|
|
|
|
|
|
|
2023-03-12 12:02:01 +08:00
|
|
|
if __name__ == '__main__':
|
|
|
|
parser = argparse.ArgumentParser(description='setting')
|
|
|
|
parser.add_argument('--exp', type=str, default='test')
|
|
|
|
parser.add_argument('--job', type=int, default='3')
|
|
|
|
parser.add_argument('--reset', type=int, default='0')
|
2023-03-07 12:29:27 +08:00
|
|
|
|
2023-03-12 12:02:01 +08:00
|
|
|
args = parser.parse_args()
|
|
|
|
assert args.job >= 1, 'Number of jobs should >= 1'
|
2023-03-07 12:29:27 +08:00
|
|
|
|
2023-03-12 12:02:01 +08:00
|
|
|
prefix_file_name = 'conf_db_prefix.yaml'
|
|
|
|
if os.path.exists(prefix_file_name):
|
|
|
|
os.remove(prefix_file_name)
|
|
|
|
with open(prefix_file_name, 'w', encoding='utf-8') as file:
|
|
|
|
yaml.dump({'db_name_prefix': args.exp}, file)
|
2023-03-07 12:29:27 +08:00
|
|
|
|
2023-03-12 12:02:01 +08:00
|
|
|
from controller_db import ControllerDB
|
|
|
|
controller_db = ControllerDB(args.exp, reset_flag=args.reset)
|
2023-03-13 19:47:25 +08:00
|
|
|
# controller_db.reset_db()
|
|
|
|
controller_db.reset_db(force_drop=True)
|
2023-03-07 12:29:27 +08:00
|
|
|
|
2023-03-12 12:02:01 +08:00
|
|
|
controller_db.prepare_list_sample()
|
2023-03-07 12:29:27 +08:00
|
|
|
|
2023-03-12 12:02:01 +08:00
|
|
|
close_all_sessions()
|
2023-03-07 12:29:27 +08:00
|
|
|
|
2023-03-12 12:02:01 +08:00
|
|
|
process_list = []
|
|
|
|
for i in range(int(args.job)):
|
|
|
|
p = Process(target=do_computation, args=(controller_db,))
|
|
|
|
p.start()
|
|
|
|
process_list.append(p)
|
2023-03-07 12:29:27 +08:00
|
|
|
|
2023-03-12 12:02:01 +08:00
|
|
|
for i in process_list:
|
|
|
|
i.join()
|