From 420eb331d91adf1de938bf2c5fcd7f172de05a89 Mon Sep 17 00:00:00 2001 From: kaichaosun Date: Mon, 12 May 2025 10:16:25 +0800 Subject: [PATCH 1/9] feat: segentation logic --- chat_sdk/segmentation.nim | 304 +++++++++++++++++++++++++ tests/{test1.nim => test_chat_sdk.nim} | 3 + 2 files changed, 307 insertions(+) create mode 100644 chat_sdk/segmentation.nim rename tests/{test1.nim => test_chat_sdk.nim} (87%) diff --git a/chat_sdk/segmentation.nim b/chat_sdk/segmentation.nim new file mode 100644 index 0000000..222ab6f --- /dev/null +++ b/chat_sdk/segmentation.nim @@ -0,0 +1,304 @@ +import math, times, sequtils, strutils, options +import nimcrypto # For Keccak256 hashing +import logging # Placeholder for logging +import results +import leopard # Import nim-leopard for Reed-Solomon + +# External dependencies (still needed) +# import protobuf # Nim protobuf library (e.g., protobuf-nim) + +# Placeholder types (unchanged) +type + WakuNewMessage = object + payload: seq[byte] + # Add other fields as needed + + StatusMessage = object + transportLayer: TransportLayer + + TransportLayer = object + hash: seq[byte] + payload: seq[byte] + sigPubKey: seq[byte] + + Persistence = object + # Placeholder for persistence interface + +# Error definitions (unchanged) +const + ErrMessageSegmentsIncomplete = "message segments incomplete" + ErrMessageSegmentsAlreadyCompleted = "message segments already completed" + ErrMessageSegmentsInvalidCount = "invalid segments count" + ErrMessageSegmentsHashMismatch = "hash of entire payload does not match" + ErrMessageSegmentsInvalidParity = "invalid parity segments" + +# Constants (unchanged) +const + SegmentsParityRate = 0.125 + SegmentsReedsolomonMaxCount = 256 + +# SegmentMessage type (unchanged) +type + SegmentMessage* = ref object + entireMessageHash*: seq[byte] + index*: uint32 + segmentsCount*: uint32 + paritySegmentIndex*: uint32 + paritySegmentsCount*: uint32 + payload*: seq[byte] + +# Validation methods (unchanged) +proc isValid*(s: SegmentMessage): bool = + s.segmentsCount >= 2 or s.paritySegmentsCount > 0 + +proc isParityMessage*(s: SegmentMessage): bool = + s.segmentsCount == 0 and s.paritySegmentsCount > 0 + +# MessageSender type (unchanged) +type + MessageSender* = ref object + messaging: Messaging + persistence: Persistence + logger: Logger + + Messaging = object + maxMessageSize: int + +# SegmentMessage proc (unchanged) +proc segmentMessage*(s: MessageSender, newMessage: WakuNewMessage): Result[seq[WakuNewMessage], string] = + let segmentSize = s.messaging.maxMessageSize div 4 * 3 + let (messages, err) = segmentMessage(newMessage, segmentSize) + if err.isSome: + return err("segmentMessage failed: " & err.get()) + s.logger.debug("message segmented", "segments", $messages.len) + return ok(messages) + +# # Replicate message (unchanged) +# proc replicateMessageWithNewPayload(message: WakuNewMessage, payload: seq[byte]): Result[WakuNewMessage, string] = +# var copy = WakuNewMessage(payload: payload) +# return ok(copy) + +# # Segment message into smaller chunks (updated with nim-leopard) +# proc segmentMessage(newMessage: WakuNewMessage, segmentSize: int): Result[seq[WakuNewMessage], string] = +# if newMessage.payload.len <= segmentSize: +# return ok(@[newMessage]) + +# let entireMessageHash = keccak256.digest(newMessage.payload) +# let entirePayloadSize = newMessage.payload.len + +# let segmentsCount = int(ceil(entirePayloadSize.float / segmentSize.float)) +# let paritySegmentsCount = int(floor(segmentsCount.float * SegmentsParityRate)) + +# var segmentPayloads = newSeq[seq[byte]](segmentsCount + paritySegmentsCount) +# var segmentMessages = newSeq[WakuNewMessage](segmentsCount) + +# for i in 0.. entirePayloadSize: +# end = entirePayloadSize + +# let segmentPayload = newMessage.payload[start.. SegmentsReedsolomonMaxCount: +# return ok(segmentMessages) + +# # Align last segment payload for Reed-Solomon (leopard requires fixed-size shards) +# let lastSegmentPayload = segmentPayloads[segmentsCount-1] +# segmentPayloads[segmentsCount-1] = newSeq[byte](segmentSize) +# copy(lastSegmentPayload, segmentPayloads[segmentsCount-1]) + +# # Allocate space for parity shards +# for i in segmentsCount..<(segmentsCount + paritySegmentsCount): +# segmentPayloads[i] = newSeq[byte](segmentSize) + +# # Use nim-leopard for Reed-Solomon encoding +# let encodeResult = leopard.encode(segmentPayloads, segmentsCount, paritySegmentsCount) +# if encodeResult.isErr: +# return err("failed to encode segments with leopard: " & encodeResult.error) + +# # Create parity messages +# for i in segmentsCount..<(segmentsCount + paritySegmentsCount): +# let parityIndex = i - segmentsCount +# let segmentWithMetadata = SegmentMessage( +# entireMessageHash: entireMessageHash.data, +# segmentsCount: 0, +# paritySegmentIndex: uint32(parityIndex), +# paritySegmentsCount: uint32(paritySegmentsCount), +# payload: segmentPayloads[i] +# ) + +# let marshaledSegment = protoMarshal(segmentWithMetadata) +# if marshaledSegment.isErr: +# return err("failed to marshal parity SegmentMessage: " & marshaledSegment.error) + +# let segmentMessage = replicateMessageWithNewPayload(newMessage, marshaledSegment.get()) +# if segmentMessage.isErr: +# return err("failed to replicate parity message: " & segmentMessage.error) + +# segmentMessages.add(segmentMessage.get()) + +# return ok(segmentMessages) + +# # Handle SegmentationLayerV2 (updated with nim-leopard) +# proc handleSegmentationLayerV2*(s: MessageSender, message: StatusMessage): Result[void, string] = +# let logger = s.logger.withFields( +# "site", "handleSegmentationLayerV2", +# "hash", message.transportLayer.hash.toHex +# ) + +# var segmentMessage = SegmentMessage() +# let unmarshalResult = protoUnmarshal(message.transportLayer.payload, segmentMessage) +# if unmarshalResult.isErr: +# return err("failed to unmarshal SegmentMessage: " & unmarshalResult.error) + +# logger.debug("handling message segment", +# "EntireMessageHash", segmentMessage.entireMessageHash.toHex, +# "Index", $segmentMessage.index, +# "SegmentsCount", $segmentMessage.segmentsCount, +# "ParitySegmentIndex", $segmentMessage.paritySegmentIndex, +# "ParitySegmentsCount", $segmentMessage.paritySegmentsCount +# ) + +# let alreadyCompleted = s.persistence.isMessageAlreadyCompleted(segmentMessage.entireMessageHash) +# if alreadyCompleted.isErr: +# return err(alreadyCompleted.error) +# if alreadyCompleted.get(): +# return err(ErrMessageSegmentsAlreadyCompleted) + +# if not segmentMessage.isValid(): +# return err(ErrMessageSegmentsInvalidCount) + +# let saveResult = s.persistence.saveMessageSegment(segmentMessage, message.transportLayer.sigPubKey, getTime().toUnix) +# if saveResult.isErr: +# return err(saveResult.error) + +# let segments = s.persistence.getMessageSegments(segmentMessage.entireMessageHash, message.transportLayer.sigPubKey) +# if segments.isErr: +# return err(segments.error) + +# if segments.get().len == 0: +# return err("unexpected state: no segments found after save operation") + +# let firstSegmentMessage = segments.get()[0] +# let lastSegmentMessage = segments.get()[^1] + +# if firstSegmentMessage.isParityMessage() or segments.get().len != int(firstSegmentMessage.segmentsCount): +# return err(ErrMessageSegmentsIncomplete) + +# var payloads = newSeq[seq[byte]](firstSegmentMessage.segmentsCount + lastSegmentMessage.paritySegmentsCount) +# let payloadSize = firstSegmentMessage.payload.len + +# let restoreUsingParityData = lastSegmentMessage.isParityMessage() +# if not restoreUsingParityData: +# for i, segment in segments.get(): +# payloads[i] = segment.payload +# else: +# var lastNonParitySegmentPayload: seq[byte] +# for segment in segments.get(): +# if not segment.isParityMessage(): +# if segment.index == firstSegmentMessage.segmentsCount - 1: +# payloads[segment.index] = newSeq[byte](payloadSize) +# copy(segment.payload, payloads[segment.index]) +# lastNonParitySegmentPayload = segment.payload +# else: +# payloads[segment.index] = segment.payload +# else: +# payloads[firstSegmentMessage.segmentsCount + segment.paritySegmentIndex] = segment.payload + +# # Use nim-leopard for Reed-Solomon reconstruction +# let reconstructResult = leopard.decode(payloads, int(firstSegmentMessage.segmentsCount), int(lastSegmentMessage.paritySegmentsCount)) +# if reconstructResult.isErr: +# return err("failed to reconstruct payloads with leopard: " & reconstructResult.error) + +# # Verify by checking hash (leopard doesn't have a direct verify function) +# var tempPayload = newSeq[byte]() +# for i in 0.. 0: +# payloads[firstSegmentMessage.segmentsCount - 1] = lastNonParitySegmentPayload + +# # Combine payload +# var entirePayload = newSeq[byte]() +# for i in 0.. Date: Wed, 14 May 2025 16:28:27 +0800 Subject: [PATCH 2/9] chore: compile the segmentation logic. --- chat_sdk/segmentation.nim | 472 +++++++++++++++++++++----------------- 1 file changed, 257 insertions(+), 215 deletions(-) diff --git a/chat_sdk/segmentation.nim b/chat_sdk/segmentation.nim index 222ab6f..ff38105 100644 --- a/chat_sdk/segmentation.nim +++ b/chat_sdk/segmentation.nim @@ -1,12 +1,23 @@ import math, times, sequtils, strutils, options import nimcrypto # For Keccak256 hashing -import logging # Placeholder for logging import results import leopard # Import nim-leopard for Reed-Solomon +import chronicles # External dependencies (still needed) # import protobuf # Nim protobuf library (e.g., protobuf-nim) +# SegmentMessage type (unchanged) +type + SegmentMessage* = ref object + entireMessageHash*: seq[byte] + index*: uint32 + segmentsCount*: uint32 + paritySegmentIndex*: uint32 + paritySegmentsCount*: uint32 + payload*: seq[byte] + + # Placeholder types (unchanged) type WakuNewMessage = object @@ -22,7 +33,8 @@ type sigPubKey: seq[byte] Persistence = object - # Placeholder for persistence interface + completedMessages: Table[seq[byte], bool] # Mock storage for completed message hashes + segments: Table[(seq[byte], seq[byte]), seq[SegmentMessage]] # Stores segments by (hash, sigPubKey) # Error definitions (unchanged) const @@ -37,15 +49,6 @@ const SegmentsParityRate = 0.125 SegmentsReedsolomonMaxCount = 256 -# SegmentMessage type (unchanged) -type - SegmentMessage* = ref object - entireMessageHash*: seq[byte] - index*: uint32 - segmentsCount*: uint32 - paritySegmentIndex*: uint32 - paritySegmentsCount*: uint32 - payload*: seq[byte] # Validation methods (unchanged) proc isValid*(s: SegmentMessage): bool = @@ -59,219 +62,258 @@ type MessageSender* = ref object messaging: Messaging persistence: Persistence - logger: Logger Messaging = object maxMessageSize: int +proc initPersistence(): Persistence = + Persistence( + completedMessages: initTable[seq[byte], bool](), + segments: initTable[(seq[byte], seq[byte]), seq[SegmentMessage]]() + ) + +proc isMessageAlreadyCompleted(p: Persistence, hash: seq[byte]): Result[bool, string] = + return ok(p.completedMessages.getOrDefault(hash, false)) + +proc saveMessageSegment(p: var Persistence, segment: SegmentMessage, sigPubKey: seq[byte], timestamp: int64): Result[void, string] = + let key = (segment.entireMessageHash, sigPubKey) + # Initialize or append to the segments list + if not p.segments.hasKey(key): + p.segments[key] = @[] + p.segments[key].add(segment) + return ok() + +proc getMessageSegments(p: Persistence, hash: seq[byte], sigPubKey: seq[byte]): Result[seq[SegmentMessage], string] = + let key = (hash, sigPubKey) + return ok(p.segments.getOrDefault(key, @[])) + +proc completeMessageSegments(p: var Persistence, hash: seq[byte], sigPubKey: seq[byte], timestamp: int64): Result[void, string] = + p.completedMessages[hash] = true + return ok() + +# Replicate message (unchanged) +proc replicateMessageWithNewPayload(message: WakuNewMessage, payload: seq[byte]): Result[WakuNewMessage, string] = + var copiedMessage = WakuNewMessage(payload: payload) + return ok(copiedMessage) + +proc protoMarshal(msg: SegmentMessage): Result[seq[byte], string] = + return err("protoMarshal not implemented") + +proc protoUnmarshal(data: seq[byte], msg: var SegmentMessage): Result[void, string] = + return err("protoUnmarshal not implemented") + +# Segment message into smaller chunks (updated with nim-leopard) +proc segmentMessageInternal(newMessage: WakuNewMessage, segmentSize: int): Result[seq[WakuNewMessage], string] = + if newMessage.payload.len <= segmentSize: + return ok(@[newMessage]) + + let entireMessageHash = keccak256.digest(newMessage.payload) + let entirePayloadSize = newMessage.payload.len + + let segmentsCount = int(ceil(entirePayloadSize.float / segmentSize.float)) + let paritySegmentsCount = int(floor(segmentsCount.float * SegmentsParityRate)) + + var segmentPayloads = newSeq[seq[byte]](segmentsCount + paritySegmentsCount) + var segmentMessages = newSeq[WakuNewMessage](segmentsCount) + + for i in 0.. entirePayloadSize: + endIndex = entirePayloadSize + + let segmentPayload = newMessage.payload[start.. SegmentsReedsolomonMaxCount: + return ok(segmentMessages) + + # Align last segment payload for Reed-Solomon (leopard requires fixed-size shards) + let lastSegmentPayload = segmentPayloads[segmentsCount-1] + segmentPayloads[segmentsCount-1] = newSeq[byte](segmentSize) + segmentPayloads[segmentsCount-1][0.. 0: + payloads[firstSegmentMessage.segmentsCount - 1] = lastNonParitySegmentPayload + + # Combine payload + var entirePayload = newSeq[byte]() + for i in 0.. entirePayloadSize: -# end = entirePayloadSize - -# let segmentPayload = newMessage.payload[start.. SegmentsReedsolomonMaxCount: -# return ok(segmentMessages) - -# # Align last segment payload for Reed-Solomon (leopard requires fixed-size shards) -# let lastSegmentPayload = segmentPayloads[segmentsCount-1] -# segmentPayloads[segmentsCount-1] = newSeq[byte](segmentSize) -# copy(lastSegmentPayload, segmentPayloads[segmentsCount-1]) - -# # Allocate space for parity shards -# for i in segmentsCount..<(segmentsCount + paritySegmentsCount): -# segmentPayloads[i] = newSeq[byte](segmentSize) - -# # Use nim-leopard for Reed-Solomon encoding -# let encodeResult = leopard.encode(segmentPayloads, segmentsCount, paritySegmentsCount) -# if encodeResult.isErr: -# return err("failed to encode segments with leopard: " & encodeResult.error) - -# # Create parity messages -# for i in segmentsCount..<(segmentsCount + paritySegmentsCount): -# let parityIndex = i - segmentsCount -# let segmentWithMetadata = SegmentMessage( -# entireMessageHash: entireMessageHash.data, -# segmentsCount: 0, -# paritySegmentIndex: uint32(parityIndex), -# paritySegmentsCount: uint32(paritySegmentsCount), -# payload: segmentPayloads[i] -# ) - -# let marshaledSegment = protoMarshal(segmentWithMetadata) -# if marshaledSegment.isErr: -# return err("failed to marshal parity SegmentMessage: " & marshaledSegment.error) - -# let segmentMessage = replicateMessageWithNewPayload(newMessage, marshaledSegment.get()) -# if segmentMessage.isErr: -# return err("failed to replicate parity message: " & segmentMessage.error) - -# segmentMessages.add(segmentMessage.get()) - -# return ok(segmentMessages) - -# # Handle SegmentationLayerV2 (updated with nim-leopard) -# proc handleSegmentationLayerV2*(s: MessageSender, message: StatusMessage): Result[void, string] = -# let logger = s.logger.withFields( -# "site", "handleSegmentationLayerV2", -# "hash", message.transportLayer.hash.toHex -# ) - -# var segmentMessage = SegmentMessage() -# let unmarshalResult = protoUnmarshal(message.transportLayer.payload, segmentMessage) -# if unmarshalResult.isErr: -# return err("failed to unmarshal SegmentMessage: " & unmarshalResult.error) - -# logger.debug("handling message segment", -# "EntireMessageHash", segmentMessage.entireMessageHash.toHex, -# "Index", $segmentMessage.index, -# "SegmentsCount", $segmentMessage.segmentsCount, -# "ParitySegmentIndex", $segmentMessage.paritySegmentIndex, -# "ParitySegmentsCount", $segmentMessage.paritySegmentsCount -# ) - -# let alreadyCompleted = s.persistence.isMessageAlreadyCompleted(segmentMessage.entireMessageHash) -# if alreadyCompleted.isErr: -# return err(alreadyCompleted.error) -# if alreadyCompleted.get(): -# return err(ErrMessageSegmentsAlreadyCompleted) - -# if not segmentMessage.isValid(): -# return err(ErrMessageSegmentsInvalidCount) - -# let saveResult = s.persistence.saveMessageSegment(segmentMessage, message.transportLayer.sigPubKey, getTime().toUnix) -# if saveResult.isErr: -# return err(saveResult.error) - -# let segments = s.persistence.getMessageSegments(segmentMessage.entireMessageHash, message.transportLayer.sigPubKey) -# if segments.isErr: -# return err(segments.error) - -# if segments.get().len == 0: -# return err("unexpected state: no segments found after save operation") - -# let firstSegmentMessage = segments.get()[0] -# let lastSegmentMessage = segments.get()[^1] - -# if firstSegmentMessage.isParityMessage() or segments.get().len != int(firstSegmentMessage.segmentsCount): -# return err(ErrMessageSegmentsIncomplete) - -# var payloads = newSeq[seq[byte]](firstSegmentMessage.segmentsCount + lastSegmentMessage.paritySegmentsCount) -# let payloadSize = firstSegmentMessage.payload.len - -# let restoreUsingParityData = lastSegmentMessage.isParityMessage() -# if not restoreUsingParityData: -# for i, segment in segments.get(): -# payloads[i] = segment.payload -# else: -# var lastNonParitySegmentPayload: seq[byte] -# for segment in segments.get(): -# if not segment.isParityMessage(): -# if segment.index == firstSegmentMessage.segmentsCount - 1: -# payloads[segment.index] = newSeq[byte](payloadSize) -# copy(segment.payload, payloads[segment.index]) -# lastNonParitySegmentPayload = segment.payload -# else: -# payloads[segment.index] = segment.payload -# else: -# payloads[firstSegmentMessage.segmentsCount + segment.paritySegmentIndex] = segment.payload - -# # Use nim-leopard for Reed-Solomon reconstruction -# let reconstructResult = leopard.decode(payloads, int(firstSegmentMessage.segmentsCount), int(lastSegmentMessage.paritySegmentsCount)) -# if reconstructResult.isErr: -# return err("failed to reconstruct payloads with leopard: " & reconstructResult.error) - -# # Verify by checking hash (leopard doesn't have a direct verify function) -# var tempPayload = newSeq[byte]() -# for i in 0.. 0: -# payloads[firstSegmentMessage.segmentsCount - 1] = lastNonParitySegmentPayload - -# # Combine payload -# var entirePayload = newSeq[byte]() -# for i in 0.. Date: Tue, 1 Jul 2025 14:16:12 +0800 Subject: [PATCH 3/9] chore: add test for segmentation --- chat_sdk/segmentation.nim | 38 +++++++++++++++----------- tests/test_chat_sdk.nim | 1 - tests/test_segmentation.nim | 53 +++++++++++++++++++++++++++++++++++++ 3 files changed, 76 insertions(+), 16 deletions(-) create mode 100644 tests/test_segmentation.nim diff --git a/chat_sdk/segmentation.nim b/chat_sdk/segmentation.nim index ff38105..3a61390 100644 --- a/chat_sdk/segmentation.nim +++ b/chat_sdk/segmentation.nim @@ -20,17 +20,17 @@ type # Placeholder types (unchanged) type - WakuNewMessage = object - payload: seq[byte] + WakuNewMessage* = object + payload*: seq[byte] # Add other fields as needed - StatusMessage = object - transportLayer: TransportLayer + StatusMessage* = object + transportLayer*: TransportLayer - TransportLayer = object - hash: seq[byte] - payload: seq[byte] - sigPubKey: seq[byte] + TransportLayer* = object + hash*: seq[byte] + payload*: seq[byte] + sigPubKey*: seq[byte] Persistence = object completedMessages: Table[seq[byte], bool] # Mock storage for completed message hashes @@ -60,13 +60,13 @@ proc isParityMessage*(s: SegmentMessage): bool = # MessageSender type (unchanged) type MessageSender* = ref object - messaging: Messaging - persistence: Persistence + messaging*: Messaging + persistence*: Persistence - Messaging = object - maxMessageSize: int + Messaging* = object + maxMessageSize*: int -proc initPersistence(): Persistence = +proc initPersistence*(): Persistence = Persistence( completedMessages: initTable[seq[byte], bool](), segments: initTable[(seq[byte], seq[byte]), seq[SegmentMessage]]() @@ -97,10 +97,18 @@ proc replicateMessageWithNewPayload(message: WakuNewMessage, payload: seq[byte]) return ok(copiedMessage) proc protoMarshal(msg: SegmentMessage): Result[seq[byte], string] = - return err("protoMarshal not implemented") + # Fake serialization (index + payload length) TODO + return ok(@[byte(msg.index)] & msg.payload) proc protoUnmarshal(data: seq[byte], msg: var SegmentMessage): Result[void, string] = - return err("protoUnmarshal not implemented") + # Fake deserialization (reconstruct index and payload) + if data.len < 1: + return err("data too short") + msg.index = uint32(data[0]) + msg.payload = data[1..^1] + msg.segmentsCount = 2 + msg.entireMessageHash = @[byte 1, 2, 3] # Dummy hash + return ok() # Segment message into smaller chunks (updated with nim-leopard) proc segmentMessageInternal(newMessage: WakuNewMessage, segmentSize: int): Result[seq[WakuNewMessage], string] = diff --git a/tests/test_chat_sdk.nim b/tests/test_chat_sdk.nim index abb9df8..a588132 100644 --- a/tests/test_chat_sdk.nim +++ b/tests/test_chat_sdk.nim @@ -12,4 +12,3 @@ import chat_sdk/segmentation test "can add": check add(5, 5) == 10 - check add2(5, 5) == 11 diff --git a/tests/test_segmentation.nim b/tests/test_segmentation.nim new file mode 100644 index 0000000..fa1677f --- /dev/null +++ b/tests/test_segmentation.nim @@ -0,0 +1,53 @@ +# This is just an example to get you started. You may wish to put all of your +# tests into a single file, or separate them into multiple `test1`, `test2` +# etc. files (better names are recommended, just make sure the name starts with +# the letter 't'). +# +# To run these tests, simply execute `nimble test`. + +import unittest, sequtils, random +import results +import chat_sdk/segmentation # Replace with the actual file name + +test "can add": + check add2(5, 5) == 10 + +suite "Message Segmentation": + let testPayload = toSeq(0..999).mapIt(byte(it)) + let testMessage = WakuNewMessage(payload: testPayload) + let sender = MessageSender( + messaging: Messaging(maxMessageSize: 500), + persistence: initPersistence() + ) + + test "Segment and reassemble full segments": + let segmentResult = sender.segmentMessage(testMessage) + check segmentResult.isOk + var segments = segmentResult.get() + check segments.len > 0 + + # Shuffle segment order for out-of-order test + segments.shuffle() + + for segment in segments: + var statusMsg: StatusMessage + statusMsg.transportLayer = TransportLayer( + hash: @[byte 0, 1, 2], # Dummy hash + payload: segment.payload, + sigPubKey: @[byte 3, 4, 5] + ) + let result = sender.handleSegmentationLayerV2(statusMsg) + discard result # Ignore intermediate errors + + # One last run to trigger completion + var finalStatus: StatusMessage + finalStatus.transportLayer = TransportLayer( + hash: @[byte 0, 1, 2], + payload: segments[0].payload, + sigPubKey: @[byte 3, 4, 5] + ) + let finalResult = sender.handleSegmentationLayerV2(finalStatus) + check finalResult.isOk + + # Check payload restored + check finalStatus.transportLayer.payload == testPayload[0.. Date: Wed, 16 Jul 2025 11:56:10 +0800 Subject: [PATCH 4/9] feat: db operations for message segments --- chat_sdk.nimble | 3 + chat_sdk/db.nim | 121 ++++++++++++++++++ chat_sdk/db_models.nim | 8 ++ chat_sdk/segmentation.nim | 11 +- .../002_create_message_segments_table.sql | 19 +++ 5 files changed, 154 insertions(+), 8 deletions(-) create mode 100644 chat_sdk/db.nim create mode 100644 chat_sdk/db_models.nim create mode 100644 migrations/002_create_message_segments_table.sql diff --git a/chat_sdk.nimble b/chat_sdk.nimble index b7d1da0..199cdf5 100644 --- a/chat_sdk.nimble +++ b/chat_sdk.nimble @@ -17,3 +17,6 @@ task buildStaticLib, "Build static library for C bindings": task migrate, "Run database migrations": exec "nim c -r apps/run_migration.nim" + +task segment, "Run segmentation": + exec "nim c -r chat_sdk/segmentation.nim" diff --git a/chat_sdk/db.nim b/chat_sdk/db.nim new file mode 100644 index 0000000..cde7bf0 --- /dev/null +++ b/chat_sdk/db.nim @@ -0,0 +1,121 @@ +import std/strutils +import results +import db_connector/db_sqlite +import db_models + +type MessageSegmentsPersistence = object + db: DbConn + +proc removeMessageSegmentsOlderThan*( + self: MessageSegmentsPersistence, timestamp: int64 +): Result[void, string] = + try: + self.db.exec(sql"DELETE FROM message_segments WHERE timestamp < ?", timestamp) + ok() + except DbError as e: + err("remove message segments with error: " & e.msg) + +proc removeMessageSegmentsCompletedOlderThan*( + self: MessageSegmentsPersistence, timestamp: int64 +): Result[void, string] = + try: + self.db.exec( + sql"DELETE FROM message_segments_completed WHERE timestamp < ?", timestamp + ) + ok() + except DbError as e: + err("remove message segments completed with error: " & e.msg) + +proc isMessageAlreadyCompleted*( + self: MessageSegmentsPersistence, hash: seq[byte] +): Result[bool, string] = + try: + let row = self.db.getRow( + sql"SELECT COUNT(*) FROM message_segments_completed WHERE hash = ?", hash + ) + if row.len == 0: + return ok(false) + let count = row[0].parseInt + ok(count > 0) + except CatchableError as e: + err("check message already completed with error: " & e.msg) + +proc saveMessageSegment*( + self: MessageSegmentsPersistence, + segment: SegmentMessage, + sigPubKeyBlob: seq[byte], + timestamp: int64, +): Result[void, string] = + try: + self.db.exec( + sql""" + INSERT INTO message_segments ( + hash, segment_index, segments_count, parity_segment_index, + parity_segments_count, sig_pub_key, payload, timestamp + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?)""", + segment.entireMessageHash, + segment.index, + segment.segmentsCount, + segment.paritySegmentIndex, + segment.paritySegmentsCount, + sigPubKeyBlob, + segment.payload, + timestamp, + ) + ok() + except DbError as e: + err("save message segments with error: " & e.msg) + +proc getMessageSegments*(self: MessageSegmentsPersistence, hash: seq[byte], sigPubKeyBlob: seq[byte]): Result[seq[SegmentMessage], string] = + var segments = newSeq[SegmentMessage]() + let query = sql""" + SELECT + hash, segment_index, segments_count, parity_segment_index, parity_segments_count, payload + FROM + message_segments + WHERE + hash = ? AND sig_pub_key = ? + ORDER BY + (segments_count = 0) ASC, + segment_index ASC, + parity_segment_index ASC + """ + + try: + for row in self.db.rows(query, hash, sigPubKeyBlob): + let segment = SegmentMessage( + entireMessageHash: cast[seq[byte]](row[0]), + index: uint32(parseInt(row[1])), + segmentsCount: uint32(parseInt(row[2])), + paritySegmentIndex: uint32(parseInt(row[3])), + paritySegmentsCount: uint32(parseInt(row[4])), + payload: cast[seq[byte]](row[5]) + ) + segments.add(segment) + return ok(segments) + except CatchableError as e: + return err("get Message Segments with error: " & e.msg) + +proc completeMessageSegments*( + self: MessageSegmentsPersistence, + hash: seq[byte], + sigPubKeyBlob: seq[byte], + timestamp: int64 +): Result[void, string] = + try: + self.db.exec(sql"BEGIN") + + # Delete old message segments + self.db.exec(sql"DELETE FROM message_segments WHERE hash = ? AND sig_pub_key = ?", hash, sigPubKeyBlob) + + # Insert completed marker + self.db.exec(sql"INSERT INTO message_segments_completed (hash, sig_pub_key, timestamp) VALUES (?, ?, ?)", + hash, sigPubKeyBlob, timestamp) + + self.db.exec(sql"COMMIT") + return ok() + except DbError as e: + try: + self.db.exec(sql"ROLLBACK") + except: discard + return err("complete segment messages with error: " & e.msg) \ No newline at end of file diff --git a/chat_sdk/db_models.nim b/chat_sdk/db_models.nim new file mode 100644 index 0000000..392760b --- /dev/null +++ b/chat_sdk/db_models.nim @@ -0,0 +1,8 @@ +type + SegmentMessage* = ref object + entireMessageHash*: seq[byte] + index*: uint32 + segmentsCount*: uint32 + paritySegmentIndex*: uint32 + paritySegmentsCount*: uint32 + payload*: seq[byte] diff --git a/chat_sdk/segmentation.nim b/chat_sdk/segmentation.nim index 3a61390..1f027ef 100644 --- a/chat_sdk/segmentation.nim +++ b/chat_sdk/segmentation.nim @@ -1,7 +1,7 @@ import math, times, sequtils, strutils, options -import nimcrypto # For Keccak256 hashing +import nimcrypto import results -import leopard # Import nim-leopard for Reed-Solomon +import leopard import chronicles # External dependencies (still needed) @@ -196,8 +196,7 @@ proc segmentMessageInternal(newMessage: WakuNewMessage, segmentSize: int): Resul return ok(segmentMessages) -# Handle SegmentationLayerV2 (updated with nim-leopard) -proc handleSegmentationLayerV2*(s: MessageSender, message: var StatusMessage): Result[void, string] = +proc handleSegmentationLayer*(s: MessageSender, message: var StatusMessage): Result[void, string] = logScope: site = "handleSegmentationLayerV2" hash = message.transportLayer.hash.toHex @@ -303,10 +302,6 @@ proc handleSegmentationLayerV2*(s: MessageSender, message: var StatusMessage): R message.transportLayer.payload = entirePayload return ok() -# Other procs (unchanged) -proc handleSegmentationLayerV1*(s: MessageSender, message: StatusMessage): Result[void, string] = - # Same as previous translation - discard proc cleanupSegments*(s: MessageSender): Result[void, string] = # Same as previous translation diff --git a/migrations/002_create_message_segments_table.sql b/migrations/002_create_message_segments_table.sql new file mode 100644 index 0000000..df75ee7 --- /dev/null +++ b/migrations/002_create_message_segments_table.sql @@ -0,0 +1,19 @@ +CREATE TABLE IF NOT EXISTS message_segments ( + hash BLOB NOT NULL, + segment_index INTEGER NOT NULL, + segments_count INTEGER NOT NULL, + payload BLOB NOT NULL, + sig_pub_key BLOB NOT NULL, + timestamp INTEGER DEFAULT 0, + PRIMARY KEY (hash, sig_pub_key, segment_index) ON CONFLICT REPLACE +); + +CREATE TABLE IF NOT EXISTS message_segments_completed ( + hash BLOB NOT NULL, + sig_pub_key BLOB NOT NULL, + timestamp INTEGER DEFAULT 0, + PRIMARY KEY (hash, sig_pub_key) +); + +CREATE INDEX idx_message_segments_timestamp ON message_segments(timestamp); +CREATE INDEX idx_message_segments_completed_timestamp ON message_segments_completed(timestamp); \ No newline at end of file From 211f8275db5ddb9cab54ecd329e53bc3551ae45e Mon Sep 17 00:00:00 2001 From: kaichaosun Date: Wed, 16 Jul 2025 14:47:47 +0800 Subject: [PATCH 5/9] chore: refactor data types --- chat_sdk/db.nim | 2 +- chat_sdk/segmentation.nim | 103 +++++++++++++++++--------------------- 2 files changed, 46 insertions(+), 59 deletions(-) diff --git a/chat_sdk/db.nim b/chat_sdk/db.nim index cde7bf0..23bda65 100644 --- a/chat_sdk/db.nim +++ b/chat_sdk/db.nim @@ -118,4 +118,4 @@ proc completeMessageSegments*( try: self.db.exec(sql"ROLLBACK") except: discard - return err("complete segment messages with error: " & e.msg) \ No newline at end of file + return err("complete segment messages with error: " & e.msg) diff --git a/chat_sdk/segmentation.nim b/chat_sdk/segmentation.nim index 1f027ef..eb21c6c 100644 --- a/chat_sdk/segmentation.nim +++ b/chat_sdk/segmentation.nim @@ -4,19 +4,12 @@ import results import leopard import chronicles -# External dependencies (still needed) -# import protobuf # Nim protobuf library (e.g., protobuf-nim) +import protobuf_serialization/proto_parser +import protobuf_serialization -# SegmentMessage type (unchanged) -type - SegmentMessage* = ref object - entireMessageHash*: seq[byte] - index*: uint32 - segmentsCount*: uint32 - paritySegmentIndex*: uint32 - paritySegmentsCount*: uint32 - payload*: seq[byte] +import db_models +import_proto3 "segment_message.proto" # Placeholder types (unchanged) type @@ -24,10 +17,7 @@ type payload*: seq[byte] # Add other fields as needed - StatusMessage* = object - transportLayer*: TransportLayer - - TransportLayer* = object + Message* = object hash*: seq[byte] payload*: seq[byte] sigPubKey*: seq[byte] @@ -50,9 +40,21 @@ const SegmentsReedsolomonMaxCount = 256 -# Validation methods (unchanged) -proc isValid*(s: SegmentMessage): bool = - s.segmentsCount >= 2 or s.paritySegmentsCount > 0 +proc isValid(s: SegmentMessageProto): bool = + # Check if the hash length is valid (32 bytes for Keccak256) + if s.entire_message_hash.len != 32: + return false + + # Check if the segment index is within the valid range + if s.segments_count > 0 and s.index >= s.segments_count: + return false + + # Check if the parity segment index is within the valid range + if s.parity_segments_count > 0 and s.parity_segment_index >= s.parity_segments_count: + return false + + # Check if segments_count is at least 2 or parity_segments_count is positive + return s.segments_count >= 2 or s.parity_segments_count > 0 proc isParityMessage*(s: SegmentMessage): bool = s.segmentsCount == 0 and s.paritySegmentsCount > 0 @@ -100,16 +102,6 @@ proc protoMarshal(msg: SegmentMessage): Result[seq[byte], string] = # Fake serialization (index + payload length) TODO return ok(@[byte(msg.index)] & msg.payload) -proc protoUnmarshal(data: seq[byte], msg: var SegmentMessage): Result[void, string] = - # Fake deserialization (reconstruct index and payload) - if data.len < 1: - return err("data too short") - msg.index = uint32(data[0]) - msg.payload = data[1..^1] - msg.segmentsCount = 2 - msg.entireMessageHash = @[byte 1, 2, 3] # Dummy hash - return ok() - # Segment message into smaller chunks (updated with nim-leopard) proc segmentMessageInternal(newMessage: WakuNewMessage, segmentSize: int): Result[seq[WakuNewMessage], string] = if newMessage.payload.len <= segmentSize: @@ -196,38 +188,43 @@ proc segmentMessageInternal(newMessage: WakuNewMessage, segmentSize: int): Resul return ok(segmentMessages) -proc handleSegmentationLayer*(s: MessageSender, message: var StatusMessage): Result[void, string] = +proc handleSegmentationLayer*(s: MessageSender, message: var Message): Result[void, string] = logScope: - site = "handleSegmentationLayerV2" - hash = message.transportLayer.hash.toHex + site = "handleSegmentationLayer" + hash = message.hash.toHex - var segmentMessage = SegmentMessage() - let unmarshalResult = protoUnmarshal(message.transportLayer.payload, segmentMessage) - if unmarshalResult.isErr: - return err("failed to unmarshal SegmentMessage: " & unmarshalResult.error) + let segmentMessageProto = Protobuf.decode(message.payload, SegmentMessageProto) debug "handling message segment", - "EntireMessageHash" = segmentMessage.entireMessageHash.toHex, - "Index" = $segmentMessage.index, - "SegmentsCount" = $segmentMessage.segmentsCount, - "ParitySegmentIndex" = $segmentMessage.paritySegmentIndex, - "ParitySegmentsCount" = $segmentMessage.paritySegmentsCount + "EntireMessageHash" = segmentMessageProto.entireMessageHash.toHex, + "Index" = $segmentMessageProto.index, + "SegmentsCount" = $segmentMessageProto.segmentsCount, + "ParitySegmentIndex" = $segmentMessageProto.paritySegmentIndex, + "ParitySegmentsCount" = $segmentMessageProto.paritySegmentsCount # TODO here use the mock function, real function should be used together with the persistence layer - let alreadyCompleted = s.persistence.isMessageAlreadyCompleted(segmentMessage.entireMessageHash) + let alreadyCompleted = s.persistence.isMessageAlreadyCompleted(segmentMessageProto.entireMessageHash) if alreadyCompleted.isErr: return err(alreadyCompleted.error) if alreadyCompleted.get(): return err(ErrMessageSegmentsAlreadyCompleted) - if not segmentMessage.isValid(): + if not segmentMessageProto.isValid(): return err(ErrMessageSegmentsInvalidCount) - let saveResult = s.persistence.saveMessageSegment(segmentMessage, message.transportLayer.sigPubKey, getTime().toUnix) + let segmentMessage = SegmentMessage( + entireMessageHash: segmentMessageProto.entireMessageHash, + index: segmentMessageProto.index, + segmentsCount: segmentMessageProto.segmentsCount, + paritySegmentIndex: segmentMessageProto.paritySegmentIndex, + paritySegmentsCount: segmentMessageProto.paritySegmentsCount, + payload: segmentMessageProto.payload + ) + let saveResult = s.persistence.saveMessageSegment(segmentMessage, message.sigPubKey, getTime().toUnix) if saveResult.isErr: return err(saveResult.error) - let segments = s.persistence.getMessageSegments(segmentMessage.entireMessageHash, message.transportLayer.sigPubKey) + let segments = s.persistence.getMessageSegments(segmentMessage.entireMessageHash, message.sigPubKey) if segments.isErr: return err(segments.error) @@ -295,11 +292,11 @@ proc handleSegmentationLayer*(s: MessageSender, message: var StatusMessage): Res if entirePayloadHash.data != segmentMessage.entireMessageHash: return err(ErrMessageSegmentsHashMismatch) - let completeResult = s.persistence.completeMessageSegments(segmentMessage.entireMessageHash, message.transportLayer.sigPubKey, getTime().toUnix) + let completeResult = s.persistence.completeMessageSegments(segmentMessage.entireMessageHash, message.sigPubKey, getTime().toUnix) if completeResult.isErr: return err(completeResult.error) - message.transportLayer.payload = entirePayload + message.payload = entirePayload return ok() @@ -317,17 +314,7 @@ proc segmentMessage*(s: MessageSender, newMessage: WakuNewMessage): Result[seq[W debug "message segmented", "segments" = $messages.len return ok(messages) -proc demoRounding() = - let x = 3.7 - let y = -3.7 - - echo "ceil(", x, ") = ", ceil(x) # Rounds up - echo "floor(", x, ") = ", floor(x) # Rounds down - echo "round(", x, ") = ", round(x) # Rounds to nearest integer - echo "trunc(", x, ") = ", trunc(x) # Truncates decimal part - echo "ceil(", y, ") = ", ceil(y) - echo "floor(", y, ") = ", floor(y) - +proc demo() = let bufSize = 64 # byte count per buffer, must be a multiple of 64 buffers = 239 # number of data symbols @@ -342,7 +329,7 @@ proc demoRounding() = when isMainModule: - demoRounding() + demo() proc add2*(x, y: int): int = ## Adds two numbers together. From 001e5b97c9e7920e642de7d1764f1cbcfdbd9a95 Mon Sep 17 00:00:00 2001 From: kaichaosun Date: Wed, 16 Jul 2025 15:54:32 +0800 Subject: [PATCH 6/9] chore: refactor segment logic --- chat_sdk/db.nim | 14 ++--- chat_sdk/segmentation.nim | 120 +++++++++----------------------------- 2 files changed, 35 insertions(+), 99 deletions(-) diff --git a/chat_sdk/db.nim b/chat_sdk/db.nim index 23bda65..0cdd936 100644 --- a/chat_sdk/db.nim +++ b/chat_sdk/db.nim @@ -3,11 +3,11 @@ import results import db_connector/db_sqlite import db_models -type MessageSegmentsPersistence = object +type SegmentationPersistence* = object db: DbConn proc removeMessageSegmentsOlderThan*( - self: MessageSegmentsPersistence, timestamp: int64 + self: SegmentationPersistence, timestamp: int64 ): Result[void, string] = try: self.db.exec(sql"DELETE FROM message_segments WHERE timestamp < ?", timestamp) @@ -16,7 +16,7 @@ proc removeMessageSegmentsOlderThan*( err("remove message segments with error: " & e.msg) proc removeMessageSegmentsCompletedOlderThan*( - self: MessageSegmentsPersistence, timestamp: int64 + self: SegmentationPersistence, timestamp: int64 ): Result[void, string] = try: self.db.exec( @@ -27,7 +27,7 @@ proc removeMessageSegmentsCompletedOlderThan*( err("remove message segments completed with error: " & e.msg) proc isMessageAlreadyCompleted*( - self: MessageSegmentsPersistence, hash: seq[byte] + self: SegmentationPersistence, hash: seq[byte] ): Result[bool, string] = try: let row = self.db.getRow( @@ -41,7 +41,7 @@ proc isMessageAlreadyCompleted*( err("check message already completed with error: " & e.msg) proc saveMessageSegment*( - self: MessageSegmentsPersistence, + self: SegmentationPersistence, segment: SegmentMessage, sigPubKeyBlob: seq[byte], timestamp: int64, @@ -66,7 +66,7 @@ proc saveMessageSegment*( except DbError as e: err("save message segments with error: " & e.msg) -proc getMessageSegments*(self: MessageSegmentsPersistence, hash: seq[byte], sigPubKeyBlob: seq[byte]): Result[seq[SegmentMessage], string] = +proc getMessageSegments*(self: SegmentationPersistence, hash: seq[byte], sigPubKeyBlob: seq[byte]): Result[seq[SegmentMessage], string] = var segments = newSeq[SegmentMessage]() let query = sql""" SELECT @@ -97,7 +97,7 @@ proc getMessageSegments*(self: MessageSegmentsPersistence, hash: seq[byte], sigP return err("get Message Segments with error: " & e.msg) proc completeMessageSegments*( - self: MessageSegmentsPersistence, + self: SegmentationPersistence, hash: seq[byte], sigPubKeyBlob: seq[byte], timestamp: int64 diff --git a/chat_sdk/segmentation.nim b/chat_sdk/segmentation.nim index eb21c6c..c758f98 100644 --- a/chat_sdk/segmentation.nim +++ b/chat_sdk/segmentation.nim @@ -8,25 +8,28 @@ import protobuf_serialization/proto_parser import protobuf_serialization import db_models +import db import_proto3 "segment_message.proto" # Placeholder types (unchanged) type - WakuNewMessage* = object + Chunk* = object payload*: seq[byte] - # Add other fields as needed Message* = object hash*: seq[byte] payload*: seq[byte] sigPubKey*: seq[byte] + + SegmentationHander* = object + maxMessageSize*: int + persistence*: SegmentationPersistence - Persistence = object - completedMessages: Table[seq[byte], bool] # Mock storage for completed message hashes - segments: Table[(seq[byte], seq[byte]), seq[SegmentMessage]] # Stores segments by (hash, sigPubKey) +const + SegmentsParityRate = 0.125 + SegmentsReedsolomonMaxCount = 256 -# Error definitions (unchanged) const ErrMessageSegmentsIncomplete = "message segments incomplete" ErrMessageSegmentsAlreadyCompleted = "message segments already completed" @@ -34,12 +37,6 @@ const ErrMessageSegmentsHashMismatch = "hash of entire payload does not match" ErrMessageSegmentsInvalidParity = "invalid parity segments" -# Constants (unchanged) -const - SegmentsParityRate = 0.125 - SegmentsReedsolomonMaxCount = 256 - - proc isValid(s: SegmentMessageProto): bool = # Check if the hash length is valid (32 bytes for Keccak256) if s.entire_message_hash.len != 32: @@ -59,51 +56,8 @@ proc isValid(s: SegmentMessageProto): bool = proc isParityMessage*(s: SegmentMessage): bool = s.segmentsCount == 0 and s.paritySegmentsCount > 0 -# MessageSender type (unchanged) -type - MessageSender* = ref object - messaging*: Messaging - persistence*: Persistence - - Messaging* = object - maxMessageSize*: int - -proc initPersistence*(): Persistence = - Persistence( - completedMessages: initTable[seq[byte], bool](), - segments: initTable[(seq[byte], seq[byte]), seq[SegmentMessage]]() - ) - -proc isMessageAlreadyCompleted(p: Persistence, hash: seq[byte]): Result[bool, string] = - return ok(p.completedMessages.getOrDefault(hash, false)) - -proc saveMessageSegment(p: var Persistence, segment: SegmentMessage, sigPubKey: seq[byte], timestamp: int64): Result[void, string] = - let key = (segment.entireMessageHash, sigPubKey) - # Initialize or append to the segments list - if not p.segments.hasKey(key): - p.segments[key] = @[] - p.segments[key].add(segment) - return ok() - -proc getMessageSegments(p: Persistence, hash: seq[byte], sigPubKey: seq[byte]): Result[seq[SegmentMessage], string] = - let key = (hash, sigPubKey) - return ok(p.segments.getOrDefault(key, @[])) - -proc completeMessageSegments(p: var Persistence, hash: seq[byte], sigPubKey: seq[byte], timestamp: int64): Result[void, string] = - p.completedMessages[hash] = true - return ok() - -# Replicate message (unchanged) -proc replicateMessageWithNewPayload(message: WakuNewMessage, payload: seq[byte]): Result[WakuNewMessage, string] = - var copiedMessage = WakuNewMessage(payload: payload) - return ok(copiedMessage) - -proc protoMarshal(msg: SegmentMessage): Result[seq[byte], string] = - # Fake serialization (index + payload length) TODO - return ok(@[byte(msg.index)] & msg.payload) - -# Segment message into smaller chunks (updated with nim-leopard) -proc segmentMessageInternal(newMessage: WakuNewMessage, segmentSize: int): Result[seq[WakuNewMessage], string] = +proc segmentMessage*(s: SegmentationHander, newMessage: Chunk): Result[seq[Chunk], string] = + let segmentSize = s.maxMessageSize div 4 * 3 if newMessage.payload.len <= segmentSize: return ok(@[newMessage]) @@ -114,7 +68,7 @@ proc segmentMessageInternal(newMessage: WakuNewMessage, segmentSize: int): Resul let paritySegmentsCount = int(floor(segmentsCount.float * SegmentsParityRate)) var segmentPayloads = newSeq[seq[byte]](segmentsCount + paritySegmentsCount) - var segmentMessages = newSeq[WakuNewMessage](segmentsCount) + var segmentMessages = newSeq[Chunk](segmentsCount) for i in 0.. SegmentsReedsolomonMaxCount: @@ -157,9 +106,11 @@ proc segmentMessageInternal(newMessage: WakuNewMessage, segmentSize: int): Resul # Use nim-leopard for Reed-Solomon encoding var data = segmentPayloads[0.. Date: Wed, 23 Jul 2025 15:25:04 +0800 Subject: [PATCH 7/9] feat: fix tests --- .gitignore | 4 +- chat_sdk/db.nim | 24 ++- chat_sdk/migration.nim | 18 +- chat_sdk/segment_message.proto | 16 ++ chat_sdk/segmentation.nim | 128 +++++++++---- .../002_create_message_segments_table.sql | 10 +- tests/test_chat_sdk.nim | 2 +- tests/test_segmentation.nim | 174 ++++++++++++++---- 8 files changed, 288 insertions(+), 88 deletions(-) create mode 100644 chat_sdk/segment_message.proto diff --git a/.gitignore b/.gitignore index 3a79a35..bd74343 100644 --- a/.gitignore +++ b/.gitignore @@ -20,8 +20,10 @@ nimcache/ # Compiled files chat_sdk/* -!*.nim apps/* +tests/* + !*.nim +!*.proto nimble.develop nimble.paths diff --git a/chat_sdk/db.nim b/chat_sdk/db.nim index 0cdd936..39e2bcf 100644 --- a/chat_sdk/db.nim +++ b/chat_sdk/db.nim @@ -2,9 +2,10 @@ import std/strutils import results import db_connector/db_sqlite import db_models +import nimcrypto type SegmentationPersistence* = object - db: DbConn + db*: DbConn proc removeMessageSegmentsOlderThan*( self: SegmentationPersistence, timestamp: int64 @@ -46,6 +47,7 @@ proc saveMessageSegment*( sigPubKeyBlob: seq[byte], timestamp: int64, ): Result[void, string] = + echo "size.....", segment.payload.len try: self.db.exec( sql""" @@ -53,15 +55,24 @@ proc saveMessageSegment*( hash, segment_index, segments_count, parity_segment_index, parity_segments_count, sig_pub_key, payload, timestamp ) VALUES (?, ?, ?, ?, ?, ?, ?, ?)""", - segment.entireMessageHash, + segment.entireMessageHash.toHex, segment.index, segment.segmentsCount, segment.paritySegmentIndex, segment.paritySegmentsCount, sigPubKeyBlob, - segment.payload, + segment.payload.toHex, timestamp, ) + + let storedPayload = self.db.getValue( + sql"SELECT payload FROM message_segments WHERE hash = ? AND sig_pub_key = ? AND segment_index = ?", + segment.entireMessageHash.toHex, sigPubKeyBlob, segment.index + ) + let payloadBytes = nimcrypto.fromHex(storedPayload) + # let lenBytes = payloadBytes.get() + echo "Stored payload length: ", payloadBytes.len + ok() except DbError as e: err("save message segments with error: " & e.msg) @@ -82,15 +93,16 @@ proc getMessageSegments*(self: SegmentationPersistence, hash: seq[byte], sigPubK """ try: - for row in self.db.rows(query, hash, sigPubKeyBlob): + for row in self.db.rows(query, hash.toHex, sigPubKeyBlob): let segment = SegmentMessage( - entireMessageHash: cast[seq[byte]](row[0]), + entireMessageHash: nimcrypto.fromHex(row[0]), index: uint32(parseInt(row[1])), segmentsCount: uint32(parseInt(row[2])), paritySegmentIndex: uint32(parseInt(row[3])), paritySegmentsCount: uint32(parseInt(row[4])), - payload: cast[seq[byte]](row[5]) + payload: nimcrypto.fromHex(row[5]) ) + echo "read size.....", segment.payload.len segments.add(segment) return ok(segments) except CatchableError as e: diff --git a/chat_sdk/migration.nim b/chat_sdk/migration.nim index 8af88d3..a71f466 100644 --- a/chat_sdk/migration.nim +++ b/chat_sdk/migration.nim @@ -1,4 +1,4 @@ -import os, sequtils, algorithm +import os, sequtils, algorithm, strutils import db_connector/db_sqlite import chronicles @@ -27,6 +27,16 @@ proc runMigrations*(db: DbConn, dir = "migrations") = info "Migration already applied", file else: info "Applying migration", file - let sql = readFile(file) - db.exec(sql(sql)) - markMigrationRun(db, file) + let script = readFile(file) + db.exec(sql"BEGIN TRANSACTION") + try: + for stmt in script.split(";"): + let trimmed = stmt.strip() + if trimmed.len > 0: + db.exec(sql(trimmed)) + + markMigrationRun(db, file) + db.exec(sql"COMMIT") + except: + db.exec(sql"ROLLBACK") + raise diff --git a/chat_sdk/segment_message.proto b/chat_sdk/segment_message.proto new file mode 100644 index 0000000..1b8c556 --- /dev/null +++ b/chat_sdk/segment_message.proto @@ -0,0 +1,16 @@ +syntax = "proto3"; + +message SegmentMessageProto { + // hash of the entire original message + bytes entire_message_hash = 1; + // Index of this segment within the entire original message + uint32 index = 2; + // Total number of segments the entire original message is divided into + uint32 segments_count = 3; + // The payload data for this particular segment + bytes payload = 4; + // Index of this parity segment + uint32 parity_segment_index = 5; + // Total number of parity segments + uint32 parity_segments_count = 6; +} diff --git a/chat_sdk/segmentation.nim b/chat_sdk/segmentation.nim index c758f98..14b5b5c 100644 --- a/chat_sdk/segmentation.nim +++ b/chat_sdk/segmentation.nim @@ -23,7 +23,7 @@ type sigPubKey*: seq[byte] SegmentationHander* = object - maxMessageSize*: int + segmentSize*: int persistence*: SegmentationPersistence const @@ -31,11 +31,11 @@ const SegmentsReedsolomonMaxCount = 256 const - ErrMessageSegmentsIncomplete = "message segments incomplete" - ErrMessageSegmentsAlreadyCompleted = "message segments already completed" - ErrMessageSegmentsInvalidCount = "invalid segments count" - ErrMessageSegmentsHashMismatch = "hash of entire payload does not match" - ErrMessageSegmentsInvalidParity = "invalid parity segments" + ErrMessageSegmentsIncomplete* = "message segments incomplete" + ErrMessageSegmentsAlreadyCompleted* = "message segments already completed" + ErrMessageSegmentsInvalidCount* = "invalid segments count" + ErrMessageSegmentsHashMismatch* = "hash of entire payload does not match" + ErrMessageSegmentsInvalidParity* = "invalid parity segments" proc isValid(s: SegmentMessageProto): bool = # Check if the hash length is valid (32 bytes for Keccak256) @@ -57,22 +57,29 @@ proc isParityMessage*(s: SegmentMessage): bool = s.segmentsCount == 0 and s.paritySegmentsCount > 0 proc segmentMessage*(s: SegmentationHander, newMessage: Chunk): Result[seq[Chunk], string] = - let segmentSize = s.maxMessageSize div 4 * 3 - if newMessage.payload.len <= segmentSize: + if newMessage.payload.len <= s.segmentSize: return ok(@[newMessage]) + info "segmenting message", + "payloadSize" = newMessage.payload.len, + "segmentSize" = s.segmentSize + info "segment payload in", "saaa" = newMessage.payload.toHex let entireMessageHash = keccak256.digest(newMessage.payload) let entirePayloadSize = newMessage.payload.len - let segmentsCount = int(ceil(entirePayloadSize.float / segmentSize.float)) + let segmentsCount = int(ceil(entirePayloadSize.float / s.segmentSize.float)) let paritySegmentsCount = int(floor(segmentsCount.float * SegmentsParityRate)) + info "parity count", paritySegmentsCount + info "segmentsCount", segmentsCount + info "entirePayloadSize", entirePayloadSize + info "segmentSize", "segmentSize" = s.segmentSize - var segmentPayloads = newSeq[seq[byte]](segmentsCount + paritySegmentsCount) + var segmentPayloads = newSeq[seq[byte]](segmentsCount) var segmentMessages = newSeq[Chunk](segmentsCount) for i in 0.. entirePayloadSize: endIndex = entirePayloadSize @@ -84,6 +91,8 @@ proc segmentMessage*(s: SegmentationHander, newMessage: Chunk): Result[seq[Chunk payload: segmentPayload ) + info "entireMessageHash", "entireMessageHash" = entireMessageHash.data.toHex + let marshaledSegment = Protobuf.encode(segmentWithMetadata) let segmentMessage = Chunk(payload: marshaledSegment) @@ -91,40 +100,60 @@ proc segmentMessage*(s: SegmentationHander, newMessage: Chunk): Result[seq[Chunk segmentMessages[i] = segmentMessage # Skip Reed-Solomon if parity segments are 0 or total exceeds max count + info "segments count", "len" = segmentMessages.len if paritySegmentsCount == 0 or (segmentsCount + paritySegmentsCount) > SegmentsReedsolomonMaxCount: return ok(segmentMessages) # Align last segment payload for Reed-Solomon (leopard requires fixed-size shards) let lastSegmentPayload = segmentPayloads[segmentsCount-1] - segmentPayloads[segmentsCount-1] = newSeq[byte](segmentSize) + segmentPayloads[segmentsCount-1] = newSeq[byte](s.segmentSize) segmentPayloads[segmentsCount-1][0.. 0: payloads[firstSegmentMessage.segmentsCount - 1] = lastNonParitySegmentPayload @@ -234,6 +287,8 @@ proc handleSegmentationLayer*(handler: SegmentationHander, message: var Message) entirePayload.add(payloads[i]) # Sanity check + info "entirePayloadSize", "entirePayloadSize" = entirePayload.len + info "entire payload", "entirePayloadSize" = entirePayload.toHex let entirePayloadHash = keccak256.digest(entirePayload) if entirePayloadHash.data != segmentMessage.entireMessageHash: return err(ErrMessageSegmentsHashMismatch) @@ -247,7 +302,6 @@ proc handleSegmentationLayer*(handler: SegmentationHander, message: var Message) proc cleanupSegments*(s: SegmentationHander): Result[void, string] = - # Same as previous translation discard proc demo() = diff --git a/migrations/002_create_message_segments_table.sql b/migrations/002_create_message_segments_table.sql index df75ee7..32adb31 100644 --- a/migrations/002_create_message_segments_table.sql +++ b/migrations/002_create_message_segments_table.sql @@ -1,11 +1,13 @@ -CREATE TABLE IF NOT EXISTS message_segments ( +CREATE TABLE message_segments ( hash BLOB NOT NULL, segment_index INTEGER NOT NULL, segments_count INTEGER NOT NULL, payload BLOB NOT NULL, sig_pub_key BLOB NOT NULL, - timestamp INTEGER DEFAULT 0, - PRIMARY KEY (hash, sig_pub_key, segment_index) ON CONFLICT REPLACE + timestamp INTEGER NOT NULL, + parity_segment_index INTEGER NOT NULL, + parity_segments_count INTEGER NOT NULL, + PRIMARY KEY (hash, sig_pub_key, segment_index, segments_count, parity_segment_index, parity_segments_count) ON CONFLICT REPLACE ); CREATE TABLE IF NOT EXISTS message_segments_completed ( @@ -16,4 +18,4 @@ CREATE TABLE IF NOT EXISTS message_segments_completed ( ); CREATE INDEX idx_message_segments_timestamp ON message_segments(timestamp); -CREATE INDEX idx_message_segments_completed_timestamp ON message_segments_completed(timestamp); \ No newline at end of file +CREATE INDEX idx_message_segments_completed_timestamp ON message_segments_completed(timestamp); diff --git a/tests/test_chat_sdk.nim b/tests/test_chat_sdk.nim index a588132..4c2a979 100644 --- a/tests/test_chat_sdk.nim +++ b/tests/test_chat_sdk.nim @@ -7,7 +7,7 @@ import unittest -import chat_sdk +# import chat_sdk import chat_sdk/segmentation test "can add": diff --git a/tests/test_segmentation.nim b/tests/test_segmentation.nim index fa1677f..e302d3b 100644 --- a/tests/test_segmentation.nim +++ b/tests/test_segmentation.nim @@ -5,49 +5,153 @@ # # To run these tests, simply execute `nimble test`. -import unittest, sequtils, random +import unittest, sequtils, random, math import results -import chat_sdk/segmentation # Replace with the actual file name +import db_connector/db_sqlite +import nimcrypto +import chat_sdk/segmentation +import chat_sdk/migration +import chat_sdk/db test "can add": check add2(5, 5) == 10 -suite "Message Segmentation": - let testPayload = toSeq(0..999).mapIt(byte(it)) - let testMessage = WakuNewMessage(payload: testPayload) - let sender = MessageSender( - messaging: Messaging(maxMessageSize: 500), - persistence: initPersistence() - ) +proc newInMemoryPersistence(): SegmentationPersistence = + let conn = open(":memory:", "", "", "") + # Define the tables (same schema as expected in your app) + runMigrations(conn) + result = SegmentationPersistence(db: conn) - test "Segment and reassemble full segments": - let segmentResult = sender.segmentMessage(testMessage) - check segmentResult.isOk - var segments = segmentResult.get() - check segments.len > 0 +suite "message Segmentation": + var + sender: SegmentationHander + testPayload: seq[byte] + mockPersistence: SegmentationPersistence - # Shuffle segment order for out-of-order test - segments.shuffle() + setup: + # Initialize test payload (1000 bytes, each byte is its index) + testPayload = newSeq[byte](1024) + for i in 0..<1024: + testPayload[i] = rand(255).byte - for segment in segments: - var statusMsg: StatusMessage - statusMsg.transportLayer = TransportLayer( - hash: @[byte 0, 1, 2], # Dummy hash - payload: segment.payload, - sigPubKey: @[byte 3, 4, 5] - ) - let result = sender.handleSegmentationLayerV2(statusMsg) - discard result # Ignore intermediate errors + # Setup mock persistence + mockPersistence = newInMemoryPersistence() - # One last run to trigger completion - var finalStatus: StatusMessage - finalStatus.transportLayer = TransportLayer( - hash: @[byte 0, 1, 2], - payload: segments[0].payload, - sigPubKey: @[byte 3, 4, 5] + # Setup sender + sender = SegmentationHander( + segmentSize: 4000, # Arbitrary size to allow segmentation + persistence: mockPersistence ) - let finalResult = sender.handleSegmentationLayerV2(finalStatus) - check finalResult.isOk + + test "HandleSegmentationLayer": + # Define test cases (mirroring Go test cases) + let testCases = @[ + ( + name: "all segments retrieved", + segmentsCount: 2, + expectedParitySegmentsCount: 0, + retrievedSegments: @[0, 1], + retrievedParitySegments: newSeq[int](), + shouldSucceed: true + ), + ( + name: "all segments retrieved out of order", + segmentsCount: 2, + expectedParitySegmentsCount: 0, + retrievedSegments: @[1, 0], + retrievedParitySegments: newSeq[int](), + shouldSucceed: true + ), + ( + name: "all segments&parity retrieved", + segmentsCount: 8, + expectedParitySegmentsCount: 1, + retrievedSegments: @[0, 1, 2, 3, 4, 5, 6, 7, 8], + retrievedParitySegments: @[8], + shouldSucceed: true + ), + ( + name: "all segments&parity retrieved out of order", + segmentsCount: 8, + expectedParitySegmentsCount: 1, + retrievedSegments: @[8, 0, 7, 1, 6, 2, 5, 3, 4], + retrievedParitySegments: @[8], + shouldSucceed: true + ), + ( + name: "no segments retrieved", + segmentsCount: 2, + expectedParitySegmentsCount: 0, + retrievedSegments: @[], + retrievedParitySegments: newSeq[int](), + shouldSucceed: false + ), + ( + name: "not all needed segments&parity retrieved", + segmentsCount: 8, + expectedParitySegmentsCount: 1, + retrievedSegments: @[1, 2, 8], + retrievedParitySegments: @[8], + shouldSucceed: false + ), + ( + name: "segments&parity retrieved", + segmentsCount: 8, + expectedParitySegmentsCount: 1, + retrievedSegments: @[1, 2, 3, 4, 5, 6, 7, 8], + retrievedParitySegments: @[8], + shouldSucceed: true + ), + ( + name: "segments&parity retrieved out of order", + segmentsCount: 16, + expectedParitySegmentsCount: 2, + retrievedSegments: @[17, 0, 16, 1, 15, 2, 14, 3, 13, 4, 12, 5, 11, 6, 10, 7], + retrievedParitySegments: @[16, 17], + shouldSucceed: true + ) + ] - # Check payload restored - check finalStatus.transportLayer.payload == testPayload[0.. Date: Wed, 23 Jul 2025 15:34:36 +0800 Subject: [PATCH 8/9] chore: clear log --- chat_sdk/db.nim | 4 --- chat_sdk/segmentation.nim | 53 ++----------------------------------- tests/test_segmentation.nim | 11 -------- 3 files changed, 2 insertions(+), 66 deletions(-) diff --git a/chat_sdk/db.nim b/chat_sdk/db.nim index 39e2bcf..7b46906 100644 --- a/chat_sdk/db.nim +++ b/chat_sdk/db.nim @@ -47,7 +47,6 @@ proc saveMessageSegment*( sigPubKeyBlob: seq[byte], timestamp: int64, ): Result[void, string] = - echo "size.....", segment.payload.len try: self.db.exec( sql""" @@ -70,8 +69,6 @@ proc saveMessageSegment*( segment.entireMessageHash.toHex, sigPubKeyBlob, segment.index ) let payloadBytes = nimcrypto.fromHex(storedPayload) - # let lenBytes = payloadBytes.get() - echo "Stored payload length: ", payloadBytes.len ok() except DbError as e: @@ -102,7 +99,6 @@ proc getMessageSegments*(self: SegmentationPersistence, hash: seq[byte], sigPubK paritySegmentsCount: uint32(parseInt(row[4])), payload: nimcrypto.fromHex(row[5]) ) - echo "read size.....", segment.payload.len segments.add(segment) return ok(segments) except CatchableError as e: diff --git a/chat_sdk/segmentation.nim b/chat_sdk/segmentation.nim index 14b5b5c..994a095 100644 --- a/chat_sdk/segmentation.nim +++ b/chat_sdk/segmentation.nim @@ -12,7 +12,6 @@ import db import_proto3 "segment_message.proto" -# Placeholder types (unchanged) type Chunk* = object payload*: seq[byte] @@ -63,16 +62,11 @@ proc segmentMessage*(s: SegmentationHander, newMessage: Chunk): Result[seq[Chunk info "segmenting message", "payloadSize" = newMessage.payload.len, "segmentSize" = s.segmentSize - info "segment payload in", "saaa" = newMessage.payload.toHex let entireMessageHash = keccak256.digest(newMessage.payload) let entirePayloadSize = newMessage.payload.len let segmentsCount = int(ceil(entirePayloadSize.float / s.segmentSize.float)) let paritySegmentsCount = int(floor(segmentsCount.float * SegmentsParityRate)) - info "parity count", paritySegmentsCount - info "segmentsCount", segmentsCount - info "entirePayloadSize", entirePayloadSize - info "segmentSize", "segmentSize" = s.segmentSize var segmentPayloads = newSeq[seq[byte]](segmentsCount) var segmentMessages = newSeq[Chunk](segmentsCount) @@ -91,8 +85,6 @@ proc segmentMessage*(s: SegmentationHander, newMessage: Chunk): Result[seq[Chunk payload: segmentPayload ) - info "entireMessageHash", "entireMessageHash" = entireMessageHash.data.toHex - let marshaledSegment = Protobuf.encode(segmentWithMetadata) let segmentMessage = Chunk(payload: marshaledSegment) @@ -109,45 +101,22 @@ proc segmentMessage*(s: SegmentationHander, newMessage: Chunk): Result[seq[Chunk segmentPayloads[segmentsCount-1] = newSeq[byte](s.segmentSize) segmentPayloads[segmentsCount-1][0.. Date: Tue, 5 Aug 2025 15:45:32 +0800 Subject: [PATCH 9/9] chore: clear unused code --- chat_sdk/segmentation.nim | 21 --------------------- tests/test_segmentation.nim | 3 --- 2 files changed, 24 deletions(-) diff --git a/chat_sdk/segmentation.nim b/chat_sdk/segmentation.nim index 994a095..85d20ed 100644 --- a/chat_sdk/segmentation.nim +++ b/chat_sdk/segmentation.nim @@ -254,24 +254,3 @@ proc handleSegmentationLayer*(handler: SegmentationHander, message: var Message) proc cleanupSegments*(s: SegmentationHander): Result[void, string] = discard - -proc demo() = - let - bufSize = 64 # byte count per buffer, must be a multiple of 64 - buffers = 239 # number of data symbols - parity = 17 # number of parity symbols - - var - encoderRes = LeoEncoder.init(bufSize, buffers, parity) - decoderRes = LeoDecoder.init(bufSize, buffers, parity) - - assert encoderRes.isOk - assert decoderRes.isOk - - -when isMainModule: - demo() - -proc add2*(x, y: int): int = - ## Adds two numbers together. - return x + y diff --git a/tests/test_segmentation.nim b/tests/test_segmentation.nim index 8112c06..c1064b5 100644 --- a/tests/test_segmentation.nim +++ b/tests/test_segmentation.nim @@ -13,9 +13,6 @@ import chat_sdk/segmentation import chat_sdk/migration import chat_sdk/db -test "can add": - check add2(5, 5) == 10 - proc newInMemoryPersistence(): SegmentationPersistence = let conn = open(":memory:", "", "", "") # Define the tables (same schema as expected in your app)