import os import random import time from multiprocessing import Process import argparse from computation import Computation from sqlalchemy.orm import close_all_sessions import yaml def do_computation(c_db): exp = Computation(c_db) while 1: time.sleep(random.uniform(0, 10)) is_all_done = exp.run() if is_all_done: break if __name__ == '__main__': parser = argparse.ArgumentParser(description='setting') parser.add_argument('--exp', type=str, default='test') parser.add_argument('--job', type=int, default='3') parser.add_argument('--reset_sample', type=int, default='0') parser.add_argument('--reset_db', type=bool, default=False) args = parser.parse_args() assert args.job >= 1, 'Number of jobs should >= 1' prefix_file_name = 'conf_db_prefix.yaml' if os.path.exists(prefix_file_name): os.remove(prefix_file_name) with open(prefix_file_name, 'w', encoding='utf-8') as file: yaml.dump({'db_name_prefix': args.exp}, file) from controller_db import ControllerDB controller_db = ControllerDB(args.exp, reset_flag=args.reset_sample) # controller_db.reset_db() # force drop controller_db.reset_db(force_drop=args.reset_db) controller_db.prepare_list_sample() close_all_sessions() process_list = [] for i in range(int(args.job)): p = Process(target=do_computation, args=(controller_db,)) p.start() process_list.append(p) for i in process_list: i.join()