mirror of
https://github.com/logos-messaging/go-libp2p-pubsub.git
synced 2026-05-23 17:09:31 +00:00
Merge pull request #151 from libp2p/feat/validator
Extend validator interface to include message source
This commit is contained in:
commit
bfd65a2f6b
@ -357,7 +357,7 @@ func TestRegisterUnregisterValidator(t *testing.T) {
|
|||||||
hosts := getNetHosts(t, ctx, 1)
|
hosts := getNetHosts(t, ctx, 1)
|
||||||
psubs := getPubsubs(ctx, hosts)
|
psubs := getPubsubs(ctx, hosts)
|
||||||
|
|
||||||
err := psubs[0].RegisterTopicValidator("foo", func(context.Context, *Message) bool {
|
err := psubs[0].RegisterTopicValidator("foo", func(context.Context, peer.ID, *Message) bool {
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -385,7 +385,7 @@ func TestValidate(t *testing.T) {
|
|||||||
connect(t, hosts[0], hosts[1])
|
connect(t, hosts[0], hosts[1])
|
||||||
topic := "foobar"
|
topic := "foobar"
|
||||||
|
|
||||||
err := psubs[1].RegisterTopicValidator(topic, func(ctx context.Context, msg *Message) bool {
|
err := psubs[1].RegisterTopicValidator(topic, func(ctx context.Context, from peer.ID, msg *Message) bool {
|
||||||
return !bytes.Contains(msg.Data, []byte("illegal"))
|
return !bytes.Contains(msg.Data, []byte("illegal"))
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -482,7 +482,7 @@ func TestValidateOverload(t *testing.T) {
|
|||||||
block := make(chan struct{})
|
block := make(chan struct{})
|
||||||
|
|
||||||
err := psubs[1].RegisterTopicValidator(topic,
|
err := psubs[1].RegisterTopicValidator(topic,
|
||||||
func(ctx context.Context, msg *Message) bool {
|
func(ctx context.Context, from peer.ID, msg *Message) bool {
|
||||||
<-block
|
<-block
|
||||||
return true
|
return true
|
||||||
},
|
},
|
||||||
|
|||||||
18
pubsub.go
18
pubsub.go
@ -661,7 +661,7 @@ func (p *PubSub) validate(vals []*topicVal, src peer.ID, msg *Message) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(vals) > 0 {
|
if len(vals) > 0 {
|
||||||
if !p.validateTopic(vals, msg) {
|
if !p.validateTopic(vals, src, msg) {
|
||||||
log.Warningf("message validation failed; dropping message from %s", src)
|
log.Warningf("message validation failed; dropping message from %s", src)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -684,9 +684,9 @@ func (p *PubSub) validateSignature(msg *Message) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PubSub) validateTopic(vals []*topicVal, msg *Message) bool {
|
func (p *PubSub) validateTopic(vals []*topicVal, src peer.ID, msg *Message) bool {
|
||||||
if len(vals) == 1 {
|
if len(vals) == 1 {
|
||||||
return p.validateSingleTopic(vals[0], msg)
|
return p.validateSingleTopic(vals[0], src, msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(p.ctx)
|
ctx, cancel := context.WithCancel(p.ctx)
|
||||||
@ -703,7 +703,7 @@ loop:
|
|||||||
select {
|
select {
|
||||||
case val.validateThrottle <- struct{}{}:
|
case val.validateThrottle <- struct{}{}:
|
||||||
go func(val *topicVal) {
|
go func(val *topicVal) {
|
||||||
rch <- val.validateMsg(ctx, msg)
|
rch <- val.validateMsg(ctx, src, msg)
|
||||||
<-val.validateThrottle
|
<-val.validateThrottle
|
||||||
}(val)
|
}(val)
|
||||||
|
|
||||||
@ -729,13 +729,13 @@ loop:
|
|||||||
}
|
}
|
||||||
|
|
||||||
// fast path for single topic validation that avoids the extra goroutine
|
// fast path for single topic validation that avoids the extra goroutine
|
||||||
func (p *PubSub) validateSingleTopic(val *topicVal, msg *Message) bool {
|
func (p *PubSub) validateSingleTopic(val *topicVal, src peer.ID, msg *Message) bool {
|
||||||
select {
|
select {
|
||||||
case val.validateThrottle <- struct{}{}:
|
case val.validateThrottle <- struct{}{}:
|
||||||
ctx, cancel := context.WithCancel(p.ctx)
|
ctx, cancel := context.WithCancel(p.ctx)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
res := val.validateMsg(ctx, msg)
|
res := val.validateMsg(ctx, src, msg)
|
||||||
<-val.validateThrottle
|
<-val.validateThrottle
|
||||||
|
|
||||||
return res
|
return res
|
||||||
@ -900,7 +900,7 @@ type topicVal struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Validator is a function that validates a message.
|
// Validator is a function that validates a message.
|
||||||
type Validator func(context.Context, *Message) bool
|
type Validator func(context.Context, peer.ID, *Message) bool
|
||||||
|
|
||||||
// ValidatorOpt is an option for RegisterTopicValidator.
|
// ValidatorOpt is an option for RegisterTopicValidator.
|
||||||
type ValidatorOpt func(addVal *addValReq) error
|
type ValidatorOpt func(addVal *addValReq) error
|
||||||
@ -992,11 +992,11 @@ func (ps *PubSub) rmValidator(req *rmValReq) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (val *topicVal) validateMsg(ctx context.Context, msg *Message) bool {
|
func (val *topicVal) validateMsg(ctx context.Context, src peer.ID, msg *Message) bool {
|
||||||
vctx, cancel := context.WithTimeout(ctx, val.validateTimeout)
|
vctx, cancel := context.WithTimeout(ctx, val.validateTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
valid := val.validate(vctx, msg)
|
valid := val.validate(vctx, src, msg)
|
||||||
if !valid {
|
if !valid {
|
||||||
log.Debugf("validation failed for topic %s", val.topic)
|
log.Debugf("validation failed for topic %s", val.topic)
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user