diff --git a/waku/cliutils/protected_topic.go b/waku/cliutils/protected_topic.go index 2537f45d..e3302690 100644 --- a/waku/cliutils/protected_topic.go +++ b/waku/cliutils/protected_topic.go @@ -1,20 +1,24 @@ package cliutils import ( + "crypto/ecdsa" + "encoding/hex" "errors" "fmt" "strings" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/crypto" ) type ProtectedTopic struct { - Topic string - Address common.Address + Topic string + PublicKey *ecdsa.PublicKey } func (p ProtectedTopic) String() string { - return fmt.Sprintf("%s:%s", p.Topic, p.Address.String()) + pubKBytes := crypto.FromECDSAPub(p.PublicKey) + return fmt.Sprintf("%s:%s", p.Topic, hex.EncodeToString(pubKBytes)) } type ProtectedTopicSlice struct { @@ -27,13 +31,13 @@ func (k *ProtectedTopicSlice) Set(value string) error { return errors.New("expected topic_name:hex_encoded_public_key") } - if !common.IsHexAddress(protectedTopicParts[1]) { - return errors.New("invalid address format") + pubk, err := crypto.UnmarshalPubkey(common.FromHex(protectedTopicParts[1])) + if err != nil { + return err } - *k.Values = append(*k.Values, ProtectedTopic{ - Topic: protectedTopicParts[0], - Address: common.HexToAddress(protectedTopicParts[1]), + Topic: protectedTopicParts[0], + PublicKey: pubk, }) return nil } diff --git a/waku/node.go b/waku/node.go index 759a063d..2e99ca54 100644 --- a/waku/node.go +++ b/waku/node.go @@ -316,7 +316,7 @@ func Execute(options Options) { } for _, protectedTopic := range options.Relay.ProtectedTopics { - err := wakuNode.Relay().AddSignedTopicValidator(protectedTopic.Topic, protectedTopic.Address) + err := wakuNode.Relay().AddSignedTopicValidator(protectedTopic.Topic, protectedTopic.PublicKey) failOnErr(err, "Error adding signed topic validator") } } diff --git a/waku/v2/protocol/relay/validators.go b/waku/v2/protocol/relay/validators.go index 418b6f4d..dac50714 100644 --- a/waku/v2/protocol/relay/validators.go +++ b/waku/v2/protocol/relay/validators.go @@ -4,10 +4,11 @@ import ( "bytes" "context" "crypto/ecdsa" + "crypto/elliptic" "encoding/binary" + "encoding/hex" "time" - "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto/secp256k1" @@ -55,7 +56,8 @@ func withinTimeWindow(t timesource.Timesource, msg *pb.WakuMessage) bool { type validatorFn = func(ctx context.Context, peerID peer.ID, message *pubsub.Message) bool -func validatorFnBuilder(t timesource.Timesource, address common.Address) (validatorFn, error) { +func validatorFnBuilder(t timesource.Timesource, publicKey *ecdsa.PublicKey) (validatorFn, error) { + address := crypto.PubkeyToAddress(*publicKey) topic := protocol.NewNamedShardingPubsubTopic(address.String() + "/proto").String() return func(ctx context.Context, peerID peer.ID, message *pubsub.Message) bool { msg := new(pb.WakuMessage) @@ -82,10 +84,10 @@ func validatorFnBuilder(t timesource.Timesource, address common.Address) (valida }, nil } -func (w *WakuRelay) AddSignedTopicValidator(topic string, address common.Address) error { - w.log.Info("adding validator to signed topic", zap.String("topic", topic), zap.String("address", address.String())) +func (w *WakuRelay) AddSignedTopicValidator(topic string, publicKey *ecdsa.PublicKey) error { + w.log.Info("adding validator to signed topic", zap.String("topic", topic), zap.String("publicKey", hex.EncodeToString(elliptic.Marshal(publicKey.Curve, publicKey.X, publicKey.Y)))) - fn, err := validatorFnBuilder(w.timesource, address) + fn, err := validatorFnBuilder(w.timesource, publicKey) if err != nil { return err } diff --git a/waku/v2/protocol/relay/validators_test.go b/waku/v2/protocol/relay/validators_test.go index 8b0820eb..352b1b1f 100644 --- a/waku/v2/protocol/relay/validators_test.go +++ b/waku/v2/protocol/relay/validators_test.go @@ -60,13 +60,11 @@ func TestMsgHash(t *testing.T) { msgData, _ := proto.Marshal(msg) - address := crypto.PubkeyToAddress(prvKey.PublicKey) - //expectedMessageHash, _ := hex.DecodeString("662F8C20A335F170BD60ABC1F02AD66F0C6A6EE285DA2A53C95259E7937C0AE9") //messageHash := MsgHash(pubsubTopic, msg) //require.True(t, bytes.Equal(expectedMessageHash, messageHash)) - myValidator, err := validatorFnBuilder(NewFakeTimesource(timestamp), address) + myValidator, err := validatorFnBuilder(NewFakeTimesource(timestamp), &prvKey.PublicKey) require.NoError(t, err) result := myValidator(context.Background(), "", &pubsub.Message{ Message: &pubsub_pb.Message{ @@ -77,7 +75,7 @@ func TestMsgHash(t *testing.T) { // Exceed 5m window in both directions now5m1sInPast := timestamp.Add(-5 * time.Minute).Add(-1 * time.Second) - myValidator, err = validatorFnBuilder(NewFakeTimesource(now5m1sInPast), address) + myValidator, err = validatorFnBuilder(NewFakeTimesource(now5m1sInPast), &prvKey.PublicKey) require.NoError(t, err) result = myValidator(context.Background(), "", &pubsub.Message{ Message: &pubsub_pb.Message{ @@ -87,7 +85,7 @@ func TestMsgHash(t *testing.T) { require.False(t, result) now5m1sInFuture := timestamp.Add(5 * time.Minute).Add(1 * time.Second) - myValidator, err = validatorFnBuilder(NewFakeTimesource(now5m1sInFuture), address) + myValidator, err = validatorFnBuilder(NewFakeTimesource(now5m1sInFuture), &prvKey.PublicKey) require.NoError(t, err) result = myValidator(context.Background(), "", &pubsub.Message{ Message: &pubsub_pb.Message{