nim-sds/src/protobuf.nim

123 lines
3.5 KiB
Nim
Raw Permalink Normal View History

2024-11-29 14:07:24 +04:00
import libp2p/protobuf/minprotobuf
import std/options
2025-01-13 17:32:33 +04:00
import ../src/[message, protobufutil, bloom, reliability_utils]
2024-12-09 17:32:50 +04:00
proc toBytes(s: string): seq[byte] =
result = newSeq[byte](s.len)
copyMem(result[0].addr, s[0].unsafeAddr, s.len)
2024-11-29 14:07:24 +04:00
proc encode*(msg: Message): ProtoBuffer =
var pb = initProtoBuffer()
pb.write(1, msg.messageId)
2024-11-29 14:07:24 +04:00
pb.write(2, uint64(msg.lamportTimestamp))
2024-12-09 17:32:50 +04:00
2024-11-29 14:07:24 +04:00
for hist in msg.causalHistory:
pb.write(3, hist.toBytes) # Convert string to bytes for proper length handling
2024-12-09 17:32:50 +04:00
2024-11-29 14:07:24 +04:00
pb.write(4, msg.channelId)
pb.write(5, msg.content)
2024-12-12 16:04:24 +04:00
pb.write(6, msg.bloomFilter)
2024-11-29 14:07:24 +04:00
pb.finish()
2024-11-29 14:07:24 +04:00
pb
proc decode*(T: type Message, buffer: seq[byte]): ProtobufResult[T] =
let pb = initProtoBuffer(buffer)
var msg = Message()
if not ?pb.getField(1, msg.messageId):
return err(ProtobufError.missingRequiredField("messageId"))
var timestamp: uint64
if not ?pb.getField(2, timestamp):
return err(ProtobufError.missingRequiredField("lamportTimestamp"))
msg.lamportTimestamp = int64(timestamp)
2024-12-09 17:32:50 +04:00
# Decode causal history
2024-12-12 16:04:24 +04:00
var causalHistory: seq[string]
let histResult = pb.getRepeatedField(3, causalHistory)
if histResult.isOk:
msg.causalHistory = causalHistory
2024-11-29 14:07:24 +04:00
if not ?pb.getField(4, msg.channelId):
return err(ProtobufError.missingRequiredField("channelId"))
if not ?pb.getField(5, msg.content):
return err(ProtobufError.missingRequiredField("content"))
2024-12-12 16:04:24 +04:00
if not ?pb.getField(6, msg.bloomFilter):
msg.bloomFilter = @[] # Empty if not present
2024-12-12 16:04:24 +04:00
2024-11-29 14:07:24 +04:00
ok(msg)
proc serializeMessage*(msg: Message): Result[seq[byte], ReliabilityError] =
2024-11-29 14:07:24 +04:00
try:
let pb = encode(msg)
ok(pb.buffer)
except:
err(reSerializationError)
proc deserializeMessage*(data: seq[byte]): Result[Message, ReliabilityError] =
try:
let msgResult = Message.decode(data)
if msgResult.isOk:
ok(msgResult.get)
else:
err(reSerializationError)
2024-12-12 16:04:24 +04:00
except:
err(reDeserializationError)
proc serializeBloomFilter*(filter: BloomFilter): Result[seq[byte], ReliabilityError] =
try:
var pb = initProtoBuffer()
2024-12-12 16:04:24 +04:00
# Convert intArray to bytes
var bytes = newSeq[byte](filter.intArray.len * sizeof(int))
for i, val in filter.intArray:
let start = i * sizeof(int)
copyMem(addr bytes[start], unsafeAddr val, sizeof(int))
2024-12-12 16:04:24 +04:00
pb.write(1, bytes)
pb.write(2, uint64(filter.capacity))
pb.write(3, uint64(filter.errorRate * 1_000_000))
pb.write(4, uint64(filter.kHashes))
pb.write(5, uint64(filter.mBits))
2024-12-12 16:04:24 +04:00
pb.finish()
ok(pb.buffer)
except:
err(reSerializationError)
proc deserializeBloomFilter*(data: seq[byte]): Result[BloomFilter, ReliabilityError] =
if data.len == 0:
return err(reDeserializationError)
2024-12-12 16:04:24 +04:00
try:
let pb = initProtoBuffer(data)
var bytes: seq[byte]
var cap, errRate, kHashes, mBits: uint64
if not pb.getField(1, bytes).get() or not pb.getField(2, cap).get() or
not pb.getField(3, errRate).get() or not pb.getField(4, kHashes).get() or
not pb.getField(5, mBits).get():
2024-12-12 16:04:24 +04:00
return err(reDeserializationError)
2024-12-12 16:04:24 +04:00
# Convert bytes back to intArray
var intArray = newSeq[int](bytes.len div sizeof(int))
for i in 0 ..< intArray.len:
let start = i * sizeof(int)
copyMem(addr intArray[i], unsafeAddr bytes[start], sizeof(int))
ok(
BloomFilter(
intArray: intArray,
capacity: int(cap),
errorRate: float(errRate) / 1_000_000,
kHashes: int(kHashes),
mBits: int(mBits),
)
)
2024-11-29 14:07:24 +04:00
except:
err(reDeserializationError)