diff --git a/p2p/protocol/identify/id_glass_test.go b/p2p/protocol/identify/id_glass_test.go index 8d7bf5eb..273391fe 100644 --- a/p2p/protocol/identify/id_glass_test.go +++ b/p2p/protocol/identify/id_glass_test.go @@ -5,12 +5,13 @@ import ( "testing" "time" + blhost "github.com/libp2p/go-libp2p-blankhost" "github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/peer" - "github.com/stretchr/testify/require" - - blhost "github.com/libp2p/go-libp2p-blankhost" swarmt "github.com/libp2p/go-libp2p-swarm/testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestFastDisconnect(t *testing.T) { @@ -27,7 +28,7 @@ func TestFastDisconnect(t *testing.T) { sync := make(chan struct{}) target.SetStreamHandler(ID, func(s network.Stream) { - // Wait till the stream is setup on both sides. + // Wait till the stream is set up on both sides. select { case <-sync: case <-ctx.Done(): @@ -35,11 +36,16 @@ func TestFastDisconnect(t *testing.T) { } // Kill the connection, and make sure we're completely disconnected. - s.Conn().Close() - for target.Network().Connectedness(s.Conn().RemotePeer()) == network.Connected { - // let something else run - time.Sleep(time.Millisecond) - } + assert.Eventually(t, + func() bool { + for _, conn := range target.Network().ConnsToPeer(s.Conn().RemotePeer()) { + conn.Close() + } + return target.Network().Connectedness(s.Conn().RemotePeer()) != network.Connected + }, + 2*time.Second, + time.Millisecond, + ) // Now try to handle the response. // This should not block indefinitely, or panic, or anything like that. // @@ -57,14 +63,10 @@ func TestFastDisconnect(t *testing.T) { source := blhost.NewBlankHost(swarmt.GenSwarm(t)) defer source.Close() - err = source.Connect(ctx, peer.AddrInfo{ID: target.ID(), Addrs: target.Addrs()}) - if err != nil { - t.Fatal(err) - } + // only connect to the first address, to make sure we only end up with one connection + require.NoError(t, source.Connect(ctx, peer.AddrInfo{ID: target.ID(), Addrs: target.Addrs()})) s, err := source.NewStream(ctx, target.ID(), ID) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) select { case sync <- struct{}{}: case <-ctx.Done(): @@ -77,7 +79,5 @@ func TestFastDisconnect(t *testing.T) { t.Fatal(ctx.Err()) } // double-check to make sure we didn't actually timeout somewhere. - if ctx.Err() != nil { - t.Fatal(ctx.Err()) - } + require.NoError(t, ctx.Err()) }