249 lines
5.8 KiB
Go
249 lines
5.8 KiB
Go
|
package discovery
|
||
|
|
||
|
import (
|
||
|
"errors"
|
||
|
"sync"
|
||
|
"testing"
|
||
|
"time"
|
||
|
|
||
|
"github.com/ethereum/go-ethereum/p2p/discv5"
|
||
|
"github.com/stretchr/testify/require"
|
||
|
)
|
||
|
|
||
|
func newRegistry() *registry {
|
||
|
return ®istry{
|
||
|
storage: map[string][]int{},
|
||
|
}
|
||
|
}
|
||
|
|
||
|
type registry struct {
|
||
|
mu sync.Mutex
|
||
|
storage map[string][]int
|
||
|
}
|
||
|
|
||
|
func (r *registry) Add(topic string, id int) {
|
||
|
r.mu.Lock()
|
||
|
defer r.mu.Unlock()
|
||
|
r.storage[topic] = append(r.storage[topic], id)
|
||
|
}
|
||
|
|
||
|
func (r *registry) Get(topic string) []int {
|
||
|
r.mu.Lock()
|
||
|
defer r.mu.Unlock()
|
||
|
return r.storage[topic]
|
||
|
}
|
||
|
|
||
|
type fake struct {
|
||
|
started bool
|
||
|
err error
|
||
|
id int
|
||
|
registry *registry
|
||
|
}
|
||
|
|
||
|
func (f *fake) Start() error {
|
||
|
if f.err != nil {
|
||
|
return f.err
|
||
|
}
|
||
|
f.started = true
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (f *fake) Stop() error {
|
||
|
f.started = false
|
||
|
if f.err != nil {
|
||
|
return f.err
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (f *fake) Running() bool {
|
||
|
return f.started
|
||
|
}
|
||
|
|
||
|
func (f *fake) Register(topic string, stop chan struct{}) error {
|
||
|
if f.err != nil {
|
||
|
return f.err
|
||
|
}
|
||
|
f.registry.Add(topic, f.id)
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (f *fake) Discover(topic string, period <-chan time.Duration, found chan<- *discv5.Node, lookup chan<- bool) error {
|
||
|
if f.err != nil {
|
||
|
return f.err
|
||
|
}
|
||
|
for _, n := range f.registry.Get(topic) {
|
||
|
found <- discv5.NewNode(discv5.NodeID{byte(n)}, nil, 0, 0)
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
type testErrorCase struct {
|
||
|
desc string
|
||
|
errors []error
|
||
|
}
|
||
|
|
||
|
func errorCases() []testErrorCase {
|
||
|
return []testErrorCase{
|
||
|
{desc: "SingleError", errors: []error{nil, errors.New("test")}},
|
||
|
{desc: "NoErrors", errors: []error{nil, nil}},
|
||
|
{desc: "AllErrors", errors: []error{errors.New("test"), errors.New("test")}},
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestMuxerStart(t *testing.T) {
|
||
|
for _, tc := range errorCases() {
|
||
|
t.Run(tc.desc, func(t *testing.T) {
|
||
|
discoveries := make([]Discovery, len(tc.errors))
|
||
|
erred := false
|
||
|
for i, err := range tc.errors {
|
||
|
if err != nil {
|
||
|
erred = true
|
||
|
}
|
||
|
discoveries[i] = &fake{err: err}
|
||
|
}
|
||
|
muxer := NewMultiplexer(discoveries)
|
||
|
if erred {
|
||
|
require.Error(t, muxer.Start())
|
||
|
} else {
|
||
|
require.NoError(t, muxer.Start())
|
||
|
}
|
||
|
for _, d := range discoveries {
|
||
|
require.Equal(t, !erred, d.Running())
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestMuxerStop(t *testing.T) {
|
||
|
for _, tc := range errorCases() {
|
||
|
t.Run(tc.desc, func(t *testing.T) {
|
||
|
discoveries := make([]Discovery, len(tc.errors))
|
||
|
erred := false
|
||
|
for i, err := range tc.errors {
|
||
|
if err != nil {
|
||
|
erred = true
|
||
|
}
|
||
|
discoveries[i] = &fake{started: true, err: err}
|
||
|
}
|
||
|
muxer := NewMultiplexer(discoveries)
|
||
|
if erred {
|
||
|
require.Error(t, muxer.Stop())
|
||
|
} else {
|
||
|
require.NoError(t, muxer.Stop())
|
||
|
}
|
||
|
for _, d := range discoveries {
|
||
|
require.False(t, d.Running())
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestMuxerRunning(t *testing.T) {
|
||
|
for _, tc := range []struct {
|
||
|
desc string
|
||
|
started []bool
|
||
|
}{
|
||
|
{desc: "FirstRunning", started: []bool{false, true}},
|
||
|
{desc: "SecondRunning", started: []bool{true, false}},
|
||
|
{desc: "AllRunning", started: []bool{true, true}},
|
||
|
{desc: "NoRunning", started: []bool{false, false}},
|
||
|
} {
|
||
|
t.Run(tc.desc, func(t *testing.T) {
|
||
|
discoveries := make([]Discovery, len(tc.started))
|
||
|
allstarted := false
|
||
|
for i, start := range tc.started {
|
||
|
allstarted = start || allstarted
|
||
|
discoveries[i] = &fake{started: start}
|
||
|
}
|
||
|
require.Equal(t, allstarted, NewMultiplexer(discoveries).Running())
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestMuxerRegister(t *testing.T) {
|
||
|
for _, tc := range []struct {
|
||
|
desc string
|
||
|
errors []error
|
||
|
topics []string
|
||
|
}{
|
||
|
{"NoErrors", []error{nil, nil, nil}, []string{"a"}},
|
||
|
{"MultipleTopics", []error{nil, nil, nil}, []string{"a", "b", "c"}},
|
||
|
{"SingleError", []error{nil, errors.New("test"), nil}, []string{"a"}},
|
||
|
{"AllErrors", []error{errors.New("test"), errors.New("test"), errors.New("test")}, []string{"a"}},
|
||
|
} {
|
||
|
t.Run(tc.desc, func(t *testing.T) {
|
||
|
reg := newRegistry()
|
||
|
discoveries := make([]Discovery, len(tc.errors))
|
||
|
erred := 0
|
||
|
for i := range discoveries {
|
||
|
if tc.errors[i] != nil {
|
||
|
erred++
|
||
|
}
|
||
|
discoveries[i] = &fake{id: i, err: tc.errors[i], registry: reg}
|
||
|
}
|
||
|
muxer := NewMultiplexer(discoveries)
|
||
|
for _, topic := range tc.topics {
|
||
|
if erred != 0 {
|
||
|
require.Error(t, muxer.Register(topic, nil))
|
||
|
} else {
|
||
|
require.NoError(t, muxer.Register(topic, nil))
|
||
|
}
|
||
|
require.Equal(t, len(discoveries)-erred, len(reg.Get(topic)))
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestMuxerDiscovery(t *testing.T) {
|
||
|
for _, tc := range []struct {
|
||
|
desc string
|
||
|
errors []error
|
||
|
topics []string
|
||
|
ids [][]int
|
||
|
}{
|
||
|
{"EqualNoErrors", []error{nil, nil}, []string{"a"}, [][]int{{11, 22, 33}, {44, 55, 66}}},
|
||
|
{"MultiTopicsSingleSource", []error{nil, nil}, []string{"a", "b"}, [][]int{{11, 22, 33}, {}}},
|
||
|
{"SingleError", []error{nil, errors.New("test")}, []string{"a"}, [][]int{{11, 22, 33}, {44, 55, 66}}},
|
||
|
{"AllErrors", []error{errors.New("test"), errors.New("test")}, []string{"a"}, [][]int{{11, 22, 33}, {44, 55, 66}}},
|
||
|
} {
|
||
|
t.Run(tc.desc, func(t *testing.T) {
|
||
|
discoveries := make([]Discovery, len(tc.errors))
|
||
|
erred := false
|
||
|
expected := 0
|
||
|
for i := range discoveries {
|
||
|
if tc.errors[i] == nil {
|
||
|
expected += len(tc.ids[i])
|
||
|
} else {
|
||
|
erred = true
|
||
|
}
|
||
|
reg := newRegistry()
|
||
|
discoveries[i] = &fake{id: i, err: tc.errors[i], registry: reg}
|
||
|
for _, topic := range tc.topics {
|
||
|
for _, id := range tc.ids[i] {
|
||
|
reg.Add(topic, id)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
muxer := NewMultiplexer(discoveries)
|
||
|
for _, topic := range tc.topics {
|
||
|
found := make(chan *discv5.Node, expected)
|
||
|
period := make(chan time.Duration)
|
||
|
close(period)
|
||
|
if erred {
|
||
|
// TODO test period channel
|
||
|
require.Error(t, muxer.Discover(topic, period, found, nil))
|
||
|
} else {
|
||
|
require.NoError(t, muxer.Discover(topic, period, found, nil))
|
||
|
}
|
||
|
close(found)
|
||
|
count := 0
|
||
|
for range found {
|
||
|
count++
|
||
|
}
|
||
|
require.Equal(t, expected, count)
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
}
|