Merge pull request #151 from libp2p/feat/validator

Extend validator interface to include message source
This commit is contained in:
vyzo 2019-01-17 18:09:14 +02:00 committed by GitHub
commit bfd65a2f6b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 12 additions and 12 deletions

View File

@ -357,7 +357,7 @@ func TestRegisterUnregisterValidator(t *testing.T) {
hosts := getNetHosts(t, ctx, 1) hosts := getNetHosts(t, ctx, 1)
psubs := getPubsubs(ctx, hosts) psubs := getPubsubs(ctx, hosts)
err := psubs[0].RegisterTopicValidator("foo", func(context.Context, *Message) bool { err := psubs[0].RegisterTopicValidator("foo", func(context.Context, peer.ID, *Message) bool {
return true return true
}) })
if err != nil { if err != nil {
@ -385,7 +385,7 @@ func TestValidate(t *testing.T) {
connect(t, hosts[0], hosts[1]) connect(t, hosts[0], hosts[1])
topic := "foobar" topic := "foobar"
err := psubs[1].RegisterTopicValidator(topic, func(ctx context.Context, msg *Message) bool { err := psubs[1].RegisterTopicValidator(topic, func(ctx context.Context, from peer.ID, msg *Message) bool {
return !bytes.Contains(msg.Data, []byte("illegal")) return !bytes.Contains(msg.Data, []byte("illegal"))
}) })
if err != nil { if err != nil {
@ -482,7 +482,7 @@ func TestValidateOverload(t *testing.T) {
block := make(chan struct{}) block := make(chan struct{})
err := psubs[1].RegisterTopicValidator(topic, err := psubs[1].RegisterTopicValidator(topic,
func(ctx context.Context, msg *Message) bool { func(ctx context.Context, from peer.ID, msg *Message) bool {
<-block <-block
return true return true
}, },

View File

@ -661,7 +661,7 @@ func (p *PubSub) validate(vals []*topicVal, src peer.ID, msg *Message) {
} }
if len(vals) > 0 { if len(vals) > 0 {
if !p.validateTopic(vals, msg) { if !p.validateTopic(vals, src, msg) {
log.Warningf("message validation failed; dropping message from %s", src) log.Warningf("message validation failed; dropping message from %s", src)
return return
} }
@ -684,9 +684,9 @@ func (p *PubSub) validateSignature(msg *Message) bool {
return true return true
} }
func (p *PubSub) validateTopic(vals []*topicVal, msg *Message) bool { func (p *PubSub) validateTopic(vals []*topicVal, src peer.ID, msg *Message) bool {
if len(vals) == 1 { if len(vals) == 1 {
return p.validateSingleTopic(vals[0], msg) return p.validateSingleTopic(vals[0], src, msg)
} }
ctx, cancel := context.WithCancel(p.ctx) ctx, cancel := context.WithCancel(p.ctx)
@ -703,7 +703,7 @@ loop:
select { select {
case val.validateThrottle <- struct{}{}: case val.validateThrottle <- struct{}{}:
go func(val *topicVal) { go func(val *topicVal) {
rch <- val.validateMsg(ctx, msg) rch <- val.validateMsg(ctx, src, msg)
<-val.validateThrottle <-val.validateThrottle
}(val) }(val)
@ -729,13 +729,13 @@ loop:
} }
// fast path for single topic validation that avoids the extra goroutine // fast path for single topic validation that avoids the extra goroutine
func (p *PubSub) validateSingleTopic(val *topicVal, msg *Message) bool { func (p *PubSub) validateSingleTopic(val *topicVal, src peer.ID, msg *Message) bool {
select { select {
case val.validateThrottle <- struct{}{}: case val.validateThrottle <- struct{}{}:
ctx, cancel := context.WithCancel(p.ctx) ctx, cancel := context.WithCancel(p.ctx)
defer cancel() defer cancel()
res := val.validateMsg(ctx, msg) res := val.validateMsg(ctx, src, msg)
<-val.validateThrottle <-val.validateThrottle
return res return res
@ -900,7 +900,7 @@ type topicVal struct {
} }
// Validator is a function that validates a message. // Validator is a function that validates a message.
type Validator func(context.Context, *Message) bool type Validator func(context.Context, peer.ID, *Message) bool
// ValidatorOpt is an option for RegisterTopicValidator. // ValidatorOpt is an option for RegisterTopicValidator.
type ValidatorOpt func(addVal *addValReq) error type ValidatorOpt func(addVal *addValReq) error
@ -992,11 +992,11 @@ func (ps *PubSub) rmValidator(req *rmValReq) {
} }
} }
func (val *topicVal) validateMsg(ctx context.Context, msg *Message) bool { func (val *topicVal) validateMsg(ctx context.Context, src peer.ID, msg *Message) bool {
vctx, cancel := context.WithTimeout(ctx, val.validateTimeout) vctx, cancel := context.WithTimeout(ctx, val.validateTimeout)
defer cancel() defer cancel()
valid := val.validate(vctx, msg) valid := val.validate(vctx, src, msg)
if !valid { if !valid {
log.Debugf("validation failed for topic %s", val.topic) log.Debugf("validation failed for topic %s", val.topic)
} }