fix_: a message might match more than one filter
This commit is contained in:
parent
e07182b3f3
commit
cadc3998b8
|
@ -235,23 +235,24 @@ func (fs *Filters) NotifyWatchers(recvMessage *ReceivedMessage) bool {
|
|||
}
|
||||
|
||||
for _, watcher := range candidates {
|
||||
matched = true
|
||||
// Messages are decrypted successfully only once
|
||||
if decodedMsg == nil {
|
||||
decodedMsg = recvMessage.Open(watcher)
|
||||
if decodedMsg == nil {
|
||||
log.Debug("processing message: failed to open", "message", recvMessage.Hash().Hex(), "filter", watcher.id)
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
matched = watcher.MatchMessage(decodedMsg)
|
||||
}
|
||||
|
||||
if matched && decodedMsg != nil {
|
||||
if watcher.MatchMessage(decodedMsg) {
|
||||
matched = true
|
||||
log.Debug("processing message: decrypted", "envelopeHash", recvMessage.Hash().Hex())
|
||||
if watcher.Src == nil || IsPubKeyEqual(decodedMsg.Src, watcher.Src) {
|
||||
watcher.Trigger(decodedMsg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return matched
|
||||
}
|
||||
|
||||
|
@ -292,6 +293,8 @@ func (f *Filter) MatchMessage(msg *ReceivedMessage) bool {
|
|||
return IsPubKeyEqual(&f.KeyAsym.PublicKey, msg.Dst)
|
||||
} else if f.expectsSymmetricEncryption() && msg.isSymmetricEncryption() {
|
||||
return f.SymKeyHash == msg.SymKeyHash
|
||||
} else if !f.expectsAsymmetricEncryption() && !f.expectsSymmetricEncryption() && !msg.isAsymmetricEncryption() && !msg.isSymmetricEncryption() {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
|
|
@ -0,0 +1,315 @@
|
|||
package common
|
||||
|
||||
import (
|
||||
crand "crypto/rand"
|
||||
mrand "math/rand"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
"golang.org/x/exp/maps"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"github.com/ethereum/go-ethereum/common"
|
||||
"github.com/ethereum/go-ethereum/crypto"
|
||||
"github.com/waku-org/go-waku/waku/v2/payload"
|
||||
"github.com/waku-org/go-waku/waku/v2/protocol"
|
||||
"github.com/waku-org/go-waku/waku/v2/protocol/pb"
|
||||
)
|
||||
|
||||
const testShard = "/waku/2/rs/16/32"
|
||||
|
||||
type FilterTestCase struct {
|
||||
f *Filter
|
||||
id string
|
||||
alive bool
|
||||
msgCnt int
|
||||
}
|
||||
|
||||
func createLogger(t *testing.T) *zap.Logger {
|
||||
config := zap.NewDevelopmentConfig()
|
||||
config.Level = zap.NewAtomicLevelAt(zap.DebugLevel)
|
||||
logger, err := config.Build()
|
||||
require.NoError(t, err)
|
||||
return logger
|
||||
}
|
||||
|
||||
func generateFilter(t *testing.T, symmetric bool) (*Filter, error) {
|
||||
var f Filter
|
||||
f.Messages = NewMemoryMessageStore()
|
||||
|
||||
f.PubsubTopic = "test"
|
||||
|
||||
const topicNum = 8
|
||||
f.ContentTopics = make(TopicSet, topicNum)
|
||||
for i := 0; i < topicNum; i++ {
|
||||
topic := make([]byte, 4)
|
||||
_, err := crand.Read(topic) // nolint: gosec
|
||||
require.NoError(t, err)
|
||||
topic[0] = 0x01
|
||||
|
||||
f.ContentTopics[BytesToTopic(topic)] = struct{}{}
|
||||
}
|
||||
|
||||
key, err := crypto.GenerateKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
f.Src = &key.PublicKey
|
||||
|
||||
if symmetric {
|
||||
f.KeySym = make([]byte, AESKeyLength)
|
||||
_, err := crand.Read(f.KeySym) // nolint: gosec
|
||||
require.NoError(t, err)
|
||||
f.SymKeyHash = crypto.Keccak256Hash(f.KeySym)
|
||||
} else {
|
||||
f.KeyAsym, err = crypto.GenerateKey()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
return &f, nil
|
||||
}
|
||||
|
||||
func generateTestCases(t *testing.T, SizeTestFilters int) []FilterTestCase {
|
||||
cases := make([]FilterTestCase, SizeTestFilters)
|
||||
for i := 0; i < SizeTestFilters; i++ {
|
||||
f, _ := generateFilter(t, true)
|
||||
cases[i].f = f
|
||||
cases[i].alive = mrand.Int()&1 == 0 // nolint: gosec
|
||||
}
|
||||
return cases
|
||||
}
|
||||
|
||||
func TestInstallFilters(t *testing.T) {
|
||||
const SizeTestFilters = 256
|
||||
filters := NewFilters(testShard, createLogger(t))
|
||||
tst := generateTestCases(t, SizeTestFilters)
|
||||
|
||||
var err error
|
||||
var j string
|
||||
for i := 0; i < SizeTestFilters; i++ {
|
||||
j, err = filters.Install(tst[i].f)
|
||||
require.NoError(t, err)
|
||||
|
||||
tst[i].id = j
|
||||
require.Len(t, j, KeyIDSize*2)
|
||||
}
|
||||
|
||||
for _, testCase := range tst {
|
||||
if !testCase.alive {
|
||||
filters.Uninstall(testCase.id)
|
||||
}
|
||||
}
|
||||
|
||||
for _, testCase := range tst {
|
||||
fil := filters.Get(testCase.id)
|
||||
exist := fil != nil
|
||||
require.Equal(t, exist, testCase.alive)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstallSymKeyGeneratesHash(t *testing.T) {
|
||||
filters := NewFilters(testShard, createLogger(t))
|
||||
filter, _ := generateFilter(t, true)
|
||||
|
||||
// save the current SymKeyHash for comparison
|
||||
initialSymKeyHash := filter.SymKeyHash
|
||||
|
||||
// ensure the SymKeyHash is invalid, for Install to recreate it
|
||||
var invalid common.Hash
|
||||
filter.SymKeyHash = invalid
|
||||
|
||||
_, err := filters.Install(filter)
|
||||
require.NoError(t, err)
|
||||
|
||||
for i, b := range filter.SymKeyHash {
|
||||
require.Equal(t, b, initialSymKeyHash[i])
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstallIdenticalFilters(t *testing.T) {
|
||||
filters := NewFilters(testShard, createLogger(t))
|
||||
filter1, _ := generateFilter(t, true)
|
||||
|
||||
// Copy the first filter since some of its fields
|
||||
// are randomly gnerated.
|
||||
filter2 := &Filter{
|
||||
KeySym: filter1.KeySym,
|
||||
PubsubTopic: filter1.PubsubTopic,
|
||||
ContentTopics: filter1.ContentTopics,
|
||||
Messages: NewMemoryMessageStore(),
|
||||
}
|
||||
|
||||
_, err := filters.Install(filter1)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = filters.Install(filter2)
|
||||
require.NoError(t, err)
|
||||
|
||||
recvMessage := generateCompatibleReceivedMessage(t, filter1)
|
||||
msg := recvMessage.Open(filter1)
|
||||
require.NotNil(t, msg)
|
||||
}
|
||||
|
||||
func TestInstallFilterWithSymAndAsymKeys(t *testing.T) {
|
||||
filters := NewFilters(testShard, createLogger(t))
|
||||
filter1, _ := generateFilter(t, true)
|
||||
|
||||
asymKey, err := crypto.GenerateKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Copy the first filter since some of its fields
|
||||
// are randomly gnerated.
|
||||
filter := &Filter{
|
||||
KeySym: filter1.KeySym,
|
||||
KeyAsym: asymKey,
|
||||
PubsubTopic: filter1.PubsubTopic,
|
||||
ContentTopics: filter1.ContentTopics,
|
||||
Messages: NewMemoryMessageStore(),
|
||||
}
|
||||
|
||||
_, err = filters.Install(filter)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func cloneFilter(orig *Filter) *Filter {
|
||||
var clone Filter
|
||||
clone.Messages = NewMemoryMessageStore()
|
||||
clone.Src = orig.Src
|
||||
clone.KeyAsym = orig.KeyAsym
|
||||
clone.KeySym = orig.KeySym
|
||||
clone.PubsubTopic = orig.PubsubTopic
|
||||
clone.ContentTopics = orig.ContentTopics
|
||||
clone.SymKeyHash = orig.SymKeyHash
|
||||
return &clone
|
||||
}
|
||||
|
||||
func generateCompatibleReceivedMessage(t *testing.T, f *Filter) *ReceivedMessage {
|
||||
keyInfo := &payload.KeyInfo{}
|
||||
keyInfo.Kind = payload.Symmetric
|
||||
keyInfo.SymKey = f.KeySym
|
||||
|
||||
var version uint32 = 1
|
||||
p := new(payload.Payload)
|
||||
p.Data = make([]byte, 20)
|
||||
_, err := crand.Read(p.Data) // nolint: gosec
|
||||
require.NoError(t, err)
|
||||
p.Key = keyInfo
|
||||
payload, err := p.Encode(version)
|
||||
require.NoError(t, err)
|
||||
|
||||
msg := &pb.WakuMessage{
|
||||
Payload: payload,
|
||||
Version: &version,
|
||||
ContentTopic: maps.Keys(f.ContentTopics)[2].ContentTopic(),
|
||||
Timestamp: proto.Int64(time.Now().UnixNano()),
|
||||
Meta: []byte{},
|
||||
}
|
||||
envelope := protocol.NewEnvelope(msg, time.Now().UnixNano(), f.PubsubTopic)
|
||||
|
||||
result := NewReceivedMessage(envelope, "test")
|
||||
result.SymKeyHash = crypto.Keccak256Hash(f.KeySym)
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func TestWatchers(t *testing.T) {
|
||||
const NumFilters = 16
|
||||
const NumMessages = 256
|
||||
var i int
|
||||
var j uint32
|
||||
var e *ReceivedMessage
|
||||
var x, firstID string
|
||||
var err error
|
||||
|
||||
filters := NewFilters("/waku/2/rs/16/32", createLogger(t))
|
||||
tst := generateTestCases(t, NumFilters)
|
||||
for i = 0; i < NumFilters; i++ {
|
||||
tst[i].f.Src = nil
|
||||
x, err = filters.Install(tst[i].f)
|
||||
require.NoError(t, err)
|
||||
|
||||
tst[i].id = x
|
||||
if len(firstID) == 0 {
|
||||
firstID = x
|
||||
}
|
||||
}
|
||||
|
||||
lastID := x
|
||||
|
||||
var envelopes [NumMessages]*ReceivedMessage
|
||||
for i = 0; i < NumMessages; i++ {
|
||||
j = mrand.Uint32() % NumFilters // nolint: gosec
|
||||
e = generateCompatibleReceivedMessage(t, tst[j].f)
|
||||
envelopes[i] = e
|
||||
tst[j].msgCnt++
|
||||
}
|
||||
|
||||
for i = 0; i < NumMessages; i++ {
|
||||
filters.NotifyWatchers(envelopes[i])
|
||||
}
|
||||
|
||||
var total int
|
||||
var mail []*ReceivedMessage
|
||||
var count [NumFilters]int
|
||||
|
||||
for i = 0; i < NumFilters; i++ {
|
||||
mail = tst[i].f.Retrieve()
|
||||
count[i] = len(mail)
|
||||
total += len(mail)
|
||||
}
|
||||
require.Equal(t, total, NumMessages)
|
||||
|
||||
for i = 0; i < NumFilters; i++ {
|
||||
mail = tst[i].f.Retrieve()
|
||||
require.Zero(t, len(mail))
|
||||
require.Equal(t, tst[i].msgCnt, count[i])
|
||||
}
|
||||
|
||||
// another round with a cloned filter
|
||||
|
||||
clone := cloneFilter(tst[0].f)
|
||||
filters.Uninstall(lastID)
|
||||
total = 0
|
||||
last := NumFilters - 1
|
||||
tst[last].f = clone
|
||||
_, err = filters.Install(clone)
|
||||
require.NoError(t, err)
|
||||
|
||||
for i = 0; i < NumFilters; i++ {
|
||||
tst[i].msgCnt = 0
|
||||
count[i] = 0
|
||||
}
|
||||
|
||||
// make sure that the first watcher receives at least one message
|
||||
e = generateCompatibleReceivedMessage(t, tst[0].f)
|
||||
envelopes[0] = e
|
||||
tst[0].msgCnt++
|
||||
for i = 1; i < NumMessages; i++ {
|
||||
j = mrand.Uint32() % NumFilters // nolint: gosec
|
||||
e = generateCompatibleReceivedMessage(t, tst[j].f)
|
||||
envelopes[i] = e
|
||||
tst[j].msgCnt++
|
||||
}
|
||||
|
||||
for i = 0; i < NumMessages; i++ {
|
||||
filters.NotifyWatchers(envelopes[i])
|
||||
}
|
||||
|
||||
for i = 0; i < NumFilters; i++ {
|
||||
mail = tst[i].f.Retrieve()
|
||||
count[i] = len(mail)
|
||||
total += len(mail)
|
||||
}
|
||||
|
||||
combined := tst[0].msgCnt + tst[last].msgCnt
|
||||
require.Equal(t, total, NumMessages+count[0])
|
||||
require.Equal(t, combined, count[0])
|
||||
require.Equal(t, combined, count[last])
|
||||
|
||||
for i = 1; i < NumFilters-1; i++ {
|
||||
mail = tst[i].f.Retrieve()
|
||||
require.Zero(t, len(mail))
|
||||
require.Equal(t, tst[i].msgCnt, count[i])
|
||||
}
|
||||
}
|
|
@ -185,7 +185,8 @@ func (msg *ReceivedMessage) Open(watcher *Filter) (result *ReceivedMessage) {
|
|||
result.Padding = raw.Padding
|
||||
result.Signature = raw.Signature
|
||||
result.Src = raw.PubKey
|
||||
|
||||
result.SymKeyHash = msg.SymKeyHash
|
||||
result.Dst = msg.Dst
|
||||
result.Sent = uint32(msg.Envelope.Message().GetTimestamp() / int64(time.Second))
|
||||
|
||||
ct, err := ExtractTopicFromContentTopic(msg.Envelope.Message().ContentTopic)
|
||||
|
|
Loading…
Reference in New Issue