From 94f18c537c30fae6474938d5b0eff531057ceb88 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?rich=CE=9Brd?= Date: Tue, 24 Oct 2023 12:26:02 -0400 Subject: [PATCH] refactor: validate protobuffers for lightpush and relay (#824) --- waku/v2/node/wakunode2_test.go | 9 +- waku/v2/protocol/legacy_filter/filter_map.go | 4 - waku/v2/protocol/lightpush/metrics.go | 16 +-- waku/v2/protocol/lightpush/pb/validation.go | 48 ++++++++ .../protocol/lightpush/pb/validation_test.go | 35 ++++++ waku/v2/protocol/lightpush/waku_lightpush.go | 103 ++++++++++-------- waku/v2/protocol/pb/validation.go | 47 ++++++++ waku/v2/protocol/relay/validators.go | 4 +- waku/v2/protocol/relay/waku_relay.go | 11 +- 9 files changed, 209 insertions(+), 68 deletions(-) create mode 100644 waku/v2/protocol/lightpush/pb/validation.go create mode 100644 waku/v2/protocol/lightpush/pb/validation_test.go create mode 100644 waku/v2/protocol/pb/validation.go diff --git a/waku/v2/node/wakunode2_test.go b/waku/v2/node/wakunode2_test.go index 333282e6..9d4d75e8 100644 --- a/waku/v2/node/wakunode2_test.go +++ b/waku/v2/node/wakunode2_test.go @@ -32,6 +32,7 @@ func createTestMsg(version uint32) *pb.WakuMessage { message.Payload = []byte{0, 1, 2} message.Version = version message.Timestamp = 123456 + message.ContentTopic = "abc" return message } @@ -264,7 +265,8 @@ func TestDecoupledStoreFromRelay(t *testing.T) { time.Sleep(2 * time.Second) _, filter, err := wakuNode2.LegacyFilter().Subscribe(ctx, legacy_filter.ContentFilter{ - Topic: string(relay.DefaultWakuTopic), + Topic: string(relay.DefaultWakuTopic), + ContentTopics: []string{"abc"}, }, legacy_filter.WithPeer(wakuNode1.host.ID())) require.NoError(t, err) @@ -281,7 +283,10 @@ func TestDecoupledStoreFromRelay(t *testing.T) { go func() { // MSG1 should be pushed in NODE2 via filter defer wg.Done() - env := <-filter.Chan + env, ok := <-filter.Chan + if !ok { + require.Fail(t, "no message") + } require.Equal(t, msg.Timestamp, env.Message().Timestamp) }() diff --git a/waku/v2/protocol/legacy_filter/filter_map.go b/waku/v2/protocol/legacy_filter/filter_map.go index f04d7de8..f71e35a6 100644 --- a/waku/v2/protocol/legacy_filter/filter_map.go +++ b/waku/v2/protocol/legacy_filter/filter_map.go @@ -104,10 +104,6 @@ func (fm *FilterMap) Notify(msg *pb.WakuMessage, requestID string) { // Broadcasting message so it's stored fm.broadcaster.Submit(envelope) - if msg.ContentTopic == "" { - filter.Chan <- envelope - } - // TODO: In case of no topics we should either trigger here for all messages, // or we should not allow such filter to exist in the first place. for _, contentTopic := range filter.ContentFilters { diff --git a/waku/v2/protocol/lightpush/metrics.go b/waku/v2/protocol/lightpush/metrics.go index 857243b5..8938bcd7 100644 --- a/waku/v2/protocol/lightpush/metrics.go +++ b/waku/v2/protocol/lightpush/metrics.go @@ -49,14 +49,14 @@ func (m *metricsImpl) RecordMessage() { type metricsErrCategory string var ( - decodeRPCFailure metricsErrCategory = "decode_rpc_failure" - writeRequestFailure metricsErrCategory = "write_request_failure" - writeResponseFailure metricsErrCategory = "write_response_failure" - dialFailure metricsErrCategory = "dial_failure" - messagePushFailure metricsErrCategory = "message_push_failure" - emptyRequestBodyFailure metricsErrCategory = "empty_request_body_failure" - emptyResponseBodyFailure metricsErrCategory = "empty_response_body_failure" - peerNotFoundFailure metricsErrCategory = "peer_not_found_failure" + decodeRPCFailure metricsErrCategory = "decode_rpc_failure" + writeRequestFailure metricsErrCategory = "write_request_failure" + writeResponseFailure metricsErrCategory = "write_response_failure" + dialFailure metricsErrCategory = "dial_failure" + messagePushFailure metricsErrCategory = "message_push_failure" + requestBodyFailure metricsErrCategory = "request_failure" + responseBodyFailure metricsErrCategory = "response_body_failure" + peerNotFoundFailure metricsErrCategory = "peer_not_found_failure" ) // RecordError increases the counter for different error types diff --git a/waku/v2/protocol/lightpush/pb/validation.go b/waku/v2/protocol/lightpush/pb/validation.go new file mode 100644 index 00000000..c2f0218b --- /dev/null +++ b/waku/v2/protocol/lightpush/pb/validation.go @@ -0,0 +1,48 @@ +package pb + +import "errors" + +var ( + errMissingRequestID = errors.New("missing RequestId field") + errMissingQuery = errors.New("missing Query field") + errMissingMessage = errors.New("missing Message field") + errMissingPubsubTopic = errors.New("missing PubsubTopic field") + errRequestIDMismatch = errors.New("RequestID in response does not match request") + errMissingResponse = errors.New("missing Response field") +) + +func (x *PushRPC) ValidateRequest() error { + if x.RequestId == "" { + return errMissingRequestID + } + + if x.Query == nil { + return errMissingQuery + } + + if x.Query.PubsubTopic == "" { + return errMissingPubsubTopic + } + + if x.Query.Message == nil { + return errMissingMessage + } + + return x.Query.Message.Validate() +} + +func (x *PushRPC) ValidateResponse(requestID string) error { + if x.RequestId == "" { + return errMissingRequestID + } + + if x.RequestId != requestID { + return errRequestIDMismatch + } + + if x.Response == nil { + return errMissingResponse + } + + return nil +} diff --git a/waku/v2/protocol/lightpush/pb/validation_test.go b/waku/v2/protocol/lightpush/pb/validation_test.go new file mode 100644 index 00000000..208edf2b --- /dev/null +++ b/waku/v2/protocol/lightpush/pb/validation_test.go @@ -0,0 +1,35 @@ +package pb + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/waku-org/go-waku/waku/v2/protocol/pb" +) + +func TestValidateRequest(t *testing.T) { + request := PushRPC{} + require.ErrorIs(t, request.ValidateRequest(), errMissingRequestID) + request.RequestId = "test" + require.ErrorIs(t, request.ValidateRequest(), errMissingQuery) + request.Query = &PushRequest{} + require.ErrorIs(t, request.ValidateRequest(), errMissingPubsubTopic) + request.Query.PubsubTopic = "test" + require.ErrorIs(t, request.ValidateRequest(), errMissingMessage) + request.Query.Message = &pb.WakuMessage{ + Payload: []byte{1, 2, 3}, + ContentTopic: "test", + } + require.NoError(t, request.ValidateRequest()) +} + +func TestValidateResponse(t *testing.T) { + response := PushRPC{} + require.ErrorIs(t, response.ValidateResponse("test"), errMissingRequestID) + response.RequestId = "test1" + require.ErrorIs(t, response.ValidateResponse("test"), errRequestIDMismatch) + response.RequestId = "test" + require.ErrorIs(t, response.ValidateResponse("test"), errMissingResponse) + response.Response = &PushResponse{} + require.NoError(t, response.ValidateResponse("test")) +} diff --git a/waku/v2/protocol/lightpush/waku_lightpush.go b/waku/v2/protocol/lightpush/waku_lightpush.go index 931de4fb..a54dee24 100644 --- a/waku/v2/protocol/lightpush/waku_lightpush.go +++ b/waku/v2/protocol/lightpush/waku_lightpush.go @@ -4,6 +4,7 @@ import ( "context" "encoding/hex" "errors" + "fmt" "math" "github.com/libp2p/go-libp2p/core/host" @@ -81,7 +82,6 @@ func (wakuLP *WakuLightPush) onRequest(ctx context.Context) func(network.Stream) logger := wakuLP.log.With(logging.HostID("peer", stream.Conn().RemotePeer())) requestPushRPC := &pb.PushRPC{} - writer := pbio.NewDelimitedWriter(stream) reader := pbio.NewDelimitedReader(stream, math.MaxInt32) err := reader.ReadMsg(requestPushRPC) @@ -94,67 +94,69 @@ func (wakuLP *WakuLightPush) onRequest(ctx context.Context) func(network.Stream) return } + responsePushRPC := &pb.PushRPC{ + RequestId: requestPushRPC.RequestId, + Response: &pb.PushResponse{}, + } + + if err := requestPushRPC.ValidateRequest(); err != nil { + responsePushRPC.Response.Info = err.Error() + wakuLP.metrics.RecordError(requestBodyFailure) + wakuLP.reply(stream, responsePushRPC, logger) + return + } + logger = logger.With(zap.String("requestID", requestPushRPC.RequestId)) - responsePushRPC := &pb.PushRPC{} - responsePushRPC.RequestId = requestPushRPC.RequestId + logger.Info("push request") - if requestPushRPC.Query != nil { - logger.Info("push request") - response := new(pb.PushResponse) + pubSubTopic := requestPushRPC.Query.PubsubTopic + message := requestPushRPC.Query.Message - pubSubTopic := requestPushRPC.Query.PubsubTopic - message := requestPushRPC.Query.Message + wakuLP.metrics.RecordMessage() - wakuLP.metrics.RecordMessage() + // TODO: Assumes success, should probably be extended to check for network, peers, etc + // It might make sense to use WithReadiness option here? - // TODO: Assumes success, should probably be extended to check for network, peers, etc - // It might make sense to use WithReadiness option here? - - _, err := wakuLP.relay.PublishToTopic(ctx, message, pubSubTopic) - - if err != nil { - logger.Error("publishing message", zap.Error(err)) - wakuLP.metrics.RecordError(messagePushFailure) - response.Info = "Could not publish message" - } else { - response.IsSuccess = true - response.Info = "OK" - } - - responsePushRPC.Response = response - - err = writer.WriteMsg(responsePushRPC) - if err != nil { - wakuLP.metrics.RecordError(writeResponseFailure) - logger.Error("writing response", zap.Error(err)) - if err := stream.Reset(); err != nil { - wakuLP.log.Error("resetting connection", zap.Error(err)) - } - return - } - - logger.Info("response sent") - stream.Close() + _, err = wakuLP.relay.PublishToTopic(ctx, message, pubSubTopic) + if err != nil { + logger.Error("publishing message", zap.Error(err)) + wakuLP.metrics.RecordError(messagePushFailure) + responsePushRPC.Response.Info = fmt.Sprintf("Could not publish message: %s", err.Error()) + return } else { - wakuLP.metrics.RecordError(emptyRequestBodyFailure) - if err := stream.Reset(); err != nil { - wakuLP.log.Error("resetting connection", zap.Error(err)) - } + responsePushRPC.Response.IsSuccess = true + responsePushRPC.Response.Info = "OK" } - if requestPushRPC.Response != nil { - if requestPushRPC.Response.IsSuccess { - logger.Info("request success") - } else { - logger.Info("request failure", zap.String("info=", requestPushRPC.Response.Info)) - } + wakuLP.reply(stream, responsePushRPC, logger) + + logger.Info("response sent") + + stream.Close() + + if responsePushRPC.Response.IsSuccess { + logger.Info("request success") } else { - wakuLP.metrics.RecordError(emptyResponseBodyFailure) + logger.Info("request failure", zap.String("info", responsePushRPC.Response.Info)) } } } +func (wakuLP *WakuLightPush) reply(stream network.Stream, responsePushRPC *pb.PushRPC, logger *zap.Logger) { + writer := pbio.NewDelimitedWriter(stream) + err := writer.WriteMsg(responsePushRPC) + if err != nil { + wakuLP.metrics.RecordError(writeResponseFailure) + logger.Error("writing response", zap.Error(err)) + if err := stream.Reset(); err != nil { + wakuLP.log.Error("resetting connection", zap.Error(err)) + } + return + } + stream.Close() +} + // request sends a message via lightPush protocol to either a specified peer or peer that is selected. func (wakuLP *WakuLightPush) request(ctx context.Context, req *pb.PushRequest, params *lightPushParameters) (*pb.PushResponse, error) { if params == nil { @@ -201,6 +203,11 @@ func (wakuLP *WakuLightPush) request(ctx context.Context, req *pb.PushRequest, p stream.Close() + if err = pushResponseRPC.ValidateResponse(pushRequestRPC.RequestId); err != nil { + wakuLP.metrics.RecordError(responseBodyFailure) + return nil, err + } + return pushResponseRPC.Response, nil } diff --git a/waku/v2/protocol/pb/validation.go b/waku/v2/protocol/pb/validation.go new file mode 100644 index 00000000..fb1e000b --- /dev/null +++ b/waku/v2/protocol/pb/validation.go @@ -0,0 +1,47 @@ +package pb + +import ( + "errors" + + "google.golang.org/protobuf/proto" +) + +const MaxMetaAttrLength = 64 + +var ( + errMissingPayload = errors.New("missing Payload field") + errMissingContentTopic = errors.New("missing ContentTopic field") + errInvalidMetaLength = errors.New("invalid length for Meta field") +) + +func (msg *WakuMessage) Validate() error { + if len(msg.Payload) == 0 { + return errMissingPayload + } + + if msg.ContentTopic == "" { + return errMissingContentTopic + } + + if len(msg.Meta) > MaxMetaAttrLength { + return errInvalidMetaLength + } + + return nil +} + +func Unmarshal(data []byte) (*WakuMessage, error) { + msg := &WakuMessage{} + err := proto.Unmarshal(data, msg) + if err != nil { + return nil, err + } + + err = msg.Validate() + if err != nil { + return nil, err + } + + return msg, nil + +} diff --git a/waku/v2/protocol/relay/validators.go b/waku/v2/protocol/relay/validators.go index 1405179a..558d7f99 100644 --- a/waku/v2/protocol/relay/validators.go +++ b/waku/v2/protocol/relay/validators.go @@ -12,7 +12,6 @@ import ( "github.com/ethereum/go-ethereum/crypto/secp256k1" pubsub "github.com/libp2p/go-libp2p-pubsub" "github.com/libp2p/go-libp2p/core/peer" - proto "google.golang.org/protobuf/proto" "github.com/waku-org/go-waku/waku/v2/hash" "github.com/waku-org/go-waku/waku/v2/protocol/pb" @@ -62,8 +61,7 @@ func (w *WakuRelay) RemoveTopicValidator(topic string) { func (w *WakuRelay) topicValidator(topic string) func(ctx context.Context, peerID peer.ID, message *pubsub.Message) bool { return func(ctx context.Context, peerID peer.ID, message *pubsub.Message) bool { - msg := new(pb.WakuMessage) - err := proto.Unmarshal(message.Data, msg) + msg, err := pb.Unmarshal(message.Data) if err != nil { return false } diff --git a/waku/v2/protocol/relay/waku_relay.go b/waku/v2/protocol/relay/waku_relay.go index 465f8268..6f6ea12e 100644 --- a/waku/v2/protocol/relay/waku_relay.go +++ b/waku/v2/protocol/relay/waku_relay.go @@ -234,12 +234,15 @@ func (w *WakuRelay) PublishToTopic(ctx context.Context, message *pb.WakuMessage, return nil, errors.New("message can't be null") } + if err := message.Validate(); err != nil { + return nil, err + } + if !w.EnoughPeersToPublishToTopic(topic) { return nil, errors.New("not enough peers to publish") } pubSubTopic, err := w.upsertTopic(topic) - if err != nil { return nil, err } @@ -461,11 +464,13 @@ func (w *WakuRelay) pubsubTopicMsgHandler(pubsubTopic string, sub *pubsub.Subscr sub.Cancel() return } - wakuMessage := &pb.WakuMessage{} - if err := proto.Unmarshal(msg.Data, wakuMessage); err != nil { + + wakuMessage, err := pb.Unmarshal(msg.Data) + if err != nil { w.log.Error("decoding message", zap.Error(err)) return } + envelope := waku_proto.NewEnvelope(wakuMessage, w.timesource.Now().UnixNano(), pubsubTopic) w.metrics.RecordMessage(envelope)