diff --git a/libp2p/varint.nim b/libp2p/varint.nim index 69fd910c8..9b8b70cf5 100644 --- a/libp2p/varint.nim +++ b/libp2p/varint.nim @@ -21,6 +21,7 @@ type Success, Overflow, Incomplete, + Overlong, Overrun PB* = object @@ -99,6 +100,13 @@ proc getUVarint*[T: PB|LP](vtype: typedesc[T], outlen = 0 outval = cast[type(outval)](0) + when vtype is LP: + if result == VarintStatus.Success: + if outlen != vsizeof(outval): + outval = cast[type(outval)](0) + outlen = 0 + result = VarintStatus.Overlong + proc putUVarint*[T: PB|LP](vtype: typedesc[T], pbytes: var openarray[byte], outlen: var int, @@ -242,18 +250,3 @@ proc encodeVarint*(vtype: typedesc[LP], result.setLen(outsize) else: raise newException(VarintError, "Error '" & $res & "'") - -proc decodeSVarint*(data: openarray[byte]): int {.inline.} = - ## Decode signed integer from array ``data`` and return it as result. - var outsize = 0 - let res = getSVarint(data, outsize, result) - if res != VarintStatus.Success: - raise newException(VarintError, "Error '" & $res & "'") - -proc decodeUVarint*[T: PB|LP](vtype: typedesc[T], - data: openarray[byte]): uint {.inline.} = - ## Decode unsigned integer from array ``data`` and return it as result. - var outsize = 0 - let res = vtype.getUVarint(data, outsize, result) - if res != VarintStatus.Success: - raise newException(VarintError, "Error '" & $res & "'") diff --git a/tests/testvarint.nim b/tests/testvarint.nim index 63e6b9184..15caa0a3a 100644 --- a/tests/testvarint.nim +++ b/tests/testvarint.nim @@ -216,3 +216,54 @@ suite "Variable integer test suite": 0x8000_0000_0000_0000'u64) == VarintStatus.Overflow LP.putUVarint(buffer, length, 0xFFFF_FFFF_FFFF_FFFF'u64) == VarintStatus.Overflow + + test "[LibP2P] Overlong values test": + const OverlongValues = [ + # Zero bytes at the end + @[0x81'u8, 0x00'u8], + @[0x81'u8, 0x80'u8, 0x00'u8], + @[0x81'u8, 0x80'u8, 0x80'u8, 0x00'u8], + @[0x81'u8, 0x80'u8, 0x80'u8, 0x80'u8, 0x00'u8], + @[0x81'u8, 0x80'u8, 0x80'u8, 0x80'u8, 0x80'u8, 0x00'u8], + @[0x81'u8, 0x80'u8, 0x80'u8, 0x80'u8, 0x80'u8, 0x80'u8, 0x00'u8], + @[0x81'u8, 0x80'u8, 0x80'u8, 0x80'u8, 0x80'u8, 0x80'u8, 0x80'u8, + 0x00'u8], + @[0x81'u8, 0x80'u8, 0x80'u8, 0x80'u8, 0x80'u8, 0x80'u8, 0x80'u8, + 0x80'u8, 0x00'u8], + # Zero bytes at the middle and zero byte at the end + @[0x81'u8, 0x80'u8, 0x81'u8, 0x00'u8], + @[0x81'u8, 0x80'u8, 0x80'u8, 0x81'u8, 0x00'u8], + @[0x81'u8, 0x80'u8, 0x80'u8, 0x80'u8, 0x81'u8, 0x00'u8], + @[0x81'u8, 0x80'u8, 0x80'u8, 0x80'u8, 0x80'u8, 0x81'u8, 0x00'u8], + @[0x81'u8, 0x80'u8, 0x80'u8, 0x80'u8, 0x80'u8, 0x80'u8, 0x81'u8, + 0x00'u8], + @[0x81'u8, 0x80'u8, 0x80'u8, 0x80'u8, 0x80'u8, 0x80'u8, 0x80'u8, + 0x81'u8, 0x00'u8], + # Zero bytes at the middle and zero bytes at the end + @[0x81'u8, 0x80'u8, 0x80'u8, 0x81'u8, 0x80'u8, 0x00'u8], + @[0x81'u8, 0x80'u8, 0x80'u8, 0x81'u8, 0x80'u8, 0x80'u8, 0x00'u8], + @[0x81'u8, 0x80'u8, 0x80'u8, 0x80'u8, 0x81'u8, 0x80'u8, 0x80'u8, + 0x00'u8], + @[0x81'u8, 0x80'u8, 0x80'u8, 0x80'u8, 0x80'u8, 0x81'u8, 0x80'u8, + 0x80'u8, 0x00'u8], + ] + var length = 0 + var value = 0'u64 + + for item in OverlongValues: + check: + LP.getUVarint(item, length, value) == VarintStatus.Overlong + length == 0 + value == 0 + + # We still should be able to decode zero value + check: + LP.getUVarint(@[0x00'u8], length, value) == VarintStatus.Success + length == 1 + value == 0 + + # But not overlonged zero value + check: + LP.getUVarint(@[0x80'u8, 0x00'u8], length, value) == VarintStatus.Overlong + length == 0 + value == 0