fix flaky BasicHost tests

This commit is contained in:
Marten Seemann 2021-09-25 17:25:42 +01:00
parent eba91ac63e
commit bf0203c6d3
1 changed files with 75 additions and 167 deletions

View File

@ -1,7 +1,6 @@
package basichost package basichost
import ( import (
"bytes"
"context" "context"
"fmt" "fmt"
"io" "io"
@ -49,9 +48,7 @@ func TestHostSimple(t *testing.T) {
defer h2.Close() defer h2.Close()
h2pi := h2.Peerstore().PeerInfo(h2.ID()) h2pi := h2.Peerstore().PeerInfo(h2.ID())
if err := h1.Connect(ctx, h2pi); err != nil { require.NoError(t, h1.Connect(ctx, h2pi))
t.Fatal(err)
}
piper, pipew := io.Pipe() piper, pipew := io.Pipe()
h2.SetStreamHandler(protocol.TestingID, func(s network.Stream) { h2.SetStreamHandler(protocol.TestingID, func(s network.Stream) {
@ -61,33 +58,24 @@ func TestHostSimple(t *testing.T) {
}) })
s, err := h1.NewStream(ctx, h2pi.ID, protocol.TestingID) s, err := h1.NewStream(ctx, h2pi.ID, protocol.TestingID)
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
// write to the stream // write to the stream
buf1 := []byte("abcdefghijkl") buf1 := []byte("abcdefghijkl")
if _, err := s.Write(buf1); err != nil { _, err = s.Write(buf1)
t.Fatal(err) require.NoError(t, err)
}
// get it from the stream (echoed) // get it from the stream (echoed)
buf2 := make([]byte, len(buf1)) buf2 := make([]byte, len(buf1))
if _, err := io.ReadFull(s, buf2); err != nil { _, err = io.ReadFull(s, buf2)
t.Fatal(err) require.NoError(t, err)
} require.Equal(t, buf1, buf2)
if !bytes.Equal(buf1, buf2) {
t.Fatalf("buf1 != buf2 -- %x != %x", buf1, buf2)
}
// get it from the pipe (tee) // get it from the pipe (tee)
buf3 := make([]byte, len(buf1)) buf3 := make([]byte, len(buf1))
if _, err := io.ReadFull(piper, buf3); err != nil { _, err = io.ReadFull(piper, buf3)
t.Fatal(err) require.NoError(t, err)
} require.Equal(t, buf1, buf3)
if !bytes.Equal(buf1, buf3) {
t.Fatalf("buf1 != buf3 -- %x != %x", buf1, buf3)
}
} }
func TestMultipleClose(t *testing.T) { func TestMultipleClose(t *testing.T) {
@ -109,9 +97,7 @@ func TestSignedPeerRecordWithNoListenAddrs(t *testing.T) {
} }
// now add a listen addr // now add a listen addr
if err := h.Network().Listen(ma.StringCast("/ip4/0.0.0.0/tcp/0")); err != nil { require.NoError(t, h.Network().Listen(ma.StringCast("/ip4/0.0.0.0/tcp/0")))
t.Fatal(err)
}
if len(h.Addrs()) < 1 { if len(h.Addrs()) < 1 {
t.Errorf("expected at least 1 listen addr, got %d", len(h.Addrs())) t.Errorf("expected at least 1 listen addr, got %d", len(h.Addrs()))
} }
@ -135,9 +121,7 @@ func TestProtocolHandlerEvents(t *testing.T) {
defer h.Close() defer h.Close()
sub, err := h.EventBus().Subscribe(&event.EvtLocalProtocolsUpdated{}, eventbus.BufSize(16)) sub, err := h.EventBus().Subscribe(&event.EvtLocalProtocolsUpdated{}, eventbus.BufSize(16))
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
defer sub.Close() defer sub.Close()
// the identify service adds new protocol handlers shortly after the host // the identify service adds new protocol handlers shortly after the host
@ -264,6 +248,8 @@ func TestAllAddrs(t *testing.T) {
t.Fatal("expected addrs to contain original addr") t.Fatal("expected addrs to contain original addr")
} }
// getHostPair gets a new pair of hosts.
// The first host initiates the connection to the second host.
func getHostPair(t *testing.T) (host.Host, host.Host) { func getHostPair(t *testing.T) (host.Host, host.Host) {
t.Helper() t.Helper()
@ -275,9 +261,7 @@ func getHostPair(t *testing.T) (host.Host, host.Host) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second) ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel() defer cancel()
h2pi := h2.Peerstore().PeerInfo(h2.ID()) h2pi := h2.Peerstore().PeerInfo(h2.ID())
if err := h1.Connect(ctx, h2pi); err != nil { require.NoError(t, h1.Connect(ctx, h2pi))
t.Fatal(err)
}
return h1, h2 return h1, h2
} }
@ -301,70 +285,55 @@ func TestHostProtoPreference(t *testing.T) {
defer h1.Close() defer h1.Close()
defer h2.Close() defer h2.Close()
protoOld := protocol.ID("/testing") const (
protoNew := protocol.ID("/testing/1.1.0") protoOld = protocol.ID("/testing")
protoMinor := protocol.ID("/testing/1.2.0") protoNew = protocol.ID("/testing/1.1.0")
protoMinor = protocol.ID("/testing/1.2.0")
)
connectedOn := make(chan protocol.ID) connectedOn := make(chan protocol.ID)
handler := func(s network.Stream) { handler := func(s network.Stream) {
connectedOn <- s.Protocol() connectedOn <- s.Protocol()
s.Close() s.Close()
} }
// Prevent pushing identify information so this test works. // Prevent pushing identify information so this test works.
h2.RemoveStreamHandler(identify.IDPush) h1.RemoveStreamHandler(identify.IDPush)
h2.RemoveStreamHandler(identify.IDDelta) h1.RemoveStreamHandler(identify.IDDelta)
h1.SetStreamHandler(protoOld, handler) h2.SetStreamHandler(protoOld, handler)
s, err := h2.NewStream(ctx, h1.ID(), protoMinor, protoNew, protoOld) s, err := h1.NewStream(ctx, h2.ID(), protoMinor, protoNew, protoOld)
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
// force the lazy negotiation to complete // force the lazy negotiation to complete
_, err = s.Write(nil) _, err = s.Write(nil)
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
assertWait(t, connectedOn, protoOld) assertWait(t, connectedOn, protoOld)
s.Close() s.Close()
mfunc, err := helpers.MultistreamSemverMatcher(protoMinor) mfunc, err := helpers.MultistreamSemverMatcher(protoMinor)
if err != nil { require.NoError(t, err)
t.Fatal(err) h2.SetStreamHandlerMatch(protoMinor, mfunc, handler)
}
h1.SetStreamHandlerMatch(protoMinor, mfunc, handler)
// remembered preference will be chosen first, even when the other side newly supports it // remembered preference will be chosen first, even when the other side newly supports it
s2, err := h2.NewStream(ctx, h1.ID(), protoMinor, protoNew, protoOld) s2, err := h1.NewStream(ctx, h2.ID(), protoMinor, protoNew, protoOld)
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
// required to force 'lazy' handshake // required to force 'lazy' handshake
_, err = s2.Write([]byte("hello")) _, err = s2.Write([]byte("hello"))
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
assertWait(t, connectedOn, protoOld) assertWait(t, connectedOn, protoOld)
s2.Close() s2.Close()
s3, err := h2.NewStream(ctx, h1.ID(), protoMinor) s3, err := h1.NewStream(ctx, h2.ID(), protoMinor)
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
// Force a lazy handshake as we may have received a protocol update by this point. // Force a lazy handshake as we may have received a protocol update by this point.
_, err = s3.Write([]byte("hello")) _, err = s3.Write([]byte("hello"))
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
assertWait(t, connectedOn, protoMinor) assertWait(t, connectedOn, protoMinor)
s3.Close() s3.Close()
@ -397,6 +366,8 @@ func TestHostProtoPreknowledge(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
h2, err := NewHost(swarmt.GenSwarm(t), nil) h2, err := NewHost(swarmt.GenSwarm(t), nil)
require.NoError(t, err) require.NoError(t, err)
defer h1.Close()
defer h2.Close()
conn := make(chan protocol.ID) conn := make(chan protocol.ID)
handler := func(s network.Stream) { handler := func(s network.Stream) {
@ -404,18 +375,13 @@ func TestHostProtoPreknowledge(t *testing.T) {
s.Close() s.Close()
} }
h1.SetStreamHandler("/super", handler) h2.SetStreamHandler("/super", handler)
// Prevent pushing identify information so this test actually _uses_ the super protocol. // Prevent pushing identify information so this test actually _uses_ the super protocol.
h2.RemoveStreamHandler(identify.IDPush) h1.RemoveStreamHandler(identify.IDPush)
h2.RemoveStreamHandler(identify.IDDelta) h1.RemoveStreamHandler(identify.IDDelta)
h2pi := h2.Peerstore().PeerInfo(h2.ID()) h2pi := h2.Peerstore().PeerInfo(h2.ID())
if err := h1.Connect(ctx, h2pi); err != nil { require.NoError(t, h1.Connect(ctx, h2pi))
t.Fatal(err)
}
defer h1.Close()
defer h2.Close()
// wait for identify handshake to finish completely // wait for identify handshake to finish completely
select { select {
@ -430,12 +396,10 @@ func TestHostProtoPreknowledge(t *testing.T) {
t.Fatal("timed out waiting for identify") t.Fatal("timed out waiting for identify")
} }
h1.SetStreamHandler("/foo", handler) h2.SetStreamHandler("/foo", handler)
s, err := h2.NewStream(ctx, h1.ID(), "/foo", "/bar", "/super") s, err := h1.NewStream(ctx, h2.ID(), "/foo", "/bar", "/super")
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
select { select {
case p := <-conn: case p := <-conn:
@ -444,10 +408,7 @@ func TestHostProtoPreknowledge(t *testing.T) {
} }
_, err = s.Read(nil) _, err = s.Read(nil)
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
assertWait(t, conn, "/super") assertWait(t, conn, "/super")
s.Close() s.Close()
@ -462,29 +423,20 @@ func TestNewDialOld(t *testing.T) {
defer h2.Close() defer h2.Close()
connectedOn := make(chan protocol.ID) connectedOn := make(chan protocol.ID)
h1.SetStreamHandler("/testing", func(s network.Stream) { h2.SetStreamHandler("/testing", func(s network.Stream) {
connectedOn <- s.Protocol() connectedOn <- s.Protocol()
s.Close() s.Close()
}) })
s, err := h2.NewStream(ctx, h1.ID(), "/testing/1.0.0", "/testing") s, err := h1.NewStream(ctx, h2.ID(), "/testing/1.0.0", "/testing")
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
// force the lazy negotiation to complete // force the lazy negotiation to complete
_, err = s.Write(nil) _, err = s.Write(nil)
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
assertWait(t, connectedOn, "/testing") assertWait(t, connectedOn, "/testing")
if s.Protocol() != "/testing" { require.Equal(t, s.Protocol(), protocol.ID("/testing"), "should have gotten /testing")
t.Fatal("should have gotten /testing")
}
s.Close()
} }
func TestProtoDowngrade(t *testing.T) { func TestProtoDowngrade(t *testing.T) {
@ -496,51 +448,32 @@ func TestProtoDowngrade(t *testing.T) {
defer h2.Close() defer h2.Close()
connectedOn := make(chan protocol.ID) connectedOn := make(chan protocol.ID)
h1.SetStreamHandler("/testing/1.0.0", func(s network.Stream) { h2.SetStreamHandler("/testing/1.0.0", func(s network.Stream) {
defer s.Close() defer s.Close()
result, err := ioutil.ReadAll(s) result, err := ioutil.ReadAll(s)
if err != nil { assert.NoError(t, err)
t.Error(err) assert.Equal(t, string(result), "bar")
} else if string(result) != "bar" {
t.Error("wrong result")
}
connectedOn <- s.Protocol() connectedOn <- s.Protocol()
}) })
s, err := h2.NewStream(ctx, h1.ID(), "/testing/1.0.0", "/testing") s, err := h1.NewStream(ctx, h2.ID(), "/testing/1.0.0", "/testing")
if err != nil { require.NoError(t, err)
t.Fatal(err) require.Equal(t, s.Protocol(), protocol.ID("/testing/1.0.0"), "should have gotten /testing/1.0.0, got %s", s.Protocol())
}
if s.Protocol() != "/testing/1.0.0" {
t.Fatalf("should have gotten /testing/1.0.0, got %s", s.Protocol())
}
_, err = s.Write([]byte("bar")) _, err = s.Write([]byte("bar"))
if err != nil { require.NoError(t, err)
t.Fatal(err) require.NoError(t, s.CloseWrite())
}
err = s.CloseWrite()
if err != nil {
t.Fatal(err)
}
assertWait(t, connectedOn, "/testing/1.0.0") assertWait(t, connectedOn, "/testing/1.0.0")
if err := s.Close(); err != nil { require.NoError(t, s.Close())
t.Error(err)
}
h2.Network().ClosePeer(h1.ID()) h1.Network().ClosePeer(h2.ID())
h2.RemoveStreamHandler("/testing/1.0.0")
h1.RemoveStreamHandler("/testing/1.0.0") h2.SetStreamHandler("/testing", func(s network.Stream) {
h1.SetStreamHandler("/testing", func(s network.Stream) {
defer s.Close() defer s.Close()
result, err := ioutil.ReadAll(s) result, err := ioutil.ReadAll(s)
if err != nil { assert.NoError(t, err)
t.Error(err) assert.Equal(t, string(result), "foo")
} else if string(result) != "foo" {
t.Error("wrong result")
}
connectedOn <- s.Protocol() connectedOn <- s.Protocol()
}) })
@ -549,45 +482,24 @@ func TestProtoDowngrade(t *testing.T) {
time.Sleep(time.Millisecond) time.Sleep(time.Millisecond)
h2pi := h2.Peerstore().PeerInfo(h2.ID()) h2pi := h2.Peerstore().PeerInfo(h2.ID())
if err := h1.Connect(ctx, h2pi); err != nil { require.NoError(t, h1.Connect(ctx, h2pi))
t.Fatal(err)
}
s2, err := h2.NewStream(ctx, h1.ID(), "/testing/1.0.0", "/testing") s2, err := h1.NewStream(ctx, h2.ID(), "/testing/1.0.0", "/testing")
if err != nil { require.NoError(t, err)
t.Fatal(err) require.Equal(t, s2.Protocol(), protocol.ID("/testing"), "should have gotten /testing, got %s, %s", s.Protocol(), s.Conn())
}
if s2.Protocol() != "/testing" {
t.Errorf("should have gotten /testing, got %s, %s", s.Protocol(), s.Conn())
}
_, err = s2.Write([]byte("foo")) _, err = s2.Write([]byte("foo"))
if err != nil { require.NoError(t, err)
t.Error(err) require.NoError(t, s2.CloseWrite())
}
err = s2.CloseWrite()
if err != nil {
t.Error(err)
}
assertWait(t, connectedOn, "/testing") assertWait(t, connectedOn, "/testing")
if err := s.Close(); err != nil {
t.Error(err)
}
} }
func TestAddrResolution(t *testing.T) { func TestAddrResolution(t *testing.T) {
ctx := context.Background() ctx := context.Background()
p1, err := test.RandPeerID() p1 := test.RandPeerIDFatal(t)
if err != nil { p2 := test.RandPeerIDFatal(t)
t.Error(err)
}
p2, err := test.RandPeerID()
if err != nil {
t.Error(err)
}
addr1 := ma.StringCast("/dnsaddr/example.com") addr1 := ma.StringCast("/dnsaddr/example.com")
addr2 := ma.StringCast("/ip4/192.0.2.1/tcp/123") addr2 := ma.StringCast("/ip4/192.0.2.1/tcp/123")
p2paddr1 := ma.StringCast("/dnsaddr/example.com/p2p/" + p1.Pretty()) p2paddr1 := ma.StringCast("/dnsaddr/example.com/p2p/" + p1.Pretty())
@ -600,18 +512,14 @@ func TestAddrResolution(t *testing.T) {
}}, }},
} }
resolver, err := madns.NewResolver(madns.WithDefaultResolver(backend)) resolver, err := madns.NewResolver(madns.WithDefaultResolver(backend))
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
h, err := NewHost(swarmt.GenSwarm(t), &HostOpts{MultiaddrResolver: resolver}) h, err := NewHost(swarmt.GenSwarm(t), &HostOpts{MultiaddrResolver: resolver})
require.NoError(t, err) require.NoError(t, err)
defer h.Close() defer h.Close()
pi, err := peer.AddrInfoFromP2pAddr(p2paddr1) pi, err := peer.AddrInfoFromP2pAddr(p2paddr1)
if err != nil { require.NoError(t, err)
t.Error(err)
}
tctx, cancel := context.WithTimeout(ctx, time.Millisecond*100) tctx, cancel := context.WithTimeout(ctx, time.Millisecond*100)
defer cancel() defer cancel()
@ -855,7 +763,7 @@ func TestNegotiationCancel(t *testing.T) {
defer h2.Close() defer h2.Close()
// pre-negotiation so we can make the negotiation hang. // pre-negotiation so we can make the negotiation hang.
h1.Network().SetStreamHandler(func(s network.Stream) { h2.Network().SetStreamHandler(func(s network.Stream) {
<-ctx.Done() // wait till the test is done. <-ctx.Done() // wait till the test is done.
s.Reset() s.Reset()
}) })
@ -865,7 +773,7 @@ func TestNegotiationCancel(t *testing.T) {
errCh := make(chan error, 1) errCh := make(chan error, 1)
go func() { go func() {
s, err := h2.NewStream(ctx2, h1.ID(), "/testing") s, err := h1.NewStream(ctx2, h2.ID(), "/testing")
if s != nil { if s != nil {
errCh <- fmt.Errorf("expected to fail negotiation") errCh <- fmt.Errorf("expected to fail negotiation")
return return