nim-chat-sdk/chat_sdk/segmentation.nim

257 lines
9.5 KiB
Nim
Raw Normal View History

2025-05-12 10:16:25 +08:00
import math, times, sequtils, strutils, options
import nimcrypto
2025-05-12 10:16:25 +08:00
import results
import leopard
2025-05-14 16:28:27 +08:00
import chronicles
2025-05-12 10:16:25 +08:00
2025-07-16 14:47:47 +08:00
import protobuf_serialization/proto_parser
import protobuf_serialization
2025-05-12 10:16:25 +08:00
2025-07-16 14:47:47 +08:00
import db_models
2025-07-16 15:54:32 +08:00
import db
2025-05-14 16:28:27 +08:00
2025-07-16 14:47:47 +08:00
import_proto3 "segment_message.proto"
2025-05-14 16:28:27 +08:00
2025-05-12 10:16:25 +08:00
type
2025-07-16 15:54:32 +08:00
Chunk* = object
2025-07-01 14:16:12 +08:00
payload*: seq[byte]
2025-05-12 10:16:25 +08:00
2025-07-16 14:47:47 +08:00
Message* = object
2025-07-01 14:16:12 +08:00
hash*: seq[byte]
payload*: seq[byte]
sigPubKey*: seq[byte]
2025-07-16 15:54:32 +08:00
SegmentationHander* = object
2025-07-23 15:25:04 +08:00
segmentSize*: int
2025-07-16 15:54:32 +08:00
persistence*: SegmentationPersistence
2025-05-12 10:16:25 +08:00
2025-07-16 15:54:32 +08:00
const
SegmentsParityRate = 0.125
SegmentsReedsolomonMaxCount = 256
2025-05-12 10:16:25 +08:00
const
2025-07-23 15:25:04 +08:00
ErrMessageSegmentsIncomplete* = "message segments incomplete"
ErrMessageSegmentsAlreadyCompleted* = "message segments already completed"
ErrMessageSegmentsInvalidCount* = "invalid segments count"
ErrMessageSegmentsHashMismatch* = "hash of entire payload does not match"
ErrMessageSegmentsInvalidParity* = "invalid parity segments"
2025-05-12 10:16:25 +08:00
2025-07-16 14:47:47 +08:00
proc isValid(s: SegmentMessageProto): bool =
# Check if the hash length is valid (32 bytes for Keccak256)
if s.entire_message_hash.len != 32:
return false
# Check if the segment index is within the valid range
if s.segments_count > 0 and s.index >= s.segments_count:
return false
# Check if the parity segment index is within the valid range
if s.parity_segments_count > 0 and s.parity_segment_index >= s.parity_segments_count:
return false
# Check if segments_count is at least 2 or parity_segments_count is positive
return s.segments_count >= 2 or s.parity_segments_count > 0
2025-05-12 10:16:25 +08:00
proc isParityMessage*(s: SegmentMessage): bool =
s.segmentsCount == 0 and s.paritySegmentsCount > 0
2025-07-16 15:54:32 +08:00
proc segmentMessage*(s: SegmentationHander, newMessage: Chunk): Result[seq[Chunk], string] =
2025-07-23 15:25:04 +08:00
if newMessage.payload.len <= s.segmentSize:
2025-05-14 16:28:27 +08:00
return ok(@[newMessage])
2025-07-23 15:25:04 +08:00
info "segmenting message",
"payloadSize" = newMessage.payload.len,
"segmentSize" = s.segmentSize
2025-05-14 16:28:27 +08:00
let entireMessageHash = keccak256.digest(newMessage.payload)
let entirePayloadSize = newMessage.payload.len
2025-07-23 15:25:04 +08:00
let segmentsCount = int(ceil(entirePayloadSize.float / s.segmentSize.float))
2025-05-14 16:28:27 +08:00
let paritySegmentsCount = int(floor(segmentsCount.float * SegmentsParityRate))
2025-07-23 15:25:04 +08:00
var segmentPayloads = newSeq[seq[byte]](segmentsCount)
2025-07-16 15:54:32 +08:00
var segmentMessages = newSeq[Chunk](segmentsCount)
2025-05-14 16:28:27 +08:00
for i in 0..<segmentsCount:
2025-07-23 15:25:04 +08:00
let start = i * s.segmentSize
var endIndex = start + s.segmentSize
2025-05-14 16:28:27 +08:00
if endIndex > entirePayloadSize:
endIndex = entirePayloadSize
let segmentPayload = newMessage.payload[start..<endIndex]
2025-07-16 15:54:32 +08:00
let segmentWithMetadata = SegmentMessageProto(
2025-05-14 16:28:27 +08:00
entireMessageHash: entireMessageHash.data.toSeq,
index: uint32(i),
segmentsCount: uint32(segmentsCount),
payload: segmentPayload
)
2025-07-16 15:54:32 +08:00
let marshaledSegment = Protobuf.encode(segmentWithMetadata)
let segmentMessage = Chunk(payload: marshaledSegment)
2025-05-14 16:28:27 +08:00
segmentPayloads[i] = segmentPayload
2025-07-16 15:54:32 +08:00
segmentMessages[i] = segmentMessage
2025-05-14 16:28:27 +08:00
# Skip Reed-Solomon if parity segments are 0 or total exceeds max count
2025-07-23 15:25:04 +08:00
info "segments count", "len" = segmentMessages.len
2025-05-14 16:28:27 +08:00
if paritySegmentsCount == 0 or (segmentsCount + paritySegmentsCount) > SegmentsReedsolomonMaxCount:
return ok(segmentMessages)
# Align last segment payload for Reed-Solomon (leopard requires fixed-size shards)
let lastSegmentPayload = segmentPayloads[segmentsCount-1]
2025-07-23 15:25:04 +08:00
segmentPayloads[segmentsCount-1] = newSeq[byte](s.segmentSize)
2025-05-14 16:28:27 +08:00
segmentPayloads[segmentsCount-1][0..<lastSegmentPayload.len] = lastSegmentPayload
# Use nim-leopard for Reed-Solomon encoding
2025-07-23 15:25:04 +08:00
var parity = newSeq[seq[byte]](paritySegmentsCount)
for i in 0..<paritySegmentsCount:
newSeq(parity[i], s.segmentSize)
var encoderRes = LeoEncoder.init(s.segmentSize, segmentsCount, paritySegmentsCount)
2025-05-14 16:28:27 +08:00
if encoderRes.isErr:
return err("failed to initialize encoder: " & $encoderRes.error)
2025-07-16 15:54:32 +08:00
2025-05-14 16:28:27 +08:00
var encoder = encoderRes.get
2025-07-23 15:25:04 +08:00
let encodeResult = encoder.encode(segmentPayloads, parity)
2025-05-14 16:28:27 +08:00
if encodeResult.isErr:
return err("failed to encode segments with leopard: " & $encodeResult.error)
# Create parity messages
2025-07-23 15:25:04 +08:00
for i in 0..<paritySegmentsCount:
2025-07-16 15:54:32 +08:00
let segmentWithMetadata = SegmentMessageProto(
2025-05-14 16:28:27 +08:00
entireMessageHash: entireMessageHash.data.toSeq,
segmentsCount: 0,
2025-07-23 15:25:04 +08:00
paritySegmentIndex: uint32(i),
2025-05-14 16:28:27 +08:00
paritySegmentsCount: uint32(paritySegmentsCount),
2025-07-23 15:25:04 +08:00
payload: parity[i]
2025-05-14 16:28:27 +08:00
)
2025-07-16 15:54:32 +08:00
let marshaledSegment = Protobuf.encode(segmentWithMetadata)
let segmentMessage = Chunk(payload: marshaledSegment)
2025-05-14 16:28:27 +08:00
2025-07-16 15:54:32 +08:00
segmentMessages.add(segmentMessage)
2025-05-14 16:28:27 +08:00
return ok(segmentMessages)
2025-07-16 15:54:32 +08:00
proc handleSegmentationLayer*(handler: SegmentationHander, message: var Message): Result[void, string] =
2025-05-14 16:28:27 +08:00
logScope:
2025-07-16 14:47:47 +08:00
site = "handleSegmentationLayer"
hash = message.hash.toHex
2025-05-14 16:28:27 +08:00
2025-07-16 14:47:47 +08:00
let segmentMessageProto = Protobuf.decode(message.payload, SegmentMessageProto)
2025-05-14 16:28:27 +08:00
debug "handling message segment",
2025-07-16 14:47:47 +08:00
"EntireMessageHash" = segmentMessageProto.entireMessageHash.toHex,
"Index" = $segmentMessageProto.index,
"SegmentsCount" = $segmentMessageProto.segmentsCount,
"ParitySegmentIndex" = $segmentMessageProto.paritySegmentIndex,
"ParitySegmentsCount" = $segmentMessageProto.paritySegmentsCount
2025-05-14 16:28:27 +08:00
2025-07-16 15:54:32 +08:00
let alreadyCompleted = handler.persistence.isMessageAlreadyCompleted(segmentMessageProto.entireMessageHash)
2025-05-14 16:28:27 +08:00
if alreadyCompleted.isErr:
return err(alreadyCompleted.error)
if alreadyCompleted.get():
return err(ErrMessageSegmentsAlreadyCompleted)
2025-07-16 14:47:47 +08:00
if not segmentMessageProto.isValid():
2025-05-14 16:28:27 +08:00
return err(ErrMessageSegmentsInvalidCount)
2025-07-16 14:47:47 +08:00
let segmentMessage = SegmentMessage(
entireMessageHash: segmentMessageProto.entireMessageHash,
index: segmentMessageProto.index,
segmentsCount: segmentMessageProto.segmentsCount,
paritySegmentIndex: segmentMessageProto.paritySegmentIndex,
paritySegmentsCount: segmentMessageProto.paritySegmentsCount,
payload: segmentMessageProto.payload
)
2025-07-16 15:54:32 +08:00
let saveResult = handler.persistence.saveMessageSegment(segmentMessage, message.sigPubKey, getTime().toUnix)
2025-05-14 16:28:27 +08:00
if saveResult.isErr:
return err(saveResult.error)
2025-07-16 15:54:32 +08:00
let segments = handler.persistence.getMessageSegments(segmentMessage.entireMessageHash, message.sigPubKey)
2025-05-14 16:28:27 +08:00
if segments.isErr:
return err(segments.error)
if segments.get().len == 0:
return err("unexpected state: no segments found after save operation")
let firstSegmentMessage = segments.get()[0]
let lastSegmentMessage = segments.get()[^1]
2025-07-23 15:34:36 +08:00
debug "first segment",
2025-07-23 15:25:04 +08:00
"index" = firstSegmentMessage.index,
"segmentsCount" = firstSegmentMessage.segmentsCount,
"paritySegmentIndex" = firstSegmentMessage.paritySegmentIndex,
"paritySegmentsCount" = firstSegmentMessage.paritySegmentsCount,
"len" = firstSegmentMessage.payload.len
2025-07-23 15:34:36 +08:00
debug "last segment",
2025-07-23 15:25:04 +08:00
"index" = lastSegmentMessage.index,
"segmentsCount" = lastSegmentMessage.segmentsCount,
"paritySegmentIndex" = lastSegmentMessage.paritySegmentIndex,
"len" = firstSegmentMessage.payload.len,
"paritySegmentsCount" = lastSegmentMessage.paritySegmentsCount
2025-05-14 16:28:27 +08:00
if firstSegmentMessage.isParityMessage() or segments.get().len != int(firstSegmentMessage.segmentsCount):
return err(ErrMessageSegmentsIncomplete)
2025-07-23 15:25:04 +08:00
var payloads = newSeq[seq[byte]](firstSegmentMessage.segmentsCount)
var parity = newSeq[seq[byte]](lastSegmentMessage.paritySegmentsCount)
2025-05-14 16:28:27 +08:00
let payloadSize = firstSegmentMessage.payload.len
let restoreUsingParityData = lastSegmentMessage.isParityMessage()
if not restoreUsingParityData:
for i, segment in segments.get():
payloads[i] = segment.payload
else:
var lastNonParitySegmentPayload: seq[byte]
for segment in segments.get():
if not segment.isParityMessage():
if segment.index == firstSegmentMessage.segmentsCount - 1:
payloads[segment.index] = newSeq[byte](payloadSize)
payloads[segment.index][0..<segment.payload.len] = segment.payload
lastNonParitySegmentPayload = segment.payload
else:
payloads[segment.index] = segment.payload
else:
2025-07-23 15:25:04 +08:00
parity[segment.paritySegmentIndex] = segment.payload
2025-05-14 16:28:27 +08:00
# Use nim-leopard for Reed-Solomon reconstruction
let decoderRes = LeoDecoder.init(payloadSize, int(firstSegmentMessage.segmentsCount), int(lastSegmentMessage.paritySegmentsCount))
if decoderRes.isErr:
return err("failed to initialize LeoDecoder: " & $decoderRes.error)
var decoder = decoderRes.get()
var recovered = newSeq[seq[byte]](int(firstSegmentMessage.segmentsCount)) # Allocate for recovered shards
for i in 0..<int(firstSegmentMessage.segmentsCount):
recovered[i] = newSeq[byte](payloadSize)
2025-07-23 15:25:04 +08:00
let reconstructResult = decoder.decode(payloads, parity, recovered)
2025-05-14 16:28:27 +08:00
if reconstructResult.isErr:
return err("failed to reconstruct payloads with leopard: " & $reconstructResult.error)
2025-07-23 15:25:04 +08:00
for i in 0..<firstSegmentMessage.segmentsCount:
if payloads[i].len == 0:
payloads[i] = recovered[i]
2025-05-14 16:28:27 +08:00
if lastNonParitySegmentPayload.len > 0:
payloads[firstSegmentMessage.segmentsCount - 1] = lastNonParitySegmentPayload
# Combine payload
var entirePayload = newSeq[byte]()
for i in 0..<int(firstSegmentMessage.segmentsCount):
entirePayload.add(payloads[i])
# Sanity check
let entirePayloadHash = keccak256.digest(entirePayload)
if entirePayloadHash.data != segmentMessage.entireMessageHash:
return err(ErrMessageSegmentsHashMismatch)
2025-07-16 15:54:32 +08:00
let completeResult = handler.persistence.completeMessageSegments(segmentMessage.entireMessageHash, message.sigPubKey, getTime().toUnix)
2025-05-14 16:28:27 +08:00
if completeResult.isErr:
return err(completeResult.error)
2025-07-16 14:47:47 +08:00
message.payload = entirePayload
2025-05-14 16:28:27 +08:00
return ok()
2025-07-16 15:54:32 +08:00
proc cleanupSegments*(s: SegmentationHander): Result[void, string] =
2025-05-14 16:28:27 +08:00
discard