refactor: validate protobuffers for lightpush and relay (#824)

This commit is contained in:
richΛrd 2023-10-24 12:26:02 -04:00
parent fa51d10b4b
commit 94f18c537c
9 changed files with 209 additions and 68 deletions

View File

@ -32,6 +32,7 @@ func createTestMsg(version uint32) *pb.WakuMessage {
message.Payload = []byte{0, 1, 2} message.Payload = []byte{0, 1, 2}
message.Version = version message.Version = version
message.Timestamp = 123456 message.Timestamp = 123456
message.ContentTopic = "abc"
return message return message
} }
@ -265,6 +266,7 @@ func TestDecoupledStoreFromRelay(t *testing.T) {
_, filter, err := wakuNode2.LegacyFilter().Subscribe(ctx, legacy_filter.ContentFilter{ _, filter, err := wakuNode2.LegacyFilter().Subscribe(ctx, legacy_filter.ContentFilter{
Topic: string(relay.DefaultWakuTopic), Topic: string(relay.DefaultWakuTopic),
ContentTopics: []string{"abc"},
}, legacy_filter.WithPeer(wakuNode1.host.ID())) }, legacy_filter.WithPeer(wakuNode1.host.ID()))
require.NoError(t, err) require.NoError(t, err)
@ -281,7 +283,10 @@ func TestDecoupledStoreFromRelay(t *testing.T) {
go func() { go func() {
// MSG1 should be pushed in NODE2 via filter // MSG1 should be pushed in NODE2 via filter
defer wg.Done() defer wg.Done()
env := <-filter.Chan env, ok := <-filter.Chan
if !ok {
require.Fail(t, "no message")
}
require.Equal(t, msg.Timestamp, env.Message().Timestamp) require.Equal(t, msg.Timestamp, env.Message().Timestamp)
}() }()

View File

@ -104,10 +104,6 @@ func (fm *FilterMap) Notify(msg *pb.WakuMessage, requestID string) {
// Broadcasting message so it's stored // Broadcasting message so it's stored
fm.broadcaster.Submit(envelope) fm.broadcaster.Submit(envelope)
if msg.ContentTopic == "" {
filter.Chan <- envelope
}
// TODO: In case of no topics we should either trigger here for all messages, // TODO: In case of no topics we should either trigger here for all messages,
// or we should not allow such filter to exist in the first place. // or we should not allow such filter to exist in the first place.
for _, contentTopic := range filter.ContentFilters { for _, contentTopic := range filter.ContentFilters {

View File

@ -54,8 +54,8 @@ var (
writeResponseFailure metricsErrCategory = "write_response_failure" writeResponseFailure metricsErrCategory = "write_response_failure"
dialFailure metricsErrCategory = "dial_failure" dialFailure metricsErrCategory = "dial_failure"
messagePushFailure metricsErrCategory = "message_push_failure" messagePushFailure metricsErrCategory = "message_push_failure"
emptyRequestBodyFailure metricsErrCategory = "empty_request_body_failure" requestBodyFailure metricsErrCategory = "request_failure"
emptyResponseBodyFailure metricsErrCategory = "empty_response_body_failure" responseBodyFailure metricsErrCategory = "response_body_failure"
peerNotFoundFailure metricsErrCategory = "peer_not_found_failure" peerNotFoundFailure metricsErrCategory = "peer_not_found_failure"
) )

View File

@ -0,0 +1,48 @@
package pb
import "errors"
var (
errMissingRequestID = errors.New("missing RequestId field")
errMissingQuery = errors.New("missing Query field")
errMissingMessage = errors.New("missing Message field")
errMissingPubsubTopic = errors.New("missing PubsubTopic field")
errRequestIDMismatch = errors.New("RequestID in response does not match request")
errMissingResponse = errors.New("missing Response field")
)
func (x *PushRPC) ValidateRequest() error {
if x.RequestId == "" {
return errMissingRequestID
}
if x.Query == nil {
return errMissingQuery
}
if x.Query.PubsubTopic == "" {
return errMissingPubsubTopic
}
if x.Query.Message == nil {
return errMissingMessage
}
return x.Query.Message.Validate()
}
func (x *PushRPC) ValidateResponse(requestID string) error {
if x.RequestId == "" {
return errMissingRequestID
}
if x.RequestId != requestID {
return errRequestIDMismatch
}
if x.Response == nil {
return errMissingResponse
}
return nil
}

View File

@ -0,0 +1,35 @@
package pb
import (
"testing"
"github.com/stretchr/testify/require"
"github.com/waku-org/go-waku/waku/v2/protocol/pb"
)
func TestValidateRequest(t *testing.T) {
request := PushRPC{}
require.ErrorIs(t, request.ValidateRequest(), errMissingRequestID)
request.RequestId = "test"
require.ErrorIs(t, request.ValidateRequest(), errMissingQuery)
request.Query = &PushRequest{}
require.ErrorIs(t, request.ValidateRequest(), errMissingPubsubTopic)
request.Query.PubsubTopic = "test"
require.ErrorIs(t, request.ValidateRequest(), errMissingMessage)
request.Query.Message = &pb.WakuMessage{
Payload: []byte{1, 2, 3},
ContentTopic: "test",
}
require.NoError(t, request.ValidateRequest())
}
func TestValidateResponse(t *testing.T) {
response := PushRPC{}
require.ErrorIs(t, response.ValidateResponse("test"), errMissingRequestID)
response.RequestId = "test1"
require.ErrorIs(t, response.ValidateResponse("test"), errRequestIDMismatch)
response.RequestId = "test"
require.ErrorIs(t, response.ValidateResponse("test"), errMissingResponse)
response.Response = &PushResponse{}
require.NoError(t, response.ValidateResponse("test"))
}

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"encoding/hex" "encoding/hex"
"errors" "errors"
"fmt"
"math" "math"
"github.com/libp2p/go-libp2p/core/host" "github.com/libp2p/go-libp2p/core/host"
@ -81,7 +82,6 @@ func (wakuLP *WakuLightPush) onRequest(ctx context.Context) func(network.Stream)
logger := wakuLP.log.With(logging.HostID("peer", stream.Conn().RemotePeer())) logger := wakuLP.log.With(logging.HostID("peer", stream.Conn().RemotePeer()))
requestPushRPC := &pb.PushRPC{} requestPushRPC := &pb.PushRPC{}
writer := pbio.NewDelimitedWriter(stream)
reader := pbio.NewDelimitedReader(stream, math.MaxInt32) reader := pbio.NewDelimitedReader(stream, math.MaxInt32)
err := reader.ReadMsg(requestPushRPC) err := reader.ReadMsg(requestPushRPC)
@ -94,14 +94,21 @@ func (wakuLP *WakuLightPush) onRequest(ctx context.Context) func(network.Stream)
return return
} }
responsePushRPC := &pb.PushRPC{
RequestId: requestPushRPC.RequestId,
Response: &pb.PushResponse{},
}
if err := requestPushRPC.ValidateRequest(); err != nil {
responsePushRPC.Response.Info = err.Error()
wakuLP.metrics.RecordError(requestBodyFailure)
wakuLP.reply(stream, responsePushRPC, logger)
return
}
logger = logger.With(zap.String("requestID", requestPushRPC.RequestId)) logger = logger.With(zap.String("requestID", requestPushRPC.RequestId))
responsePushRPC := &pb.PushRPC{}
responsePushRPC.RequestId = requestPushRPC.RequestId
if requestPushRPC.Query != nil {
logger.Info("push request") logger.Info("push request")
response := new(pb.PushResponse)
pubSubTopic := requestPushRPC.Query.PubsubTopic pubSubTopic := requestPushRPC.Query.PubsubTopic
message := requestPushRPC.Query.Message message := requestPushRPC.Query.Message
@ -111,20 +118,34 @@ func (wakuLP *WakuLightPush) onRequest(ctx context.Context) func(network.Stream)
// TODO: Assumes success, should probably be extended to check for network, peers, etc // TODO: Assumes success, should probably be extended to check for network, peers, etc
// It might make sense to use WithReadiness option here? // It might make sense to use WithReadiness option here?
_, err := wakuLP.relay.PublishToTopic(ctx, message, pubSubTopic) _, err = wakuLP.relay.PublishToTopic(ctx, message, pubSubTopic)
if err != nil { if err != nil {
logger.Error("publishing message", zap.Error(err)) logger.Error("publishing message", zap.Error(err))
wakuLP.metrics.RecordError(messagePushFailure) wakuLP.metrics.RecordError(messagePushFailure)
response.Info = "Could not publish message" responsePushRPC.Response.Info = fmt.Sprintf("Could not publish message: %s", err.Error())
return
} else { } else {
response.IsSuccess = true responsePushRPC.Response.IsSuccess = true
response.Info = "OK" responsePushRPC.Response.Info = "OK"
} }
responsePushRPC.Response = response wakuLP.reply(stream, responsePushRPC, logger)
err = writer.WriteMsg(responsePushRPC) logger.Info("response sent")
stream.Close()
if responsePushRPC.Response.IsSuccess {
logger.Info("request success")
} else {
logger.Info("request failure", zap.String("info", responsePushRPC.Response.Info))
}
}
}
func (wakuLP *WakuLightPush) reply(stream network.Stream, responsePushRPC *pb.PushRPC, logger *zap.Logger) {
writer := pbio.NewDelimitedWriter(stream)
err := writer.WriteMsg(responsePushRPC)
if err != nil { if err != nil {
wakuLP.metrics.RecordError(writeResponseFailure) wakuLP.metrics.RecordError(writeResponseFailure)
logger.Error("writing response", zap.Error(err)) logger.Error("writing response", zap.Error(err))
@ -133,26 +154,7 @@ func (wakuLP *WakuLightPush) onRequest(ctx context.Context) func(network.Stream)
} }
return return
} }
logger.Info("response sent")
stream.Close() stream.Close()
} else {
wakuLP.metrics.RecordError(emptyRequestBodyFailure)
if err := stream.Reset(); err != nil {
wakuLP.log.Error("resetting connection", zap.Error(err))
}
}
if requestPushRPC.Response != nil {
if requestPushRPC.Response.IsSuccess {
logger.Info("request success")
} else {
logger.Info("request failure", zap.String("info=", requestPushRPC.Response.Info))
}
} else {
wakuLP.metrics.RecordError(emptyResponseBodyFailure)
}
}
} }
// request sends a message via lightPush protocol to either a specified peer or peer that is selected. // request sends a message via lightPush protocol to either a specified peer or peer that is selected.
@ -201,6 +203,11 @@ func (wakuLP *WakuLightPush) request(ctx context.Context, req *pb.PushRequest, p
stream.Close() stream.Close()
if err = pushResponseRPC.ValidateResponse(pushRequestRPC.RequestId); err != nil {
wakuLP.metrics.RecordError(responseBodyFailure)
return nil, err
}
return pushResponseRPC.Response, nil return pushResponseRPC.Response, nil
} }

View File

@ -0,0 +1,47 @@
package pb
import (
"errors"
"google.golang.org/protobuf/proto"
)
const MaxMetaAttrLength = 64
var (
errMissingPayload = errors.New("missing Payload field")
errMissingContentTopic = errors.New("missing ContentTopic field")
errInvalidMetaLength = errors.New("invalid length for Meta field")
)
func (msg *WakuMessage) Validate() error {
if len(msg.Payload) == 0 {
return errMissingPayload
}
if msg.ContentTopic == "" {
return errMissingContentTopic
}
if len(msg.Meta) > MaxMetaAttrLength {
return errInvalidMetaLength
}
return nil
}
func Unmarshal(data []byte) (*WakuMessage, error) {
msg := &WakuMessage{}
err := proto.Unmarshal(data, msg)
if err != nil {
return nil, err
}
err = msg.Validate()
if err != nil {
return nil, err
}
return msg, nil
}

View File

@ -12,7 +12,6 @@ import (
"github.com/ethereum/go-ethereum/crypto/secp256k1" "github.com/ethereum/go-ethereum/crypto/secp256k1"
pubsub "github.com/libp2p/go-libp2p-pubsub" pubsub "github.com/libp2p/go-libp2p-pubsub"
"github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/peer"
proto "google.golang.org/protobuf/proto"
"github.com/waku-org/go-waku/waku/v2/hash" "github.com/waku-org/go-waku/waku/v2/hash"
"github.com/waku-org/go-waku/waku/v2/protocol/pb" "github.com/waku-org/go-waku/waku/v2/protocol/pb"
@ -62,8 +61,7 @@ func (w *WakuRelay) RemoveTopicValidator(topic string) {
func (w *WakuRelay) topicValidator(topic string) func(ctx context.Context, peerID peer.ID, message *pubsub.Message) bool { func (w *WakuRelay) topicValidator(topic string) func(ctx context.Context, peerID peer.ID, message *pubsub.Message) bool {
return func(ctx context.Context, peerID peer.ID, message *pubsub.Message) bool { return func(ctx context.Context, peerID peer.ID, message *pubsub.Message) bool {
msg := new(pb.WakuMessage) msg, err := pb.Unmarshal(message.Data)
err := proto.Unmarshal(message.Data, msg)
if err != nil { if err != nil {
return false return false
} }

View File

@ -234,12 +234,15 @@ func (w *WakuRelay) PublishToTopic(ctx context.Context, message *pb.WakuMessage,
return nil, errors.New("message can't be null") return nil, errors.New("message can't be null")
} }
if err := message.Validate(); err != nil {
return nil, err
}
if !w.EnoughPeersToPublishToTopic(topic) { if !w.EnoughPeersToPublishToTopic(topic) {
return nil, errors.New("not enough peers to publish") return nil, errors.New("not enough peers to publish")
} }
pubSubTopic, err := w.upsertTopic(topic) pubSubTopic, err := w.upsertTopic(topic)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -461,11 +464,13 @@ func (w *WakuRelay) pubsubTopicMsgHandler(pubsubTopic string, sub *pubsub.Subscr
sub.Cancel() sub.Cancel()
return return
} }
wakuMessage := &pb.WakuMessage{}
if err := proto.Unmarshal(msg.Data, wakuMessage); err != nil { wakuMessage, err := pb.Unmarshal(msg.Data)
if err != nil {
w.log.Error("decoding message", zap.Error(err)) w.log.Error("decoding message", zap.Error(err))
return return
} }
envelope := waku_proto.NewEnvelope(wakuMessage, w.timesource.Now().UnixNano(), pubsubTopic) envelope := waku_proto.NewEnvelope(wakuMessage, w.timesource.Now().UnixNano(), pubsubTopic)
w.metrics.RecordMessage(envelope) w.metrics.RecordMessage(envelope)