diff --git a/waku/v2/discv5/discover.go b/waku/v2/discv5/discover.go index c710ce54..01c98767 100644 --- a/waku/v2/discv5/discover.go +++ b/waku/v2/discv5/discover.go @@ -15,6 +15,7 @@ import ( "github.com/waku-org/go-discover/discover" "github.com/waku-org/go-waku/logging" "github.com/waku-org/go-waku/waku/v2/metrics" + "github.com/waku-org/go-waku/waku/v2/protocol/enr" "github.com/waku-org/go-waku/waku/v2/utils" "go.uber.org/zap" @@ -243,7 +244,7 @@ func evaluateNode(node *enode.Node) bool { return false }*/ - _, err := utils.EnodeToPeerInfo(node) + _, err := enr.EnodeToPeerInfo(node) if err != nil { metrics.RecordDiscV5Error(context.Background(), "peer_info_failure") @@ -295,7 +296,7 @@ func (d *DiscoveryV5) iterate(ctx context.Context) error { break } - _, addresses, err := utils.Multiaddress(iterator.Node()) + _, addresses, err := enr.Multiaddress(iterator.Node()) if err != nil { metrics.RecordDiscV5Error(context.Background(), "peer_info_failure") d.log.Error("extracting multiaddrs from enr", zap.Error(err)) diff --git a/waku/v2/discv5/discover_test.go b/waku/v2/discv5/discover_test.go index a06792ff..1a3b3925 100644 --- a/waku/v2/discv5/discover_test.go +++ b/waku/v2/discv5/discover_test.go @@ -13,6 +13,8 @@ import ( gcrypto "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/p2p/enr" + wenr "github.com/waku-org/go-waku/waku/v2/protocol/enr" + "github.com/multiformats/go-multiaddr" "github.com/stretchr/testify/require" "github.com/waku-org/go-waku/tests" @@ -45,14 +47,14 @@ func createHost(t *testing.T) (host.Host, int, *ecdsa.PrivateKey) { return host, port, privKey } -func newLocalnode(priv *ecdsa.PrivateKey, ipAddr *net.TCPAddr, udpPort int, wakuFlags utils.WakuEnrBitfield, advertiseAddr *net.IP, log *zap.Logger) (*enode.LocalNode, error) { +func newLocalnode(priv *ecdsa.PrivateKey, ipAddr *net.TCPAddr, udpPort int, wakuFlags wenr.WakuEnrBitfield, advertiseAddr *net.IP, log *zap.Logger) (*enode.LocalNode, error) { db, err := enode.OpenDB("") if err != nil { return nil, err } localnode := enode.NewLocalNode(db, priv) localnode.SetFallbackUDP(udpPort) - localnode.Set(enr.WithEntry(utils.WakuENRField, wakuFlags)) + localnode.Set(enr.WithEntry(wenr.WakuENRField, wakuFlags)) localnode.SetFallbackIP(net.IP{127, 0, 0, 1}) localnode.SetStaticIP(ipAddr.IP) @@ -103,7 +105,7 @@ func TestDiscV5(t *testing.T) { udpPort1, err := tests.FindFreeUDPPort(t, "127.0.0.1", 3) require.NoError(t, err) ip1, _ := extractIP(host1.Addrs()[0]) - l1, err := newLocalnode(prvKey1, ip1, udpPort1, utils.NewWakuEnrBitfield(true, true, true, true), nil, utils.Logger()) + l1, err := newLocalnode(prvKey1, ip1, udpPort1, wenr.NewWakuEnrBitfield(true, true, true, true), nil, utils.Logger()) require.NoError(t, err) peerconn1 := tests.NewTestPeerDiscoverer() d1, err := NewDiscoveryV5(prvKey1, l1, peerconn1, utils.Logger(), WithUDPPort(uint(udpPort1))) @@ -115,7 +117,7 @@ func TestDiscV5(t *testing.T) { ip2, _ := extractIP(host2.Addrs()[0]) udpPort2, err := tests.FindFreeUDPPort(t, "127.0.0.1", 3) require.NoError(t, err) - l2, err := newLocalnode(prvKey2, ip2, udpPort2, utils.NewWakuEnrBitfield(true, true, true, true), nil, utils.Logger()) + l2, err := newLocalnode(prvKey2, ip2, udpPort2, wenr.NewWakuEnrBitfield(true, true, true, true), nil, utils.Logger()) require.NoError(t, err) peerconn2 := tests.NewTestPeerDiscoverer() d2, err := NewDiscoveryV5(prvKey2, l2, peerconn2, utils.Logger(), WithUDPPort(uint(udpPort2)), WithBootnodes([]*enode.Node{d1.localnode.Node()})) @@ -127,7 +129,7 @@ func TestDiscV5(t *testing.T) { ip3, _ := extractIP(host3.Addrs()[0]) udpPort3, err := tests.FindFreeUDPPort(t, "127.0.0.1", 3) require.NoError(t, err) - l3, err := newLocalnode(prvKey3, ip3, udpPort3, utils.NewWakuEnrBitfield(true, true, true, true), nil, utils.Logger()) + l3, err := newLocalnode(prvKey3, ip3, udpPort3, wenr.NewWakuEnrBitfield(true, true, true, true), nil, utils.Logger()) require.NoError(t, err) peerconn3 := tests.NewTestPeerDiscoverer() d3, err := NewDiscoveryV5(prvKey3, l3, peerconn3, utils.Logger(), WithUDPPort(uint(udpPort3)), WithBootnodes([]*enode.Node{d2.localnode.Node()})) diff --git a/waku/v2/dnsdisc/enr.go b/waku/v2/dnsdisc/enr.go index 0ef33f1a..ac32ef38 100644 --- a/waku/v2/dnsdisc/enr.go +++ b/waku/v2/dnsdisc/enr.go @@ -8,7 +8,7 @@ import ( "github.com/ethereum/go-ethereum/p2p/enr" "github.com/libp2p/go-libp2p/core/peer" "github.com/waku-org/go-waku/waku/v2/metrics" - "github.com/waku-org/go-waku/waku/v2/utils" + wenr "github.com/waku-org/go-waku/waku/v2/protocol/enr" ma "github.com/multiformats/go-multiaddr" ) @@ -52,7 +52,7 @@ func RetrieveNodes(ctx context.Context, url string, opts ...DnsDiscoveryOption) } for _, node := range tree.Nodes() { - peerID, m, err := utils.Multiaddress(node) + peerID, m, err := wenr.Multiaddress(node) if err != nil { metrics.RecordDnsDiscoveryError(ctx, "peer_info_failure") return nil, err diff --git a/waku/v2/node/localnode.go b/waku/v2/node/localnode.go index f866e77b..5cb6a8e9 100644 --- a/waku/v2/node/localnode.go +++ b/waku/v2/node/localnode.go @@ -14,7 +14,7 @@ import ( "github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/p2p/enr" ma "github.com/multiformats/go-multiaddr" - "github.com/waku-org/go-waku/waku/v2/utils" + wenr "github.com/waku-org/go-waku/waku/v2/protocol/enr" "go.uber.org/zap" ) @@ -30,7 +30,7 @@ func writeMultiaddressField(localnode *enode.LocalNode, addrAggr []ma.Multiaddr) defer func() { if e := recover(); e != nil { // Deleting the multiaddr entry, as we could not write it succesfully - localnode.Delete(enr.WithEntry(utils.MultiaddrENRField, struct{}{})) + localnode.Delete(enr.WithEntry(wenr.MultiaddrENRField, struct{}{})) err = errors.New("could not write enr record") } }() @@ -46,7 +46,7 @@ func writeMultiaddressField(localnode *enode.LocalNode, addrAggr []ma.Multiaddr) } if len(fieldRaw) != 0 && len(fieldRaw) <= 100 { // Max length for multiaddr field before triggering the 300 bytes limit - localnode.Set(enr.WithEntry(utils.MultiaddrENRField, fieldRaw)) + localnode.Set(enr.WithEntry(wenr.MultiaddrENRField, fieldRaw)) } // This is to trigger the signing record err due to exceeding 300bytes limit @@ -55,9 +55,9 @@ func writeMultiaddressField(localnode *enode.LocalNode, addrAggr []ma.Multiaddr) return nil } -func (w *WakuNode) updateLocalNode(localnode *enode.LocalNode, multiaddrs []ma.Multiaddr, ipAddr *net.TCPAddr, udpPort uint, wakuFlags utils.WakuEnrBitfield, advertiseAddr []ma.Multiaddr, shouldAutoUpdate bool, log *zap.Logger) error { +func (w *WakuNode) updateLocalNode(localnode *enode.LocalNode, multiaddrs []ma.Multiaddr, ipAddr *net.TCPAddr, udpPort uint, wakuFlags wenr.WakuEnrBitfield, advertiseAddr []ma.Multiaddr, shouldAutoUpdate bool, log *zap.Logger) error { localnode.SetFallbackUDP(int(udpPort)) - localnode.Set(enr.WithEntry(utils.WakuENRField, wakuFlags)) + localnode.Set(enr.WithEntry(wenr.WakuENRField, wakuFlags)) localnode.SetFallbackIP(net.IP{127, 0, 0, 1}) if udpPort > math.MaxUint16 { diff --git a/waku/v2/node/wakunode2.go b/waku/v2/node/wakunode2.go index deb251c6..fb14f323 100644 --- a/waku/v2/node/wakunode2.go +++ b/waku/v2/node/wakunode2.go @@ -32,6 +32,7 @@ import ( v2 "github.com/waku-org/go-waku/waku/v2" "github.com/waku-org/go-waku/waku/v2/discv5" "github.com/waku-org/go-waku/waku/v2/metrics" + "github.com/waku-org/go-waku/waku/v2/protocol/enr" "github.com/waku-org/go-waku/waku/v2/protocol/filter" "github.com/waku-org/go-waku/waku/v2/protocol/legacy_filter" "github.com/waku-org/go-waku/waku/v2/protocol/lightpush" @@ -88,7 +89,7 @@ type WakuNode struct { store ReceptorService rlnRelay RLNRelay - wakuFlag utils.WakuEnrBitfield + wakuFlag enr.WakuEnrBitfield localNode *enode.LocalNode @@ -175,7 +176,7 @@ func New(opts ...WakuNodeOption) (*WakuNode, error) { w.log = params.logger.Named("node2") w.wg = &sync.WaitGroup{} w.keepAliveFails = make(map[peer.ID]int) - w.wakuFlag = utils.NewWakuEnrBitfield(w.opts.enableLightPush, w.opts.enableLegacyFilter, w.opts.enableStore, w.opts.enableRelay) + w.wakuFlag = enr.NewWakuEnrBitfield(w.opts.enableLightPush, w.opts.enableLegacyFilter, w.opts.enableStore, w.opts.enableRelay) if params.enableNTP { w.timesource = timesource.NewNTPTimesource(w.opts.ntpURLs, w.log) diff --git a/waku/v2/utils/enr.go b/waku/v2/protocol/enr/enr.go similarity index 92% rename from waku/v2/utils/enr.go rename to waku/v2/protocol/enr/enr.go index 4f7f4707..9d6bec32 100644 --- a/waku/v2/utils/enr.go +++ b/waku/v2/protocol/enr/enr.go @@ -1,4 +1,4 @@ -package utils +package enr import ( "encoding/binary" @@ -9,6 +9,7 @@ import ( "github.com/ethereum/go-ethereum/p2p/enr" "github.com/libp2p/go-libp2p/core/peer" "github.com/multiformats/go-multiaddr" + "github.com/waku-org/go-waku/waku/v2/utils" ) // WakuENRField is the name of the ENR field that contains information about which protocols are supported by the node @@ -18,6 +19,10 @@ const WakuENRField = "waku2" // already available ENR fields (i.e. in the case of websocket connections) const MultiaddrENRField = "multiaddrs" +const ShardingIndicesListEnrField = "rs" + +const ShardingBitVectorEnrField = "rsv" + // WakuEnrBitfield is a8-bit flag field to indicate Waku capabilities. Only the 4 LSBs are currently defined according to RFC31 (https://rfc.vac.dev/spec/31/). type WakuEnrBitfield = uint8 @@ -46,7 +51,7 @@ func NewWakuEnrBitfield(lightpush, filter, store, relay bool) WakuEnrBitfield { // EnodeToMultiaddress converts an enode into a multiaddress func enodeToMultiAddr(node *enode.Node) (multiaddr.Multiaddr, error) { - pubKey := EcdsaPubKeyToSecp256k1PublicKey(node.Pubkey()) + pubKey := utils.EcdsaPubKeyToSecp256k1PublicKey(node.Pubkey()) peerID, err := peer.IDFromPublicKey(pubKey) if err != nil { return nil, err @@ -57,7 +62,7 @@ func enodeToMultiAddr(node *enode.Node) (multiaddr.Multiaddr, error) { // Multiaddress is used to extract all the multiaddresses that are part of a ENR record func Multiaddress(node *enode.Node) (peer.ID, []multiaddr.Multiaddr, error) { - pubKey := EcdsaPubKeyToSecp256k1PublicKey(node.Pubkey()) + pubKey := utils.EcdsaPubKeyToSecp256k1PublicKey(node.Pubkey()) peerID, err := peer.IDFromPublicKey(pubKey) if err != nil { return "", nil, err diff --git a/waku/v2/utils/enr_test.go b/waku/v2/protocol/enr/enr_test.go similarity index 98% rename from waku/v2/utils/enr_test.go rename to waku/v2/protocol/enr/enr_test.go index 00e486bc..5370c614 100644 --- a/waku/v2/utils/enr_test.go +++ b/waku/v2/protocol/enr/enr_test.go @@ -1,4 +1,4 @@ -package utils +package enr import ( "encoding/binary" @@ -14,6 +14,7 @@ import ( "github.com/ethereum/go-ethereum/p2p/enr" ma "github.com/multiformats/go-multiaddr" "github.com/stretchr/testify/require" + "github.com/waku-org/go-waku/waku/v2/utils" "go.uber.org/zap" ) @@ -151,7 +152,7 @@ func TestMultiaddr(t *testing.T) { db, _ := enode.OpenDB("") localNode := enode.NewLocalNode(db, key) - err := updateLocalNode(localNode, multiaddrValues, &net.TCPAddr{IP: net.IPv4(192, 168, 1, 241), Port: 60000}, 50000, wakuFlag, nil, false, Logger()) + err := updateLocalNode(localNode, multiaddrValues, &net.TCPAddr{IP: net.IPv4(192, 168, 1, 241), Port: 60000}, 50000, wakuFlag, nil, false, utils.Logger()) require.NoError(t, err) _ = localNode.Node() // Should not panic diff --git a/waku/v2/protocol/enr/shards.go b/waku/v2/protocol/enr/shards.go new file mode 100644 index 00000000..09930522 --- /dev/null +++ b/waku/v2/protocol/enr/shards.go @@ -0,0 +1,104 @@ +package enr + +import ( + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/p2p/enr" + "github.com/waku-org/go-waku/waku/v2/protocol" +) + +func SetWakuRelayShardingIndicesList(localnode *enode.LocalNode, rs protocol.RelayShards) error { + value, err := rs.IndicesList() + if err != nil { + return err + } + localnode.Set(enr.WithEntry(ShardingIndicesListEnrField, value)) + return nil +} + +func SetWakuRelayShardingBitVector(localnode *enode.LocalNode, rs protocol.RelayShards) error { + localnode.Set(enr.WithEntry(ShardingBitVectorEnrField, rs.BitVector())) + return nil +} + +func SetWakuRelaySharding(localnode *enode.LocalNode, rs protocol.RelayShards) error { + if len(rs.Indices) >= 64 { + return SetWakuRelayShardingBitVector(localnode, rs) + } else { + return SetWakuRelayShardingIndicesList(localnode, rs) + } +} + +// ENR record accessors + +func RelayShardingIndicesList(localnode *enode.LocalNode) (*protocol.RelayShards, error) { + var field []byte + if err := localnode.Node().Record().Load(enr.WithEntry(ShardingIndicesListEnrField, field)); err != nil { + return nil, nil + } + + res, err := protocol.FromIndicesList(field) + if err != nil { + return nil, err + } + + return &res, nil +} + +func RelayShardingBitVector(localnode *enode.LocalNode) (*protocol.RelayShards, error) { + var field []byte + if err := localnode.Node().Record().Load(enr.WithEntry(ShardingBitVectorEnrField, field)); err != nil { + return nil, nil + } + + res, err := protocol.FromBitVector(field) + if err != nil { + return nil, err + } + + return &res, nil +} + +func RelaySharding(localnode *enode.LocalNode) (*protocol.RelayShards, error) { + res, err := RelayShardingIndicesList(localnode) + if err != nil { + return nil, err + } + + if res != nil { + return res, nil + } + + return RelayShardingBitVector(localnode) +} + +// Utils + +func ContainsShard(localnode *enode.LocalNode, cluster uint16, index uint16) bool { + if index > protocol.MaxShardIndex { + return false + } + + rs, err := RelaySharding(localnode) + if err != nil { + return false + } + + return rs.Contains(cluster, index) +} + +func ContainsShardWithNsTopic(localnode *enode.LocalNode, topic protocol.NamespacedPubsubTopic) bool { + if topic.Kind() != protocol.StaticSharding { + return false + } + shardTopic := topic.(protocol.StaticShardingPubsubTopic) + return ContainsShard(localnode, shardTopic.Cluster(), shardTopic.Shard()) + +} + +func ContainsShardTopic(localnode *enode.LocalNode, topic string) bool { + shardTopic, err := protocol.ToShardedPubsubTopic(topic) + if err != nil { + return false + } + return ContainsShardWithNsTopic(localnode, shardTopic) +} diff --git a/waku/v2/protocol/peer_exchange/client.go b/waku/v2/protocol/peer_exchange/client.go index 3c97e5f1..9f151c97 100644 --- a/waku/v2/protocol/peer_exchange/client.go +++ b/waku/v2/protocol/peer_exchange/client.go @@ -11,8 +11,8 @@ import ( "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-msgio/pbio" "github.com/waku-org/go-waku/waku/v2/metrics" + wenr "github.com/waku-org/go-waku/waku/v2/protocol/enr" "github.com/waku-org/go-waku/waku/v2/protocol/peer_exchange/pb" - "github.com/waku-org/go-waku/waku/v2/utils" "go.uber.org/zap" ) @@ -85,7 +85,7 @@ func (wakuPX *WakuPeerExchange) handleResponse(ctx context.Context, response *pb return err } - peerInfo, err := utils.EnodeToPeerInfo(enodeRecord) + peerInfo, err := wenr.EnodeToPeerInfo(enodeRecord) if err != nil { return err } diff --git a/waku/v2/protocol/peer_exchange/protocol.go b/waku/v2/protocol/peer_exchange/protocol.go index 24c9293a..ae48a0cd 100644 --- a/waku/v2/protocol/peer_exchange/protocol.go +++ b/waku/v2/protocol/peer_exchange/protocol.go @@ -21,8 +21,8 @@ import ( "github.com/waku-org/go-waku/waku/v2/discv5" "github.com/waku-org/go-waku/waku/v2/metrics" "github.com/waku-org/go-waku/waku/v2/protocol" + "github.com/waku-org/go-waku/waku/v2/protocol/enr" "github.com/waku-org/go-waku/waku/v2/protocol/peer_exchange/pb" - "github.com/waku-org/go-waku/waku/v2/utils" "go.uber.org/zap" ) @@ -237,7 +237,7 @@ func (wakuPX *WakuPeerExchange) iterate(ctx context.Context) error { break } - _, addresses, err := utils.Multiaddress(iterator.Node()) + _, addresses, err := enr.Multiaddress(iterator.Node()) if err != nil { wakuPX.log.Error("extracting multiaddrs from enr", zap.Error(err)) continue diff --git a/waku/v2/protocol/peer_exchange/waku_peer_exchange_test.go b/waku/v2/protocol/peer_exchange/waku_peer_exchange_test.go index b4c53014..8db5db8f 100644 --- a/waku/v2/protocol/peer_exchange/waku_peer_exchange_test.go +++ b/waku/v2/protocol/peer_exchange/waku_peer_exchange_test.go @@ -20,6 +20,7 @@ import ( "github.com/stretchr/testify/require" "github.com/waku-org/go-waku/tests" "github.com/waku-org/go-waku/waku/v2/discv5" + wenr "github.com/waku-org/go-waku/waku/v2/protocol/enr" "github.com/waku-org/go-waku/waku/v2/utils" "go.uber.org/zap" @@ -67,14 +68,14 @@ func extractIP(addr multiaddr.Multiaddr) (*net.TCPAddr, error) { }, nil } -func newLocalnode(priv *ecdsa.PrivateKey, ipAddr *net.TCPAddr, udpPort int, wakuFlags utils.WakuEnrBitfield, advertiseAddr *net.IP, log *zap.Logger) (*enode.LocalNode, error) { +func newLocalnode(priv *ecdsa.PrivateKey, ipAddr *net.TCPAddr, udpPort int, wakuFlags wenr.WakuEnrBitfield, advertiseAddr *net.IP, log *zap.Logger) (*enode.LocalNode, error) { db, err := enode.OpenDB("") if err != nil { return nil, err } localnode := enode.NewLocalNode(db, priv) localnode.SetFallbackUDP(udpPort) - localnode.Set(enr.WithEntry(utils.WakuENRField, wakuFlags)) + localnode.Set(enr.WithEntry(wenr.WakuENRField, wakuFlags)) localnode.SetFallbackIP(net.IP{127, 0, 0, 1}) localnode.SetStaticIP(ipAddr.IP) @@ -103,7 +104,7 @@ func TestRetrieveProvidePeerExchangePeers(t *testing.T) { udpPort1, err := tests.FindFreePort(t, "127.0.0.1", 3) require.NoError(t, err) ip1, _ := extractIP(host1.Addrs()[0]) - l1, err := newLocalnode(prvKey1, ip1, udpPort1, utils.NewWakuEnrBitfield(false, false, false, true), nil, utils.Logger()) + l1, err := newLocalnode(prvKey1, ip1, udpPort1, wenr.NewWakuEnrBitfield(false, false, false, true), nil, utils.Logger()) require.NoError(t, err) discv5PeerConn1 := tests.NewTestPeerDiscoverer() d1, err := discv5.NewDiscoveryV5(prvKey1, l1, discv5PeerConn1, utils.Logger(), discv5.WithUDPPort(uint(udpPort1))) @@ -115,7 +116,7 @@ func TestRetrieveProvidePeerExchangePeers(t *testing.T) { ip2, _ := extractIP(host2.Addrs()[0]) udpPort2, err := tests.FindFreePort(t, "127.0.0.1", 3) require.NoError(t, err) - l2, err := newLocalnode(prvKey2, ip2, udpPort2, utils.NewWakuEnrBitfield(false, false, false, true), nil, utils.Logger()) + l2, err := newLocalnode(prvKey2, ip2, udpPort2, wenr.NewWakuEnrBitfield(false, false, false, true), nil, utils.Logger()) require.NoError(t, err) discv5PeerConn2 := tests.NewTestPeerDiscoverer() d2, err := discv5.NewDiscoveryV5(prvKey2, l2, discv5PeerConn2, utils.Logger(), discv5.WithUDPPort(uint(udpPort2)), discv5.WithBootnodes([]*enode.Node{d1.Node()})) diff --git a/waku/v2/protocol/shard.go b/waku/v2/protocol/shard.go new file mode 100644 index 00000000..fec1ccb1 --- /dev/null +++ b/waku/v2/protocol/shard.go @@ -0,0 +1,166 @@ +package protocol + +import ( + "encoding/binary" + "errors" + "fmt" + "math" +) + +const MaxShardIndex = uint16(1023) + +type RelayShards struct { + Cluster uint16 + Indices []uint16 +} + +func NewRelayShards(cluster uint16, indices ...uint16) (RelayShards, error) { + if len(indices) > math.MaxUint8 { + return RelayShards{}, errors.New("too many indices") + } + + indiceSet := make(map[uint16]struct{}) + for _, index := range indices { + if index > MaxShardIndex { + return RelayShards{}, errors.New("invalid index") + } + indiceSet[index] = struct{}{} // dedup + } + + if len(indiceSet) == 0 { + return RelayShards{}, errors.New("invalid index count") + } + + indices = []uint16{} + for index := range indiceSet { + indices = append(indices, index) + } + + return RelayShards{Cluster: cluster, Indices: indices}, nil +} + +func (rs RelayShards) Topics() []NamespacedPubsubTopic { + var result []NamespacedPubsubTopic + for _, i := range rs.Indices { + result = append(result, NewStaticShardingPubsubTopic(rs.Cluster, i)) + } + return result +} + +func (rs RelayShards) Contains(cluster uint16, index uint16) bool { + if rs.Cluster != cluster { + return false + } + + found := false + for _, i := range rs.Indices { + if i == index { + found = true + } + } + + return found +} + +func (rs RelayShards) ContainsNamespacedTopic(topic NamespacedPubsubTopic) bool { + if topic.Kind() != StaticSharding { + return false + } + + shardedTopic := topic.(StaticShardingPubsubTopic) + + return rs.Contains(shardedTopic.Cluster(), shardedTopic.Shard()) +} + +func (rs RelayShards) ContainsTopic(topic string) bool { + nsTopic, err := ToShardedPubsubTopic(topic) + if err != nil { + return false + } + return rs.ContainsNamespacedTopic(nsTopic) +} + +func (rs RelayShards) IndicesList() ([]byte, error) { + if len(rs.Indices) > math.MaxUint8 { + return nil, errors.New("indices list too long") + } + + var result []byte + + result = binary.BigEndian.AppendUint16(result, rs.Cluster) + result = append(result, uint8(len(rs.Indices))) + for _, index := range rs.Indices { + result = binary.BigEndian.AppendUint16(result, index) + } + + return result, nil +} + +func FromIndicesList(buf []byte) (RelayShards, error) { + if len(buf) < 3 { + return RelayShards{}, fmt.Errorf("insufficient data: expected at least 3 bytes, got %d bytes", len(buf)) + } + + cluster := binary.BigEndian.Uint16(buf[0:2]) + length := int(buf[2]) + + if len(buf) != 3+2*length { + return RelayShards{}, fmt.Errorf("invalid data: `length` field is %d but %d bytes were provided", length, len(buf)) + } + + var indices []uint16 + for i := 0; i < length; i++ { + indices = append(indices, binary.BigEndian.Uint16(buf[3+2*i:5+2*i])) + } + + return NewRelayShards(cluster, indices...) +} + +func setBit(n byte, pos uint) byte { + n |= (1 << pos) + return n +} + +func hasBit(n byte, pos uint) bool { + val := n & (1 << pos) + return (val > 0) +} + +func (rs RelayShards) BitVector() []byte { + // The value is comprised of a two-byte shard cluster index in network byte + // order concatenated with a 128-byte wide bit vector. The bit vector + // indicates which shards of the respective shard cluster the node is part + // of. The right-most bit in the bit vector represents shard 0, the left-most + // bit represents shard 1023. + var result []byte + result = binary.BigEndian.AppendUint16(result, rs.Cluster) + + vec := make([]byte, 128) + for _, index := range rs.Indices { + n := vec[index/8] + vec[index/8] = byte(setBit(n, uint(index%8))) + } + + return append(result, vec...) +} + +func FromBitVector(buf []byte) (RelayShards, error) { + if len(buf) != 130 { + return RelayShards{}, errors.New("invalid data: expected 130 bytes") + } + + cluster := binary.BigEndian.Uint16(buf[0:2]) + var indices []uint16 + + for i := uint16(0); i < 128; i++ { + for j := uint(0); j < 8; j++ { + if !hasBit(buf[2+i], j) { + continue + } + + indices = append(indices, uint16(j)+8*i) + } + } + + return RelayShards{Cluster: cluster, Indices: indices}, nil +} diff --git a/waku/v2/protocol/topic.go b/waku/v2/protocol/topic.go index b5f6a7fb..abbdc7b1 100644 --- a/waku/v2/protocol/topic.go +++ b/waku/v2/protocol/topic.go @@ -3,7 +3,6 @@ package protocol import ( "errors" "fmt" - "runtime/debug" "strconv" "strings" ) @@ -72,19 +71,19 @@ const ( NamedSharding ) -type ShardedPubsubTopic interface { +type NamespacedPubsubTopic interface { String() string Kind() NamespacedPubsubTopicKind - Equal(ShardedPubsubTopic) bool + Equal(NamespacedPubsubTopic) bool } type NamedShardingPubsubTopic struct { - ShardedPubsubTopic + NamespacedPubsubTopic kind NamespacedPubsubTopicKind name string } -func NewNamedShardingPubsubTopic(name string) ShardedPubsubTopic { +func NewNamedShardingPubsubTopic(name string) NamespacedPubsubTopic { return NamedShardingPubsubTopic{ kind: NamedSharding, name: name, @@ -99,7 +98,7 @@ func (n NamedShardingPubsubTopic) Name() string { return n.name } -func (s NamedShardingPubsubTopic) Equal(t2 ShardedPubsubTopic) bool { +func (s NamedShardingPubsubTopic) Equal(t2 NamespacedPubsubTopic) bool { return s.String() == t2.String() } @@ -124,13 +123,13 @@ func (s *NamedShardingPubsubTopic) Parse(topic string) error { } type StaticShardingPubsubTopic struct { - ShardedPubsubTopic + NamespacedPubsubTopic kind NamespacedPubsubTopicKind cluster uint16 shard uint16 } -func NewStaticShardingPubsubTopic(cluster uint16, shard uint16) ShardedPubsubTopic { +func NewStaticShardingPubsubTopic(cluster uint16, shard uint16) NamespacedPubsubTopic { return StaticShardingPubsubTopic{ kind: StaticSharding, cluster: cluster, @@ -150,7 +149,7 @@ func (n StaticShardingPubsubTopic) Kind() NamespacedPubsubTopicKind { return n.kind } -func (s StaticShardingPubsubTopic) Equal(t2 ShardedPubsubTopic) bool { +func (s StaticShardingPubsubTopic) Equal(t2 NamespacedPubsubTopic) bool { return s.String() == t2.String() } @@ -196,7 +195,7 @@ func (s *StaticShardingPubsubTopic) Parse(topic string) error { return nil } -func ToShardedPubsubTopic(topic string) (ShardedPubsubTopic, error) { +func ToShardedPubsubTopic(topic string) (NamespacedPubsubTopic, error) { if strings.HasPrefix(topic, StaticShardingPubsubTopicPrefix) { s := StaticShardingPubsubTopic{} err := s.Parse(topic) @@ -205,7 +204,6 @@ func ToShardedPubsubTopic(topic string) (ShardedPubsubTopic, error) { } return s, nil } else { - debug.PrintStack() s := NamedShardingPubsubTopic{} err := s.Parse(topic) if err != nil { @@ -215,6 +213,6 @@ func ToShardedPubsubTopic(topic string) (ShardedPubsubTopic, error) { } } -func DefaultPubsubTopic() ShardedPubsubTopic { +func DefaultPubsubTopic() NamespacedPubsubTopic { return NewNamedShardingPubsubTopic("default-waku/proto") }