chore: refactor segment logic

This commit is contained in:
kaichaosun 2025-07-16 15:54:32 +08:00
parent 211f8275db
commit 001e5b97c9
No known key found for this signature in database
GPG Key ID: 223E0F992F4F03BF
2 changed files with 35 additions and 99 deletions

View File

@ -3,11 +3,11 @@ import results
import db_connector/db_sqlite
import db_models
type MessageSegmentsPersistence = object
type SegmentationPersistence* = object
db: DbConn
proc removeMessageSegmentsOlderThan*(
self: MessageSegmentsPersistence, timestamp: int64
self: SegmentationPersistence, timestamp: int64
): Result[void, string] =
try:
self.db.exec(sql"DELETE FROM message_segments WHERE timestamp < ?", timestamp)
@ -16,7 +16,7 @@ proc removeMessageSegmentsOlderThan*(
err("remove message segments with error: " & e.msg)
proc removeMessageSegmentsCompletedOlderThan*(
self: MessageSegmentsPersistence, timestamp: int64
self: SegmentationPersistence, timestamp: int64
): Result[void, string] =
try:
self.db.exec(
@ -27,7 +27,7 @@ proc removeMessageSegmentsCompletedOlderThan*(
err("remove message segments completed with error: " & e.msg)
proc isMessageAlreadyCompleted*(
self: MessageSegmentsPersistence, hash: seq[byte]
self: SegmentationPersistence, hash: seq[byte]
): Result[bool, string] =
try:
let row = self.db.getRow(
@ -41,7 +41,7 @@ proc isMessageAlreadyCompleted*(
err("check message already completed with error: " & e.msg)
proc saveMessageSegment*(
self: MessageSegmentsPersistence,
self: SegmentationPersistence,
segment: SegmentMessage,
sigPubKeyBlob: seq[byte],
timestamp: int64,
@ -66,7 +66,7 @@ proc saveMessageSegment*(
except DbError as e:
err("save message segments with error: " & e.msg)
proc getMessageSegments*(self: MessageSegmentsPersistence, hash: seq[byte], sigPubKeyBlob: seq[byte]): Result[seq[SegmentMessage], string] =
proc getMessageSegments*(self: SegmentationPersistence, hash: seq[byte], sigPubKeyBlob: seq[byte]): Result[seq[SegmentMessage], string] =
var segments = newSeq[SegmentMessage]()
let query = sql"""
SELECT
@ -97,7 +97,7 @@ proc getMessageSegments*(self: MessageSegmentsPersistence, hash: seq[byte], sigP
return err("get Message Segments with error: " & e.msg)
proc completeMessageSegments*(
self: MessageSegmentsPersistence,
self: SegmentationPersistence,
hash: seq[byte],
sigPubKeyBlob: seq[byte],
timestamp: int64

View File

@ -8,25 +8,28 @@ import protobuf_serialization/proto_parser
import protobuf_serialization
import db_models
import db
import_proto3 "segment_message.proto"
# Placeholder types (unchanged)
type
WakuNewMessage* = object
Chunk* = object
payload*: seq[byte]
# Add other fields as needed
Message* = object
hash*: seq[byte]
payload*: seq[byte]
sigPubKey*: seq[byte]
SegmentationHander* = object
maxMessageSize*: int
persistence*: SegmentationPersistence
Persistence = object
completedMessages: Table[seq[byte], bool] # Mock storage for completed message hashes
segments: Table[(seq[byte], seq[byte]), seq[SegmentMessage]] # Stores segments by (hash, sigPubKey)
const
SegmentsParityRate = 0.125
SegmentsReedsolomonMaxCount = 256
# Error definitions (unchanged)
const
ErrMessageSegmentsIncomplete = "message segments incomplete"
ErrMessageSegmentsAlreadyCompleted = "message segments already completed"
@ -34,12 +37,6 @@ const
ErrMessageSegmentsHashMismatch = "hash of entire payload does not match"
ErrMessageSegmentsInvalidParity = "invalid parity segments"
# Constants (unchanged)
const
SegmentsParityRate = 0.125
SegmentsReedsolomonMaxCount = 256
proc isValid(s: SegmentMessageProto): bool =
# Check if the hash length is valid (32 bytes for Keccak256)
if s.entire_message_hash.len != 32:
@ -59,51 +56,8 @@ proc isValid(s: SegmentMessageProto): bool =
proc isParityMessage*(s: SegmentMessage): bool =
s.segmentsCount == 0 and s.paritySegmentsCount > 0
# MessageSender type (unchanged)
type
MessageSender* = ref object
messaging*: Messaging
persistence*: Persistence
Messaging* = object
maxMessageSize*: int
proc initPersistence*(): Persistence =
Persistence(
completedMessages: initTable[seq[byte], bool](),
segments: initTable[(seq[byte], seq[byte]), seq[SegmentMessage]]()
)
proc isMessageAlreadyCompleted(p: Persistence, hash: seq[byte]): Result[bool, string] =
return ok(p.completedMessages.getOrDefault(hash, false))
proc saveMessageSegment(p: var Persistence, segment: SegmentMessage, sigPubKey: seq[byte], timestamp: int64): Result[void, string] =
let key = (segment.entireMessageHash, sigPubKey)
# Initialize or append to the segments list
if not p.segments.hasKey(key):
p.segments[key] = @[]
p.segments[key].add(segment)
return ok()
proc getMessageSegments(p: Persistence, hash: seq[byte], sigPubKey: seq[byte]): Result[seq[SegmentMessage], string] =
let key = (hash, sigPubKey)
return ok(p.segments.getOrDefault(key, @[]))
proc completeMessageSegments(p: var Persistence, hash: seq[byte], sigPubKey: seq[byte], timestamp: int64): Result[void, string] =
p.completedMessages[hash] = true
return ok()
# Replicate message (unchanged)
proc replicateMessageWithNewPayload(message: WakuNewMessage, payload: seq[byte]): Result[WakuNewMessage, string] =
var copiedMessage = WakuNewMessage(payload: payload)
return ok(copiedMessage)
proc protoMarshal(msg: SegmentMessage): Result[seq[byte], string] =
# Fake serialization (index + payload length) TODO
return ok(@[byte(msg.index)] & msg.payload)
# Segment message into smaller chunks (updated with nim-leopard)
proc segmentMessageInternal(newMessage: WakuNewMessage, segmentSize: int): Result[seq[WakuNewMessage], string] =
proc segmentMessage*(s: SegmentationHander, newMessage: Chunk): Result[seq[Chunk], string] =
let segmentSize = s.maxMessageSize div 4 * 3
if newMessage.payload.len <= segmentSize:
return ok(@[newMessage])
@ -114,7 +68,7 @@ proc segmentMessageInternal(newMessage: WakuNewMessage, segmentSize: int): Resul
let paritySegmentsCount = int(floor(segmentsCount.float * SegmentsParityRate))
var segmentPayloads = newSeq[seq[byte]](segmentsCount + paritySegmentsCount)
var segmentMessages = newSeq[WakuNewMessage](segmentsCount)
var segmentMessages = newSeq[Chunk](segmentsCount)
for i in 0..<segmentsCount:
let start = i * segmentSize
@ -123,23 +77,18 @@ proc segmentMessageInternal(newMessage: WakuNewMessage, segmentSize: int): Resul
endIndex = entirePayloadSize
let segmentPayload = newMessage.payload[start..<endIndex]
let segmentWithMetadata = SegmentMessage(
let segmentWithMetadata = SegmentMessageProto(
entireMessageHash: entireMessageHash.data.toSeq,
index: uint32(i),
segmentsCount: uint32(segmentsCount),
payload: segmentPayload
)
let marshaledSegment = protoMarshal(segmentWithMetadata)
if marshaledSegment.isErr:
return err("failed to marshal SegmentMessage: " & marshaledSegment.error)
let segmentMessage = replicateMessageWithNewPayload(newMessage, marshaledSegment.get())
if segmentMessage.isErr:
return err("failed to replicate message: " & segmentMessage.error)
let marshaledSegment = Protobuf.encode(segmentWithMetadata)
let segmentMessage = Chunk(payload: marshaledSegment)
segmentPayloads[i] = segmentPayload
segmentMessages[i] = segmentMessage.get()
segmentMessages[i] = segmentMessage
# Skip Reed-Solomon if parity segments are 0 or total exceeds max count
if paritySegmentsCount == 0 or (segmentsCount + paritySegmentsCount) > SegmentsReedsolomonMaxCount:
@ -157,9 +106,11 @@ proc segmentMessageInternal(newMessage: WakuNewMessage, segmentSize: int): Resul
# Use nim-leopard for Reed-Solomon encoding
var data = segmentPayloads[0..<segmentsCount]
var parity = segmentPayloads[segmentsCount..<(segmentsCount + paritySegmentsCount)]
var encoderRes = LeoEncoder.init(segmentSize, segmentsCount, paritySegmentsCount)
if encoderRes.isErr:
return err("failed to initialize encoder: " & $encoderRes.error)
var encoder = encoderRes.get
let encodeResult = encoder.encode(data, parity)
if encodeResult.isErr:
@ -168,7 +119,7 @@ proc segmentMessageInternal(newMessage: WakuNewMessage, segmentSize: int): Resul
# Create parity messages
for i in segmentsCount..<(segmentsCount + paritySegmentsCount):
let parityIndex = i - segmentsCount
let segmentWithMetadata = SegmentMessage(
let segmentWithMetadata = SegmentMessageProto(
entireMessageHash: entireMessageHash.data.toSeq,
segmentsCount: 0,
paritySegmentIndex: uint32(parityIndex),
@ -176,19 +127,14 @@ proc segmentMessageInternal(newMessage: WakuNewMessage, segmentSize: int): Resul
payload: segmentPayloads[i]
)
let marshaledSegment = protoMarshal(segmentWithMetadata)
if marshaledSegment.isErr:
return err("failed to marshal parity SegmentMessage: " & marshaledSegment.error)
let marshaledSegment = Protobuf.encode(segmentWithMetadata)
let segmentMessage = Chunk(payload: marshaledSegment)
let segmentMessage = replicateMessageWithNewPayload(newMessage, marshaledSegment.get())
if segmentMessage.isErr:
return err("failed to replicate parity message: " & segmentMessage.error)
segmentMessages.add(segmentMessage.get())
segmentMessages.add(segmentMessage)
return ok(segmentMessages)
proc handleSegmentationLayer*(s: MessageSender, message: var Message): Result[void, string] =
proc handleSegmentationLayer*(handler: SegmentationHander, message: var Message): Result[void, string] =
logScope:
site = "handleSegmentationLayer"
hash = message.hash.toHex
@ -203,7 +149,7 @@ proc handleSegmentationLayer*(s: MessageSender, message: var Message): Result[vo
"ParitySegmentsCount" = $segmentMessageProto.paritySegmentsCount
# TODO here use the mock function, real function should be used together with the persistence layer
let alreadyCompleted = s.persistence.isMessageAlreadyCompleted(segmentMessageProto.entireMessageHash)
let alreadyCompleted = handler.persistence.isMessageAlreadyCompleted(segmentMessageProto.entireMessageHash)
if alreadyCompleted.isErr:
return err(alreadyCompleted.error)
if alreadyCompleted.get():
@ -220,11 +166,11 @@ proc handleSegmentationLayer*(s: MessageSender, message: var Message): Result[vo
paritySegmentsCount: segmentMessageProto.paritySegmentsCount,
payload: segmentMessageProto.payload
)
let saveResult = s.persistence.saveMessageSegment(segmentMessage, message.sigPubKey, getTime().toUnix)
let saveResult = handler.persistence.saveMessageSegment(segmentMessage, message.sigPubKey, getTime().toUnix)
if saveResult.isErr:
return err(saveResult.error)
let segments = s.persistence.getMessageSegments(segmentMessage.entireMessageHash, message.sigPubKey)
let segments = handler.persistence.getMessageSegments(segmentMessage.entireMessageHash, message.sigPubKey)
if segments.isErr:
return err(segments.error)
@ -292,7 +238,7 @@ proc handleSegmentationLayer*(s: MessageSender, message: var Message): Result[vo
if entirePayloadHash.data != segmentMessage.entireMessageHash:
return err(ErrMessageSegmentsHashMismatch)
let completeResult = s.persistence.completeMessageSegments(segmentMessage.entireMessageHash, message.sigPubKey, getTime().toUnix)
let completeResult = handler.persistence.completeMessageSegments(segmentMessage.entireMessageHash, message.sigPubKey, getTime().toUnix)
if completeResult.isErr:
return err(completeResult.error)
@ -300,20 +246,10 @@ proc handleSegmentationLayer*(s: MessageSender, message: var Message): Result[vo
return ok()
proc cleanupSegments*(s: MessageSender): Result[void, string] =
proc cleanupSegments*(s: SegmentationHander): Result[void, string] =
# Same as previous translation
discard
# SegmentMessage proc (unchanged)
proc segmentMessage*(s: MessageSender, newMessage: WakuNewMessage): Result[seq[WakuNewMessage], string] =
let segmentSize = s.messaging.maxMessageSize div 4 * 3
let segmentResult = segmentMessageInternal(newMessage, segmentSize)
if segmentResult.isErr:
return err("segmentMessage failed: " & result.error)
let messages = result.get()
debug "message segmented", "segments" = $messages.len
return ok(messages)
proc demo() =
let
bufSize = 64 # byte count per buffer, must be a multiple of 64