add test for multi-topic validation

This commit is contained in:
vyzo 2020-04-23 16:41:14 +03:00
parent 957335ba52
commit 741b7e9b41

View File

@ -4,11 +4,18 @@ import (
"bytes"
"context"
"fmt"
"io"
"sync"
"testing"
"time"
pb "github.com/libp2p/go-libp2p-pubsub/pb"
"github.com/libp2p/go-libp2p-core/host"
"github.com/libp2p/go-libp2p-core/network"
"github.com/libp2p/go-libp2p-core/peer"
ggio "github.com/gogo/protobuf/io"
)
func TestRegisterUnregisterValidator(t *testing.T) {
@ -264,3 +271,154 @@ func TestValidateAssortedOptions(t *testing.T) {
}
}
}
func TestValidateMultitopic(t *testing.T) {
// this test adds coverage for multi-topic validation
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
hosts := getNetHosts(t, ctx, 3)
psubs := getPubsubs(ctx, hosts[1:], WithMessageSigning(false))
for _, ps := range psubs {
err := ps.RegisterTopicValidator("test1", func(context.Context, peer.ID, *Message) bool {
return true
})
if err != nil {
t.Fatal(err)
}
err = ps.RegisterTopicValidator("test2", func(context.Context, peer.ID, *Message) bool {
return true
})
if err != nil {
t.Fatal(err)
}
err = ps.RegisterTopicValidator("test3", func(context.Context, peer.ID, *Message) bool {
return false
})
if err != nil {
t.Fatal(err)
}
}
publisher := &multiTopicPublisher{ctx: ctx, h: hosts[0]}
hosts[0].SetStreamHandler(FloodSubID, publisher.handleStream)
connectAll(t, hosts)
var subs1, subs2, subs3 []*Subscription
for _, ps := range psubs {
sub, err := ps.Subscribe("test1")
if err != nil {
t.Fatal(err)
}
subs1 = append(subs1, sub)
sub, err = ps.Subscribe("test2")
if err != nil {
t.Fatal(err)
}
subs2 = append(subs2, sub)
sub, err = ps.Subscribe("test3")
if err != nil {
t.Fatal(err)
}
subs3 = append(subs2, sub)
}
time.Sleep(100 * time.Millisecond)
msg1 := "i am a walrus"
// this goes to test1 and test2, which is accepted and should be delivered
publisher.publish(msg1, "test1", "test2")
for _, sub := range subs1 {
assertReceive(t, sub, []byte(msg1))
}
for _, sub := range subs2 {
assertReceive(t, sub, []byte(msg1))
}
// this goes to test2 and test3, which is rejected by the test3 validator and should not be delivered
msg2 := "i am not a walrus"
publisher.publish(msg2, "test2", "test3")
expectNoMessage := func(sub *Subscription) {
ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond)
defer cancel()
m, err := sub.Next(ctx)
if err == nil {
t.Fatal("expected no message, but got ", string(m.Data))
}
}
for _, sub := range subs2 {
expectNoMessage(sub)
}
for _, sub := range subs3 {
expectNoMessage(sub)
}
}
type multiTopicPublisher struct {
ctx context.Context
h host.Host
mx sync.Mutex
out []network.Stream
mcount int
}
func (p *multiTopicPublisher) handleStream(s network.Stream) {
defer s.Close()
os, err := p.h.NewStream(p.ctx, s.Conn().RemotePeer(), FloodSubID)
if err != nil {
panic(err)
}
p.mx.Lock()
p.out = append(p.out, os)
p.mx.Unlock()
r := ggio.NewDelimitedReader(s, 1<<20)
var rpc pb.RPC
for {
rpc.Reset()
err = r.ReadMsg(&rpc)
if err != nil {
if err != io.EOF {
s.Reset()
}
return
}
}
}
func (p *multiTopicPublisher) publish(msg string, topics ...string) {
p.mcount++
rpc := &pb.RPC{
Publish: []*pb.Message{
&pb.Message{
From: []byte(p.h.ID()),
Data: []byte(msg),
Seqno: []byte{byte(p.mcount)},
TopicIDs: topics,
},
},
}
p.mx.Lock()
defer p.mx.Unlock()
for _, os := range p.out {
w := ggio.NewDelimitedWriter(os)
err := w.WriteMsg(rpc)
if err != nil {
panic(err)
}
}
}