mirror of
https://github.com/logos-messaging/go-libp2p-pubsub.git
synced 2026-01-02 12:53:09 +00:00
add test for multi-topic validation
This commit is contained in:
parent
957335ba52
commit
741b7e9b41
@ -4,11 +4,18 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
pb "github.com/libp2p/go-libp2p-pubsub/pb"
|
||||
|
||||
"github.com/libp2p/go-libp2p-core/host"
|
||||
"github.com/libp2p/go-libp2p-core/network"
|
||||
"github.com/libp2p/go-libp2p-core/peer"
|
||||
|
||||
ggio "github.com/gogo/protobuf/io"
|
||||
)
|
||||
|
||||
func TestRegisterUnregisterValidator(t *testing.T) {
|
||||
@ -264,3 +271,154 @@ func TestValidateAssortedOptions(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateMultitopic(t *testing.T) {
|
||||
// this test adds coverage for multi-topic validation
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
hosts := getNetHosts(t, ctx, 3)
|
||||
psubs := getPubsubs(ctx, hosts[1:], WithMessageSigning(false))
|
||||
for _, ps := range psubs {
|
||||
err := ps.RegisterTopicValidator("test1", func(context.Context, peer.ID, *Message) bool {
|
||||
return true
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = ps.RegisterTopicValidator("test2", func(context.Context, peer.ID, *Message) bool {
|
||||
return true
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = ps.RegisterTopicValidator("test3", func(context.Context, peer.ID, *Message) bool {
|
||||
return false
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
publisher := &multiTopicPublisher{ctx: ctx, h: hosts[0]}
|
||||
hosts[0].SetStreamHandler(FloodSubID, publisher.handleStream)
|
||||
|
||||
connectAll(t, hosts)
|
||||
|
||||
var subs1, subs2, subs3 []*Subscription
|
||||
for _, ps := range psubs {
|
||||
sub, err := ps.Subscribe("test1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
subs1 = append(subs1, sub)
|
||||
|
||||
sub, err = ps.Subscribe("test2")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
subs2 = append(subs2, sub)
|
||||
|
||||
sub, err = ps.Subscribe("test3")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
subs3 = append(subs2, sub)
|
||||
}
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
msg1 := "i am a walrus"
|
||||
|
||||
// this goes to test1 and test2, which is accepted and should be delivered
|
||||
publisher.publish(msg1, "test1", "test2")
|
||||
for _, sub := range subs1 {
|
||||
assertReceive(t, sub, []byte(msg1))
|
||||
}
|
||||
for _, sub := range subs2 {
|
||||
assertReceive(t, sub, []byte(msg1))
|
||||
}
|
||||
|
||||
// this goes to test2 and test3, which is rejected by the test3 validator and should not be delivered
|
||||
msg2 := "i am not a walrus"
|
||||
publisher.publish(msg2, "test2", "test3")
|
||||
|
||||
expectNoMessage := func(sub *Subscription) {
|
||||
ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
m, err := sub.Next(ctx)
|
||||
if err == nil {
|
||||
t.Fatal("expected no message, but got ", string(m.Data))
|
||||
}
|
||||
}
|
||||
|
||||
for _, sub := range subs2 {
|
||||
expectNoMessage(sub)
|
||||
}
|
||||
|
||||
for _, sub := range subs3 {
|
||||
expectNoMessage(sub)
|
||||
}
|
||||
}
|
||||
|
||||
type multiTopicPublisher struct {
|
||||
ctx context.Context
|
||||
h host.Host
|
||||
|
||||
mx sync.Mutex
|
||||
out []network.Stream
|
||||
mcount int
|
||||
}
|
||||
|
||||
func (p *multiTopicPublisher) handleStream(s network.Stream) {
|
||||
defer s.Close()
|
||||
|
||||
os, err := p.h.NewStream(p.ctx, s.Conn().RemotePeer(), FloodSubID)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
p.mx.Lock()
|
||||
p.out = append(p.out, os)
|
||||
p.mx.Unlock()
|
||||
|
||||
r := ggio.NewDelimitedReader(s, 1<<20)
|
||||
var rpc pb.RPC
|
||||
for {
|
||||
rpc.Reset()
|
||||
err = r.ReadMsg(&rpc)
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
s.Reset()
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *multiTopicPublisher) publish(msg string, topics ...string) {
|
||||
p.mcount++
|
||||
rpc := &pb.RPC{
|
||||
Publish: []*pb.Message{
|
||||
&pb.Message{
|
||||
From: []byte(p.h.ID()),
|
||||
Data: []byte(msg),
|
||||
Seqno: []byte{byte(p.mcount)},
|
||||
TopicIDs: topics,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
p.mx.Lock()
|
||||
defer p.mx.Unlock()
|
||||
for _, os := range p.out {
|
||||
w := ggio.NewDelimitedWriter(os)
|
||||
err := w.WriteMsg(rpc)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user