go-waku/waku/v2/protocol/metadata/waku_metadata.go
2023-10-20 20:30:23 -04:00

244 lines
6.4 KiB
Go

package metadata
import (
"context"
"errors"
"math"
"strings"
"github.com/ethereum/go-ethereum/p2p/enode"
"github.com/libp2p/go-libp2p/core/host"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
libp2pProtocol "github.com/libp2p/go-libp2p/core/protocol"
"github.com/libp2p/go-msgio/pbio"
"github.com/multiformats/go-multiaddr"
"github.com/waku-org/go-waku/logging"
"github.com/waku-org/go-waku/waku/v2/protocol"
"github.com/waku-org/go-waku/waku/v2/protocol/enr"
"github.com/waku-org/go-waku/waku/v2/protocol/metadata/pb"
"go.uber.org/zap"
)
// MetadataID_v1 is the current Waku Metadata protocol identifier
const MetadataID_v1 = libp2pProtocol.ID("/vac/waku/metadata/1.0.0")
// WakuMetadata is the implementation of the Waku Metadata protocol
type WakuMetadata struct {
network.Notifiee
h host.Host
ctx context.Context
cancel context.CancelFunc
clusterID uint16
localnode *enode.LocalNode
log *zap.Logger
}
// NewWakuMetadata returns a new instance of Waku Metadata struct
// Takes an optional peermanager if WakuLightPush is being created along with WakuNode.
// If using libp2p host, then pass peermanager as nil
func NewWakuMetadata(clusterID uint16, localnode *enode.LocalNode, log *zap.Logger) *WakuMetadata {
m := new(WakuMetadata)
m.log = log.Named("metadata")
m.clusterID = clusterID
m.localnode = localnode
return m
}
// Sets the host to be able to mount or consume a protocol
func (wakuM *WakuMetadata) SetHost(h host.Host) {
wakuM.h = h
}
// Start inits the metadata protocol
func (wakuM *WakuMetadata) Start(ctx context.Context) error {
if wakuM.clusterID == 0 {
wakuM.log.Warn("no clusterID is specified. Protocol will not be initialized")
}
ctx, cancel := context.WithCancel(ctx)
wakuM.ctx = ctx
wakuM.cancel = cancel
wakuM.h.Network().Notify(wakuM)
wakuM.h.SetStreamHandlerMatch(MetadataID_v1, protocol.PrefixTextMatch(string(MetadataID_v1)), wakuM.onRequest(ctx))
wakuM.log.Info("metadata protocol started")
return nil
}
func (wakuM *WakuMetadata) getClusterAndShards() (*uint32, []uint32, error) {
shard, err := enr.RelaySharding(wakuM.localnode.Node().Record())
if err != nil {
return nil, nil, err
}
var shards []uint32
if shard != nil && shard.Cluster == uint16(wakuM.clusterID) {
for _, idx := range shard.Indices {
shards = append(shards, uint32(idx))
}
}
u32ClusterID := uint32(wakuM.clusterID)
return &u32ClusterID, shards, nil
}
func (wakuM *WakuMetadata) Request(ctx context.Context, peerID peer.ID) (*protocol.RelayShards, error) {
logger := wakuM.log.With(logging.HostID("peer", peerID))
connOpt, err := wakuM.h.NewStream(ctx, peerID, MetadataID_v1)
if err != nil {
logger.Error("creating stream to peer", zap.Error(err))
return nil, err
}
defer connOpt.Close()
defer func() {
err := connOpt.Reset()
if err != nil {
logger.Error("resetting connection", zap.Error(err))
}
}()
clusterID, shards, err := wakuM.getClusterAndShards()
if err != nil {
return nil, err
}
request := &pb.WakuMetadataRequest{}
request.ClusterId = clusterID
request.Shards = shards
writer := pbio.NewDelimitedWriter(connOpt)
reader := pbio.NewDelimitedReader(connOpt, math.MaxInt32)
err = writer.WriteMsg(request)
if err != nil {
logger.Error("writing request", zap.Error(err))
return nil, err
}
response := &pb.WakuMetadataResponse{}
err = reader.ReadMsg(response)
if err != nil {
logger.Error("reading response", zap.Error(err))
return nil, err
}
if response.ClusterId == nil {
return nil, nil // Node is not using sharding
}
result := &protocol.RelayShards{}
result.Cluster = uint16(*response.ClusterId)
for _, i := range response.Shards {
result.Indices = append(result.Indices, uint16(i))
}
return result, nil
}
func (wakuM *WakuMetadata) onRequest(ctx context.Context) func(s network.Stream) {
return func(s network.Stream) {
defer s.Close()
logger := wakuM.log.With(logging.HostID("peer", s.Conn().RemotePeer()))
request := &pb.WakuMetadataRequest{}
writer := pbio.NewDelimitedWriter(s)
reader := pbio.NewDelimitedReader(s, math.MaxInt32)
err := reader.ReadMsg(request)
if err != nil {
logger.Error("reading request", zap.Error(err))
return
}
response := new(pb.WakuMetadataResponse)
clusterID, shards, err := wakuM.getClusterAndShards()
if err != nil {
logger.Error("obtaining shard info", zap.Error(err))
} else {
response.ClusterId = clusterID
response.Shards = shards
}
err = writer.WriteMsg(response)
if err != nil {
logger.Error("writing response", zap.Error(err))
_ = s.Reset()
}
}
}
// Stop unmounts the metadata protocol
func (wakuM *WakuMetadata) Stop() {
if wakuM.cancel == nil {
return
}
wakuM.h.Network().StopNotify(wakuM)
wakuM.cancel()
wakuM.h.RemoveStreamHandler(MetadataID_v1)
}
// Listen is called when network starts listening on an addr
func (wakuM *WakuMetadata) Listen(n network.Network, m multiaddr.Multiaddr) {
// Do nothing
}
// ListenClose is called when network stops listening on an address
func (wakuM *WakuMetadata) ListenClose(n network.Network, m multiaddr.Multiaddr) {
// Do nothing
}
// Connected is called when a connection is opened
func (wakuM *WakuMetadata) Connected(n network.Network, cc network.Conn) {
go func() {
// Metadata verification is done only if a clusterID is specified
if wakuM.clusterID == 0 {
return
}
peerID := cc.RemotePeer()
logger := wakuM.log.With(logging.HostID("peerID", peerID))
shouldDisconnect := true
shard, err := wakuM.Request(wakuM.ctx, peerID)
if err == nil {
if shard == nil {
err = errors.New("no shard reported")
} else if shard.Cluster != wakuM.clusterID {
err = errors.New("different clusterID reported")
}
} else {
// Only disconnect from peers if they support the protocol
// TODO: open a PR in go-libp2p to create a var with this error to not have to compare strings but use errors.Is instead
if strings.Contains(err.Error(), "protocols not supported") {
shouldDisconnect = false
}
}
if shouldDisconnect && err != nil {
logger.Error("disconnecting from peer", zap.Error(err))
wakuM.h.Peerstore().RemovePeer(peerID)
if err := wakuM.h.Network().ClosePeer(peerID); err != nil {
logger.Error("could not disconnect from peer", zap.Error(err))
}
}
}()
}
// Disconnected is called when a connection closed
func (wakuM *WakuMetadata) Disconnected(n network.Network, cc network.Conn) {
// Do nothing
}