fix_: a message might match more than one filter
This commit is contained in:
parent
e07182b3f3
commit
cadc3998b8
|
@ -235,23 +235,24 @@ func (fs *Filters) NotifyWatchers(recvMessage *ReceivedMessage) bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, watcher := range candidates {
|
for _, watcher := range candidates {
|
||||||
matched = true
|
// Messages are decrypted successfully only once
|
||||||
if decodedMsg == nil {
|
if decodedMsg == nil {
|
||||||
decodedMsg = recvMessage.Open(watcher)
|
decodedMsg = recvMessage.Open(watcher)
|
||||||
if decodedMsg == nil {
|
if decodedMsg == nil {
|
||||||
log.Debug("processing message: failed to open", "message", recvMessage.Hash().Hex(), "filter", watcher.id)
|
log.Debug("processing message: failed to open", "message", recvMessage.Hash().Hex(), "filter", watcher.id)
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
matched = watcher.MatchMessage(decodedMsg)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if matched && decodedMsg != nil {
|
if watcher.MatchMessage(decodedMsg) {
|
||||||
|
matched = true
|
||||||
log.Debug("processing message: decrypted", "envelopeHash", recvMessage.Hash().Hex())
|
log.Debug("processing message: decrypted", "envelopeHash", recvMessage.Hash().Hex())
|
||||||
if watcher.Src == nil || IsPubKeyEqual(decodedMsg.Src, watcher.Src) {
|
if watcher.Src == nil || IsPubKeyEqual(decodedMsg.Src, watcher.Src) {
|
||||||
watcher.Trigger(decodedMsg)
|
watcher.Trigger(decodedMsg)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return matched
|
return matched
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -292,6 +293,8 @@ func (f *Filter) MatchMessage(msg *ReceivedMessage) bool {
|
||||||
return IsPubKeyEqual(&f.KeyAsym.PublicKey, msg.Dst)
|
return IsPubKeyEqual(&f.KeyAsym.PublicKey, msg.Dst)
|
||||||
} else if f.expectsSymmetricEncryption() && msg.isSymmetricEncryption() {
|
} else if f.expectsSymmetricEncryption() && msg.isSymmetricEncryption() {
|
||||||
return f.SymKeyHash == msg.SymKeyHash
|
return f.SymKeyHash == msg.SymKeyHash
|
||||||
|
} else if !f.expectsAsymmetricEncryption() && !f.expectsSymmetricEncryption() && !msg.isAsymmetricEncryption() && !msg.isSymmetricEncryption() {
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,315 @@
|
||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
crand "crypto/rand"
|
||||||
|
mrand "math/rand"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
"golang.org/x/exp/maps"
|
||||||
|
"google.golang.org/protobuf/proto"
|
||||||
|
|
||||||
|
"github.com/ethereum/go-ethereum/common"
|
||||||
|
"github.com/ethereum/go-ethereum/crypto"
|
||||||
|
"github.com/waku-org/go-waku/waku/v2/payload"
|
||||||
|
"github.com/waku-org/go-waku/waku/v2/protocol"
|
||||||
|
"github.com/waku-org/go-waku/waku/v2/protocol/pb"
|
||||||
|
)
|
||||||
|
|
||||||
|
const testShard = "/waku/2/rs/16/32"
|
||||||
|
|
||||||
|
type FilterTestCase struct {
|
||||||
|
f *Filter
|
||||||
|
id string
|
||||||
|
alive bool
|
||||||
|
msgCnt int
|
||||||
|
}
|
||||||
|
|
||||||
|
func createLogger(t *testing.T) *zap.Logger {
|
||||||
|
config := zap.NewDevelopmentConfig()
|
||||||
|
config.Level = zap.NewAtomicLevelAt(zap.DebugLevel)
|
||||||
|
logger, err := config.Build()
|
||||||
|
require.NoError(t, err)
|
||||||
|
return logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateFilter(t *testing.T, symmetric bool) (*Filter, error) {
|
||||||
|
var f Filter
|
||||||
|
f.Messages = NewMemoryMessageStore()
|
||||||
|
|
||||||
|
f.PubsubTopic = "test"
|
||||||
|
|
||||||
|
const topicNum = 8
|
||||||
|
f.ContentTopics = make(TopicSet, topicNum)
|
||||||
|
for i := 0; i < topicNum; i++ {
|
||||||
|
topic := make([]byte, 4)
|
||||||
|
_, err := crand.Read(topic) // nolint: gosec
|
||||||
|
require.NoError(t, err)
|
||||||
|
topic[0] = 0x01
|
||||||
|
|
||||||
|
f.ContentTopics[BytesToTopic(topic)] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
key, err := crypto.GenerateKey()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
f.Src = &key.PublicKey
|
||||||
|
|
||||||
|
if symmetric {
|
||||||
|
f.KeySym = make([]byte, AESKeyLength)
|
||||||
|
_, err := crand.Read(f.KeySym) // nolint: gosec
|
||||||
|
require.NoError(t, err)
|
||||||
|
f.SymKeyHash = crypto.Keccak256Hash(f.KeySym)
|
||||||
|
} else {
|
||||||
|
f.KeyAsym, err = crypto.GenerateKey()
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &f, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateTestCases(t *testing.T, SizeTestFilters int) []FilterTestCase {
|
||||||
|
cases := make([]FilterTestCase, SizeTestFilters)
|
||||||
|
for i := 0; i < SizeTestFilters; i++ {
|
||||||
|
f, _ := generateFilter(t, true)
|
||||||
|
cases[i].f = f
|
||||||
|
cases[i].alive = mrand.Int()&1 == 0 // nolint: gosec
|
||||||
|
}
|
||||||
|
return cases
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInstallFilters(t *testing.T) {
|
||||||
|
const SizeTestFilters = 256
|
||||||
|
filters := NewFilters(testShard, createLogger(t))
|
||||||
|
tst := generateTestCases(t, SizeTestFilters)
|
||||||
|
|
||||||
|
var err error
|
||||||
|
var j string
|
||||||
|
for i := 0; i < SizeTestFilters; i++ {
|
||||||
|
j, err = filters.Install(tst[i].f)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
tst[i].id = j
|
||||||
|
require.Len(t, j, KeyIDSize*2)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, testCase := range tst {
|
||||||
|
if !testCase.alive {
|
||||||
|
filters.Uninstall(testCase.id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, testCase := range tst {
|
||||||
|
fil := filters.Get(testCase.id)
|
||||||
|
exist := fil != nil
|
||||||
|
require.Equal(t, exist, testCase.alive)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInstallSymKeyGeneratesHash(t *testing.T) {
|
||||||
|
filters := NewFilters(testShard, createLogger(t))
|
||||||
|
filter, _ := generateFilter(t, true)
|
||||||
|
|
||||||
|
// save the current SymKeyHash for comparison
|
||||||
|
initialSymKeyHash := filter.SymKeyHash
|
||||||
|
|
||||||
|
// ensure the SymKeyHash is invalid, for Install to recreate it
|
||||||
|
var invalid common.Hash
|
||||||
|
filter.SymKeyHash = invalid
|
||||||
|
|
||||||
|
_, err := filters.Install(filter)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
for i, b := range filter.SymKeyHash {
|
||||||
|
require.Equal(t, b, initialSymKeyHash[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInstallIdenticalFilters(t *testing.T) {
|
||||||
|
filters := NewFilters(testShard, createLogger(t))
|
||||||
|
filter1, _ := generateFilter(t, true)
|
||||||
|
|
||||||
|
// Copy the first filter since some of its fields
|
||||||
|
// are randomly gnerated.
|
||||||
|
filter2 := &Filter{
|
||||||
|
KeySym: filter1.KeySym,
|
||||||
|
PubsubTopic: filter1.PubsubTopic,
|
||||||
|
ContentTopics: filter1.ContentTopics,
|
||||||
|
Messages: NewMemoryMessageStore(),
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := filters.Install(filter1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = filters.Install(filter2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
recvMessage := generateCompatibleReceivedMessage(t, filter1)
|
||||||
|
msg := recvMessage.Open(filter1)
|
||||||
|
require.NotNil(t, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInstallFilterWithSymAndAsymKeys(t *testing.T) {
|
||||||
|
filters := NewFilters(testShard, createLogger(t))
|
||||||
|
filter1, _ := generateFilter(t, true)
|
||||||
|
|
||||||
|
asymKey, err := crypto.GenerateKey()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Copy the first filter since some of its fields
|
||||||
|
// are randomly gnerated.
|
||||||
|
filter := &Filter{
|
||||||
|
KeySym: filter1.KeySym,
|
||||||
|
KeyAsym: asymKey,
|
||||||
|
PubsubTopic: filter1.PubsubTopic,
|
||||||
|
ContentTopics: filter1.ContentTopics,
|
||||||
|
Messages: NewMemoryMessageStore(),
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = filters.Install(filter)
|
||||||
|
require.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func cloneFilter(orig *Filter) *Filter {
|
||||||
|
var clone Filter
|
||||||
|
clone.Messages = NewMemoryMessageStore()
|
||||||
|
clone.Src = orig.Src
|
||||||
|
clone.KeyAsym = orig.KeyAsym
|
||||||
|
clone.KeySym = orig.KeySym
|
||||||
|
clone.PubsubTopic = orig.PubsubTopic
|
||||||
|
clone.ContentTopics = orig.ContentTopics
|
||||||
|
clone.SymKeyHash = orig.SymKeyHash
|
||||||
|
return &clone
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateCompatibleReceivedMessage(t *testing.T, f *Filter) *ReceivedMessage {
|
||||||
|
keyInfo := &payload.KeyInfo{}
|
||||||
|
keyInfo.Kind = payload.Symmetric
|
||||||
|
keyInfo.SymKey = f.KeySym
|
||||||
|
|
||||||
|
var version uint32 = 1
|
||||||
|
p := new(payload.Payload)
|
||||||
|
p.Data = make([]byte, 20)
|
||||||
|
_, err := crand.Read(p.Data) // nolint: gosec
|
||||||
|
require.NoError(t, err)
|
||||||
|
p.Key = keyInfo
|
||||||
|
payload, err := p.Encode(version)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
msg := &pb.WakuMessage{
|
||||||
|
Payload: payload,
|
||||||
|
Version: &version,
|
||||||
|
ContentTopic: maps.Keys(f.ContentTopics)[2].ContentTopic(),
|
||||||
|
Timestamp: proto.Int64(time.Now().UnixNano()),
|
||||||
|
Meta: []byte{},
|
||||||
|
}
|
||||||
|
envelope := protocol.NewEnvelope(msg, time.Now().UnixNano(), f.PubsubTopic)
|
||||||
|
|
||||||
|
result := NewReceivedMessage(envelope, "test")
|
||||||
|
result.SymKeyHash = crypto.Keccak256Hash(f.KeySym)
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWatchers(t *testing.T) {
|
||||||
|
const NumFilters = 16
|
||||||
|
const NumMessages = 256
|
||||||
|
var i int
|
||||||
|
var j uint32
|
||||||
|
var e *ReceivedMessage
|
||||||
|
var x, firstID string
|
||||||
|
var err error
|
||||||
|
|
||||||
|
filters := NewFilters("/waku/2/rs/16/32", createLogger(t))
|
||||||
|
tst := generateTestCases(t, NumFilters)
|
||||||
|
for i = 0; i < NumFilters; i++ {
|
||||||
|
tst[i].f.Src = nil
|
||||||
|
x, err = filters.Install(tst[i].f)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
tst[i].id = x
|
||||||
|
if len(firstID) == 0 {
|
||||||
|
firstID = x
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
lastID := x
|
||||||
|
|
||||||
|
var envelopes [NumMessages]*ReceivedMessage
|
||||||
|
for i = 0; i < NumMessages; i++ {
|
||||||
|
j = mrand.Uint32() % NumFilters // nolint: gosec
|
||||||
|
e = generateCompatibleReceivedMessage(t, tst[j].f)
|
||||||
|
envelopes[i] = e
|
||||||
|
tst[j].msgCnt++
|
||||||
|
}
|
||||||
|
|
||||||
|
for i = 0; i < NumMessages; i++ {
|
||||||
|
filters.NotifyWatchers(envelopes[i])
|
||||||
|
}
|
||||||
|
|
||||||
|
var total int
|
||||||
|
var mail []*ReceivedMessage
|
||||||
|
var count [NumFilters]int
|
||||||
|
|
||||||
|
for i = 0; i < NumFilters; i++ {
|
||||||
|
mail = tst[i].f.Retrieve()
|
||||||
|
count[i] = len(mail)
|
||||||
|
total += len(mail)
|
||||||
|
}
|
||||||
|
require.Equal(t, total, NumMessages)
|
||||||
|
|
||||||
|
for i = 0; i < NumFilters; i++ {
|
||||||
|
mail = tst[i].f.Retrieve()
|
||||||
|
require.Zero(t, len(mail))
|
||||||
|
require.Equal(t, tst[i].msgCnt, count[i])
|
||||||
|
}
|
||||||
|
|
||||||
|
// another round with a cloned filter
|
||||||
|
|
||||||
|
clone := cloneFilter(tst[0].f)
|
||||||
|
filters.Uninstall(lastID)
|
||||||
|
total = 0
|
||||||
|
last := NumFilters - 1
|
||||||
|
tst[last].f = clone
|
||||||
|
_, err = filters.Install(clone)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
for i = 0; i < NumFilters; i++ {
|
||||||
|
tst[i].msgCnt = 0
|
||||||
|
count[i] = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// make sure that the first watcher receives at least one message
|
||||||
|
e = generateCompatibleReceivedMessage(t, tst[0].f)
|
||||||
|
envelopes[0] = e
|
||||||
|
tst[0].msgCnt++
|
||||||
|
for i = 1; i < NumMessages; i++ {
|
||||||
|
j = mrand.Uint32() % NumFilters // nolint: gosec
|
||||||
|
e = generateCompatibleReceivedMessage(t, tst[j].f)
|
||||||
|
envelopes[i] = e
|
||||||
|
tst[j].msgCnt++
|
||||||
|
}
|
||||||
|
|
||||||
|
for i = 0; i < NumMessages; i++ {
|
||||||
|
filters.NotifyWatchers(envelopes[i])
|
||||||
|
}
|
||||||
|
|
||||||
|
for i = 0; i < NumFilters; i++ {
|
||||||
|
mail = tst[i].f.Retrieve()
|
||||||
|
count[i] = len(mail)
|
||||||
|
total += len(mail)
|
||||||
|
}
|
||||||
|
|
||||||
|
combined := tst[0].msgCnt + tst[last].msgCnt
|
||||||
|
require.Equal(t, total, NumMessages+count[0])
|
||||||
|
require.Equal(t, combined, count[0])
|
||||||
|
require.Equal(t, combined, count[last])
|
||||||
|
|
||||||
|
for i = 1; i < NumFilters-1; i++ {
|
||||||
|
mail = tst[i].f.Retrieve()
|
||||||
|
require.Zero(t, len(mail))
|
||||||
|
require.Equal(t, tst[i].msgCnt, count[i])
|
||||||
|
}
|
||||||
|
}
|
|
@ -185,7 +185,8 @@ func (msg *ReceivedMessage) Open(watcher *Filter) (result *ReceivedMessage) {
|
||||||
result.Padding = raw.Padding
|
result.Padding = raw.Padding
|
||||||
result.Signature = raw.Signature
|
result.Signature = raw.Signature
|
||||||
result.Src = raw.PubKey
|
result.Src = raw.PubKey
|
||||||
|
result.SymKeyHash = msg.SymKeyHash
|
||||||
|
result.Dst = msg.Dst
|
||||||
result.Sent = uint32(msg.Envelope.Message().GetTimestamp() / int64(time.Second))
|
result.Sent = uint32(msg.Envelope.Message().GetTimestamp() / int64(time.Second))
|
||||||
|
|
||||||
ct, err := ExtractTopicFromContentTopic(msg.Envelope.Message().ContentTopic)
|
ct, err := ExtractTopicFromContentTopic(msg.Envelope.Message().ContentTopic)
|
||||||
|
|
Loading…
Reference in New Issue