diff --git a/beacon_chain/ssz.nim b/beacon_chain/ssz.nim index 3e775f133..341fcb99d 100644 --- a/beacon_chain/ssz.nim +++ b/beacon_chain/ssz.nim @@ -9,7 +9,7 @@ # See https://github.com/ethereum/beacon_chain/issues/100 # and https://github.com/ethereum/beacon_chain/tree/master/ssz -import ./datatypes, eth_common, endians, typetraits +import ./datatypes, eth_common, endians, typetraits, options # ################### Helper functions ################################### func `+`[T](p: ptr T, offset: int): ptr T {.inline.}= @@ -17,81 +17,91 @@ func `+`[T](p: ptr T, offset: int): ptr T {.inline.}= const size = sizeof T cast[ptr T](cast[ByteAddress](p) +% offset * size) -func checkSize[T: not seq](x: T, pos, len: int) {.inline.}= - # This assumes that T is packed - doAssert pos + T.sizeof < len, "Deserialization overflow" +func eat(x: var auto, data: ptr byte, pos: var int, len: int): bool = + if pos + x.sizeof > len: return + copyMem(x.addr, data + pos, x.sizeof) + inc pos, x.sizeof + return true -func checkSize[T](x: seq[T], pos, len: int) {.inline.}= - # seq length is stored in an uint32 (4 bytes) for SSZ - doAssert pos + 4 + x.len * T.sizeof < len, "Deserialization overflow" +func eatInt[T: SomeInteger or byte](x: var T, data: ptr byte, pos: var int, len: int): + bool = + if pos + x.sizeof > len: return + + # XXX: any better way to get a suitably aligned buffer in nim??? + # see also: https://github.com/nim-lang/Nim/issues/9206 + var tmp: uint64 + var alignedBuf = cast[ptr byte](tmp.addr) + copyMem(alignedBuf, data + pos, x.sizeof) -template deserInt(x: var SomeInteger or byte, data: ptr byte, pos: var int) = when x.sizeof == 8: - bigEndian64(x.addr, data + pos) - inc pos, 8 + bigEndian64(x.addr, alignedBuf) elif x.sizeof == 4: - bigEndian32(x.addr, data + pos) - inc pos, 4 + bigEndian32(x.addr, alignedBuf) elif x.sizeof == 2: - bigEndian16(x.addr, data + pos) - inc pos, 2 + bigEndian16(x.addr, alignedBuf) + elif x.sizeof == 1: + x = cast[ptr type x](alignedBuf)[] else: - x = cast[ptr type x](data + pos)[] - inc pos + {.fatal: "Unsupported type deserialization: " & $(type(x)).name.} -func deserSeq[T](dest: var seq[T], len: int, src: ptr byte, pos: var int) = - dest = newSeqUninitialized[T](len) - for val in dest.mitems: - val.deserInt(src, pos) + inc pos, x.sizeof + return true -func serInt[T: SomeInteger or byte](dest: var seq[byte], src: T, buffer: var array[sizeof(T), byte]) {.inline.}= - when T.sizeof == 8: - bigEndian64(buffer.addr, src.unsafeAddr) - elif T.sizeof == 4: - bigEndian32(buffer.addr, src.unsafeAddr) - elif T.sizeof == 2: - bigEndian16(buffer.addr, src.unsafeAddr) - else: - dest.add byte(src) - return - dest.add buffer +func eatSeq[T: SomeInteger or byte](x: var seq[T], data: ptr byte, pos: var int, + len: int): bool = + var items: int32 + if not eatInt(items, data, pos, len): return + if pos + T.sizeof * items > len: return + + x = newSeqUninitialized[T](items) + for val in x.mitems: + discard eatInt(val, data, pos, len) # Bounds-checked above + return true func serInt[T: SomeInteger or byte](dest: var seq[byte], src: T) {.inline.}= - var buffer: array[T.sizeof, byte] - dest.serInt(src, buffer) + # XXX: any better way to get a suitably aligned buffer in nim??? + var tmp: T + var alignedBuf = cast[ptr array[src.sizeof, byte]](tmp.addr) + when src.sizeof == 8: + bigEndian64(alignedBuf, src.unsafeAddr) + elif src.sizeof == 4: + bigEndian32(alignedBuf, src.unsafeAddr) + elif src.sizeof == 2: + bigEndian16(alignedBuf, src.unsafeAddr) + elif src.sizeof == 1: + copyMem(alignedBuf, src.unsafeAddr, src.sizeof) # careful, aliasing.. + else: + {.fatal: "Unsupported type deserialization: " & $(type(x)).name.} + + dest.add alignedBuf[] func serSeq[T: SomeInteger or byte](dest: var seq[byte], src: seq[T]) = dest.serInt src.len.uint32 - var buffer: array[T.sizeof, byte] for val in src: - dest.serInt(val, buffer) + dest.serInt(val) # ################### Core functions ################################### -func deserialize(data: ptr byte, pos: var int, len: int, typ: typedesc[object]): typ = - for field in result.fields: - checkSize field, pos, len - when field is EthAddress: - copyMem(field.addr, data + pos, 20) - inc pos, 20 - elif field is MDigest: - const size = field.bits div 8 - copyMem(field.addr, data + pos, size) - inc pos, size +func deserialize(data: ptr byte, pos: var int, len: int, typ: typedesc[object]): + auto = + var t: typ + + for field in t.fields: + when field is EthAddress | MDigest: + if not eat(field, data, pos, len): return elif field is (SomeInteger or byte): - field.deserInt(data, pos) + if not eatInt(field, data, pos, len): return elif field is seq[SomeInteger or byte]: - var length: int32 - bigEndian32(length.addr, data + pos) - inc pos, 4 - deserSeq(field, length, data, pos) + if not eatSeq(field, data, pos, len): return else: # TODO: deserializing subtypes (?, depends on final spec) {.fatal: "Unsupported type deserialization: " & $typ.name.} + return some(t) func deserialize*( data: seq[byte or uint8] or openarray[byte or uint8] or string, - typ: typedesc[object]): typ {.inline.}= + typ: typedesc[object]): auto {.inline.} = + # XXX: returns Option[typ]: https://github.com/nim-lang/Nim/issues/9195 var pos = 0 - deserialize((ptr byte)(data[0].unsafeAddr), pos, data.len, typ) + return deserialize((ptr byte)(data[0].unsafeAddr), pos, data.len, typ) func serialize*[T](value: T): seq[byte] = for field in value.fields: diff --git a/tests/test_ssz.nim b/tests/test_ssz.nim index 152b35898..453499603 100644 --- a/tests/test_ssz.nim +++ b/tests/test_ssz.nim @@ -6,7 +6,7 @@ # at your option. This file may not be copied, modified, or distributed except according to those terms. import - unittest, nimcrypto, eth_common, sequtils, + unittest, nimcrypto, eth_common, sequtils, options, ../beacon_chain/ssz func filled[N: static[int], T](typ: type array[N, T], value: T): array[N, T] = @@ -47,9 +47,14 @@ suite "Simple serialization": expected_ser &= [byte 0, 0, 0, 3, 'c'.ord, 'o'.ord, 'w'.ord] test "Deserialization": - let deser = expected_ser.deserialize(Foo) + let deser = expected_ser.deserialize(Foo).get() check: expected_deser == deser test "Serialization": let ser = expected_deser.serialize() check: expected_ser == ser + + test "Overflow": + check: + expected_ser[0..^2].deserialize(Foo).isNone() + expected_ser[1..^1].deserialize(Foo).isNone()