2
0
mirror of https://github.com/status-im/status-go.git synced 2025-01-12 15:45:07 +00:00

250 lines
5.8 KiB
Go
Raw Normal View History

2018-07-20 09:46:29 +03:00
package discovery
import (
"errors"
"sync"
"testing"
"time"
"github.com/stretchr/testify/require"
2020-01-02 10:10:19 +01:00
"github.com/ethereum/go-ethereum/p2p/discv5"
2018-07-20 09:46:29 +03:00
)
func newRegistry() *registry {
return &registry{
storage: map[string][]int{},
}
}
type registry struct {
mu sync.Mutex
storage map[string][]int
}
func (r *registry) Add(topic string, id int) {
r.mu.Lock()
defer r.mu.Unlock()
r.storage[topic] = append(r.storage[topic], id)
}
func (r *registry) Get(topic string) []int {
r.mu.Lock()
defer r.mu.Unlock()
return r.storage[topic]
}
type fake struct {
started bool
err error
id int
registry *registry
}
func (f *fake) Start() error {
if f.err != nil {
return f.err
}
f.started = true
return nil
}
func (f *fake) Stop() error {
f.started = false
if f.err != nil {
return f.err
}
return nil
}
func (f *fake) Running() bool {
return f.started
}
func (f *fake) Register(topic string, stop chan struct{}) error {
if f.err != nil {
return f.err
}
f.registry.Add(topic, f.id)
return nil
}
func (f *fake) Discover(topic string, period <-chan time.Duration, found chan<- *discv5.Node, lookup chan<- bool) error {
if f.err != nil {
return f.err
}
for _, n := range f.registry.Get(topic) {
found <- discv5.NewNode(discv5.NodeID{byte(n)}, nil, 0, 0)
}
return nil
}
type testErrorCase struct {
desc string
errors []error
}
func errorCases() []testErrorCase {
return []testErrorCase{
{desc: "SingleError", errors: []error{nil, errors.New("test")}},
{desc: "NoErrors", errors: []error{nil, nil}},
{desc: "AllErrors", errors: []error{errors.New("test"), errors.New("test")}},
}
}
func TestMuxerStart(t *testing.T) {
for _, tc := range errorCases() {
t.Run(tc.desc, func(t *testing.T) {
discoveries := make([]Discovery, len(tc.errors))
erred := false
for i, err := range tc.errors {
if err != nil {
erred = true
}
discoveries[i] = &fake{err: err}
}
muxer := NewMultiplexer(discoveries)
if erred {
require.Error(t, muxer.Start())
} else {
require.NoError(t, muxer.Start())
}
for _, d := range discoveries {
require.Equal(t, !erred, d.Running())
}
})
}
}
func TestMuxerStop(t *testing.T) {
for _, tc := range errorCases() {
t.Run(tc.desc, func(t *testing.T) {
discoveries := make([]Discovery, len(tc.errors))
erred := false
for i, err := range tc.errors {
if err != nil {
erred = true
}
discoveries[i] = &fake{started: true, err: err}
}
muxer := NewMultiplexer(discoveries)
if erred {
require.Error(t, muxer.Stop())
} else {
require.NoError(t, muxer.Stop())
}
for _, d := range discoveries {
require.False(t, d.Running())
}
})
}
}
func TestMuxerRunning(t *testing.T) {
for _, tc := range []struct {
desc string
started []bool
}{
{desc: "FirstRunning", started: []bool{false, true}},
{desc: "SecondRunning", started: []bool{true, false}},
{desc: "AllRunning", started: []bool{true, true}},
{desc: "NoRunning", started: []bool{false, false}},
} {
t.Run(tc.desc, func(t *testing.T) {
discoveries := make([]Discovery, len(tc.started))
allstarted := false
for i, start := range tc.started {
allstarted = start || allstarted
discoveries[i] = &fake{started: start}
}
require.Equal(t, allstarted, NewMultiplexer(discoveries).Running())
})
}
}
func TestMuxerRegister(t *testing.T) {
for _, tc := range []struct {
desc string
errors []error
topics []string
}{
{"NoErrors", []error{nil, nil, nil}, []string{"a"}},
{"MultipleTopics", []error{nil, nil, nil}, []string{"a", "b", "c"}},
{"SingleError", []error{nil, errors.New("test"), nil}, []string{"a"}},
{"AllErrors", []error{errors.New("test"), errors.New("test"), errors.New("test")}, []string{"a"}},
} {
t.Run(tc.desc, func(t *testing.T) {
reg := newRegistry()
discoveries := make([]Discovery, len(tc.errors))
erred := 0
for i := range discoveries {
if tc.errors[i] != nil {
erred++
}
discoveries[i] = &fake{id: i, err: tc.errors[i], registry: reg}
}
muxer := NewMultiplexer(discoveries)
for _, topic := range tc.topics {
if erred != 0 {
require.Error(t, muxer.Register(topic, nil))
} else {
require.NoError(t, muxer.Register(topic, nil))
}
require.Equal(t, len(discoveries)-erred, len(reg.Get(topic)))
}
})
}
}
func TestMuxerDiscovery(t *testing.T) {
for _, tc := range []struct {
desc string
errors []error
topics []string
ids [][]int
}{
{"EqualNoErrors", []error{nil, nil}, []string{"a"}, [][]int{{11, 22, 33}, {44, 55, 66}}},
{"MultiTopicsSingleSource", []error{nil, nil}, []string{"a", "b"}, [][]int{{11, 22, 33}, {}}},
{"SingleError", []error{nil, errors.New("test")}, []string{"a"}, [][]int{{11, 22, 33}, {44, 55, 66}}},
{"AllErrors", []error{errors.New("test"), errors.New("test")}, []string{"a"}, [][]int{{11, 22, 33}, {44, 55, 66}}},
} {
t.Run(tc.desc, func(t *testing.T) {
discoveries := make([]Discovery, len(tc.errors))
erred := false
expected := 0
for i := range discoveries {
if tc.errors[i] == nil {
expected += len(tc.ids[i])
} else {
erred = true
}
reg := newRegistry()
discoveries[i] = &fake{id: i, err: tc.errors[i], registry: reg}
for _, topic := range tc.topics {
for _, id := range tc.ids[i] {
reg.Add(topic, id)
}
}
}
muxer := NewMultiplexer(discoveries)
for _, topic := range tc.topics {
found := make(chan *discv5.Node, expected)
period := make(chan time.Duration)
close(period)
if erred {
// TODO test period channel
require.Error(t, muxer.Discover(topic, period, found, nil))
} else {
require.NoError(t, muxer.Discover(topic, period, found, nil))
}
close(found)
count := 0
for range found {
count++
}
require.Equal(t, expected, count)
}
})
}
}