diff --git a/gossipsub.go b/gossipsub.go index 1fbf215..a5d55ce 100644 --- a/gossipsub.go +++ b/gossipsub.go @@ -287,6 +287,8 @@ func (gs *GossipSubRouter) Attach(p *PubSub) { // connect to direct peers if len(gs.direct) > 0 { go func() { + // add a small delay to make this unit-testable + time.Sleep(time.Second) for p := range gs.direct { gs.connect <- connectInfo{p: p} } diff --git a/gossipsub_test.go b/gossipsub_test.go index 98767e7..3ba51f2 100644 --- a/gossipsub_test.go +++ b/gossipsub_test.go @@ -9,6 +9,7 @@ import ( "time" "github.com/libp2p/go-libp2p-core/host" + "github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-libp2p-core/peerstore" ) @@ -982,3 +983,73 @@ func TestGossipsubStarTopology(t *testing.T) { } } } + +func TestGossipSubDirectPeers(t *testing.T) { + originalGossipSubDirectConnectTicks := GossipSubDirectConnectTicks + GossipSubDirectConnectTicks = 2 + defer func() { + GossipSubDirectConnectTicks = originalGossipSubDirectConnectTicks + }() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + h := getNetHosts(t, ctx, 3) + psubs := []*PubSub{ + getGossipsub(ctx, h[0]), + getGossipsub(ctx, h[1], WithDirectPeers([]peer.AddrInfo{peer.AddrInfo{h[2].ID(), h[2].Addrs()}})), + getGossipsub(ctx, h[2], WithDirectPeers([]peer.AddrInfo{peer.AddrInfo{h[1].ID(), h[1].Addrs()}})), + } + + connect(t, h[0], h[1]) + connect(t, h[0], h[2]) + + // verify that the direct peers connected + time.Sleep(2 * time.Second) + if len(h[1].Network().ConnsToPeer(h[2].ID())) == 0 { + t.Fatal("expected a connection between direct peers") + } + + // build the mesh + var subs []*Subscription + for _, ps := range psubs { + sub, err := ps.Subscribe("test") + if err != nil { + t.Fatal(err) + } + subs = append(subs, sub) + } + + time.Sleep(time.Second) + + // publish some messages + for i := 0; i < 3; i++ { + msg := []byte(fmt.Sprintf("message %d", i)) + psubs[i].Publish("test", msg) + + for _, sub := range subs { + assertReceive(t, sub, msg) + } + } + + // disconnect the direct peers to test reconnection + for _, c := range h[1].Network().ConnsToPeer(h[2].ID()) { + c.Close() + } + + time.Sleep(3 * time.Second) + + if len(h[1].Network().ConnsToPeer(h[2].ID())) == 0 { + t.Fatal("expected a connection between direct peers") + } + + // publish some messages + for i := 0; i < 3; i++ { + msg := []byte(fmt.Sprintf("message %d", i)) + psubs[i].Publish("test", msg) + + for _, sub := range subs { + assertReceive(t, sub, msg) + } + } +}