add --num-workers arg

This commit is contained in:
Youngjoon Lee 2024-08-08 15:04:54 +09:00
parent 156bff6ddc
commit ff5baa4bfe
No known key found for this signature in database
GPG Key ID: 167546E2D1712F8C
2 changed files with 4 additions and 4 deletions

View File

@ -19,6 +19,7 @@ if __name__ == "__main__":
required=True,
help=f"Queue type: {' | '.join([t.value for t in TemporalMixType])}",
)
parser.add_argument("--num-workers", type=int, required=True, help="num workers")
parser.add_argument("--outdir", type=str, required=True, help="output directory")
parser.add_argument(
"--from-paramset",
@ -33,6 +34,7 @@ if __name__ == "__main__":
ExperimentID(args.exp_id),
SessionID(args.session_id),
TemporalMixType(args.queue_type),
args.num_workers,
args.outdir,
args.from_paramset,
)

View File

@ -71,6 +71,7 @@ def run_session(
exp_id: ExperimentID,
session_id: SessionID,
queue_type: TemporalMixType,
num_workers: int,
outdir: str,
from_paramset: int = 1,
):
@ -93,10 +94,7 @@ def run_session(
future_map: dict[concurrent.futures.Future[tuple[bool, float]], IterationInfo] = (
dict()
)
total_cores = os.cpu_count()
assert total_cores is not None
max_workers = max(1, total_cores - 1)
with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) as executor:
with concurrent.futures.ProcessPoolExecutor(max_workers=num_workers) as executor:
# Submit all iterations of all parameter sets to the ProcessPoolExecutor
for paramset_idx, paramset in enumerate(paramsets):
paramset_id = paramset_idx + 1