diff --git a/waku/v2/protocol/enr/shards.go b/waku/v2/protocol/enr/shards.go index bec90b7a..d5f47b64 100644 --- a/waku/v2/protocol/enr/shards.go +++ b/waku/v2/protocol/enr/shards.go @@ -1,6 +1,8 @@ package enr import ( + "errors" + "github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/p2p/enr" "github.com/waku-org/go-waku/waku/v2/protocol" @@ -24,7 +26,7 @@ func WithWakuRelayShardingBitVector(rs protocol.RelayShards) ENROption { } } -func WithtWakuRelaySharding(rs protocol.RelayShards) ENROption { +func WithWakuRelaySharding(rs protocol.RelayShards) ENROption { return func(localnode *enode.LocalNode) error { if len(rs.Indices) >= 64 { return WithWakuRelayShardingBitVector(rs)(localnode) @@ -34,6 +36,21 @@ func WithtWakuRelaySharding(rs protocol.RelayShards) ENROption { } } +func WithWakuRelayShardingTopics(topics ...string) ENROption { + return func(localnode *enode.LocalNode) error { + rs, err := protocol.TopicsToRelayShards(topics...) + if err != nil { + return err + } + + if len(rs) != 1 { + return errors.New("expected a single RelayShards") + } + + return WithWakuRelaySharding(rs[0])(localnode) + } +} + // ENR record accessors func RelayShardingIndicesList(localnode *enode.LocalNode) (*protocol.RelayShards, error) { diff --git a/waku/v2/protocol/shard.go b/waku/v2/protocol/shard.go index fec1ccb1..184817b8 100644 --- a/waku/v2/protocol/shard.go +++ b/waku/v2/protocol/shard.go @@ -72,6 +72,42 @@ func (rs RelayShards) ContainsNamespacedTopic(topic NamespacedPubsubTopic) bool return rs.Contains(shardedTopic.Cluster(), shardedTopic.Shard()) } +func TopicsToRelayShards(topic ...string) ([]RelayShards, error) { + result := make([]RelayShards, 0) + dict := make(map[uint16]map[uint16]struct{}) + for _, t := range topic { + var ps StaticShardingPubsubTopic + err := ps.Parse(t) + if err != nil { + return nil, err + } + + indices, ok := dict[ps.cluster] + if !ok { + indices = make(map[uint16]struct{}) + } + + indices[ps.shard] = struct{}{} + dict[ps.cluster] = indices + } + + for cluster, indices := range dict { + idx := make([]uint16, 0, len(indices)) + for index := range indices { + idx = append(idx, index) + } + + rs, err := NewRelayShards(cluster, idx...) + if err != nil { + return nil, err + } + + result = append(result, rs) + } + + return result, nil +} + func (rs RelayShards) ContainsTopic(topic string) bool { nsTopic, err := ToShardedPubsubTopic(topic) if err != nil {