mirror of
https://github.com/logos-messaging/nim-chat-sdk.git
synced 2026-01-02 14:13:07 +00:00
feat: fix tests
This commit is contained in:
parent
001e5b97c9
commit
7b71248a80
4
.gitignore
vendored
4
.gitignore
vendored
@ -20,8 +20,10 @@ nimcache/
|
||||
|
||||
# Compiled files
|
||||
chat_sdk/*
|
||||
!*.nim
|
||||
apps/*
|
||||
tests/*
|
||||
|
||||
!*.nim
|
||||
!*.proto
|
||||
nimble.develop
|
||||
nimble.paths
|
||||
|
||||
@ -2,9 +2,10 @@ import std/strutils
|
||||
import results
|
||||
import db_connector/db_sqlite
|
||||
import db_models
|
||||
import nimcrypto
|
||||
|
||||
type SegmentationPersistence* = object
|
||||
db: DbConn
|
||||
db*: DbConn
|
||||
|
||||
proc removeMessageSegmentsOlderThan*(
|
||||
self: SegmentationPersistence, timestamp: int64
|
||||
@ -46,6 +47,7 @@ proc saveMessageSegment*(
|
||||
sigPubKeyBlob: seq[byte],
|
||||
timestamp: int64,
|
||||
): Result[void, string] =
|
||||
echo "size.....", segment.payload.len
|
||||
try:
|
||||
self.db.exec(
|
||||
sql"""
|
||||
@ -53,15 +55,24 @@ proc saveMessageSegment*(
|
||||
hash, segment_index, segments_count, parity_segment_index,
|
||||
parity_segments_count, sig_pub_key, payload, timestamp
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
|
||||
segment.entireMessageHash,
|
||||
segment.entireMessageHash.toHex,
|
||||
segment.index,
|
||||
segment.segmentsCount,
|
||||
segment.paritySegmentIndex,
|
||||
segment.paritySegmentsCount,
|
||||
sigPubKeyBlob,
|
||||
segment.payload,
|
||||
segment.payload.toHex,
|
||||
timestamp,
|
||||
)
|
||||
|
||||
let storedPayload = self.db.getValue(
|
||||
sql"SELECT payload FROM message_segments WHERE hash = ? AND sig_pub_key = ? AND segment_index = ?",
|
||||
segment.entireMessageHash.toHex, sigPubKeyBlob, segment.index
|
||||
)
|
||||
let payloadBytes = nimcrypto.fromHex(storedPayload)
|
||||
# let lenBytes = payloadBytes.get()
|
||||
echo "Stored payload length: ", payloadBytes.len
|
||||
|
||||
ok()
|
||||
except DbError as e:
|
||||
err("save message segments with error: " & e.msg)
|
||||
@ -82,15 +93,16 @@ proc getMessageSegments*(self: SegmentationPersistence, hash: seq[byte], sigPubK
|
||||
"""
|
||||
|
||||
try:
|
||||
for row in self.db.rows(query, hash, sigPubKeyBlob):
|
||||
for row in self.db.rows(query, hash.toHex, sigPubKeyBlob):
|
||||
let segment = SegmentMessage(
|
||||
entireMessageHash: cast[seq[byte]](row[0]),
|
||||
entireMessageHash: nimcrypto.fromHex(row[0]),
|
||||
index: uint32(parseInt(row[1])),
|
||||
segmentsCount: uint32(parseInt(row[2])),
|
||||
paritySegmentIndex: uint32(parseInt(row[3])),
|
||||
paritySegmentsCount: uint32(parseInt(row[4])),
|
||||
payload: cast[seq[byte]](row[5])
|
||||
payload: nimcrypto.fromHex(row[5])
|
||||
)
|
||||
echo "read size.....", segment.payload.len
|
||||
segments.add(segment)
|
||||
return ok(segments)
|
||||
except CatchableError as e:
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import os, sequtils, algorithm
|
||||
import os, sequtils, algorithm, strutils
|
||||
import db_connector/db_sqlite
|
||||
import chronicles
|
||||
|
||||
@ -27,6 +27,16 @@ proc runMigrations*(db: DbConn, dir = "migrations") =
|
||||
info "Migration already applied", file
|
||||
else:
|
||||
info "Applying migration", file
|
||||
let sql = readFile(file)
|
||||
db.exec(sql(sql))
|
||||
markMigrationRun(db, file)
|
||||
let script = readFile(file)
|
||||
db.exec(sql"BEGIN TRANSACTION")
|
||||
try:
|
||||
for stmt in script.split(";"):
|
||||
let trimmed = stmt.strip()
|
||||
if trimmed.len > 0:
|
||||
db.exec(sql(trimmed))
|
||||
|
||||
markMigrationRun(db, file)
|
||||
db.exec(sql"COMMIT")
|
||||
except:
|
||||
db.exec(sql"ROLLBACK")
|
||||
raise
|
||||
|
||||
16
chat_sdk/segment_message.proto
Normal file
16
chat_sdk/segment_message.proto
Normal file
@ -0,0 +1,16 @@
|
||||
syntax = "proto3";
|
||||
|
||||
message SegmentMessageProto {
|
||||
// hash of the entire original message
|
||||
bytes entire_message_hash = 1;
|
||||
// Index of this segment within the entire original message
|
||||
uint32 index = 2;
|
||||
// Total number of segments the entire original message is divided into
|
||||
uint32 segments_count = 3;
|
||||
// The payload data for this particular segment
|
||||
bytes payload = 4;
|
||||
// Index of this parity segment
|
||||
uint32 parity_segment_index = 5;
|
||||
// Total number of parity segments
|
||||
uint32 parity_segments_count = 6;
|
||||
}
|
||||
@ -23,7 +23,7 @@ type
|
||||
sigPubKey*: seq[byte]
|
||||
|
||||
SegmentationHander* = object
|
||||
maxMessageSize*: int
|
||||
segmentSize*: int
|
||||
persistence*: SegmentationPersistence
|
||||
|
||||
const
|
||||
@ -31,11 +31,11 @@ const
|
||||
SegmentsReedsolomonMaxCount = 256
|
||||
|
||||
const
|
||||
ErrMessageSegmentsIncomplete = "message segments incomplete"
|
||||
ErrMessageSegmentsAlreadyCompleted = "message segments already completed"
|
||||
ErrMessageSegmentsInvalidCount = "invalid segments count"
|
||||
ErrMessageSegmentsHashMismatch = "hash of entire payload does not match"
|
||||
ErrMessageSegmentsInvalidParity = "invalid parity segments"
|
||||
ErrMessageSegmentsIncomplete* = "message segments incomplete"
|
||||
ErrMessageSegmentsAlreadyCompleted* = "message segments already completed"
|
||||
ErrMessageSegmentsInvalidCount* = "invalid segments count"
|
||||
ErrMessageSegmentsHashMismatch* = "hash of entire payload does not match"
|
||||
ErrMessageSegmentsInvalidParity* = "invalid parity segments"
|
||||
|
||||
proc isValid(s: SegmentMessageProto): bool =
|
||||
# Check if the hash length is valid (32 bytes for Keccak256)
|
||||
@ -57,22 +57,29 @@ proc isParityMessage*(s: SegmentMessage): bool =
|
||||
s.segmentsCount == 0 and s.paritySegmentsCount > 0
|
||||
|
||||
proc segmentMessage*(s: SegmentationHander, newMessage: Chunk): Result[seq[Chunk], string] =
|
||||
let segmentSize = s.maxMessageSize div 4 * 3
|
||||
if newMessage.payload.len <= segmentSize:
|
||||
if newMessage.payload.len <= s.segmentSize:
|
||||
return ok(@[newMessage])
|
||||
|
||||
info "segmenting message",
|
||||
"payloadSize" = newMessage.payload.len,
|
||||
"segmentSize" = s.segmentSize
|
||||
info "segment payload in", "saaa" = newMessage.payload.toHex
|
||||
let entireMessageHash = keccak256.digest(newMessage.payload)
|
||||
let entirePayloadSize = newMessage.payload.len
|
||||
|
||||
let segmentsCount = int(ceil(entirePayloadSize.float / segmentSize.float))
|
||||
let segmentsCount = int(ceil(entirePayloadSize.float / s.segmentSize.float))
|
||||
let paritySegmentsCount = int(floor(segmentsCount.float * SegmentsParityRate))
|
||||
info "parity count", paritySegmentsCount
|
||||
info "segmentsCount", segmentsCount
|
||||
info "entirePayloadSize", entirePayloadSize
|
||||
info "segmentSize", "segmentSize" = s.segmentSize
|
||||
|
||||
var segmentPayloads = newSeq[seq[byte]](segmentsCount + paritySegmentsCount)
|
||||
var segmentPayloads = newSeq[seq[byte]](segmentsCount)
|
||||
var segmentMessages = newSeq[Chunk](segmentsCount)
|
||||
|
||||
for i in 0..<segmentsCount:
|
||||
let start = i * segmentSize
|
||||
var endIndex = start + segmentSize
|
||||
let start = i * s.segmentSize
|
||||
var endIndex = start + s.segmentSize
|
||||
if endIndex > entirePayloadSize:
|
||||
endIndex = entirePayloadSize
|
||||
|
||||
@ -84,6 +91,8 @@ proc segmentMessage*(s: SegmentationHander, newMessage: Chunk): Result[seq[Chunk
|
||||
payload: segmentPayload
|
||||
)
|
||||
|
||||
info "entireMessageHash", "entireMessageHash" = entireMessageHash.data.toHex
|
||||
|
||||
let marshaledSegment = Protobuf.encode(segmentWithMetadata)
|
||||
let segmentMessage = Chunk(payload: marshaledSegment)
|
||||
|
||||
@ -91,40 +100,60 @@ proc segmentMessage*(s: SegmentationHander, newMessage: Chunk): Result[seq[Chunk
|
||||
segmentMessages[i] = segmentMessage
|
||||
|
||||
# Skip Reed-Solomon if parity segments are 0 or total exceeds max count
|
||||
info "segments count", "len" = segmentMessages.len
|
||||
if paritySegmentsCount == 0 or (segmentsCount + paritySegmentsCount) > SegmentsReedsolomonMaxCount:
|
||||
return ok(segmentMessages)
|
||||
|
||||
# Align last segment payload for Reed-Solomon (leopard requires fixed-size shards)
|
||||
let lastSegmentPayload = segmentPayloads[segmentsCount-1]
|
||||
segmentPayloads[segmentsCount-1] = newSeq[byte](segmentSize)
|
||||
segmentPayloads[segmentsCount-1] = newSeq[byte](s.segmentSize)
|
||||
segmentPayloads[segmentsCount-1][0..<lastSegmentPayload.len] = lastSegmentPayload
|
||||
|
||||
for i, shard in segmentPayloads:
|
||||
info "segment payload before encoding", "index" = i, "len" = shard.len, "payload" = shard
|
||||
|
||||
# Allocate space for parity shards
|
||||
for i in segmentsCount..<(segmentsCount + paritySegmentsCount):
|
||||
segmentPayloads[i] = newSeq[byte](segmentSize)
|
||||
# for i in segmentsCount..<(segmentsCount + paritySegmentsCount):
|
||||
# segmentPayloads[i] = newSeq[byte](s.segmentSize)
|
||||
|
||||
# Use nim-leopard for Reed-Solomon encoding
|
||||
var data = segmentPayloads[0..<segmentsCount]
|
||||
var parity = segmentPayloads[segmentsCount..<(segmentsCount + paritySegmentsCount)]
|
||||
# var data = segmentPayloads[0..<segmentsCount]
|
||||
var parity = newSeq[seq[byte]](paritySegmentsCount)
|
||||
for i in 0..<paritySegmentsCount:
|
||||
newSeq(parity[i], s.segmentSize)
|
||||
|
||||
# var parity = segmentPayloads[segmentsCount..<(segmentsCount + paritySegmentsCount)]
|
||||
|
||||
var encoderRes = LeoEncoder.init(segmentSize, segmentsCount, paritySegmentsCount)
|
||||
info "parity segments",
|
||||
"paritySegmentsCount" = paritySegmentsCount,
|
||||
"parity" = parity
|
||||
info "initializing encoder",
|
||||
"segmentSize" = s.segmentSize,
|
||||
"segmentsCount" = segmentsCount,
|
||||
"paritySegmentsCount" = paritySegmentsCount
|
||||
var encoderRes = LeoEncoder.init(s.segmentSize, segmentsCount, paritySegmentsCount)
|
||||
if encoderRes.isErr:
|
||||
return err("failed to initialize encoder: " & $encoderRes.error)
|
||||
|
||||
var encoder = encoderRes.get
|
||||
let encodeResult = encoder.encode(data, parity)
|
||||
let encodeResult = encoder.encode(segmentPayloads, parity)
|
||||
info "parity segments",
|
||||
"paritySegmentsCount" = paritySegmentsCount,
|
||||
"parity" = parity
|
||||
info "encoding result", "encodeResult" = encodeResult
|
||||
if encodeResult.isErr:
|
||||
return err("failed to encode segments with leopard: " & $encodeResult.error)
|
||||
|
||||
# Create parity messages
|
||||
for i in segmentsCount..<(segmentsCount + paritySegmentsCount):
|
||||
let parityIndex = i - segmentsCount
|
||||
for i in 0..<paritySegmentsCount:
|
||||
info "parity segment index", "index" = i
|
||||
info "parity segment payload", "payload" = parity[i]
|
||||
let segmentWithMetadata = SegmentMessageProto(
|
||||
entireMessageHash: entireMessageHash.data.toSeq,
|
||||
segmentsCount: 0,
|
||||
paritySegmentIndex: uint32(parityIndex),
|
||||
paritySegmentIndex: uint32(i),
|
||||
paritySegmentsCount: uint32(paritySegmentsCount),
|
||||
payload: segmentPayloads[i]
|
||||
payload: parity[i]
|
||||
)
|
||||
|
||||
let marshaledSegment = Protobuf.encode(segmentWithMetadata)
|
||||
@ -148,7 +177,6 @@ proc handleSegmentationLayer*(handler: SegmentationHander, message: var Message)
|
||||
"ParitySegmentIndex" = $segmentMessageProto.paritySegmentIndex,
|
||||
"ParitySegmentsCount" = $segmentMessageProto.paritySegmentsCount
|
||||
|
||||
# TODO here use the mock function, real function should be used together with the persistence layer
|
||||
let alreadyCompleted = handler.persistence.isMessageAlreadyCompleted(segmentMessageProto.entireMessageHash)
|
||||
if alreadyCompleted.isErr:
|
||||
return err(alreadyCompleted.error)
|
||||
@ -166,6 +194,7 @@ proc handleSegmentationLayer*(handler: SegmentationHander, message: var Message)
|
||||
paritySegmentsCount: segmentMessageProto.paritySegmentsCount,
|
||||
payload: segmentMessageProto.payload
|
||||
)
|
||||
info "segment payload", "len" = segmentMessage.payload.len, "segment payload" = segmentMessage.payload.toHex
|
||||
let saveResult = handler.persistence.saveMessageSegment(segmentMessage, message.sigPubKey, getTime().toUnix)
|
||||
if saveResult.isErr:
|
||||
return err(saveResult.error)
|
||||
@ -180,20 +209,40 @@ proc handleSegmentationLayer*(handler: SegmentationHander, message: var Message)
|
||||
let firstSegmentMessage = segments.get()[0]
|
||||
let lastSegmentMessage = segments.get()[^1]
|
||||
|
||||
info "first segment",
|
||||
"index" = firstSegmentMessage.index,
|
||||
"segmentsCount" = firstSegmentMessage.segmentsCount,
|
||||
"paritySegmentIndex" = firstSegmentMessage.paritySegmentIndex,
|
||||
"paritySegmentsCount" = firstSegmentMessage.paritySegmentsCount,
|
||||
"len" = firstSegmentMessage.payload.len
|
||||
|
||||
info "last segment",
|
||||
"index" = lastSegmentMessage.index,
|
||||
"segmentsCount" = lastSegmentMessage.segmentsCount,
|
||||
"paritySegmentIndex" = lastSegmentMessage.paritySegmentIndex,
|
||||
"len" = firstSegmentMessage.payload.len,
|
||||
"paritySegmentsCount" = lastSegmentMessage.paritySegmentsCount
|
||||
|
||||
if firstSegmentMessage.isParityMessage() or segments.get().len != int(firstSegmentMessage.segmentsCount):
|
||||
return err(ErrMessageSegmentsIncomplete)
|
||||
|
||||
var payloads = newSeq[seq[byte]](firstSegmentMessage.segmentsCount + lastSegmentMessage.paritySegmentsCount)
|
||||
var payloads = newSeq[seq[byte]](firstSegmentMessage.segmentsCount)
|
||||
var parity = newSeq[seq[byte]](lastSegmentMessage.paritySegmentsCount)
|
||||
let payloadSize = firstSegmentMessage.payload.len
|
||||
|
||||
let restoreUsingParityData = lastSegmentMessage.isParityMessage()
|
||||
info "restoreUsingParityData", restoreUsingParityData
|
||||
if not restoreUsingParityData:
|
||||
for i, segment in segments.get():
|
||||
payloads[i] = segment.payload
|
||||
info "segment payload=====", "index" = i, "len" = segment.payload.len, "payload" = segment.payload.toHex
|
||||
else:
|
||||
info "restoring using parity data...", "payloadSize" = payloadSize
|
||||
var lastNonParitySegmentPayload: seq[byte]
|
||||
for segment in segments.get():
|
||||
info "segment after read", "segment"= segment.entireMessageHash.toHex, "index" = segment.index, "paritySegmentIndex" = segment.paritySegmentIndex, "paritySegmentsCount" = segment.paritySegmentsCount, "len" = segment.payload.len
|
||||
if not segment.isParityMessage():
|
||||
info "segment index for non parity", "index" = segment.index
|
||||
if segment.index == firstSegmentMessage.segmentsCount - 1:
|
||||
payloads[segment.index] = newSeq[byte](payloadSize)
|
||||
payloads[segment.index][0..<segment.payload.len] = segment.payload
|
||||
@ -201,30 +250,34 @@ proc handleSegmentationLayer*(handler: SegmentationHander, message: var Message)
|
||||
else:
|
||||
payloads[segment.index] = segment.payload
|
||||
else:
|
||||
payloads[firstSegmentMessage.segmentsCount + segment.paritySegmentIndex] = segment.payload
|
||||
info "parity segment index.....", "index" = segment.paritySegmentIndex, "hhhh" = firstSegmentMessage.segmentsCount + segment.paritySegmentIndex
|
||||
info "parity segment index payload.....", "payload" = segment.payload
|
||||
parity[segment.paritySegmentIndex] = segment.payload
|
||||
|
||||
# Use nim-leopard for Reed-Solomon reconstruction
|
||||
info "leo decoder init",
|
||||
"payloadSize" = payloadSize,
|
||||
"segmentsCount" = firstSegmentMessage.segmentsCount,
|
||||
"paritySegmentsCount" = lastSegmentMessage.paritySegmentsCount
|
||||
let decoderRes = LeoDecoder.init(payloadSize, int(firstSegmentMessage.segmentsCount), int(lastSegmentMessage.paritySegmentsCount))
|
||||
if decoderRes.isErr:
|
||||
return err("failed to initialize LeoDecoder: " & $decoderRes.error)
|
||||
var decoder = decoderRes.get()
|
||||
var data = payloads[0..<int(firstSegmentMessage.segmentsCount)]
|
||||
var parity = payloads[int(firstSegmentMessage.segmentsCount)..<(int(firstSegmentMessage.segmentsCount) + int(lastSegmentMessage.paritySegmentsCount))]
|
||||
info "decode data", "data" = payloads
|
||||
info "decode data parity", "parityData" = parity
|
||||
var recovered = newSeq[seq[byte]](int(firstSegmentMessage.segmentsCount)) # Allocate for recovered shards
|
||||
for i in 0..<int(firstSegmentMessage.segmentsCount):
|
||||
recovered[i] = newSeq[byte](payloadSize)
|
||||
let reconstructResult = decoder.decode(data, parity, recovered)
|
||||
let reconstructResult = decoder.decode(payloads, parity, recovered)
|
||||
if reconstructResult.isErr:
|
||||
return err("failed to reconstruct payloads with leopard: " & $reconstructResult.error)
|
||||
|
||||
# Verify by checking hash (leopard doesn't have a direct verify function)
|
||||
var tempPayload = newSeq[byte]()
|
||||
for i in 0..<int(firstSegmentMessage.segmentsCount):
|
||||
tempPayload.add(payloads[i])
|
||||
let tempHash = keccak256.digest(tempPayload)
|
||||
if tempHash.data != segmentMessage.entireMessageHash:
|
||||
return err(ErrMessageSegmentsInvalidParity)
|
||||
info "recovered segments", "recoveredCount" = recovered.len, "recovered" = recovered
|
||||
|
||||
for i in 0..<firstSegmentMessage.segmentsCount:
|
||||
if payloads[i].len == 0:
|
||||
payloads[i] = recovered[i]
|
||||
|
||||
if lastNonParitySegmentPayload.len > 0:
|
||||
payloads[firstSegmentMessage.segmentsCount - 1] = lastNonParitySegmentPayload
|
||||
|
||||
@ -234,6 +287,8 @@ proc handleSegmentationLayer*(handler: SegmentationHander, message: var Message)
|
||||
entirePayload.add(payloads[i])
|
||||
|
||||
# Sanity check
|
||||
info "entirePayloadSize", "entirePayloadSize" = entirePayload.len
|
||||
info "entire payload", "entirePayloadSize" = entirePayload.toHex
|
||||
let entirePayloadHash = keccak256.digest(entirePayload)
|
||||
if entirePayloadHash.data != segmentMessage.entireMessageHash:
|
||||
return err(ErrMessageSegmentsHashMismatch)
|
||||
@ -247,7 +302,6 @@ proc handleSegmentationLayer*(handler: SegmentationHander, message: var Message)
|
||||
|
||||
|
||||
proc cleanupSegments*(s: SegmentationHander): Result[void, string] =
|
||||
# Same as previous translation
|
||||
discard
|
||||
|
||||
proc demo() =
|
||||
|
||||
@ -1,11 +1,13 @@
|
||||
CREATE TABLE IF NOT EXISTS message_segments (
|
||||
CREATE TABLE message_segments (
|
||||
hash BLOB NOT NULL,
|
||||
segment_index INTEGER NOT NULL,
|
||||
segments_count INTEGER NOT NULL,
|
||||
payload BLOB NOT NULL,
|
||||
sig_pub_key BLOB NOT NULL,
|
||||
timestamp INTEGER DEFAULT 0,
|
||||
PRIMARY KEY (hash, sig_pub_key, segment_index) ON CONFLICT REPLACE
|
||||
timestamp INTEGER NOT NULL,
|
||||
parity_segment_index INTEGER NOT NULL,
|
||||
parity_segments_count INTEGER NOT NULL,
|
||||
PRIMARY KEY (hash, sig_pub_key, segment_index, segments_count, parity_segment_index, parity_segments_count) ON CONFLICT REPLACE
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS message_segments_completed (
|
||||
@ -16,4 +18,4 @@ CREATE TABLE IF NOT EXISTS message_segments_completed (
|
||||
);
|
||||
|
||||
CREATE INDEX idx_message_segments_timestamp ON message_segments(timestamp);
|
||||
CREATE INDEX idx_message_segments_completed_timestamp ON message_segments_completed(timestamp);
|
||||
CREATE INDEX idx_message_segments_completed_timestamp ON message_segments_completed(timestamp);
|
||||
|
||||
@ -7,7 +7,7 @@
|
||||
|
||||
import unittest
|
||||
|
||||
import chat_sdk
|
||||
# import chat_sdk
|
||||
import chat_sdk/segmentation
|
||||
|
||||
test "can add":
|
||||
|
||||
@ -5,49 +5,153 @@
|
||||
#
|
||||
# To run these tests, simply execute `nimble test`.
|
||||
|
||||
import unittest, sequtils, random
|
||||
import unittest, sequtils, random, math
|
||||
import results
|
||||
import chat_sdk/segmentation # Replace with the actual file name
|
||||
import db_connector/db_sqlite
|
||||
import nimcrypto
|
||||
import chat_sdk/segmentation
|
||||
import chat_sdk/migration
|
||||
import chat_sdk/db
|
||||
|
||||
test "can add":
|
||||
check add2(5, 5) == 10
|
||||
|
||||
suite "Message Segmentation":
|
||||
let testPayload = toSeq(0..999).mapIt(byte(it))
|
||||
let testMessage = WakuNewMessage(payload: testPayload)
|
||||
let sender = MessageSender(
|
||||
messaging: Messaging(maxMessageSize: 500),
|
||||
persistence: initPersistence()
|
||||
)
|
||||
proc newInMemoryPersistence(): SegmentationPersistence =
|
||||
let conn = open(":memory:", "", "", "")
|
||||
# Define the tables (same schema as expected in your app)
|
||||
runMigrations(conn)
|
||||
result = SegmentationPersistence(db: conn)
|
||||
|
||||
test "Segment and reassemble full segments":
|
||||
let segmentResult = sender.segmentMessage(testMessage)
|
||||
check segmentResult.isOk
|
||||
var segments = segmentResult.get()
|
||||
check segments.len > 0
|
||||
suite "message Segmentation":
|
||||
var
|
||||
sender: SegmentationHander
|
||||
testPayload: seq[byte]
|
||||
mockPersistence: SegmentationPersistence
|
||||
|
||||
# Shuffle segment order for out-of-order test
|
||||
segments.shuffle()
|
||||
setup:
|
||||
# Initialize test payload (1000 bytes, each byte is its index)
|
||||
testPayload = newSeq[byte](1024)
|
||||
for i in 0..<1024:
|
||||
testPayload[i] = rand(255).byte
|
||||
|
||||
for segment in segments:
|
||||
var statusMsg: StatusMessage
|
||||
statusMsg.transportLayer = TransportLayer(
|
||||
hash: @[byte 0, 1, 2], # Dummy hash
|
||||
payload: segment.payload,
|
||||
sigPubKey: @[byte 3, 4, 5]
|
||||
)
|
||||
let result = sender.handleSegmentationLayerV2(statusMsg)
|
||||
discard result # Ignore intermediate errors
|
||||
# Setup mock persistence
|
||||
mockPersistence = newInMemoryPersistence()
|
||||
|
||||
# One last run to trigger completion
|
||||
var finalStatus: StatusMessage
|
||||
finalStatus.transportLayer = TransportLayer(
|
||||
hash: @[byte 0, 1, 2],
|
||||
payload: segments[0].payload,
|
||||
sigPubKey: @[byte 3, 4, 5]
|
||||
# Setup sender
|
||||
sender = SegmentationHander(
|
||||
segmentSize: 4000, # Arbitrary size to allow segmentation
|
||||
persistence: mockPersistence
|
||||
)
|
||||
let finalResult = sender.handleSegmentationLayerV2(finalStatus)
|
||||
check finalResult.isOk
|
||||
|
||||
test "HandleSegmentationLayer":
|
||||
# Define test cases (mirroring Go test cases)
|
||||
let testCases = @[
|
||||
(
|
||||
name: "all segments retrieved",
|
||||
segmentsCount: 2,
|
||||
expectedParitySegmentsCount: 0,
|
||||
retrievedSegments: @[0, 1],
|
||||
retrievedParitySegments: newSeq[int](),
|
||||
shouldSucceed: true
|
||||
),
|
||||
(
|
||||
name: "all segments retrieved out of order",
|
||||
segmentsCount: 2,
|
||||
expectedParitySegmentsCount: 0,
|
||||
retrievedSegments: @[1, 0],
|
||||
retrievedParitySegments: newSeq[int](),
|
||||
shouldSucceed: true
|
||||
),
|
||||
(
|
||||
name: "all segments&parity retrieved",
|
||||
segmentsCount: 8,
|
||||
expectedParitySegmentsCount: 1,
|
||||
retrievedSegments: @[0, 1, 2, 3, 4, 5, 6, 7, 8],
|
||||
retrievedParitySegments: @[8],
|
||||
shouldSucceed: true
|
||||
),
|
||||
(
|
||||
name: "all segments&parity retrieved out of order",
|
||||
segmentsCount: 8,
|
||||
expectedParitySegmentsCount: 1,
|
||||
retrievedSegments: @[8, 0, 7, 1, 6, 2, 5, 3, 4],
|
||||
retrievedParitySegments: @[8],
|
||||
shouldSucceed: true
|
||||
),
|
||||
(
|
||||
name: "no segments retrieved",
|
||||
segmentsCount: 2,
|
||||
expectedParitySegmentsCount: 0,
|
||||
retrievedSegments: @[],
|
||||
retrievedParitySegments: newSeq[int](),
|
||||
shouldSucceed: false
|
||||
),
|
||||
(
|
||||
name: "not all needed segments&parity retrieved",
|
||||
segmentsCount: 8,
|
||||
expectedParitySegmentsCount: 1,
|
||||
retrievedSegments: @[1, 2, 8],
|
||||
retrievedParitySegments: @[8],
|
||||
shouldSucceed: false
|
||||
),
|
||||
(
|
||||
name: "segments&parity retrieved",
|
||||
segmentsCount: 8,
|
||||
expectedParitySegmentsCount: 1,
|
||||
retrievedSegments: @[1, 2, 3, 4, 5, 6, 7, 8],
|
||||
retrievedParitySegments: @[8],
|
||||
shouldSucceed: true
|
||||
),
|
||||
(
|
||||
name: "segments&parity retrieved out of order",
|
||||
segmentsCount: 16,
|
||||
expectedParitySegmentsCount: 2,
|
||||
retrievedSegments: @[17, 0, 16, 1, 15, 2, 14, 3, 13, 4, 12, 5, 11, 6, 10, 7],
|
||||
retrievedParitySegments: @[16, 17],
|
||||
shouldSucceed: true
|
||||
)
|
||||
]
|
||||
|
||||
# Check payload restored
|
||||
check finalStatus.transportLayer.payload == testPayload[0..<finalStatus.transportLayer.payload.len]
|
||||
for tc in testCases:
|
||||
test tc.name:
|
||||
# Segment the message
|
||||
let chunk = Chunk(payload: testPayload)
|
||||
sender.segmentSize = int(ceil(testPayload.len.float / tc.segmentsCount.float))
|
||||
let segmentedMessagesRes = sender.segmentMessage(chunk)
|
||||
echo "segmentedMessagesRes: ", segmentedMessagesRes
|
||||
require(segmentedMessagesRes.isOk)
|
||||
let segmentedMessages = segmentedMessagesRes.get()
|
||||
check(segmentedMessages.len == tc.segmentsCount + tc.expectedParitySegmentsCount)
|
||||
|
||||
echo "segmet result: ", tc.segmentsCount, " ", tc.expectedParitySegmentsCount, " ", segmentedMessages.len
|
||||
|
||||
var message = Message(
|
||||
sigPubKey: keccak256.digest("testkey").data.toSeq,
|
||||
hash: keccak256.digest(testPayload).data.toSeq
|
||||
)
|
||||
|
||||
var messageRecreated = false
|
||||
var handledSegments: seq[int] = @[]
|
||||
|
||||
for i, segmentIndex in tc.retrievedSegments:
|
||||
echo "i=", i, " segmentIndex=", segmentIndex
|
||||
message.payload = segmentedMessages[segmentIndex].payload
|
||||
let err = sender.handleSegmentationLayer(message)
|
||||
|
||||
echo "handle err: ", err
|
||||
|
||||
handledSegments.add(segmentIndex)
|
||||
echo "handledSegments: ", handledSegments
|
||||
|
||||
echo "handledSegments result: ", handledSegments.len, " ", tc.segmentsCount
|
||||
|
||||
if handledSegments.len < tc.segmentsCount:
|
||||
check(err.isErr and err.error == ErrMessageSegmentsIncomplete)
|
||||
elif handledSegments.len == tc.segmentsCount:
|
||||
check(err.isOk)
|
||||
check(message.payload == testPayload)
|
||||
messageRecreated = true
|
||||
else:
|
||||
check(err.isErr and err.error == ErrMessageSegmentsAlreadyCompleted)
|
||||
|
||||
check(messageRecreated == tc.shouldSucceed)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user