status-go/protocol/common/message_segmentation_test.go
2024-04-17 22:05:53 +02:00

206 lines
6.6 KiB
Go

package common
import (
"fmt"
"math"
"testing"
"github.com/stretchr/testify/suite"
"go.uber.org/zap"
"golang.org/x/exp/slices"
"github.com/status-im/status-go/appdatabase"
"github.com/status-im/status-go/eth-node/crypto"
"github.com/status-im/status-go/eth-node/types"
"github.com/status-im/status-go/protocol/sqlite"
"github.com/status-im/status-go/protocol/v1"
"github.com/status-im/status-go/t/helpers"
)
func TestMessageSegmentationSuite(t *testing.T) {
suite.Run(t, new(MessageSegmentationSuite))
}
type MessageSegmentationSuite struct {
suite.Suite
sender *MessageSender
testPayload []byte
logger *zap.Logger
}
func (s *MessageSegmentationSuite) SetupSuite() {
s.testPayload = make([]byte, 1000)
for i := 0; i < 1000; i++ {
s.testPayload[i] = byte(i)
}
}
func (s *MessageSegmentationSuite) SetupTest() {
identity, err := crypto.GenerateKey()
s.Require().NoError(err)
database, err := helpers.SetupTestMemorySQLDB(appdatabase.DbInitializer{})
s.Require().NoError(err)
err = sqlite.Migrate(database)
s.Require().NoError(err)
s.logger, err = zap.NewDevelopment()
s.Require().NoError(err)
s.sender, err = NewMessageSender(
identity,
database,
nil,
nil,
s.logger,
FeatureFlags{},
)
s.Require().NoError(err)
}
func (s *MessageSegmentationSuite) SetupSubTest() {
s.SetupTest()
}
func (s *MessageSegmentationSuite) TestHandleSegmentationLayer() {
testCases := []struct {
name string
segmentsCount int
expectedParitySegmentsCount int
retrievedSegments []int
retrievedParitySegments []int
segmentationLayerV1ShouldSucceed bool
segmentationLayerV2ShouldSucceed bool
}{
{
name: "all segments retrieved",
segmentsCount: 2,
expectedParitySegmentsCount: 0,
retrievedSegments: []int{0, 1},
retrievedParitySegments: []int{},
segmentationLayerV1ShouldSucceed: true,
segmentationLayerV2ShouldSucceed: true,
},
{
name: "all segments retrieved out of order",
segmentsCount: 2,
expectedParitySegmentsCount: 0,
retrievedSegments: []int{1, 0},
retrievedParitySegments: []int{},
segmentationLayerV1ShouldSucceed: true,
segmentationLayerV2ShouldSucceed: true,
},
{
name: "all segments&parity retrieved",
segmentsCount: 8,
expectedParitySegmentsCount: 1,
retrievedSegments: []int{0, 1, 2, 3, 4, 5, 6, 7, 8},
retrievedParitySegments: []int{8},
segmentationLayerV1ShouldSucceed: true,
segmentationLayerV2ShouldSucceed: true,
},
{
name: "all segments&parity retrieved out of order",
segmentsCount: 8,
expectedParitySegmentsCount: 1,
retrievedSegments: []int{8, 0, 7, 1, 6, 2, 5, 3, 4},
retrievedParitySegments: []int{8},
segmentationLayerV1ShouldSucceed: true,
segmentationLayerV2ShouldSucceed: true,
},
{
name: "no segments retrieved",
segmentsCount: 2,
expectedParitySegmentsCount: 0,
retrievedSegments: []int{},
retrievedParitySegments: []int{},
segmentationLayerV1ShouldSucceed: false,
segmentationLayerV2ShouldSucceed: false,
},
{
name: "not all needed segments&parity retrieved",
segmentsCount: 8,
expectedParitySegmentsCount: 1,
retrievedSegments: []int{1, 2, 8},
retrievedParitySegments: []int{8},
segmentationLayerV1ShouldSucceed: false,
segmentationLayerV2ShouldSucceed: false,
},
{
name: "segments&parity retrieved",
segmentsCount: 8,
expectedParitySegmentsCount: 1,
retrievedSegments: []int{1, 2, 3, 4, 5, 6, 7, 8},
retrievedParitySegments: []int{8},
segmentationLayerV1ShouldSucceed: false,
segmentationLayerV2ShouldSucceed: true, // succeed even though one segment is missing, thank you reedsolomon
},
{
name: "segments&parity retrieved out of order",
segmentsCount: 16,
expectedParitySegmentsCount: 2,
retrievedSegments: []int{17, 0, 16, 1, 15, 2, 14, 3, 13, 4, 12, 5, 11, 6, 10, 7},
retrievedParitySegments: []int{16, 17},
segmentationLayerV1ShouldSucceed: false,
segmentationLayerV2ShouldSucceed: true, // succeed even though two segments are missing, thank you reedsolomon
},
}
for _, version := range []string{"V1", "V2"} {
for _, tc := range testCases {
s.Run(fmt.Sprintf("%s %s", version, tc.name), func() {
segmentedMessages, err := segmentMessage(&types.NewMessage{Payload: s.testPayload}, int(math.Ceil(float64(len(s.testPayload))/float64(tc.segmentsCount))))
s.Require().NoError(err)
s.Require().Len(segmentedMessages, tc.segmentsCount+tc.expectedParitySegmentsCount)
message := &protocol.StatusMessage{TransportLayer: protocol.TransportLayer{
SigPubKey: &s.sender.identity.PublicKey,
}}
messageRecreated := false
handledSegments := []int{}
for i, segmentIndex := range tc.retrievedSegments {
s.T().Log("i=", i, "segmentIndex=", segmentIndex)
message.TransportLayer.Payload = segmentedMessages[segmentIndex].Payload
if version == "V1" {
err = s.sender.handleSegmentationLayerV1(message)
// V1 is unable to handle parity segment
if slices.Contains(tc.retrievedParitySegments, segmentIndex) {
if len(handledSegments) >= tc.segmentsCount {
s.Require().ErrorIs(err, ErrMessageSegmentsAlreadyCompleted)
} else {
s.Require().ErrorIs(err, ErrMessageSegmentsInvalidCount)
}
continue
}
} else {
err = s.sender.handleSegmentationLayerV2(message)
}
handledSegments = append(handledSegments, segmentIndex)
if len(handledSegments) < tc.segmentsCount {
s.Require().ErrorIs(err, ErrMessageSegmentsIncomplete)
} else if len(handledSegments) == tc.segmentsCount {
s.Require().NoError(err)
s.Require().ElementsMatch(s.testPayload, message.TransportLayer.Payload)
messageRecreated = true
} else {
s.Require().ErrorIs(err, ErrMessageSegmentsAlreadyCompleted)
}
}
if version == "V1" {
s.Require().Equal(tc.segmentationLayerV1ShouldSucceed, messageRecreated)
} else {
s.Require().Equal(tc.segmentationLayerV2ShouldSucceed, messageRecreated)
}
})
}
}
}