From 001e5b97c9e7920e642de7d1764f1cbcfdbd9a95 Mon Sep 17 00:00:00 2001 From: kaichaosun Date: Wed, 16 Jul 2025 15:54:32 +0800 Subject: [PATCH] chore: refactor segment logic --- chat_sdk/db.nim | 14 ++--- chat_sdk/segmentation.nim | 120 +++++++++----------------------------- 2 files changed, 35 insertions(+), 99 deletions(-) diff --git a/chat_sdk/db.nim b/chat_sdk/db.nim index 23bda65..0cdd936 100644 --- a/chat_sdk/db.nim +++ b/chat_sdk/db.nim @@ -3,11 +3,11 @@ import results import db_connector/db_sqlite import db_models -type MessageSegmentsPersistence = object +type SegmentationPersistence* = object db: DbConn proc removeMessageSegmentsOlderThan*( - self: MessageSegmentsPersistence, timestamp: int64 + self: SegmentationPersistence, timestamp: int64 ): Result[void, string] = try: self.db.exec(sql"DELETE FROM message_segments WHERE timestamp < ?", timestamp) @@ -16,7 +16,7 @@ proc removeMessageSegmentsOlderThan*( err("remove message segments with error: " & e.msg) proc removeMessageSegmentsCompletedOlderThan*( - self: MessageSegmentsPersistence, timestamp: int64 + self: SegmentationPersistence, timestamp: int64 ): Result[void, string] = try: self.db.exec( @@ -27,7 +27,7 @@ proc removeMessageSegmentsCompletedOlderThan*( err("remove message segments completed with error: " & e.msg) proc isMessageAlreadyCompleted*( - self: MessageSegmentsPersistence, hash: seq[byte] + self: SegmentationPersistence, hash: seq[byte] ): Result[bool, string] = try: let row = self.db.getRow( @@ -41,7 +41,7 @@ proc isMessageAlreadyCompleted*( err("check message already completed with error: " & e.msg) proc saveMessageSegment*( - self: MessageSegmentsPersistence, + self: SegmentationPersistence, segment: SegmentMessage, sigPubKeyBlob: seq[byte], timestamp: int64, @@ -66,7 +66,7 @@ proc saveMessageSegment*( except DbError as e: err("save message segments with error: " & e.msg) -proc getMessageSegments*(self: MessageSegmentsPersistence, hash: seq[byte], sigPubKeyBlob: seq[byte]): Result[seq[SegmentMessage], string] = +proc getMessageSegments*(self: SegmentationPersistence, hash: seq[byte], sigPubKeyBlob: seq[byte]): Result[seq[SegmentMessage], string] = var segments = newSeq[SegmentMessage]() let query = sql""" SELECT @@ -97,7 +97,7 @@ proc getMessageSegments*(self: MessageSegmentsPersistence, hash: seq[byte], sigP return err("get Message Segments with error: " & e.msg) proc completeMessageSegments*( - self: MessageSegmentsPersistence, + self: SegmentationPersistence, hash: seq[byte], sigPubKeyBlob: seq[byte], timestamp: int64 diff --git a/chat_sdk/segmentation.nim b/chat_sdk/segmentation.nim index eb21c6c..c758f98 100644 --- a/chat_sdk/segmentation.nim +++ b/chat_sdk/segmentation.nim @@ -8,25 +8,28 @@ import protobuf_serialization/proto_parser import protobuf_serialization import db_models +import db import_proto3 "segment_message.proto" # Placeholder types (unchanged) type - WakuNewMessage* = object + Chunk* = object payload*: seq[byte] - # Add other fields as needed Message* = object hash*: seq[byte] payload*: seq[byte] sigPubKey*: seq[byte] + + SegmentationHander* = object + maxMessageSize*: int + persistence*: SegmentationPersistence - Persistence = object - completedMessages: Table[seq[byte], bool] # Mock storage for completed message hashes - segments: Table[(seq[byte], seq[byte]), seq[SegmentMessage]] # Stores segments by (hash, sigPubKey) +const + SegmentsParityRate = 0.125 + SegmentsReedsolomonMaxCount = 256 -# Error definitions (unchanged) const ErrMessageSegmentsIncomplete = "message segments incomplete" ErrMessageSegmentsAlreadyCompleted = "message segments already completed" @@ -34,12 +37,6 @@ const ErrMessageSegmentsHashMismatch = "hash of entire payload does not match" ErrMessageSegmentsInvalidParity = "invalid parity segments" -# Constants (unchanged) -const - SegmentsParityRate = 0.125 - SegmentsReedsolomonMaxCount = 256 - - proc isValid(s: SegmentMessageProto): bool = # Check if the hash length is valid (32 bytes for Keccak256) if s.entire_message_hash.len != 32: @@ -59,51 +56,8 @@ proc isValid(s: SegmentMessageProto): bool = proc isParityMessage*(s: SegmentMessage): bool = s.segmentsCount == 0 and s.paritySegmentsCount > 0 -# MessageSender type (unchanged) -type - MessageSender* = ref object - messaging*: Messaging - persistence*: Persistence - - Messaging* = object - maxMessageSize*: int - -proc initPersistence*(): Persistence = - Persistence( - completedMessages: initTable[seq[byte], bool](), - segments: initTable[(seq[byte], seq[byte]), seq[SegmentMessage]]() - ) - -proc isMessageAlreadyCompleted(p: Persistence, hash: seq[byte]): Result[bool, string] = - return ok(p.completedMessages.getOrDefault(hash, false)) - -proc saveMessageSegment(p: var Persistence, segment: SegmentMessage, sigPubKey: seq[byte], timestamp: int64): Result[void, string] = - let key = (segment.entireMessageHash, sigPubKey) - # Initialize or append to the segments list - if not p.segments.hasKey(key): - p.segments[key] = @[] - p.segments[key].add(segment) - return ok() - -proc getMessageSegments(p: Persistence, hash: seq[byte], sigPubKey: seq[byte]): Result[seq[SegmentMessage], string] = - let key = (hash, sigPubKey) - return ok(p.segments.getOrDefault(key, @[])) - -proc completeMessageSegments(p: var Persistence, hash: seq[byte], sigPubKey: seq[byte], timestamp: int64): Result[void, string] = - p.completedMessages[hash] = true - return ok() - -# Replicate message (unchanged) -proc replicateMessageWithNewPayload(message: WakuNewMessage, payload: seq[byte]): Result[WakuNewMessage, string] = - var copiedMessage = WakuNewMessage(payload: payload) - return ok(copiedMessage) - -proc protoMarshal(msg: SegmentMessage): Result[seq[byte], string] = - # Fake serialization (index + payload length) TODO - return ok(@[byte(msg.index)] & msg.payload) - -# Segment message into smaller chunks (updated with nim-leopard) -proc segmentMessageInternal(newMessage: WakuNewMessage, segmentSize: int): Result[seq[WakuNewMessage], string] = +proc segmentMessage*(s: SegmentationHander, newMessage: Chunk): Result[seq[Chunk], string] = + let segmentSize = s.maxMessageSize div 4 * 3 if newMessage.payload.len <= segmentSize: return ok(@[newMessage]) @@ -114,7 +68,7 @@ proc segmentMessageInternal(newMessage: WakuNewMessage, segmentSize: int): Resul let paritySegmentsCount = int(floor(segmentsCount.float * SegmentsParityRate)) var segmentPayloads = newSeq[seq[byte]](segmentsCount + paritySegmentsCount) - var segmentMessages = newSeq[WakuNewMessage](segmentsCount) + var segmentMessages = newSeq[Chunk](segmentsCount) for i in 0.. SegmentsReedsolomonMaxCount: @@ -157,9 +106,11 @@ proc segmentMessageInternal(newMessage: WakuNewMessage, segmentSize: int): Resul # Use nim-leopard for Reed-Solomon encoding var data = segmentPayloads[0..