fix flaky BasicHost tests

This commit is contained in:
Marten Seemann 2021-09-25 17:25:42 +01:00
parent eba91ac63e
commit bf0203c6d3
1 changed files with 75 additions and 167 deletions

View File

@ -1,7 +1,6 @@
package basichost
import (
"bytes"
"context"
"fmt"
"io"
@ -49,9 +48,7 @@ func TestHostSimple(t *testing.T) {
defer h2.Close()
h2pi := h2.Peerstore().PeerInfo(h2.ID())
if err := h1.Connect(ctx, h2pi); err != nil {
t.Fatal(err)
}
require.NoError(t, h1.Connect(ctx, h2pi))
piper, pipew := io.Pipe()
h2.SetStreamHandler(protocol.TestingID, func(s network.Stream) {
@ -61,33 +58,24 @@ func TestHostSimple(t *testing.T) {
})
s, err := h1.NewStream(ctx, h2pi.ID, protocol.TestingID)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
// write to the stream
buf1 := []byte("abcdefghijkl")
if _, err := s.Write(buf1); err != nil {
t.Fatal(err)
}
_, err = s.Write(buf1)
require.NoError(t, err)
// get it from the stream (echoed)
buf2 := make([]byte, len(buf1))
if _, err := io.ReadFull(s, buf2); err != nil {
t.Fatal(err)
}
if !bytes.Equal(buf1, buf2) {
t.Fatalf("buf1 != buf2 -- %x != %x", buf1, buf2)
}
_, err = io.ReadFull(s, buf2)
require.NoError(t, err)
require.Equal(t, buf1, buf2)
// get it from the pipe (tee)
buf3 := make([]byte, len(buf1))
if _, err := io.ReadFull(piper, buf3); err != nil {
t.Fatal(err)
}
if !bytes.Equal(buf1, buf3) {
t.Fatalf("buf1 != buf3 -- %x != %x", buf1, buf3)
}
_, err = io.ReadFull(piper, buf3)
require.NoError(t, err)
require.Equal(t, buf1, buf3)
}
func TestMultipleClose(t *testing.T) {
@ -109,9 +97,7 @@ func TestSignedPeerRecordWithNoListenAddrs(t *testing.T) {
}
// now add a listen addr
if err := h.Network().Listen(ma.StringCast("/ip4/0.0.0.0/tcp/0")); err != nil {
t.Fatal(err)
}
require.NoError(t, h.Network().Listen(ma.StringCast("/ip4/0.0.0.0/tcp/0")))
if len(h.Addrs()) < 1 {
t.Errorf("expected at least 1 listen addr, got %d", len(h.Addrs()))
}
@ -135,9 +121,7 @@ func TestProtocolHandlerEvents(t *testing.T) {
defer h.Close()
sub, err := h.EventBus().Subscribe(&event.EvtLocalProtocolsUpdated{}, eventbus.BufSize(16))
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
defer sub.Close()
// the identify service adds new protocol handlers shortly after the host
@ -264,6 +248,8 @@ func TestAllAddrs(t *testing.T) {
t.Fatal("expected addrs to contain original addr")
}
// getHostPair gets a new pair of hosts.
// The first host initiates the connection to the second host.
func getHostPair(t *testing.T) (host.Host, host.Host) {
t.Helper()
@ -275,9 +261,7 @@ func getHostPair(t *testing.T) (host.Host, host.Host) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
h2pi := h2.Peerstore().PeerInfo(h2.ID())
if err := h1.Connect(ctx, h2pi); err != nil {
t.Fatal(err)
}
require.NoError(t, h1.Connect(ctx, h2pi))
return h1, h2
}
@ -301,70 +285,55 @@ func TestHostProtoPreference(t *testing.T) {
defer h1.Close()
defer h2.Close()
protoOld := protocol.ID("/testing")
protoNew := protocol.ID("/testing/1.1.0")
protoMinor := protocol.ID("/testing/1.2.0")
const (
protoOld = protocol.ID("/testing")
protoNew = protocol.ID("/testing/1.1.0")
protoMinor = protocol.ID("/testing/1.2.0")
)
connectedOn := make(chan protocol.ID)
handler := func(s network.Stream) {
connectedOn <- s.Protocol()
s.Close()
}
// Prevent pushing identify information so this test works.
h2.RemoveStreamHandler(identify.IDPush)
h2.RemoveStreamHandler(identify.IDDelta)
h1.RemoveStreamHandler(identify.IDPush)
h1.RemoveStreamHandler(identify.IDDelta)
h1.SetStreamHandler(protoOld, handler)
h2.SetStreamHandler(protoOld, handler)
s, err := h2.NewStream(ctx, h1.ID(), protoMinor, protoNew, protoOld)
if err != nil {
t.Fatal(err)
}
s, err := h1.NewStream(ctx, h2.ID(), protoMinor, protoNew, protoOld)
require.NoError(t, err)
// force the lazy negotiation to complete
_, err = s.Write(nil)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
assertWait(t, connectedOn, protoOld)
s.Close()
mfunc, err := helpers.MultistreamSemverMatcher(protoMinor)
if err != nil {
t.Fatal(err)
}
h1.SetStreamHandlerMatch(protoMinor, mfunc, handler)
require.NoError(t, err)
h2.SetStreamHandlerMatch(protoMinor, mfunc, handler)
// remembered preference will be chosen first, even when the other side newly supports it
s2, err := h2.NewStream(ctx, h1.ID(), protoMinor, protoNew, protoOld)
if err != nil {
t.Fatal(err)
}
s2, err := h1.NewStream(ctx, h2.ID(), protoMinor, protoNew, protoOld)
require.NoError(t, err)
// required to force 'lazy' handshake
_, err = s2.Write([]byte("hello"))
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
assertWait(t, connectedOn, protoOld)
s2.Close()
s3, err := h2.NewStream(ctx, h1.ID(), protoMinor)
if err != nil {
t.Fatal(err)
}
s3, err := h1.NewStream(ctx, h2.ID(), protoMinor)
require.NoError(t, err)
// Force a lazy handshake as we may have received a protocol update by this point.
_, err = s3.Write([]byte("hello"))
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
assertWait(t, connectedOn, protoMinor)
s3.Close()
@ -397,6 +366,8 @@ func TestHostProtoPreknowledge(t *testing.T) {
require.NoError(t, err)
h2, err := NewHost(swarmt.GenSwarm(t), nil)
require.NoError(t, err)
defer h1.Close()
defer h2.Close()
conn := make(chan protocol.ID)
handler := func(s network.Stream) {
@ -404,18 +375,13 @@ func TestHostProtoPreknowledge(t *testing.T) {
s.Close()
}
h1.SetStreamHandler("/super", handler)
h2.SetStreamHandler("/super", handler)
// Prevent pushing identify information so this test actually _uses_ the super protocol.
h2.RemoveStreamHandler(identify.IDPush)
h2.RemoveStreamHandler(identify.IDDelta)
h1.RemoveStreamHandler(identify.IDPush)
h1.RemoveStreamHandler(identify.IDDelta)
h2pi := h2.Peerstore().PeerInfo(h2.ID())
if err := h1.Connect(ctx, h2pi); err != nil {
t.Fatal(err)
}
defer h1.Close()
defer h2.Close()
require.NoError(t, h1.Connect(ctx, h2pi))
// wait for identify handshake to finish completely
select {
@ -430,12 +396,10 @@ func TestHostProtoPreknowledge(t *testing.T) {
t.Fatal("timed out waiting for identify")
}
h1.SetStreamHandler("/foo", handler)
h2.SetStreamHandler("/foo", handler)
s, err := h2.NewStream(ctx, h1.ID(), "/foo", "/bar", "/super")
if err != nil {
t.Fatal(err)
}
s, err := h1.NewStream(ctx, h2.ID(), "/foo", "/bar", "/super")
require.NoError(t, err)
select {
case p := <-conn:
@ -444,10 +408,7 @@ func TestHostProtoPreknowledge(t *testing.T) {
}
_, err = s.Read(nil)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
assertWait(t, conn, "/super")
s.Close()
@ -462,29 +423,20 @@ func TestNewDialOld(t *testing.T) {
defer h2.Close()
connectedOn := make(chan protocol.ID)
h1.SetStreamHandler("/testing", func(s network.Stream) {
h2.SetStreamHandler("/testing", func(s network.Stream) {
connectedOn <- s.Protocol()
s.Close()
})
s, err := h2.NewStream(ctx, h1.ID(), "/testing/1.0.0", "/testing")
if err != nil {
t.Fatal(err)
}
s, err := h1.NewStream(ctx, h2.ID(), "/testing/1.0.0", "/testing")
require.NoError(t, err)
// force the lazy negotiation to complete
_, err = s.Write(nil)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
assertWait(t, connectedOn, "/testing")
if s.Protocol() != "/testing" {
t.Fatal("should have gotten /testing")
}
s.Close()
require.Equal(t, s.Protocol(), protocol.ID("/testing"), "should have gotten /testing")
}
func TestProtoDowngrade(t *testing.T) {
@ -496,51 +448,32 @@ func TestProtoDowngrade(t *testing.T) {
defer h2.Close()
connectedOn := make(chan protocol.ID)
h1.SetStreamHandler("/testing/1.0.0", func(s network.Stream) {
h2.SetStreamHandler("/testing/1.0.0", func(s network.Stream) {
defer s.Close()
result, err := ioutil.ReadAll(s)
if err != nil {
t.Error(err)
} else if string(result) != "bar" {
t.Error("wrong result")
}
assert.NoError(t, err)
assert.Equal(t, string(result), "bar")
connectedOn <- s.Protocol()
})
s, err := h2.NewStream(ctx, h1.ID(), "/testing/1.0.0", "/testing")
if err != nil {
t.Fatal(err)
}
if s.Protocol() != "/testing/1.0.0" {
t.Fatalf("should have gotten /testing/1.0.0, got %s", s.Protocol())
}
s, err := h1.NewStream(ctx, h2.ID(), "/testing/1.0.0", "/testing")
require.NoError(t, err)
require.Equal(t, s.Protocol(), protocol.ID("/testing/1.0.0"), "should have gotten /testing/1.0.0, got %s", s.Protocol())
_, err = s.Write([]byte("bar"))
if err != nil {
t.Fatal(err)
}
err = s.CloseWrite()
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
require.NoError(t, s.CloseWrite())
assertWait(t, connectedOn, "/testing/1.0.0")
if err := s.Close(); err != nil {
t.Error(err)
}
require.NoError(t, s.Close())
h2.Network().ClosePeer(h1.ID())
h1.RemoveStreamHandler("/testing/1.0.0")
h1.SetStreamHandler("/testing", func(s network.Stream) {
h1.Network().ClosePeer(h2.ID())
h2.RemoveStreamHandler("/testing/1.0.0")
h2.SetStreamHandler("/testing", func(s network.Stream) {
defer s.Close()
result, err := ioutil.ReadAll(s)
if err != nil {
t.Error(err)
} else if string(result) != "foo" {
t.Error("wrong result")
}
assert.NoError(t, err)
assert.Equal(t, string(result), "foo")
connectedOn <- s.Protocol()
})
@ -549,45 +482,24 @@ func TestProtoDowngrade(t *testing.T) {
time.Sleep(time.Millisecond)
h2pi := h2.Peerstore().PeerInfo(h2.ID())
if err := h1.Connect(ctx, h2pi); err != nil {
t.Fatal(err)
}
require.NoError(t, h1.Connect(ctx, h2pi))
s2, err := h2.NewStream(ctx, h1.ID(), "/testing/1.0.0", "/testing")
if err != nil {
t.Fatal(err)
}
if s2.Protocol() != "/testing" {
t.Errorf("should have gotten /testing, got %s, %s", s.Protocol(), s.Conn())
}
s2, err := h1.NewStream(ctx, h2.ID(), "/testing/1.0.0", "/testing")
require.NoError(t, err)
require.Equal(t, s2.Protocol(), protocol.ID("/testing"), "should have gotten /testing, got %s, %s", s.Protocol(), s.Conn())
_, err = s2.Write([]byte("foo"))
if err != nil {
t.Error(err)
}
err = s2.CloseWrite()
if err != nil {
t.Error(err)
}
require.NoError(t, err)
require.NoError(t, s2.CloseWrite())
assertWait(t, connectedOn, "/testing")
if err := s.Close(); err != nil {
t.Error(err)
}
}
func TestAddrResolution(t *testing.T) {
ctx := context.Background()
p1, err := test.RandPeerID()
if err != nil {
t.Error(err)
}
p2, err := test.RandPeerID()
if err != nil {
t.Error(err)
}
p1 := test.RandPeerIDFatal(t)
p2 := test.RandPeerIDFatal(t)
addr1 := ma.StringCast("/dnsaddr/example.com")
addr2 := ma.StringCast("/ip4/192.0.2.1/tcp/123")
p2paddr1 := ma.StringCast("/dnsaddr/example.com/p2p/" + p1.Pretty())
@ -600,18 +512,14 @@ func TestAddrResolution(t *testing.T) {
}},
}
resolver, err := madns.NewResolver(madns.WithDefaultResolver(backend))
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
h, err := NewHost(swarmt.GenSwarm(t), &HostOpts{MultiaddrResolver: resolver})
require.NoError(t, err)
defer h.Close()
pi, err := peer.AddrInfoFromP2pAddr(p2paddr1)
if err != nil {
t.Error(err)
}
require.NoError(t, err)
tctx, cancel := context.WithTimeout(ctx, time.Millisecond*100)
defer cancel()
@ -855,7 +763,7 @@ func TestNegotiationCancel(t *testing.T) {
defer h2.Close()
// pre-negotiation so we can make the negotiation hang.
h1.Network().SetStreamHandler(func(s network.Stream) {
h2.Network().SetStreamHandler(func(s network.Stream) {
<-ctx.Done() // wait till the test is done.
s.Reset()
})
@ -865,7 +773,7 @@ func TestNegotiationCancel(t *testing.T) {
errCh := make(chan error, 1)
go func() {
s, err := h2.NewStream(ctx2, h1.ID(), "/testing")
s, err := h1.NewStream(ctx2, h2.ID(), "/testing")
if s != nil {
errCh <- fmt.Errorf("expected to fail negotiation")
return