From 211f8275db5ddb9cab54ecd329e53bc3551ae45e Mon Sep 17 00:00:00 2001 From: kaichaosun Date: Wed, 16 Jul 2025 14:47:47 +0800 Subject: [PATCH] chore: refactor data types --- chat_sdk/db.nim | 2 +- chat_sdk/segmentation.nim | 103 +++++++++++++++++--------------------- 2 files changed, 46 insertions(+), 59 deletions(-) diff --git a/chat_sdk/db.nim b/chat_sdk/db.nim index cde7bf0..23bda65 100644 --- a/chat_sdk/db.nim +++ b/chat_sdk/db.nim @@ -118,4 +118,4 @@ proc completeMessageSegments*( try: self.db.exec(sql"ROLLBACK") except: discard - return err("complete segment messages with error: " & e.msg) \ No newline at end of file + return err("complete segment messages with error: " & e.msg) diff --git a/chat_sdk/segmentation.nim b/chat_sdk/segmentation.nim index 1f027ef..eb21c6c 100644 --- a/chat_sdk/segmentation.nim +++ b/chat_sdk/segmentation.nim @@ -4,19 +4,12 @@ import results import leopard import chronicles -# External dependencies (still needed) -# import protobuf # Nim protobuf library (e.g., protobuf-nim) +import protobuf_serialization/proto_parser +import protobuf_serialization -# SegmentMessage type (unchanged) -type - SegmentMessage* = ref object - entireMessageHash*: seq[byte] - index*: uint32 - segmentsCount*: uint32 - paritySegmentIndex*: uint32 - paritySegmentsCount*: uint32 - payload*: seq[byte] +import db_models +import_proto3 "segment_message.proto" # Placeholder types (unchanged) type @@ -24,10 +17,7 @@ type payload*: seq[byte] # Add other fields as needed - StatusMessage* = object - transportLayer*: TransportLayer - - TransportLayer* = object + Message* = object hash*: seq[byte] payload*: seq[byte] sigPubKey*: seq[byte] @@ -50,9 +40,21 @@ const SegmentsReedsolomonMaxCount = 256 -# Validation methods (unchanged) -proc isValid*(s: SegmentMessage): bool = - s.segmentsCount >= 2 or s.paritySegmentsCount > 0 +proc isValid(s: SegmentMessageProto): bool = + # Check if the hash length is valid (32 bytes for Keccak256) + if s.entire_message_hash.len != 32: + return false + + # Check if the segment index is within the valid range + if s.segments_count > 0 and s.index >= s.segments_count: + return false + + # Check if the parity segment index is within the valid range + if s.parity_segments_count > 0 and s.parity_segment_index >= s.parity_segments_count: + return false + + # Check if segments_count is at least 2 or parity_segments_count is positive + return s.segments_count >= 2 or s.parity_segments_count > 0 proc isParityMessage*(s: SegmentMessage): bool = s.segmentsCount == 0 and s.paritySegmentsCount > 0 @@ -100,16 +102,6 @@ proc protoMarshal(msg: SegmentMessage): Result[seq[byte], string] = # Fake serialization (index + payload length) TODO return ok(@[byte(msg.index)] & msg.payload) -proc protoUnmarshal(data: seq[byte], msg: var SegmentMessage): Result[void, string] = - # Fake deserialization (reconstruct index and payload) - if data.len < 1: - return err("data too short") - msg.index = uint32(data[0]) - msg.payload = data[1..^1] - msg.segmentsCount = 2 - msg.entireMessageHash = @[byte 1, 2, 3] # Dummy hash - return ok() - # Segment message into smaller chunks (updated with nim-leopard) proc segmentMessageInternal(newMessage: WakuNewMessage, segmentSize: int): Result[seq[WakuNewMessage], string] = if newMessage.payload.len <= segmentSize: @@ -196,38 +188,43 @@ proc segmentMessageInternal(newMessage: WakuNewMessage, segmentSize: int): Resul return ok(segmentMessages) -proc handleSegmentationLayer*(s: MessageSender, message: var StatusMessage): Result[void, string] = +proc handleSegmentationLayer*(s: MessageSender, message: var Message): Result[void, string] = logScope: - site = "handleSegmentationLayerV2" - hash = message.transportLayer.hash.toHex + site = "handleSegmentationLayer" + hash = message.hash.toHex - var segmentMessage = SegmentMessage() - let unmarshalResult = protoUnmarshal(message.transportLayer.payload, segmentMessage) - if unmarshalResult.isErr: - return err("failed to unmarshal SegmentMessage: " & unmarshalResult.error) + let segmentMessageProto = Protobuf.decode(message.payload, SegmentMessageProto) debug "handling message segment", - "EntireMessageHash" = segmentMessage.entireMessageHash.toHex, - "Index" = $segmentMessage.index, - "SegmentsCount" = $segmentMessage.segmentsCount, - "ParitySegmentIndex" = $segmentMessage.paritySegmentIndex, - "ParitySegmentsCount" = $segmentMessage.paritySegmentsCount + "EntireMessageHash" = segmentMessageProto.entireMessageHash.toHex, + "Index" = $segmentMessageProto.index, + "SegmentsCount" = $segmentMessageProto.segmentsCount, + "ParitySegmentIndex" = $segmentMessageProto.paritySegmentIndex, + "ParitySegmentsCount" = $segmentMessageProto.paritySegmentsCount # TODO here use the mock function, real function should be used together with the persistence layer - let alreadyCompleted = s.persistence.isMessageAlreadyCompleted(segmentMessage.entireMessageHash) + let alreadyCompleted = s.persistence.isMessageAlreadyCompleted(segmentMessageProto.entireMessageHash) if alreadyCompleted.isErr: return err(alreadyCompleted.error) if alreadyCompleted.get(): return err(ErrMessageSegmentsAlreadyCompleted) - if not segmentMessage.isValid(): + if not segmentMessageProto.isValid(): return err(ErrMessageSegmentsInvalidCount) - let saveResult = s.persistence.saveMessageSegment(segmentMessage, message.transportLayer.sigPubKey, getTime().toUnix) + let segmentMessage = SegmentMessage( + entireMessageHash: segmentMessageProto.entireMessageHash, + index: segmentMessageProto.index, + segmentsCount: segmentMessageProto.segmentsCount, + paritySegmentIndex: segmentMessageProto.paritySegmentIndex, + paritySegmentsCount: segmentMessageProto.paritySegmentsCount, + payload: segmentMessageProto.payload + ) + let saveResult = s.persistence.saveMessageSegment(segmentMessage, message.sigPubKey, getTime().toUnix) if saveResult.isErr: return err(saveResult.error) - let segments = s.persistence.getMessageSegments(segmentMessage.entireMessageHash, message.transportLayer.sigPubKey) + let segments = s.persistence.getMessageSegments(segmentMessage.entireMessageHash, message.sigPubKey) if segments.isErr: return err(segments.error) @@ -295,11 +292,11 @@ proc handleSegmentationLayer*(s: MessageSender, message: var StatusMessage): Res if entirePayloadHash.data != segmentMessage.entireMessageHash: return err(ErrMessageSegmentsHashMismatch) - let completeResult = s.persistence.completeMessageSegments(segmentMessage.entireMessageHash, message.transportLayer.sigPubKey, getTime().toUnix) + let completeResult = s.persistence.completeMessageSegments(segmentMessage.entireMessageHash, message.sigPubKey, getTime().toUnix) if completeResult.isErr: return err(completeResult.error) - message.transportLayer.payload = entirePayload + message.payload = entirePayload return ok() @@ -317,17 +314,7 @@ proc segmentMessage*(s: MessageSender, newMessage: WakuNewMessage): Result[seq[W debug "message segmented", "segments" = $messages.len return ok(messages) -proc demoRounding() = - let x = 3.7 - let y = -3.7 - - echo "ceil(", x, ") = ", ceil(x) # Rounds up - echo "floor(", x, ") = ", floor(x) # Rounds down - echo "round(", x, ") = ", round(x) # Rounds to nearest integer - echo "trunc(", x, ") = ", trunc(x) # Truncates decimal part - echo "ceil(", y, ") = ", ceil(y) - echo "floor(", y, ") = ", floor(y) - +proc demo() = let bufSize = 64 # byte count per buffer, must be a multiple of 64 buffers = 239 # number of data symbols @@ -342,7 +329,7 @@ proc demoRounding() = when isMainModule: - demoRounding() + demo() proc add2*(x, y: int): int = ## Adds two numbers together.