diff --git a/protobuf_serialization.nim b/protobuf_serialization.nim index 20106fb..74af9a7 100644 --- a/protobuf_serialization.nim +++ b/protobuf_serialization.nim @@ -3,8 +3,8 @@ import sets import serialization export serialization -import protobuf_serialization/[internal, types, reader, writer] -export types, reader, writer +import protobuf_serialization/[internal, types, reader, sizer, writer] +export types, reader, sizer, writer serializationFormat Protobuf @@ -27,3 +27,7 @@ func supports*[T](_: type Protobuf, ty: typedesc[T]): bool = # TODO return false when not supporting, instead of crashing compiler static: supportsCompileTime(T) true + +func computeSize*[T: object](_: type Protobuf, value: T): int = + ## Return the encoded size of the given value + computeObjectSize(value) diff --git a/protobuf_serialization/codec.nim b/protobuf_serialization/codec.nim index 1db1aa9..14dbb71 100644 --- a/protobuf_serialization/codec.nim +++ b/protobuf_serialization/codec.nim @@ -156,10 +156,26 @@ template toBytes*(x: pfloat): openArray[byte] = template toBytes*(header: FieldHeader): openArray[byte] = toBytes(uint32(header), Leb128).toOpenArray() -proc vsizeof*(x: SomeVarint): int = +func computeSize*(x: SomeVarint): int = ## Returns number of bytes required to encode integer ``x`` as varint. Leb128.len(toUleb(x)) +func computeSize*(x: SomeFixed64 | SomeFixed32): int = + ## Returns number of bytes required to encode integer ``x`` as varint. + sizeof(x) + +func computeSize*(x: pstring | pbytes): int = + let len = distinctBase(x).len() + computeSize(puint64(len)) + len + +func computeSize*(x: FieldHeader): int = + ## Returns number of bytes required to encode integer ``x`` as varint. + computeSize(puint32(x)) + +func computeSize*(field: int, x: SomeScalar): int = + computeSize(FieldHeader.init(field, wireKind(typeof(x)))) + + computeSize(x) + proc writeValue*(output: OutputStream, value: SomeVarint) = output.write(toBytes(value)) @@ -177,8 +193,11 @@ proc writeValue*(output: OutputStream, value: pbytes) = proc writeValue*(output: OutputStream, value: SomeFixed32) = output.write(toBytes(value)) +proc writeValue*(output: OutputStream, value: FieldHeader) = + output.write(toBytes(value)) + proc writeField*(output: OutputStream, field: int, value: SomeScalar) = - output.write(toBytes(FieldHeader.init(field, wireKind(typeof(value))))) + output.writeValue(FieldHeader.init(field, wireKind(typeof(value)))) output.writeValue(value) proc readValue*[T: SomeVarint](input: InputStream, _: type T): T = diff --git a/protobuf_serialization/internal.nim b/protobuf_serialization/internal.nim index eb5fbbf..10d72c4 100644 --- a/protobuf_serialization/internal.nim +++ b/protobuf_serialization/internal.nim @@ -54,6 +54,28 @@ proc fieldNumberOf*(T: type, fieldName: static string): int {.compileTime.} = else: fieldNum +template tableObject*(TableObject, K, V) = + when K is SomePBInt and V is SomePBInt: + type + TableObject {.proto3.} = object + key {.fieldNumber: 1, pint.}: K + value {.fieldNumber: 2, pint.}: V + elif K is SomePBInt: + type + TableObject {.proto3.} = object + key {.fieldNumber: 1, pint.}: K + value {.fieldNumber: 2.}: V + elif V is SomePBInt: + type + TableObject {.proto3.} = object + key {.fieldNumber: 1.}: K + value {.fieldNumber: 2, pint.}: V + else: + type + TableObject {.proto3.} = object + key {.fieldNumber: 1.}: K + value {.fieldNumber: 2.}: V + template protoType*(InnerType, RootType, FieldType: untyped, fieldName: untyped) = mixin flatType @@ -117,8 +139,10 @@ template protoType*(InnerType, RootType, FieldType: untyped, fieldName: untyped) type InnerType = pbytes elif FlatType is enum: type InnerType = penum - elif FlatType is object or FlatType is ref: - type InnerType = FieldType + elif FlatType is object: + type InnerType = pbytes + elif FlatType is ref and defined(ConformanceTest): + type InnerType = pbytes else: type InnerType = UnsupportedType[FieldType, RootType, fieldName] diff --git a/protobuf_serialization/reader.nim b/protobuf_serialization/reader.nim index fcd6527..35e1cbb 100644 --- a/protobuf_serialization/reader.nim +++ b/protobuf_serialization/reader.nim @@ -64,28 +64,7 @@ when defined(ConformanceTest): header: FieldHeader, ProtoType: type ) = - # I know it's ugly, but I cannot find a clean way to do it - # ... And nobody cares about map - when K is SomePBInt and V is SomePBInt: - type - TableObject {.proto3.} = object - key {.fieldNumber: 1, pint.}: K - value {.fieldNumber: 2, pint.}: V - elif K is SomePBInt: - type - TableObject {.proto3.} = object - key {.fieldNumber: 1, pint.}: K - value {.fieldNumber: 2.}: V - elif V is SomePBInt: - type - TableObject {.proto3.} = object - key {.fieldNumber: 1.}: K - value {.fieldNumber: 2, pint.}: V - else: - type - TableObject {.proto3.} = object - key {.fieldNumber: 1.}: K - value {.fieldNumber: 2.}: V + tableObject(TableObject, K, V) var tmp = default(TableObject) stream.readFieldInto(tmp, header, ProtoType) value[tmp.key] = tmp.value @@ -146,6 +125,7 @@ proc readFieldPackedInto[T]( elif ProtoType is SomeFixed32: WireKind.Fixed32 else: + static: doAssert ProtoType is SomeFixed64 WireKind.Fixed64 inner.readFieldInto(value[^1], FieldHeader.init(header.number, kind), ProtoType) @@ -184,8 +164,8 @@ proc readValueInternal[T: object](stream: InputStream, value: var T, silent: boo stream.readFieldPackedInto(fieldVar, header, ProtoType) else: stream.readFieldInto(fieldVar, header, ProtoType) - elif ProtoType is ref and defined(ConformanceTest): - fieldVar = new ProtoType + elif typeof(fieldVar) is ref and defined(ConformanceTest): + fieldVar = new typeof(fieldVar) stream.readFieldInto(fieldVar[], header, ProtoType) else: stream.readFieldInto(fieldVar, header, ProtoType) diff --git a/protobuf_serialization/sizer.nim b/protobuf_serialization/sizer.nim new file mode 100644 index 0000000..0a378a8 --- /dev/null +++ b/protobuf_serialization/sizer.nim @@ -0,0 +1,127 @@ +import + std/[typetraits, tables], + stew/shims/macros, + serialization, + "."/[codec, internal, types] + +func computeObjectSize*[T: object](value: T): int + +func computeFieldSize( + fieldNum: int, fieldVal: auto, ProtoType: type UnsupportedType, + _: static bool) = + # TODO turn this into an extension point + unsupportedProtoType ProtoType.FieldType, ProtoType.RootType, ProtoType.fieldName + +func computeFieldSize[T: object and not PBOption]( + fieldNum: int, fieldVal: T, ProtoType: type pbytes, + skipDefault: static bool): int = + let + size = computeObjectSize(fieldVal) + + when skipDefault: + if size == 0: + return 0 + + computeSize(FieldHeader.init(fieldNum, ProtoType.wireKind())) + + computeSize(puint64(size)) + + size + +proc computeFieldSize*[T: not object]( + fieldNum: int, fieldVal: T, + ProtoType: type SomeScalar, skipDefault: static bool): int = + when skipDefault: + const def = default(typeof(fieldVal)) + if fieldVal == def: + return + + computeSize(fieldNum, ProtoType(fieldVal)) + +proc computeFieldSize*( + fieldNum: int, fieldVal: PBOption, ProtoType: type, + skipDefault: static bool): int = + if fieldVal.isSome(): # TODO required field checking + computeFieldSize(fieldNum, fieldVal.get(), ProtoType, skipDefault) + else: + 0 + +when defined(ConformanceTest): + proc computeFieldSize*[T]( + fieldNum: int, fieldVal: ref T, + ProtoType: type pbytes, skipDefault: static bool): int = + if not fieldVal.isNil(): + computeFieldSize(fieldNum, fieldVal[], ProtoType, skipDefault) + else: + 0 + + proc writeField[T: enum]( + stream: OutputStream, fieldNum: int, fieldVal: T, ProtoType: type) = + when 0 notin T: + {.fatal: $T & " definition must contain a constant that maps to zero".} + stream.writeField(fieldNum, pint32(fieldVal.ord())) + + proc computeFieldSize*[K, V]( + fieldNum: int, fieldVal: Table[K, V], ProtoType: type pbytes, + skipDefault: static bool): int = + tableObject(TableObject, K, V) + for k, v in fieldVal.pairs(): + let tmp = TableObject(key: k, value: v) + result += computeFieldSize(fieldNum, tmp, ProtoType, false) + +proc computeSizePacked*[T: not byte, ProtoType: SomePrimitive]( + values: openArray[T], _: type ProtoType): int = + const canCopyMem = + ProtoType is SomeFixed32 or ProtoType is SomeFixed64 or ProtoType is pbool + when canCopyMem: + values.len() * sizeof(T) + else: + var total = 0 + for item in values: + total += computeSize(ProtoType(item)) + total + +proc computeFieldSizePacked*[ProtoType: SomePrimitive]( + field: int, values: openArray, _: type ProtoType): int = + # Packed encoding uses a length-delimited field byte length of the sum of the + # byte lengths of each field followed by the header-free contents + let + dataSize = computeSizePacked(values, ProtoType) + + computeSize(FieldHeader.init(field, WireKind.LengthDelim)) + + computeSize(puint64(dataSize)) + + dataSize + +func computeObjectSize*[T: object](value: T): int = + const + isProto2: bool = T.isProto2() + isProto3: bool = T.isProto3() + static: + doAssert isProto2 xor isProto3 + + var total = 0 + enumInstanceSerializedFields(value, fieldName, fieldVal): + const + fieldNum = T.fieldNumberOf(fieldName) + + type + FlatType = flatType(fieldVal) + + protoType(ProtoType, T, FlatType, fieldName) + + let fieldSize = when FlatType is seq and FlatType isnot seq[byte]: + const + isPacked = T.isPacked(fieldName).get(isProto3) + when isPacked and ProtoType is SomePrimitive: + computeFieldSizePacked(fieldNum, fieldVal, ProtoType) + else: + var dataSize = 0 + for i in 0.. 0: output.write( cast[ptr UncheckedArray[byte]]( - unsafeAddr values[0]).toOpenArray(0, dlength - 1)) + unsafeAddr values[0]).toOpenArray(0, dataSize - 1)) else: for value in values: output.write(toBytes(ProtoType(value))) when defined(ConformanceTest): proc writeField[T: enum]( - stream: OutputStream, fieldNum: int, fieldVal: T, ProtoType: type) = + stream: OutputStream, + fieldNum: int, + fieldVal: T, + ProtoType: type, + skipDefault: static bool = false + ) = when 0 notin T: {.fatal: $T & " definition must contain a constant that maps to zero".} stream.writeField(fieldNum, pint32(fieldVal.ord())) - proc writeField*[K, V]( + proc writeField[K, V]( stream: OutputStream, fieldNum: int, value: Table[K, V], - ProtoType: type + ProtoType: type, + skipDefault: static bool = false ) = - when K is SomePBInt and V is SomePBInt: - type - TableObject {.proto3.} = object - key {.fieldNumber: 1, pint.}: K - value {.fieldNumber: 2, pint.}: V - elif K is SomePBInt: - type - TableObject {.proto3.} = object - key {.fieldNumber: 1, pint.}: K - value {.fieldNumber: 2.}: V - elif V is SomePBInt: - type - TableObject {.proto3.} = object - key {.fieldNumber: 1.}: K - value {.fieldNumber: 2, pint.}: V - else: - type - TableObject {.proto3.} = object - key {.fieldNumber: 1.}: K - value {.fieldNumber: 2.}: V + tableObject(TableObject, K, V) for k, v in value.pairs(): let tmp = TableObject(key: k, value: v) stream.writeField(fieldNum, tmp, ProtoType) -proc writeValue*[T: object](stream: OutputStream, value: T) = +proc writeObject[T: object](stream: OutputStream, value: T) = const isProto2: bool = T.isProto2() isProto3: bool = T.isProto3() @@ -127,54 +116,21 @@ proc writeValue*[T: object](stream: OutputStream, value: T) = stream.writeFieldPacked(fieldNum, fieldVal, ProtoType) else: for i in 0..