go-waku/waku/v2/protocol/metadata/waku_metadata.go

342 lines
9.2 KiB
Go

package metadata
import (
"context"
"errors"
"math"
"sync"
"time"
"github.com/ethereum/go-ethereum/p2p/enode"
"github.com/libp2p/go-libp2p/core/host"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
libp2pProtocol "github.com/libp2p/go-libp2p/core/protocol"
"github.com/libp2p/go-msgio/pbio"
"github.com/multiformats/go-multiaddr"
"github.com/waku-org/go-waku/logging"
"github.com/waku-org/go-waku/waku/v2/protocol"
"github.com/waku-org/go-waku/waku/v2/protocol/enr"
"github.com/waku-org/go-waku/waku/v2/protocol/metadata/pb"
"go.uber.org/zap"
)
// MetadataID_v1 is the current Waku Metadata protocol identifier
const MetadataID_v1 = libp2pProtocol.ID("/vac/waku/metadata/1.0.0")
// WakuMetadata is the implementation of the Waku Metadata protocol
type WakuMetadata struct {
network.Notifiee
h host.Host
ctx context.Context
cancel context.CancelFunc
clusterID uint16
localnode *enode.LocalNode
peerShardsMutex sync.RWMutex
peerShards map[peer.ID][]uint16
log *zap.Logger
}
// NewWakuMetadata returns a new instance of Waku Metadata struct
// Takes an optional peermanager if WakuLightPush is being created along with WakuNode.
// If using libp2p host, then pass peermanager as nil
func NewWakuMetadata(clusterID uint16, localnode *enode.LocalNode, log *zap.Logger) *WakuMetadata {
m := new(WakuMetadata)
m.log = log.Named("metadata")
m.clusterID = clusterID
m.localnode = localnode
return m
}
// Sets the host to be able to mount or consume a protocol
func (wakuM *WakuMetadata) SetHost(h host.Host) {
wakuM.h = h
}
// Start inits the metadata protocol
func (wakuM *WakuMetadata) Start(ctx context.Context) error {
if wakuM.clusterID == 0 {
wakuM.log.Warn("no clusterID is specified. Protocol will not be initialized")
return nil
}
ctx, cancel := context.WithCancel(ctx)
wakuM.ctx = ctx
wakuM.cancel = cancel
wakuM.peerShards = make(map[peer.ID][]uint16)
wakuM.h.SetStreamHandlerMatch(MetadataID_v1, protocol.PrefixTextMatch(string(MetadataID_v1)), wakuM.onRequest(ctx))
wakuM.h.Network().Notify(wakuM)
wakuM.log.Info("metadata protocol started")
return nil
}
func (wakuM *WakuMetadata) RelayShard() (*protocol.RelayShards, error) {
return enr.RelaySharding(wakuM.localnode.Node().Record())
}
func (wakuM *WakuMetadata) ClusterAndShards() (*uint32, []uint32, error) {
shard, err := wakuM.RelayShard()
if err != nil {
return nil, nil, err
}
var shards []uint32
if shard != nil && shard.ClusterID == uint16(wakuM.clusterID) {
for _, idx := range shard.ShardIDs {
shards = append(shards, uint32(idx))
}
}
u32ClusterID := uint32(wakuM.clusterID)
return &u32ClusterID, shards, nil
}
func (wakuM *WakuMetadata) Request(ctx context.Context, peerID peer.ID) (*protocol.RelayShards, error) {
logger := wakuM.log.With(logging.HostID("peer", peerID))
stream, err := wakuM.h.NewStream(ctx, peerID, MetadataID_v1)
if err != nil {
logger.Error("creating stream to peer", zap.Error(err))
return nil, err
}
clusterID, shards, err := wakuM.ClusterAndShards()
if err != nil {
if err := stream.Reset(); err != nil {
wakuM.log.Error("resetting connection", zap.Error(err))
}
return nil, err
}
request := &pb.WakuMetadataRequest{}
request.ClusterId = clusterID
request.Shards = shards
// TODO: remove with nwaku 0.28 deployment
request.ShardsDeprecated = shards // nolint: staticcheck
writer := pbio.NewDelimitedWriter(stream)
reader := pbio.NewDelimitedReader(stream, math.MaxInt32)
err = writer.WriteMsg(request)
if err != nil {
logger.Error("writing request", zap.Error(err))
if err := stream.Reset(); err != nil {
wakuM.log.Error("resetting connection", zap.Error(err))
}
return nil, err
}
response := &pb.WakuMetadataResponse{}
err = reader.ReadMsg(response)
if err != nil {
logger.Error("reading response", zap.Error(err))
if err := stream.Reset(); err != nil {
wakuM.log.Error("resetting connection", zap.Error(err))
}
return nil, err
}
stream.Close()
if response.ClusterId == nil {
return nil, errors.New("node did not provide a waku clusterid")
}
rClusterID := uint16(*response.ClusterId)
var rShardIDs []uint16
if len(response.Shards) != 0 {
for _, i := range response.Shards {
rShardIDs = append(rShardIDs, uint16(i))
}
} else {
// TODO: remove with nwaku 0.28 deployment
for _, i := range response.ShardsDeprecated { // nolint: staticcheck
rShardIDs = append(rShardIDs, uint16(i))
}
}
rs, err := protocol.NewRelayShards(rClusterID, rShardIDs...)
if err != nil {
return nil, err
}
return &rs, nil
}
func (wakuM *WakuMetadata) onRequest(ctx context.Context) func(network.Stream) {
return func(stream network.Stream) {
logger := wakuM.log.With(logging.HostID("peer", stream.Conn().RemotePeer()))
request := &pb.WakuMetadataRequest{}
writer := pbio.NewDelimitedWriter(stream)
reader := pbio.NewDelimitedReader(stream, math.MaxInt32)
err := reader.ReadMsg(request)
if err != nil {
logger.Error("reading request", zap.Error(err))
if err := stream.Reset(); err != nil {
wakuM.log.Error("resetting connection", zap.Error(err))
}
return
}
response := new(pb.WakuMetadataResponse)
clusterID, shards, err := wakuM.ClusterAndShards()
if err != nil {
logger.Error("obtaining shard info", zap.Error(err))
} else {
response.ClusterId = clusterID
response.Shards = shards
// TODO: remove with nwaku 0.28 deployment
response.ShardsDeprecated = shards // nolint: staticcheck
}
err = writer.WriteMsg(response)
if err != nil {
logger.Error("writing response", zap.Error(err))
if err := stream.Reset(); err != nil {
wakuM.log.Error("resetting connection", zap.Error(err))
}
return
}
stream.Close()
}
}
// Stop unmounts the metadata protocol
func (wakuM *WakuMetadata) Stop() {
if wakuM.cancel == nil {
return
}
wakuM.h.Network().StopNotify(wakuM)
wakuM.cancel()
wakuM.h.RemoveStreamHandler(MetadataID_v1)
}
// Listen is called when network starts listening on an addr
func (wakuM *WakuMetadata) Listen(n network.Network, m multiaddr.Multiaddr) {
// Do nothing
}
// ListenClose is called when network stops listening on an address
func (wakuM *WakuMetadata) ListenClose(n network.Network, m multiaddr.Multiaddr) {
// Do nothing
}
func (wakuM *WakuMetadata) disconnectPeer(peerID peer.ID, reason error) {
logger := wakuM.log.With(logging.HostID("peerID", peerID))
logger.Error("disconnecting from peer", zap.Error(reason))
wakuM.h.Peerstore().RemovePeer(peerID)
if err := wakuM.h.Network().ClosePeer(peerID); err != nil {
logger.Error("could not disconnect from peer", zap.Error(err))
}
}
// Connected is called when a connection is opened
func (wakuM *WakuMetadata) Connected(n network.Network, cc network.Conn) {
go func() {
// Metadata verification is done only if a clusterID is specified
if wakuM.clusterID == 0 {
return
}
peerID := cc.RemotePeer()
shard, err := wakuM.Request(wakuM.ctx, peerID)
if err != nil {
wakuM.disconnectPeer(peerID, err)
return
}
if shard.ClusterID != wakuM.clusterID {
wakuM.disconnectPeer(peerID, errors.New("different clusterID reported"))
return
}
// Store shards so they're used to verify if a relay peer supports the same shards we do
wakuM.peerShardsMutex.Lock()
defer wakuM.peerShardsMutex.Unlock()
wakuM.peerShards[peerID] = shard.ShardIDs
}()
}
// Disconnected is called when a connection closed
func (wakuM *WakuMetadata) Disconnected(n network.Network, cc network.Conn) {
// We no longer need the shard info for that peer
wakuM.peerShardsMutex.Lock()
defer wakuM.peerShardsMutex.Unlock()
delete(wakuM.peerShards, cc.RemotePeer())
}
func (wakuM *WakuMetadata) GetPeerShards(ctx context.Context, peerID peer.ID) ([]uint16, error) {
// Already connected and we got the shard info, return immediatly
wakuM.peerShardsMutex.RLock()
shards, ok := wakuM.peerShards[peerID]
wakuM.peerShardsMutex.RUnlock()
if ok {
return shards, nil
}
// Shard info pending. Let's wait
t := time.NewTicker(200 * time.Millisecond)
defer t.Stop()
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-t.C:
wakuM.peerShardsMutex.RLock()
shards, ok := wakuM.peerShards[peerID]
wakuM.peerShardsMutex.RUnlock()
if ok {
return shards, nil
}
}
}
}
func (wakuM *WakuMetadata) disconnect(peerID peer.ID) {
wakuM.h.Peerstore().RemovePeer(peerID)
err := wakuM.h.Network().ClosePeer(peerID)
if err != nil {
wakuM.log.Error("disconnecting peer", logging.HostID("peerID", peerID), zap.Error(err))
}
}
func (wakuM *WakuMetadata) DisconnectPeerOnShardMismatch(ctx context.Context, peerID peer.ID) error {
peerShards, err := wakuM.GetPeerShards(ctx, peerID)
if err != nil {
wakuM.log.Error("could not obtain peer shards", zap.Error(err), logging.HostID("peerID", peerID))
wakuM.disconnect(peerID)
return err
}
rs, err := wakuM.RelayShard()
if err != nil {
wakuM.log.Error("could not obtain shards", zap.Error(err))
wakuM.disconnect(peerID)
return err
}
if !rs.ContainsAnyShard(rs.ClusterID, peerShards) {
wakuM.log.Info("shard mismatch", logging.HostID("peerID", peerID), zap.Uint16("clusterID", rs.ClusterID), zap.Uint16s("ourShardIDs", rs.ShardIDs), zap.Uint16s("theirShardIDs", peerShards))
wakuM.disconnect(peerID)
return errors.New("shard mismatch")
}
return nil
}