diff --git a/.gitignore b/.gitignore index 1843ca9..9a9c03f 100644 --- a/.gitignore +++ b/.gitignore @@ -20,12 +20,12 @@ nimcache/ # Compiled files chat_sdk/* -!*.nim apps/* !*.nim tests/* !*.nim ratelimit/* !*.nim +!*.proto nimble.develop nimble.paths diff --git a/chat_sdk.nimble b/chat_sdk.nimble index 8841ff3..b1165f3 100644 --- a/chat_sdk.nimble +++ b/chat_sdk.nimble @@ -17,3 +17,6 @@ task buildStaticLib, "Build static library for C bindings": task migrate, "Run database migrations": exec "nim c -r apps/run_migration.nim" + +task segment, "Run segmentation": + exec "nim c -r chat_sdk/segmentation.nim" diff --git a/chat_sdk/db.nim b/chat_sdk/db.nim new file mode 100644 index 0000000..7b46906 --- /dev/null +++ b/chat_sdk/db.nim @@ -0,0 +1,129 @@ +import std/strutils +import results +import db_connector/db_sqlite +import db_models +import nimcrypto + +type SegmentationPersistence* = object + db*: DbConn + +proc removeMessageSegmentsOlderThan*( + self: SegmentationPersistence, timestamp: int64 +): Result[void, string] = + try: + self.db.exec(sql"DELETE FROM message_segments WHERE timestamp < ?", timestamp) + ok() + except DbError as e: + err("remove message segments with error: " & e.msg) + +proc removeMessageSegmentsCompletedOlderThan*( + self: SegmentationPersistence, timestamp: int64 +): Result[void, string] = + try: + self.db.exec( + sql"DELETE FROM message_segments_completed WHERE timestamp < ?", timestamp + ) + ok() + except DbError as e: + err("remove message segments completed with error: " & e.msg) + +proc isMessageAlreadyCompleted*( + self: SegmentationPersistence, hash: seq[byte] +): Result[bool, string] = + try: + let row = self.db.getRow( + sql"SELECT COUNT(*) FROM message_segments_completed WHERE hash = ?", hash + ) + if row.len == 0: + return ok(false) + let count = row[0].parseInt + ok(count > 0) + except CatchableError as e: + err("check message already completed with error: " & e.msg) + +proc saveMessageSegment*( + self: SegmentationPersistence, + segment: SegmentMessage, + sigPubKeyBlob: seq[byte], + timestamp: int64, +): Result[void, string] = + try: + self.db.exec( + sql""" + INSERT INTO message_segments ( + hash, segment_index, segments_count, parity_segment_index, + parity_segments_count, sig_pub_key, payload, timestamp + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?)""", + segment.entireMessageHash.toHex, + segment.index, + segment.segmentsCount, + segment.paritySegmentIndex, + segment.paritySegmentsCount, + sigPubKeyBlob, + segment.payload.toHex, + timestamp, + ) + + let storedPayload = self.db.getValue( + sql"SELECT payload FROM message_segments WHERE hash = ? AND sig_pub_key = ? AND segment_index = ?", + segment.entireMessageHash.toHex, sigPubKeyBlob, segment.index + ) + let payloadBytes = nimcrypto.fromHex(storedPayload) + + ok() + except DbError as e: + err("save message segments with error: " & e.msg) + +proc getMessageSegments*(self: SegmentationPersistence, hash: seq[byte], sigPubKeyBlob: seq[byte]): Result[seq[SegmentMessage], string] = + var segments = newSeq[SegmentMessage]() + let query = sql""" + SELECT + hash, segment_index, segments_count, parity_segment_index, parity_segments_count, payload + FROM + message_segments + WHERE + hash = ? AND sig_pub_key = ? + ORDER BY + (segments_count = 0) ASC, + segment_index ASC, + parity_segment_index ASC + """ + + try: + for row in self.db.rows(query, hash.toHex, sigPubKeyBlob): + let segment = SegmentMessage( + entireMessageHash: nimcrypto.fromHex(row[0]), + index: uint32(parseInt(row[1])), + segmentsCount: uint32(parseInt(row[2])), + paritySegmentIndex: uint32(parseInt(row[3])), + paritySegmentsCount: uint32(parseInt(row[4])), + payload: nimcrypto.fromHex(row[5]) + ) + segments.add(segment) + return ok(segments) + except CatchableError as e: + return err("get Message Segments with error: " & e.msg) + +proc completeMessageSegments*( + self: SegmentationPersistence, + hash: seq[byte], + sigPubKeyBlob: seq[byte], + timestamp: int64 +): Result[void, string] = + try: + self.db.exec(sql"BEGIN") + + # Delete old message segments + self.db.exec(sql"DELETE FROM message_segments WHERE hash = ? AND sig_pub_key = ?", hash, sigPubKeyBlob) + + # Insert completed marker + self.db.exec(sql"INSERT INTO message_segments_completed (hash, sig_pub_key, timestamp) VALUES (?, ?, ?)", + hash, sigPubKeyBlob, timestamp) + + self.db.exec(sql"COMMIT") + return ok() + except DbError as e: + try: + self.db.exec(sql"ROLLBACK") + except: discard + return err("complete segment messages with error: " & e.msg) diff --git a/chat_sdk/db_models.nim b/chat_sdk/db_models.nim new file mode 100644 index 0000000..392760b --- /dev/null +++ b/chat_sdk/db_models.nim @@ -0,0 +1,8 @@ +type + SegmentMessage* = ref object + entireMessageHash*: seq[byte] + index*: uint32 + segmentsCount*: uint32 + paritySegmentIndex*: uint32 + paritySegmentsCount*: uint32 + payload*: seq[byte] diff --git a/chat_sdk/migration.nim b/chat_sdk/migration.nim index a23f740..2c91f21 100644 --- a/chat_sdk/migration.nim +++ b/chat_sdk/migration.nim @@ -31,16 +31,16 @@ proc runMigrations*(db: DbConn, dir = "migrations") = info "Migration already applied", file else: info "Applying migration", file - let sqlContent = readFile(file) + let script = readFile(file) db.exec(sql"BEGIN TRANSACTION") try: - # Split by semicolon and execute each statement separately - for stmt in sqlContent.split(';'): - let trimmedStmt = stmt.strip() - if trimmedStmt.len > 0: - db.exec(sql(trimmedStmt)) + for stmt in script.split(";"): + let trimmed = stmt.strip() + if trimmed.len > 0: + db.exec(sql(trimmed)) + markMigrationRun(db, file) db.exec(sql"COMMIT") except: db.exec(sql"ROLLBACK") - raise newException(ValueError, "Migration failed: " & file & " - " & getCurrentExceptionMsg()) + raise diff --git a/chat_sdk/segment_message.proto b/chat_sdk/segment_message.proto new file mode 100644 index 0000000..1b8c556 --- /dev/null +++ b/chat_sdk/segment_message.proto @@ -0,0 +1,16 @@ +syntax = "proto3"; + +message SegmentMessageProto { + // hash of the entire original message + bytes entire_message_hash = 1; + // Index of this segment within the entire original message + uint32 index = 2; + // Total number of segments the entire original message is divided into + uint32 segments_count = 3; + // The payload data for this particular segment + bytes payload = 4; + // Index of this parity segment + uint32 parity_segment_index = 5; + // Total number of parity segments + uint32 parity_segments_count = 6; +} diff --git a/chat_sdk/segmentation.nim b/chat_sdk/segmentation.nim new file mode 100644 index 0000000..85d20ed --- /dev/null +++ b/chat_sdk/segmentation.nim @@ -0,0 +1,256 @@ +import math, times, sequtils, strutils, options +import nimcrypto +import results +import leopard +import chronicles + +import protobuf_serialization/proto_parser +import protobuf_serialization + +import db_models +import db + +import_proto3 "segment_message.proto" + +type + Chunk* = object + payload*: seq[byte] + + Message* = object + hash*: seq[byte] + payload*: seq[byte] + sigPubKey*: seq[byte] + + SegmentationHander* = object + segmentSize*: int + persistence*: SegmentationPersistence + +const + SegmentsParityRate = 0.125 + SegmentsReedsolomonMaxCount = 256 + +const + ErrMessageSegmentsIncomplete* = "message segments incomplete" + ErrMessageSegmentsAlreadyCompleted* = "message segments already completed" + ErrMessageSegmentsInvalidCount* = "invalid segments count" + ErrMessageSegmentsHashMismatch* = "hash of entire payload does not match" + ErrMessageSegmentsInvalidParity* = "invalid parity segments" + +proc isValid(s: SegmentMessageProto): bool = + # Check if the hash length is valid (32 bytes for Keccak256) + if s.entire_message_hash.len != 32: + return false + + # Check if the segment index is within the valid range + if s.segments_count > 0 and s.index >= s.segments_count: + return false + + # Check if the parity segment index is within the valid range + if s.parity_segments_count > 0 and s.parity_segment_index >= s.parity_segments_count: + return false + + # Check if segments_count is at least 2 or parity_segments_count is positive + return s.segments_count >= 2 or s.parity_segments_count > 0 + +proc isParityMessage*(s: SegmentMessage): bool = + s.segmentsCount == 0 and s.paritySegmentsCount > 0 + +proc segmentMessage*(s: SegmentationHander, newMessage: Chunk): Result[seq[Chunk], string] = + if newMessage.payload.len <= s.segmentSize: + return ok(@[newMessage]) + + info "segmenting message", + "payloadSize" = newMessage.payload.len, + "segmentSize" = s.segmentSize + let entireMessageHash = keccak256.digest(newMessage.payload) + let entirePayloadSize = newMessage.payload.len + + let segmentsCount = int(ceil(entirePayloadSize.float / s.segmentSize.float)) + let paritySegmentsCount = int(floor(segmentsCount.float * SegmentsParityRate)) + + var segmentPayloads = newSeq[seq[byte]](segmentsCount) + var segmentMessages = newSeq[Chunk](segmentsCount) + + for i in 0.. entirePayloadSize: + endIndex = entirePayloadSize + + let segmentPayload = newMessage.payload[start.. SegmentsReedsolomonMaxCount: + return ok(segmentMessages) + + # Align last segment payload for Reed-Solomon (leopard requires fixed-size shards) + let lastSegmentPayload = segmentPayloads[segmentsCount-1] + segmentPayloads[segmentsCount-1] = newSeq[byte](s.segmentSize) + segmentPayloads[segmentsCount-1][0.. 0: + payloads[firstSegmentMessage.segmentsCount - 1] = lastNonParitySegmentPayload + + # Combine payload + var entirePayload = newSeq[byte]() + for i in 0..