diff --git a/.gitignore b/.gitignore index 3a79a35..bd74343 100644 --- a/.gitignore +++ b/.gitignore @@ -20,8 +20,10 @@ nimcache/ # Compiled files chat_sdk/* -!*.nim apps/* +tests/* + !*.nim +!*.proto nimble.develop nimble.paths diff --git a/chat_sdk/db.nim b/chat_sdk/db.nim index 0cdd936..39e2bcf 100644 --- a/chat_sdk/db.nim +++ b/chat_sdk/db.nim @@ -2,9 +2,10 @@ import std/strutils import results import db_connector/db_sqlite import db_models +import nimcrypto type SegmentationPersistence* = object - db: DbConn + db*: DbConn proc removeMessageSegmentsOlderThan*( self: SegmentationPersistence, timestamp: int64 @@ -46,6 +47,7 @@ proc saveMessageSegment*( sigPubKeyBlob: seq[byte], timestamp: int64, ): Result[void, string] = + echo "size.....", segment.payload.len try: self.db.exec( sql""" @@ -53,15 +55,24 @@ proc saveMessageSegment*( hash, segment_index, segments_count, parity_segment_index, parity_segments_count, sig_pub_key, payload, timestamp ) VALUES (?, ?, ?, ?, ?, ?, ?, ?)""", - segment.entireMessageHash, + segment.entireMessageHash.toHex, segment.index, segment.segmentsCount, segment.paritySegmentIndex, segment.paritySegmentsCount, sigPubKeyBlob, - segment.payload, + segment.payload.toHex, timestamp, ) + + let storedPayload = self.db.getValue( + sql"SELECT payload FROM message_segments WHERE hash = ? AND sig_pub_key = ? AND segment_index = ?", + segment.entireMessageHash.toHex, sigPubKeyBlob, segment.index + ) + let payloadBytes = nimcrypto.fromHex(storedPayload) + # let lenBytes = payloadBytes.get() + echo "Stored payload length: ", payloadBytes.len + ok() except DbError as e: err("save message segments with error: " & e.msg) @@ -82,15 +93,16 @@ proc getMessageSegments*(self: SegmentationPersistence, hash: seq[byte], sigPubK """ try: - for row in self.db.rows(query, hash, sigPubKeyBlob): + for row in self.db.rows(query, hash.toHex, sigPubKeyBlob): let segment = SegmentMessage( - entireMessageHash: cast[seq[byte]](row[0]), + entireMessageHash: nimcrypto.fromHex(row[0]), index: uint32(parseInt(row[1])), segmentsCount: uint32(parseInt(row[2])), paritySegmentIndex: uint32(parseInt(row[3])), paritySegmentsCount: uint32(parseInt(row[4])), - payload: cast[seq[byte]](row[5]) + payload: nimcrypto.fromHex(row[5]) ) + echo "read size.....", segment.payload.len segments.add(segment) return ok(segments) except CatchableError as e: diff --git a/chat_sdk/migration.nim b/chat_sdk/migration.nim index 8af88d3..a71f466 100644 --- a/chat_sdk/migration.nim +++ b/chat_sdk/migration.nim @@ -1,4 +1,4 @@ -import os, sequtils, algorithm +import os, sequtils, algorithm, strutils import db_connector/db_sqlite import chronicles @@ -27,6 +27,16 @@ proc runMigrations*(db: DbConn, dir = "migrations") = info "Migration already applied", file else: info "Applying migration", file - let sql = readFile(file) - db.exec(sql(sql)) - markMigrationRun(db, file) + let script = readFile(file) + db.exec(sql"BEGIN TRANSACTION") + try: + for stmt in script.split(";"): + let trimmed = stmt.strip() + if trimmed.len > 0: + db.exec(sql(trimmed)) + + markMigrationRun(db, file) + db.exec(sql"COMMIT") + except: + db.exec(sql"ROLLBACK") + raise diff --git a/chat_sdk/segment_message.proto b/chat_sdk/segment_message.proto new file mode 100644 index 0000000..1b8c556 --- /dev/null +++ b/chat_sdk/segment_message.proto @@ -0,0 +1,16 @@ +syntax = "proto3"; + +message SegmentMessageProto { + // hash of the entire original message + bytes entire_message_hash = 1; + // Index of this segment within the entire original message + uint32 index = 2; + // Total number of segments the entire original message is divided into + uint32 segments_count = 3; + // The payload data for this particular segment + bytes payload = 4; + // Index of this parity segment + uint32 parity_segment_index = 5; + // Total number of parity segments + uint32 parity_segments_count = 6; +} diff --git a/chat_sdk/segmentation.nim b/chat_sdk/segmentation.nim index c758f98..14b5b5c 100644 --- a/chat_sdk/segmentation.nim +++ b/chat_sdk/segmentation.nim @@ -23,7 +23,7 @@ type sigPubKey*: seq[byte] SegmentationHander* = object - maxMessageSize*: int + segmentSize*: int persistence*: SegmentationPersistence const @@ -31,11 +31,11 @@ const SegmentsReedsolomonMaxCount = 256 const - ErrMessageSegmentsIncomplete = "message segments incomplete" - ErrMessageSegmentsAlreadyCompleted = "message segments already completed" - ErrMessageSegmentsInvalidCount = "invalid segments count" - ErrMessageSegmentsHashMismatch = "hash of entire payload does not match" - ErrMessageSegmentsInvalidParity = "invalid parity segments" + ErrMessageSegmentsIncomplete* = "message segments incomplete" + ErrMessageSegmentsAlreadyCompleted* = "message segments already completed" + ErrMessageSegmentsInvalidCount* = "invalid segments count" + ErrMessageSegmentsHashMismatch* = "hash of entire payload does not match" + ErrMessageSegmentsInvalidParity* = "invalid parity segments" proc isValid(s: SegmentMessageProto): bool = # Check if the hash length is valid (32 bytes for Keccak256) @@ -57,22 +57,29 @@ proc isParityMessage*(s: SegmentMessage): bool = s.segmentsCount == 0 and s.paritySegmentsCount > 0 proc segmentMessage*(s: SegmentationHander, newMessage: Chunk): Result[seq[Chunk], string] = - let segmentSize = s.maxMessageSize div 4 * 3 - if newMessage.payload.len <= segmentSize: + if newMessage.payload.len <= s.segmentSize: return ok(@[newMessage]) + info "segmenting message", + "payloadSize" = newMessage.payload.len, + "segmentSize" = s.segmentSize + info "segment payload in", "saaa" = newMessage.payload.toHex let entireMessageHash = keccak256.digest(newMessage.payload) let entirePayloadSize = newMessage.payload.len - let segmentsCount = int(ceil(entirePayloadSize.float / segmentSize.float)) + let segmentsCount = int(ceil(entirePayloadSize.float / s.segmentSize.float)) let paritySegmentsCount = int(floor(segmentsCount.float * SegmentsParityRate)) + info "parity count", paritySegmentsCount + info "segmentsCount", segmentsCount + info "entirePayloadSize", entirePayloadSize + info "segmentSize", "segmentSize" = s.segmentSize - var segmentPayloads = newSeq[seq[byte]](segmentsCount + paritySegmentsCount) + var segmentPayloads = newSeq[seq[byte]](segmentsCount) var segmentMessages = newSeq[Chunk](segmentsCount) for i in 0.. entirePayloadSize: endIndex = entirePayloadSize @@ -84,6 +91,8 @@ proc segmentMessage*(s: SegmentationHander, newMessage: Chunk): Result[seq[Chunk payload: segmentPayload ) + info "entireMessageHash", "entireMessageHash" = entireMessageHash.data.toHex + let marshaledSegment = Protobuf.encode(segmentWithMetadata) let segmentMessage = Chunk(payload: marshaledSegment) @@ -91,40 +100,60 @@ proc segmentMessage*(s: SegmentationHander, newMessage: Chunk): Result[seq[Chunk segmentMessages[i] = segmentMessage # Skip Reed-Solomon if parity segments are 0 or total exceeds max count + info "segments count", "len" = segmentMessages.len if paritySegmentsCount == 0 or (segmentsCount + paritySegmentsCount) > SegmentsReedsolomonMaxCount: return ok(segmentMessages) # Align last segment payload for Reed-Solomon (leopard requires fixed-size shards) let lastSegmentPayload = segmentPayloads[segmentsCount-1] - segmentPayloads[segmentsCount-1] = newSeq[byte](segmentSize) + segmentPayloads[segmentsCount-1] = newSeq[byte](s.segmentSize) segmentPayloads[segmentsCount-1][0.. 0: payloads[firstSegmentMessage.segmentsCount - 1] = lastNonParitySegmentPayload @@ -234,6 +287,8 @@ proc handleSegmentationLayer*(handler: SegmentationHander, message: var Message) entirePayload.add(payloads[i]) # Sanity check + info "entirePayloadSize", "entirePayloadSize" = entirePayload.len + info "entire payload", "entirePayloadSize" = entirePayload.toHex let entirePayloadHash = keccak256.digest(entirePayload) if entirePayloadHash.data != segmentMessage.entireMessageHash: return err(ErrMessageSegmentsHashMismatch) @@ -247,7 +302,6 @@ proc handleSegmentationLayer*(handler: SegmentationHander, message: var Message) proc cleanupSegments*(s: SegmentationHander): Result[void, string] = - # Same as previous translation discard proc demo() = diff --git a/migrations/002_create_message_segments_table.sql b/migrations/002_create_message_segments_table.sql index df75ee7..32adb31 100644 --- a/migrations/002_create_message_segments_table.sql +++ b/migrations/002_create_message_segments_table.sql @@ -1,11 +1,13 @@ -CREATE TABLE IF NOT EXISTS message_segments ( +CREATE TABLE message_segments ( hash BLOB NOT NULL, segment_index INTEGER NOT NULL, segments_count INTEGER NOT NULL, payload BLOB NOT NULL, sig_pub_key BLOB NOT NULL, - timestamp INTEGER DEFAULT 0, - PRIMARY KEY (hash, sig_pub_key, segment_index) ON CONFLICT REPLACE + timestamp INTEGER NOT NULL, + parity_segment_index INTEGER NOT NULL, + parity_segments_count INTEGER NOT NULL, + PRIMARY KEY (hash, sig_pub_key, segment_index, segments_count, parity_segment_index, parity_segments_count) ON CONFLICT REPLACE ); CREATE TABLE IF NOT EXISTS message_segments_completed ( @@ -16,4 +18,4 @@ CREATE TABLE IF NOT EXISTS message_segments_completed ( ); CREATE INDEX idx_message_segments_timestamp ON message_segments(timestamp); -CREATE INDEX idx_message_segments_completed_timestamp ON message_segments_completed(timestamp); \ No newline at end of file +CREATE INDEX idx_message_segments_completed_timestamp ON message_segments_completed(timestamp); diff --git a/tests/test_chat_sdk.nim b/tests/test_chat_sdk.nim index a588132..4c2a979 100644 --- a/tests/test_chat_sdk.nim +++ b/tests/test_chat_sdk.nim @@ -7,7 +7,7 @@ import unittest -import chat_sdk +# import chat_sdk import chat_sdk/segmentation test "can add": diff --git a/tests/test_segmentation.nim b/tests/test_segmentation.nim index fa1677f..e302d3b 100644 --- a/tests/test_segmentation.nim +++ b/tests/test_segmentation.nim @@ -5,49 +5,153 @@ # # To run these tests, simply execute `nimble test`. -import unittest, sequtils, random +import unittest, sequtils, random, math import results -import chat_sdk/segmentation # Replace with the actual file name +import db_connector/db_sqlite +import nimcrypto +import chat_sdk/segmentation +import chat_sdk/migration +import chat_sdk/db test "can add": check add2(5, 5) == 10 -suite "Message Segmentation": - let testPayload = toSeq(0..999).mapIt(byte(it)) - let testMessage = WakuNewMessage(payload: testPayload) - let sender = MessageSender( - messaging: Messaging(maxMessageSize: 500), - persistence: initPersistence() - ) +proc newInMemoryPersistence(): SegmentationPersistence = + let conn = open(":memory:", "", "", "") + # Define the tables (same schema as expected in your app) + runMigrations(conn) + result = SegmentationPersistence(db: conn) - test "Segment and reassemble full segments": - let segmentResult = sender.segmentMessage(testMessage) - check segmentResult.isOk - var segments = segmentResult.get() - check segments.len > 0 +suite "message Segmentation": + var + sender: SegmentationHander + testPayload: seq[byte] + mockPersistence: SegmentationPersistence - # Shuffle segment order for out-of-order test - segments.shuffle() + setup: + # Initialize test payload (1000 bytes, each byte is its index) + testPayload = newSeq[byte](1024) + for i in 0..<1024: + testPayload[i] = rand(255).byte - for segment in segments: - var statusMsg: StatusMessage - statusMsg.transportLayer = TransportLayer( - hash: @[byte 0, 1, 2], # Dummy hash - payload: segment.payload, - sigPubKey: @[byte 3, 4, 5] - ) - let result = sender.handleSegmentationLayerV2(statusMsg) - discard result # Ignore intermediate errors + # Setup mock persistence + mockPersistence = newInMemoryPersistence() - # One last run to trigger completion - var finalStatus: StatusMessage - finalStatus.transportLayer = TransportLayer( - hash: @[byte 0, 1, 2], - payload: segments[0].payload, - sigPubKey: @[byte 3, 4, 5] + # Setup sender + sender = SegmentationHander( + segmentSize: 4000, # Arbitrary size to allow segmentation + persistence: mockPersistence ) - let finalResult = sender.handleSegmentationLayerV2(finalStatus) - check finalResult.isOk + + test "HandleSegmentationLayer": + # Define test cases (mirroring Go test cases) + let testCases = @[ + ( + name: "all segments retrieved", + segmentsCount: 2, + expectedParitySegmentsCount: 0, + retrievedSegments: @[0, 1], + retrievedParitySegments: newSeq[int](), + shouldSucceed: true + ), + ( + name: "all segments retrieved out of order", + segmentsCount: 2, + expectedParitySegmentsCount: 0, + retrievedSegments: @[1, 0], + retrievedParitySegments: newSeq[int](), + shouldSucceed: true + ), + ( + name: "all segments&parity retrieved", + segmentsCount: 8, + expectedParitySegmentsCount: 1, + retrievedSegments: @[0, 1, 2, 3, 4, 5, 6, 7, 8], + retrievedParitySegments: @[8], + shouldSucceed: true + ), + ( + name: "all segments&parity retrieved out of order", + segmentsCount: 8, + expectedParitySegmentsCount: 1, + retrievedSegments: @[8, 0, 7, 1, 6, 2, 5, 3, 4], + retrievedParitySegments: @[8], + shouldSucceed: true + ), + ( + name: "no segments retrieved", + segmentsCount: 2, + expectedParitySegmentsCount: 0, + retrievedSegments: @[], + retrievedParitySegments: newSeq[int](), + shouldSucceed: false + ), + ( + name: "not all needed segments&parity retrieved", + segmentsCount: 8, + expectedParitySegmentsCount: 1, + retrievedSegments: @[1, 2, 8], + retrievedParitySegments: @[8], + shouldSucceed: false + ), + ( + name: "segments&parity retrieved", + segmentsCount: 8, + expectedParitySegmentsCount: 1, + retrievedSegments: @[1, 2, 3, 4, 5, 6, 7, 8], + retrievedParitySegments: @[8], + shouldSucceed: true + ), + ( + name: "segments&parity retrieved out of order", + segmentsCount: 16, + expectedParitySegmentsCount: 2, + retrievedSegments: @[17, 0, 16, 1, 15, 2, 14, 3, 13, 4, 12, 5, 11, 6, 10, 7], + retrievedParitySegments: @[16, 17], + shouldSucceed: true + ) + ] - # Check payload restored - check finalStatus.transportLayer.payload == testPayload[0..