From 19ba25ffcb3e4b11a7903ad709dbe49b53777b4f Mon Sep 17 00:00:00 2001 From: Richard Ramos Date: Sun, 15 Oct 2023 15:16:40 -0400 Subject: [PATCH] feat: metadata protocol --- cmd/waku/flags.go | 7 + cmd/waku/main.go | 1 + cmd/waku/node.go | 1 + cmd/waku/options.go | 1 + waku/v2/discv5/discover.go | 6 +- waku/v2/node/wakunode2.go | 12 + waku/v2/node/wakuoptions.go | 9 + waku/v2/peermanager/connection_gater.go | 17 +- waku/v2/protocol/metadata/pb/generate.go | 3 + .../protocol/metadata/pb/waku_metadata.pb.go | 229 +++++++++++++++++ .../protocol/metadata/pb/waku_metadata.proto | 13 + waku/v2/protocol/metadata/waku_metadata.go | 243 ++++++++++++++++++ .../protocol/metadata/waku_metadata_test.go | 154 +++++++++++ 13 files changed, 679 insertions(+), 17 deletions(-) create mode 100644 waku/v2/protocol/metadata/pb/generate.go create mode 100644 waku/v2/protocol/metadata/pb/waku_metadata.pb.go create mode 100644 waku/v2/protocol/metadata/pb/waku_metadata.proto create mode 100644 waku/v2/protocol/metadata/waku_metadata.go create mode 100644 waku/v2/protocol/metadata/waku_metadata_test.go diff --git a/cmd/waku/flags.go b/cmd/waku/flags.go index 071c3f84..e83e44dc 100644 --- a/cmd/waku/flags.go +++ b/cmd/waku/flags.go @@ -121,6 +121,13 @@ var ( Destination: &options.KeyPasswd, EnvVars: []string{"WAKUNODE2_KEY_PASSWORD"}, }) + ClusterID = altsrc.NewUintFlag(&cli.UintFlag{ + Name: "cluster-id", + Value: 0, + Usage: "Cluster id that the node is running in. Node in a different cluster id is disconnected.", + Destination: &options.ClusterID, + EnvVars: []string{"WAKUNODE2_CLUSTER_ID"}, + }) StaticNode = cliutils.NewGenericFlagMultiValue(&cli.GenericFlag{ Name: "staticnode", Usage: "Multiaddr of peer to directly connect with. Option may be repeated", diff --git a/cmd/waku/main.go b/cmd/waku/main.go index eeedb3a8..93ffd946 100644 --- a/cmd/waku/main.go +++ b/cmd/waku/main.go @@ -36,6 +36,7 @@ func main() { NodeKey, KeyFile, KeyPassword, + ClusterID, StaticNode, KeepAlive, PersistPeers, diff --git a/cmd/waku/node.go b/cmd/waku/node.go index 57387dd1..0126aaa2 100644 --- a/cmd/waku/node.go +++ b/cmd/waku/node.go @@ -141,6 +141,7 @@ func Execute(options NodeOptions) error { node.WithMaxPeerConnections(options.MaxPeerConnections), node.WithPrometheusRegisterer(prometheus.DefaultRegisterer), node.WithPeerStoreCapacity(options.PeerStoreCapacity), + node.WithClusterID(uint16(options.ClusterID)), } if len(options.AdvertiseAddresses) != 0 { nodeOpts = append(nodeOpts, node.WithAdvertiseAddresses(options.AdvertiseAddresses...)) diff --git a/cmd/waku/options.go b/cmd/waku/options.go index 904bdd55..35b79249 100644 --- a/cmd/waku/options.go +++ b/cmd/waku/options.go @@ -148,6 +148,7 @@ type RendezvousOptions struct { type NodeOptions struct { Port int Address string + ClusterID uint DNS4DomainName string NodeKey *ecdsa.PrivateKey KeyFile string diff --git a/waku/v2/discv5/discover.go b/waku/v2/discv5/discover.go index 67bd4a2d..e723d20e 100644 --- a/waku/v2/discv5/discover.go +++ b/waku/v2/discv5/discover.go @@ -397,10 +397,8 @@ func (d *DiscoveryV5) peerLoop(ctx context.Context) error { } if nodeRS == nil { - // TODO: Node has no shard registered. - // Since for now, status-go uses both mixed static and named shards, we assume the node is valid - // Once status-go uses only static shards, we can't return true anymore. - return true + // Node has no shards registered. + return false } if nodeRS.Cluster != localRS.Cluster { diff --git a/waku/v2/node/wakunode2.go b/waku/v2/node/wakunode2.go index c19e6aad..25228398 100644 --- a/waku/v2/node/wakunode2.go +++ b/waku/v2/node/wakunode2.go @@ -37,6 +37,7 @@ import ( "github.com/waku-org/go-waku/waku/v2/protocol/filter" "github.com/waku-org/go-waku/waku/v2/protocol/legacy_filter" "github.com/waku-org/go-waku/waku/v2/protocol/lightpush" + "github.com/waku-org/go-waku/waku/v2/protocol/metadata" "github.com/waku-org/go-waku/waku/v2/protocol/pb" "github.com/waku-org/go-waku/waku/v2/protocol/peer_exchange" "github.com/waku-org/go-waku/waku/v2/protocol/relay" @@ -94,6 +95,7 @@ type WakuNode struct { discoveryV5 Service peerExchange Service rendezvous Service + metadata Service legacyFilter ReceptorService filterFullNode ReceptorService filterLightNode Service @@ -253,6 +255,8 @@ func New(opts ...WakuNodeOption) (*WakuNode, error) { w.log.Error("creating localnode", zap.Error(err)) } + w.metadata = metadata.NewWakuMetadata(w.opts.clusterID, w.localNode, w.log) + //Initialize peer manager. w.peermanager = peermanager.NewPeerManager(w.opts.maxPeerConnections, w.opts.peerStoreCapacity, w.log) @@ -388,6 +392,12 @@ func (w *WakuNode) Start(ctx context.Context) error { go w.startKeepAlive(ctx, w.opts.keepAliveInterval) } + w.metadata.SetHost(host) + err = w.metadata.Start(ctx) + if err != nil { + return err + } + w.peerConnector.SetHost(host) w.peermanager.SetHost(host) err = w.peerConnector.Start(ctx) @@ -508,6 +518,8 @@ func (w *WakuNode) Stop() { defer w.identificationEventSub.Close() defer w.addressChangesSub.Close() + w.host.Network().StopNotify(w.connectionNotif) + w.relay.Stop() w.lightPush.Stop() w.store.Stop() diff --git a/waku/v2/node/wakuoptions.go b/waku/v2/node/wakuoptions.go index 9e364e21..39f04fe3 100644 --- a/waku/v2/node/wakuoptions.go +++ b/waku/v2/node/wakuoptions.go @@ -45,6 +45,7 @@ const defaultMinRelayPeersToPublish = 0 type WakuNodeParameters struct { hostAddr *net.TCPAddr + clusterID uint16 dns4Domain string advertiseAddrs []multiaddr.Multiaddr multiAddr []multiaddr.Multiaddr @@ -294,6 +295,14 @@ func WithPrivateKey(privKey *ecdsa.PrivateKey) WakuNodeOption { } } +// WithClusterID is used to set the node's ClusterID +func WithClusterID(clusterID uint16) WakuNodeOption { + return func(params *WakuNodeParameters) error { + params.clusterID = clusterID + return nil + } +} + // WithNTP is used to use ntp for any operation that requires obtaining time // A list of ntp servers can be passed but if none is specified, some defaults // will be used diff --git a/waku/v2/peermanager/connection_gater.go b/waku/v2/peermanager/connection_gater.go index 99abdcfa..5b9b761b 100644 --- a/waku/v2/peermanager/connection_gater.go +++ b/waku/v2/peermanager/connection_gater.go @@ -16,10 +16,8 @@ import ( // the number of connections per IP address type ConnectionGater struct { sync.Mutex - logger *zap.Logger - limiter map[string]int - inbound int - outbound int + logger *zap.Logger + limiter map[string]int } const maxConnsPerIP = 10 @@ -27,10 +25,8 @@ const maxConnsPerIP = 10 // NewConnectionGater creates a new instance of ConnectionGater func NewConnectionGater(logger *zap.Logger) *ConnectionGater { c := &ConnectionGater{ - logger: logger.Named("connection-gater"), - limiter: make(map[string]int), - inbound: 0, - outbound: 0, + logger: logger.Named("connection-gater"), + limiter: make(map[string]int), } return c @@ -61,11 +57,6 @@ func (c *ConnectionGater) InterceptAccept(n network.ConnMultiaddrs) (allow bool) return false } - if false { // inbound > someLimit - c.logger.Info("connection not accepted. Max inbound connections reached", zap.String("multiaddr", n.RemoteMultiaddr().String())) - return false - } - return true } diff --git a/waku/v2/protocol/metadata/pb/generate.go b/waku/v2/protocol/metadata/pb/generate.go new file mode 100644 index 00000000..09c5cdb1 --- /dev/null +++ b/waku/v2/protocol/metadata/pb/generate.go @@ -0,0 +1,3 @@ +package pb + +//go:generate protoc -I. --go_opt=paths=source_relative --go_opt=Mwaku_metadata.proto=github.com/waku-org/go-waku/waku/v2/protocol/metadata/pb --go_out=. ./waku_metadata.proto diff --git a/waku/v2/protocol/metadata/pb/waku_metadata.pb.go b/waku/v2/protocol/metadata/pb/waku_metadata.pb.go new file mode 100644 index 00000000..ee87a3c4 --- /dev/null +++ b/waku/v2/protocol/metadata/pb/waku_metadata.pb.go @@ -0,0 +1,229 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.31.0 +// protoc v3.21.12 +// source: waku_metadata.proto + +package pb + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type WakuMetadataRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + ClusterId *uint32 `protobuf:"varint,1,opt,name=cluster_id,json=clusterId,proto3,oneof" json:"cluster_id,omitempty"` + Shards []uint32 `protobuf:"varint,2,rep,packed,name=shards,proto3" json:"shards,omitempty"` +} + +func (x *WakuMetadataRequest) Reset() { + *x = WakuMetadataRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_waku_metadata_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *WakuMetadataRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*WakuMetadataRequest) ProtoMessage() {} + +func (x *WakuMetadataRequest) ProtoReflect() protoreflect.Message { + mi := &file_waku_metadata_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use WakuMetadataRequest.ProtoReflect.Descriptor instead. +func (*WakuMetadataRequest) Descriptor() ([]byte, []int) { + return file_waku_metadata_proto_rawDescGZIP(), []int{0} +} + +func (x *WakuMetadataRequest) GetClusterId() uint32 { + if x != nil && x.ClusterId != nil { + return *x.ClusterId + } + return 0 +} + +func (x *WakuMetadataRequest) GetShards() []uint32 { + if x != nil { + return x.Shards + } + return nil +} + +type WakuMetadataResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + ClusterId *uint32 `protobuf:"varint,1,opt,name=cluster_id,json=clusterId,proto3,oneof" json:"cluster_id,omitempty"` + Shards []uint32 `protobuf:"varint,2,rep,packed,name=shards,proto3" json:"shards,omitempty"` +} + +func (x *WakuMetadataResponse) Reset() { + *x = WakuMetadataResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_waku_metadata_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *WakuMetadataResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*WakuMetadataResponse) ProtoMessage() {} + +func (x *WakuMetadataResponse) ProtoReflect() protoreflect.Message { + mi := &file_waku_metadata_proto_msgTypes[1] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use WakuMetadataResponse.ProtoReflect.Descriptor instead. +func (*WakuMetadataResponse) Descriptor() ([]byte, []int) { + return file_waku_metadata_proto_rawDescGZIP(), []int{1} +} + +func (x *WakuMetadataResponse) GetClusterId() uint32 { + if x != nil && x.ClusterId != nil { + return *x.ClusterId + } + return 0 +} + +func (x *WakuMetadataResponse) GetShards() []uint32 { + if x != nil { + return x.Shards + } + return nil +} + +var File_waku_metadata_proto protoreflect.FileDescriptor + +var file_waku_metadata_proto_rawDesc = []byte{ + 0x0a, 0x13, 0x77, 0x61, 0x6b, 0x75, 0x5f, 0x6d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x2e, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x02, 0x70, 0x62, 0x22, 0x60, 0x0a, 0x13, 0x57, 0x61, 0x6b, + 0x75, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x12, 0x22, 0x0a, 0x0a, 0x63, 0x6c, 0x75, 0x73, 0x74, 0x65, 0x72, 0x5f, 0x69, 0x64, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x0d, 0x48, 0x00, 0x52, 0x09, 0x63, 0x6c, 0x75, 0x73, 0x74, 0x65, 0x72, 0x49, + 0x64, 0x88, 0x01, 0x01, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x68, 0x61, 0x72, 0x64, 0x73, 0x18, 0x02, + 0x20, 0x03, 0x28, 0x0d, 0x52, 0x06, 0x73, 0x68, 0x61, 0x72, 0x64, 0x73, 0x42, 0x0d, 0x0a, 0x0b, + 0x5f, 0x63, 0x6c, 0x75, 0x73, 0x74, 0x65, 0x72, 0x5f, 0x69, 0x64, 0x22, 0x61, 0x0a, 0x14, 0x57, + 0x61, 0x6b, 0x75, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x52, 0x65, 0x73, 0x70, 0x6f, + 0x6e, 0x73, 0x65, 0x12, 0x22, 0x0a, 0x0a, 0x63, 0x6c, 0x75, 0x73, 0x74, 0x65, 0x72, 0x5f, 0x69, + 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x48, 0x00, 0x52, 0x09, 0x63, 0x6c, 0x75, 0x73, 0x74, + 0x65, 0x72, 0x49, 0x64, 0x88, 0x01, 0x01, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x68, 0x61, 0x72, 0x64, + 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0d, 0x52, 0x06, 0x73, 0x68, 0x61, 0x72, 0x64, 0x73, 0x42, + 0x0d, 0x0a, 0x0b, 0x5f, 0x63, 0x6c, 0x75, 0x73, 0x74, 0x65, 0x72, 0x5f, 0x69, 0x64, 0x62, 0x06, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, +} + +var ( + file_waku_metadata_proto_rawDescOnce sync.Once + file_waku_metadata_proto_rawDescData = file_waku_metadata_proto_rawDesc +) + +func file_waku_metadata_proto_rawDescGZIP() []byte { + file_waku_metadata_proto_rawDescOnce.Do(func() { + file_waku_metadata_proto_rawDescData = protoimpl.X.CompressGZIP(file_waku_metadata_proto_rawDescData) + }) + return file_waku_metadata_proto_rawDescData +} + +var file_waku_metadata_proto_msgTypes = make([]protoimpl.MessageInfo, 2) +var file_waku_metadata_proto_goTypes = []interface{}{ + (*WakuMetadataRequest)(nil), // 0: pb.WakuMetadataRequest + (*WakuMetadataResponse)(nil), // 1: pb.WakuMetadataResponse +} +var file_waku_metadata_proto_depIdxs = []int32{ + 0, // [0:0] is the sub-list for method output_type + 0, // [0:0] is the sub-list for method input_type + 0, // [0:0] is the sub-list for extension type_name + 0, // [0:0] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name +} + +func init() { file_waku_metadata_proto_init() } +func file_waku_metadata_proto_init() { + if File_waku_metadata_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_waku_metadata_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*WakuMetadataRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_waku_metadata_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*WakuMetadataResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + file_waku_metadata_proto_msgTypes[0].OneofWrappers = []interface{}{} + file_waku_metadata_proto_msgTypes[1].OneofWrappers = []interface{}{} + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_waku_metadata_proto_rawDesc, + NumEnums: 0, + NumMessages: 2, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_waku_metadata_proto_goTypes, + DependencyIndexes: file_waku_metadata_proto_depIdxs, + MessageInfos: file_waku_metadata_proto_msgTypes, + }.Build() + File_waku_metadata_proto = out.File + file_waku_metadata_proto_rawDesc = nil + file_waku_metadata_proto_goTypes = nil + file_waku_metadata_proto_depIdxs = nil +} diff --git a/waku/v2/protocol/metadata/pb/waku_metadata.proto b/waku/v2/protocol/metadata/pb/waku_metadata.proto new file mode 100644 index 00000000..81e1c87e --- /dev/null +++ b/waku/v2/protocol/metadata/pb/waku_metadata.proto @@ -0,0 +1,13 @@ +syntax = "proto3"; + +package pb; + +message WakuMetadataRequest { + optional uint32 cluster_id = 1; + repeated uint32 shards = 2; +} + +message WakuMetadataResponse { + optional uint32 cluster_id = 1; + repeated uint32 shards = 2; +} \ No newline at end of file diff --git a/waku/v2/protocol/metadata/waku_metadata.go b/waku/v2/protocol/metadata/waku_metadata.go new file mode 100644 index 00000000..d5c2d832 --- /dev/null +++ b/waku/v2/protocol/metadata/waku_metadata.go @@ -0,0 +1,243 @@ +package metadata + +import ( + "context" + "errors" + "math" + "strings" + + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/libp2p/go-libp2p/core/host" + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/peer" + libp2pProtocol "github.com/libp2p/go-libp2p/core/protocol" + "github.com/libp2p/go-msgio/pbio" + "github.com/multiformats/go-multiaddr" + "github.com/waku-org/go-waku/logging" + "github.com/waku-org/go-waku/waku/v2/protocol" + "github.com/waku-org/go-waku/waku/v2/protocol/enr" + "github.com/waku-org/go-waku/waku/v2/protocol/metadata/pb" + "go.uber.org/zap" +) + +// MetadataID_v1 is the current Waku Metadata protocol identifier +const MetadataID_v1 = libp2pProtocol.ID("/vac/waku/metadata/1.0.0") + +// WakuMetadata is the implementation of the Waku Metadata protocol +type WakuMetadata struct { + network.Notifiee + + h host.Host + ctx context.Context + cancel context.CancelFunc + clusterID uint16 + localnode *enode.LocalNode + + log *zap.Logger +} + +// NewWakuMetadata returns a new instance of Waku Metadata struct +// Takes an optional peermanager if WakuLightPush is being created along with WakuNode. +// If using libp2p host, then pass peermanager as nil +func NewWakuMetadata(clusterID uint16, localnode *enode.LocalNode, log *zap.Logger) *WakuMetadata { + m := new(WakuMetadata) + m.log = log.Named("metadata") + m.clusterID = clusterID + m.localnode = localnode + + return m +} + +// Sets the host to be able to mount or consume a protocol +func (wakuM *WakuMetadata) SetHost(h host.Host) { + wakuM.h = h +} + +// Start inits the metadata protocol +func (wakuM *WakuMetadata) Start(ctx context.Context) error { + if wakuM.clusterID == 0 { + wakuM.log.Warn("no clusterID is specified. Protocol will not be initialized") + } + + ctx, cancel := context.WithCancel(ctx) + + wakuM.ctx = ctx + wakuM.cancel = cancel + + wakuM.h.Network().Notify(wakuM) + + wakuM.h.SetStreamHandlerMatch(MetadataID_v1, protocol.PrefixTextMatch(string(MetadataID_v1)), wakuM.onRequest(ctx)) + wakuM.log.Info("metadata protocol started") + return nil +} + +func (wakuM *WakuMetadata) getClusterAndShards() (*uint32, []uint32, error) { + shard, err := enr.RelaySharding(wakuM.localnode.Node().Record()) + if err != nil { + return nil, nil, err + } + + var shards []uint32 + if shard != nil && shard.Cluster == uint16(wakuM.clusterID) { + for _, idx := range shard.Indices { + shards = append(shards, uint32(idx)) + } + } + + u32ClusterID := uint32(wakuM.clusterID) + + return &u32ClusterID, shards, nil +} + +func (wakuM *WakuMetadata) Request(ctx context.Context, peerID peer.ID) (*protocol.RelayShards, error) { + logger := wakuM.log.With(logging.HostID("peer", peerID)) + + connOpt, err := wakuM.h.NewStream(ctx, peerID, MetadataID_v1) + if err != nil { + logger.Error("creating stream to peer", zap.Error(err)) + return nil, err + } + + defer connOpt.Close() + defer func() { + err := connOpt.Reset() + if err != nil { + logger.Error("resetting connection", zap.Error(err)) + } + }() + + clusterID, shards, err := wakuM.getClusterAndShards() + if err != nil { + return nil, err + } + + request := &pb.WakuMetadataRequest{} + request.ClusterId = clusterID + request.Shards = shards + + writer := pbio.NewDelimitedWriter(connOpt) + reader := pbio.NewDelimitedReader(connOpt, math.MaxInt32) + + err = writer.WriteMsg(request) + if err != nil { + logger.Error("writing request", zap.Error(err)) + return nil, err + } + + response := &pb.WakuMetadataResponse{} + err = reader.ReadMsg(response) + if err != nil { + logger.Error("reading response", zap.Error(err)) + return nil, err + } + + if response.ClusterId == nil { + return nil, nil // Node is not using sharding + } + + result := &protocol.RelayShards{} + result.Cluster = uint16(*response.ClusterId) + for _, i := range response.Shards { + result.Indices = append(result.Indices, uint16(i)) + } + + return result, nil +} + +func (wakuM *WakuMetadata) onRequest(ctx context.Context) func(s network.Stream) { + return func(s network.Stream) { + defer s.Close() + logger := wakuM.log.With(logging.HostID("peer", s.Conn().RemotePeer())) + request := &pb.WakuMetadataRequest{} + + writer := pbio.NewDelimitedWriter(s) + reader := pbio.NewDelimitedReader(s, math.MaxInt32) + + err := reader.ReadMsg(request) + if err != nil { + logger.Error("reading request", zap.Error(err)) + return + } + + response := new(pb.WakuMetadataResponse) + + clusterID, shards, err := wakuM.getClusterAndShards() + if err != nil { + logger.Error("obtaining shard info", zap.Error(err)) + } else { + response.ClusterId = clusterID + response.Shards = shards + } + + err = writer.WriteMsg(response) + if err != nil { + logger.Error("writing response", zap.Error(err)) + _ = s.Reset() + } + } +} + +// Stop unmounts the metadata protocol +func (wakuM *WakuMetadata) Stop() { + if wakuM.cancel == nil { + return + } + + wakuM.h.Network().StopNotify(wakuM) + wakuM.cancel() + wakuM.h.RemoveStreamHandler(MetadataID_v1) + +} + +// Listen is called when network starts listening on an addr +func (wakuM *WakuMetadata) Listen(n network.Network, m multiaddr.Multiaddr) { + // Do nothing +} + +// ListenClose is called when network stops listening on an address +func (wakuM *WakuMetadata) ListenClose(n network.Network, m multiaddr.Multiaddr) { + // Do nothing +} + +// Connected is called when a connection is opened +func (wakuM *WakuMetadata) Connected(n network.Network, cc network.Conn) { + go func() { + // Metadata verification is done only if a clusterID is specified + if wakuM.clusterID == 0 { + return + } + + peerID := cc.RemotePeer() + + logger := wakuM.log.With(logging.HostID("peerID", peerID)) + + shouldDisconnect := true + shard, err := wakuM.Request(wakuM.ctx, peerID) + if err == nil { + if shard == nil { + err = errors.New("no shard reported") + } else if shard.Cluster != wakuM.clusterID { + err = errors.New("different clusterID reported") + } + } else { + // Only disconnect from peers if they support the protocol + // TODO: open a PR in go-libp2p to create a var with this error to not have to compare strings but use errors.Is instead + if strings.Contains(err.Error(), "protocols not supported") { + shouldDisconnect = false + } + } + + if shouldDisconnect && err != nil { + logger.Error("disconnecting from peer", zap.Error(err)) + wakuM.h.Peerstore().RemovePeer(peerID) + if err := wakuM.h.Network().ClosePeer(peerID); err != nil { + logger.Error("could not disconnect from peer", zap.Error(err)) + } + } + }() +} + +// Disconnected is called when a connection closed +func (wakuM *WakuMetadata) Disconnected(n network.Network, cc network.Conn) { + // Do nothing +} diff --git a/waku/v2/protocol/metadata/waku_metadata_test.go b/waku/v2/protocol/metadata/waku_metadata_test.go new file mode 100644 index 00000000..7b59d2b8 --- /dev/null +++ b/waku/v2/protocol/metadata/waku_metadata_test.go @@ -0,0 +1,154 @@ +package metadata + +import ( + "context" + "crypto/rand" + "testing" + "time" + + gcrypto "github.com/ethereum/go-ethereum/crypto" + "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/core/peerstore" + "github.com/stretchr/testify/require" + "github.com/waku-org/go-waku/tests" + "github.com/waku-org/go-waku/waku/v2/protocol" + "github.com/waku-org/go-waku/waku/v2/protocol/enr" + "github.com/waku-org/go-waku/waku/v2/utils" +) + +func createWakuMetadata(t *testing.T, rs *protocol.RelayShards) *WakuMetadata { + port, err := tests.FindFreePort(t, "", 5) + require.NoError(t, err) + host, err := tests.MakeHost(context.Background(), port, rand.Reader) + require.NoError(t, err) + + key, _ := gcrypto.GenerateKey() + + localNode, err := enr.NewLocalnode(key) + require.NoError(t, err) + + cluster := uint16(0) + if rs != nil { + err = enr.WithWakuRelaySharding(*rs)(localNode) + require.NoError(t, err) + cluster = rs.Cluster + } + + m1 := NewWakuMetadata(cluster, localNode, utils.Logger()) + m1.SetHost(host) + err = m1.Start(context.TODO()) + require.NoError(t, err) + + return m1 +} + +func TestWakuMetadataRequest(t *testing.T) { + testShard16 := uint16(16) + + rs16_1, err := protocol.NewRelayShards(testShard16, 1) + require.NoError(t, err) + rs16_2, err := protocol.NewRelayShards(testShard16, 2) + require.NoError(t, err) + + m16_1 := createWakuMetadata(t, &rs16_1) + m16_2 := createWakuMetadata(t, &rs16_2) + m_noRS := createWakuMetadata(t, nil) + + m16_1.h.Peerstore().AddAddrs(m16_2.h.ID(), m16_2.h.Network().ListenAddresses(), peerstore.PermanentAddrTTL) + m16_1.h.Peerstore().AddAddrs(m_noRS.h.ID(), m_noRS.h.Network().ListenAddresses(), peerstore.PermanentAddrTTL) + + // Query a peer that is subscribed to a shard + result, err := m16_1.Request(context.Background(), m16_2.h.ID()) + require.NoError(t, err) + require.Equal(t, testShard16, result.Cluster) + require.Equal(t, rs16_2.Indices, result.Indices) + + // Updating the peer shards + rs16_2.Indices = append(rs16_2.Indices, 3, 4) + err = enr.WithWakuRelaySharding(rs16_2)(m16_2.localnode) + require.NoError(t, err) + + // Query same peer, after that peer subscribes to more shards + result, err = m16_1.Request(context.Background(), m16_2.h.ID()) + require.NoError(t, err) + require.Equal(t, testShard16, result.Cluster) + require.ElementsMatch(t, rs16_2.Indices, result.Indices) + + // Query a peer not subscribed to a shard + result, err = m16_1.Request(context.Background(), m_noRS.h.ID()) + require.NoError(t, err) + require.Equal(t, uint16(0), result.Cluster) + require.Len(t, result.Indices, 0) +} + +func TestNoNetwork(t *testing.T) { + cluster1 := uint16(1) + + rs1, err := protocol.NewRelayShards(cluster1, 1) + require.NoError(t, err) + m1 := createWakuMetadata(t, &rs1) + + // host2 does not support metadata protocol + port, err := tests.FindFreePort(t, "", 5) + require.NoError(t, err) + host2, err := tests.MakeHost(context.Background(), port, rand.Reader) + require.NoError(t, err) + + m1.h.Peerstore().AddAddrs(host2.ID(), host2.Network().ListenAddresses(), peerstore.PermanentAddrTTL) + _, err = m1.h.Network().DialPeer(context.TODO(), host2.ID()) + require.NoError(t, err) + + time.Sleep(2 * time.Second) + + // Verifying peer connections + require.Len(t, m1.h.Network().Peers(), 1) + require.Len(t, host2.Network().Peers(), 1) +} + +// go test -timeout 300s -run TestDropConnectionOnDiffNetworks github.com/waku-org/go-waku/waku/v2/protocol/metadata -count 1 -v + +func TestDropConnectionOnDiffNetworks(t *testing.T) { + cluster1 := uint16(1) + cluster2 := uint16(2) + + // Initializing metadata and peer managers + + rs1, err := protocol.NewRelayShards(cluster1, 1) + require.NoError(t, err) + m1 := createWakuMetadata(t, &rs1) + + rs2, err := protocol.NewRelayShards(cluster2, 1) + require.NoError(t, err) + m2 := createWakuMetadata(t, &rs2) + + rs3, err := protocol.NewRelayShards(cluster2, 1) + require.NoError(t, err) + m3 := createWakuMetadata(t, &rs3) + + // Creating connection between peers + + // 1->2 (fails) + m1.h.Peerstore().AddAddrs(m2.h.ID(), m2.h.Network().ListenAddresses(), peerstore.PermanentAddrTTL) + _, err = m1.h.Network().DialPeer(context.TODO(), m2.h.ID()) + require.NoError(t, err) + + // 1->3 (fails) + m1.h.Peerstore().AddAddrs(m3.h.ID(), m3.h.Network().ListenAddresses(), peerstore.PermanentAddrTTL) + _, err = m1.h.Network().DialPeer(context.TODO(), m3.h.ID()) + require.NoError(t, err) + + // 2->3 (succeeds) + m2.h.Peerstore().AddAddrs(m3.h.ID(), m3.h.Network().ListenAddresses(), peerstore.PermanentAddrTTL) + _, err = m2.h.Network().DialPeer(context.TODO(), m3.h.ID()) + require.NoError(t, err) + + time.Sleep(2 * time.Second) + + // Verifying peer connections + require.Len(t, m1.h.Network().Peers(), 0) + require.Len(t, m2.h.Network().Peers(), 1) + require.Len(t, m3.h.Network().Peers(), 1) + require.Equal(t, []peer.ID{m3.h.ID()}, m2.h.Network().Peers()) + require.Equal(t, []peer.ID{m2.h.ID()}, m3.h.Network().Peers()) + +}