From b7b2fee6350489ec4cd1c2e118b8bfba21db8055 Mon Sep 17 00:00:00 2001 From: protolambda Date: Fri, 21 Jun 2019 21:12:27 +0200 Subject: [PATCH] uint add/sub type checking, fixes #1029 --- .../pyspec/eth2spec/utils/ssz/ssz_typing.py | 8 ++++- .../eth2spec/utils/ssz/test_ssz_typing.py | 32 +++++++++++++------ 2 files changed, 29 insertions(+), 11 deletions(-) diff --git a/test_libs/pyspec/eth2spec/utils/ssz/ssz_typing.py b/test_libs/pyspec/eth2spec/utils/ssz/ssz_typing.py index 22f76bada..971b2106b 100644 --- a/test_libs/pyspec/eth2spec/utils/ssz/ssz_typing.py +++ b/test_libs/pyspec/eth2spec/utils/ssz/ssz_typing.py @@ -52,9 +52,15 @@ class uint(BasicValue, metaclass=BasicType): if value < 0: raise ValueError("unsigned types must not be negative") if cls.byte_len and value.bit_length() > (cls.byte_len << 3): - raise ValueError("value out of bounds for uint{}".format(cls.byte_len)) + raise ValueError("value out of bounds for uint{}".format(cls.byte_len * 8)) return super().__new__(cls, value) + def __add__(self, other): + return self.__class__(super().__add__(coerce_type_maybe(other, self.__class__, strict=True))) + + def __sub__(self, other): + return self.__class__(super().__sub__(coerce_type_maybe(other, self.__class__, strict=True))) + @classmethod def default(cls): return cls(0) diff --git a/test_libs/pyspec/eth2spec/utils/ssz/test_ssz_typing.py b/test_libs/pyspec/eth2spec/utils/ssz/test_ssz_typing.py index 895a074a9..4325501aa 100644 --- a/test_libs/pyspec/eth2spec/utils/ssz/test_ssz_typing.py +++ b/test_libs/pyspec/eth2spec/utils/ssz/test_ssz_typing.py @@ -6,6 +6,14 @@ from .ssz_typing import ( ) +def expect_value_error(fn, msg): + try: + fn() + raise AssertionError(msg) + except ValueError: + pass + + def test_subclasses(): for u in [uint, uint8, uint16, uint32, uint64, uint128, uint256]: assert issubclass(u, uint) @@ -55,21 +63,13 @@ def test_basic_value_bounds(): # this should work assert k(v - 1) == v - 1 # but we do not allow overflows - try: - k(v) - assert False - except ValueError: - pass + expect_value_error(lambda: k(v), "no overflows allowed") for k, _ in max.items(): # this should work assert k(0) == 0 # but we do not allow underflows - try: - k(-1) - assert False - except ValueError: - pass + expect_value_error(lambda: k(-1), "no underflows allowed") def test_container(): @@ -213,3 +213,15 @@ def test_bytesn_subclass(): assert not issubclass(Bytes48, Bytes32) assert len(Bytes32() + Bytes48()) == 80 + + +def test_uint_math(): + assert uint8(0) + uint8(uint32(16)) == uint8(16) # allow explict casting to make invalid addition valid + + expect_value_error(lambda: uint8(0) - uint8(1), "no underflows allowed") + expect_value_error(lambda: uint8(1) + uint8(255), "no overflows allowed") + expect_value_error(lambda: uint8(0) + 256, "no overflows allowed") + expect_value_error(lambda: uint8(42) + uint32(123), "no mixed types") + expect_value_error(lambda: uint32(42) + uint8(123), "no mixed types") + + assert type(uint32(1234) + 56) == uint32