Try multiprocessing

This commit is contained in:
Hsiao-Wei Wang 2023-05-05 23:03:25 +08:00
parent 9f5bb03cb4
commit aeccd20fd1
No known key found for this signature in database
GPG Key ID: AE3D6B174F971DE4
3 changed files with 100 additions and 49 deletions

View File

@ -1184,7 +1184,7 @@ setup(
extras_require={ extras_require={
"test": ["pytest>=4.4", "pytest-cov", "pytest-xdist"], "test": ["pytest>=4.4", "pytest-cov", "pytest-xdist"],
"lint": ["flake8==5.0.4", "mypy==0.981", "pylint==2.15.3"], "lint": ["flake8==5.0.4", "mypy==0.981", "pylint==2.15.3"],
"generator": ["python-snappy==0.6.1", "filelock"], "generator": ["python-snappy==0.6.1", "filelock", "pathos==0.3.0"],
"docs": ["mkdocs==1.4.2", "mkdocs-material==9.1.5", "mdx-truly-sane-lists==1.3", "mkdocs-awesome-pages-plugin==2.8.0"] "docs": ["mkdocs==1.4.2", "mkdocs-material==9.1.5", "mdx-truly-sane-lists==1.3", "mkdocs-awesome-pages-plugin==2.8.0"]
}, },
install_requires=[ install_requires=[

View File

@ -11,12 +11,16 @@ import sys
import json import json
from typing import Iterable, AnyStr, Any, Callable from typing import Iterable, AnyStr, Any, Callable
import traceback import traceback
import multiprocessing
from collections import namedtuple
from ruamel.yaml import ( from ruamel.yaml import (
YAML, YAML,
) )
from filelock import FileLock from filelock import FileLock
from snappy import compress from snappy import compress
from pathos.multiprocessing import ProcessingPool as Pool
from eth_utils import encode_hex from eth_utils import encode_hex
@ -32,6 +36,12 @@ context.is_pytest = False
TIME_THRESHOLD_TO_PRINT = 1.0 # seconds TIME_THRESHOLD_TO_PRINT = 1.0 # seconds
# Generator mode setting
MODE_SINGLE_PROCESS = 'MODE_SINGLE_PROCESS'
MODE_MULTIPROCESSING = 'MODE_MULTIPROCESSING'
GENERATOR_MODE = MODE_SINGLE_PROCESS
@dataclass @dataclass
class Diagnostics(object): class Diagnostics(object):
@ -41,6 +51,45 @@ class Diagnostics(object):
test_identifiers: list = field(default_factory=list) test_identifiers: list = field(default_factory=list)
TestCaseParams = namedtuple(
'TestCaseParams', [
'test_case', 'case_dir', 'log_file', 'file_mode',
])
def worker_function(item):
return generate_test_vector(*item)
def get_default_yaml():
yaml = YAML(pure=True)
yaml.default_flow_style = None
def _represent_none(self, _):
return self.represent_scalar('tag:yaml.org,2002:null', 'null')
yaml.representer.add_representer(type(None), _represent_none)
return yaml
def get_cfg_yaml():
# Spec config is using a YAML subset
cfg_yaml = YAML(pure=True)
cfg_yaml.default_flow_style = False # Emit separate line for each key
def cfg_represent_bytes(self, data):
return self.represent_int(encode_hex(data))
cfg_yaml.representer.add_representer(bytes, cfg_represent_bytes)
def cfg_represent_quoted_str(self, data):
return self.represent_scalar(u'tag:yaml.org,2002:str', data, style="'")
cfg_yaml.representer.add_representer(context.quoted_str, cfg_represent_quoted_str)
return cfg_yaml
def validate_output_dir(path_str): def validate_output_dir(path_str):
path = Path(path_str) path = Path(path_str)
@ -148,28 +197,6 @@ def run_generator(generator_name, test_providers: Iterable[TestProvider]):
else: else:
file_mode = "w" file_mode = "w"
yaml = YAML(pure=True)
yaml.default_flow_style = None
def _represent_none(self, _):
return self.represent_scalar('tag:yaml.org,2002:null', 'null')
yaml.representer.add_representer(type(None), _represent_none)
# Spec config is using a YAML subset
cfg_yaml = YAML(pure=True)
cfg_yaml.default_flow_style = False # Emit separate line for each key
def cfg_represent_bytes(self, data):
return self.represent_int(encode_hex(data))
cfg_yaml.representer.add_representer(bytes, cfg_represent_bytes)
def cfg_represent_quoted_str(self, data):
return self.represent_scalar(u'tag:yaml.org,2002:str', data, style="'")
cfg_yaml.representer.add_representer(context.quoted_str, cfg_represent_quoted_str)
log_file = Path(output_dir) / 'testgen_error_log.txt' log_file = Path(output_dir) / 'testgen_error_log.txt'
print(f"Generating tests into {output_dir}") print(f"Generating tests into {output_dir}")
@ -185,8 +212,11 @@ def run_generator(generator_name, test_providers: Iterable[TestProvider]):
collect_only = args.collect_only collect_only = args.collect_only
diagnostics_obj = Diagnostics() diagnostics_obj = Diagnostics()
provider_start = time.time() provider_start = time.time()
if GENERATOR_MODE == MODE_MULTIPROCESSING:
all_test_case_params = []
for tprov in test_providers: for tprov in test_providers:
if not collect_only: if not collect_only:
# runs anything that we don't want to repeat for every test case. # runs anything that we don't want to repeat for every test case.
@ -205,13 +235,20 @@ def run_generator(generator_name, test_providers: Iterable[TestProvider]):
if is_skip: if is_skip:
continue continue
# generate test vector if GENERATOR_MODE == MODE_SINGLE_PROCESS:
is_skip, diagnostics_obj = generate_test_vector_and_diagnose( result = generate_test_vector(test_case, case_dir, log_file, file_mode)
test_case, case_dir, log_file, file_mode, write_result_into_diagnostics_obj(result, diagnostics_obj)
cfg_yaml, yaml, diagnostics_obj, elif GENERATOR_MODE == MODE_MULTIPROCESSING:
) item = TestCaseParams(test_case, case_dir, log_file, file_mode)
if is_skip: all_test_case_params.append(item)
continue
if GENERATOR_MODE == MODE_MULTIPROCESSING:
num_process = multiprocessing.cpu_count() // 2 - 1
with Pool(processes=num_process) as pool:
results = pool.map(worker_function, iter(all_test_case_params))
for result in results:
write_result_into_diagnostics_obj(result, diagnostics_obj)
provider_end = time.time() provider_end = time.time()
span = round(provider_end - provider_start, 2) span = round(provider_end - provider_start, 2)
@ -249,54 +286,55 @@ def run_generator(generator_name, test_providers: Iterable[TestProvider]):
print(f"wrote diagnostics_obj to {diagnostics_path}") print(f"wrote diagnostics_obj to {diagnostics_path}")
def generate_test_vector_and_diagnose(test_case, case_dir, log_file, file_mode, cfg_yaml, yaml, diagnostics_obj): def generate_test_vector(test_case, case_dir, log_file, file_mode):
is_skip = False cfg_yaml = get_cfg_yaml()
yaml = get_default_yaml()
written_part = False
print(f'Generating test: {case_dir}') print(f'Generating test: {case_dir}')
test_start = time.time() test_start = time.time()
written_part = False
# Add `INCOMPLETE` tag file to indicate that the test generation has not completed. # Add `INCOMPLETE` tag file to indicate that the test generation has not completed.
incomplete_tag_file = get_incomplete_tag_file(case_dir) incomplete_tag_file = get_incomplete_tag_file(case_dir)
case_dir.mkdir(parents=True, exist_ok=True) case_dir.mkdir(parents=True, exist_ok=True)
with incomplete_tag_file.open("w") as f: with incomplete_tag_file.open("w") as f:
f.write("\n") f.write("\n")
result = None
try: try:
meta = dict() meta = dict()
try: try:
written_part, meta = execute_test(test_case, case_dir, meta, log_file, file_mode, cfg_yaml, yaml) written_part, meta = execute_test(test_case, case_dir, meta, log_file, file_mode, cfg_yaml, yaml)
except SkippedTest as e: except SkippedTest as e:
result = 0 # 0 means skipped
print(e) print(e)
diagnostics_obj.skipped_test_count += 1
shutil.rmtree(case_dir) shutil.rmtree(case_dir)
is_skip = True return result
return is_skip, diagnostics_obj
# Once all meta data is collected (if any), write it to a meta data file. # Once all meta data is collected (if any), write it to a meta data file.
if len(meta) != 0: if len(meta) != 0:
written_part = True written_part = True
output_part(case_dir, log_file, "data", "meta", dump_yaml_fn(meta, "meta", file_mode, yaml)) output_part(case_dir, log_file, "data", "meta", dump_yaml_fn(meta, "meta", file_mode, yaml))
if not written_part:
print(f"test case {case_dir} did not produce any test case parts")
except Exception as e: except Exception as e:
result = -1 # -1 means error
error_message = f"[ERROR] failed to generate vector(s) for test {case_dir}: {e}" error_message = f"[ERROR] failed to generate vector(s) for test {case_dir}: {e}"
# Write to error log file # Write to error log file
with log_file.open("a+") as f: with log_file.open("a+") as f:
f.write(error_message) f.write(error_message)
traceback.print_exc(file=f) traceback.print_exc(file=f)
f.write('\n') f.write('\n')
print(error_message)
traceback.print_exc() traceback.print_exc()
else: else:
# If no written_part, the only file was incomplete_tag_file. Clear the existing case_dir folder. # If no written_part, the only file was incomplete_tag_file. Clear the existing case_dir folder.
if not written_part: if not written_part:
print(f"test case {case_dir} did not produce any written_part")
shutil.rmtree(case_dir) shutil.rmtree(case_dir)
result = -1
else: else:
diagnostics_obj.generated_test_count += 1 result = get_test_identifier(test_case)
test_identifier = get_test_identifier(test_case)
diagnostics_obj.test_identifiers.append(test_identifier)
# Only remove `INCOMPLETE` tag file # Only remove `INCOMPLETE` tag file
os.remove(incomplete_tag_file) os.remove(incomplete_tag_file)
test_end = time.time() test_end = time.time()
@ -304,7 +342,19 @@ def generate_test_vector_and_diagnose(test_case, case_dir, log_file, file_mode,
if span > TIME_THRESHOLD_TO_PRINT: if span > TIME_THRESHOLD_TO_PRINT:
print(f' - generated in {span} seconds') print(f' - generated in {span} seconds')
return is_skip, diagnostics_obj return result
def write_result_into_diagnostics_obj(result, diagnostics_obj):
if result == -1: # error
pass
elif result == 0:
diagnostics_obj.skipped_test_count += 1
elif result is not None:
diagnostics_obj.generated_test_count += 1
diagnostics_obj.test_identifiers.append(result)
else:
raise Exception(f"Unexpected result: {result}")
def dump_yaml_fn(data: Any, name: str, file_mode: str, yaml_encoder: YAML): def dump_yaml_fn(data: Any, name: str, file_mode: str, yaml_encoder: YAML):
@ -312,6 +362,7 @@ def dump_yaml_fn(data: Any, name: str, file_mode: str, yaml_encoder: YAML):
out_path = case_path / Path(name + '.yaml') out_path = case_path / Path(name + '.yaml')
with out_path.open(file_mode) as f: with out_path.open(file_mode) as f:
yaml_encoder.dump(data, f) yaml_encoder.dump(data, f)
f.close()
return dump return dump
@ -320,14 +371,14 @@ def output_part(case_dir, log_file, out_kind: str, name: str, fn: Callable[[Path
case_dir.mkdir(parents=True, exist_ok=True) case_dir.mkdir(parents=True, exist_ok=True)
try: try:
fn(case_dir) fn(case_dir)
except IOError as e: except (IOError, ValueError) as e:
error_message = f'[Error] error when dumping test "{case_dir}", part "{name}", kind "{out_kind}": {e}' error_message = f'[Error] error when dumping test "{case_dir}", part "{name}", kind "{out_kind}": {e}'
# Write to error log file # Write to error log file
with log_file.open("a+") as f: with log_file.open("a+") as f:
f.write(error_message) f.write(error_message)
traceback.print_exc(file=f) traceback.print_exc(file=f)
f.write('\n') f.write('\n')
print(error_message)
sys.exit(error_message) sys.exit(error_message)

View File

@ -560,7 +560,7 @@ def _get_basic_dict(ssz_dict: Dict[str, Any]) -> Dict[str, Any]:
return result return result
def _get_copy_of_spec(spec): def get_copy_of_spec(spec):
fork = spec.fork fork = spec.fork
preset = spec.config.PRESET_BASE preset = spec.config.PRESET_BASE
module_path = f"eth2spec.{fork}.{preset}" module_path = f"eth2spec.{fork}.{preset}"
@ -601,14 +601,14 @@ def with_config_overrides(config_overrides, emitted_fork=None, emit=True):
def decorator(fn): def decorator(fn):
def wrapper(*args, spec: Spec, **kw): def wrapper(*args, spec: Spec, **kw):
# Apply config overrides to spec # Apply config overrides to spec
spec, output_config = spec_with_config_overrides(_get_copy_of_spec(spec), config_overrides) spec, output_config = spec_with_config_overrides(get_copy_of_spec(spec), config_overrides)
# Apply config overrides to additional phases, if present # Apply config overrides to additional phases, if present
if 'phases' in kw: if 'phases' in kw:
phases = {} phases = {}
for fork in kw['phases']: for fork in kw['phases']:
phases[fork], output = spec_with_config_overrides( phases[fork], output = spec_with_config_overrides(
_get_copy_of_spec(kw['phases'][fork]), config_overrides) get_copy_of_spec(kw['phases'][fork]), config_overrides)
if emitted_fork == fork: if emitted_fork == fork:
output_config = output output_config = output
kw['phases'] = phases kw['phases'] = phases