From 782897ea41f376916c9f7fab6e1e344597f8bc94 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 5 Sep 2021 19:00:36 +0100 Subject: [PATCH 1/3] stop using goprocess to control teardown --- p2p/net/swarm/dial_test.go | 86 ++++++++++++--------------- p2p/net/swarm/peers_test.go | 8 +-- p2p/net/swarm/simul_test.go | 11 ++-- p2p/net/swarm/swarm.go | 49 ++++----------- p2p/net/swarm/swarm_addr_test.go | 32 +++------- p2p/net/swarm/swarm_listen.go | 2 +- p2p/net/swarm/swarm_net_test.go | 35 +++-------- p2p/net/swarm/swarm_notif_test.go | 14 ++--- p2p/net/swarm/swarm_test.go | 61 +++++++------------ p2p/net/swarm/testing/testing.go | 25 +++++--- p2p/net/swarm/testing/testing_test.go | 3 +- p2p/net/swarm/transport_test.go | 29 ++++----- 12 files changed, 137 insertions(+), 218 deletions(-) diff --git a/p2p/net/swarm/dial_test.go b/p2p/net/swarm/dial_test.go index 3bbc5e2b..de41e7f7 100644 --- a/p2p/net/swarm/dial_test.go +++ b/p2p/net/swarm/dial_test.go @@ -7,28 +7,33 @@ import ( "testing" "time" - addrutil "github.com/libp2p/go-addr-util" + . "github.com/libp2p/go-libp2p-swarm" + addrutil "github.com/libp2p/go-addr-util" "github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-libp2p-core/peerstore" - "github.com/libp2p/go-libp2p-core/transport" - testutil "github.com/libp2p/go-libp2p-core/test" + "github.com/libp2p/go-libp2p-core/transport" swarmt "github.com/libp2p/go-libp2p-swarm/testing" "github.com/libp2p/go-libp2p-testing/ci" ma "github.com/multiformats/go-multiaddr" manet "github.com/multiformats/go-multiaddr/net" - . "github.com/libp2p/go-libp2p-swarm" + "github.com/stretchr/testify/require" ) func init() { transport.DialTimeout = time.Second } -func closeSwarms(swarms []*Swarm) { +type swarmWithBackoff interface { + network.Network + Backoff() *DialBackoff +} + +func closeSwarms(swarms []network.Network) { for _, s := range swarms { s.Close() } @@ -36,50 +41,37 @@ func closeSwarms(swarms []*Swarm) { func TestBasicDialPeer(t *testing.T) { t.Parallel() - ctx := context.Background() - swarms := makeSwarms(ctx, t, 2) + swarms := makeSwarms(t, 2) defer closeSwarms(swarms) s1 := swarms[0] s2 := swarms[1] s1.Peerstore().AddAddrs(s2.LocalPeer(), s2.ListenAddresses(), peerstore.PermanentAddrTTL) - c, err := s1.DialPeer(ctx, s2.LocalPeer()) - if err != nil { - t.Fatal(err) - } - - s, err := c.NewStream(ctx) - if err != nil { - t.Fatal(err) - } + c, err := s1.DialPeer(context.Background(), s2.LocalPeer()) + require.NoError(t, err) + s, err := c.NewStream(context.Background()) + require.NoError(t, err) s.Close() } func TestDialWithNoListeners(t *testing.T) { t.Parallel() - ctx := context.Background() - s1 := makeDialOnlySwarm(ctx, t) - - swarms := makeSwarms(ctx, t, 1) + s1 := makeDialOnlySwarm(t) + swarms := makeSwarms(t, 1) defer closeSwarms(swarms) s2 := swarms[0] s1.Peerstore().AddAddrs(s2.LocalPeer(), s2.ListenAddresses(), peerstore.PermanentAddrTTL) - c, err := s1.DialPeer(ctx, s2.LocalPeer()) - if err != nil { - t.Fatal(err) - } - - s, err := c.NewStream(ctx) - if err != nil { - t.Fatal(err) - } + c, err := s1.DialPeer(context.Background(), s2.LocalPeer()) + require.NoError(t, err) + s, err := c.NewStream(context.Background()) + require.NoError(t, err) s.Close() } @@ -104,12 +96,12 @@ func TestSimultDials(t *testing.T) { t.Parallel() ctx := context.Background() - swarms := makeSwarms(ctx, t, 2, swarmt.OptDisableReuseport) + swarms := makeSwarms(t, 2, swarmt.OptDisableReuseport) // connect everyone { var wg sync.WaitGroup - connect := func(s *Swarm, dst peer.ID, addr ma.Multiaddr) { + connect := func(s network.Network, dst peer.ID, addr ma.Multiaddr) { // copy for other peer log.Debugf("TestSimultOpen: connecting: %s --> %s (%s)", s.LocalPeer(), dst, addr) s.Peerstore().AddAddr(dst, addr, peerstore.TempAddrTTL) @@ -175,7 +167,7 @@ func TestDialWait(t *testing.T) { t.Parallel() ctx := context.Background() - swarms := makeSwarms(ctx, t, 1) + swarms := makeSwarms(t, 1) s1 := swarms[0] defer s1.Close() @@ -201,7 +193,7 @@ func TestDialWait(t *testing.T) { t.Error("> 2*transport.DialTimeout * DialAttempts not being respected", duration, 2*transport.DialTimeout*DialAttempts) } - if !s1.Backoff().Backoff(s2p, s2addr) { + if !s1.(swarmWithBackoff).Backoff().Backoff(s2p, s2addr) { t.Error("s2 should now be on backoff") } } @@ -215,7 +207,7 @@ func TestDialBackoff(t *testing.T) { t.Parallel() ctx := context.Background() - swarms := makeSwarms(ctx, t, 2) + swarms := makeSwarms(t, 2) s1 := swarms[0] s2 := swarms[1] defer s1.Close() @@ -338,10 +330,10 @@ func TestDialBackoff(t *testing.T) { } // check backoff state - if s1.Backoff().Backoff(s2.LocalPeer(), s2addrs[0]) { + if s1.(swarmWithBackoff).Backoff().Backoff(s2.LocalPeer(), s2addrs[0]) { t.Error("s2 should not be on backoff") } - if !s1.Backoff().Backoff(s3p, s3addr) { + if !s1.(swarmWithBackoff).Backoff().Backoff(s3p, s3addr) { t.Error("s3 should be on backoff") } @@ -408,10 +400,10 @@ func TestDialBackoff(t *testing.T) { } // check backoff state (the same) - if s1.Backoff().Backoff(s2.LocalPeer(), s2addrs[0]) { + if s1.(swarmWithBackoff).Backoff().Backoff(s2.LocalPeer(), s2addrs[0]) { t.Error("s2 should not be on backoff") } - if !s1.Backoff().Backoff(s3p, s3addr) { + if !s1.(swarmWithBackoff).Backoff().Backoff(s3p, s3addr) { t.Error("s3 should be on backoff") } } @@ -422,7 +414,7 @@ func TestDialBackoffClears(t *testing.T) { t.Parallel() ctx := context.Background() - swarms := makeSwarms(ctx, t, 2) + swarms := makeSwarms(t, 2) s1 := swarms[0] s2 := swarms[1] defer s1.Close() @@ -453,7 +445,7 @@ func TestDialBackoffClears(t *testing.T) { t.Error("> 2*transport.DialTimeout * DialAttempts not being respected", duration, 2*transport.DialTimeout*DialAttempts) } - if !s1.Backoff().Backoff(s2.LocalPeer(), s2bad) { + if !s1.(swarmWithBackoff).Backoff().Backoff(s2.LocalPeer(), s2bad) { t.Error("s2 should now be on backoff") } else { t.Log("correctly added to backoff") @@ -480,7 +472,7 @@ func TestDialBackoffClears(t *testing.T) { t.Log("correctly connected") } - if s1.Backoff().Backoff(s2.LocalPeer(), s2bad) { + if s1.(swarmWithBackoff).Backoff().Backoff(s2.LocalPeer(), s2bad) { t.Error("s2 should no longer be on backoff") } else { t.Log("correctly cleared backoff") @@ -491,7 +483,7 @@ func TestDialPeerFailed(t *testing.T) { t.Parallel() ctx := context.Background() - swarms := makeSwarms(ctx, t, 2) + swarms := makeSwarms(t, 2) defer closeSwarms(swarms) testedSwarm, targetSwarm := swarms[0], swarms[1] @@ -530,7 +522,7 @@ func TestDialPeerFailed(t *testing.T) { func TestDialExistingConnection(t *testing.T) { ctx := context.Background() - swarms := makeSwarms(ctx, t, 2) + swarms := makeSwarms(t, 2) defer closeSwarms(swarms) s1 := swarms[0] s2 := swarms[1] @@ -574,7 +566,7 @@ func TestDialSimultaneousJoin(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - swarms := makeSwarms(ctx, t, 2) + swarms := makeSwarms(t, 2) s1 := swarms[0] s2 := swarms[1] defer s1.Close() @@ -676,12 +668,10 @@ func TestDialSelf(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - swarms := makeSwarms(ctx, t, 2) + swarms := makeSwarms(t, 2) s1 := swarms[0] defer s1.Close() _, err := s1.DialPeer(ctx, s1.LocalPeer()) - if err != ErrDialToSelf { - t.Fatal("expected error from self dial") - } + require.ErrorIs(t, err, ErrDialToSelf, "expected error from self dial") } diff --git a/p2p/net/swarm/peers_test.go b/p2p/net/swarm/peers_test.go index 8e82bf5b..908abe91 100644 --- a/p2p/net/swarm/peers_test.go +++ b/p2p/net/swarm/peers_test.go @@ -9,17 +9,15 @@ import ( "github.com/libp2p/go-libp2p-core/peerstore" ma "github.com/multiformats/go-multiaddr" - - . "github.com/libp2p/go-libp2p-swarm" ) func TestPeers(t *testing.T) { ctx := context.Background() - swarms := makeSwarms(ctx, t, 2) + swarms := makeSwarms(t, 2) s1 := swarms[0] s2 := swarms[1] - connect := func(s *Swarm, dst peer.ID, addr ma.Multiaddr) { + connect := func(s network.Network, dst peer.ID, addr ma.Multiaddr) { // TODO: make a DialAddr func. s.Peerstore().AddAddr(dst, addr, peerstore.PermanentAddrTTL) // t.Logf("connections from %s", s.LocalPeer()) @@ -55,7 +53,7 @@ func TestPeers(t *testing.T) { log.Infof("%s swarm routing table: %s", s.LocalPeer(), s.Peers()) } - test := func(s *Swarm) { + test := func(s network.Network) { expect := 1 actual := len(s.Peers()) if actual != expect { diff --git a/p2p/net/swarm/simul_test.go b/p2p/net/swarm/simul_test.go index 0373e37d..326c4e21 100644 --- a/p2p/net/swarm/simul_test.go +++ b/p2p/net/swarm/simul_test.go @@ -7,32 +7,29 @@ import ( "testing" "time" + "github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-libp2p-core/peerstore" ma "github.com/multiformats/go-multiaddr" - . "github.com/libp2p/go-libp2p-swarm" swarmt "github.com/libp2p/go-libp2p-swarm/testing" "github.com/libp2p/go-libp2p-testing/ci" ) func TestSimultOpen(t *testing.T) { - t.Parallel() - - ctx := context.Background() - swarms := makeSwarms(ctx, t, 2, swarmt.OptDisableReuseport) + swarms := makeSwarms(t, 2, swarmt.OptDisableReuseport) // connect everyone { var wg sync.WaitGroup - connect := func(s *Swarm, dst peer.ID, addr ma.Multiaddr) { + connect := func(s network.Network, dst peer.ID, addr ma.Multiaddr) { defer wg.Done() // copy for other peer log.Debugf("TestSimultOpen: connecting: %s --> %s (%s)", s.LocalPeer(), dst, addr) s.Peerstore().AddAddr(dst, addr, peerstore.PermanentAddrTTL) - if _, err := s.DialPeer(ctx, dst); err != nil { + if _, err := s.DialPeer(context.Background(), dst); err != nil { t.Error("error swarm dialing to peer", err) } } diff --git a/p2p/net/swarm/swarm.go b/p2p/net/swarm/swarm.go index c34497f7..0f4e3c75 100644 --- a/p2p/net/swarm/swarm.go +++ b/p2p/net/swarm/swarm.go @@ -18,8 +18,6 @@ import ( "github.com/libp2p/go-libp2p-core/transport" logging "github.com/ipfs/go-log" - "github.com/jbenet/goprocess" - goprocessctx "github.com/jbenet/goprocess/context" ma "github.com/multiformats/go-multiaddr" ) @@ -92,9 +90,10 @@ type Swarm struct { limiter *dialLimiter gater connmgr.ConnectionGater - proc goprocess.Process - ctx context.Context - bwc metrics.Reporter + ctx context.Context // is canceled when Close is called + ctxCancel context.CancelFunc + + bwc metrics.Reporter } // NewSwarm constructs a Swarm. @@ -103,11 +102,14 @@ type Swarm struct { // `extra` interface{} parameter facilitates the future migration. Supported // elements are: // - connmgr.ConnectionGater -func NewSwarm(ctx context.Context, local peer.ID, peers peerstore.Peerstore, bwc metrics.Reporter, extra ...interface{}) *Swarm { +func NewSwarm(local peer.ID, peers peerstore.Peerstore, bwc metrics.Reporter, extra ...interface{}) *Swarm { + ctx, cancel := context.WithCancel(context.Background()) s := &Swarm{ - local: local, - peers: peers, - bwc: bwc, + local: local, + peers: peers, + bwc: bwc, + ctx: ctx, + ctxCancel: cancel, } s.conns.m = make(map[peer.ID][]*Conn) @@ -124,23 +126,11 @@ func NewSwarm(ctx context.Context, local peer.ID, peers peerstore.Peerstore, bwc s.dsync = newDialSync(s.dialWorkerLoop) s.limiter = newDialLimiter(s.dialAddr) - s.proc = goprocessctx.WithContext(ctx) - s.ctx = goprocessctx.OnClosingContext(s.proc) s.backf.init(s.ctx) - - // Set teardown after setting the context/process so we don't start the - // teardown process early. - s.proc.SetTeardown(s.teardown) - return s } -func (s *Swarm) teardown() error { - // Wait for the context to be canceled. - // This allows other parts of the swarm to detect that we're shutting - // down. - <-s.ctx.Done() - +func (s *Swarm) Close() error { // Prevents new connections and/or listeners from being added to the swarm. s.listeners.Lock() @@ -201,11 +191,6 @@ func (s *Swarm) teardown() error { return nil } -// Process returns the Process of the swarm -func (s *Swarm) Process() goprocess.Process { - return s.proc -} - func (s *Swarm) addConn(tc transport.CapableConn, dir network.Direction) (*Conn, error) { var ( p = tc.RemotePeer() @@ -293,16 +278,6 @@ func (s *Swarm) Peerstore() peerstore.Peerstore { return s.peers } -// Context returns the context of the swarm -func (s *Swarm) Context() context.Context { - return s.ctx -} - -// Close stops the Swarm. -func (s *Swarm) Close() error { - return s.proc.Close() -} - // TODO: We probably don't need the conn handlers. // SetConnHandler assigns the handler for new connections. diff --git a/p2p/net/swarm/swarm_addr_test.go b/p2p/net/swarm/swarm_addr_test.go index baeac462..21ecafa2 100644 --- a/p2p/net/swarm/swarm_addr_test.go +++ b/p2p/net/swarm/swarm_addr_test.go @@ -6,6 +6,7 @@ import ( "github.com/libp2p/go-libp2p-core/peerstore" "github.com/libp2p/go-libp2p-core/test" + "github.com/stretchr/testify/require" ma "github.com/multiformats/go-multiaddr" @@ -13,7 +14,6 @@ import ( ) func TestDialBadAddrs(t *testing.T) { - m := func(s string) ma.Multiaddr { maddr, err := ma.NewMultiaddr(s) if err != nil { @@ -22,13 +22,12 @@ func TestDialBadAddrs(t *testing.T) { return maddr } - ctx := context.Background() - s := makeSwarms(ctx, t, 1)[0] + s := makeSwarms(t, 1)[0] test := func(a ma.Multiaddr) { p := test.RandPeerIDFatal(t) s.Peerstore().AddAddr(p, a, peerstore.PermanentAddrTTL) - if _, err := s.DialPeer(ctx, p); err == nil { + if _, err := s.DialPeer(context.Background(), p); err == nil { t.Errorf("swarm should not dial: %s", p) } } @@ -39,19 +38,13 @@ func TestDialBadAddrs(t *testing.T) { } func TestAddrRace(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - s := makeSwarms(ctx, t, 1)[0] + s := makeSwarms(t, 1)[0] defer s.Close() a1, err := s.InterfaceListenAddresses() - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) a2, err := s.InterfaceListenAddresses() - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) if len(a1) > 0 && len(a2) > 0 && &a1[0] == &a2[0] { t.Fatal("got the exact same address set twice; this could lead to data races") @@ -59,15 +52,8 @@ func TestAddrRace(t *testing.T) { } func TestAddressesWithoutListening(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - s := swarmt.GenSwarm(t, ctx, swarmt.OptDialOnly) - + s := swarmt.GenSwarm(t, swarmt.OptDialOnly) a1, err := s.InterfaceListenAddresses() - if err != nil { - t.Fatal(err) - } - if len(a1) != 0 { - t.Fatalf("expected to be listening on no addresses, was listening on %d", len(a1)) - } + require.NoError(t, err) + require.Empty(t, a1, "expected to be listening on no addresses") } diff --git a/p2p/net/swarm/swarm_listen.go b/p2p/net/swarm/swarm_listen.go index c064ae85..ca54280c 100644 --- a/p2p/net/swarm/swarm_listen.go +++ b/p2p/net/swarm/swarm_listen.go @@ -46,7 +46,7 @@ func (s *Swarm) AddListenAddr(a ma.Multiaddr) error { // // Distinguish between these two cases to avoid confusing users. select { - case <-s.proc.Closing(): + case <-s.ctx.Done(): return ErrSwarmClosed default: return ErrNoTransport diff --git a/p2p/net/swarm/swarm_net_test.go b/p2p/net/swarm/swarm_net_test.go index 05984f6b..1f1d0454 100644 --- a/p2p/net/swarm/swarm_net_test.go +++ b/p2p/net/swarm/swarm_net_test.go @@ -7,6 +7,8 @@ import ( "testing" "time" + "github.com/stretchr/testify/require" + "github.com/libp2p/go-libp2p-core/network" . "github.com/libp2p/go-libp2p-swarm/testing" @@ -15,19 +17,16 @@ import ( // TestConnectednessCorrect starts a few networks, connects a few // and tests Connectedness value is correct. func TestConnectednessCorrect(t *testing.T) { - - ctx := context.Background() - nets := make([]network.Network, 4) for i := 0; i < 4; i++ { - nets[i] = GenSwarm(t, ctx) + nets[i] = GenSwarm(t) } // connect 0-1, 0-2, 0-3, 1-2, 2-3 dial := func(a, b network.Network) { DivulgeAddresses(b, a) - if _, err := a.DialPeer(ctx, b.LocalPeer()); err != nil { + if _, err := a.DialPeer(context.Background(), b.LocalPeer()); err != nil { t.Fatalf("Failed to dial: %s", err) } } @@ -54,33 +53,17 @@ func TestConnectednessCorrect(t *testing.T) { expectConnectedness(t, nets[0], nets[2], network.NotConnected) expectConnectedness(t, nets[1], nets[3], network.NotConnected) - if len(nets[0].Peers()) != 2 { - t.Fatal("expected net 0 to have two peers") - } - - if len(nets[2].Peers()) != 2 { - t.Fatal("expected net 2 to have two peers") - } - - if len(nets[1].ConnsToPeer(nets[3].LocalPeer())) != 0 { - t.Fatal("net 1 should have no connections to net 3") - } - - if err := nets[2].ClosePeer(nets[1].LocalPeer()); err != nil { - t.Fatal(err) - } + require.Len(t, nets[0].Peers(), 2, "expected net 0 to have two peers") + require.Len(t, nets[2].Peers(), 2, "expected net 2 to have two peers") + require.NotZerof(t, nets[1].ConnsToPeer(nets[3].LocalPeer()), "net 1 should have no connections to net 3") + require.NoError(t, nets[2].ClosePeer(nets[1].LocalPeer())) time.Sleep(time.Millisecond * 50) - expectConnectedness(t, nets[2], nets[1], network.NotConnected) for _, n := range nets { n.Close() } - - for _, n := range nets { - <-n.Process().Closed() - } } func expectConnectedness(t *testing.T, a, b network.Network, expected network.Connectedness) { @@ -113,7 +96,7 @@ func TestNetworkOpenStream(t *testing.T) { nets := make([]network.Network, 4) for i := 0; i < 4; i++ { - nets[i] = GenSwarm(t, ctx) + nets[i] = GenSwarm(t) } dial := func(a, b network.Network) { diff --git a/p2p/net/swarm/swarm_notif_test.go b/p2p/net/swarm/swarm_notif_test.go index 33836172..157405cb 100644 --- a/p2p/net/swarm/swarm_notif_test.go +++ b/p2p/net/swarm/swarm_notif_test.go @@ -5,6 +5,8 @@ import ( "testing" "time" + "github.com/stretchr/testify/require" + "github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/peer" @@ -18,8 +20,7 @@ func TestNotifications(t *testing.T) { notifiees := make([]*netNotifiee, swarmSize) - ctx := context.Background() - swarms := makeSwarms(ctx, t, swarmSize) + swarms := makeSwarms(t, swarmSize) defer func() { for i, s := range swarms { select { @@ -27,10 +28,7 @@ func TestNotifications(t *testing.T) { t.Error("should not have been closed") default: } - err := s.Close() - if err != nil { - t.Error(err) - } + require.NoError(t, s.Close()) select { case <-notifiees[i].listenClose: default: @@ -48,7 +46,7 @@ func TestNotifications(t *testing.T) { notifiees[i] = n } - connectSwarms(t, ctx, swarms) + connectSwarms(t, context.Background(), swarms) time.Sleep(50 * time.Millisecond) // should've gotten 5 by now. @@ -96,7 +94,7 @@ func TestNotifications(t *testing.T) { } } - complement := func(c network.Conn) (*Swarm, *netNotifiee, *Conn) { + complement := func(c network.Conn) (network.Network, *netNotifiee, *Conn) { for i, s := range swarms { for _, c2 := range s.Conns() { if c.LocalMultiaddr().Equal(c2.RemoteMultiaddr()) && diff --git a/p2p/net/swarm/swarm_test.go b/p2p/net/swarm/swarm_test.go index a94281b1..e6cb0ba3 100644 --- a/p2p/net/swarm/swarm_test.go +++ b/p2p/net/swarm/swarm_test.go @@ -58,29 +58,25 @@ func EchoStreamHandler(stream network.Stream) { }() } -func makeDialOnlySwarm(ctx context.Context, t *testing.T) *Swarm { - swarm := GenSwarm(t, ctx, OptDialOnly) +func makeDialOnlySwarm(t *testing.T) network.Network { + swarm := GenSwarm(t, OptDialOnly) swarm.SetStreamHandler(EchoStreamHandler) - return swarm } -func makeSwarms(ctx context.Context, t *testing.T, num int, opts ...Option) []*Swarm { - swarms := make([]*Swarm, 0, num) - +func makeSwarms(t *testing.T, num int, opts ...Option) []network.Network { + swarms := make([]network.Network, 0, num) for i := 0; i < num; i++ { - swarm := GenSwarm(t, ctx, opts...) + swarm := GenSwarm(t, opts...) swarm.SetStreamHandler(EchoStreamHandler) swarms = append(swarms, swarm) } - return swarms } -func connectSwarms(t *testing.T, ctx context.Context, swarms []*Swarm) { - +func connectSwarms(t *testing.T, ctx context.Context, swarms []network.Network) { var wg sync.WaitGroup - connect := func(s *Swarm, dst peer.ID, addr ma.Multiaddr) { + connect := func(s network.Network, dst peer.ID, addr ma.Multiaddr) { // TODO: make a DialAddr func. s.Peerstore().AddAddr(dst, addr, peerstore.PermanentAddrTTL) if _, err := s.DialPeer(ctx, dst); err != nil { @@ -104,13 +100,10 @@ func connectSwarms(t *testing.T, ctx context.Context, swarms []*Swarm) { } func SubtestSwarm(t *testing.T, SwarmNum int, MsgNum int) { - // t.Skip("skipping for another test") - - ctx := context.Background() - swarms := makeSwarms(ctx, t, SwarmNum, OptDisableReuseport) + swarms := makeSwarms(t, SwarmNum, OptDisableReuseport) // connect everyone - connectSwarms(t, ctx, swarms) + connectSwarms(t, context.Background(), swarms) // ping/pong for _, s1 := range swarms { @@ -118,7 +111,7 @@ func SubtestSwarm(t *testing.T, SwarmNum int, MsgNum int) { log.Debugf("%s ping pong round", s1.LocalPeer()) log.Debugf("-------------------------------------------------------") - _, cancel := context.WithCancel(ctx) + _, cancel := context.WithCancel(context.Background()) got := map[peer.ID]int{} errChan := make(chan error, MsgNum*len(swarms)) streamChan := make(chan network.Stream, MsgNum) @@ -132,7 +125,7 @@ func SubtestSwarm(t *testing.T, SwarmNum int, MsgNum int) { defer wg.Done() // first, one stream per peer (nice) - stream, err := s1.NewStream(ctx, p) + stream, err := s1.NewStream(context.Background(), p) if err != nil { errChan <- err return @@ -253,7 +246,7 @@ func TestConnHandler(t *testing.T) { t.Parallel() ctx := context.Background() - swarms := makeSwarms(ctx, t, 5) + swarms := makeSwarms(t, 5) gotconn := make(chan struct{}, 10) swarms[0].SetConnHandler(func(conn network.Conn) { @@ -387,8 +380,8 @@ func TestConnectionGating(t *testing.T) { p2Gater = tc.p2Gater(p2Gater) } - sw1 := GenSwarm(t, ctx, OptConnGater(p1Gater), optTransport) - sw2 := GenSwarm(t, ctx, OptConnGater(p2Gater), optTransport) + sw1 := GenSwarm(t, OptConnGater(p1Gater), optTransport) + sw2 := GenSwarm(t, OptConnGater(p2Gater), optTransport) p1 := sw1.LocalPeer() p2 := sw2.LocalPeer() @@ -408,10 +401,9 @@ func TestConnectionGating(t *testing.T) { } func TestNoDial(t *testing.T) { - ctx := context.Background() - swarms := makeSwarms(ctx, t, 2) + swarms := makeSwarms(t, 2) - _, err := swarms[0].NewStream(network.WithNoDial(ctx, "swarm test"), swarms[1].LocalPeer()) + _, err := swarms[0].NewStream(network.WithNoDial(context.Background(), "swarm test"), swarms[1].LocalPeer()) if err != network.ErrNoConn { t.Fatal("should have failed with ErrNoConn") } @@ -419,36 +411,29 @@ func TestNoDial(t *testing.T) { func TestCloseWithOpenStreams(t *testing.T) { ctx := context.Background() - swarms := makeSwarms(ctx, t, 2) + swarms := makeSwarms(t, 2) connectSwarms(t, ctx, swarms) s, err := swarms[0].NewStream(ctx, swarms[1].LocalPeer()) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) defer s.Close() // close swarm before stream. - err = swarms[0].Close() - if err != nil { - t.Fatal(err) - } + require.NoError(t, swarms[0].Close()) } func TestTypedNilConn(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - s := GenSwarm(t, ctx) + s := GenSwarm(t) defer s.Close() // We can't dial ourselves. - c, err := s.DialPeer(ctx, s.LocalPeer()) + c, err := s.DialPeer(context.Background(), s.LocalPeer()) require.Error(t, err) // If we fail to dial, the connection should be nil. - require.True(t, c == nil) + require.Nil(t, c) } func TestPreventDialListenAddr(t *testing.T) { - s := GenSwarm(t, context.Background(), OptDialOnly) + s := GenSwarm(t, OptDialOnly) if err := s.Listen(ma.StringCast("/ip4/0.0.0.0/udp/0/quic")); err != nil { t.Fatal(err) } diff --git a/p2p/net/swarm/testing/testing.go b/p2p/net/swarm/testing/testing.go index ba517769..d6091354 100644 --- a/p2p/net/swarm/testing/testing.go +++ b/p2p/net/swarm/testing/testing.go @@ -1,7 +1,6 @@ package testing import ( - "context" "testing" csms "github.com/libp2p/go-conn-security-multistream" @@ -22,7 +21,6 @@ import ( msmux "github.com/libp2p/go-stream-muxer-multistream" "github.com/libp2p/go-tcp-transport" - "github.com/jbenet/goprocess" ma "github.com/multiformats/go-multiaddr" ) @@ -73,7 +71,7 @@ func OptPeerPrivateKey(sk crypto.PrivKey) Option { } // GenUpgrader creates a new connection upgrader for use with this swarm. -func GenUpgrader(n *swarm.Swarm) *tptu.Upgrader { +func GenUpgrader(n network.Network) *tptu.Upgrader { id := n.LocalPeer() pk := n.Peerstore().PrivKey(id) secMuxer := new(csms.SSMuxer) @@ -88,8 +86,18 @@ func GenUpgrader(n *swarm.Swarm) *tptu.Upgrader { } } +type mSwarm struct { + *swarm.Swarm + ps peerstore.Peerstore +} + +func (s *mSwarm) Close() error { + s.ps.Close() + return s.Swarm.Close() +} + // GenSwarm generates a new test swarm. -func GenSwarm(t *testing.T, ctx context.Context, opts ...Option) *swarm.Swarm { +func GenSwarm(t *testing.T, opts ...Option) network.Network { var cfg config for _, o := range opts { o(t, &cfg) @@ -113,11 +121,10 @@ func GenSwarm(t *testing.T, ctx context.Context, opts ...Option) *swarm.Swarm { ps := pstoremem.NewPeerstore() ps.AddPubKey(p.ID, p.PubKey) ps.AddPrivKey(p.ID, p.PrivKey) - s := swarm.NewSwarm(ctx, p.ID, ps, metrics.NewBandwidthCounter(), cfg.connectionGater) - - // Call AddChildNoWait because we can't call AddChild after the process - // may have been closed (e.g., if the context was canceled). - s.Process().AddChildNoWait(goprocess.WithTeardown(ps.Close)) + s := &mSwarm{ + Swarm: swarm.NewSwarm(p.ID, ps, metrics.NewBandwidthCounter(), cfg.connectionGater), + ps: ps, + } upgrader := GenUpgrader(s) upgrader.ConnGater = cfg.connectionGater diff --git a/p2p/net/swarm/testing/testing_test.go b/p2p/net/swarm/testing/testing_test.go index a80cca17..60cd2787 100644 --- a/p2p/net/swarm/testing/testing_test.go +++ b/p2p/net/swarm/testing/testing_test.go @@ -1,14 +1,13 @@ package testing import ( - "context" "testing" "github.com/stretchr/testify/require" ) func TestGenSwarm(t *testing.T) { - swarm := GenSwarm(t, context.Background()) + swarm := GenSwarm(t) require.NoError(t, swarm.Close()) GenUpgrader(swarm) } diff --git a/p2p/net/swarm/transport_test.go b/p2p/net/swarm/transport_test.go index 82225840..52726026 100644 --- a/p2p/net/swarm/transport_test.go +++ b/p2p/net/swarm/transport_test.go @@ -7,9 +7,13 @@ import ( swarm "github.com/libp2p/go-libp2p-swarm" swarmt "github.com/libp2p/go-libp2p-swarm/testing" + "github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-libp2p-core/transport" + ma "github.com/multiformats/go-multiaddr" + + "github.com/stretchr/testify/require" ) type dummyTransport struct { @@ -42,24 +46,23 @@ func (dt *dummyTransport) Close() error { return nil } +type swarmWithTransport interface { + network.Network + AddTransport(transport.Transport) error +} + func TestUselessTransport(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - s := swarmt.GenSwarm(t, ctx) - err := s.AddTransport(new(dummyTransport)) + s := swarmt.GenSwarm(t) + err := s.(swarmWithTransport).AddTransport(new(dummyTransport)) if err == nil { t.Fatal("adding a transport that supports no protocols should have failed") } } func TestTransportClose(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - s := swarmt.GenSwarm(t, ctx) + s := swarmt.GenSwarm(t) tpt := &dummyTransport{protocols: []int{1}} - if err := s.AddTransport(tpt); err != nil { - t.Fatal(err) - } + require.NoError(t, s.(swarmWithTransport).AddTransport(tpt)) _ = s.Close() if !tpt.closed { t.Fatal("expected transport to be closed") @@ -68,13 +71,11 @@ func TestTransportClose(t *testing.T) { } func TestTransportAfterClose(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - s := swarmt.GenSwarm(t, ctx) + s := swarmt.GenSwarm(t) s.Close() tpt := &dummyTransport{protocols: []int{1}} - if err := s.AddTransport(tpt); err != swarm.ErrSwarmClosed { + if err := s.(swarmWithTransport).AddTransport(tpt); err != swarm.ErrSwarmClosed { t.Fatal("expected swarm closed error, got: ", err) } } From a872d26b7c566429e744fb3130879454f64ee3a6 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Tue, 7 Sep 2021 10:42:50 +0100 Subject: [PATCH 2/3] use testing.Cleanup to shut down peerstore and revert most test changes --- p2p/net/swarm/dial_test.go | 23 +++++++++-------------- p2p/net/swarm/peers_test.go | 6 ++++-- p2p/net/swarm/simul_test.go | 4 ++-- p2p/net/swarm/swarm_notif_test.go | 2 +- p2p/net/swarm/swarm_test.go | 10 +++++----- p2p/net/swarm/testing/testing.go | 21 +++++---------------- p2p/net/swarm/transport_test.go | 16 +++------------- 7 files changed, 29 insertions(+), 53 deletions(-) diff --git a/p2p/net/swarm/dial_test.go b/p2p/net/swarm/dial_test.go index de41e7f7..5fe4bdc2 100644 --- a/p2p/net/swarm/dial_test.go +++ b/p2p/net/swarm/dial_test.go @@ -28,12 +28,7 @@ func init() { transport.DialTimeout = time.Second } -type swarmWithBackoff interface { - network.Network - Backoff() *DialBackoff -} - -func closeSwarms(swarms []network.Network) { +func closeSwarms(swarms []*Swarm) { for _, s := range swarms { s.Close() } @@ -101,7 +96,7 @@ func TestSimultDials(t *testing.T) { // connect everyone { var wg sync.WaitGroup - connect := func(s network.Network, dst peer.ID, addr ma.Multiaddr) { + connect := func(s *Swarm, dst peer.ID, addr ma.Multiaddr) { // copy for other peer log.Debugf("TestSimultOpen: connecting: %s --> %s (%s)", s.LocalPeer(), dst, addr) s.Peerstore().AddAddr(dst, addr, peerstore.TempAddrTTL) @@ -193,7 +188,7 @@ func TestDialWait(t *testing.T) { t.Error("> 2*transport.DialTimeout * DialAttempts not being respected", duration, 2*transport.DialTimeout*DialAttempts) } - if !s1.(swarmWithBackoff).Backoff().Backoff(s2p, s2addr) { + if !s1.Backoff().Backoff(s2p, s2addr) { t.Error("s2 should now be on backoff") } } @@ -330,10 +325,10 @@ func TestDialBackoff(t *testing.T) { } // check backoff state - if s1.(swarmWithBackoff).Backoff().Backoff(s2.LocalPeer(), s2addrs[0]) { + if s1.Backoff().Backoff(s2.LocalPeer(), s2addrs[0]) { t.Error("s2 should not be on backoff") } - if !s1.(swarmWithBackoff).Backoff().Backoff(s3p, s3addr) { + if !s1.Backoff().Backoff(s3p, s3addr) { t.Error("s3 should be on backoff") } @@ -400,10 +395,10 @@ func TestDialBackoff(t *testing.T) { } // check backoff state (the same) - if s1.(swarmWithBackoff).Backoff().Backoff(s2.LocalPeer(), s2addrs[0]) { + if s1.Backoff().Backoff(s2.LocalPeer(), s2addrs[0]) { t.Error("s2 should not be on backoff") } - if !s1.(swarmWithBackoff).Backoff().Backoff(s3p, s3addr) { + if !s1.Backoff().Backoff(s3p, s3addr) { t.Error("s3 should be on backoff") } } @@ -445,7 +440,7 @@ func TestDialBackoffClears(t *testing.T) { t.Error("> 2*transport.DialTimeout * DialAttempts not being respected", duration, 2*transport.DialTimeout*DialAttempts) } - if !s1.(swarmWithBackoff).Backoff().Backoff(s2.LocalPeer(), s2bad) { + if !s1.Backoff().Backoff(s2.LocalPeer(), s2bad) { t.Error("s2 should now be on backoff") } else { t.Log("correctly added to backoff") @@ -472,7 +467,7 @@ func TestDialBackoffClears(t *testing.T) { t.Log("correctly connected") } - if s1.(swarmWithBackoff).Backoff().Backoff(s2.LocalPeer(), s2bad) { + if s1.Backoff().Backoff(s2.LocalPeer(), s2bad) { t.Error("s2 should no longer be on backoff") } else { t.Log("correctly cleared backoff") diff --git a/p2p/net/swarm/peers_test.go b/p2p/net/swarm/peers_test.go index 908abe91..3145d862 100644 --- a/p2p/net/swarm/peers_test.go +++ b/p2p/net/swarm/peers_test.go @@ -9,6 +9,8 @@ import ( "github.com/libp2p/go-libp2p-core/peerstore" ma "github.com/multiformats/go-multiaddr" + + . "github.com/libp2p/go-libp2p-swarm" ) func TestPeers(t *testing.T) { @@ -17,7 +19,7 @@ func TestPeers(t *testing.T) { s1 := swarms[0] s2 := swarms[1] - connect := func(s network.Network, dst peer.ID, addr ma.Multiaddr) { + connect := func(s *Swarm, dst peer.ID, addr ma.Multiaddr) { // TODO: make a DialAddr func. s.Peerstore().AddAddr(dst, addr, peerstore.PermanentAddrTTL) // t.Logf("connections from %s", s.LocalPeer()) @@ -53,7 +55,7 @@ func TestPeers(t *testing.T) { log.Infof("%s swarm routing table: %s", s.LocalPeer(), s.Peers()) } - test := func(s network.Network) { + test := func(s *Swarm) { expect := 1 actual := len(s.Peers()) if actual != expect { diff --git a/p2p/net/swarm/simul_test.go b/p2p/net/swarm/simul_test.go index 326c4e21..aa4eb590 100644 --- a/p2p/net/swarm/simul_test.go +++ b/p2p/net/swarm/simul_test.go @@ -7,12 +7,12 @@ import ( "testing" "time" - "github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-libp2p-core/peerstore" ma "github.com/multiformats/go-multiaddr" + . "github.com/libp2p/go-libp2p-swarm" swarmt "github.com/libp2p/go-libp2p-swarm/testing" "github.com/libp2p/go-libp2p-testing/ci" ) @@ -24,7 +24,7 @@ func TestSimultOpen(t *testing.T) { // connect everyone { var wg sync.WaitGroup - connect := func(s network.Network, dst peer.ID, addr ma.Multiaddr) { + connect := func(s *Swarm, dst peer.ID, addr ma.Multiaddr) { defer wg.Done() // copy for other peer log.Debugf("TestSimultOpen: connecting: %s --> %s (%s)", s.LocalPeer(), dst, addr) diff --git a/p2p/net/swarm/swarm_notif_test.go b/p2p/net/swarm/swarm_notif_test.go index 157405cb..c0c6f82d 100644 --- a/p2p/net/swarm/swarm_notif_test.go +++ b/p2p/net/swarm/swarm_notif_test.go @@ -94,7 +94,7 @@ func TestNotifications(t *testing.T) { } } - complement := func(c network.Conn) (network.Network, *netNotifiee, *Conn) { + complement := func(c network.Conn) (*Swarm, *netNotifiee, *Conn) { for i, s := range swarms { for _, c2 := range s.Conns() { if c.LocalMultiaddr().Equal(c2.RemoteMultiaddr()) && diff --git a/p2p/net/swarm/swarm_test.go b/p2p/net/swarm/swarm_test.go index e6cb0ba3..9f799f66 100644 --- a/p2p/net/swarm/swarm_test.go +++ b/p2p/net/swarm/swarm_test.go @@ -58,14 +58,14 @@ func EchoStreamHandler(stream network.Stream) { }() } -func makeDialOnlySwarm(t *testing.T) network.Network { +func makeDialOnlySwarm(t *testing.T) *Swarm { swarm := GenSwarm(t, OptDialOnly) swarm.SetStreamHandler(EchoStreamHandler) return swarm } -func makeSwarms(t *testing.T, num int, opts ...Option) []network.Network { - swarms := make([]network.Network, 0, num) +func makeSwarms(t *testing.T, num int, opts ...Option) []*Swarm { + swarms := make([]*Swarm, 0, num) for i := 0; i < num; i++ { swarm := GenSwarm(t, opts...) swarm.SetStreamHandler(EchoStreamHandler) @@ -74,9 +74,9 @@ func makeSwarms(t *testing.T, num int, opts ...Option) []network.Network { return swarms } -func connectSwarms(t *testing.T, ctx context.Context, swarms []network.Network) { +func connectSwarms(t *testing.T, ctx context.Context, swarms []*Swarm) { var wg sync.WaitGroup - connect := func(s network.Network, dst peer.ID, addr ma.Multiaddr) { + connect := func(s *Swarm, dst peer.ID, addr ma.Multiaddr) { // TODO: make a DialAddr func. s.Peerstore().AddAddr(dst, addr, peerstore.PermanentAddrTTL) if _, err := s.DialPeer(ctx, dst); err != nil { diff --git a/p2p/net/swarm/testing/testing.go b/p2p/net/swarm/testing/testing.go index d6091354..201b4f0f 100644 --- a/p2p/net/swarm/testing/testing.go +++ b/p2p/net/swarm/testing/testing.go @@ -71,7 +71,7 @@ func OptPeerPrivateKey(sk crypto.PrivKey) Option { } // GenUpgrader creates a new connection upgrader for use with this swarm. -func GenUpgrader(n network.Network) *tptu.Upgrader { +func GenUpgrader(n *swarm.Swarm) *tptu.Upgrader { id := n.LocalPeer() pk := n.Peerstore().PrivKey(id) secMuxer := new(csms.SSMuxer) @@ -86,18 +86,8 @@ func GenUpgrader(n network.Network) *tptu.Upgrader { } } -type mSwarm struct { - *swarm.Swarm - ps peerstore.Peerstore -} - -func (s *mSwarm) Close() error { - s.ps.Close() - return s.Swarm.Close() -} - // GenSwarm generates a new test swarm. -func GenSwarm(t *testing.T, opts ...Option) network.Network { +func GenSwarm(t *testing.T, opts ...Option) *swarm.Swarm { var cfg config for _, o := range opts { o(t, &cfg) @@ -121,10 +111,9 @@ func GenSwarm(t *testing.T, opts ...Option) network.Network { ps := pstoremem.NewPeerstore() ps.AddPubKey(p.ID, p.PubKey) ps.AddPrivKey(p.ID, p.PrivKey) - s := &mSwarm{ - Swarm: swarm.NewSwarm(p.ID, ps, metrics.NewBandwidthCounter(), cfg.connectionGater), - ps: ps, - } + t.Cleanup(func() { ps.Close() }) + + s := swarm.NewSwarm(p.ID, ps, metrics.NewBandwidthCounter(), cfg.connectionGater) upgrader := GenUpgrader(s) upgrader.ConnGater = cfg.connectionGater diff --git a/p2p/net/swarm/transport_test.go b/p2p/net/swarm/transport_test.go index 52726026..6d5913cf 100644 --- a/p2p/net/swarm/transport_test.go +++ b/p2p/net/swarm/transport_test.go @@ -7,7 +7,6 @@ import ( swarm "github.com/libp2p/go-libp2p-swarm" swarmt "github.com/libp2p/go-libp2p-swarm/testing" - "github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-libp2p-core/transport" @@ -46,28 +45,19 @@ func (dt *dummyTransport) Close() error { return nil } -type swarmWithTransport interface { - network.Network - AddTransport(transport.Transport) error -} - func TestUselessTransport(t *testing.T) { s := swarmt.GenSwarm(t) - err := s.(swarmWithTransport).AddTransport(new(dummyTransport)) - if err == nil { - t.Fatal("adding a transport that supports no protocols should have failed") - } + require.Error(t, s.AddTransport(new(dummyTransport)), "adding a transport that supports no protocols should have failed") } func TestTransportClose(t *testing.T) { s := swarmt.GenSwarm(t) tpt := &dummyTransport{protocols: []int{1}} - require.NoError(t, s.(swarmWithTransport).AddTransport(tpt)) + require.NoError(t, s.AddTransport(tpt)) _ = s.Close() if !tpt.closed { t.Fatal("expected transport to be closed") } - } func TestTransportAfterClose(t *testing.T) { @@ -75,7 +65,7 @@ func TestTransportAfterClose(t *testing.T) { s.Close() tpt := &dummyTransport{protocols: []int{1}} - if err := s.(swarmWithTransport).AddTransport(tpt); err != swarm.ErrSwarmClosed { + if err := s.AddTransport(tpt); err != swarm.ErrSwarmClosed { t.Fatal("expected swarm closed error, got: ", err) } } From 0537306605873e8e18bc658045a6022cc20d5beb Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Tue, 7 Sep 2021 10:56:55 +0100 Subject: [PATCH 3/3] cancel the ctx when closing, use a sync.Once to only close once --- p2p/net/swarm/swarm.go | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/p2p/net/swarm/swarm.go b/p2p/net/swarm/swarm.go index 0f4e3c75..43f7a36f 100644 --- a/p2p/net/swarm/swarm.go +++ b/p2p/net/swarm/swarm.go @@ -90,6 +90,7 @@ type Swarm struct { limiter *dialLimiter gater connmgr.ConnectionGater + closeOnce sync.Once ctx context.Context // is canceled when Close is called ctxCancel context.CancelFunc @@ -131,8 +132,14 @@ func NewSwarm(local peer.ID, peers peerstore.Peerstore, bwc metrics.Reporter, ex } func (s *Swarm) Close() error { - // Prevents new connections and/or listeners from being added to the swarm. + s.closeOnce.Do(s.close) + return nil +} +func (s *Swarm) close() { + s.ctxCancel() + + // Prevents new connections and/or listeners from being added to the swarm. s.listeners.Lock() listeners := s.listeners.m s.listeners.m = nil @@ -187,8 +194,6 @@ func (s *Swarm) Close() error { } } wg.Wait() - - return nil } func (s *Swarm) addConn(tc transport.CapableConn, dir network.Direction) (*Conn, error) {