diff --git a/discovery/muxer.go b/discovery/muxer.go new file mode 100644 index 000000000..f651ab994 --- /dev/null +++ b/discovery/muxer.go @@ -0,0 +1,130 @@ +package discovery + +import ( + "fmt" + "strings" + "sync" + "time" + + "github.com/ethereum/go-ethereum/p2p/discv5" +) + +// NewMultiplexer creates Multiplexer instance. +func NewMultiplexer(discoveries []Discovery) Multiplexer { + return Multiplexer{discoveries} +} + +// Multiplexer allows to use multiple discoveries behind single Discovery interface. +type Multiplexer struct { + discoveries []Discovery +} + +// Running should return true if at least one discovery is running +func (m Multiplexer) Running() (rst bool) { + for i := range m.discoveries { + rst = rst || m.discoveries[i].Running() + } + return rst +} + +// Start every discovery and stop every started in case if at least one fails. +func (m Multiplexer) Start() (err error) { + started := []int{} + for i := range m.discoveries { + if err = m.discoveries[i].Start(); err != nil { + break + } + started = append(started, i) + } + if err != nil { + for _, i := range started { + _ = m.discoveries[i].Stop() + } + } + return err +} + +// Stop every discovery. +func (m Multiplexer) Stop() (err error) { + messages := []string{} + for i := range m.discoveries { + if err = m.discoveries[i].Stop(); err != nil { + messages = append(messages, err.Error()) + } + } + if len(messages) != 0 { + return fmt.Errorf("failed to stop discoveries: %s", strings.Join(messages, "; ")) + } + return nil +} + +// Register passed topic and stop channel to every discovery and waits till it will return. +func (m Multiplexer) Register(topic string, stop chan struct{}) error { + errors := make(chan error, len(m.discoveries)) + for i := range m.discoveries { + i := i + go func() { + errors <- m.discoveries[i].Register(topic, stop) + }() + } + total := 0 + messages := []string{} + for err := range errors { + total++ + if err != nil { + messages = append(messages, err.Error()) + } + if total == len(m.discoveries) { + break + } + } + if len(messages) != 0 { + return fmt.Errorf("failed to register %s: %s", topic, strings.Join(messages, "; ")) + } + return nil +} + +// Discover shares topic and channles for receiving results. And multiplexer periods that are sent to period channel. +func (m Multiplexer) Discover(topic string, period <-chan time.Duration, found chan<- *discv5.Node, lookup chan<- bool) error { + var ( + periods = make([]chan time.Duration, len(m.discoveries)) + messages = []string{} + wg sync.WaitGroup + mu sync.Mutex + ) + wg.Add(len(m.discoveries) + 1) + for i := range m.discoveries { + i := i + periods[i] = make(chan time.Duration, 2) + go func() { + err := m.discoveries[i].Discover(topic, periods[i], found, lookup) + if err != nil { + mu.Lock() + messages = append(messages, err.Error()) + mu.Unlock() + } + wg.Done() + }() + } + go func() { + for { + newPeriod, ok := <-period + for i := range periods { + if !ok { + close(periods[i]) + } else { + periods[i] <- newPeriod + } + } + if !ok { + wg.Done() + return + } + } + }() + wg.Wait() + if len(messages) != 0 { + return fmt.Errorf("failed to discover topic %s: %s", topic, strings.Join(messages, "; ")) + } + return nil +} diff --git a/discovery/muxer_test.go b/discovery/muxer_test.go new file mode 100644 index 000000000..ed1805319 --- /dev/null +++ b/discovery/muxer_test.go @@ -0,0 +1,248 @@ +package discovery + +import ( + "errors" + "sync" + "testing" + "time" + + "github.com/ethereum/go-ethereum/p2p/discv5" + "github.com/stretchr/testify/require" +) + +func newRegistry() *registry { + return ®istry{ + storage: map[string][]int{}, + } +} + +type registry struct { + mu sync.Mutex + storage map[string][]int +} + +func (r *registry) Add(topic string, id int) { + r.mu.Lock() + defer r.mu.Unlock() + r.storage[topic] = append(r.storage[topic], id) +} + +func (r *registry) Get(topic string) []int { + r.mu.Lock() + defer r.mu.Unlock() + return r.storage[topic] +} + +type fake struct { + started bool + err error + id int + registry *registry +} + +func (f *fake) Start() error { + if f.err != nil { + return f.err + } + f.started = true + return nil +} + +func (f *fake) Stop() error { + f.started = false + if f.err != nil { + return f.err + } + return nil +} + +func (f *fake) Running() bool { + return f.started +} + +func (f *fake) Register(topic string, stop chan struct{}) error { + if f.err != nil { + return f.err + } + f.registry.Add(topic, f.id) + return nil +} + +func (f *fake) Discover(topic string, period <-chan time.Duration, found chan<- *discv5.Node, lookup chan<- bool) error { + if f.err != nil { + return f.err + } + for _, n := range f.registry.Get(topic) { + found <- discv5.NewNode(discv5.NodeID{byte(n)}, nil, 0, 0) + } + return nil +} + +type testErrorCase struct { + desc string + errors []error +} + +func errorCases() []testErrorCase { + return []testErrorCase{ + {desc: "SingleError", errors: []error{nil, errors.New("test")}}, + {desc: "NoErrors", errors: []error{nil, nil}}, + {desc: "AllErrors", errors: []error{errors.New("test"), errors.New("test")}}, + } +} + +func TestMuxerStart(t *testing.T) { + for _, tc := range errorCases() { + t.Run(tc.desc, func(t *testing.T) { + discoveries := make([]Discovery, len(tc.errors)) + erred := false + for i, err := range tc.errors { + if err != nil { + erred = true + } + discoveries[i] = &fake{err: err} + } + muxer := NewMultiplexer(discoveries) + if erred { + require.Error(t, muxer.Start()) + } else { + require.NoError(t, muxer.Start()) + } + for _, d := range discoveries { + require.Equal(t, !erred, d.Running()) + } + }) + } +} + +func TestMuxerStop(t *testing.T) { + for _, tc := range errorCases() { + t.Run(tc.desc, func(t *testing.T) { + discoveries := make([]Discovery, len(tc.errors)) + erred := false + for i, err := range tc.errors { + if err != nil { + erred = true + } + discoveries[i] = &fake{started: true, err: err} + } + muxer := NewMultiplexer(discoveries) + if erred { + require.Error(t, muxer.Stop()) + } else { + require.NoError(t, muxer.Stop()) + } + for _, d := range discoveries { + require.False(t, d.Running()) + } + }) + } +} + +func TestMuxerRunning(t *testing.T) { + for _, tc := range []struct { + desc string + started []bool + }{ + {desc: "FirstRunning", started: []bool{false, true}}, + {desc: "SecondRunning", started: []bool{true, false}}, + {desc: "AllRunning", started: []bool{true, true}}, + {desc: "NoRunning", started: []bool{false, false}}, + } { + t.Run(tc.desc, func(t *testing.T) { + discoveries := make([]Discovery, len(tc.started)) + allstarted := false + for i, start := range tc.started { + allstarted = start || allstarted + discoveries[i] = &fake{started: start} + } + require.Equal(t, allstarted, NewMultiplexer(discoveries).Running()) + }) + } +} + +func TestMuxerRegister(t *testing.T) { + for _, tc := range []struct { + desc string + errors []error + topics []string + }{ + {"NoErrors", []error{nil, nil, nil}, []string{"a"}}, + {"MultipleTopics", []error{nil, nil, nil}, []string{"a", "b", "c"}}, + {"SingleError", []error{nil, errors.New("test"), nil}, []string{"a"}}, + {"AllErrors", []error{errors.New("test"), errors.New("test"), errors.New("test")}, []string{"a"}}, + } { + t.Run(tc.desc, func(t *testing.T) { + reg := newRegistry() + discoveries := make([]Discovery, len(tc.errors)) + erred := 0 + for i := range discoveries { + if tc.errors[i] != nil { + erred++ + } + discoveries[i] = &fake{id: i, err: tc.errors[i], registry: reg} + } + muxer := NewMultiplexer(discoveries) + for _, topic := range tc.topics { + if erred != 0 { + require.Error(t, muxer.Register(topic, nil)) + } else { + require.NoError(t, muxer.Register(topic, nil)) + } + require.Equal(t, len(discoveries)-erred, len(reg.Get(topic))) + } + }) + } +} + +func TestMuxerDiscovery(t *testing.T) { + for _, tc := range []struct { + desc string + errors []error + topics []string + ids [][]int + }{ + {"EqualNoErrors", []error{nil, nil}, []string{"a"}, [][]int{{11, 22, 33}, {44, 55, 66}}}, + {"MultiTopicsSingleSource", []error{nil, nil}, []string{"a", "b"}, [][]int{{11, 22, 33}, {}}}, + {"SingleError", []error{nil, errors.New("test")}, []string{"a"}, [][]int{{11, 22, 33}, {44, 55, 66}}}, + {"AllErrors", []error{errors.New("test"), errors.New("test")}, []string{"a"}, [][]int{{11, 22, 33}, {44, 55, 66}}}, + } { + t.Run(tc.desc, func(t *testing.T) { + discoveries := make([]Discovery, len(tc.errors)) + erred := false + expected := 0 + for i := range discoveries { + if tc.errors[i] == nil { + expected += len(tc.ids[i]) + } else { + erred = true + } + reg := newRegistry() + discoveries[i] = &fake{id: i, err: tc.errors[i], registry: reg} + for _, topic := range tc.topics { + for _, id := range tc.ids[i] { + reg.Add(topic, id) + } + } + } + muxer := NewMultiplexer(discoveries) + for _, topic := range tc.topics { + found := make(chan *discv5.Node, expected) + period := make(chan time.Duration) + close(period) + if erred { + // TODO test period channel + require.Error(t, muxer.Discover(topic, period, found, nil)) + } else { + require.NoError(t, muxer.Discover(topic, period, found, nil)) + } + close(found) + count := 0 + for range found { + count++ + } + require.Equal(t, expected, count) + } + }) + } +} diff --git a/node/status_node.go b/node/status_node.go index 2d4ed8a87..5138a1d5f 100644 --- a/node/status_node.go +++ b/node/status_node.go @@ -184,40 +184,46 @@ func (n *StatusNode) discoveryEnabled() bool { return n.config != nil && (!n.config.NoDiscovery || n.config.Rendezvous) && n.config.ClusterConfig != nil } -func (n *StatusNode) startRendezvous() error { +func (n *StatusNode) startRendezvous() (discovery.Discovery, error) { if !n.config.Rendezvous { - return errors.New("rendezvous is not enabled") + return nil, errors.New("rendezvous is not enabled") } if len(n.config.ClusterConfig.RendezvousNodes) == 0 { - return errors.New("rendezvous node must be provided if rendezvous discovery is enabled") + return nil, errors.New("rendezvous node must be provided if rendezvous discovery is enabled") } maddrs := make([]ma.Multiaddr, len(n.config.ClusterConfig.RendezvousNodes)) for i, addr := range n.config.ClusterConfig.RendezvousNodes { var err error maddrs[i], err = ma.NewMultiaddr(addr) if err != nil { - return fmt.Errorf("failed to parse rendezvous node %s: %v", n.config.ClusterConfig.RendezvousNodes[0], err) + return nil, fmt.Errorf("failed to parse rendezvous node %s: %v", n.config.ClusterConfig.RendezvousNodes[0], err) } } srv := n.gethNode.Server() - var err error - n.discovery, err = discovery.NewRendezvous(maddrs, srv.PrivateKey, srv.Self()) - return err + return discovery.NewRendezvous(maddrs, srv.PrivateKey, srv.Self()) } func (n *StatusNode) startDiscovery() error { - if !n.config.NoDiscovery && n.config.Rendezvous { - return errors.New("only one discovery can be used (will be allowed to use more in next change)") - } + discoveries := []discovery.Discovery{} if !n.config.NoDiscovery { - n.discovery = discovery.NewDiscV5( + discoveries = append(discoveries, discovery.NewDiscV5( n.gethNode.Server().PrivateKey, n.config.ListenAddr, - parseNodesV5(n.config.ClusterConfig.BootNodes)) - } else if n.config.Rendezvous { - if err := n.startRendezvous(); err != nil { + parseNodesV5(n.config.ClusterConfig.BootNodes))) + } + if n.config.Rendezvous { + d, err := n.startRendezvous() + if err != nil { return err } + discoveries = append(discoveries, d) + } + if len(discoveries) == 0 { + return errors.New("wasn't able to register any discovery") + } else if len(discoveries) > 1 { + n.discovery = discovery.NewMultiplexer(discoveries) + } else { + n.discovery = discoveries[0] } n.register = peers.NewRegister(n.discovery, n.config.RegisterTopics...) options := peers.NewDefaultOptions()