chore: use testify for svc test

Signed-off-by: gfanton <8671905+gfanton@users.noreply.github.com>
This commit is contained in:
gfanton 2022-11-08 19:27:30 +01:00
parent c70235c26b
commit a153da08b3

View File

@ -12,6 +12,7 @@ import (
inet "github.com/libp2p/go-libp2p/core/network" inet "github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/peer"
mocknet "github.com/libp2p/go-libp2p/p2p/net/mock" mocknet "github.com/libp2p/go-libp2p/p2p/net/mock"
"github.com/stretchr/testify/require"
db "github.com/libp2p/go-libp2p-rendezvous/db/sqlite" db "github.com/libp2p/go-libp2p-rendezvous/db/sqlite"
pb "github.com/libp2p/go-libp2p-rendezvous/pb" pb "github.com/libp2p/go-libp2p-rendezvous/pb"
@ -49,58 +50,36 @@ func TestSVCRegistrationAndDiscovery(t *testing.T) {
hosts := getRendezvousHosts(t, ctx, m, 5) hosts := getRendezvousHosts(t, ctx, m, 5)
svc, err := makeRendezvousService(ctx, hosts[0], ":memory:") svc, err := makeRendezvousService(ctx, hosts[0], ":memory:")
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
defer svc.DB.Close() defer svc.DB.Close()
clients := getRendezvousPointsTest(t, hosts) clients := getRendezvousPointsTest(t, hosts)
const registerTTL = 60 const registerTTL = 60
recordTTL, err := clients[0].Register(ctx, "foo1", registerTTL) recordTTL, err := clients[0].Register(ctx, "foo1", registerTTL)
if err != nil { require.NoError(t, err)
t.Fatal(err) require.Equalf(t, registerTTL*time.Second, recordTTL, "expected record TTL to be %d seconds", DefaultTTL)
}
if recordTTL != registerTTL*time.Second {
t.Fatalf("Expected record TTL to be %d seconds", DefaultTTL)
}
rrs, cookie, err := clients[0].Discover(ctx, "foo1", 10, nil) rrs, cookie, err := clients[0].Discover(ctx, "foo1", 10, nil)
if err != nil { require.NoError(t, err)
t.Fatal(err) require.Len(t, rrs, 1)
}
if len(rrs) != 1 {
t.Fatal("Expected 1 registration")
}
checkHostRegistration(t, rrs[0], hosts[1]) checkHostRegistration(t, rrs[0], hosts[1])
for i, client := range clients[1:] { for i, client := range clients[1:] {
recordTTL, err = client.Register(ctx, "foo1", registerTTL) recordTTL, err = client.Register(ctx, "foo1", registerTTL)
if err != nil { require.NoError(t, err)
t.Fatal(err) require.Equalf(t, registerTTL*time.Second, recordTTL, "expected record TTL to be %d seconds", DefaultTTL)
}
if recordTTL != registerTTL*time.Second {
t.Fatalf("Expected record TTL to be %d seconds", DefaultTTL)
}
rrs, cookie, err = clients[0].Discover(ctx, "foo1", 10, cookie) rrs, cookie, err = clients[0].Discover(ctx, "foo1", 10, cookie)
if err != nil { require.NoError(t, err)
t.Fatal(err) require.Len(t, rrs, 1)
}
if len(rrs) != 1 {
t.Fatal("Expected 1 registration")
}
checkHostRegistration(t, rrs[0], hosts[2+i]) checkHostRegistration(t, rrs[0], hosts[2+i])
} }
for _, client := range clients[1:] { for _, client := range clients[1:] {
rrs, _, err = client.Discover(ctx, "foo1", 10, nil) rrs, _, err = client.Discover(ctx, "foo1", 10, nil)
if err != nil { require.NoError(t, err)
t.Fatal(err) require.Len(t, rrs, 4)
}
if len(rrs) != 4 {
t.Fatal("Expected 4 registrations")
}
for j, rr := range rrs { for j, rr := range rrs {
checkHostRegistration(t, rr, hosts[1+j]) checkHostRegistration(t, rr, hosts[1+j])
@ -108,18 +87,12 @@ func TestSVCRegistrationAndDiscovery(t *testing.T) {
} }
err = clients[0].Unregister(ctx, "foo1") err = clients[0].Unregister(ctx, "foo1")
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
for _, client := range clients[0:] { for _, client := range clients[0:] {
rrs, _, err = client.Discover(ctx, "foo1", 10, nil) rrs, _, err = client.Discover(ctx, "foo1", 10, nil)
if err != nil { require.NoError(t, err)
t.Fatal(err) require.Lenf(t, rrs, 3, "Expected 3 registrations, got %d", len(rrs))
}
if len(rrs) != 3 {
t.Fatalf("Expected 3 registrations, got %d", len(rrs))
}
for j, rr := range rrs { for j, rr := range rrs {
checkHostRegistration(t, rr, hosts[2+j]) checkHostRegistration(t, rr, hosts[2+j])
@ -127,18 +100,12 @@ func TestSVCRegistrationAndDiscovery(t *testing.T) {
} }
err = clients[1].Unregister(ctx, "") err = clients[1].Unregister(ctx, "")
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
for _, client := range clients[0:] { for _, client := range clients[0:] {
rrs, _, err = client.Discover(ctx, "foo1", 10, nil) rrs, _, err = client.Discover(ctx, "foo1", 10, nil)
if err != nil { require.NoError(t, err)
t.Fatal(err) require.Len(t, rrs, 2)
}
if len(rrs) != 2 {
t.Fatal("Expected 2 registrations")
}
for j, rr := range rrs { for j, rr := range rrs {
checkHostRegistration(t, rr, hosts[3+j]) checkHostRegistration(t, rr, hosts[3+j])
@ -152,14 +119,11 @@ func checkHostRegistration(t *testing.T, rr Registration, host host.Host) {
} }
addrs := host.Addrs() addrs := host.Addrs()
raddrs := rr.Peer.Addrs raddrs := rr.Peer.Addrs
if len(addrs) != len(raddrs) { require.Equal(t, len(addrs), len(raddrs), "bad registration: peer address length mismatch")
t.Fatal("bad registration: peer address length mismatch")
}
for i, addr := range addrs { for i, addr := range addrs {
raddr := raddrs[i] raddr := raddrs[i]
if !addr.Equal(raddr) { require.True(t, addr.Equal(raddr), "bad registration: peer address mismatch")
t.Fatal("bad registration: peer address mismatch")
}
} }
} }
@ -173,81 +137,49 @@ func TestSVCErrors(t *testing.T) {
hosts := getRendezvousHosts(t, ctx, m, 2) hosts := getRendezvousHosts(t, ctx, m, 2)
svc, err := makeRendezvousService(ctx, hosts[0], ":memory:") svc, err := makeRendezvousService(ctx, hosts[0], ":memory:")
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
defer svc.DB.Close() defer svc.DB.Close()
// testable registration errors // testable registration errors
res, err := doTestRequest(ctx, hosts[1], hosts[0].ID(), res, err := doTestRequest(ctx, hosts[1], hosts[0].ID(),
newRegisterMessage("", peer.AddrInfo{}, 0)) newRegisterMessage("", peer.AddrInfo{}, 0))
if err != nil { require.NoError(t, err)
t.Fatal(err) require.Equal(t, pb.Message_E_INVALID_NAMESPACE, res.GetRegisterResponse().GetStatus())
}
if res.GetRegisterResponse().GetStatus() != pb.Message_E_INVALID_NAMESPACE {
t.Fatal("expected E_INVALID_NAMESPACE")
}
badns := make([]byte, 2*MaxNamespaceLength) badns := make([]byte, 2*MaxNamespaceLength)
rand.Read(badns) rand.Read(badns)
res, err = doTestRequest(ctx, hosts[1], hosts[0].ID(), res, err = doTestRequest(ctx, hosts[1], hosts[0].ID(),
newRegisterMessage(string(badns), peer.AddrInfo{}, 0)) newRegisterMessage(string(badns), peer.AddrInfo{}, 0))
if err != nil { require.NoError(t, err)
t.Fatal(err) require.Equal(t, pb.Message_E_INVALID_NAMESPACE, res.GetRegisterResponse().GetStatus())
}
if res.GetRegisterResponse().GetStatus() != pb.Message_E_INVALID_NAMESPACE {
t.Fatal("expected E_INVALID_NAMESPACE")
}
res, err = doTestRequest(ctx, hosts[1], hosts[0].ID(), res, err = doTestRequest(ctx, hosts[1], hosts[0].ID(),
newRegisterMessage("foo", peer.AddrInfo{}, 0)) newRegisterMessage("foo", peer.AddrInfo{}, 0))
if err != nil { require.NoError(t, err)
t.Fatal(err) require.Equal(t, pb.Message_E_INVALID_PEER_INFO, res.GetRegisterResponse().GetStatus())
}
if res.GetRegisterResponse().GetStatus() != pb.Message_E_INVALID_PEER_INFO {
t.Fatal("expected E_INVALID_PEER_INFO")
}
res, err = doTestRequest(ctx, hosts[1], hosts[0].ID(), res, err = doTestRequest(ctx, hosts[1], hosts[0].ID(),
newRegisterMessage("foo", peer.AddrInfo{ID: peer.ID("blah")}, 0)) newRegisterMessage("foo", peer.AddrInfo{ID: peer.ID("blah")}, 0))
if err != nil { require.NoError(t, err)
t.Fatal(err) require.Equal(t, pb.Message_E_INVALID_PEER_INFO, res.GetRegisterResponse().GetStatus())
}
if res.GetRegisterResponse().GetStatus() != pb.Message_E_INVALID_PEER_INFO {
t.Fatal("expected E_INVALID_PEER_INFO")
}
p, err := peer.Decode("QmVr26fY1tKyspEJBniVhqxQeEjhF78XerGiqWAwraVLQH") p, err := peer.Decode("QmVr26fY1tKyspEJBniVhqxQeEjhF78XerGiqWAwraVLQH")
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
res, err = doTestRequest(ctx, hosts[1], hosts[0].ID(), res, err = doTestRequest(ctx, hosts[1], hosts[0].ID(),
newRegisterMessage("foo", peer.AddrInfo{ID: p}, 0)) newRegisterMessage("foo", peer.AddrInfo{ID: p}, 0))
if err != nil { require.NoError(t, err)
t.Fatal(err) require.Equal(t, pb.Message_E_INVALID_PEER_INFO, res.GetRegisterResponse().GetStatus())
}
if res.GetRegisterResponse().GetStatus() != pb.Message_E_INVALID_PEER_INFO {
t.Fatal("expected E_INVALID_PEER_INFO")
}
res, err = doTestRequest(ctx, hosts[1], hosts[0].ID(), res, err = doTestRequest(ctx, hosts[1], hosts[0].ID(),
newRegisterMessage("foo", peer.AddrInfo{ID: hosts[1].ID()}, 0)) newRegisterMessage("foo", peer.AddrInfo{ID: hosts[1].ID()}, 0))
if err != nil { require.NoError(t, err)
t.Fatal(err) require.Equal(t, pb.Message_E_INVALID_PEER_INFO, res.GetRegisterResponse().GetStatus())
}
if res.GetRegisterResponse().GetStatus() != pb.Message_E_INVALID_PEER_INFO {
t.Fatal("expected E_INVALID_PEER_INFO")
}
res, err = doTestRequest(ctx, hosts[1], hosts[0].ID(), res, err = doTestRequest(ctx, hosts[1], hosts[0].ID(),
newRegisterMessage("foo", peer.AddrInfo{ID: hosts[1].ID(), Addrs: hosts[1].Addrs()}, 2*MaxTTL)) newRegisterMessage("foo", peer.AddrInfo{ID: hosts[1].ID(), Addrs: hosts[1].Addrs()}, 2*MaxTTL))
if err != nil { require.NoError(t, err)
t.Fatal(err) require.Equal(t, pb.Message_E_INVALID_TTL, res.GetRegisterResponse().GetStatus())
}
if res.GetRegisterResponse().GetStatus() != pb.Message_E_INVALID_TTL {
t.Fatal("expected E_INVALID_TTL")
}
// do MaxRegistrations // do MaxRegistrations
for i := 0; i < MaxRegistrations+1; i++ { for i := 0; i < MaxRegistrations+1; i++ {
@ -264,45 +196,28 @@ func TestSVCErrors(t *testing.T) {
// and now fail // and now fail
res, err = doTestRequest(ctx, hosts[1], hosts[0].ID(), res, err = doTestRequest(ctx, hosts[1], hosts[0].ID(),
newRegisterMessage("foo", peer.AddrInfo{ID: hosts[1].ID(), Addrs: hosts[1].Addrs()}, 0)) newRegisterMessage("foo", peer.AddrInfo{ID: hosts[1].ID(), Addrs: hosts[1].Addrs()}, 0))
if err != nil { require.NoError(t, err)
t.Fatal(err) require.Equal(t, pb.Message_E_NOT_AUTHORIZED, res.GetRegisterResponse().GetStatus())
}
if res.GetRegisterResponse().GetStatus() != pb.Message_E_NOT_AUTHORIZED {
t.Fatal("expected E_NOT_AUTHORIZED")
}
// testable discovery errors // testable discovery errors
res, err = doTestRequest(ctx, hosts[1], hosts[0].ID(), res, err = doTestRequest(ctx, hosts[1], hosts[0].ID(),
newDiscoverMessage(string(badns), 0, nil)) newDiscoverMessage(string(badns), 0, nil))
if err != nil { require.NoError(t, err)
t.Fatal(err) require.Equal(t, pb.Message_E_INVALID_NAMESPACE, res.GetDiscoverResponse().GetStatus())
}
if res.GetDiscoverResponse().GetStatus() != pb.Message_E_INVALID_NAMESPACE {
t.Fatal("expected E_INVALID_NAMESPACE")
}
badcookie := make([]byte, 10) badcookie := make([]byte, 10)
rand.Read(badcookie) rand.Read(badcookie)
res, err = doTestRequest(ctx, hosts[1], hosts[0].ID(), res, err = doTestRequest(ctx, hosts[1], hosts[0].ID(),
newDiscoverMessage("foo", 0, badcookie)) newDiscoverMessage("foo", 0, badcookie))
if err != nil { require.NoError(t, err)
t.Fatal(err) require.Equal(t, pb.Message_E_INVALID_COOKIE, res.GetDiscoverResponse().GetStatus())
}
if res.GetDiscoverResponse().GetStatus() != pb.Message_E_INVALID_COOKIE {
t.Fatal("expected E_INVALID_COOKIE")
}
badcookie = make([]byte, 40) badcookie = make([]byte, 40)
rand.Read(badcookie) rand.Read(badcookie)
res, err = doTestRequest(ctx, hosts[1], hosts[0].ID(), res, err = doTestRequest(ctx, hosts[1], hosts[0].ID(),
newDiscoverMessage("foo", 0, badcookie)) newDiscoverMessage("foo", 0, badcookie))
if err != nil { require.NoError(t, err)
t.Fatal(err) require.Equal(t, pb.Message_E_INVALID_COOKIE, res.GetDiscoverResponse().GetStatus())
}
if res.GetDiscoverResponse().GetStatus() != pb.Message_E_INVALID_COOKIE {
t.Fatal("expected E_INVALID_COOKIE")
}
} }
func doTestRequest(ctx context.Context, host host.Host, rp peer.ID, m *pb.Message) (*pb.Message, error) { func doTestRequest(ctx context.Context, host host.Host, rp peer.ID, m *pb.Message) (*pb.Message, error) {