refactor: fix nomenclature for shards (#849)

This commit is contained in:
richΛrd 2023-10-31 06:50:13 -04:00 committed by GitHub
parent 48acff4a5c
commit 36beb9de75
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 94 additions and 87 deletions

View File

@ -400,13 +400,13 @@ func (d *DiscoveryV5) defaultPredicate() Predicate {
return false
}
if nodeRS.Cluster != localRS.Cluster {
if nodeRS.ClusterID != localRS.ClusterID {
return false
}
// Contains any
for _, idx := range localRS.Indices {
if nodeRS.Contains(localRS.Cluster, idx) {
for _, idx := range localRS.ShardIDs {
if nodeRS.Contains(localRS.ClusterID, idx) {
return true
}
}

View File

@ -319,7 +319,7 @@ func (w *WakuNode) watchTopicShards(ctx context.Context) error {
if len(rs) == 1 {
w.log.Info("updating advertised relay shards in ENR")
if len(rs[0].Indices) != len(topics) {
if len(rs[0].ShardIDs) != len(topics) {
w.log.Warn("A mix of named and static shards found. ENR shard will contain only the following shards", zap.Any("shards", rs[0]))
}

View File

@ -13,9 +13,9 @@ func deleteShardingENREntries(localnode *enode.LocalNode) {
localnode.Delete(enr.WithEntry(ShardingIndicesListEnrField, struct{}{}))
}
func WithWakuRelayShardingIndicesList(rs protocol.RelayShards) ENROption {
func WithWakuRelayShardList(rs protocol.RelayShards) ENROption {
return func(localnode *enode.LocalNode) error {
value, err := rs.IndicesList()
value, err := rs.ShardList()
if err != nil {
return err
}
@ -35,11 +35,11 @@ func WithWakuRelayShardingBitVector(rs protocol.RelayShards) ENROption {
func WithWakuRelaySharding(rs protocol.RelayShards) ENROption {
return func(localnode *enode.LocalNode) error {
if len(rs.Indices) >= 64 {
if len(rs.ShardIDs) >= 64 {
return WithWakuRelayShardingBitVector(rs)(localnode)
}
return WithWakuRelayShardingIndicesList(rs)(localnode)
return WithWakuRelayShardList(rs)(localnode)
}
}
@ -60,7 +60,7 @@ func WithWakuRelayShardingTopics(topics ...string) ENROption {
// ENR record accessors
func RelayShardingIndicesList(record *enr.Record) (*protocol.RelayShards, error) {
func RelayShardList(record *enr.Record) (*protocol.RelayShards, error) {
var field []byte
if err := record.Load(enr.WithEntry(ShardingIndicesListEnrField, &field)); err != nil {
if enr.IsNotFound(err) {
@ -69,7 +69,7 @@ func RelayShardingIndicesList(record *enr.Record) (*protocol.RelayShards, error)
return nil, err
}
res, err := protocol.FromIndicesList(field)
res, err := protocol.FromShardList(field)
if err != nil {
return nil, err
}
@ -95,7 +95,7 @@ func RelayShardingBitVector(record *enr.Record) (*protocol.RelayShards, error) {
}
func RelaySharding(record *enr.Record) (*protocol.RelayShards, error) {
res, err := RelayShardingIndicesList(record)
res, err := RelayShardList(record)
if err != nil {
return nil, err
}

View File

@ -78,8 +78,8 @@ func (wakuM *WakuMetadata) getClusterAndShards() (*uint32, []uint32, error) {
}
var shards []uint32
if shard != nil && shard.Cluster == uint16(wakuM.clusterID) {
for _, idx := range shard.Indices {
if shard != nil && shard.ClusterID == uint16(wakuM.clusterID) {
for _, idx := range shard.ShardIDs {
shards = append(shards, uint32(idx))
}
}
@ -139,9 +139,9 @@ func (wakuM *WakuMetadata) Request(ctx context.Context, peerID peer.ID) (*protoc
}
result := &protocol.RelayShards{}
result.Cluster = uint16(*response.ClusterId)
result.ClusterID = uint16(*response.ClusterId)
for _, i := range response.Shards {
result.Indices = append(result.Indices, uint16(i))
result.ShardIDs = append(result.ShardIDs, uint16(i))
}
return result, nil
@ -226,7 +226,7 @@ func (wakuM *WakuMetadata) Connected(n network.Network, cc network.Conn) {
if err == nil {
if shard == nil {
err = errors.New("no shard reported")
} else if shard.Cluster != wakuM.clusterID {
} else if shard.ClusterID != wakuM.clusterID {
err = errors.New("different clusterID reported")
}
} else {

View File

@ -27,14 +27,14 @@ func createWakuMetadata(t *testing.T, rs *protocol.RelayShards) *WakuMetadata {
localNode, err := enr.NewLocalnode(key)
require.NoError(t, err)
cluster := uint16(0)
clusterID := uint16(0)
if rs != nil {
err = enr.WithWakuRelaySharding(*rs)(localNode)
require.NoError(t, err)
cluster = rs.Cluster
clusterID = rs.ClusterID
}
m1 := NewWakuMetadata(cluster, localNode, utils.Logger())
m1 := NewWakuMetadata(clusterID, localNode, utils.Logger())
m1.SetHost(host)
err = m1.Start(context.TODO())
require.NoError(t, err)
@ -65,25 +65,25 @@ func TestWakuMetadataRequest(t *testing.T) {
// Query a peer that is subscribed to a shard
result, err := m16_1.Request(context.Background(), m16_2.h.ID())
require.NoError(t, err)
require.Equal(t, testShard16, result.Cluster)
require.Equal(t, rs16_2.Indices, result.Indices)
require.Equal(t, testShard16, result.ClusterID)
require.Equal(t, rs16_2.ShardIDs, result.ShardIDs)
// Updating the peer shards
rs16_2.Indices = append(rs16_2.Indices, 3, 4)
rs16_2.ShardIDs = append(rs16_2.ShardIDs, 3, 4)
err = enr.WithWakuRelaySharding(rs16_2)(m16_2.localnode)
require.NoError(t, err)
// Query same peer, after that peer subscribes to more shards
result, err = m16_1.Request(context.Background(), m16_2.h.ID())
require.NoError(t, err)
require.Equal(t, testShard16, result.Cluster)
require.ElementsMatch(t, rs16_2.Indices, result.Indices)
require.Equal(t, testShard16, result.ClusterID)
require.ElementsMatch(t, rs16_2.ShardIDs, result.ShardIDs)
// Query a peer not subscribed to a shard
result, err = m16_1.Request(context.Background(), m_noRS.h.ID())
require.NoError(t, err)
require.Equal(t, uint16(0), result.Cluster)
require.Len(t, result.Indices, 0)
require.Equal(t, uint16(0), result.ClusterID)
require.Len(t, result.ShardIDs, 0)
}
func TestNoNetwork(t *testing.T) {

View File

@ -37,26 +37,26 @@ var ErrInvalidNumberFormat = errors.New("only 2^16 numbers are allowed")
// StaticShardingPubsubTopic describes a pubSub topic as per StaticSharding
type StaticShardingPubsubTopic struct {
cluster uint16
shard uint16
clusterID uint16
shardID uint16
}
// NewStaticShardingPubsubTopic creates a new pubSub topic
func NewStaticShardingPubsubTopic(cluster uint16, shard uint16) StaticShardingPubsubTopic {
return StaticShardingPubsubTopic{
cluster: cluster,
shard: shard,
clusterID: cluster,
shardID: shard,
}
}
// Cluster returns the sharded cluster index
func (s StaticShardingPubsubTopic) Cluster() uint16 {
return s.cluster
return s.clusterID
}
// Shard returns the shard number
func (s StaticShardingPubsubTopic) Shard() uint16 {
return s.shard
return s.shardID
}
// Equal compares StaticShardingPubsubTopic
@ -66,7 +66,7 @@ func (s StaticShardingPubsubTopic) Equal(t2 StaticShardingPubsubTopic) bool {
// String formats StaticShardingPubsubTopic to RFC 23 specific string format for pubsub topic.
func (s StaticShardingPubsubTopic) String() string {
return fmt.Sprintf("%s/%d/%d", StaticShardingPubsubTopicPrefix, s.cluster, s.shard)
return fmt.Sprintf("%s/%d/%d", StaticShardingPubsubTopicPrefix, s.clusterID, s.shardID)
}
// Parse parses a topic string into a StaticShardingPubsubTopic
@ -100,18 +100,18 @@ func (s *StaticShardingPubsubTopic) Parse(topic string) error {
return ErrInvalidNumberFormat
}
s.shard = uint16(shardInt)
s.cluster = uint16(clusterInt)
s.shardID = uint16(shardInt)
s.clusterID = uint16(clusterInt)
return nil
}
func ToShardPubsubTopic(topic WakuPubSubTopic) (StaticShardingPubsubTopic, error) {
result, ok := topic.(StaticShardingPubsubTopic)
if !ok {
return StaticShardingPubsubTopic{}, ErrNotShardPubsubTopic
}
return result, nil
result, ok := topic.(StaticShardingPubsubTopic)
if !ok {
return StaticShardingPubsubTopic{}, ErrNotShardPubsubTopic
}
return result, nil
}
// ToWakuPubsubTopic takes a pubSub topic string and creates a WakuPubsubTopic object.

View File

@ -13,57 +13,64 @@ import (
const MaxShardIndex = uint16(1023)
// ClusterIndex is the clusterID used in sharding space.
// For indices allocation and other magic numbers refer to RFC 51
// For shardIDs allocation and other magic numbers refer to RFC 51
const ClusterIndex = 1
// GenerationZeroShardsCount is number of shards supported in generation-0
const GenerationZeroShardsCount = 8
var (
ErrTooManyShards = errors.New("too many shards")
ErrInvalidShard = errors.New("invalid shard")
ErrInvalidShardCount = errors.New("invalid shard count")
ErrExpected130Bytes = errors.New("invalid data: expected 130 bytes")
)
type RelayShards struct {
Cluster uint16 `json:"cluster"`
Indices []uint16 `json:"indices"`
ClusterID uint16 `json:"clusterID"`
ShardIDs []uint16 `json:"shardIDs"`
}
func NewRelayShards(cluster uint16, indices ...uint16) (RelayShards, error) {
if len(indices) > math.MaxUint8 {
return RelayShards{}, errors.New("too many indices")
func NewRelayShards(clusterID uint16, shardIDs ...uint16) (RelayShards, error) {
if len(shardIDs) > math.MaxUint8 {
return RelayShards{}, ErrTooManyShards
}
indiceSet := make(map[uint16]struct{})
for _, index := range indices {
shardIDSet := make(map[uint16]struct{})
for _, index := range shardIDs {
if index > MaxShardIndex {
return RelayShards{}, errors.New("invalid index")
return RelayShards{}, ErrInvalidShard
}
indiceSet[index] = struct{}{} // dedup
shardIDSet[index] = struct{}{} // dedup
}
if len(indiceSet) == 0 {
return RelayShards{}, errors.New("invalid index count")
if len(shardIDSet) == 0 {
return RelayShards{}, ErrInvalidShardCount
}
indices = []uint16{}
for index := range indiceSet {
indices = append(indices, index)
shardIDs = []uint16{}
for index := range shardIDSet {
shardIDs = append(shardIDs, index)
}
return RelayShards{Cluster: cluster, Indices: indices}, nil
return RelayShards{ClusterID: clusterID, ShardIDs: shardIDs}, nil
}
func (rs RelayShards) Topics() []WakuPubSubTopic {
var result []WakuPubSubTopic
for _, i := range rs.Indices {
result = append(result, NewStaticShardingPubsubTopic(rs.Cluster, i))
for _, i := range rs.ShardIDs {
result = append(result, NewStaticShardingPubsubTopic(rs.ClusterID, i))
}
return result
}
func (rs RelayShards) Contains(cluster uint16, index uint16) bool {
if rs.Cluster != cluster {
if rs.ClusterID != cluster {
return false
}
found := false
for _, idx := range rs.Indices {
for _, idx := range rs.ShardIDs {
if idx == index {
found = true
}
@ -94,22 +101,22 @@ func TopicsToRelayShards(topic ...string) ([]RelayShards, error) {
return nil, err
}
indices, ok := dict[ps.cluster]
shardIDs, ok := dict[ps.clusterID]
if !ok {
indices = make(map[uint16]struct{})
shardIDs = make(map[uint16]struct{})
}
indices[ps.shard] = struct{}{}
dict[ps.cluster] = indices
shardIDs[ps.shardID] = struct{}{}
dict[ps.clusterID] = shardIDs
}
for cluster, indices := range dict {
idx := make([]uint16, 0, len(indices))
for index := range indices {
idx = append(idx, index)
for clusterID, shardIDs := range dict {
idx := make([]uint16, 0, len(shardIDs))
for shardID := range shardIDs {
idx = append(idx, shardID)
}
rs, err := NewRelayShards(cluster, idx...)
rs, err := NewRelayShards(clusterID, idx...)
if err != nil {
return nil, err
}
@ -128,23 +135,23 @@ func (rs RelayShards) ContainsTopic(topic string) bool {
return rs.ContainsShardPubsubTopic(wTopic)
}
func (rs RelayShards) IndicesList() ([]byte, error) {
if len(rs.Indices) > math.MaxUint8 {
return nil, errors.New("indices list too long")
func (rs RelayShards) ShardList() ([]byte, error) {
if len(rs.ShardIDs) > math.MaxUint8 {
return nil, ErrTooManyShards
}
var result []byte
result = binary.BigEndian.AppendUint16(result, rs.Cluster)
result = append(result, uint8(len(rs.Indices)))
for _, index := range rs.Indices {
result = binary.BigEndian.AppendUint16(result, rs.ClusterID)
result = append(result, uint8(len(rs.ShardIDs)))
for _, index := range rs.ShardIDs {
result = binary.BigEndian.AppendUint16(result, index)
}
return result, nil
}
func FromIndicesList(buf []byte) (RelayShards, error) {
func FromShardList(buf []byte) (RelayShards, error) {
if len(buf) < 3 {
return RelayShards{}, fmt.Errorf("insufficient data: expected at least 3 bytes, got %d bytes", len(buf))
}
@ -156,12 +163,12 @@ func FromIndicesList(buf []byte) (RelayShards, error) {
return RelayShards{}, fmt.Errorf("invalid data: `length` field is %d but %d bytes were provided", length, len(buf))
}
var indices []uint16
shardIDs := make([]uint16, length)
for i := 0; i < length; i++ {
indices = append(indices, binary.BigEndian.Uint16(buf[3+2*i:5+2*i]))
shardIDs[i] = binary.BigEndian.Uint16(buf[3+2*i : 5+2*i])
}
return NewRelayShards(cluster, indices...)
return NewRelayShards(cluster, shardIDs...)
}
func setBit(n byte, pos uint) byte {
@ -181,10 +188,10 @@ func (rs RelayShards) BitVector() []byte {
// of. The right-most bit in the bit vector represents shard 0, the left-most
// bit represents shard 1023.
var result []byte
result = binary.BigEndian.AppendUint16(result, rs.Cluster)
result = binary.BigEndian.AppendUint16(result, rs.ClusterID)
vec := make([]byte, 128)
for _, index := range rs.Indices {
for _, index := range rs.ShardIDs {
n := vec[index/8]
vec[index/8] = byte(setBit(n, uint(index%8)))
}
@ -195,11 +202,11 @@ func (rs RelayShards) BitVector() []byte {
// Generate a RelayShards from a byte slice
func FromBitVector(buf []byte) (RelayShards, error) {
if len(buf) != 130 {
return RelayShards{}, errors.New("invalid data: expected 130 bytes")
return RelayShards{}, ErrExpected130Bytes
}
cluster := binary.BigEndian.Uint16(buf[0:2])
var indices []uint16
var shardIDs []uint16
for i := uint16(0); i < 128; i++ {
for j := uint(0); j < 8; j++ {
@ -207,11 +214,11 @@ func FromBitVector(buf []byte) (RelayShards, error) {
continue
}
indices = append(indices, uint16(j)+8*i)
shardIDs = append(shardIDs, uint16(j)+8*i)
}
}
return RelayShards{Cluster: cluster, Indices: indices}, nil
return RelayShards{ClusterID: cluster, ShardIDs: shardIDs}, nil
}
// GetShardFromContentTopic runs Autosharding logic and returns a pubSubTopic

View File

@ -148,8 +148,8 @@ func (r *Rendezvous) RegisterShard(ctx context.Context, cluster uint16, shard ui
// RegisterRelayShards registers the node in the rendezvous point by specifying a RelayShards struct (more than one shard index can be registered)
func (r *Rendezvous) RegisterRelayShards(ctx context.Context, rs protocol.RelayShards, rendezvousPoints []*RendezvousPoint) {
for _, idx := range rs.Indices {
go r.RegisterShard(ctx, rs.Cluster, idx, rendezvousPoints)
for _, idx := range rs.ShardIDs {
go r.RegisterShard(ctx, rs.ClusterID, idx, rendezvousPoints)
}
}