chore: refactor data types

This commit is contained in:
kaichaosun 2025-07-16 14:47:47 +08:00
parent 114f2d1473
commit 211f8275db
No known key found for this signature in database
GPG Key ID: 223E0F992F4F03BF
2 changed files with 46 additions and 59 deletions

View File

@ -118,4 +118,4 @@ proc completeMessageSegments*(
try: try:
self.db.exec(sql"ROLLBACK") self.db.exec(sql"ROLLBACK")
except: discard except: discard
return err("complete segment messages with error: " & e.msg) return err("complete segment messages with error: " & e.msg)

View File

@ -4,19 +4,12 @@ import results
import leopard import leopard
import chronicles import chronicles
# External dependencies (still needed) import protobuf_serialization/proto_parser
# import protobuf # Nim protobuf library (e.g., protobuf-nim) import protobuf_serialization
# SegmentMessage type (unchanged) import db_models
type
SegmentMessage* = ref object
entireMessageHash*: seq[byte]
index*: uint32
segmentsCount*: uint32
paritySegmentIndex*: uint32
paritySegmentsCount*: uint32
payload*: seq[byte]
import_proto3 "segment_message.proto"
# Placeholder types (unchanged) # Placeholder types (unchanged)
type type
@ -24,10 +17,7 @@ type
payload*: seq[byte] payload*: seq[byte]
# Add other fields as needed # Add other fields as needed
StatusMessage* = object Message* = object
transportLayer*: TransportLayer
TransportLayer* = object
hash*: seq[byte] hash*: seq[byte]
payload*: seq[byte] payload*: seq[byte]
sigPubKey*: seq[byte] sigPubKey*: seq[byte]
@ -50,9 +40,21 @@ const
SegmentsReedsolomonMaxCount = 256 SegmentsReedsolomonMaxCount = 256
# Validation methods (unchanged) proc isValid(s: SegmentMessageProto): bool =
proc isValid*(s: SegmentMessage): bool = # Check if the hash length is valid (32 bytes for Keccak256)
s.segmentsCount >= 2 or s.paritySegmentsCount > 0 if s.entire_message_hash.len != 32:
return false
# Check if the segment index is within the valid range
if s.segments_count > 0 and s.index >= s.segments_count:
return false
# Check if the parity segment index is within the valid range
if s.parity_segments_count > 0 and s.parity_segment_index >= s.parity_segments_count:
return false
# Check if segments_count is at least 2 or parity_segments_count is positive
return s.segments_count >= 2 or s.parity_segments_count > 0
proc isParityMessage*(s: SegmentMessage): bool = proc isParityMessage*(s: SegmentMessage): bool =
s.segmentsCount == 0 and s.paritySegmentsCount > 0 s.segmentsCount == 0 and s.paritySegmentsCount > 0
@ -100,16 +102,6 @@ proc protoMarshal(msg: SegmentMessage): Result[seq[byte], string] =
# Fake serialization (index + payload length) TODO # Fake serialization (index + payload length) TODO
return ok(@[byte(msg.index)] & msg.payload) return ok(@[byte(msg.index)] & msg.payload)
proc protoUnmarshal(data: seq[byte], msg: var SegmentMessage): Result[void, string] =
# Fake deserialization (reconstruct index and payload)
if data.len < 1:
return err("data too short")
msg.index = uint32(data[0])
msg.payload = data[1..^1]
msg.segmentsCount = 2
msg.entireMessageHash = @[byte 1, 2, 3] # Dummy hash
return ok()
# Segment message into smaller chunks (updated with nim-leopard) # Segment message into smaller chunks (updated with nim-leopard)
proc segmentMessageInternal(newMessage: WakuNewMessage, segmentSize: int): Result[seq[WakuNewMessage], string] = proc segmentMessageInternal(newMessage: WakuNewMessage, segmentSize: int): Result[seq[WakuNewMessage], string] =
if newMessage.payload.len <= segmentSize: if newMessage.payload.len <= segmentSize:
@ -196,38 +188,43 @@ proc segmentMessageInternal(newMessage: WakuNewMessage, segmentSize: int): Resul
return ok(segmentMessages) return ok(segmentMessages)
proc handleSegmentationLayer*(s: MessageSender, message: var StatusMessage): Result[void, string] = proc handleSegmentationLayer*(s: MessageSender, message: var Message): Result[void, string] =
logScope: logScope:
site = "handleSegmentationLayerV2" site = "handleSegmentationLayer"
hash = message.transportLayer.hash.toHex hash = message.hash.toHex
var segmentMessage = SegmentMessage() let segmentMessageProto = Protobuf.decode(message.payload, SegmentMessageProto)
let unmarshalResult = protoUnmarshal(message.transportLayer.payload, segmentMessage)
if unmarshalResult.isErr:
return err("failed to unmarshal SegmentMessage: " & unmarshalResult.error)
debug "handling message segment", debug "handling message segment",
"EntireMessageHash" = segmentMessage.entireMessageHash.toHex, "EntireMessageHash" = segmentMessageProto.entireMessageHash.toHex,
"Index" = $segmentMessage.index, "Index" = $segmentMessageProto.index,
"SegmentsCount" = $segmentMessage.segmentsCount, "SegmentsCount" = $segmentMessageProto.segmentsCount,
"ParitySegmentIndex" = $segmentMessage.paritySegmentIndex, "ParitySegmentIndex" = $segmentMessageProto.paritySegmentIndex,
"ParitySegmentsCount" = $segmentMessage.paritySegmentsCount "ParitySegmentsCount" = $segmentMessageProto.paritySegmentsCount
# TODO here use the mock function, real function should be used together with the persistence layer # TODO here use the mock function, real function should be used together with the persistence layer
let alreadyCompleted = s.persistence.isMessageAlreadyCompleted(segmentMessage.entireMessageHash) let alreadyCompleted = s.persistence.isMessageAlreadyCompleted(segmentMessageProto.entireMessageHash)
if alreadyCompleted.isErr: if alreadyCompleted.isErr:
return err(alreadyCompleted.error) return err(alreadyCompleted.error)
if alreadyCompleted.get(): if alreadyCompleted.get():
return err(ErrMessageSegmentsAlreadyCompleted) return err(ErrMessageSegmentsAlreadyCompleted)
if not segmentMessage.isValid(): if not segmentMessageProto.isValid():
return err(ErrMessageSegmentsInvalidCount) return err(ErrMessageSegmentsInvalidCount)
let saveResult = s.persistence.saveMessageSegment(segmentMessage, message.transportLayer.sigPubKey, getTime().toUnix) let segmentMessage = SegmentMessage(
entireMessageHash: segmentMessageProto.entireMessageHash,
index: segmentMessageProto.index,
segmentsCount: segmentMessageProto.segmentsCount,
paritySegmentIndex: segmentMessageProto.paritySegmentIndex,
paritySegmentsCount: segmentMessageProto.paritySegmentsCount,
payload: segmentMessageProto.payload
)
let saveResult = s.persistence.saveMessageSegment(segmentMessage, message.sigPubKey, getTime().toUnix)
if saveResult.isErr: if saveResult.isErr:
return err(saveResult.error) return err(saveResult.error)
let segments = s.persistence.getMessageSegments(segmentMessage.entireMessageHash, message.transportLayer.sigPubKey) let segments = s.persistence.getMessageSegments(segmentMessage.entireMessageHash, message.sigPubKey)
if segments.isErr: if segments.isErr:
return err(segments.error) return err(segments.error)
@ -295,11 +292,11 @@ proc handleSegmentationLayer*(s: MessageSender, message: var StatusMessage): Res
if entirePayloadHash.data != segmentMessage.entireMessageHash: if entirePayloadHash.data != segmentMessage.entireMessageHash:
return err(ErrMessageSegmentsHashMismatch) return err(ErrMessageSegmentsHashMismatch)
let completeResult = s.persistence.completeMessageSegments(segmentMessage.entireMessageHash, message.transportLayer.sigPubKey, getTime().toUnix) let completeResult = s.persistence.completeMessageSegments(segmentMessage.entireMessageHash, message.sigPubKey, getTime().toUnix)
if completeResult.isErr: if completeResult.isErr:
return err(completeResult.error) return err(completeResult.error)
message.transportLayer.payload = entirePayload message.payload = entirePayload
return ok() return ok()
@ -317,17 +314,7 @@ proc segmentMessage*(s: MessageSender, newMessage: WakuNewMessage): Result[seq[W
debug "message segmented", "segments" = $messages.len debug "message segmented", "segments" = $messages.len
return ok(messages) return ok(messages)
proc demoRounding() = proc demo() =
let x = 3.7
let y = -3.7
echo "ceil(", x, ") = ", ceil(x) # Rounds up
echo "floor(", x, ") = ", floor(x) # Rounds down
echo "round(", x, ") = ", round(x) # Rounds to nearest integer
echo "trunc(", x, ") = ", trunc(x) # Truncates decimal part
echo "ceil(", y, ") = ", ceil(y)
echo "floor(", y, ") = ", floor(y)
let let
bufSize = 64 # byte count per buffer, must be a multiple of 64 bufSize = 64 # byte count per buffer, must be a multiple of 64
buffers = 239 # number of data symbols buffers = 239 # number of data symbols
@ -342,7 +329,7 @@ proc demoRounding() =
when isMainModule: when isMainModule:
demoRounding() demo()
proc add2*(x, y: int): int = proc add2*(x, y: int): int =
## Adds two numbers together. ## Adds two numbers together.