From 9e5462eb9ed6703c9e8bd49453d411ad561e1632 Mon Sep 17 00:00:00 2001 From: Patryk Osmaczko Date: Thu, 4 Apr 2024 19:46:51 +0200 Subject: [PATCH] feat_: introduce forward error correction in segmentation layer closes: #4330 --- protocol/common/message_segmentation.go | 365 ++++++++++++++++++ protocol/common/message_segmentation_test.go | 205 ++++++++++ protocol/common/message_sender.go | 153 +------- protocol/common/raw_messages_persistence.go | 31 +- protocol/migrations/migrations.go | 23 ++ ...5223_add_parity_to_message_segments.up.sql | 19 + protocol/protobuf/segment_message.pb.go | 32 +- protocol/protobuf/segment_message.proto | 4 + 8 files changed, 667 insertions(+), 165 deletions(-) create mode 100644 protocol/common/message_segmentation.go create mode 100644 protocol/common/message_segmentation_test.go create mode 100644 protocol/migrations/sqlite/1712905223_add_parity_to_message_segments.up.sql diff --git a/protocol/common/message_segmentation.go b/protocol/common/message_segmentation.go new file mode 100644 index 000000000..e45a0c782 --- /dev/null +++ b/protocol/common/message_segmentation.go @@ -0,0 +1,365 @@ +package common + +import ( + "bytes" + "math" + "time" + + "github.com/golang/protobuf/proto" + "github.com/jinzhu/copier" + "github.com/klauspost/reedsolomon" + "github.com/pkg/errors" + "go.uber.org/zap" + + "github.com/status-im/status-go/eth-node/crypto" + "github.com/status-im/status-go/eth-node/types" + "github.com/status-im/status-go/protocol/protobuf" + v1protocol "github.com/status-im/status-go/protocol/v1" +) + +var ErrMessageSegmentsIncomplete = errors.New("message segments incomplete") +var ErrMessageSegmentsAlreadyCompleted = errors.New("message segments already completed") +var ErrMessageSegmentsInvalidCount = errors.New("invalid segments count") +var ErrMessageSegmentsHashMismatch = errors.New("hash of entire payload does not match") +var ErrMessageSegmentsInvalidParity = errors.New("invalid parity segments") + +const ( + segmentsParityRate = 0.125 + segmentsReedsolomonMaxCount = 256 +) + +type SegmentMessage struct { + *protobuf.SegmentMessage +} + +func (s *SegmentMessage) IsValid() bool { + return s.SegmentsCount >= 2 || s.ParitySegmentsCount > 0 +} + +func (s *SegmentMessage) IsParityMessage() bool { + return s.SegmentsCount == 0 && s.ParitySegmentsCount > 0 +} + +func (s *MessageSender) segmentMessage(newMessage *types.NewMessage) ([]*types.NewMessage, error) { + // We set the max message size to 3/4 of the allowed message size, to leave + // room for segment message metadata. + newMessages, err := segmentMessage(newMessage, int(s.transport.MaxMessageSize()/4*3)) + s.logger.Debug("message segmented", zap.Int("segments", len(newMessages))) + return newMessages, err +} + +func replicateMessageWithNewPayload(message *types.NewMessage, payload []byte) (*types.NewMessage, error) { + copy := &types.NewMessage{} + err := copier.Copy(copy, message) + if err != nil { + return nil, err + } + + copy.Payload = payload + copy.PowTarget = calculatePoW(payload) + return copy, nil +} + +// Segments message into smaller chunks if the size exceeds segmentSize. +func segmentMessage(newMessage *types.NewMessage, segmentSize int) ([]*types.NewMessage, error) { + if len(newMessage.Payload) <= segmentSize { + return []*types.NewMessage{newMessage}, nil + } + + entireMessageHash := crypto.Keccak256(newMessage.Payload) + entirePayloadSize := len(newMessage.Payload) + + segmentsCount := int(math.Ceil(float64(entirePayloadSize) / float64(segmentSize))) + paritySegmentsCount := int(math.Floor(float64(segmentsCount) * segmentsParityRate)) + + segmentPayloads := make([][]byte, segmentsCount+paritySegmentsCount) + segmentMessages := make([]*types.NewMessage, segmentsCount) + + for start, index := 0, 0; start < entirePayloadSize; start += segmentSize { + end := start + segmentSize + if end > entirePayloadSize { + end = entirePayloadSize + } + + segmentPayload := newMessage.Payload[start:end] + segmentWithMetadata := &protobuf.SegmentMessage{ + EntireMessageHash: entireMessageHash, + Index: uint32(index), + SegmentsCount: uint32(segmentsCount), + Payload: segmentPayload, + } + marshaledSegmentWithMetadata, err := proto.Marshal(segmentWithMetadata) + if err != nil { + return nil, err + } + segmentMessage, err := replicateMessageWithNewPayload(newMessage, marshaledSegmentWithMetadata) + if err != nil { + return nil, err + } + + segmentPayloads[index] = segmentPayload + segmentMessages[index] = segmentMessage + index++ + } + + // Skip reedsolomon if the combined total of data and parity segments exceeds the predefined limit of segmentsReedsolomonMaxCount. + // Exceeding this limit necessitates shard sizes to be multiples of 64, which are incompatible with clients that do not support forward error correction. + if paritySegmentsCount == 0 || segmentsCount+paritySegmentsCount > segmentsReedsolomonMaxCount { + return segmentMessages, nil + } + + enc, err := reedsolomon.New(segmentsCount, paritySegmentsCount) + if err != nil { + return nil, err + } + + // Align the size of the last segment payload. + lastSegmentPayload := segmentPayloads[segmentsCount-1] + segmentPayloads[segmentsCount-1] = make([]byte, segmentSize) + copy(segmentPayloads[segmentsCount-1], lastSegmentPayload) + + // Make space for parity data. + for i := segmentsCount; i < segmentsCount+paritySegmentsCount; i++ { + segmentPayloads[i] = make([]byte, segmentSize) + } + + err = enc.Encode(segmentPayloads) + if err != nil { + return nil, err + } + + // Create parity messages. + for i, index := segmentsCount, 0; i < segmentsCount+paritySegmentsCount; i++ { + segmentWithMetadata := &protobuf.SegmentMessage{ + EntireMessageHash: entireMessageHash, + SegmentsCount: 0, // indicates parity message + ParitySegmentIndex: uint32(index), + ParitySegmentsCount: uint32(paritySegmentsCount), + Payload: segmentPayloads[i], + } + marshaledSegmentWithMetadata, err := proto.Marshal(segmentWithMetadata) + if err != nil { + return nil, err + } + segmentMessage, err := replicateMessageWithNewPayload(newMessage, marshaledSegmentWithMetadata) + if err != nil { + return nil, err + } + + segmentMessages = append(segmentMessages, segmentMessage) + index++ + } + + return segmentMessages, nil +} + +// SegmentationLayerV1 reconstructs the message only when all segments have been successfully retrieved. +// It lacks the capability to perform forward error correction. +// Kept to test forward compatibility. +func (s *MessageSender) handleSegmentationLayerV1(message *v1protocol.StatusMessage) error { + logger := s.logger.With(zap.String("site", "handleSegmentationLayerV1")).With(zap.String("hash", types.HexBytes(message.TransportLayer.Hash).String())) + + segmentMessage := &SegmentMessage{ + SegmentMessage: &protobuf.SegmentMessage{}, + } + err := proto.Unmarshal(message.TransportLayer.Payload, segmentMessage.SegmentMessage) + if err != nil { + return errors.Wrap(err, "failed to unmarshal SegmentMessage") + } + + logger.Debug("handling message segment", zap.String("EntireMessageHash", types.HexBytes(segmentMessage.EntireMessageHash).String()), + zap.Uint32("Index", segmentMessage.Index), zap.Uint32("SegmentsCount", segmentMessage.SegmentsCount)) + + alreadyCompleted, err := s.persistence.IsMessageAlreadyCompleted(segmentMessage.EntireMessageHash) + if err != nil { + return err + } + if alreadyCompleted { + return ErrMessageSegmentsAlreadyCompleted + } + + if segmentMessage.SegmentsCount < 2 { + return ErrMessageSegmentsInvalidCount + } + + err = s.persistence.SaveMessageSegment(segmentMessage, message.TransportLayer.SigPubKey, time.Now().Unix()) + if err != nil { + return err + } + + segments, err := s.persistence.GetMessageSegments(segmentMessage.EntireMessageHash, message.TransportLayer.SigPubKey) + if err != nil { + return err + } + + if len(segments) != int(segmentMessage.SegmentsCount) { + return ErrMessageSegmentsIncomplete + } + + // Combine payload + var entirePayload bytes.Buffer + for _, segment := range segments { + _, err := entirePayload.Write(segment.Payload) + if err != nil { + return errors.Wrap(err, "failed to write segment payload") + } + } + + // Sanity check + entirePayloadHash := crypto.Keccak256(entirePayload.Bytes()) + if !bytes.Equal(entirePayloadHash, segmentMessage.EntireMessageHash) { + return ErrMessageSegmentsHashMismatch + } + + err = s.persistence.CompleteMessageSegments(segmentMessage.EntireMessageHash, message.TransportLayer.SigPubKey, time.Now().Unix()) + if err != nil { + return err + } + + message.TransportLayer.Payload = entirePayload.Bytes() + + return nil +} + +// SegmentationLayerV2 is capable of reconstructing the message from both complete and partial sets of data segments. +// It has capability to perform forward error correction. +func (s *MessageSender) handleSegmentationLayerV2(message *v1protocol.StatusMessage) error { + logger := s.logger.With(zap.String("site", "handleSegmentationLayerV2")).With(zap.String("hash", types.HexBytes(message.TransportLayer.Hash).String())) + + segmentMessage := &SegmentMessage{ + SegmentMessage: &protobuf.SegmentMessage{}, + } + err := proto.Unmarshal(message.TransportLayer.Payload, segmentMessage.SegmentMessage) + if err != nil { + return errors.Wrap(err, "failed to unmarshal SegmentMessage") + } + + logger.Debug("handling message segment", + zap.String("EntireMessageHash", types.HexBytes(segmentMessage.EntireMessageHash).String()), + zap.Uint32("Index", segmentMessage.Index), + zap.Uint32("SegmentsCount", segmentMessage.SegmentsCount), + zap.Uint32("ParitySegmentIndex", segmentMessage.ParitySegmentIndex), + zap.Uint32("ParitySegmentsCount", segmentMessage.ParitySegmentsCount)) + + alreadyCompleted, err := s.persistence.IsMessageAlreadyCompleted(segmentMessage.EntireMessageHash) + if err != nil { + return err + } + if alreadyCompleted { + return ErrMessageSegmentsAlreadyCompleted + } + + if !segmentMessage.IsValid() { + return ErrMessageSegmentsInvalidCount + } + + err = s.persistence.SaveMessageSegment(segmentMessage, message.TransportLayer.SigPubKey, time.Now().Unix()) + if err != nil { + return err + } + + segments, err := s.persistence.GetMessageSegments(segmentMessage.EntireMessageHash, message.TransportLayer.SigPubKey) + if err != nil { + return err + } + + if len(segments) == 0 { + return errors.New("unexpected state: no segments found after save operation") // This should theoretically never occur. + } + + firstSegmentMessage := segments[0] + lastSegmentMessage := segments[len(segments)-1] + + // First segment message must not be a parity message. + if firstSegmentMessage.IsParityMessage() || len(segments) != int(firstSegmentMessage.SegmentsCount) { + return ErrMessageSegmentsIncomplete + } + + payloads := make([][]byte, firstSegmentMessage.SegmentsCount+lastSegmentMessage.ParitySegmentsCount) + payloadSize := len(firstSegmentMessage.Payload) + + restoreUsingParityData := lastSegmentMessage.IsParityMessage() + if !restoreUsingParityData { + for i, segment := range segments { + payloads[i] = segment.Payload + } + } else { + enc, err := reedsolomon.New(int(firstSegmentMessage.SegmentsCount), int(lastSegmentMessage.ParitySegmentsCount)) + if err != nil { + return err + } + + var lastNonParitySegmentPayload []byte + for _, segment := range segments { + if !segment.IsParityMessage() { + if segment.Index == firstSegmentMessage.SegmentsCount-1 { + // Ensure last segment is aligned to payload size, as it is required by reedsolomon. + payloads[segment.Index] = make([]byte, payloadSize) + copy(payloads[segment.Index], segment.Payload) + lastNonParitySegmentPayload = segment.Payload + } else { + payloads[segment.Index] = segment.Payload + } + } else { + payloads[firstSegmentMessage.SegmentsCount+segment.ParitySegmentIndex] = segment.Payload + } + } + + err = enc.Reconstruct(payloads) + if err != nil { + return err + } + + ok, err := enc.Verify(payloads) + if err != nil { + return err + } + if !ok { + return ErrMessageSegmentsInvalidParity + } + + if lastNonParitySegmentPayload != nil { + payloads[firstSegmentMessage.SegmentsCount-1] = lastNonParitySegmentPayload // Bring back last segment with original length. + } + } + + // Combine payload. + var entirePayload bytes.Buffer + for i := 0; i < int(firstSegmentMessage.SegmentsCount); i++ { + _, err := entirePayload.Write(payloads[i]) + if err != nil { + return errors.Wrap(err, "failed to write segment payload") + } + } + + // Sanity check. + entirePayloadHash := crypto.Keccak256(entirePayload.Bytes()) + if !bytes.Equal(entirePayloadHash, segmentMessage.EntireMessageHash) { + return ErrMessageSegmentsHashMismatch + } + + err = s.persistence.CompleteMessageSegments(segmentMessage.EntireMessageHash, message.TransportLayer.SigPubKey, time.Now().Unix()) + if err != nil { + return err + } + + message.TransportLayer.Payload = entirePayload.Bytes() + + return nil +} + +func (s *MessageSender) CleanupSegments() error { + monthAgo := time.Now().AddDate(0, -1, 0).Unix() + + err := s.persistence.RemoveMessageSegmentsOlderThan(monthAgo) + if err != nil { + return err + } + + err = s.persistence.RemoveMessageSegmentsCompletedOlderThan(monthAgo) + if err != nil { + return err + } + + return nil +} diff --git a/protocol/common/message_segmentation_test.go b/protocol/common/message_segmentation_test.go new file mode 100644 index 000000000..83a007844 --- /dev/null +++ b/protocol/common/message_segmentation_test.go @@ -0,0 +1,205 @@ +package common + +import ( + "fmt" + "math" + "testing" + + "github.com/stretchr/testify/suite" + "go.uber.org/zap" + "golang.org/x/exp/slices" + + "github.com/status-im/status-go/appdatabase" + "github.com/status-im/status-go/eth-node/crypto" + "github.com/status-im/status-go/eth-node/types" + "github.com/status-im/status-go/protocol/sqlite" + "github.com/status-im/status-go/protocol/v1" + "github.com/status-im/status-go/t/helpers" +) + +func TestMessageSegmentationSuite(t *testing.T) { + suite.Run(t, new(MessageSegmentationSuite)) +} + +type MessageSegmentationSuite struct { + suite.Suite + + sender *MessageSender + testPayload []byte + logger *zap.Logger +} + +func (s *MessageSegmentationSuite) SetupSuite() { + s.testPayload = make([]byte, 1000) + for i := 0; i < 1000; i++ { + s.testPayload[i] = byte(i) + } +} + +func (s *MessageSegmentationSuite) SetupTest() { + identity, err := crypto.GenerateKey() + s.Require().NoError(err) + + database, err := helpers.SetupTestMemorySQLDB(appdatabase.DbInitializer{}) + s.Require().NoError(err) + err = sqlite.Migrate(database) + s.Require().NoError(err) + + s.logger, err = zap.NewDevelopment() + s.Require().NoError(err) + + s.sender, err = NewMessageSender( + identity, + database, + nil, + nil, + s.logger, + FeatureFlags{}, + ) + s.Require().NoError(err) +} + +func (s *MessageSegmentationSuite) SetupSubTest() { + s.SetupTest() +} + +func (s *MessageSegmentationSuite) TestHandleSegmentationLayer() { + testCases := []struct { + name string + segmentsCount int + expectedParitySegmentsCount int + retrievedSegments []int + retrievedParitySegments []int + segmentationLayerV1ShouldSucceed bool + segmentationLayerV2ShouldSucceed bool + }{ + { + name: "all segments retrieved", + segmentsCount: 2, + expectedParitySegmentsCount: 0, + retrievedSegments: []int{0, 1}, + retrievedParitySegments: []int{}, + segmentationLayerV1ShouldSucceed: true, + segmentationLayerV2ShouldSucceed: true, + }, + { + name: "all segments retrieved out of order", + segmentsCount: 2, + expectedParitySegmentsCount: 0, + retrievedSegments: []int{1, 0}, + retrievedParitySegments: []int{}, + segmentationLayerV1ShouldSucceed: true, + segmentationLayerV2ShouldSucceed: true, + }, + { + name: "all segments&parity retrieved", + segmentsCount: 8, + expectedParitySegmentsCount: 1, + retrievedSegments: []int{0, 1, 2, 3, 4, 5, 6, 7, 8}, + retrievedParitySegments: []int{8}, + segmentationLayerV1ShouldSucceed: true, + segmentationLayerV2ShouldSucceed: true, + }, + { + name: "all segments&parity retrieved out of order", + segmentsCount: 8, + expectedParitySegmentsCount: 1, + retrievedSegments: []int{8, 0, 7, 1, 6, 2, 5, 3, 4}, + retrievedParitySegments: []int{8}, + segmentationLayerV1ShouldSucceed: true, + segmentationLayerV2ShouldSucceed: true, + }, + { + name: "no segments retrieved", + segmentsCount: 2, + expectedParitySegmentsCount: 0, + retrievedSegments: []int{}, + retrievedParitySegments: []int{}, + segmentationLayerV1ShouldSucceed: false, + segmentationLayerV2ShouldSucceed: false, + }, + { + name: "not all needed segments&parity retrieved", + segmentsCount: 8, + expectedParitySegmentsCount: 1, + retrievedSegments: []int{1, 2, 8}, + retrievedParitySegments: []int{8}, + segmentationLayerV1ShouldSucceed: false, + segmentationLayerV2ShouldSucceed: false, + }, + { + name: "segments&parity retrieved", + segmentsCount: 8, + expectedParitySegmentsCount: 1, + retrievedSegments: []int{1, 2, 3, 4, 5, 6, 7, 8}, + retrievedParitySegments: []int{8}, + segmentationLayerV1ShouldSucceed: false, + segmentationLayerV2ShouldSucceed: true, // succeed even though one segment is missing, thank you reedsolomon + }, + { + name: "segments&parity retrieved out of order", + segmentsCount: 16, + expectedParitySegmentsCount: 2, + retrievedSegments: []int{17, 0, 16, 1, 15, 2, 14, 3, 13, 4, 12, 5, 11, 6, 10, 7}, + retrievedParitySegments: []int{16, 17}, + segmentationLayerV1ShouldSucceed: false, + segmentationLayerV2ShouldSucceed: true, // succeed even though two segments are missing, thank you reedsolomon + }, + } + + for _, version := range []string{"V1", "V2"} { + for _, tc := range testCases { + s.Run(fmt.Sprintf("%s %s", version, tc.name), func() { + segmentedMessages, err := segmentMessage(&types.NewMessage{Payload: s.testPayload}, int(math.Ceil(float64(len(s.testPayload))/float64(tc.segmentsCount)))) + s.Require().NoError(err) + s.Require().Len(segmentedMessages, tc.segmentsCount+tc.expectedParitySegmentsCount) + + message := &protocol.StatusMessage{TransportLayer: protocol.TransportLayer{ + SigPubKey: &s.sender.identity.PublicKey, + }} + + messageRecreated := false + handledSegments := []int{} + + for i, segmentIndex := range tc.retrievedSegments { + s.T().Log("i=", i, "segmentIndex=", segmentIndex) + + message.TransportLayer.Payload = segmentedMessages[segmentIndex].Payload + + if version == "V1" { + err = s.sender.handleSegmentationLayerV1(message) + // V1 is unable to handle parity segment + if slices.Contains(tc.retrievedParitySegments, segmentIndex) { + if len(handledSegments) >= tc.segmentsCount { + s.Require().ErrorIs(err, ErrMessageSegmentsAlreadyCompleted) + } else { + s.Require().ErrorIs(err, ErrMessageSegmentsInvalidCount) + } + continue + } + } else { + err = s.sender.handleSegmentationLayerV2(message) + } + + handledSegments = append(handledSegments, segmentIndex) + + if len(handledSegments) < tc.segmentsCount { + s.Require().ErrorIs(err, ErrMessageSegmentsIncomplete) + } else if len(handledSegments) == tc.segmentsCount { + s.Require().NoError(err) + s.Require().ElementsMatch(s.testPayload, message.TransportLayer.Payload) + messageRecreated = true + } else { + s.Require().ErrorIs(err, ErrMessageSegmentsAlreadyCompleted) + } + } + + if version == "V1" { + s.Require().Equal(tc.segmentationLayerV1ShouldSucceed, messageRecreated) + } else { + s.Require().Equal(tc.segmentationLayerV2ShouldSucceed, messageRecreated) + } + }) + } + } +} diff --git a/protocol/common/message_sender.go b/protocol/common/message_sender.go index 2bcbd6c3b..c5352b1e0 100644 --- a/protocol/common/message_sender.go +++ b/protocol/common/message_sender.go @@ -1,16 +1,13 @@ package common import ( - "bytes" "context" "crypto/ecdsa" "database/sql" - "math" "sync" "time" "github.com/golang/protobuf/proto" - "github.com/jinzhu/copier" "github.com/pkg/errors" datasyncnode "github.com/status-im/mvds/node" datasyncproto "github.com/status-im/mvds/protobuf" @@ -897,7 +894,7 @@ func (s *MessageSender) handleMessage(wakuMessage *types.Message) (*handleMessag return nil, err } - err = s.handleSegmentationLayer(message) + err = s.handleSegmentationLayerV2(message) if err != nil { hlogger.Debug("failed to handle segmentation layer message", zap.Error(err)) @@ -1281,151 +1278,3 @@ func (s *MessageSender) GetCurrentKeyForGroup(groupID []byte) (*encryption.HashR func (s *MessageSender) GetKeysForGroup(groupID []byte) ([]*encryption.HashRatchetKeyCompatibility, error) { return s.protocol.GetKeysForGroup(groupID) } - -// Segments message into smaller chunks if the size exceeds the maximum allowed -func segmentMessage(newMessage *types.NewMessage, maxSegmentSize int) ([]*types.NewMessage, error) { - if len(newMessage.Payload) <= maxSegmentSize { - return []*types.NewMessage{newMessage}, nil - } - - createSegment := func(chunkPayload []byte) (*types.NewMessage, error) { - copy := &types.NewMessage{} - err := copier.Copy(copy, newMessage) - if err != nil { - return nil, err - } - - copy.Payload = chunkPayload - copy.PowTarget = calculatePoW(chunkPayload) - return copy, nil - } - - entireMessageHash := crypto.Keccak256(newMessage.Payload) - payloadSize := len(newMessage.Payload) - segmentsCount := int(math.Ceil(float64(payloadSize) / float64(maxSegmentSize))) - - var segmentMessages []*types.NewMessage - - for start, index := 0, 0; start < payloadSize; start += maxSegmentSize { - end := start + maxSegmentSize - if end > payloadSize { - end = payloadSize - } - - chunk := newMessage.Payload[start:end] - - segmentMessageProto := &protobuf.SegmentMessage{ - EntireMessageHash: entireMessageHash, - Index: uint32(index), - SegmentsCount: uint32(segmentsCount), - Payload: chunk, - } - chunkPayload, err := proto.Marshal(segmentMessageProto) - if err != nil { - return nil, err - } - segmentMessage, err := createSegment(chunkPayload) - if err != nil { - return nil, err - } - - segmentMessages = append(segmentMessages, segmentMessage) - index++ - } - - return segmentMessages, nil -} - -func (s *MessageSender) segmentMessage(newMessage *types.NewMessage) ([]*types.NewMessage, error) { - // We set the max message size to 3/4 of the allowed message size, to leave - // room for segment message metadata. - newMessages, err := segmentMessage(newMessage, int(s.transport.MaxMessageSize()/4*3)) - s.logger.Debug("message segmented", zap.Int("segments", len(newMessages))) - return newMessages, err -} - -var ErrMessageSegmentsIncomplete = errors.New("message segments incomplete") -var ErrMessageSegmentsAlreadyCompleted = errors.New("message segments already completed") -var ErrMessageSegmentsInvalidCount = errors.New("invalid segments count") -var ErrMessageSegmentsHashMismatch = errors.New("hash of entire payload does not match") - -func (s *MessageSender) handleSegmentationLayer(message *v1protocol.StatusMessage) error { - logger := s.logger.With(zap.String("site", "handleSegmentationLayer")) - hlogger := logger.With(zap.String("hash", types.HexBytes(message.TransportLayer.Hash).String())) - - var segmentMessage protobuf.SegmentMessage - err := proto.Unmarshal(message.TransportLayer.Payload, &segmentMessage) - if err != nil { - return errors.Wrap(err, "failed to unmarshal SegmentMessage") - } - - hlogger.Debug("handling message segment", zap.String("EntireMessageHash", types.HexBytes(segmentMessage.EntireMessageHash).String()), - zap.Uint32("Index", segmentMessage.Index), zap.Uint32("SegmentsCount", segmentMessage.SegmentsCount)) - - alreadyCompleted, err := s.persistence.IsMessageAlreadyCompleted(segmentMessage.EntireMessageHash) - if err != nil { - return err - } - if alreadyCompleted { - return ErrMessageSegmentsAlreadyCompleted - } - - if segmentMessage.SegmentsCount < 2 { - return ErrMessageSegmentsInvalidCount - } - - err = s.persistence.SaveMessageSegment(&segmentMessage, message.TransportLayer.SigPubKey, time.Now().Unix()) - if err != nil { - return err - } - - segments, err := s.persistence.GetMessageSegments(segmentMessage.EntireMessageHash, message.TransportLayer.SigPubKey) - if err != nil { - return err - } - - if len(segments) != int(segmentMessage.SegmentsCount) { - return ErrMessageSegmentsIncomplete - } - - // Combine payload - var entirePayload bytes.Buffer - for _, segment := range segments { - _, err := entirePayload.Write(segment.Payload) - if err != nil { - return errors.Wrap(err, "failed to write segment payload") - } - } - - // Sanity check - entirePayloadHash := crypto.Keccak256(entirePayload.Bytes()) - if !bytes.Equal(entirePayloadHash, segmentMessage.EntireMessageHash) { - return ErrMessageSegmentsHashMismatch - } - - err = s.persistence.CompleteMessageSegments(segmentMessage.EntireMessageHash, message.TransportLayer.SigPubKey, time.Now().Unix()) - if err != nil { - return err - } - - message.TransportLayer.Payload = entirePayload.Bytes() - - return nil -} - -func (s *MessageSender) CleanupSegments() error { - weekAgo := time.Now().AddDate(0, 0, -7).Unix() - monthAgo := time.Now().AddDate(0, -1, 0).Unix() - - err := s.persistence.RemoveMessageSegmentsOlderThan(weekAgo) - if err != nil { - return err - } - - err = s.persistence.RemoveMessageSegmentsCompletedOlderThan(monthAgo) - if err != nil { - return err - } - - return nil -} diff --git a/protocol/common/raw_messages_persistence.go b/protocol/common/raw_messages_persistence.go index 5ef37508b..099f1f586 100644 --- a/protocol/common/raw_messages_persistence.go +++ b/protocol/common/raw_messages_persistence.go @@ -347,33 +347,46 @@ func (db *RawMessagesPersistence) IsMessageAlreadyCompleted(hash []byte) (bool, return alreadyCompleted > 0, nil } -func (db *RawMessagesPersistence) SaveMessageSegment(segment *protobuf.SegmentMessage, sigPubKey *ecdsa.PublicKey, timestamp int64) error { +func (db *RawMessagesPersistence) SaveMessageSegment(segment *SegmentMessage, sigPubKey *ecdsa.PublicKey, timestamp int64) error { sigPubKeyBlob := crypto.CompressPubkey(sigPubKey) - _, err := db.db.Exec("INSERT INTO message_segments (hash, segment_index, segments_count, sig_pub_key, payload, timestamp) VALUES (?, ?, ?, ?, ?, ?)", - segment.EntireMessageHash, segment.Index, segment.SegmentsCount, sigPubKeyBlob, segment.Payload, timestamp) + _, err := db.db.Exec("INSERT INTO message_segments (hash, segment_index, segments_count, parity_segment_index, parity_segments_count, sig_pub_key, payload, timestamp) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", + segment.EntireMessageHash, segment.Index, segment.SegmentsCount, segment.ParitySegmentIndex, segment.ParitySegmentsCount, sigPubKeyBlob, segment.Payload, timestamp) return err } // Get ordered message segments for given hash -func (db *RawMessagesPersistence) GetMessageSegments(hash []byte, sigPubKey *ecdsa.PublicKey) ([]*protobuf.SegmentMessage, error) { +func (db *RawMessagesPersistence) GetMessageSegments(hash []byte, sigPubKey *ecdsa.PublicKey) ([]*SegmentMessage, error) { sigPubKeyBlob := crypto.CompressPubkey(sigPubKey) - rows, err := db.db.Query("SELECT hash, segment_index, segments_count, payload FROM message_segments WHERE hash = ? AND sig_pub_key = ? ORDER BY segment_index", hash, sigPubKeyBlob) + rows, err := db.db.Query(` + SELECT + hash, segment_index, segments_count, parity_segment_index, parity_segments_count, payload + FROM + message_segments + WHERE + hash = ? AND sig_pub_key = ? + ORDER BY + (segments_count = 0) ASC, -- Prioritize segments_count > 0 + segment_index ASC, + parity_segment_index ASC`, + hash, sigPubKeyBlob) if err != nil { return nil, err } defer rows.Close() - var segments []*protobuf.SegmentMessage + var segments []*SegmentMessage for rows.Next() { - var segment protobuf.SegmentMessage - err := rows.Scan(&segment.EntireMessageHash, &segment.Index, &segment.SegmentsCount, &segment.Payload) + segment := &SegmentMessage{ + SegmentMessage: &protobuf.SegmentMessage{}, + } + err := rows.Scan(&segment.EntireMessageHash, &segment.Index, &segment.SegmentsCount, &segment.ParitySegmentIndex, &segment.ParitySegmentsCount, &segment.Payload) if err != nil { return nil, err } - segments = append(segments, &segment) + segments = append(segments, segment) } err = rows.Err() if err != nil { diff --git a/protocol/migrations/migrations.go b/protocol/migrations/migrations.go index ced922c9a..cbd73ec29 100644 --- a/protocol/migrations/migrations.go +++ b/protocol/migrations/migrations.go @@ -134,6 +134,7 @@ // 1711389881_add_profile_showcase_community_grant.up.sql (86B) // 1711937186_add_contact_customization_color.up.sql (172B) // 1712745141_hash_ratchet_encrypted_messages_key_id.up.sql (111B) +// 1712905223_add_parity_to_message_segments.up.sql (792B) // README.md (554B) // doc.go (870B) @@ -2883,6 +2884,26 @@ func _1712745141_hash_ratchet_encrypted_messages_key_idUpSql() (*asset, error) { return a, nil } +var __1712905223_add_parity_to_message_segmentsUpSql = []byte("\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\xff\xac\x92\x3d\x6b\xc3\x30\x10\x86\x77\xfd\x8a\x1b\x13\xd0\x90\x3d\x93\xec\x5e\x8a\xa9\x2c\x99\x8b\x3a\x64\x12\x6e\x2d\x1c\xd3\xf8\x83\x4a\x81\xfa\xdf\x97\x04\x53\x1a\x23\xf7\x03\xb2\xde\x3d\xd2\x3d\xbc\x77\x42\x1a\x24\x30\x22\x91\x08\xad\xf3\xbe\xac\x9d\xf5\xae\x6e\x5d\x17\x3c\x10\x2a\x91\x23\x18\x0d\xfd\xa9\xb2\xf3\xf6\x96\xb1\x94\x50\x18\x5c\x7a\xbe\x62\x00\x00\xc7\xd2\x1f\x21\x91\x3a\x01\xa5\x0d\xa8\x67\x29\xf9\xb5\x3e\x61\xb6\xe9\x2a\xf7\x01\x99\x32\xf8\x88\x14\x67\xbc\x7d\xed\xcf\x5d\x58\x80\x86\x72\x3c\xf5\x65\x15\x9d\xd1\xd4\x76\x38\xbf\xd8\x37\x37\xc6\xda\xa1\x69\x9d\x0f\x65\x3b\x2c\xfe\xfc\xde\x84\xd1\xfe\xc5\xf4\x16\xfd\x59\xb8\xa0\x2c\x17\x74\x80\x27\x3c\xc0\xea\x12\x0f\xff\x2e\xca\x6f\x93\xe1\xb3\x10\x78\xd4\x6a\x5e\x9d\xe0\x35\x68\x05\xa9\x56\x3b\x99\xa5\x06\x08\x0b\x29\x52\x64\xeb\x2d\x63\x99\xda\x23\x99\x8b\xa0\x8e\x2c\x6e\x92\xfa\xcd\xe3\x9a\xfb\x4c\xfe\x2b\xd3\xff\x89\xb2\x3d\x4a\x4c\x0d\xdc\x67\xf2\x86\xc3\x86\xed\x48\xe7\x4b\x87\xfb\x40\xba\x98\xce\x36\x4e\x7c\x06\x00\x00\xff\xff\xfd\x2e\x92\x5d\x18\x03\x00\x00") + +func _1712905223_add_parity_to_message_segmentsUpSqlBytes() ([]byte, error) { + return bindataRead( + __1712905223_add_parity_to_message_segmentsUpSql, + "1712905223_add_parity_to_message_segments.up.sql", + ) +} + +func _1712905223_add_parity_to_message_segmentsUpSql() (*asset, error) { + bytes, err := _1712905223_add_parity_to_message_segmentsUpSqlBytes() + if err != nil { + return nil, err + } + + info := bindataFileInfo{name: "1712905223_add_parity_to_message_segments.up.sql", size: 792, mode: os.FileMode(0644), modTime: time.Unix(1700000000, 0)} + a := &asset{bytes: bytes, info: info, digest: [32]uint8{0x9, 0x78, 0x5e, 0x84, 0x2b, 0xf9, 0x52, 0x77, 0x7, 0x6c, 0xb6, 0x76, 0x6d, 0x59, 0xd2, 0x1f, 0x6c, 0xe6, 0xd, 0x86, 0x85, 0xeb, 0x34, 0x95, 0x6e, 0xa, 0xa9, 0xd8, 0x3b, 0x7a, 0xd, 0x1a}} + return a, nil +} + var _readmeMd = []byte("\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\xff\x54\x91\xc1\xce\xd3\x30\x10\x84\xef\x7e\x8a\x91\x7a\x01\xa9\x2a\x8f\xc0\x0d\x71\x82\x03\x48\x1c\xc9\x36\x9e\x36\x96\x1c\x6f\xf0\xae\x93\xe6\xed\x91\xa3\xc2\xdf\xff\x66\xed\xd8\x33\xdf\x78\x4f\xa7\x13\xbe\xea\x06\x57\x6c\x35\x39\x31\xa7\x7b\x15\x4f\x5a\xec\x73\x08\xbf\x08\x2d\x79\x7f\x4a\x43\x5b\x86\x17\xfd\x8c\x21\xea\x56\x5e\x47\x90\x4a\x14\x75\x48\xde\x64\x37\x2c\x6a\x96\xae\x99\x48\x05\xf6\x27\x77\x13\xad\x08\xae\x8a\x51\xe7\x25\xf3\xf1\xa9\x9f\xf9\x58\x58\x2c\xad\xbc\xe0\x8b\x56\xf0\x21\x5d\xeb\x4c\x95\xb3\xae\x84\x60\xd4\xdc\xe6\x82\x5d\x1b\x36\x6d\x39\x62\x92\xf5\xb8\x11\xdb\x92\xd3\x28\xce\xe0\x13\xe1\x72\xcd\x3c\x63\xd4\x65\x87\xae\xac\xe8\xc3\x28\x2e\x67\x44\x66\x3a\x21\x25\xa2\x72\xac\x14\x67\xbc\x84\x9f\x53\x32\x8c\x52\x70\x25\x56\xd6\xfd\x8d\x05\x37\xad\x30\x9d\x9f\xa6\x86\x0f\xcd\x58\x7f\xcf\x34\x93\x3b\xed\x90\x9f\xa4\x1f\xcf\x30\x85\x4d\x07\x58\xaf\x7f\x25\xc4\x9d\xf3\x72\x64\x84\xd0\x7f\xf9\x9b\x3a\x2d\x84\xef\x85\x48\x66\x8d\xd8\x88\x9b\x8c\x8c\x98\x5b\xf6\x74\x14\x4e\x33\x0d\xc9\xe0\x93\x38\xda\x12\xc5\x69\xbd\xe4\xf0\x2e\x7a\x78\x07\x1c\xfe\x13\x9f\x91\x29\x31\x95\x7b\x7f\x62\x59\x37\xb4\xe5\x5e\x25\xfe\x33\xee\xd5\x53\x71\xd6\xda\x3a\xd8\xcb\xde\x2e\xf8\xa1\x90\x55\x53\x0c\xc7\xaa\x0d\xe9\x76\x14\x29\x1c\x7b\x68\xdd\x2f\xe1\x6f\x00\x00\x00\xff\xff\x3c\x0a\xc2\xfe\x2a\x02\x00\x00") func readmeMdBytes() ([]byte, error) { @@ -3148,6 +3169,7 @@ var _bindata = map[string]func() (*asset, error){ "1711389881_add_profile_showcase_community_grant.up.sql": _1711389881_add_profile_showcase_community_grantUpSql, "1711937186_add_contact_customization_color.up.sql": _1711937186_add_contact_customization_colorUpSql, "1712745141_hash_ratchet_encrypted_messages_key_id.up.sql": _1712745141_hash_ratchet_encrypted_messages_key_idUpSql, + "1712905223_add_parity_to_message_segments.up.sql": _1712905223_add_parity_to_message_segmentsUpSql, "README.md": readmeMd, "doc.go": docGo, } @@ -3332,6 +3354,7 @@ var _bintree = &bintree{nil, map[string]*bintree{ "1711389881_add_profile_showcase_community_grant.up.sql": {_1711389881_add_profile_showcase_community_grantUpSql, map[string]*bintree{}}, "1711937186_add_contact_customization_color.up.sql": {_1711937186_add_contact_customization_colorUpSql, map[string]*bintree{}}, "1712745141_hash_ratchet_encrypted_messages_key_id.up.sql": {_1712745141_hash_ratchet_encrypted_messages_key_idUpSql, map[string]*bintree{}}, + "1712905223_add_parity_to_message_segments.up.sql": {_1712905223_add_parity_to_message_segmentsUpSql, map[string]*bintree{}}, "README.md": {readmeMd, map[string]*bintree{}}, "doc.go": {docGo, map[string]*bintree{}}, }} diff --git a/protocol/migrations/sqlite/1712905223_add_parity_to_message_segments.up.sql b/protocol/migrations/sqlite/1712905223_add_parity_to_message_segments.up.sql new file mode 100644 index 000000000..d42158e78 --- /dev/null +++ b/protocol/migrations/sqlite/1712905223_add_parity_to_message_segments.up.sql @@ -0,0 +1,19 @@ +ALTER TABLE message_segments RENAME TO old_message_segments; + +CREATE TABLE message_segments ( + hash BLOB NOT NULL, + segment_index INTEGER NOT NULL, + segments_count INTEGER NOT NULL, + payload BLOB NOT NULL, + sig_pub_key BLOB NOT NULL, + timestamp INTEGER NOT NULL, + parity_segment_index INTEGER NOT NULL, + parity_segments_count INTEGER NOT NULL, + PRIMARY KEY (hash, sig_pub_key, segment_index, segments_count, parity_segment_index, parity_segments_count) ON CONFLICT REPLACE +); + +INSERT INTO message_segments (hash, segment_index, segments_count, payload, sig_pub_key, timestamp, parity_segment_index, parity_segments_count) +SELECT hash, segment_index, segments_count, payload, sig_pub_key, timestamp, 0, 0 +FROM old_message_segments; + +DROP TABLE old_message_segments; diff --git a/protocol/protobuf/segment_message.pb.go b/protocol/protobuf/segment_message.pb.go index b593d31a8..ca8d91fab 100644 --- a/protocol/protobuf/segment_message.pb.go +++ b/protocol/protobuf/segment_message.pb.go @@ -33,6 +33,10 @@ type SegmentMessage struct { SegmentsCount uint32 `protobuf:"varint,3,opt,name=segments_count,json=segmentsCount,proto3" json:"segments_count,omitempty"` // The payload data for this particular segment Payload []byte `protobuf:"bytes,4,opt,name=payload,proto3" json:"payload,omitempty"` + // Index of this parity segment + ParitySegmentIndex uint32 `protobuf:"varint,5,opt,name=parity_segment_index,json=paritySegmentIndex,proto3" json:"parity_segment_index,omitempty"` + // Total number of parity segments + ParitySegmentsCount uint32 `protobuf:"varint,6,opt,name=parity_segments_count,json=paritySegmentsCount,proto3" json:"parity_segments_count,omitempty"` } func (x *SegmentMessage) Reset() { @@ -95,12 +99,26 @@ func (x *SegmentMessage) GetPayload() []byte { return nil } +func (x *SegmentMessage) GetParitySegmentIndex() uint32 { + if x != nil { + return x.ParitySegmentIndex + } + return 0 +} + +func (x *SegmentMessage) GetParitySegmentsCount() uint32 { + if x != nil { + return x.ParitySegmentsCount + } + return 0 +} + var File_segment_message_proto protoreflect.FileDescriptor var file_segment_message_proto_rawDesc = []byte{ 0x0a, 0x15, 0x73, 0x65, 0x67, 0x6d, 0x65, 0x6e, 0x74, 0x5f, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, - 0x66, 0x22, 0x97, 0x01, 0x0a, 0x0e, 0x53, 0x65, 0x67, 0x6d, 0x65, 0x6e, 0x74, 0x4d, 0x65, 0x73, + 0x66, 0x22, 0xfd, 0x01, 0x0a, 0x0e, 0x53, 0x65, 0x67, 0x6d, 0x65, 0x6e, 0x74, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x2e, 0x0a, 0x13, 0x65, 0x6e, 0x74, 0x69, 0x72, 0x65, 0x5f, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x5f, 0x68, 0x61, 0x73, 0x68, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x11, 0x65, 0x6e, 0x74, 0x69, 0x72, 0x65, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, @@ -109,9 +127,15 @@ var file_segment_message_proto_rawDesc = []byte{ 0x67, 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x5f, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0d, 0x73, 0x65, 0x67, 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x43, 0x6f, 0x75, 0x6e, 0x74, 0x12, 0x18, 0x0a, 0x07, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x04, 0x20, 0x01, - 0x28, 0x0c, 0x52, 0x07, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x42, 0x0d, 0x5a, 0x0b, 0x2e, - 0x2f, 0x3b, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, - 0x6f, 0x33, + 0x28, 0x0c, 0x52, 0x07, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x12, 0x30, 0x0a, 0x14, 0x70, + 0x61, 0x72, 0x69, 0x74, 0x79, 0x5f, 0x73, 0x65, 0x67, 0x6d, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x6e, + 0x64, 0x65, 0x78, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x12, 0x70, 0x61, 0x72, 0x69, 0x74, + 0x79, 0x53, 0x65, 0x67, 0x6d, 0x65, 0x6e, 0x74, 0x49, 0x6e, 0x64, 0x65, 0x78, 0x12, 0x32, 0x0a, + 0x15, 0x70, 0x61, 0x72, 0x69, 0x74, 0x79, 0x5f, 0x73, 0x65, 0x67, 0x6d, 0x65, 0x6e, 0x74, 0x73, + 0x5f, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x13, 0x70, 0x61, + 0x72, 0x69, 0x74, 0x79, 0x53, 0x65, 0x67, 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x43, 0x6f, 0x75, 0x6e, + 0x74, 0x42, 0x0d, 0x5a, 0x0b, 0x2e, 0x2f, 0x3b, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, + 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( diff --git a/protocol/protobuf/segment_message.proto b/protocol/protobuf/segment_message.proto index 8b9bbc20e..2de195056 100644 --- a/protocol/protobuf/segment_message.proto +++ b/protocol/protobuf/segment_message.proto @@ -12,4 +12,8 @@ message SegmentMessage { uint32 segments_count = 3; // The payload data for this particular segment bytes payload = 4; + // Index of this parity segment + uint32 parity_segment_index = 5; + // Total number of parity segments + uint32 parity_segments_count = 6; }