mirror of
https://github.com/logos-messaging/nim-chat-sdk.git
synced 2026-01-02 14:13:07 +00:00
chore: refactor segment logic
This commit is contained in:
parent
211f8275db
commit
001e5b97c9
@ -3,11 +3,11 @@ import results
|
||||
import db_connector/db_sqlite
|
||||
import db_models
|
||||
|
||||
type MessageSegmentsPersistence = object
|
||||
type SegmentationPersistence* = object
|
||||
db: DbConn
|
||||
|
||||
proc removeMessageSegmentsOlderThan*(
|
||||
self: MessageSegmentsPersistence, timestamp: int64
|
||||
self: SegmentationPersistence, timestamp: int64
|
||||
): Result[void, string] =
|
||||
try:
|
||||
self.db.exec(sql"DELETE FROM message_segments WHERE timestamp < ?", timestamp)
|
||||
@ -16,7 +16,7 @@ proc removeMessageSegmentsOlderThan*(
|
||||
err("remove message segments with error: " & e.msg)
|
||||
|
||||
proc removeMessageSegmentsCompletedOlderThan*(
|
||||
self: MessageSegmentsPersistence, timestamp: int64
|
||||
self: SegmentationPersistence, timestamp: int64
|
||||
): Result[void, string] =
|
||||
try:
|
||||
self.db.exec(
|
||||
@ -27,7 +27,7 @@ proc removeMessageSegmentsCompletedOlderThan*(
|
||||
err("remove message segments completed with error: " & e.msg)
|
||||
|
||||
proc isMessageAlreadyCompleted*(
|
||||
self: MessageSegmentsPersistence, hash: seq[byte]
|
||||
self: SegmentationPersistence, hash: seq[byte]
|
||||
): Result[bool, string] =
|
||||
try:
|
||||
let row = self.db.getRow(
|
||||
@ -41,7 +41,7 @@ proc isMessageAlreadyCompleted*(
|
||||
err("check message already completed with error: " & e.msg)
|
||||
|
||||
proc saveMessageSegment*(
|
||||
self: MessageSegmentsPersistence,
|
||||
self: SegmentationPersistence,
|
||||
segment: SegmentMessage,
|
||||
sigPubKeyBlob: seq[byte],
|
||||
timestamp: int64,
|
||||
@ -66,7 +66,7 @@ proc saveMessageSegment*(
|
||||
except DbError as e:
|
||||
err("save message segments with error: " & e.msg)
|
||||
|
||||
proc getMessageSegments*(self: MessageSegmentsPersistence, hash: seq[byte], sigPubKeyBlob: seq[byte]): Result[seq[SegmentMessage], string] =
|
||||
proc getMessageSegments*(self: SegmentationPersistence, hash: seq[byte], sigPubKeyBlob: seq[byte]): Result[seq[SegmentMessage], string] =
|
||||
var segments = newSeq[SegmentMessage]()
|
||||
let query = sql"""
|
||||
SELECT
|
||||
@ -97,7 +97,7 @@ proc getMessageSegments*(self: MessageSegmentsPersistence, hash: seq[byte], sigP
|
||||
return err("get Message Segments with error: " & e.msg)
|
||||
|
||||
proc completeMessageSegments*(
|
||||
self: MessageSegmentsPersistence,
|
||||
self: SegmentationPersistence,
|
||||
hash: seq[byte],
|
||||
sigPubKeyBlob: seq[byte],
|
||||
timestamp: int64
|
||||
|
||||
@ -8,25 +8,28 @@ import protobuf_serialization/proto_parser
|
||||
import protobuf_serialization
|
||||
|
||||
import db_models
|
||||
import db
|
||||
|
||||
import_proto3 "segment_message.proto"
|
||||
|
||||
# Placeholder types (unchanged)
|
||||
type
|
||||
WakuNewMessage* = object
|
||||
Chunk* = object
|
||||
payload*: seq[byte]
|
||||
# Add other fields as needed
|
||||
|
||||
Message* = object
|
||||
hash*: seq[byte]
|
||||
payload*: seq[byte]
|
||||
sigPubKey*: seq[byte]
|
||||
|
||||
SegmentationHander* = object
|
||||
maxMessageSize*: int
|
||||
persistence*: SegmentationPersistence
|
||||
|
||||
Persistence = object
|
||||
completedMessages: Table[seq[byte], bool] # Mock storage for completed message hashes
|
||||
segments: Table[(seq[byte], seq[byte]), seq[SegmentMessage]] # Stores segments by (hash, sigPubKey)
|
||||
const
|
||||
SegmentsParityRate = 0.125
|
||||
SegmentsReedsolomonMaxCount = 256
|
||||
|
||||
# Error definitions (unchanged)
|
||||
const
|
||||
ErrMessageSegmentsIncomplete = "message segments incomplete"
|
||||
ErrMessageSegmentsAlreadyCompleted = "message segments already completed"
|
||||
@ -34,12 +37,6 @@ const
|
||||
ErrMessageSegmentsHashMismatch = "hash of entire payload does not match"
|
||||
ErrMessageSegmentsInvalidParity = "invalid parity segments"
|
||||
|
||||
# Constants (unchanged)
|
||||
const
|
||||
SegmentsParityRate = 0.125
|
||||
SegmentsReedsolomonMaxCount = 256
|
||||
|
||||
|
||||
proc isValid(s: SegmentMessageProto): bool =
|
||||
# Check if the hash length is valid (32 bytes for Keccak256)
|
||||
if s.entire_message_hash.len != 32:
|
||||
@ -59,51 +56,8 @@ proc isValid(s: SegmentMessageProto): bool =
|
||||
proc isParityMessage*(s: SegmentMessage): bool =
|
||||
s.segmentsCount == 0 and s.paritySegmentsCount > 0
|
||||
|
||||
# MessageSender type (unchanged)
|
||||
type
|
||||
MessageSender* = ref object
|
||||
messaging*: Messaging
|
||||
persistence*: Persistence
|
||||
|
||||
Messaging* = object
|
||||
maxMessageSize*: int
|
||||
|
||||
proc initPersistence*(): Persistence =
|
||||
Persistence(
|
||||
completedMessages: initTable[seq[byte], bool](),
|
||||
segments: initTable[(seq[byte], seq[byte]), seq[SegmentMessage]]()
|
||||
)
|
||||
|
||||
proc isMessageAlreadyCompleted(p: Persistence, hash: seq[byte]): Result[bool, string] =
|
||||
return ok(p.completedMessages.getOrDefault(hash, false))
|
||||
|
||||
proc saveMessageSegment(p: var Persistence, segment: SegmentMessage, sigPubKey: seq[byte], timestamp: int64): Result[void, string] =
|
||||
let key = (segment.entireMessageHash, sigPubKey)
|
||||
# Initialize or append to the segments list
|
||||
if not p.segments.hasKey(key):
|
||||
p.segments[key] = @[]
|
||||
p.segments[key].add(segment)
|
||||
return ok()
|
||||
|
||||
proc getMessageSegments(p: Persistence, hash: seq[byte], sigPubKey: seq[byte]): Result[seq[SegmentMessage], string] =
|
||||
let key = (hash, sigPubKey)
|
||||
return ok(p.segments.getOrDefault(key, @[]))
|
||||
|
||||
proc completeMessageSegments(p: var Persistence, hash: seq[byte], sigPubKey: seq[byte], timestamp: int64): Result[void, string] =
|
||||
p.completedMessages[hash] = true
|
||||
return ok()
|
||||
|
||||
# Replicate message (unchanged)
|
||||
proc replicateMessageWithNewPayload(message: WakuNewMessage, payload: seq[byte]): Result[WakuNewMessage, string] =
|
||||
var copiedMessage = WakuNewMessage(payload: payload)
|
||||
return ok(copiedMessage)
|
||||
|
||||
proc protoMarshal(msg: SegmentMessage): Result[seq[byte], string] =
|
||||
# Fake serialization (index + payload length) TODO
|
||||
return ok(@[byte(msg.index)] & msg.payload)
|
||||
|
||||
# Segment message into smaller chunks (updated with nim-leopard)
|
||||
proc segmentMessageInternal(newMessage: WakuNewMessage, segmentSize: int): Result[seq[WakuNewMessage], string] =
|
||||
proc segmentMessage*(s: SegmentationHander, newMessage: Chunk): Result[seq[Chunk], string] =
|
||||
let segmentSize = s.maxMessageSize div 4 * 3
|
||||
if newMessage.payload.len <= segmentSize:
|
||||
return ok(@[newMessage])
|
||||
|
||||
@ -114,7 +68,7 @@ proc segmentMessageInternal(newMessage: WakuNewMessage, segmentSize: int): Resul
|
||||
let paritySegmentsCount = int(floor(segmentsCount.float * SegmentsParityRate))
|
||||
|
||||
var segmentPayloads = newSeq[seq[byte]](segmentsCount + paritySegmentsCount)
|
||||
var segmentMessages = newSeq[WakuNewMessage](segmentsCount)
|
||||
var segmentMessages = newSeq[Chunk](segmentsCount)
|
||||
|
||||
for i in 0..<segmentsCount:
|
||||
let start = i * segmentSize
|
||||
@ -123,23 +77,18 @@ proc segmentMessageInternal(newMessage: WakuNewMessage, segmentSize: int): Resul
|
||||
endIndex = entirePayloadSize
|
||||
|
||||
let segmentPayload = newMessage.payload[start..<endIndex]
|
||||
let segmentWithMetadata = SegmentMessage(
|
||||
let segmentWithMetadata = SegmentMessageProto(
|
||||
entireMessageHash: entireMessageHash.data.toSeq,
|
||||
index: uint32(i),
|
||||
segmentsCount: uint32(segmentsCount),
|
||||
payload: segmentPayload
|
||||
)
|
||||
|
||||
let marshaledSegment = protoMarshal(segmentWithMetadata)
|
||||
if marshaledSegment.isErr:
|
||||
return err("failed to marshal SegmentMessage: " & marshaledSegment.error)
|
||||
|
||||
let segmentMessage = replicateMessageWithNewPayload(newMessage, marshaledSegment.get())
|
||||
if segmentMessage.isErr:
|
||||
return err("failed to replicate message: " & segmentMessage.error)
|
||||
let marshaledSegment = Protobuf.encode(segmentWithMetadata)
|
||||
let segmentMessage = Chunk(payload: marshaledSegment)
|
||||
|
||||
segmentPayloads[i] = segmentPayload
|
||||
segmentMessages[i] = segmentMessage.get()
|
||||
segmentMessages[i] = segmentMessage
|
||||
|
||||
# Skip Reed-Solomon if parity segments are 0 or total exceeds max count
|
||||
if paritySegmentsCount == 0 or (segmentsCount + paritySegmentsCount) > SegmentsReedsolomonMaxCount:
|
||||
@ -157,9 +106,11 @@ proc segmentMessageInternal(newMessage: WakuNewMessage, segmentSize: int): Resul
|
||||
# Use nim-leopard for Reed-Solomon encoding
|
||||
var data = segmentPayloads[0..<segmentsCount]
|
||||
var parity = segmentPayloads[segmentsCount..<(segmentsCount + paritySegmentsCount)]
|
||||
|
||||
var encoderRes = LeoEncoder.init(segmentSize, segmentsCount, paritySegmentsCount)
|
||||
if encoderRes.isErr:
|
||||
return err("failed to initialize encoder: " & $encoderRes.error)
|
||||
|
||||
var encoder = encoderRes.get
|
||||
let encodeResult = encoder.encode(data, parity)
|
||||
if encodeResult.isErr:
|
||||
@ -168,7 +119,7 @@ proc segmentMessageInternal(newMessage: WakuNewMessage, segmentSize: int): Resul
|
||||
# Create parity messages
|
||||
for i in segmentsCount..<(segmentsCount + paritySegmentsCount):
|
||||
let parityIndex = i - segmentsCount
|
||||
let segmentWithMetadata = SegmentMessage(
|
||||
let segmentWithMetadata = SegmentMessageProto(
|
||||
entireMessageHash: entireMessageHash.data.toSeq,
|
||||
segmentsCount: 0,
|
||||
paritySegmentIndex: uint32(parityIndex),
|
||||
@ -176,19 +127,14 @@ proc segmentMessageInternal(newMessage: WakuNewMessage, segmentSize: int): Resul
|
||||
payload: segmentPayloads[i]
|
||||
)
|
||||
|
||||
let marshaledSegment = protoMarshal(segmentWithMetadata)
|
||||
if marshaledSegment.isErr:
|
||||
return err("failed to marshal parity SegmentMessage: " & marshaledSegment.error)
|
||||
let marshaledSegment = Protobuf.encode(segmentWithMetadata)
|
||||
let segmentMessage = Chunk(payload: marshaledSegment)
|
||||
|
||||
let segmentMessage = replicateMessageWithNewPayload(newMessage, marshaledSegment.get())
|
||||
if segmentMessage.isErr:
|
||||
return err("failed to replicate parity message: " & segmentMessage.error)
|
||||
|
||||
segmentMessages.add(segmentMessage.get())
|
||||
segmentMessages.add(segmentMessage)
|
||||
|
||||
return ok(segmentMessages)
|
||||
|
||||
proc handleSegmentationLayer*(s: MessageSender, message: var Message): Result[void, string] =
|
||||
proc handleSegmentationLayer*(handler: SegmentationHander, message: var Message): Result[void, string] =
|
||||
logScope:
|
||||
site = "handleSegmentationLayer"
|
||||
hash = message.hash.toHex
|
||||
@ -203,7 +149,7 @@ proc handleSegmentationLayer*(s: MessageSender, message: var Message): Result[vo
|
||||
"ParitySegmentsCount" = $segmentMessageProto.paritySegmentsCount
|
||||
|
||||
# TODO here use the mock function, real function should be used together with the persistence layer
|
||||
let alreadyCompleted = s.persistence.isMessageAlreadyCompleted(segmentMessageProto.entireMessageHash)
|
||||
let alreadyCompleted = handler.persistence.isMessageAlreadyCompleted(segmentMessageProto.entireMessageHash)
|
||||
if alreadyCompleted.isErr:
|
||||
return err(alreadyCompleted.error)
|
||||
if alreadyCompleted.get():
|
||||
@ -220,11 +166,11 @@ proc handleSegmentationLayer*(s: MessageSender, message: var Message): Result[vo
|
||||
paritySegmentsCount: segmentMessageProto.paritySegmentsCount,
|
||||
payload: segmentMessageProto.payload
|
||||
)
|
||||
let saveResult = s.persistence.saveMessageSegment(segmentMessage, message.sigPubKey, getTime().toUnix)
|
||||
let saveResult = handler.persistence.saveMessageSegment(segmentMessage, message.sigPubKey, getTime().toUnix)
|
||||
if saveResult.isErr:
|
||||
return err(saveResult.error)
|
||||
|
||||
let segments = s.persistence.getMessageSegments(segmentMessage.entireMessageHash, message.sigPubKey)
|
||||
let segments = handler.persistence.getMessageSegments(segmentMessage.entireMessageHash, message.sigPubKey)
|
||||
if segments.isErr:
|
||||
return err(segments.error)
|
||||
|
||||
@ -292,7 +238,7 @@ proc handleSegmentationLayer*(s: MessageSender, message: var Message): Result[vo
|
||||
if entirePayloadHash.data != segmentMessage.entireMessageHash:
|
||||
return err(ErrMessageSegmentsHashMismatch)
|
||||
|
||||
let completeResult = s.persistence.completeMessageSegments(segmentMessage.entireMessageHash, message.sigPubKey, getTime().toUnix)
|
||||
let completeResult = handler.persistence.completeMessageSegments(segmentMessage.entireMessageHash, message.sigPubKey, getTime().toUnix)
|
||||
if completeResult.isErr:
|
||||
return err(completeResult.error)
|
||||
|
||||
@ -300,20 +246,10 @@ proc handleSegmentationLayer*(s: MessageSender, message: var Message): Result[vo
|
||||
return ok()
|
||||
|
||||
|
||||
proc cleanupSegments*(s: MessageSender): Result[void, string] =
|
||||
proc cleanupSegments*(s: SegmentationHander): Result[void, string] =
|
||||
# Same as previous translation
|
||||
discard
|
||||
|
||||
# SegmentMessage proc (unchanged)
|
||||
proc segmentMessage*(s: MessageSender, newMessage: WakuNewMessage): Result[seq[WakuNewMessage], string] =
|
||||
let segmentSize = s.messaging.maxMessageSize div 4 * 3
|
||||
let segmentResult = segmentMessageInternal(newMessage, segmentSize)
|
||||
if segmentResult.isErr:
|
||||
return err("segmentMessage failed: " & result.error)
|
||||
let messages = result.get()
|
||||
debug "message segmented", "segments" = $messages.len
|
||||
return ok(messages)
|
||||
|
||||
proc demo() =
|
||||
let
|
||||
bufSize = 64 # byte count per buffer, must be a multiple of 64
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user