package common import ( "bytes" "math" "time" "github.com/golang/protobuf/proto" "github.com/jinzhu/copier" "github.com/klauspost/reedsolomon" "github.com/pkg/errors" "go.uber.org/zap" "github.com/status-im/status-go/eth-node/crypto" "github.com/status-im/status-go/eth-node/types" "github.com/status-im/status-go/protocol/protobuf" v1protocol "github.com/status-im/status-go/protocol/v1" ) var ErrMessageSegmentsIncomplete = errors.New("message segments incomplete") var ErrMessageSegmentsAlreadyCompleted = errors.New("message segments already completed") var ErrMessageSegmentsInvalidCount = errors.New("invalid segments count") var ErrMessageSegmentsHashMismatch = errors.New("hash of entire payload does not match") var ErrMessageSegmentsInvalidParity = errors.New("invalid parity segments") const ( segmentsParityRate = 0.125 segmentsReedsolomonMaxCount = 256 ) type SegmentMessage struct { *protobuf.SegmentMessage } func (s *SegmentMessage) IsValid() bool { return s.SegmentsCount >= 2 || s.ParitySegmentsCount > 0 } func (s *SegmentMessage) IsParityMessage() bool { return s.SegmentsCount == 0 && s.ParitySegmentsCount > 0 } func (s *MessageSender) segmentMessage(newMessage *types.NewMessage) ([]*types.NewMessage, error) { // We set the max message size to 3/4 of the allowed message size, to leave // room for segment message metadata. newMessages, err := segmentMessage(newMessage, int(s.transport.MaxMessageSize()/4*3)) s.logger.Debug("message segmented", zap.Int("segments", len(newMessages))) return newMessages, err } func replicateMessageWithNewPayload(message *types.NewMessage, payload []byte) (*types.NewMessage, error) { copy := &types.NewMessage{} err := copier.Copy(copy, message) if err != nil { return nil, err } copy.Payload = payload copy.PowTarget = calculatePoW(payload) return copy, nil } // Segments message into smaller chunks if the size exceeds segmentSize. func segmentMessage(newMessage *types.NewMessage, segmentSize int) ([]*types.NewMessage, error) { if len(newMessage.Payload) <= segmentSize { return []*types.NewMessage{newMessage}, nil } entireMessageHash := crypto.Keccak256(newMessage.Payload) entirePayloadSize := len(newMessage.Payload) segmentsCount := int(math.Ceil(float64(entirePayloadSize) / float64(segmentSize))) paritySegmentsCount := int(math.Floor(float64(segmentsCount) * segmentsParityRate)) segmentPayloads := make([][]byte, segmentsCount+paritySegmentsCount) segmentMessages := make([]*types.NewMessage, segmentsCount) for start, index := 0, 0; start < entirePayloadSize; start += segmentSize { end := start + segmentSize if end > entirePayloadSize { end = entirePayloadSize } segmentPayload := newMessage.Payload[start:end] segmentWithMetadata := &protobuf.SegmentMessage{ EntireMessageHash: entireMessageHash, Index: uint32(index), SegmentsCount: uint32(segmentsCount), Payload: segmentPayload, } marshaledSegmentWithMetadata, err := proto.Marshal(segmentWithMetadata) if err != nil { return nil, err } segmentMessage, err := replicateMessageWithNewPayload(newMessage, marshaledSegmentWithMetadata) if err != nil { return nil, err } segmentPayloads[index] = segmentPayload segmentMessages[index] = segmentMessage index++ } // Skip reedsolomon if the combined total of data and parity segments exceeds the predefined limit of segmentsReedsolomonMaxCount. // Exceeding this limit necessitates shard sizes to be multiples of 64, which are incompatible with clients that do not support forward error correction. if paritySegmentsCount == 0 || segmentsCount+paritySegmentsCount > segmentsReedsolomonMaxCount { return segmentMessages, nil } enc, err := reedsolomon.New(segmentsCount, paritySegmentsCount) if err != nil { return nil, err } // Align the size of the last segment payload. lastSegmentPayload := segmentPayloads[segmentsCount-1] segmentPayloads[segmentsCount-1] = make([]byte, segmentSize) copy(segmentPayloads[segmentsCount-1], lastSegmentPayload) // Make space for parity data. for i := segmentsCount; i < segmentsCount+paritySegmentsCount; i++ { segmentPayloads[i] = make([]byte, segmentSize) } err = enc.Encode(segmentPayloads) if err != nil { return nil, err } // Create parity messages. for i, index := segmentsCount, 0; i < segmentsCount+paritySegmentsCount; i++ { segmentWithMetadata := &protobuf.SegmentMessage{ EntireMessageHash: entireMessageHash, SegmentsCount: 0, // indicates parity message ParitySegmentIndex: uint32(index), ParitySegmentsCount: uint32(paritySegmentsCount), Payload: segmentPayloads[i], } marshaledSegmentWithMetadata, err := proto.Marshal(segmentWithMetadata) if err != nil { return nil, err } segmentMessage, err := replicateMessageWithNewPayload(newMessage, marshaledSegmentWithMetadata) if err != nil { return nil, err } segmentMessages = append(segmentMessages, segmentMessage) index++ } return segmentMessages, nil } // SegmentationLayerV1 reconstructs the message only when all segments have been successfully retrieved. // It lacks the capability to perform forward error correction. // Kept to test forward compatibility. func (s *MessageSender) handleSegmentationLayerV1(message *v1protocol.StatusMessage) error { logger := s.logger.With(zap.String("site", "handleSegmentationLayerV1")).With(zap.String("hash", types.HexBytes(message.TransportLayer.Hash).String())) segmentMessage := &SegmentMessage{ SegmentMessage: &protobuf.SegmentMessage{}, } err := proto.Unmarshal(message.TransportLayer.Payload, segmentMessage.SegmentMessage) if err != nil { return errors.Wrap(err, "failed to unmarshal SegmentMessage") } logger.Debug("handling message segment", zap.String("EntireMessageHash", types.HexBytes(segmentMessage.EntireMessageHash).String()), zap.Uint32("Index", segmentMessage.Index), zap.Uint32("SegmentsCount", segmentMessage.SegmentsCount)) alreadyCompleted, err := s.persistence.IsMessageAlreadyCompleted(segmentMessage.EntireMessageHash) if err != nil { return err } if alreadyCompleted { return ErrMessageSegmentsAlreadyCompleted } if segmentMessage.SegmentsCount < 2 { return ErrMessageSegmentsInvalidCount } err = s.persistence.SaveMessageSegment(segmentMessage, message.TransportLayer.SigPubKey, time.Now().Unix()) if err != nil { return err } segments, err := s.persistence.GetMessageSegments(segmentMessage.EntireMessageHash, message.TransportLayer.SigPubKey) if err != nil { return err } if len(segments) != int(segmentMessage.SegmentsCount) { return ErrMessageSegmentsIncomplete } // Combine payload var entirePayload bytes.Buffer for _, segment := range segments { _, err := entirePayload.Write(segment.Payload) if err != nil { return errors.Wrap(err, "failed to write segment payload") } } // Sanity check entirePayloadHash := crypto.Keccak256(entirePayload.Bytes()) if !bytes.Equal(entirePayloadHash, segmentMessage.EntireMessageHash) { return ErrMessageSegmentsHashMismatch } err = s.persistence.CompleteMessageSegments(segmentMessage.EntireMessageHash, message.TransportLayer.SigPubKey, time.Now().Unix()) if err != nil { return err } message.TransportLayer.Payload = entirePayload.Bytes() return nil } // SegmentationLayerV2 is capable of reconstructing the message from both complete and partial sets of data segments. // It has capability to perform forward error correction. func (s *MessageSender) handleSegmentationLayerV2(message *v1protocol.StatusMessage) error { logger := s.logger.With(zap.String("site", "handleSegmentationLayerV2")).With(zap.String("hash", types.HexBytes(message.TransportLayer.Hash).String())) segmentMessage := &SegmentMessage{ SegmentMessage: &protobuf.SegmentMessage{}, } err := proto.Unmarshal(message.TransportLayer.Payload, segmentMessage.SegmentMessage) if err != nil { return errors.Wrap(err, "failed to unmarshal SegmentMessage") } logger.Debug("handling message segment", zap.String("EntireMessageHash", types.HexBytes(segmentMessage.EntireMessageHash).String()), zap.Uint32("Index", segmentMessage.Index), zap.Uint32("SegmentsCount", segmentMessage.SegmentsCount), zap.Uint32("ParitySegmentIndex", segmentMessage.ParitySegmentIndex), zap.Uint32("ParitySegmentsCount", segmentMessage.ParitySegmentsCount)) alreadyCompleted, err := s.persistence.IsMessageAlreadyCompleted(segmentMessage.EntireMessageHash) if err != nil { return err } if alreadyCompleted { return ErrMessageSegmentsAlreadyCompleted } if !segmentMessage.IsValid() { return ErrMessageSegmentsInvalidCount } err = s.persistence.SaveMessageSegment(segmentMessage, message.TransportLayer.SigPubKey, time.Now().Unix()) if err != nil { return err } segments, err := s.persistence.GetMessageSegments(segmentMessage.EntireMessageHash, message.TransportLayer.SigPubKey) if err != nil { return err } if len(segments) == 0 { return errors.New("unexpected state: no segments found after save operation") // This should theoretically never occur. } firstSegmentMessage := segments[0] lastSegmentMessage := segments[len(segments)-1] // First segment message must not be a parity message. if firstSegmentMessage.IsParityMessage() || len(segments) != int(firstSegmentMessage.SegmentsCount) { return ErrMessageSegmentsIncomplete } payloads := make([][]byte, firstSegmentMessage.SegmentsCount+lastSegmentMessage.ParitySegmentsCount) payloadSize := len(firstSegmentMessage.Payload) restoreUsingParityData := lastSegmentMessage.IsParityMessage() if !restoreUsingParityData { for i, segment := range segments { payloads[i] = segment.Payload } } else { enc, err := reedsolomon.New(int(firstSegmentMessage.SegmentsCount), int(lastSegmentMessage.ParitySegmentsCount)) if err != nil { return err } var lastNonParitySegmentPayload []byte for _, segment := range segments { if !segment.IsParityMessage() { if segment.Index == firstSegmentMessage.SegmentsCount-1 { // Ensure last segment is aligned to payload size, as it is required by reedsolomon. payloads[segment.Index] = make([]byte, payloadSize) copy(payloads[segment.Index], segment.Payload) lastNonParitySegmentPayload = segment.Payload } else { payloads[segment.Index] = segment.Payload } } else { payloads[firstSegmentMessage.SegmentsCount+segment.ParitySegmentIndex] = segment.Payload } } err = enc.Reconstruct(payloads) if err != nil { return err } ok, err := enc.Verify(payloads) if err != nil { return err } if !ok { return ErrMessageSegmentsInvalidParity } if lastNonParitySegmentPayload != nil { payloads[firstSegmentMessage.SegmentsCount-1] = lastNonParitySegmentPayload // Bring back last segment with original length. } } // Combine payload. var entirePayload bytes.Buffer for i := 0; i < int(firstSegmentMessage.SegmentsCount); i++ { _, err := entirePayload.Write(payloads[i]) if err != nil { return errors.Wrap(err, "failed to write segment payload") } } // Sanity check. entirePayloadHash := crypto.Keccak256(entirePayload.Bytes()) if !bytes.Equal(entirePayloadHash, segmentMessage.EntireMessageHash) { return ErrMessageSegmentsHashMismatch } err = s.persistence.CompleteMessageSegments(segmentMessage.EntireMessageHash, message.TransportLayer.SigPubKey, time.Now().Unix()) if err != nil { return err } message.TransportLayer.Payload = entirePayload.Bytes() return nil } func (s *MessageSender) CleanupSegments() error { monthAgo := time.Now().AddDate(0, -1, 0).Unix() err := s.persistence.RemoveMessageSegmentsOlderThan(monthAgo) if err != nil { return err } err = s.persistence.RemoveMessageSegmentsCompletedOlderThan(monthAgo) if err != nil { return err } return nil }