mirror of
https://github.com/logos-messaging/nim-chat-sdk.git
synced 2026-01-03 14:43:07 +00:00
327 lines
13 KiB
Nim
327 lines
13 KiB
Nim
import math, times, sequtils, strutils, options
|
|
import nimcrypto
|
|
import results
|
|
import leopard
|
|
import chronicles
|
|
|
|
import protobuf_serialization/proto_parser
|
|
import protobuf_serialization
|
|
|
|
import db_models
|
|
import db
|
|
|
|
import_proto3 "segment_message.proto"
|
|
|
|
# Placeholder types (unchanged)
|
|
type
|
|
Chunk* = object
|
|
payload*: seq[byte]
|
|
|
|
Message* = object
|
|
hash*: seq[byte]
|
|
payload*: seq[byte]
|
|
sigPubKey*: seq[byte]
|
|
|
|
SegmentationHander* = object
|
|
segmentSize*: int
|
|
persistence*: SegmentationPersistence
|
|
|
|
const
|
|
SegmentsParityRate = 0.125
|
|
SegmentsReedsolomonMaxCount = 256
|
|
|
|
const
|
|
ErrMessageSegmentsIncomplete* = "message segments incomplete"
|
|
ErrMessageSegmentsAlreadyCompleted* = "message segments already completed"
|
|
ErrMessageSegmentsInvalidCount* = "invalid segments count"
|
|
ErrMessageSegmentsHashMismatch* = "hash of entire payload does not match"
|
|
ErrMessageSegmentsInvalidParity* = "invalid parity segments"
|
|
|
|
proc isValid(s: SegmentMessageProto): bool =
|
|
# Check if the hash length is valid (32 bytes for Keccak256)
|
|
if s.entire_message_hash.len != 32:
|
|
return false
|
|
|
|
# Check if the segment index is within the valid range
|
|
if s.segments_count > 0 and s.index >= s.segments_count:
|
|
return false
|
|
|
|
# Check if the parity segment index is within the valid range
|
|
if s.parity_segments_count > 0 and s.parity_segment_index >= s.parity_segments_count:
|
|
return false
|
|
|
|
# Check if segments_count is at least 2 or parity_segments_count is positive
|
|
return s.segments_count >= 2 or s.parity_segments_count > 0
|
|
|
|
proc isParityMessage*(s: SegmentMessage): bool =
|
|
s.segmentsCount == 0 and s.paritySegmentsCount > 0
|
|
|
|
proc segmentMessage*(s: SegmentationHander, newMessage: Chunk): Result[seq[Chunk], string] =
|
|
if newMessage.payload.len <= s.segmentSize:
|
|
return ok(@[newMessage])
|
|
|
|
info "segmenting message",
|
|
"payloadSize" = newMessage.payload.len,
|
|
"segmentSize" = s.segmentSize
|
|
info "segment payload in", "saaa" = newMessage.payload.toHex
|
|
let entireMessageHash = keccak256.digest(newMessage.payload)
|
|
let entirePayloadSize = newMessage.payload.len
|
|
|
|
let segmentsCount = int(ceil(entirePayloadSize.float / s.segmentSize.float))
|
|
let paritySegmentsCount = int(floor(segmentsCount.float * SegmentsParityRate))
|
|
info "parity count", paritySegmentsCount
|
|
info "segmentsCount", segmentsCount
|
|
info "entirePayloadSize", entirePayloadSize
|
|
info "segmentSize", "segmentSize" = s.segmentSize
|
|
|
|
var segmentPayloads = newSeq[seq[byte]](segmentsCount)
|
|
var segmentMessages = newSeq[Chunk](segmentsCount)
|
|
|
|
for i in 0..<segmentsCount:
|
|
let start = i * s.segmentSize
|
|
var endIndex = start + s.segmentSize
|
|
if endIndex > entirePayloadSize:
|
|
endIndex = entirePayloadSize
|
|
|
|
let segmentPayload = newMessage.payload[start..<endIndex]
|
|
let segmentWithMetadata = SegmentMessageProto(
|
|
entireMessageHash: entireMessageHash.data.toSeq,
|
|
index: uint32(i),
|
|
segmentsCount: uint32(segmentsCount),
|
|
payload: segmentPayload
|
|
)
|
|
|
|
info "entireMessageHash", "entireMessageHash" = entireMessageHash.data.toHex
|
|
|
|
let marshaledSegment = Protobuf.encode(segmentWithMetadata)
|
|
let segmentMessage = Chunk(payload: marshaledSegment)
|
|
|
|
segmentPayloads[i] = segmentPayload
|
|
segmentMessages[i] = segmentMessage
|
|
|
|
# Skip Reed-Solomon if parity segments are 0 or total exceeds max count
|
|
info "segments count", "len" = segmentMessages.len
|
|
if paritySegmentsCount == 0 or (segmentsCount + paritySegmentsCount) > SegmentsReedsolomonMaxCount:
|
|
return ok(segmentMessages)
|
|
|
|
# Align last segment payload for Reed-Solomon (leopard requires fixed-size shards)
|
|
let lastSegmentPayload = segmentPayloads[segmentsCount-1]
|
|
segmentPayloads[segmentsCount-1] = newSeq[byte](s.segmentSize)
|
|
segmentPayloads[segmentsCount-1][0..<lastSegmentPayload.len] = lastSegmentPayload
|
|
|
|
for i, shard in segmentPayloads:
|
|
info "segment payload before encoding", "index" = i, "len" = shard.len, "payload" = shard
|
|
|
|
# Allocate space for parity shards
|
|
# for i in segmentsCount..<(segmentsCount + paritySegmentsCount):
|
|
# segmentPayloads[i] = newSeq[byte](s.segmentSize)
|
|
|
|
# Use nim-leopard for Reed-Solomon encoding
|
|
# var data = segmentPayloads[0..<segmentsCount]
|
|
var parity = newSeq[seq[byte]](paritySegmentsCount)
|
|
for i in 0..<paritySegmentsCount:
|
|
newSeq(parity[i], s.segmentSize)
|
|
|
|
# var parity = segmentPayloads[segmentsCount..<(segmentsCount + paritySegmentsCount)]
|
|
|
|
info "parity segments",
|
|
"paritySegmentsCount" = paritySegmentsCount,
|
|
"parity" = parity
|
|
info "initializing encoder",
|
|
"segmentSize" = s.segmentSize,
|
|
"segmentsCount" = segmentsCount,
|
|
"paritySegmentsCount" = paritySegmentsCount
|
|
var encoderRes = LeoEncoder.init(s.segmentSize, segmentsCount, paritySegmentsCount)
|
|
if encoderRes.isErr:
|
|
return err("failed to initialize encoder: " & $encoderRes.error)
|
|
|
|
var encoder = encoderRes.get
|
|
let encodeResult = encoder.encode(segmentPayloads, parity)
|
|
info "parity segments",
|
|
"paritySegmentsCount" = paritySegmentsCount,
|
|
"parity" = parity
|
|
info "encoding result", "encodeResult" = encodeResult
|
|
if encodeResult.isErr:
|
|
return err("failed to encode segments with leopard: " & $encodeResult.error)
|
|
|
|
# Create parity messages
|
|
for i in 0..<paritySegmentsCount:
|
|
info "parity segment index", "index" = i
|
|
info "parity segment payload", "payload" = parity[i]
|
|
let segmentWithMetadata = SegmentMessageProto(
|
|
entireMessageHash: entireMessageHash.data.toSeq,
|
|
segmentsCount: 0,
|
|
paritySegmentIndex: uint32(i),
|
|
paritySegmentsCount: uint32(paritySegmentsCount),
|
|
payload: parity[i]
|
|
)
|
|
|
|
let marshaledSegment = Protobuf.encode(segmentWithMetadata)
|
|
let segmentMessage = Chunk(payload: marshaledSegment)
|
|
|
|
segmentMessages.add(segmentMessage)
|
|
|
|
return ok(segmentMessages)
|
|
|
|
proc handleSegmentationLayer*(handler: SegmentationHander, message: var Message): Result[void, string] =
|
|
logScope:
|
|
site = "handleSegmentationLayer"
|
|
hash = message.hash.toHex
|
|
|
|
let segmentMessageProto = Protobuf.decode(message.payload, SegmentMessageProto)
|
|
|
|
debug "handling message segment",
|
|
"EntireMessageHash" = segmentMessageProto.entireMessageHash.toHex,
|
|
"Index" = $segmentMessageProto.index,
|
|
"SegmentsCount" = $segmentMessageProto.segmentsCount,
|
|
"ParitySegmentIndex" = $segmentMessageProto.paritySegmentIndex,
|
|
"ParitySegmentsCount" = $segmentMessageProto.paritySegmentsCount
|
|
|
|
let alreadyCompleted = handler.persistence.isMessageAlreadyCompleted(segmentMessageProto.entireMessageHash)
|
|
if alreadyCompleted.isErr:
|
|
return err(alreadyCompleted.error)
|
|
if alreadyCompleted.get():
|
|
return err(ErrMessageSegmentsAlreadyCompleted)
|
|
|
|
if not segmentMessageProto.isValid():
|
|
return err(ErrMessageSegmentsInvalidCount)
|
|
|
|
let segmentMessage = SegmentMessage(
|
|
entireMessageHash: segmentMessageProto.entireMessageHash,
|
|
index: segmentMessageProto.index,
|
|
segmentsCount: segmentMessageProto.segmentsCount,
|
|
paritySegmentIndex: segmentMessageProto.paritySegmentIndex,
|
|
paritySegmentsCount: segmentMessageProto.paritySegmentsCount,
|
|
payload: segmentMessageProto.payload
|
|
)
|
|
info "segment payload", "len" = segmentMessage.payload.len, "segment payload" = segmentMessage.payload.toHex
|
|
let saveResult = handler.persistence.saveMessageSegment(segmentMessage, message.sigPubKey, getTime().toUnix)
|
|
if saveResult.isErr:
|
|
return err(saveResult.error)
|
|
|
|
let segments = handler.persistence.getMessageSegments(segmentMessage.entireMessageHash, message.sigPubKey)
|
|
if segments.isErr:
|
|
return err(segments.error)
|
|
|
|
if segments.get().len == 0:
|
|
return err("unexpected state: no segments found after save operation")
|
|
|
|
let firstSegmentMessage = segments.get()[0]
|
|
let lastSegmentMessage = segments.get()[^1]
|
|
|
|
info "first segment",
|
|
"index" = firstSegmentMessage.index,
|
|
"segmentsCount" = firstSegmentMessage.segmentsCount,
|
|
"paritySegmentIndex" = firstSegmentMessage.paritySegmentIndex,
|
|
"paritySegmentsCount" = firstSegmentMessage.paritySegmentsCount,
|
|
"len" = firstSegmentMessage.payload.len
|
|
|
|
info "last segment",
|
|
"index" = lastSegmentMessage.index,
|
|
"segmentsCount" = lastSegmentMessage.segmentsCount,
|
|
"paritySegmentIndex" = lastSegmentMessage.paritySegmentIndex,
|
|
"len" = firstSegmentMessage.payload.len,
|
|
"paritySegmentsCount" = lastSegmentMessage.paritySegmentsCount
|
|
|
|
if firstSegmentMessage.isParityMessage() or segments.get().len != int(firstSegmentMessage.segmentsCount):
|
|
return err(ErrMessageSegmentsIncomplete)
|
|
|
|
var payloads = newSeq[seq[byte]](firstSegmentMessage.segmentsCount)
|
|
var parity = newSeq[seq[byte]](lastSegmentMessage.paritySegmentsCount)
|
|
let payloadSize = firstSegmentMessage.payload.len
|
|
|
|
let restoreUsingParityData = lastSegmentMessage.isParityMessage()
|
|
info "restoreUsingParityData", restoreUsingParityData
|
|
if not restoreUsingParityData:
|
|
for i, segment in segments.get():
|
|
payloads[i] = segment.payload
|
|
info "segment payload=====", "index" = i, "len" = segment.payload.len, "payload" = segment.payload.toHex
|
|
else:
|
|
info "restoring using parity data...", "payloadSize" = payloadSize
|
|
var lastNonParitySegmentPayload: seq[byte]
|
|
for segment in segments.get():
|
|
info "segment after read", "segment"= segment.entireMessageHash.toHex, "index" = segment.index, "paritySegmentIndex" = segment.paritySegmentIndex, "paritySegmentsCount" = segment.paritySegmentsCount, "len" = segment.payload.len
|
|
if not segment.isParityMessage():
|
|
info "segment index for non parity", "index" = segment.index
|
|
if segment.index == firstSegmentMessage.segmentsCount - 1:
|
|
payloads[segment.index] = newSeq[byte](payloadSize)
|
|
payloads[segment.index][0..<segment.payload.len] = segment.payload
|
|
lastNonParitySegmentPayload = segment.payload
|
|
else:
|
|
payloads[segment.index] = segment.payload
|
|
else:
|
|
info "parity segment index.....", "index" = segment.paritySegmentIndex, "hhhh" = firstSegmentMessage.segmentsCount + segment.paritySegmentIndex
|
|
info "parity segment index payload.....", "payload" = segment.payload
|
|
parity[segment.paritySegmentIndex] = segment.payload
|
|
|
|
# Use nim-leopard for Reed-Solomon reconstruction
|
|
info "leo decoder init",
|
|
"payloadSize" = payloadSize,
|
|
"segmentsCount" = firstSegmentMessage.segmentsCount,
|
|
"paritySegmentsCount" = lastSegmentMessage.paritySegmentsCount
|
|
let decoderRes = LeoDecoder.init(payloadSize, int(firstSegmentMessage.segmentsCount), int(lastSegmentMessage.paritySegmentsCount))
|
|
if decoderRes.isErr:
|
|
return err("failed to initialize LeoDecoder: " & $decoderRes.error)
|
|
var decoder = decoderRes.get()
|
|
info "decode data", "data" = payloads
|
|
info "decode data parity", "parityData" = parity
|
|
var recovered = newSeq[seq[byte]](int(firstSegmentMessage.segmentsCount)) # Allocate for recovered shards
|
|
for i in 0..<int(firstSegmentMessage.segmentsCount):
|
|
recovered[i] = newSeq[byte](payloadSize)
|
|
let reconstructResult = decoder.decode(payloads, parity, recovered)
|
|
if reconstructResult.isErr:
|
|
return err("failed to reconstruct payloads with leopard: " & $reconstructResult.error)
|
|
|
|
info "recovered segments", "recoveredCount" = recovered.len, "recovered" = recovered
|
|
|
|
for i in 0..<firstSegmentMessage.segmentsCount:
|
|
if payloads[i].len == 0:
|
|
payloads[i] = recovered[i]
|
|
|
|
if lastNonParitySegmentPayload.len > 0:
|
|
payloads[firstSegmentMessage.segmentsCount - 1] = lastNonParitySegmentPayload
|
|
|
|
# Combine payload
|
|
var entirePayload = newSeq[byte]()
|
|
for i in 0..<int(firstSegmentMessage.segmentsCount):
|
|
entirePayload.add(payloads[i])
|
|
|
|
# Sanity check
|
|
info "entirePayloadSize", "entirePayloadSize" = entirePayload.len
|
|
info "entire payload", "entirePayloadSize" = entirePayload.toHex
|
|
let entirePayloadHash = keccak256.digest(entirePayload)
|
|
if entirePayloadHash.data != segmentMessage.entireMessageHash:
|
|
return err(ErrMessageSegmentsHashMismatch)
|
|
|
|
let completeResult = handler.persistence.completeMessageSegments(segmentMessage.entireMessageHash, message.sigPubKey, getTime().toUnix)
|
|
if completeResult.isErr:
|
|
return err(completeResult.error)
|
|
|
|
message.payload = entirePayload
|
|
return ok()
|
|
|
|
|
|
proc cleanupSegments*(s: SegmentationHander): Result[void, string] =
|
|
discard
|
|
|
|
proc demo() =
|
|
let
|
|
bufSize = 64 # byte count per buffer, must be a multiple of 64
|
|
buffers = 239 # number of data symbols
|
|
parity = 17 # number of parity symbols
|
|
|
|
var
|
|
encoderRes = LeoEncoder.init(bufSize, buffers, parity)
|
|
decoderRes = LeoDecoder.init(bufSize, buffers, parity)
|
|
|
|
assert encoderRes.isOk
|
|
assert decoderRes.isOk
|
|
|
|
|
|
when isMainModule:
|
|
demo()
|
|
|
|
proc add2*(x, y: int): int =
|
|
## Adds two numbers together.
|
|
return x + y
|