refactor: fix nomenclature for shards (#849)

This commit is contained in:
richΛrd 2023-10-31 06:50:13 -04:00 committed by GitHub
parent 48acff4a5c
commit 36beb9de75
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 94 additions and 87 deletions

View File

@ -400,13 +400,13 @@ func (d *DiscoveryV5) defaultPredicate() Predicate {
return false return false
} }
if nodeRS.Cluster != localRS.Cluster { if nodeRS.ClusterID != localRS.ClusterID {
return false return false
} }
// Contains any // Contains any
for _, idx := range localRS.Indices { for _, idx := range localRS.ShardIDs {
if nodeRS.Contains(localRS.Cluster, idx) { if nodeRS.Contains(localRS.ClusterID, idx) {
return true return true
} }
} }

View File

@ -319,7 +319,7 @@ func (w *WakuNode) watchTopicShards(ctx context.Context) error {
if len(rs) == 1 { if len(rs) == 1 {
w.log.Info("updating advertised relay shards in ENR") w.log.Info("updating advertised relay shards in ENR")
if len(rs[0].Indices) != len(topics) { if len(rs[0].ShardIDs) != len(topics) {
w.log.Warn("A mix of named and static shards found. ENR shard will contain only the following shards", zap.Any("shards", rs[0])) w.log.Warn("A mix of named and static shards found. ENR shard will contain only the following shards", zap.Any("shards", rs[0]))
} }

View File

@ -13,9 +13,9 @@ func deleteShardingENREntries(localnode *enode.LocalNode) {
localnode.Delete(enr.WithEntry(ShardingIndicesListEnrField, struct{}{})) localnode.Delete(enr.WithEntry(ShardingIndicesListEnrField, struct{}{}))
} }
func WithWakuRelayShardingIndicesList(rs protocol.RelayShards) ENROption { func WithWakuRelayShardList(rs protocol.RelayShards) ENROption {
return func(localnode *enode.LocalNode) error { return func(localnode *enode.LocalNode) error {
value, err := rs.IndicesList() value, err := rs.ShardList()
if err != nil { if err != nil {
return err return err
} }
@ -35,11 +35,11 @@ func WithWakuRelayShardingBitVector(rs protocol.RelayShards) ENROption {
func WithWakuRelaySharding(rs protocol.RelayShards) ENROption { func WithWakuRelaySharding(rs protocol.RelayShards) ENROption {
return func(localnode *enode.LocalNode) error { return func(localnode *enode.LocalNode) error {
if len(rs.Indices) >= 64 { if len(rs.ShardIDs) >= 64 {
return WithWakuRelayShardingBitVector(rs)(localnode) return WithWakuRelayShardingBitVector(rs)(localnode)
} }
return WithWakuRelayShardingIndicesList(rs)(localnode) return WithWakuRelayShardList(rs)(localnode)
} }
} }
@ -60,7 +60,7 @@ func WithWakuRelayShardingTopics(topics ...string) ENROption {
// ENR record accessors // ENR record accessors
func RelayShardingIndicesList(record *enr.Record) (*protocol.RelayShards, error) { func RelayShardList(record *enr.Record) (*protocol.RelayShards, error) {
var field []byte var field []byte
if err := record.Load(enr.WithEntry(ShardingIndicesListEnrField, &field)); err != nil { if err := record.Load(enr.WithEntry(ShardingIndicesListEnrField, &field)); err != nil {
if enr.IsNotFound(err) { if enr.IsNotFound(err) {
@ -69,7 +69,7 @@ func RelayShardingIndicesList(record *enr.Record) (*protocol.RelayShards, error)
return nil, err return nil, err
} }
res, err := protocol.FromIndicesList(field) res, err := protocol.FromShardList(field)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -95,7 +95,7 @@ func RelayShardingBitVector(record *enr.Record) (*protocol.RelayShards, error) {
} }
func RelaySharding(record *enr.Record) (*protocol.RelayShards, error) { func RelaySharding(record *enr.Record) (*protocol.RelayShards, error) {
res, err := RelayShardingIndicesList(record) res, err := RelayShardList(record)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -78,8 +78,8 @@ func (wakuM *WakuMetadata) getClusterAndShards() (*uint32, []uint32, error) {
} }
var shards []uint32 var shards []uint32
if shard != nil && shard.Cluster == uint16(wakuM.clusterID) { if shard != nil && shard.ClusterID == uint16(wakuM.clusterID) {
for _, idx := range shard.Indices { for _, idx := range shard.ShardIDs {
shards = append(shards, uint32(idx)) shards = append(shards, uint32(idx))
} }
} }
@ -139,9 +139,9 @@ func (wakuM *WakuMetadata) Request(ctx context.Context, peerID peer.ID) (*protoc
} }
result := &protocol.RelayShards{} result := &protocol.RelayShards{}
result.Cluster = uint16(*response.ClusterId) result.ClusterID = uint16(*response.ClusterId)
for _, i := range response.Shards { for _, i := range response.Shards {
result.Indices = append(result.Indices, uint16(i)) result.ShardIDs = append(result.ShardIDs, uint16(i))
} }
return result, nil return result, nil
@ -226,7 +226,7 @@ func (wakuM *WakuMetadata) Connected(n network.Network, cc network.Conn) {
if err == nil { if err == nil {
if shard == nil { if shard == nil {
err = errors.New("no shard reported") err = errors.New("no shard reported")
} else if shard.Cluster != wakuM.clusterID { } else if shard.ClusterID != wakuM.clusterID {
err = errors.New("different clusterID reported") err = errors.New("different clusterID reported")
} }
} else { } else {

View File

@ -27,14 +27,14 @@ func createWakuMetadata(t *testing.T, rs *protocol.RelayShards) *WakuMetadata {
localNode, err := enr.NewLocalnode(key) localNode, err := enr.NewLocalnode(key)
require.NoError(t, err) require.NoError(t, err)
cluster := uint16(0) clusterID := uint16(0)
if rs != nil { if rs != nil {
err = enr.WithWakuRelaySharding(*rs)(localNode) err = enr.WithWakuRelaySharding(*rs)(localNode)
require.NoError(t, err) require.NoError(t, err)
cluster = rs.Cluster clusterID = rs.ClusterID
} }
m1 := NewWakuMetadata(cluster, localNode, utils.Logger()) m1 := NewWakuMetadata(clusterID, localNode, utils.Logger())
m1.SetHost(host) m1.SetHost(host)
err = m1.Start(context.TODO()) err = m1.Start(context.TODO())
require.NoError(t, err) require.NoError(t, err)
@ -65,25 +65,25 @@ func TestWakuMetadataRequest(t *testing.T) {
// Query a peer that is subscribed to a shard // Query a peer that is subscribed to a shard
result, err := m16_1.Request(context.Background(), m16_2.h.ID()) result, err := m16_1.Request(context.Background(), m16_2.h.ID())
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, testShard16, result.Cluster) require.Equal(t, testShard16, result.ClusterID)
require.Equal(t, rs16_2.Indices, result.Indices) require.Equal(t, rs16_2.ShardIDs, result.ShardIDs)
// Updating the peer shards // Updating the peer shards
rs16_2.Indices = append(rs16_2.Indices, 3, 4) rs16_2.ShardIDs = append(rs16_2.ShardIDs, 3, 4)
err = enr.WithWakuRelaySharding(rs16_2)(m16_2.localnode) err = enr.WithWakuRelaySharding(rs16_2)(m16_2.localnode)
require.NoError(t, err) require.NoError(t, err)
// Query same peer, after that peer subscribes to more shards // Query same peer, after that peer subscribes to more shards
result, err = m16_1.Request(context.Background(), m16_2.h.ID()) result, err = m16_1.Request(context.Background(), m16_2.h.ID())
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, testShard16, result.Cluster) require.Equal(t, testShard16, result.ClusterID)
require.ElementsMatch(t, rs16_2.Indices, result.Indices) require.ElementsMatch(t, rs16_2.ShardIDs, result.ShardIDs)
// Query a peer not subscribed to a shard // Query a peer not subscribed to a shard
result, err = m16_1.Request(context.Background(), m_noRS.h.ID()) result, err = m16_1.Request(context.Background(), m_noRS.h.ID())
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, uint16(0), result.Cluster) require.Equal(t, uint16(0), result.ClusterID)
require.Len(t, result.Indices, 0) require.Len(t, result.ShardIDs, 0)
} }
func TestNoNetwork(t *testing.T) { func TestNoNetwork(t *testing.T) {

View File

@ -37,26 +37,26 @@ var ErrInvalidNumberFormat = errors.New("only 2^16 numbers are allowed")
// StaticShardingPubsubTopic describes a pubSub topic as per StaticSharding // StaticShardingPubsubTopic describes a pubSub topic as per StaticSharding
type StaticShardingPubsubTopic struct { type StaticShardingPubsubTopic struct {
cluster uint16 clusterID uint16
shard uint16 shardID uint16
} }
// NewStaticShardingPubsubTopic creates a new pubSub topic // NewStaticShardingPubsubTopic creates a new pubSub topic
func NewStaticShardingPubsubTopic(cluster uint16, shard uint16) StaticShardingPubsubTopic { func NewStaticShardingPubsubTopic(cluster uint16, shard uint16) StaticShardingPubsubTopic {
return StaticShardingPubsubTopic{ return StaticShardingPubsubTopic{
cluster: cluster, clusterID: cluster,
shard: shard, shardID: shard,
} }
} }
// Cluster returns the sharded cluster index // Cluster returns the sharded cluster index
func (s StaticShardingPubsubTopic) Cluster() uint16 { func (s StaticShardingPubsubTopic) Cluster() uint16 {
return s.cluster return s.clusterID
} }
// Shard returns the shard number // Shard returns the shard number
func (s StaticShardingPubsubTopic) Shard() uint16 { func (s StaticShardingPubsubTopic) Shard() uint16 {
return s.shard return s.shardID
} }
// Equal compares StaticShardingPubsubTopic // Equal compares StaticShardingPubsubTopic
@ -66,7 +66,7 @@ func (s StaticShardingPubsubTopic) Equal(t2 StaticShardingPubsubTopic) bool {
// String formats StaticShardingPubsubTopic to RFC 23 specific string format for pubsub topic. // String formats StaticShardingPubsubTopic to RFC 23 specific string format for pubsub topic.
func (s StaticShardingPubsubTopic) String() string { func (s StaticShardingPubsubTopic) String() string {
return fmt.Sprintf("%s/%d/%d", StaticShardingPubsubTopicPrefix, s.cluster, s.shard) return fmt.Sprintf("%s/%d/%d", StaticShardingPubsubTopicPrefix, s.clusterID, s.shardID)
} }
// Parse parses a topic string into a StaticShardingPubsubTopic // Parse parses a topic string into a StaticShardingPubsubTopic
@ -100,18 +100,18 @@ func (s *StaticShardingPubsubTopic) Parse(topic string) error {
return ErrInvalidNumberFormat return ErrInvalidNumberFormat
} }
s.shard = uint16(shardInt) s.shardID = uint16(shardInt)
s.cluster = uint16(clusterInt) s.clusterID = uint16(clusterInt)
return nil return nil
} }
func ToShardPubsubTopic(topic WakuPubSubTopic) (StaticShardingPubsubTopic, error) { func ToShardPubsubTopic(topic WakuPubSubTopic) (StaticShardingPubsubTopic, error) {
result, ok := topic.(StaticShardingPubsubTopic) result, ok := topic.(StaticShardingPubsubTopic)
if !ok { if !ok {
return StaticShardingPubsubTopic{}, ErrNotShardPubsubTopic return StaticShardingPubsubTopic{}, ErrNotShardPubsubTopic
} }
return result, nil return result, nil
} }
// ToWakuPubsubTopic takes a pubSub topic string and creates a WakuPubsubTopic object. // ToWakuPubsubTopic takes a pubSub topic string and creates a WakuPubsubTopic object.

View File

@ -13,57 +13,64 @@ import (
const MaxShardIndex = uint16(1023) const MaxShardIndex = uint16(1023)
// ClusterIndex is the clusterID used in sharding space. // ClusterIndex is the clusterID used in sharding space.
// For indices allocation and other magic numbers refer to RFC 51 // For shardIDs allocation and other magic numbers refer to RFC 51
const ClusterIndex = 1 const ClusterIndex = 1
// GenerationZeroShardsCount is number of shards supported in generation-0 // GenerationZeroShardsCount is number of shards supported in generation-0
const GenerationZeroShardsCount = 8 const GenerationZeroShardsCount = 8
var (
ErrTooManyShards = errors.New("too many shards")
ErrInvalidShard = errors.New("invalid shard")
ErrInvalidShardCount = errors.New("invalid shard count")
ErrExpected130Bytes = errors.New("invalid data: expected 130 bytes")
)
type RelayShards struct { type RelayShards struct {
Cluster uint16 `json:"cluster"` ClusterID uint16 `json:"clusterID"`
Indices []uint16 `json:"indices"` ShardIDs []uint16 `json:"shardIDs"`
} }
func NewRelayShards(cluster uint16, indices ...uint16) (RelayShards, error) { func NewRelayShards(clusterID uint16, shardIDs ...uint16) (RelayShards, error) {
if len(indices) > math.MaxUint8 { if len(shardIDs) > math.MaxUint8 {
return RelayShards{}, errors.New("too many indices") return RelayShards{}, ErrTooManyShards
} }
indiceSet := make(map[uint16]struct{}) shardIDSet := make(map[uint16]struct{})
for _, index := range indices { for _, index := range shardIDs {
if index > MaxShardIndex { if index > MaxShardIndex {
return RelayShards{}, errors.New("invalid index") return RelayShards{}, ErrInvalidShard
} }
indiceSet[index] = struct{}{} // dedup shardIDSet[index] = struct{}{} // dedup
} }
if len(indiceSet) == 0 { if len(shardIDSet) == 0 {
return RelayShards{}, errors.New("invalid index count") return RelayShards{}, ErrInvalidShardCount
} }
indices = []uint16{} shardIDs = []uint16{}
for index := range indiceSet { for index := range shardIDSet {
indices = append(indices, index) shardIDs = append(shardIDs, index)
} }
return RelayShards{Cluster: cluster, Indices: indices}, nil return RelayShards{ClusterID: clusterID, ShardIDs: shardIDs}, nil
} }
func (rs RelayShards) Topics() []WakuPubSubTopic { func (rs RelayShards) Topics() []WakuPubSubTopic {
var result []WakuPubSubTopic var result []WakuPubSubTopic
for _, i := range rs.Indices { for _, i := range rs.ShardIDs {
result = append(result, NewStaticShardingPubsubTopic(rs.Cluster, i)) result = append(result, NewStaticShardingPubsubTopic(rs.ClusterID, i))
} }
return result return result
} }
func (rs RelayShards) Contains(cluster uint16, index uint16) bool { func (rs RelayShards) Contains(cluster uint16, index uint16) bool {
if rs.Cluster != cluster { if rs.ClusterID != cluster {
return false return false
} }
found := false found := false
for _, idx := range rs.Indices { for _, idx := range rs.ShardIDs {
if idx == index { if idx == index {
found = true found = true
} }
@ -94,22 +101,22 @@ func TopicsToRelayShards(topic ...string) ([]RelayShards, error) {
return nil, err return nil, err
} }
indices, ok := dict[ps.cluster] shardIDs, ok := dict[ps.clusterID]
if !ok { if !ok {
indices = make(map[uint16]struct{}) shardIDs = make(map[uint16]struct{})
} }
indices[ps.shard] = struct{}{} shardIDs[ps.shardID] = struct{}{}
dict[ps.cluster] = indices dict[ps.clusterID] = shardIDs
} }
for cluster, indices := range dict { for clusterID, shardIDs := range dict {
idx := make([]uint16, 0, len(indices)) idx := make([]uint16, 0, len(shardIDs))
for index := range indices { for shardID := range shardIDs {
idx = append(idx, index) idx = append(idx, shardID)
} }
rs, err := NewRelayShards(cluster, idx...) rs, err := NewRelayShards(clusterID, idx...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -128,23 +135,23 @@ func (rs RelayShards) ContainsTopic(topic string) bool {
return rs.ContainsShardPubsubTopic(wTopic) return rs.ContainsShardPubsubTopic(wTopic)
} }
func (rs RelayShards) IndicesList() ([]byte, error) { func (rs RelayShards) ShardList() ([]byte, error) {
if len(rs.Indices) > math.MaxUint8 { if len(rs.ShardIDs) > math.MaxUint8 {
return nil, errors.New("indices list too long") return nil, ErrTooManyShards
} }
var result []byte var result []byte
result = binary.BigEndian.AppendUint16(result, rs.Cluster) result = binary.BigEndian.AppendUint16(result, rs.ClusterID)
result = append(result, uint8(len(rs.Indices))) result = append(result, uint8(len(rs.ShardIDs)))
for _, index := range rs.Indices { for _, index := range rs.ShardIDs {
result = binary.BigEndian.AppendUint16(result, index) result = binary.BigEndian.AppendUint16(result, index)
} }
return result, nil return result, nil
} }
func FromIndicesList(buf []byte) (RelayShards, error) { func FromShardList(buf []byte) (RelayShards, error) {
if len(buf) < 3 { if len(buf) < 3 {
return RelayShards{}, fmt.Errorf("insufficient data: expected at least 3 bytes, got %d bytes", len(buf)) return RelayShards{}, fmt.Errorf("insufficient data: expected at least 3 bytes, got %d bytes", len(buf))
} }
@ -156,12 +163,12 @@ func FromIndicesList(buf []byte) (RelayShards, error) {
return RelayShards{}, fmt.Errorf("invalid data: `length` field is %d but %d bytes were provided", length, len(buf)) return RelayShards{}, fmt.Errorf("invalid data: `length` field is %d but %d bytes were provided", length, len(buf))
} }
var indices []uint16 shardIDs := make([]uint16, length)
for i := 0; i < length; i++ { for i := 0; i < length; i++ {
indices = append(indices, binary.BigEndian.Uint16(buf[3+2*i:5+2*i])) shardIDs[i] = binary.BigEndian.Uint16(buf[3+2*i : 5+2*i])
} }
return NewRelayShards(cluster, indices...) return NewRelayShards(cluster, shardIDs...)
} }
func setBit(n byte, pos uint) byte { func setBit(n byte, pos uint) byte {
@ -181,10 +188,10 @@ func (rs RelayShards) BitVector() []byte {
// of. The right-most bit in the bit vector represents shard 0, the left-most // of. The right-most bit in the bit vector represents shard 0, the left-most
// bit represents shard 1023. // bit represents shard 1023.
var result []byte var result []byte
result = binary.BigEndian.AppendUint16(result, rs.Cluster) result = binary.BigEndian.AppendUint16(result, rs.ClusterID)
vec := make([]byte, 128) vec := make([]byte, 128)
for _, index := range rs.Indices { for _, index := range rs.ShardIDs {
n := vec[index/8] n := vec[index/8]
vec[index/8] = byte(setBit(n, uint(index%8))) vec[index/8] = byte(setBit(n, uint(index%8)))
} }
@ -195,11 +202,11 @@ func (rs RelayShards) BitVector() []byte {
// Generate a RelayShards from a byte slice // Generate a RelayShards from a byte slice
func FromBitVector(buf []byte) (RelayShards, error) { func FromBitVector(buf []byte) (RelayShards, error) {
if len(buf) != 130 { if len(buf) != 130 {
return RelayShards{}, errors.New("invalid data: expected 130 bytes") return RelayShards{}, ErrExpected130Bytes
} }
cluster := binary.BigEndian.Uint16(buf[0:2]) cluster := binary.BigEndian.Uint16(buf[0:2])
var indices []uint16 var shardIDs []uint16
for i := uint16(0); i < 128; i++ { for i := uint16(0); i < 128; i++ {
for j := uint(0); j < 8; j++ { for j := uint(0); j < 8; j++ {
@ -207,11 +214,11 @@ func FromBitVector(buf []byte) (RelayShards, error) {
continue continue
} }
indices = append(indices, uint16(j)+8*i) shardIDs = append(shardIDs, uint16(j)+8*i)
} }
} }
return RelayShards{Cluster: cluster, Indices: indices}, nil return RelayShards{ClusterID: cluster, ShardIDs: shardIDs}, nil
} }
// GetShardFromContentTopic runs Autosharding logic and returns a pubSubTopic // GetShardFromContentTopic runs Autosharding logic and returns a pubSubTopic

View File

@ -148,8 +148,8 @@ func (r *Rendezvous) RegisterShard(ctx context.Context, cluster uint16, shard ui
// RegisterRelayShards registers the node in the rendezvous point by specifying a RelayShards struct (more than one shard index can be registered) // RegisterRelayShards registers the node in the rendezvous point by specifying a RelayShards struct (more than one shard index can be registered)
func (r *Rendezvous) RegisterRelayShards(ctx context.Context, rs protocol.RelayShards, rendezvousPoints []*RendezvousPoint) { func (r *Rendezvous) RegisterRelayShards(ctx context.Context, rs protocol.RelayShards, rendezvousPoints []*RendezvousPoint) {
for _, idx := range rs.Indices { for _, idx := range rs.ShardIDs {
go r.RegisterShard(ctx, rs.Cluster, idx, rendezvousPoints) go r.RegisterShard(ctx, rs.ClusterID, idx, rendezvousPoints)
} }
} }