mirror of
https://github.com/logos-messaging/nim-chat-sdk.git
synced 2026-01-04 07:03:09 +00:00
Merge remote-tracking branch 'origin/main' into feat/rate-limit-store-state
This commit is contained in:
commit
0e0f357f70
2
.gitignore
vendored
2
.gitignore
vendored
@ -20,12 +20,12 @@ nimcache/
|
|||||||
|
|
||||||
# Compiled files
|
# Compiled files
|
||||||
chat_sdk/*
|
chat_sdk/*
|
||||||
!*.nim
|
|
||||||
apps/*
|
apps/*
|
||||||
!*.nim
|
!*.nim
|
||||||
tests/*
|
tests/*
|
||||||
!*.nim
|
!*.nim
|
||||||
ratelimit/*
|
ratelimit/*
|
||||||
!*.nim
|
!*.nim
|
||||||
|
!*.proto
|
||||||
nimble.develop
|
nimble.develop
|
||||||
nimble.paths
|
nimble.paths
|
||||||
|
|||||||
@ -17,3 +17,6 @@ task buildStaticLib, "Build static library for C bindings":
|
|||||||
|
|
||||||
task migrate, "Run database migrations":
|
task migrate, "Run database migrations":
|
||||||
exec "nim c -r apps/run_migration.nim"
|
exec "nim c -r apps/run_migration.nim"
|
||||||
|
|
||||||
|
task segment, "Run segmentation":
|
||||||
|
exec "nim c -r chat_sdk/segmentation.nim"
|
||||||
|
|||||||
129
chat_sdk/db.nim
Normal file
129
chat_sdk/db.nim
Normal file
@ -0,0 +1,129 @@
|
|||||||
|
import std/strutils
|
||||||
|
import results
|
||||||
|
import db_connector/db_sqlite
|
||||||
|
import db_models
|
||||||
|
import nimcrypto
|
||||||
|
|
||||||
|
type SegmentationPersistence* = object
|
||||||
|
db*: DbConn
|
||||||
|
|
||||||
|
proc removeMessageSegmentsOlderThan*(
|
||||||
|
self: SegmentationPersistence, timestamp: int64
|
||||||
|
): Result[void, string] =
|
||||||
|
try:
|
||||||
|
self.db.exec(sql"DELETE FROM message_segments WHERE timestamp < ?", timestamp)
|
||||||
|
ok()
|
||||||
|
except DbError as e:
|
||||||
|
err("remove message segments with error: " & e.msg)
|
||||||
|
|
||||||
|
proc removeMessageSegmentsCompletedOlderThan*(
|
||||||
|
self: SegmentationPersistence, timestamp: int64
|
||||||
|
): Result[void, string] =
|
||||||
|
try:
|
||||||
|
self.db.exec(
|
||||||
|
sql"DELETE FROM message_segments_completed WHERE timestamp < ?", timestamp
|
||||||
|
)
|
||||||
|
ok()
|
||||||
|
except DbError as e:
|
||||||
|
err("remove message segments completed with error: " & e.msg)
|
||||||
|
|
||||||
|
proc isMessageAlreadyCompleted*(
|
||||||
|
self: SegmentationPersistence, hash: seq[byte]
|
||||||
|
): Result[bool, string] =
|
||||||
|
try:
|
||||||
|
let row = self.db.getRow(
|
||||||
|
sql"SELECT COUNT(*) FROM message_segments_completed WHERE hash = ?", hash
|
||||||
|
)
|
||||||
|
if row.len == 0:
|
||||||
|
return ok(false)
|
||||||
|
let count = row[0].parseInt
|
||||||
|
ok(count > 0)
|
||||||
|
except CatchableError as e:
|
||||||
|
err("check message already completed with error: " & e.msg)
|
||||||
|
|
||||||
|
proc saveMessageSegment*(
|
||||||
|
self: SegmentationPersistence,
|
||||||
|
segment: SegmentMessage,
|
||||||
|
sigPubKeyBlob: seq[byte],
|
||||||
|
timestamp: int64,
|
||||||
|
): Result[void, string] =
|
||||||
|
try:
|
||||||
|
self.db.exec(
|
||||||
|
sql"""
|
||||||
|
INSERT INTO message_segments (
|
||||||
|
hash, segment_index, segments_count, parity_segment_index,
|
||||||
|
parity_segments_count, sig_pub_key, payload, timestamp
|
||||||
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
|
||||||
|
segment.entireMessageHash.toHex,
|
||||||
|
segment.index,
|
||||||
|
segment.segmentsCount,
|
||||||
|
segment.paritySegmentIndex,
|
||||||
|
segment.paritySegmentsCount,
|
||||||
|
sigPubKeyBlob,
|
||||||
|
segment.payload.toHex,
|
||||||
|
timestamp,
|
||||||
|
)
|
||||||
|
|
||||||
|
let storedPayload = self.db.getValue(
|
||||||
|
sql"SELECT payload FROM message_segments WHERE hash = ? AND sig_pub_key = ? AND segment_index = ?",
|
||||||
|
segment.entireMessageHash.toHex, sigPubKeyBlob, segment.index
|
||||||
|
)
|
||||||
|
let payloadBytes = nimcrypto.fromHex(storedPayload)
|
||||||
|
|
||||||
|
ok()
|
||||||
|
except DbError as e:
|
||||||
|
err("save message segments with error: " & e.msg)
|
||||||
|
|
||||||
|
proc getMessageSegments*(self: SegmentationPersistence, hash: seq[byte], sigPubKeyBlob: seq[byte]): Result[seq[SegmentMessage], string] =
|
||||||
|
var segments = newSeq[SegmentMessage]()
|
||||||
|
let query = sql"""
|
||||||
|
SELECT
|
||||||
|
hash, segment_index, segments_count, parity_segment_index, parity_segments_count, payload
|
||||||
|
FROM
|
||||||
|
message_segments
|
||||||
|
WHERE
|
||||||
|
hash = ? AND sig_pub_key = ?
|
||||||
|
ORDER BY
|
||||||
|
(segments_count = 0) ASC,
|
||||||
|
segment_index ASC,
|
||||||
|
parity_segment_index ASC
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
for row in self.db.rows(query, hash.toHex, sigPubKeyBlob):
|
||||||
|
let segment = SegmentMessage(
|
||||||
|
entireMessageHash: nimcrypto.fromHex(row[0]),
|
||||||
|
index: uint32(parseInt(row[1])),
|
||||||
|
segmentsCount: uint32(parseInt(row[2])),
|
||||||
|
paritySegmentIndex: uint32(parseInt(row[3])),
|
||||||
|
paritySegmentsCount: uint32(parseInt(row[4])),
|
||||||
|
payload: nimcrypto.fromHex(row[5])
|
||||||
|
)
|
||||||
|
segments.add(segment)
|
||||||
|
return ok(segments)
|
||||||
|
except CatchableError as e:
|
||||||
|
return err("get Message Segments with error: " & e.msg)
|
||||||
|
|
||||||
|
proc completeMessageSegments*(
|
||||||
|
self: SegmentationPersistence,
|
||||||
|
hash: seq[byte],
|
||||||
|
sigPubKeyBlob: seq[byte],
|
||||||
|
timestamp: int64
|
||||||
|
): Result[void, string] =
|
||||||
|
try:
|
||||||
|
self.db.exec(sql"BEGIN")
|
||||||
|
|
||||||
|
# Delete old message segments
|
||||||
|
self.db.exec(sql"DELETE FROM message_segments WHERE hash = ? AND sig_pub_key = ?", hash, sigPubKeyBlob)
|
||||||
|
|
||||||
|
# Insert completed marker
|
||||||
|
self.db.exec(sql"INSERT INTO message_segments_completed (hash, sig_pub_key, timestamp) VALUES (?, ?, ?)",
|
||||||
|
hash, sigPubKeyBlob, timestamp)
|
||||||
|
|
||||||
|
self.db.exec(sql"COMMIT")
|
||||||
|
return ok()
|
||||||
|
except DbError as e:
|
||||||
|
try:
|
||||||
|
self.db.exec(sql"ROLLBACK")
|
||||||
|
except: discard
|
||||||
|
return err("complete segment messages with error: " & e.msg)
|
||||||
8
chat_sdk/db_models.nim
Normal file
8
chat_sdk/db_models.nim
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
type
|
||||||
|
SegmentMessage* = ref object
|
||||||
|
entireMessageHash*: seq[byte]
|
||||||
|
index*: uint32
|
||||||
|
segmentsCount*: uint32
|
||||||
|
paritySegmentIndex*: uint32
|
||||||
|
paritySegmentsCount*: uint32
|
||||||
|
payload*: seq[byte]
|
||||||
@ -31,16 +31,16 @@ proc runMigrations*(db: DbConn, dir = "migrations") =
|
|||||||
info "Migration already applied", file
|
info "Migration already applied", file
|
||||||
else:
|
else:
|
||||||
info "Applying migration", file
|
info "Applying migration", file
|
||||||
let sqlContent = readFile(file)
|
let script = readFile(file)
|
||||||
db.exec(sql"BEGIN TRANSACTION")
|
db.exec(sql"BEGIN TRANSACTION")
|
||||||
try:
|
try:
|
||||||
# Split by semicolon and execute each statement separately
|
for stmt in script.split(";"):
|
||||||
for stmt in sqlContent.split(';'):
|
let trimmed = stmt.strip()
|
||||||
let trimmedStmt = stmt.strip()
|
if trimmed.len > 0:
|
||||||
if trimmedStmt.len > 0:
|
db.exec(sql(trimmed))
|
||||||
db.exec(sql(trimmedStmt))
|
|
||||||
markMigrationRun(db, file)
|
markMigrationRun(db, file)
|
||||||
db.exec(sql"COMMIT")
|
db.exec(sql"COMMIT")
|
||||||
except:
|
except:
|
||||||
db.exec(sql"ROLLBACK")
|
db.exec(sql"ROLLBACK")
|
||||||
raise newException(ValueError, "Migration failed: " & file & " - " & getCurrentExceptionMsg())
|
raise
|
||||||
|
|||||||
16
chat_sdk/segment_message.proto
Normal file
16
chat_sdk/segment_message.proto
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
message SegmentMessageProto {
|
||||||
|
// hash of the entire original message
|
||||||
|
bytes entire_message_hash = 1;
|
||||||
|
// Index of this segment within the entire original message
|
||||||
|
uint32 index = 2;
|
||||||
|
// Total number of segments the entire original message is divided into
|
||||||
|
uint32 segments_count = 3;
|
||||||
|
// The payload data for this particular segment
|
||||||
|
bytes payload = 4;
|
||||||
|
// Index of this parity segment
|
||||||
|
uint32 parity_segment_index = 5;
|
||||||
|
// Total number of parity segments
|
||||||
|
uint32 parity_segments_count = 6;
|
||||||
|
}
|
||||||
256
chat_sdk/segmentation.nim
Normal file
256
chat_sdk/segmentation.nim
Normal file
@ -0,0 +1,256 @@
|
|||||||
|
import math, times, sequtils, strutils, options
|
||||||
|
import nimcrypto
|
||||||
|
import results
|
||||||
|
import leopard
|
||||||
|
import chronicles
|
||||||
|
|
||||||
|
import protobuf_serialization/proto_parser
|
||||||
|
import protobuf_serialization
|
||||||
|
|
||||||
|
import db_models
|
||||||
|
import db
|
||||||
|
|
||||||
|
import_proto3 "segment_message.proto"
|
||||||
|
|
||||||
|
type
|
||||||
|
Chunk* = object
|
||||||
|
payload*: seq[byte]
|
||||||
|
|
||||||
|
Message* = object
|
||||||
|
hash*: seq[byte]
|
||||||
|
payload*: seq[byte]
|
||||||
|
sigPubKey*: seq[byte]
|
||||||
|
|
||||||
|
SegmentationHander* = object
|
||||||
|
segmentSize*: int
|
||||||
|
persistence*: SegmentationPersistence
|
||||||
|
|
||||||
|
const
|
||||||
|
SegmentsParityRate = 0.125
|
||||||
|
SegmentsReedsolomonMaxCount = 256
|
||||||
|
|
||||||
|
const
|
||||||
|
ErrMessageSegmentsIncomplete* = "message segments incomplete"
|
||||||
|
ErrMessageSegmentsAlreadyCompleted* = "message segments already completed"
|
||||||
|
ErrMessageSegmentsInvalidCount* = "invalid segments count"
|
||||||
|
ErrMessageSegmentsHashMismatch* = "hash of entire payload does not match"
|
||||||
|
ErrMessageSegmentsInvalidParity* = "invalid parity segments"
|
||||||
|
|
||||||
|
proc isValid(s: SegmentMessageProto): bool =
|
||||||
|
# Check if the hash length is valid (32 bytes for Keccak256)
|
||||||
|
if s.entire_message_hash.len != 32:
|
||||||
|
return false
|
||||||
|
|
||||||
|
# Check if the segment index is within the valid range
|
||||||
|
if s.segments_count > 0 and s.index >= s.segments_count:
|
||||||
|
return false
|
||||||
|
|
||||||
|
# Check if the parity segment index is within the valid range
|
||||||
|
if s.parity_segments_count > 0 and s.parity_segment_index >= s.parity_segments_count:
|
||||||
|
return false
|
||||||
|
|
||||||
|
# Check if segments_count is at least 2 or parity_segments_count is positive
|
||||||
|
return s.segments_count >= 2 or s.parity_segments_count > 0
|
||||||
|
|
||||||
|
proc isParityMessage*(s: SegmentMessage): bool =
|
||||||
|
s.segmentsCount == 0 and s.paritySegmentsCount > 0
|
||||||
|
|
||||||
|
proc segmentMessage*(s: SegmentationHander, newMessage: Chunk): Result[seq[Chunk], string] =
|
||||||
|
if newMessage.payload.len <= s.segmentSize:
|
||||||
|
return ok(@[newMessage])
|
||||||
|
|
||||||
|
info "segmenting message",
|
||||||
|
"payloadSize" = newMessage.payload.len,
|
||||||
|
"segmentSize" = s.segmentSize
|
||||||
|
let entireMessageHash = keccak256.digest(newMessage.payload)
|
||||||
|
let entirePayloadSize = newMessage.payload.len
|
||||||
|
|
||||||
|
let segmentsCount = int(ceil(entirePayloadSize.float / s.segmentSize.float))
|
||||||
|
let paritySegmentsCount = int(floor(segmentsCount.float * SegmentsParityRate))
|
||||||
|
|
||||||
|
var segmentPayloads = newSeq[seq[byte]](segmentsCount)
|
||||||
|
var segmentMessages = newSeq[Chunk](segmentsCount)
|
||||||
|
|
||||||
|
for i in 0..<segmentsCount:
|
||||||
|
let start = i * s.segmentSize
|
||||||
|
var endIndex = start + s.segmentSize
|
||||||
|
if endIndex > entirePayloadSize:
|
||||||
|
endIndex = entirePayloadSize
|
||||||
|
|
||||||
|
let segmentPayload = newMessage.payload[start..<endIndex]
|
||||||
|
let segmentWithMetadata = SegmentMessageProto(
|
||||||
|
entireMessageHash: entireMessageHash.data.toSeq,
|
||||||
|
index: uint32(i),
|
||||||
|
segmentsCount: uint32(segmentsCount),
|
||||||
|
payload: segmentPayload
|
||||||
|
)
|
||||||
|
|
||||||
|
let marshaledSegment = Protobuf.encode(segmentWithMetadata)
|
||||||
|
let segmentMessage = Chunk(payload: marshaledSegment)
|
||||||
|
|
||||||
|
segmentPayloads[i] = segmentPayload
|
||||||
|
segmentMessages[i] = segmentMessage
|
||||||
|
|
||||||
|
# Skip Reed-Solomon if parity segments are 0 or total exceeds max count
|
||||||
|
info "segments count", "len" = segmentMessages.len
|
||||||
|
if paritySegmentsCount == 0 or (segmentsCount + paritySegmentsCount) > SegmentsReedsolomonMaxCount:
|
||||||
|
return ok(segmentMessages)
|
||||||
|
|
||||||
|
# Align last segment payload for Reed-Solomon (leopard requires fixed-size shards)
|
||||||
|
let lastSegmentPayload = segmentPayloads[segmentsCount-1]
|
||||||
|
segmentPayloads[segmentsCount-1] = newSeq[byte](s.segmentSize)
|
||||||
|
segmentPayloads[segmentsCount-1][0..<lastSegmentPayload.len] = lastSegmentPayload
|
||||||
|
|
||||||
|
# Use nim-leopard for Reed-Solomon encoding
|
||||||
|
var parity = newSeq[seq[byte]](paritySegmentsCount)
|
||||||
|
for i in 0..<paritySegmentsCount:
|
||||||
|
newSeq(parity[i], s.segmentSize)
|
||||||
|
|
||||||
|
var encoderRes = LeoEncoder.init(s.segmentSize, segmentsCount, paritySegmentsCount)
|
||||||
|
if encoderRes.isErr:
|
||||||
|
return err("failed to initialize encoder: " & $encoderRes.error)
|
||||||
|
|
||||||
|
var encoder = encoderRes.get
|
||||||
|
let encodeResult = encoder.encode(segmentPayloads, parity)
|
||||||
|
if encodeResult.isErr:
|
||||||
|
return err("failed to encode segments with leopard: " & $encodeResult.error)
|
||||||
|
|
||||||
|
# Create parity messages
|
||||||
|
for i in 0..<paritySegmentsCount:
|
||||||
|
let segmentWithMetadata = SegmentMessageProto(
|
||||||
|
entireMessageHash: entireMessageHash.data.toSeq,
|
||||||
|
segmentsCount: 0,
|
||||||
|
paritySegmentIndex: uint32(i),
|
||||||
|
paritySegmentsCount: uint32(paritySegmentsCount),
|
||||||
|
payload: parity[i]
|
||||||
|
)
|
||||||
|
|
||||||
|
let marshaledSegment = Protobuf.encode(segmentWithMetadata)
|
||||||
|
let segmentMessage = Chunk(payload: marshaledSegment)
|
||||||
|
|
||||||
|
segmentMessages.add(segmentMessage)
|
||||||
|
|
||||||
|
return ok(segmentMessages)
|
||||||
|
|
||||||
|
proc handleSegmentationLayer*(handler: SegmentationHander, message: var Message): Result[void, string] =
|
||||||
|
logScope:
|
||||||
|
site = "handleSegmentationLayer"
|
||||||
|
hash = message.hash.toHex
|
||||||
|
|
||||||
|
let segmentMessageProto = Protobuf.decode(message.payload, SegmentMessageProto)
|
||||||
|
|
||||||
|
debug "handling message segment",
|
||||||
|
"EntireMessageHash" = segmentMessageProto.entireMessageHash.toHex,
|
||||||
|
"Index" = $segmentMessageProto.index,
|
||||||
|
"SegmentsCount" = $segmentMessageProto.segmentsCount,
|
||||||
|
"ParitySegmentIndex" = $segmentMessageProto.paritySegmentIndex,
|
||||||
|
"ParitySegmentsCount" = $segmentMessageProto.paritySegmentsCount
|
||||||
|
|
||||||
|
let alreadyCompleted = handler.persistence.isMessageAlreadyCompleted(segmentMessageProto.entireMessageHash)
|
||||||
|
if alreadyCompleted.isErr:
|
||||||
|
return err(alreadyCompleted.error)
|
||||||
|
if alreadyCompleted.get():
|
||||||
|
return err(ErrMessageSegmentsAlreadyCompleted)
|
||||||
|
|
||||||
|
if not segmentMessageProto.isValid():
|
||||||
|
return err(ErrMessageSegmentsInvalidCount)
|
||||||
|
|
||||||
|
let segmentMessage = SegmentMessage(
|
||||||
|
entireMessageHash: segmentMessageProto.entireMessageHash,
|
||||||
|
index: segmentMessageProto.index,
|
||||||
|
segmentsCount: segmentMessageProto.segmentsCount,
|
||||||
|
paritySegmentIndex: segmentMessageProto.paritySegmentIndex,
|
||||||
|
paritySegmentsCount: segmentMessageProto.paritySegmentsCount,
|
||||||
|
payload: segmentMessageProto.payload
|
||||||
|
)
|
||||||
|
let saveResult = handler.persistence.saveMessageSegment(segmentMessage, message.sigPubKey, getTime().toUnix)
|
||||||
|
if saveResult.isErr:
|
||||||
|
return err(saveResult.error)
|
||||||
|
|
||||||
|
let segments = handler.persistence.getMessageSegments(segmentMessage.entireMessageHash, message.sigPubKey)
|
||||||
|
if segments.isErr:
|
||||||
|
return err(segments.error)
|
||||||
|
|
||||||
|
if segments.get().len == 0:
|
||||||
|
return err("unexpected state: no segments found after save operation")
|
||||||
|
|
||||||
|
let firstSegmentMessage = segments.get()[0]
|
||||||
|
let lastSegmentMessage = segments.get()[^1]
|
||||||
|
|
||||||
|
debug "first segment",
|
||||||
|
"index" = firstSegmentMessage.index,
|
||||||
|
"segmentsCount" = firstSegmentMessage.segmentsCount,
|
||||||
|
"paritySegmentIndex" = firstSegmentMessage.paritySegmentIndex,
|
||||||
|
"paritySegmentsCount" = firstSegmentMessage.paritySegmentsCount,
|
||||||
|
"len" = firstSegmentMessage.payload.len
|
||||||
|
|
||||||
|
debug "last segment",
|
||||||
|
"index" = lastSegmentMessage.index,
|
||||||
|
"segmentsCount" = lastSegmentMessage.segmentsCount,
|
||||||
|
"paritySegmentIndex" = lastSegmentMessage.paritySegmentIndex,
|
||||||
|
"len" = firstSegmentMessage.payload.len,
|
||||||
|
"paritySegmentsCount" = lastSegmentMessage.paritySegmentsCount
|
||||||
|
|
||||||
|
if firstSegmentMessage.isParityMessage() or segments.get().len != int(firstSegmentMessage.segmentsCount):
|
||||||
|
return err(ErrMessageSegmentsIncomplete)
|
||||||
|
|
||||||
|
var payloads = newSeq[seq[byte]](firstSegmentMessage.segmentsCount)
|
||||||
|
var parity = newSeq[seq[byte]](lastSegmentMessage.paritySegmentsCount)
|
||||||
|
let payloadSize = firstSegmentMessage.payload.len
|
||||||
|
|
||||||
|
let restoreUsingParityData = lastSegmentMessage.isParityMessage()
|
||||||
|
if not restoreUsingParityData:
|
||||||
|
for i, segment in segments.get():
|
||||||
|
payloads[i] = segment.payload
|
||||||
|
else:
|
||||||
|
var lastNonParitySegmentPayload: seq[byte]
|
||||||
|
for segment in segments.get():
|
||||||
|
if not segment.isParityMessage():
|
||||||
|
if segment.index == firstSegmentMessage.segmentsCount - 1:
|
||||||
|
payloads[segment.index] = newSeq[byte](payloadSize)
|
||||||
|
payloads[segment.index][0..<segment.payload.len] = segment.payload
|
||||||
|
lastNonParitySegmentPayload = segment.payload
|
||||||
|
else:
|
||||||
|
payloads[segment.index] = segment.payload
|
||||||
|
else:
|
||||||
|
parity[segment.paritySegmentIndex] = segment.payload
|
||||||
|
|
||||||
|
# Use nim-leopard for Reed-Solomon reconstruction
|
||||||
|
let decoderRes = LeoDecoder.init(payloadSize, int(firstSegmentMessage.segmentsCount), int(lastSegmentMessage.paritySegmentsCount))
|
||||||
|
if decoderRes.isErr:
|
||||||
|
return err("failed to initialize LeoDecoder: " & $decoderRes.error)
|
||||||
|
var decoder = decoderRes.get()
|
||||||
|
var recovered = newSeq[seq[byte]](int(firstSegmentMessage.segmentsCount)) # Allocate for recovered shards
|
||||||
|
for i in 0..<int(firstSegmentMessage.segmentsCount):
|
||||||
|
recovered[i] = newSeq[byte](payloadSize)
|
||||||
|
let reconstructResult = decoder.decode(payloads, parity, recovered)
|
||||||
|
if reconstructResult.isErr:
|
||||||
|
return err("failed to reconstruct payloads with leopard: " & $reconstructResult.error)
|
||||||
|
|
||||||
|
for i in 0..<firstSegmentMessage.segmentsCount:
|
||||||
|
if payloads[i].len == 0:
|
||||||
|
payloads[i] = recovered[i]
|
||||||
|
|
||||||
|
if lastNonParitySegmentPayload.len > 0:
|
||||||
|
payloads[firstSegmentMessage.segmentsCount - 1] = lastNonParitySegmentPayload
|
||||||
|
|
||||||
|
# Combine payload
|
||||||
|
var entirePayload = newSeq[byte]()
|
||||||
|
for i in 0..<int(firstSegmentMessage.segmentsCount):
|
||||||
|
entirePayload.add(payloads[i])
|
||||||
|
|
||||||
|
# Sanity check
|
||||||
|
let entirePayloadHash = keccak256.digest(entirePayload)
|
||||||
|
if entirePayloadHash.data != segmentMessage.entireMessageHash:
|
||||||
|
return err(ErrMessageSegmentsHashMismatch)
|
||||||
|
|
||||||
|
let completeResult = handler.persistence.completeMessageSegments(segmentMessage.entireMessageHash, message.sigPubKey, getTime().toUnix)
|
||||||
|
if completeResult.isErr:
|
||||||
|
return err(completeResult.error)
|
||||||
|
|
||||||
|
message.payload = entirePayload
|
||||||
|
return ok()
|
||||||
|
|
||||||
|
|
||||||
|
proc cleanupSegments*(s: SegmentationHander): Result[void, string] =
|
||||||
|
discard
|
||||||
21
migrations/002_create_message_segments_table.sql
Normal file
21
migrations/002_create_message_segments_table.sql
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
CREATE TABLE message_segments (
|
||||||
|
hash BLOB NOT NULL,
|
||||||
|
segment_index INTEGER NOT NULL,
|
||||||
|
segments_count INTEGER NOT NULL,
|
||||||
|
payload BLOB NOT NULL,
|
||||||
|
sig_pub_key BLOB NOT NULL,
|
||||||
|
timestamp INTEGER NOT NULL,
|
||||||
|
parity_segment_index INTEGER NOT NULL,
|
||||||
|
parity_segments_count INTEGER NOT NULL,
|
||||||
|
PRIMARY KEY (hash, sig_pub_key, segment_index, segments_count, parity_segment_index, parity_segments_count) ON CONFLICT REPLACE
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS message_segments_completed (
|
||||||
|
hash BLOB NOT NULL,
|
||||||
|
sig_pub_key BLOB NOT NULL,
|
||||||
|
timestamp INTEGER DEFAULT 0,
|
||||||
|
PRIMARY KEY (hash, sig_pub_key)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX idx_message_segments_timestamp ON message_segments(timestamp);
|
||||||
|
CREATE INDEX idx_message_segments_completed_timestamp ON message_segments_completed(timestamp);
|
||||||
@ -7,6 +7,8 @@
|
|||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import chat_sdk
|
# import chat_sdk
|
||||||
|
import chat_sdk/segmentation
|
||||||
|
|
||||||
test "can add":
|
test "can add":
|
||||||
check add(5, 5) == 10
|
check add(5, 5) == 10
|
||||||
143
tests/test_segmentation.nim
Normal file
143
tests/test_segmentation.nim
Normal file
@ -0,0 +1,143 @@
|
|||||||
|
# This is just an example to get you started. You may wish to put all of your
|
||||||
|
# tests into a single file, or separate them into multiple `test1`, `test2`
|
||||||
|
# etc. files (better names are recommended, just make sure the name starts with
|
||||||
|
# the letter 't').
|
||||||
|
#
|
||||||
|
# To run these tests, simply execute `nimble test`.
|
||||||
|
|
||||||
|
import unittest, sequtils, random, math
|
||||||
|
import results
|
||||||
|
import db_connector/db_sqlite
|
||||||
|
import nimcrypto
|
||||||
|
import chat_sdk/segmentation
|
||||||
|
import chat_sdk/migration
|
||||||
|
import chat_sdk/db
|
||||||
|
|
||||||
|
proc newInMemoryPersistence(): SegmentationPersistence =
|
||||||
|
let conn = open(":memory:", "", "", "")
|
||||||
|
# Define the tables (same schema as expected in your app)
|
||||||
|
runMigrations(conn)
|
||||||
|
result = SegmentationPersistence(db: conn)
|
||||||
|
|
||||||
|
suite "message Segmentation":
|
||||||
|
var
|
||||||
|
sender: SegmentationHander
|
||||||
|
testPayload: seq[byte]
|
||||||
|
mockPersistence: SegmentationPersistence
|
||||||
|
|
||||||
|
setup:
|
||||||
|
# Initialize test payload (1000 bytes, each byte is its index)
|
||||||
|
testPayload = newSeq[byte](1024)
|
||||||
|
for i in 0..<1024:
|
||||||
|
testPayload[i] = rand(255).byte
|
||||||
|
|
||||||
|
# Setup mock persistence
|
||||||
|
mockPersistence = newInMemoryPersistence()
|
||||||
|
|
||||||
|
# Setup sender
|
||||||
|
sender = SegmentationHander(
|
||||||
|
segmentSize: 4000, # Arbitrary size to allow segmentation
|
||||||
|
persistence: mockPersistence
|
||||||
|
)
|
||||||
|
|
||||||
|
test "HandleSegmentationLayer":
|
||||||
|
# Define test cases (mirroring Go test cases)
|
||||||
|
let testCases = @[
|
||||||
|
(
|
||||||
|
name: "all segments retrieved",
|
||||||
|
segmentsCount: 2,
|
||||||
|
expectedParitySegmentsCount: 0,
|
||||||
|
retrievedSegments: @[0, 1],
|
||||||
|
retrievedParitySegments: newSeq[int](),
|
||||||
|
shouldSucceed: true
|
||||||
|
),
|
||||||
|
(
|
||||||
|
name: "all segments retrieved out of order",
|
||||||
|
segmentsCount: 2,
|
||||||
|
expectedParitySegmentsCount: 0,
|
||||||
|
retrievedSegments: @[1, 0],
|
||||||
|
retrievedParitySegments: newSeq[int](),
|
||||||
|
shouldSucceed: true
|
||||||
|
),
|
||||||
|
(
|
||||||
|
name: "all segments&parity retrieved",
|
||||||
|
segmentsCount: 8,
|
||||||
|
expectedParitySegmentsCount: 1,
|
||||||
|
retrievedSegments: @[0, 1, 2, 3, 4, 5, 6, 7, 8],
|
||||||
|
retrievedParitySegments: @[8],
|
||||||
|
shouldSucceed: true
|
||||||
|
),
|
||||||
|
(
|
||||||
|
name: "all segments&parity retrieved out of order",
|
||||||
|
segmentsCount: 8,
|
||||||
|
expectedParitySegmentsCount: 1,
|
||||||
|
retrievedSegments: @[8, 0, 7, 1, 6, 2, 5, 3, 4],
|
||||||
|
retrievedParitySegments: @[8],
|
||||||
|
shouldSucceed: true
|
||||||
|
),
|
||||||
|
(
|
||||||
|
name: "no segments retrieved",
|
||||||
|
segmentsCount: 2,
|
||||||
|
expectedParitySegmentsCount: 0,
|
||||||
|
retrievedSegments: @[],
|
||||||
|
retrievedParitySegments: newSeq[int](),
|
||||||
|
shouldSucceed: false
|
||||||
|
),
|
||||||
|
(
|
||||||
|
name: "not all needed segments&parity retrieved",
|
||||||
|
segmentsCount: 8,
|
||||||
|
expectedParitySegmentsCount: 1,
|
||||||
|
retrievedSegments: @[1, 2, 8],
|
||||||
|
retrievedParitySegments: @[8],
|
||||||
|
shouldSucceed: false
|
||||||
|
),
|
||||||
|
(
|
||||||
|
name: "segments&parity retrieved",
|
||||||
|
segmentsCount: 8,
|
||||||
|
expectedParitySegmentsCount: 1,
|
||||||
|
retrievedSegments: @[1, 2, 3, 4, 5, 6, 7, 8],
|
||||||
|
retrievedParitySegments: @[8],
|
||||||
|
shouldSucceed: true
|
||||||
|
),
|
||||||
|
(
|
||||||
|
name: "segments&parity retrieved out of order",
|
||||||
|
segmentsCount: 16,
|
||||||
|
expectedParitySegmentsCount: 2,
|
||||||
|
retrievedSegments: @[17, 0, 16, 1, 15, 2, 14, 3, 13, 4, 12, 5, 11, 6, 10, 7],
|
||||||
|
retrievedParitySegments: @[16, 17],
|
||||||
|
shouldSucceed: true
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
for tc in testCases:
|
||||||
|
test tc.name:
|
||||||
|
# Segment the message
|
||||||
|
let chunk = Chunk(payload: testPayload)
|
||||||
|
sender.segmentSize = int(ceil(testPayload.len.float / tc.segmentsCount.float))
|
||||||
|
let segmentedMessagesRes = sender.segmentMessage(chunk)
|
||||||
|
require(segmentedMessagesRes.isOk)
|
||||||
|
let segmentedMessages = segmentedMessagesRes.get()
|
||||||
|
check(segmentedMessages.len == tc.segmentsCount + tc.expectedParitySegmentsCount)
|
||||||
|
|
||||||
|
var message = Message(
|
||||||
|
sigPubKey: keccak256.digest("testkey").data.toSeq,
|
||||||
|
hash: keccak256.digest(testPayload).data.toSeq
|
||||||
|
)
|
||||||
|
|
||||||
|
var messageRecreated = false
|
||||||
|
var handledSegments: seq[int] = @[]
|
||||||
|
|
||||||
|
for i, segmentIndex in tc.retrievedSegments:
|
||||||
|
message.payload = segmentedMessages[segmentIndex].payload
|
||||||
|
let err = sender.handleSegmentationLayer(message)
|
||||||
|
handledSegments.add(segmentIndex)
|
||||||
|
if handledSegments.len < tc.segmentsCount:
|
||||||
|
check(err.isErr and err.error == ErrMessageSegmentsIncomplete)
|
||||||
|
elif handledSegments.len == tc.segmentsCount:
|
||||||
|
check(err.isOk)
|
||||||
|
check(message.payload == testPayload)
|
||||||
|
messageRecreated = true
|
||||||
|
else:
|
||||||
|
check(err.isErr and err.error == ErrMessageSegmentsAlreadyCompleted)
|
||||||
|
|
||||||
|
check(messageRecreated == tc.shouldSucceed)
|
||||||
Loading…
x
Reference in New Issue
Block a user