diff --git a/svc_test.go b/svc_test.go index 5b2c6b4..a722e7f 100644 --- a/svc_test.go +++ b/svc_test.go @@ -12,6 +12,7 @@ import ( inet "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" mocknet "github.com/libp2p/go-libp2p/p2p/net/mock" + "github.com/stretchr/testify/require" db "github.com/libp2p/go-libp2p-rendezvous/db/sqlite" pb "github.com/libp2p/go-libp2p-rendezvous/pb" @@ -49,58 +50,36 @@ func TestSVCRegistrationAndDiscovery(t *testing.T) { hosts := getRendezvousHosts(t, ctx, m, 5) svc, err := makeRendezvousService(ctx, hosts[0], ":memory:") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) defer svc.DB.Close() clients := getRendezvousPointsTest(t, hosts) const registerTTL = 60 recordTTL, err := clients[0].Register(ctx, "foo1", registerTTL) - if err != nil { - t.Fatal(err) - } - if recordTTL != registerTTL*time.Second { - t.Fatalf("Expected record TTL to be %d seconds", DefaultTTL) - } + require.NoError(t, err) + require.Equalf(t, registerTTL*time.Second, recordTTL, "expected record TTL to be %d seconds", DefaultTTL) rrs, cookie, err := clients[0].Discover(ctx, "foo1", 10, nil) - if err != nil { - t.Fatal(err) - } - if len(rrs) != 1 { - t.Fatal("Expected 1 registration") - } + require.NoError(t, err) + require.Len(t, rrs, 1) checkHostRegistration(t, rrs[0], hosts[1]) for i, client := range clients[1:] { recordTTL, err = client.Register(ctx, "foo1", registerTTL) - if err != nil { - t.Fatal(err) - } - if recordTTL != registerTTL*time.Second { - t.Fatalf("Expected record TTL to be %d seconds", DefaultTTL) - } + require.NoError(t, err) + require.Equalf(t, registerTTL*time.Second, recordTTL, "expected record TTL to be %d seconds", DefaultTTL) rrs, cookie, err = clients[0].Discover(ctx, "foo1", 10, cookie) - if err != nil { - t.Fatal(err) - } - if len(rrs) != 1 { - t.Fatal("Expected 1 registration") - } + require.NoError(t, err) + require.Len(t, rrs, 1) checkHostRegistration(t, rrs[0], hosts[2+i]) } for _, client := range clients[1:] { rrs, _, err = client.Discover(ctx, "foo1", 10, nil) - if err != nil { - t.Fatal(err) - } - if len(rrs) != 4 { - t.Fatal("Expected 4 registrations") - } + require.NoError(t, err) + require.Len(t, rrs, 4) for j, rr := range rrs { checkHostRegistration(t, rr, hosts[1+j]) @@ -108,18 +87,12 @@ func TestSVCRegistrationAndDiscovery(t *testing.T) { } err = clients[0].Unregister(ctx, "foo1") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) for _, client := range clients[0:] { rrs, _, err = client.Discover(ctx, "foo1", 10, nil) - if err != nil { - t.Fatal(err) - } - if len(rrs) != 3 { - t.Fatalf("Expected 3 registrations, got %d", len(rrs)) - } + require.NoError(t, err) + require.Lenf(t, rrs, 3, "Expected 3 registrations, got %d", len(rrs)) for j, rr := range rrs { checkHostRegistration(t, rr, hosts[2+j]) @@ -127,18 +100,12 @@ func TestSVCRegistrationAndDiscovery(t *testing.T) { } err = clients[1].Unregister(ctx, "") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) for _, client := range clients[0:] { rrs, _, err = client.Discover(ctx, "foo1", 10, nil) - if err != nil { - t.Fatal(err) - } - if len(rrs) != 2 { - t.Fatal("Expected 2 registrations") - } + require.NoError(t, err) + require.Len(t, rrs, 2) for j, rr := range rrs { checkHostRegistration(t, rr, hosts[3+j]) @@ -152,14 +119,11 @@ func checkHostRegistration(t *testing.T, rr Registration, host host.Host) { } addrs := host.Addrs() raddrs := rr.Peer.Addrs - if len(addrs) != len(raddrs) { - t.Fatal("bad registration: peer address length mismatch") - } + require.Equal(t, len(addrs), len(raddrs), "bad registration: peer address length mismatch") + for i, addr := range addrs { raddr := raddrs[i] - if !addr.Equal(raddr) { - t.Fatal("bad registration: peer address mismatch") - } + require.True(t, addr.Equal(raddr), "bad registration: peer address mismatch") } } @@ -173,81 +137,49 @@ func TestSVCErrors(t *testing.T) { hosts := getRendezvousHosts(t, ctx, m, 2) svc, err := makeRendezvousService(ctx, hosts[0], ":memory:") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) defer svc.DB.Close() // testable registration errors res, err := doTestRequest(ctx, hosts[1], hosts[0].ID(), newRegisterMessage("", peer.AddrInfo{}, 0)) - if err != nil { - t.Fatal(err) - } - if res.GetRegisterResponse().GetStatus() != pb.Message_E_INVALID_NAMESPACE { - t.Fatal("expected E_INVALID_NAMESPACE") - } + require.NoError(t, err) + require.Equal(t, pb.Message_E_INVALID_NAMESPACE, res.GetRegisterResponse().GetStatus()) badns := make([]byte, 2*MaxNamespaceLength) rand.Read(badns) res, err = doTestRequest(ctx, hosts[1], hosts[0].ID(), newRegisterMessage(string(badns), peer.AddrInfo{}, 0)) - if err != nil { - t.Fatal(err) - } - if res.GetRegisterResponse().GetStatus() != pb.Message_E_INVALID_NAMESPACE { - t.Fatal("expected E_INVALID_NAMESPACE") - } + require.NoError(t, err) + require.Equal(t, pb.Message_E_INVALID_NAMESPACE, res.GetRegisterResponse().GetStatus()) res, err = doTestRequest(ctx, hosts[1], hosts[0].ID(), newRegisterMessage("foo", peer.AddrInfo{}, 0)) - if err != nil { - t.Fatal(err) - } - if res.GetRegisterResponse().GetStatus() != pb.Message_E_INVALID_PEER_INFO { - t.Fatal("expected E_INVALID_PEER_INFO") - } + require.NoError(t, err) + require.Equal(t, pb.Message_E_INVALID_PEER_INFO, res.GetRegisterResponse().GetStatus()) res, err = doTestRequest(ctx, hosts[1], hosts[0].ID(), newRegisterMessage("foo", peer.AddrInfo{ID: peer.ID("blah")}, 0)) - if err != nil { - t.Fatal(err) - } - if res.GetRegisterResponse().GetStatus() != pb.Message_E_INVALID_PEER_INFO { - t.Fatal("expected E_INVALID_PEER_INFO") - } + require.NoError(t, err) + require.Equal(t, pb.Message_E_INVALID_PEER_INFO, res.GetRegisterResponse().GetStatus()) p, err := peer.Decode("QmVr26fY1tKyspEJBniVhqxQeEjhF78XerGiqWAwraVLQH") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) res, err = doTestRequest(ctx, hosts[1], hosts[0].ID(), newRegisterMessage("foo", peer.AddrInfo{ID: p}, 0)) - if err != nil { - t.Fatal(err) - } - if res.GetRegisterResponse().GetStatus() != pb.Message_E_INVALID_PEER_INFO { - t.Fatal("expected E_INVALID_PEER_INFO") - } + require.NoError(t, err) + require.Equal(t, pb.Message_E_INVALID_PEER_INFO, res.GetRegisterResponse().GetStatus()) res, err = doTestRequest(ctx, hosts[1], hosts[0].ID(), newRegisterMessage("foo", peer.AddrInfo{ID: hosts[1].ID()}, 0)) - if err != nil { - t.Fatal(err) - } - if res.GetRegisterResponse().GetStatus() != pb.Message_E_INVALID_PEER_INFO { - t.Fatal("expected E_INVALID_PEER_INFO") - } + require.NoError(t, err) + require.Equal(t, pb.Message_E_INVALID_PEER_INFO, res.GetRegisterResponse().GetStatus()) res, err = doTestRequest(ctx, hosts[1], hosts[0].ID(), newRegisterMessage("foo", peer.AddrInfo{ID: hosts[1].ID(), Addrs: hosts[1].Addrs()}, 2*MaxTTL)) - if err != nil { - t.Fatal(err) - } - if res.GetRegisterResponse().GetStatus() != pb.Message_E_INVALID_TTL { - t.Fatal("expected E_INVALID_TTL") - } + require.NoError(t, err) + require.Equal(t, pb.Message_E_INVALID_TTL, res.GetRegisterResponse().GetStatus()) // do MaxRegistrations for i := 0; i < MaxRegistrations+1; i++ { @@ -264,45 +196,28 @@ func TestSVCErrors(t *testing.T) { // and now fail res, err = doTestRequest(ctx, hosts[1], hosts[0].ID(), newRegisterMessage("foo", peer.AddrInfo{ID: hosts[1].ID(), Addrs: hosts[1].Addrs()}, 0)) - if err != nil { - t.Fatal(err) - } - if res.GetRegisterResponse().GetStatus() != pb.Message_E_NOT_AUTHORIZED { - t.Fatal("expected E_NOT_AUTHORIZED") - } + require.NoError(t, err) + require.Equal(t, pb.Message_E_NOT_AUTHORIZED, res.GetRegisterResponse().GetStatus()) // testable discovery errors res, err = doTestRequest(ctx, hosts[1], hosts[0].ID(), newDiscoverMessage(string(badns), 0, nil)) - if err != nil { - t.Fatal(err) - } - if res.GetDiscoverResponse().GetStatus() != pb.Message_E_INVALID_NAMESPACE { - t.Fatal("expected E_INVALID_NAMESPACE") - } + require.NoError(t, err) + require.Equal(t, pb.Message_E_INVALID_NAMESPACE, res.GetDiscoverResponse().GetStatus()) badcookie := make([]byte, 10) rand.Read(badcookie) res, err = doTestRequest(ctx, hosts[1], hosts[0].ID(), newDiscoverMessage("foo", 0, badcookie)) - if err != nil { - t.Fatal(err) - } - if res.GetDiscoverResponse().GetStatus() != pb.Message_E_INVALID_COOKIE { - t.Fatal("expected E_INVALID_COOKIE") - } + require.NoError(t, err) + require.Equal(t, pb.Message_E_INVALID_COOKIE, res.GetDiscoverResponse().GetStatus()) badcookie = make([]byte, 40) rand.Read(badcookie) res, err = doTestRequest(ctx, hosts[1], hosts[0].ID(), newDiscoverMessage("foo", 0, badcookie)) - if err != nil { - t.Fatal(err) - } - if res.GetDiscoverResponse().GetStatus() != pb.Message_E_INVALID_COOKIE { - t.Fatal("expected E_INVALID_COOKIE") - } - + require.NoError(t, err) + require.Equal(t, pb.Message_E_INVALID_COOKIE, res.GetDiscoverResponse().GetStatus()) } func doTestRequest(ctx context.Context, host host.Host, rp peer.ID, m *pb.Message) (*pb.Message, error) {