feat: introduce messages segmentation

This commit is contained in:
Patryk Osmaczko 2023-11-09 21:36:57 +01:00 committed by osmaczko
parent fa44e03ac2
commit 6bb806caad
8 changed files with 775 additions and 455 deletions

View File

@ -1,13 +1,16 @@
package common
import (
"bytes"
"context"
"crypto/ecdsa"
"database/sql"
"math"
"sync"
"time"
"github.com/golang/protobuf/proto"
"github.com/jinzhu/copier"
"github.com/pkg/errors"
datasyncnode "github.com/vacp2p/mvds/node"
datasyncproto "github.com/vacp2p/mvds/protobuf"
@ -756,6 +759,12 @@ func (s *MessageSender) HandleMessages(wakuMessage *types.Message) ([]*v1protoco
return nil, nil, s.persistence.SaveHashRatchetMessage(info.GroupID, info.KeyID, wakuMessage)
}
// The current message segment has been successfully retrieved.
// However, the collection of segments is not yet complete.
if err == ErrMessageSegmentsIncomplete {
return nil, nil, nil
}
return nil, nil, err
}
statusMessages = append(statusMessages, response.Messages()...)
@ -820,6 +829,20 @@ func (s *MessageSender) handleMessage(wakuMessage *types.Message) (*handleMessag
return nil, err
}
err = s.handleSegmentationLayer(response.Message)
if err != nil {
hlogger.Debug("failed to handle segmentation layer message", zap.Error(err))
// Segments not completed yet, stop processing
if err == ErrMessageSegmentsIncomplete {
return nil, err
}
// Segments already completed, stop processing
if err == ErrMessageSegmentsAlreadyCompleted {
return nil, err
}
}
err = s.handleEncryptionLayer(context.Background(), response.Message)
if err != nil {
hlogger.Debug("failed to handle an encryption message", zap.Error(err))
@ -1178,3 +1201,126 @@ func (s *MessageSender) GetCurrentKeyForGroup(groupID []byte) (*encryption.HashR
func (s *MessageSender) GetKeysForGroup(groupID []byte) ([]*encryption.HashRatchetKeyCompatibility, error) {
return s.protocol.GetKeysForGroup(groupID)
}
// Segments message into smaller chunks if the size exceeds the maximum allowed
func segmentMessage(newMessage *types.NewMessage, maxSegmentSize int) ([]*types.NewMessage, error) {
if len(newMessage.Payload) <= maxSegmentSize {
return []*types.NewMessage{newMessage}, nil
}
createSegment := func(chunkPayload []byte) (*types.NewMessage, error) {
copy := &types.NewMessage{}
err := copier.Copy(copy, newMessage)
if err != nil {
return nil, err
}
copy.Payload = chunkPayload
copy.PowTarget = calculatePoW(chunkPayload)
return copy, nil
}
entireMessageHash := crypto.Keccak256(newMessage.Payload)
payloadSize := len(newMessage.Payload)
segmentsCount := int(math.Ceil(float64(payloadSize) / float64(maxSegmentSize)))
var segmentMessages []*types.NewMessage
for start, index := 0, 0; start < payloadSize; start += maxSegmentSize {
end := start + maxSegmentSize
if end > payloadSize {
end = payloadSize
}
chunk := newMessage.Payload[start:end]
segmentMessageProto := &protobuf.SegmentMessage{
EntireMessageHash: entireMessageHash,
Index: uint32(index),
SegmentsCount: uint32(segmentsCount),
Payload: chunk,
}
chunkPayload, err := proto.Marshal(segmentMessageProto)
if err != nil {
return nil, err
}
segmentMessage, err := createSegment(chunkPayload)
if err != nil {
return nil, err
}
segmentMessages = append(segmentMessages, segmentMessage)
index++
}
return segmentMessages, nil
}
var ErrMessageSegmentsIncomplete = errors.New("message segments incomplete")
var ErrMessageSegmentsAlreadyCompleted = errors.New("message segments already completed")
var ErrMessageSegmentsInvalidCount = errors.New("invalid segments count")
var ErrMessageSegmentsHashMismatch = errors.New("hash of entire payload does not match")
func (s *MessageSender) handleSegmentationLayer(message *v1protocol.StatusMessage) error {
logger := s.logger.With(zap.String("site", "handleSegmentationLayer"))
hlogger := logger.With(zap.ByteString("hash", message.TransportLayer.Hash))
var segmentMessage protobuf.SegmentMessage
err := proto.Unmarshal(message.TransportLayer.Payload, &segmentMessage)
if err != nil {
return errors.Wrap(err, "failed to unmarshal SegmentMessage")
}
hlogger.Debug("handling message segment", zap.ByteString("EntireMessageHash", segmentMessage.EntireMessageHash),
zap.Uint32("Index", segmentMessage.Index), zap.Uint32("SegmentsCount", segmentMessage.SegmentsCount))
alreadyCompleted, err := s.persistence.IsMessageAlreadyCompleted(segmentMessage.EntireMessageHash)
if err != nil {
return err
}
if alreadyCompleted {
return ErrMessageSegmentsAlreadyCompleted
}
if segmentMessage.SegmentsCount < 2 {
return ErrMessageSegmentsInvalidCount
}
err = s.persistence.SaveMessageSegment(&segmentMessage, message.TransportLayer.SigPubKey)
if err != nil {
return err
}
segments, err := s.persistence.GetMessageSegments(segmentMessage.EntireMessageHash, message.TransportLayer.SigPubKey)
if err != nil {
return err
}
if len(segments) != int(segmentMessage.SegmentsCount) {
return ErrMessageSegmentsIncomplete
}
// Combine payload
var entirePayload bytes.Buffer
for _, segment := range segments {
_, err := entirePayload.Write(segment.Payload)
if err != nil {
return errors.Wrap(err, "failed to write segment payload")
}
}
// Sanity check
entirePayloadHash := crypto.Keccak256(entirePayload.Bytes())
if !bytes.Equal(entirePayloadHash, segmentMessage.EntireMessageHash) {
return ErrMessageSegmentsHashMismatch
}
err = s.persistence.CompleteMessageSegments(segmentMessage.EntireMessageHash, message.TransportLayer.SigPubKey)
if err != nil {
return err
}
message.TransportLayer.Payload = entirePayload.Bytes()
return nil
}

View File

@ -1,6 +1,7 @@
package common
import (
"math"
"testing"
transport2 "github.com/status-im/status-go/protocol/transport"
@ -304,3 +305,44 @@ func (s *MessageSenderSuite) TestHandleOutOfOrderHashRatchet() {
s.Require().Len(msgs, 0)
}
func (s *MessageSenderSuite) TestHandleSegmentMessages() {
relayerKey, err := crypto.GenerateKey()
s.Require().NoError(err)
authorKey, err := crypto.GenerateKey()
s.Require().NoError(err)
encodedPayload, err := proto.Marshal(&s.testMessage)
s.Require().NoError(err)
wrappedPayload, err := v1protocol.WrapMessageV1(encodedPayload, protobuf.ApplicationMetadataMessage_CHAT_MESSAGE, authorKey)
s.Require().NoError(err)
segmentedMessages, err := segmentMessage(&types.NewMessage{Payload: wrappedPayload}, int(math.Ceil(float64(len(wrappedPayload))/2)))
s.Require().NoError(err)
s.Require().Len(segmentedMessages, 2)
message := &types.Message{}
message.Sig = crypto.FromECDSAPub(&relayerKey.PublicKey)
message.Payload = segmentedMessages[0].Payload
// First segment is received, no messages are decoded
decodedMessages, _, err := s.sender.HandleMessages(message)
s.Require().NoError(err)
s.Require().Len(decodedMessages, 0)
// Second (and final) segment is received, reassembled message is decoded
message.Payload = segmentedMessages[1].Payload
decodedMessages, _, err = s.sender.HandleMessages(message)
s.Require().NoError(err)
s.Require().Len(decodedMessages, 1)
s.Require().Equal(&authorKey.PublicKey, decodedMessages[0].SigPubKey())
s.Require().Equal(v1protocol.MessageID(&authorKey.PublicKey, wrappedPayload), decodedMessages[0].ApplicationLayer.ID)
s.Require().Equal(encodedPayload, decodedMessages[0].ApplicationLayer.Payload)
s.Require().Equal(protobuf.ApplicationMetadataMessage_CHAT_MESSAGE, decodedMessages[0].ApplicationLayer.Type)
// Receiving another segment after the message has been reassembled is considered an error
_, _, err = s.sender.HandleMessages(message)
s.Require().ErrorIs(err, ErrMessageSegmentsAlreadyCompleted)
}

View File

@ -3,6 +3,7 @@ package common
import (
"bytes"
"context"
"crypto/ecdsa"
"database/sql"
"encoding/gob"
"strings"
@ -336,3 +337,78 @@ func (db RawMessagesPersistence) DeleteHashRatchetMessages(ids [][]byte) error {
return err
}
func (db *RawMessagesPersistence) IsMessageAlreadyCompleted(hash []byte) (bool, error) {
var alreadyCompleted int
err := db.db.QueryRow("SELECT COUNT(*) FROM message_segments_completed WHERE hash = ?", hash).Scan(&alreadyCompleted)
if err != nil {
return false, err
}
return alreadyCompleted > 0, nil
}
func (db *RawMessagesPersistence) SaveMessageSegment(segment *protobuf.SegmentMessage, sigPubKey *ecdsa.PublicKey) error {
sigPubKeyBlob := crypto.CompressPubkey(sigPubKey)
_, err := db.db.Exec("INSERT INTO message_segments (hash, segment_index, segments_count, sig_pub_key, payload) VALUES (?, ?, ?, ?, ?)",
segment.EntireMessageHash, segment.Index, segment.SegmentsCount, sigPubKeyBlob, segment.Payload)
return err
}
// Get ordered message segments for given hash
func (db *RawMessagesPersistence) GetMessageSegments(hash []byte, sigPubKey *ecdsa.PublicKey) ([]*protobuf.SegmentMessage, error) {
sigPubKeyBlob := crypto.CompressPubkey(sigPubKey)
rows, err := db.db.Query("SELECT hash, segment_index, segments_count, payload FROM message_segments WHERE hash = ? AND sig_pub_key = ? ORDER BY segment_index", hash, sigPubKeyBlob)
if err != nil {
return nil, err
}
defer rows.Close()
var segments []*protobuf.SegmentMessage
for rows.Next() {
var segment protobuf.SegmentMessage
err := rows.Scan(&segment.EntireMessageHash, &segment.Index, &segment.SegmentsCount, &segment.Payload)
if err != nil {
return nil, err
}
segments = append(segments, &segment)
}
err = rows.Err()
if err != nil {
return nil, err
}
return segments, nil
}
func (db *RawMessagesPersistence) CompleteMessageSegments(hash []byte, sigPubKey *ecdsa.PublicKey) error {
tx, err := db.db.BeginTx(context.Background(), &sql.TxOptions{})
if err != nil {
return err
}
defer func() {
if err == nil {
err = tx.Commit()
return
}
// don't shadow original error
_ = tx.Rollback()
}()
sigPubKeyBlob := crypto.CompressPubkey(sigPubKey)
_, err = tx.Exec("DELETE FROM message_segments WHERE hash = ? AND sig_pub_key = ?", hash, sigPubKeyBlob)
if err != nil {
return err
}
_, err = tx.Exec("INSERT INTO message_segments_completed (hash, sig_pub_key) VALUES (?,?)", hash, sigPubKeyBlob)
if err != nil {
return err
}
return err
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,14 @@
CREATE TABLE IF NOT EXISTS message_segments (
hash BLOB NOT NULL,
segment_index INTEGER NOT NULL,
segments_count INTEGER NOT NULL,
payload BLOB NOT NULL,
sig_pub_key BLOB NOT NULL,
PRIMARY KEY (hash, sig_pub_key, segment_index) ON CONFLICT REPLACE
);
CREATE TABLE IF NOT EXISTS message_segments_completed (
hash BLOB NOT NULL,
sig_pub_key BLOB NOT NULL,
PRIMARY KEY (hash, sig_pub_key)
);

View File

@ -0,0 +1,111 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// source: segment_message.proto
package protobuf
import (
fmt "fmt"
proto "github.com/golang/protobuf/proto"
math "math"
)
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package
type SegmentMessage struct {
// hash of the entire original message
EntireMessageHash []byte `protobuf:"bytes,1,opt,name=entire_message_hash,json=entireMessageHash,proto3" json:"entire_message_hash,omitempty"`
// Index of this segment within the entire original message
Index uint32 `protobuf:"varint,2,opt,name=index,proto3" json:"index,omitempty"`
// Total number of segments the entire original message is divided into
SegmentsCount uint32 `protobuf:"varint,3,opt,name=segments_count,json=segmentsCount,proto3" json:"segments_count,omitempty"`
// The payload data for this particular segment
Payload []byte `protobuf:"bytes,4,opt,name=payload,proto3" json:"payload,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *SegmentMessage) Reset() { *m = SegmentMessage{} }
func (m *SegmentMessage) String() string { return proto.CompactTextString(m) }
func (*SegmentMessage) ProtoMessage() {}
func (*SegmentMessage) Descriptor() ([]byte, []int) {
return fileDescriptor_857302809a887a8b, []int{0}
}
func (m *SegmentMessage) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_SegmentMessage.Unmarshal(m, b)
}
func (m *SegmentMessage) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_SegmentMessage.Marshal(b, m, deterministic)
}
func (m *SegmentMessage) XXX_Merge(src proto.Message) {
xxx_messageInfo_SegmentMessage.Merge(m, src)
}
func (m *SegmentMessage) XXX_Size() int {
return xxx_messageInfo_SegmentMessage.Size(m)
}
func (m *SegmentMessage) XXX_DiscardUnknown() {
xxx_messageInfo_SegmentMessage.DiscardUnknown(m)
}
var xxx_messageInfo_SegmentMessage proto.InternalMessageInfo
func (m *SegmentMessage) GetEntireMessageHash() []byte {
if m != nil {
return m.EntireMessageHash
}
return nil
}
func (m *SegmentMessage) GetIndex() uint32 {
if m != nil {
return m.Index
}
return 0
}
func (m *SegmentMessage) GetSegmentsCount() uint32 {
if m != nil {
return m.SegmentsCount
}
return 0
}
func (m *SegmentMessage) GetPayload() []byte {
if m != nil {
return m.Payload
}
return nil
}
func init() {
proto.RegisterType((*SegmentMessage)(nil), "protobuf.SegmentMessage")
}
func init() {
proto.RegisterFile("segment_message.proto", fileDescriptor_857302809a887a8b)
}
var fileDescriptor_857302809a887a8b = []byte{
// 169 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0x12, 0x2d, 0x4e, 0x4d, 0xcf,
0x4d, 0xcd, 0x2b, 0x89, 0xcf, 0x4d, 0x2d, 0x2e, 0x4e, 0x4c, 0x4f, 0xd5, 0x2b, 0x28, 0xca, 0x2f,
0xc9, 0x17, 0xe2, 0x00, 0x53, 0x49, 0xa5, 0x69, 0x4a, 0xd3, 0x19, 0xb9, 0xf8, 0x82, 0x21, 0x6a,
0x7c, 0x21, 0x4a, 0x84, 0xf4, 0xb8, 0x84, 0x53, 0xf3, 0x4a, 0x32, 0x8b, 0x52, 0x61, 0x9a, 0xe2,
0x33, 0x12, 0x8b, 0x33, 0x24, 0x18, 0x15, 0x18, 0x35, 0x78, 0x82, 0x04, 0x21, 0x52, 0x50, 0xb5,
0x1e, 0x89, 0xc5, 0x19, 0x42, 0x22, 0x5c, 0xac, 0x99, 0x79, 0x29, 0xa9, 0x15, 0x12, 0x4c, 0x0a,
0x8c, 0x1a, 0xbc, 0x41, 0x10, 0x8e, 0x90, 0x2a, 0x17, 0x1f, 0xd4, 0xee, 0xe2, 0xf8, 0xe4, 0xfc,
0xd2, 0xbc, 0x12, 0x09, 0x66, 0xb0, 0x34, 0x2f, 0x4c, 0xd4, 0x19, 0x24, 0x28, 0x24, 0xc1, 0xc5,
0x5e, 0x90, 0x58, 0x99, 0x93, 0x9f, 0x98, 0x22, 0xc1, 0x02, 0xb6, 0x00, 0xc6, 0x75, 0xe2, 0x8d,
0xe2, 0xd6, 0xd3, 0xb7, 0x86, 0x39, 0x34, 0x89, 0x0d, 0xcc, 0x32, 0x06, 0x04, 0x00, 0x00, 0xff,
0xff, 0x12, 0x40, 0x55, 0x2e, 0xd2, 0x00, 0x00, 0x00,
}

View File

@ -0,0 +1,15 @@
syntax = "proto3";
option go_package = "./;protobuf";
package protobuf;
message SegmentMessage {
// hash of the entire original message
bytes entire_message_hash = 1;
// Index of this segment within the entire original message
uint32 index = 2;
// Total number of segments the entire original message is divided into
uint32 segments_count = 3;
// The payload data for this particular segment
bytes payload = 4;
}

View File

@ -4,7 +4,7 @@ import (
"github.com/golang/protobuf/proto"
)
//go:generate protoc --go_out=. ./chat_message.proto ./application_metadata_message.proto ./membership_update_message.proto ./command.proto ./contact.proto ./pairing.proto ./push_notifications.proto ./emoji_reaction.proto ./enums.proto ./shard.proto ./group_chat_invitation.proto ./chat_identity.proto ./communities.proto ./pin_message.proto ./anon_metrics.proto ./status_update.proto ./sync_settings.proto ./contact_verification.proto ./community_update.proto ./community_shard_key.proto ./url_data.proto ./community_privileged_user_sync_message.proto ./profile_showcase.proto
//go:generate protoc --go_out=. ./chat_message.proto ./application_metadata_message.proto ./membership_update_message.proto ./command.proto ./contact.proto ./pairing.proto ./push_notifications.proto ./emoji_reaction.proto ./enums.proto ./shard.proto ./group_chat_invitation.proto ./chat_identity.proto ./communities.proto ./pin_message.proto ./anon_metrics.proto ./status_update.proto ./sync_settings.proto ./contact_verification.proto ./community_update.proto ./community_shard_key.proto ./url_data.proto ./community_privileged_user_sync_message.proto ./profile_showcase.proto ./segment_message.proto
func Unmarshal(payload []byte) (*ApplicationMetadataMessage, error) {
var message ApplicationMetadataMessage