Merge pull request #3347 from ethereum/testgen-refactor

Multiprocessing testgen runner
This commit is contained in:
Hsiao-Wei Wang 2023-05-18 23:03:17 +08:00 committed by GitHub
commit e18e9743ed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 261 additions and 146 deletions

View File

@ -1184,7 +1184,7 @@ setup(
extras_require={ extras_require={
"test": ["pytest>=4.4", "pytest-cov", "pytest-xdist"], "test": ["pytest>=4.4", "pytest-cov", "pytest-xdist"],
"lint": ["flake8==5.0.4", "mypy==0.981", "pylint==2.15.3"], "lint": ["flake8==5.0.4", "mypy==0.981", "pylint==2.15.3"],
"generator": ["python-snappy==0.6.1", "filelock"], "generator": ["python-snappy==0.6.1", "filelock", "pathos==0.3.0"],
"docs": ["mkdocs==1.4.2", "mkdocs-material==9.1.5", "mdx-truly-sane-lists==1.3", "mkdocs-awesome-pages-plugin==2.8.0"] "docs": ["mkdocs==1.4.2", "mkdocs-material==9.1.5", "mdx-truly-sane-lists==1.3", "mkdocs-awesome-pages-plugin==2.8.0"]
}, },
install_requires=[ install_requires=[

View File

@ -1,4 +1,7 @@
from eth_utils import encode_hex from dataclasses import (
dataclass,
field,
)
import os import os
import time import time
import shutil import shutil
@ -8,24 +11,80 @@ import sys
import json import json
from typing import Iterable, AnyStr, Any, Callable from typing import Iterable, AnyStr, Any, Callable
import traceback import traceback
from collections import namedtuple
from ruamel.yaml import ( from ruamel.yaml import (
YAML, YAML,
) )
from filelock import FileLock from filelock import FileLock
from snappy import compress from snappy import compress
from pathos.multiprocessing import ProcessingPool as Pool
from eth_utils import encode_hex
from eth2spec.test import context from eth2spec.test import context
from eth2spec.test.exceptions import SkippedTest from eth2spec.test.exceptions import SkippedTest
from .gen_typing import TestProvider from .gen_typing import TestProvider
from .settings import (
GENERATOR_MODE,
MODE_MULTIPROCESSING,
MODE_SINGLE_PROCESS,
NUM_PROCESS,
TIME_THRESHOLD_TO_PRINT,
)
# Flag that the runner does NOT run test via pytest # Flag that the runner does NOT run test via pytest
context.is_pytest = False context.is_pytest = False
TIME_THRESHOLD_TO_PRINT = 1.0 # seconds @dataclass
class Diagnostics(object):
collected_test_count: int = 0
generated_test_count: int = 0
skipped_test_count: int = 0
test_identifiers: list = field(default_factory=list)
TestCaseParams = namedtuple(
'TestCaseParams', [
'test_case', 'case_dir', 'log_file', 'file_mode',
])
def worker_function(item):
return generate_test_vector(*item)
def get_default_yaml():
yaml = YAML(pure=True)
yaml.default_flow_style = None
def _represent_none(self, _):
return self.represent_scalar('tag:yaml.org,2002:null', 'null')
yaml.representer.add_representer(type(None), _represent_none)
return yaml
def get_cfg_yaml():
# Spec config is using a YAML subset
cfg_yaml = YAML(pure=True)
cfg_yaml.default_flow_style = False # Emit separate line for each key
def cfg_represent_bytes(self, data):
return self.represent_int(encode_hex(data))
cfg_yaml.representer.add_representer(bytes, cfg_represent_bytes)
def cfg_represent_quoted_str(self, data):
return self.represent_scalar(u'tag:yaml.org,2002:str', data, style="'")
cfg_yaml.representer.add_representer(context.quoted_str, cfg_represent_quoted_str)
return cfg_yaml
def validate_output_dir(path_str): def validate_output_dir(path_str):
@ -40,6 +99,47 @@ def validate_output_dir(path_str):
return path return path
def get_test_case_dir(test_case, output_dir):
return (
Path(output_dir) / Path(test_case.preset_name) / Path(test_case.fork_name)
/ Path(test_case.runner_name) / Path(test_case.handler_name)
/ Path(test_case.suite_name) / Path(test_case.case_name)
)
def get_test_identifier(test_case):
return "::".join([
test_case.preset_name,
test_case.fork_name,
test_case.runner_name,
test_case.handler_name,
test_case.suite_name,
test_case.case_name
])
def get_incomplete_tag_file(case_dir):
return case_dir / "INCOMPLETE"
def should_skip_case_dir(case_dir, is_force, diagnostics_obj):
is_skip = False
incomplete_tag_file = get_incomplete_tag_file(case_dir)
if case_dir.exists():
if not is_force and not incomplete_tag_file.exists():
diagnostics_obj.skipped_test_count += 1
print(f'Skipping already existing test: {case_dir}')
is_skip = True
else:
print(f'Warning, output directory {case_dir} already exist,'
' old files will be deleted and it will generate test vector files with the latest version')
# Clear the existing case_dir folder
shutil.rmtree(case_dir)
return is_skip, diagnostics_obj
def run_generator(generator_name, test_providers: Iterable[TestProvider]): def run_generator(generator_name, test_providers: Iterable[TestProvider]):
""" """
Implementation for a general test generator. Implementation for a general test generator.
@ -94,28 +194,6 @@ def run_generator(generator_name, test_providers: Iterable[TestProvider]):
else: else:
file_mode = "w" file_mode = "w"
yaml = YAML(pure=True)
yaml.default_flow_style = None
def _represent_none(self, _):
return self.represent_scalar('tag:yaml.org,2002:null', 'null')
yaml.representer.add_representer(type(None), _represent_none)
# Spec config is using a YAML subset
cfg_yaml = YAML(pure=True)
cfg_yaml.default_flow_style = False # Emit separate line for each key
def cfg_represent_bytes(self, data):
return self.represent_int(encode_hex(data))
cfg_yaml.representer.add_representer(bytes, cfg_represent_bytes)
def cfg_represent_quoted_str(self, data):
return self.represent_scalar(u'tag:yaml.org,2002:str', data, style="'")
cfg_yaml.representer.add_representer(context.quoted_str, cfg_represent_quoted_str)
log_file = Path(output_dir) / 'testgen_error_log.txt' log_file = Path(output_dir) / 'testgen_error_log.txt'
print(f"Generating tests into {output_dir}") print(f"Generating tests into {output_dir}")
@ -129,12 +207,13 @@ def run_generator(generator_name, test_providers: Iterable[TestProvider]):
print(f"Filtering test-generator runs to only include presets: {', '.join(presets)}") print(f"Filtering test-generator runs to only include presets: {', '.join(presets)}")
collect_only = args.collect_only collect_only = args.collect_only
collected_test_count = 0
generated_test_count = 0
skipped_test_count = 0
test_identifiers = []
diagnostics_obj = Diagnostics()
provider_start = time.time() provider_start = time.time()
if GENERATOR_MODE == MODE_MULTIPROCESSING:
all_test_case_params = []
for tprov in test_providers: for tprov in test_providers:
if not collect_only: if not collect_only:
# runs anything that we don't want to repeat for every test case. # runs anything that we don't want to repeat for every test case.
@ -145,146 +224,133 @@ def run_generator(generator_name, test_providers: Iterable[TestProvider]):
if len(presets) != 0 and test_case.preset_name not in presets: if len(presets) != 0 and test_case.preset_name not in presets:
continue continue
case_dir = ( case_dir = get_test_case_dir(test_case, output_dir)
Path(output_dir) / Path(test_case.preset_name) / Path(test_case.fork_name)
/ Path(test_case.runner_name) / Path(test_case.handler_name)
/ Path(test_case.suite_name) / Path(test_case.case_name)
)
collected_test_count += 1
print(f"Collected test at: {case_dir}") print(f"Collected test at: {case_dir}")
diagnostics_obj.collected_test_count += 1
incomplete_tag_file = case_dir / "INCOMPLETE" is_skip, diagnostics_obj = should_skip_case_dir(case_dir, args.force, diagnostics_obj)
if is_skip:
continue
if case_dir.exists(): if GENERATOR_MODE == MODE_SINGLE_PROCESS:
if not args.force and not incomplete_tag_file.exists(): result = generate_test_vector(test_case, case_dir, log_file, file_mode)
skipped_test_count += 1 write_result_into_diagnostics_obj(result, diagnostics_obj)
print(f'Skipping already existing test: {case_dir}') elif GENERATOR_MODE == MODE_MULTIPROCESSING:
continue item = TestCaseParams(test_case, case_dir, log_file, file_mode)
else: all_test_case_params.append(item)
print(f'Warning, output directory {case_dir} already exist,'
f' old files will be deleted and it will generate test vector files with the latest version')
# Clear the existing case_dir folder
shutil.rmtree(case_dir)
print(f'Generating test: {case_dir}') if GENERATOR_MODE == MODE_MULTIPROCESSING:
test_start = time.time() with Pool(processes=NUM_PROCESS) as pool:
results = pool.map(worker_function, iter(all_test_case_params))
written_part = False for result in results:
write_result_into_diagnostics_obj(result, diagnostics_obj)
# Add `INCOMPLETE` tag file to indicate that the test generation has not completed.
case_dir.mkdir(parents=True, exist_ok=True)
with incomplete_tag_file.open("w") as f:
f.write("\n")
try:
def output_part(out_kind: str, name: str, fn: Callable[[Path, ], None]):
# make sure the test case directory is created before any test part is written.
case_dir.mkdir(parents=True, exist_ok=True)
try:
fn(case_dir)
except IOError as e:
error_message = (
f'[Error] error when dumping test "{case_dir}", part "{name}", kind "{out_kind}": {e}'
)
# Write to error log file
with log_file.open("a+") as f:
f.write(error_message)
traceback.print_exc(file=f)
f.write('\n')
sys.exit(error_message)
meta = dict()
try:
for (name, out_kind, data) in test_case.case_fn():
written_part = True
if out_kind == "meta":
meta[name] = data
elif out_kind == "cfg":
output_part(out_kind, name, dump_yaml_fn(data, name, file_mode, cfg_yaml))
elif out_kind == "data":
output_part(out_kind, name, dump_yaml_fn(data, name, file_mode, yaml))
elif out_kind == "ssz":
output_part(out_kind, name, dump_ssz_fn(data, name, file_mode))
else:
assert False # Unknown kind
except SkippedTest as e:
print(e)
skipped_test_count += 1
shutil.rmtree(case_dir)
continue
# Once all meta data is collected (if any), write it to a meta data file.
if len(meta) != 0:
written_part = True
output_part("data", "meta", dump_yaml_fn(meta, "meta", file_mode, yaml))
if not written_part:
print(f"test case {case_dir} did not produce any test case parts")
except Exception as e:
error_message = f"[ERROR] failed to generate vector(s) for test {case_dir}: {e}"
# Write to error log file
with log_file.open("a+") as f:
f.write(error_message)
traceback.print_exc(file=f)
f.write('\n')
traceback.print_exc()
else:
# If no written_part, the only file was incomplete_tag_file. Clear the existing case_dir folder.
if not written_part:
shutil.rmtree(case_dir)
else:
generated_test_count += 1
test_identifier = "::".join([
test_case.preset_name,
test_case.fork_name,
test_case.runner_name,
test_case.handler_name,
test_case.suite_name,
test_case.case_name
])
test_identifiers.append(test_identifier)
# Only remove `INCOMPLETE` tag file
os.remove(incomplete_tag_file)
test_end = time.time()
span = round(test_end - test_start, 2)
if span > TIME_THRESHOLD_TO_PRINT:
print(f' - generated in {span} seconds')
provider_end = time.time() provider_end = time.time()
span = round(provider_end - provider_start, 2) span = round(provider_end - provider_start, 2)
if collect_only: if collect_only:
print(f"Collected {collected_test_count} tests in total") print(f"Collected {diagnostics_obj.collected_test_count} tests in total")
else: else:
summary_message = f"completed generation of {generator_name} with {generated_test_count} tests" summary_message = f"completed generation of {generator_name} with {diagnostics_obj.generated_test_count} tests"
summary_message += f" ({skipped_test_count} skipped tests)" summary_message += f" ({diagnostics_obj.skipped_test_count} skipped tests)"
if span > TIME_THRESHOLD_TO_PRINT: if span > TIME_THRESHOLD_TO_PRINT:
summary_message += f" in {span} seconds" summary_message += f" in {span} seconds"
print(summary_message) print(summary_message)
diagnostics = {
"collected_test_count": collected_test_count, diagnostics_output = {
"generated_test_count": generated_test_count, "collected_test_count": diagnostics_obj.collected_test_count,
"skipped_test_count": skipped_test_count, "generated_test_count": diagnostics_obj.generated_test_count,
"test_identifiers": test_identifiers, "skipped_test_count": diagnostics_obj.skipped_test_count,
"test_identifiers": diagnostics_obj.test_identifiers,
"durations": [f"{span} seconds"], "durations": [f"{span} seconds"],
} }
diagnostics_path = Path(os.path.join(output_dir, "diagnostics.json")) diagnostics_path = Path(os.path.join(output_dir, "diagnostics_obj.json"))
diagnostics_lock = FileLock(os.path.join(output_dir, "diagnostics.json.lock")) diagnostics_lock = FileLock(os.path.join(output_dir, "diagnostics_obj.json.lock"))
with diagnostics_lock: with diagnostics_lock:
diagnostics_path.touch(exist_ok=True) diagnostics_path.touch(exist_ok=True)
if os.path.getsize(diagnostics_path) == 0: if os.path.getsize(diagnostics_path) == 0:
with open(diagnostics_path, "w+") as f: with open(diagnostics_path, "w+") as f:
json.dump(diagnostics, f) json.dump(diagnostics_output, f)
else: else:
with open(diagnostics_path, "r+") as f: with open(diagnostics_path, "r+") as f:
existing_diagnostics = json.load(f) existing_diagnostics = json.load(f)
for k, v in diagnostics.items(): for k, v in diagnostics_output.items():
existing_diagnostics[k] += v existing_diagnostics[k] += v
with open(diagnostics_path, "w+") as f: with open(diagnostics_path, "w+") as f:
json.dump(existing_diagnostics, f) json.dump(existing_diagnostics, f)
print(f"wrote diagnostics to {diagnostics_path}") print(f"wrote diagnostics_obj to {diagnostics_path}")
def generate_test_vector(test_case, case_dir, log_file, file_mode):
cfg_yaml = get_cfg_yaml()
yaml = get_default_yaml()
written_part = False
print(f'Generating test: {case_dir}')
test_start = time.time()
# Add `INCOMPLETE` tag file to indicate that the test generation has not completed.
incomplete_tag_file = get_incomplete_tag_file(case_dir)
case_dir.mkdir(parents=True, exist_ok=True)
with incomplete_tag_file.open("w") as f:
f.write("\n")
result = None
try:
meta = dict()
try:
written_part, meta = execute_test(test_case, case_dir, meta, log_file, file_mode, cfg_yaml, yaml)
except SkippedTest as e:
result = 0 # 0 means skipped
print(e)
shutil.rmtree(case_dir)
return result
# Once all meta data is collected (if any), write it to a meta data file.
if len(meta) != 0:
written_part = True
output_part(case_dir, log_file, "data", "meta", dump_yaml_fn(meta, "meta", file_mode, yaml))
except Exception as e:
result = -1 # -1 means error
error_message = f"[ERROR] failed to generate vector(s) for test {case_dir}: {e}"
# Write to error log file
with log_file.open("a+") as f:
f.write(error_message)
traceback.print_exc(file=f)
f.write('\n')
print(error_message)
traceback.print_exc()
else:
# If no written_part, the only file was incomplete_tag_file. Clear the existing case_dir folder.
if not written_part:
print(f"[Error] test case {case_dir} did not produce any written_part")
shutil.rmtree(case_dir)
result = -1
else:
result = get_test_identifier(test_case)
# Only remove `INCOMPLETE` tag file
os.remove(incomplete_tag_file)
test_end = time.time()
span = round(test_end - test_start, 2)
if span > TIME_THRESHOLD_TO_PRINT:
print(f' - generated in {span} seconds')
return result
def write_result_into_diagnostics_obj(result, diagnostics_obj):
if result == -1: # error
pass
elif result == 0:
diagnostics_obj.skipped_test_count += 1
elif result is not None:
diagnostics_obj.generated_test_count += 1
diagnostics_obj.test_identifiers.append(result)
else:
raise Exception(f"Unexpected result: {result}")
def dump_yaml_fn(data: Any, name: str, file_mode: str, yaml_encoder: YAML): def dump_yaml_fn(data: Any, name: str, file_mode: str, yaml_encoder: YAML):
@ -292,9 +358,45 @@ def dump_yaml_fn(data: Any, name: str, file_mode: str, yaml_encoder: YAML):
out_path = case_path / Path(name + '.yaml') out_path = case_path / Path(name + '.yaml')
with out_path.open(file_mode) as f: with out_path.open(file_mode) as f:
yaml_encoder.dump(data, f) yaml_encoder.dump(data, f)
f.close()
return dump return dump
def output_part(case_dir, log_file, out_kind: str, name: str, fn: Callable[[Path, ], None]):
# make sure the test case directory is created before any test part is written.
case_dir.mkdir(parents=True, exist_ok=True)
try:
fn(case_dir)
except (IOError, ValueError) as e:
error_message = f'[Error] error when dumping test "{case_dir}", part "{name}", kind "{out_kind}": {e}'
# Write to error log file
with log_file.open("a+") as f:
f.write(error_message)
traceback.print_exc(file=f)
f.write('\n')
print(error_message)
sys.exit(error_message)
def execute_test(test_case, case_dir, meta, log_file, file_mode, cfg_yaml, yaml):
result = test_case.case_fn()
written_part = False
for (name, out_kind, data) in result:
written_part = True
if out_kind == "meta":
meta[name] = data
elif out_kind == "cfg":
output_part(case_dir, log_file, out_kind, name, dump_yaml_fn(data, name, file_mode, cfg_yaml))
elif out_kind == "data":
output_part(case_dir, log_file, out_kind, name, dump_yaml_fn(data, name, file_mode, yaml))
elif out_kind == "ssz":
output_part(case_dir, log_file, out_kind, name, dump_ssz_fn(data, name, file_mode))
else:
raise ValueError("Unknown out_kind %s" % out_kind)
return written_part, meta
def dump_ssz_fn(data: AnyStr, name: str, file_mode: str): def dump_ssz_fn(data: AnyStr, name: str, file_mode: str):
def dump(case_path: Path): def dump(case_path: Path):
out_path = case_path / Path(name + '.ssz_snappy') out_path = case_path / Path(name + '.ssz_snappy')

View File

@ -0,0 +1,13 @@
import multiprocessing
# Generator mode setting
MODE_SINGLE_PROCESS = 'MODE_SINGLE_PROCESS'
MODE_MULTIPROCESSING = 'MODE_MULTIPROCESSING'
# Test generator mode
GENERATOR_MODE = MODE_MULTIPROCESSING
# Number of subprocesses when using MODE_MULTIPROCESSING
NUM_PROCESS = multiprocessing.cpu_count() // 2 - 1
# Diagnostics
TIME_THRESHOLD_TO_PRINT = 1.0 # seconds

View File

@ -560,7 +560,7 @@ def _get_basic_dict(ssz_dict: Dict[str, Any]) -> Dict[str, Any]:
return result return result
def _get_copy_of_spec(spec): def get_copy_of_spec(spec):
fork = spec.fork fork = spec.fork
preset = spec.config.PRESET_BASE preset = spec.config.PRESET_BASE
module_path = f"eth2spec.{fork}.{preset}" module_path = f"eth2spec.{fork}.{preset}"
@ -601,14 +601,14 @@ def with_config_overrides(config_overrides, emitted_fork=None, emit=True):
def decorator(fn): def decorator(fn):
def wrapper(*args, spec: Spec, **kw): def wrapper(*args, spec: Spec, **kw):
# Apply config overrides to spec # Apply config overrides to spec
spec, output_config = spec_with_config_overrides(_get_copy_of_spec(spec), config_overrides) spec, output_config = spec_with_config_overrides(get_copy_of_spec(spec), config_overrides)
# Apply config overrides to additional phases, if present # Apply config overrides to additional phases, if present
if 'phases' in kw: if 'phases' in kw:
phases = {} phases = {}
for fork in kw['phases']: for fork in kw['phases']:
phases[fork], output = spec_with_config_overrides( phases[fork], output = spec_with_config_overrides(
_get_copy_of_spec(kw['phases'][fork]), config_overrides) get_copy_of_spec(kw['phases'][fork]), config_overrides)
if emitted_fork == fork: if emitted_fork == fork:
output_config = output output_config = output
kw['phases'] = phases kw['phases'] = phases