From 36beb9de757a29c20308b5e6e11ee571e93fe666 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?rich=CE=9Brd?= Date: Tue, 31 Oct 2023 06:50:13 -0400 Subject: [PATCH] refactor: fix nomenclature for shards (#849) --- waku/v2/discv5/discover.go | 6 +- waku/v2/node/localnode.go | 2 +- waku/v2/protocol/enr/shards.go | 14 +-- waku/v2/protocol/metadata/waku_metadata.go | 10 +- .../protocol/metadata/waku_metadata_test.go | 20 ++-- waku/v2/protocol/pubsub_topic.go | 28 +++--- waku/v2/protocol/shard.go | 97 ++++++++++--------- waku/v2/rendezvous/rendezvous.go | 4 +- 8 files changed, 94 insertions(+), 87 deletions(-) diff --git a/waku/v2/discv5/discover.go b/waku/v2/discv5/discover.go index 95bb8c78..1f2e9677 100644 --- a/waku/v2/discv5/discover.go +++ b/waku/v2/discv5/discover.go @@ -400,13 +400,13 @@ func (d *DiscoveryV5) defaultPredicate() Predicate { return false } - if nodeRS.Cluster != localRS.Cluster { + if nodeRS.ClusterID != localRS.ClusterID { return false } // Contains any - for _, idx := range localRS.Indices { - if nodeRS.Contains(localRS.Cluster, idx) { + for _, idx := range localRS.ShardIDs { + if nodeRS.Contains(localRS.ClusterID, idx) { return true } } diff --git a/waku/v2/node/localnode.go b/waku/v2/node/localnode.go index fe64e558..1d12d535 100644 --- a/waku/v2/node/localnode.go +++ b/waku/v2/node/localnode.go @@ -319,7 +319,7 @@ func (w *WakuNode) watchTopicShards(ctx context.Context) error { if len(rs) == 1 { w.log.Info("updating advertised relay shards in ENR") - if len(rs[0].Indices) != len(topics) { + if len(rs[0].ShardIDs) != len(topics) { w.log.Warn("A mix of named and static shards found. ENR shard will contain only the following shards", zap.Any("shards", rs[0])) } diff --git a/waku/v2/protocol/enr/shards.go b/waku/v2/protocol/enr/shards.go index 0e57eda2..29adc2d5 100644 --- a/waku/v2/protocol/enr/shards.go +++ b/waku/v2/protocol/enr/shards.go @@ -13,9 +13,9 @@ func deleteShardingENREntries(localnode *enode.LocalNode) { localnode.Delete(enr.WithEntry(ShardingIndicesListEnrField, struct{}{})) } -func WithWakuRelayShardingIndicesList(rs protocol.RelayShards) ENROption { +func WithWakuRelayShardList(rs protocol.RelayShards) ENROption { return func(localnode *enode.LocalNode) error { - value, err := rs.IndicesList() + value, err := rs.ShardList() if err != nil { return err } @@ -35,11 +35,11 @@ func WithWakuRelayShardingBitVector(rs protocol.RelayShards) ENROption { func WithWakuRelaySharding(rs protocol.RelayShards) ENROption { return func(localnode *enode.LocalNode) error { - if len(rs.Indices) >= 64 { + if len(rs.ShardIDs) >= 64 { return WithWakuRelayShardingBitVector(rs)(localnode) } - return WithWakuRelayShardingIndicesList(rs)(localnode) + return WithWakuRelayShardList(rs)(localnode) } } @@ -60,7 +60,7 @@ func WithWakuRelayShardingTopics(topics ...string) ENROption { // ENR record accessors -func RelayShardingIndicesList(record *enr.Record) (*protocol.RelayShards, error) { +func RelayShardList(record *enr.Record) (*protocol.RelayShards, error) { var field []byte if err := record.Load(enr.WithEntry(ShardingIndicesListEnrField, &field)); err != nil { if enr.IsNotFound(err) { @@ -69,7 +69,7 @@ func RelayShardingIndicesList(record *enr.Record) (*protocol.RelayShards, error) return nil, err } - res, err := protocol.FromIndicesList(field) + res, err := protocol.FromShardList(field) if err != nil { return nil, err } @@ -95,7 +95,7 @@ func RelayShardingBitVector(record *enr.Record) (*protocol.RelayShards, error) { } func RelaySharding(record *enr.Record) (*protocol.RelayShards, error) { - res, err := RelayShardingIndicesList(record) + res, err := RelayShardList(record) if err != nil { return nil, err } diff --git a/waku/v2/protocol/metadata/waku_metadata.go b/waku/v2/protocol/metadata/waku_metadata.go index 3f0295ea..aae9db62 100644 --- a/waku/v2/protocol/metadata/waku_metadata.go +++ b/waku/v2/protocol/metadata/waku_metadata.go @@ -78,8 +78,8 @@ func (wakuM *WakuMetadata) getClusterAndShards() (*uint32, []uint32, error) { } var shards []uint32 - if shard != nil && shard.Cluster == uint16(wakuM.clusterID) { - for _, idx := range shard.Indices { + if shard != nil && shard.ClusterID == uint16(wakuM.clusterID) { + for _, idx := range shard.ShardIDs { shards = append(shards, uint32(idx)) } } @@ -139,9 +139,9 @@ func (wakuM *WakuMetadata) Request(ctx context.Context, peerID peer.ID) (*protoc } result := &protocol.RelayShards{} - result.Cluster = uint16(*response.ClusterId) + result.ClusterID = uint16(*response.ClusterId) for _, i := range response.Shards { - result.Indices = append(result.Indices, uint16(i)) + result.ShardIDs = append(result.ShardIDs, uint16(i)) } return result, nil @@ -226,7 +226,7 @@ func (wakuM *WakuMetadata) Connected(n network.Network, cc network.Conn) { if err == nil { if shard == nil { err = errors.New("no shard reported") - } else if shard.Cluster != wakuM.clusterID { + } else if shard.ClusterID != wakuM.clusterID { err = errors.New("different clusterID reported") } } else { diff --git a/waku/v2/protocol/metadata/waku_metadata_test.go b/waku/v2/protocol/metadata/waku_metadata_test.go index 18eaf990..58b462b4 100644 --- a/waku/v2/protocol/metadata/waku_metadata_test.go +++ b/waku/v2/protocol/metadata/waku_metadata_test.go @@ -27,14 +27,14 @@ func createWakuMetadata(t *testing.T, rs *protocol.RelayShards) *WakuMetadata { localNode, err := enr.NewLocalnode(key) require.NoError(t, err) - cluster := uint16(0) + clusterID := uint16(0) if rs != nil { err = enr.WithWakuRelaySharding(*rs)(localNode) require.NoError(t, err) - cluster = rs.Cluster + clusterID = rs.ClusterID } - m1 := NewWakuMetadata(cluster, localNode, utils.Logger()) + m1 := NewWakuMetadata(clusterID, localNode, utils.Logger()) m1.SetHost(host) err = m1.Start(context.TODO()) require.NoError(t, err) @@ -65,25 +65,25 @@ func TestWakuMetadataRequest(t *testing.T) { // Query a peer that is subscribed to a shard result, err := m16_1.Request(context.Background(), m16_2.h.ID()) require.NoError(t, err) - require.Equal(t, testShard16, result.Cluster) - require.Equal(t, rs16_2.Indices, result.Indices) + require.Equal(t, testShard16, result.ClusterID) + require.Equal(t, rs16_2.ShardIDs, result.ShardIDs) // Updating the peer shards - rs16_2.Indices = append(rs16_2.Indices, 3, 4) + rs16_2.ShardIDs = append(rs16_2.ShardIDs, 3, 4) err = enr.WithWakuRelaySharding(rs16_2)(m16_2.localnode) require.NoError(t, err) // Query same peer, after that peer subscribes to more shards result, err = m16_1.Request(context.Background(), m16_2.h.ID()) require.NoError(t, err) - require.Equal(t, testShard16, result.Cluster) - require.ElementsMatch(t, rs16_2.Indices, result.Indices) + require.Equal(t, testShard16, result.ClusterID) + require.ElementsMatch(t, rs16_2.ShardIDs, result.ShardIDs) // Query a peer not subscribed to a shard result, err = m16_1.Request(context.Background(), m_noRS.h.ID()) require.NoError(t, err) - require.Equal(t, uint16(0), result.Cluster) - require.Len(t, result.Indices, 0) + require.Equal(t, uint16(0), result.ClusterID) + require.Len(t, result.ShardIDs, 0) } func TestNoNetwork(t *testing.T) { diff --git a/waku/v2/protocol/pubsub_topic.go b/waku/v2/protocol/pubsub_topic.go index b3fcd395..6b2fd94e 100644 --- a/waku/v2/protocol/pubsub_topic.go +++ b/waku/v2/protocol/pubsub_topic.go @@ -37,26 +37,26 @@ var ErrInvalidNumberFormat = errors.New("only 2^16 numbers are allowed") // StaticShardingPubsubTopic describes a pubSub topic as per StaticSharding type StaticShardingPubsubTopic struct { - cluster uint16 - shard uint16 + clusterID uint16 + shardID uint16 } // NewStaticShardingPubsubTopic creates a new pubSub topic func NewStaticShardingPubsubTopic(cluster uint16, shard uint16) StaticShardingPubsubTopic { return StaticShardingPubsubTopic{ - cluster: cluster, - shard: shard, + clusterID: cluster, + shardID: shard, } } // Cluster returns the sharded cluster index func (s StaticShardingPubsubTopic) Cluster() uint16 { - return s.cluster + return s.clusterID } // Shard returns the shard number func (s StaticShardingPubsubTopic) Shard() uint16 { - return s.shard + return s.shardID } // Equal compares StaticShardingPubsubTopic @@ -66,7 +66,7 @@ func (s StaticShardingPubsubTopic) Equal(t2 StaticShardingPubsubTopic) bool { // String formats StaticShardingPubsubTopic to RFC 23 specific string format for pubsub topic. func (s StaticShardingPubsubTopic) String() string { - return fmt.Sprintf("%s/%d/%d", StaticShardingPubsubTopicPrefix, s.cluster, s.shard) + return fmt.Sprintf("%s/%d/%d", StaticShardingPubsubTopicPrefix, s.clusterID, s.shardID) } // Parse parses a topic string into a StaticShardingPubsubTopic @@ -100,18 +100,18 @@ func (s *StaticShardingPubsubTopic) Parse(topic string) error { return ErrInvalidNumberFormat } - s.shard = uint16(shardInt) - s.cluster = uint16(clusterInt) + s.shardID = uint16(shardInt) + s.clusterID = uint16(clusterInt) return nil } func ToShardPubsubTopic(topic WakuPubSubTopic) (StaticShardingPubsubTopic, error) { - result, ok := topic.(StaticShardingPubsubTopic) - if !ok { - return StaticShardingPubsubTopic{}, ErrNotShardPubsubTopic - } - return result, nil + result, ok := topic.(StaticShardingPubsubTopic) + if !ok { + return StaticShardingPubsubTopic{}, ErrNotShardPubsubTopic + } + return result, nil } // ToWakuPubsubTopic takes a pubSub topic string and creates a WakuPubsubTopic object. diff --git a/waku/v2/protocol/shard.go b/waku/v2/protocol/shard.go index a49c6e30..5a426182 100644 --- a/waku/v2/protocol/shard.go +++ b/waku/v2/protocol/shard.go @@ -13,57 +13,64 @@ import ( const MaxShardIndex = uint16(1023) // ClusterIndex is the clusterID used in sharding space. -// For indices allocation and other magic numbers refer to RFC 51 +// For shardIDs allocation and other magic numbers refer to RFC 51 const ClusterIndex = 1 // GenerationZeroShardsCount is number of shards supported in generation-0 const GenerationZeroShardsCount = 8 +var ( + ErrTooManyShards = errors.New("too many shards") + ErrInvalidShard = errors.New("invalid shard") + ErrInvalidShardCount = errors.New("invalid shard count") + ErrExpected130Bytes = errors.New("invalid data: expected 130 bytes") +) + type RelayShards struct { - Cluster uint16 `json:"cluster"` - Indices []uint16 `json:"indices"` + ClusterID uint16 `json:"clusterID"` + ShardIDs []uint16 `json:"shardIDs"` } -func NewRelayShards(cluster uint16, indices ...uint16) (RelayShards, error) { - if len(indices) > math.MaxUint8 { - return RelayShards{}, errors.New("too many indices") +func NewRelayShards(clusterID uint16, shardIDs ...uint16) (RelayShards, error) { + if len(shardIDs) > math.MaxUint8 { + return RelayShards{}, ErrTooManyShards } - indiceSet := make(map[uint16]struct{}) - for _, index := range indices { + shardIDSet := make(map[uint16]struct{}) + for _, index := range shardIDs { if index > MaxShardIndex { - return RelayShards{}, errors.New("invalid index") + return RelayShards{}, ErrInvalidShard } - indiceSet[index] = struct{}{} // dedup + shardIDSet[index] = struct{}{} // dedup } - if len(indiceSet) == 0 { - return RelayShards{}, errors.New("invalid index count") + if len(shardIDSet) == 0 { + return RelayShards{}, ErrInvalidShardCount } - indices = []uint16{} - for index := range indiceSet { - indices = append(indices, index) + shardIDs = []uint16{} + for index := range shardIDSet { + shardIDs = append(shardIDs, index) } - return RelayShards{Cluster: cluster, Indices: indices}, nil + return RelayShards{ClusterID: clusterID, ShardIDs: shardIDs}, nil } func (rs RelayShards) Topics() []WakuPubSubTopic { var result []WakuPubSubTopic - for _, i := range rs.Indices { - result = append(result, NewStaticShardingPubsubTopic(rs.Cluster, i)) + for _, i := range rs.ShardIDs { + result = append(result, NewStaticShardingPubsubTopic(rs.ClusterID, i)) } return result } func (rs RelayShards) Contains(cluster uint16, index uint16) bool { - if rs.Cluster != cluster { + if rs.ClusterID != cluster { return false } found := false - for _, idx := range rs.Indices { + for _, idx := range rs.ShardIDs { if idx == index { found = true } @@ -94,22 +101,22 @@ func TopicsToRelayShards(topic ...string) ([]RelayShards, error) { return nil, err } - indices, ok := dict[ps.cluster] + shardIDs, ok := dict[ps.clusterID] if !ok { - indices = make(map[uint16]struct{}) + shardIDs = make(map[uint16]struct{}) } - indices[ps.shard] = struct{}{} - dict[ps.cluster] = indices + shardIDs[ps.shardID] = struct{}{} + dict[ps.clusterID] = shardIDs } - for cluster, indices := range dict { - idx := make([]uint16, 0, len(indices)) - for index := range indices { - idx = append(idx, index) + for clusterID, shardIDs := range dict { + idx := make([]uint16, 0, len(shardIDs)) + for shardID := range shardIDs { + idx = append(idx, shardID) } - rs, err := NewRelayShards(cluster, idx...) + rs, err := NewRelayShards(clusterID, idx...) if err != nil { return nil, err } @@ -128,23 +135,23 @@ func (rs RelayShards) ContainsTopic(topic string) bool { return rs.ContainsShardPubsubTopic(wTopic) } -func (rs RelayShards) IndicesList() ([]byte, error) { - if len(rs.Indices) > math.MaxUint8 { - return nil, errors.New("indices list too long") +func (rs RelayShards) ShardList() ([]byte, error) { + if len(rs.ShardIDs) > math.MaxUint8 { + return nil, ErrTooManyShards } var result []byte - result = binary.BigEndian.AppendUint16(result, rs.Cluster) - result = append(result, uint8(len(rs.Indices))) - for _, index := range rs.Indices { + result = binary.BigEndian.AppendUint16(result, rs.ClusterID) + result = append(result, uint8(len(rs.ShardIDs))) + for _, index := range rs.ShardIDs { result = binary.BigEndian.AppendUint16(result, index) } return result, nil } -func FromIndicesList(buf []byte) (RelayShards, error) { +func FromShardList(buf []byte) (RelayShards, error) { if len(buf) < 3 { return RelayShards{}, fmt.Errorf("insufficient data: expected at least 3 bytes, got %d bytes", len(buf)) } @@ -156,12 +163,12 @@ func FromIndicesList(buf []byte) (RelayShards, error) { return RelayShards{}, fmt.Errorf("invalid data: `length` field is %d but %d bytes were provided", length, len(buf)) } - var indices []uint16 + shardIDs := make([]uint16, length) for i := 0; i < length; i++ { - indices = append(indices, binary.BigEndian.Uint16(buf[3+2*i:5+2*i])) + shardIDs[i] = binary.BigEndian.Uint16(buf[3+2*i : 5+2*i]) } - return NewRelayShards(cluster, indices...) + return NewRelayShards(cluster, shardIDs...) } func setBit(n byte, pos uint) byte { @@ -181,10 +188,10 @@ func (rs RelayShards) BitVector() []byte { // of. The right-most bit in the bit vector represents shard 0, the left-most // bit represents shard 1023. var result []byte - result = binary.BigEndian.AppendUint16(result, rs.Cluster) + result = binary.BigEndian.AppendUint16(result, rs.ClusterID) vec := make([]byte, 128) - for _, index := range rs.Indices { + for _, index := range rs.ShardIDs { n := vec[index/8] vec[index/8] = byte(setBit(n, uint(index%8))) } @@ -195,11 +202,11 @@ func (rs RelayShards) BitVector() []byte { // Generate a RelayShards from a byte slice func FromBitVector(buf []byte) (RelayShards, error) { if len(buf) != 130 { - return RelayShards{}, errors.New("invalid data: expected 130 bytes") + return RelayShards{}, ErrExpected130Bytes } cluster := binary.BigEndian.Uint16(buf[0:2]) - var indices []uint16 + var shardIDs []uint16 for i := uint16(0); i < 128; i++ { for j := uint(0); j < 8; j++ { @@ -207,11 +214,11 @@ func FromBitVector(buf []byte) (RelayShards, error) { continue } - indices = append(indices, uint16(j)+8*i) + shardIDs = append(shardIDs, uint16(j)+8*i) } } - return RelayShards{Cluster: cluster, Indices: indices}, nil + return RelayShards{ClusterID: cluster, ShardIDs: shardIDs}, nil } // GetShardFromContentTopic runs Autosharding logic and returns a pubSubTopic diff --git a/waku/v2/rendezvous/rendezvous.go b/waku/v2/rendezvous/rendezvous.go index 28efa899..db48faec 100644 --- a/waku/v2/rendezvous/rendezvous.go +++ b/waku/v2/rendezvous/rendezvous.go @@ -148,8 +148,8 @@ func (r *Rendezvous) RegisterShard(ctx context.Context, cluster uint16, shard ui // RegisterRelayShards registers the node in the rendezvous point by specifying a RelayShards struct (more than one shard index can be registered) func (r *Rendezvous) RegisterRelayShards(ctx context.Context, rs protocol.RelayShards, rendezvousPoints []*RendezvousPoint) { - for _, idx := range rs.Indices { - go r.RegisterShard(ctx, rs.Cluster, idx, rendezvousPoints) + for _, idx := range rs.ShardIDs { + go r.RegisterShard(ctx, rs.ClusterID, idx, rendezvousPoints) } }