chore_: adapt tracking for segmented messages

closes: #4310
This commit is contained in:
Patryk Osmaczko 2023-11-17 15:45:19 +01:00 committed by osmaczko
parent 71f2d63a71
commit d73d1e2488
6 changed files with 230 additions and 74 deletions

View File

@ -94,6 +94,15 @@ func EncodeHex(b []byte) string {
return string(enc)
}
// EncodeHex encodes bs as a hex strings with 0x prefix.
func EncodeHexes(bs [][]byte) []string {
result := make([]string, len(bs))
for i, b := range bs {
result[i] = EncodeHex(b)
}
return result
}
// DecodeHex decodes a hex string with 0x prefix.
func DecodeHex(input string) ([]byte, error) {
if len(input) == 0 {

View File

@ -291,7 +291,6 @@ func (s *MessageSender) sendCommunity(
return nil, err
}
rawMessage.ID = types.EncodeHex(messageID)
messageIDs := [][]byte{messageID}
if rawMessage.BeforeDispatch != nil {
if err := rawMessage.BeforeDispatch(rawMessage); err != nil {
@ -320,7 +319,7 @@ func (s *MessageSender) sendCommunity(
for i, spec := range keyExMessageSpecs {
recipient := rawMessage.Recipients[i]
_, _, err = s.sendMessageSpec(ctx, recipient, spec, messageIDs)
_, _, err = s.sendMessageSpec(ctx, recipient, spec, [][]byte{messageID})
if err != nil {
return nil, err
}
@ -404,7 +403,7 @@ func (s *MessageSender) sendCommunity(
sentMessage := &SentMessage{
Spec: messageSpec,
MessageIDs: messageIDs,
MessageIDs: [][]byte{messageID},
}
s.notifyOnSentMessage(sentMessage)
@ -417,17 +416,15 @@ func (s *MessageSender) sendCommunity(
if err != nil {
return nil, errors.Wrap(err, "failed to decompress pubkey")
}
hashes, newMessages, err = s.dispatchCommunityMessage(ctx, pubkey, payload, messageIDs, rawMessage.PubsubTopic)
hashes, newMessages, err = s.dispatchCommunityMessage(ctx, pubkey, payload, rawMessage.PubsubTopic)
if err != nil {
s.logger.Error("failed to send a community message", zap.Error(err))
return nil, errors.Wrap(err, "failed to send a message spec")
}
}
for i, newMessage := range newMessages {
s.logger.Debug("sent community message ", zap.String("messageID", messageID.String()), zap.String("hash", types.EncodeHex(hashes[i])))
s.transport.Track(messageIDs, hashes[i], newMessage)
}
s.logger.Debug("sent community message ", zap.String("messageID", messageID.String()), zap.Strings("hashes", types.EncodeHexes(hashes)))
s.transport.Track(messageID, hashes, newMessages)
return messageID, nil
}
@ -479,17 +476,14 @@ func (s *MessageSender) sendPrivate(
}
} else if rawMessage.SkipEncryptionLayer {
// When SkipProtocolLayer is set we don't pass the message to the encryption layer
messageIDs := [][]byte{messageID}
hashes, newMessages, err := s.sendPrivateRawMessage(ctx, rawMessage, recipient, wrappedMessage, messageIDs)
hashes, newMessages, err := s.sendPrivateRawMessage(ctx, rawMessage, recipient, wrappedMessage)
if err != nil {
s.logger.Error("failed to send a private message", zap.Error(err))
return nil, errors.Wrap(err, "failed to send a message spec")
}
for i, newMessage := range newMessages {
s.logger.Debug("sent private message skipProtocolLayer", zap.String("messageID", messageID.String()), zap.String("hash", types.EncodeHex(hashes[i])))
s.transport.Track(messageIDs, hashes[i], newMessage)
}
s.logger.Debug("sent private message skipProtocolLayer", zap.String("messageID", messageID.String()), zap.Strings("hashes", types.EncodeHexes(hashes)))
s.transport.Track(messageID, hashes, newMessages)
} else {
messageSpec, err := s.protocol.BuildEncryptedMessage(rawMessage.Sender, recipient, wrappedMessage)
@ -507,17 +501,14 @@ func (s *MessageSender) sendPrivate(
}
messageIDs := [][]byte{messageID}
hashes, newMessages, err := s.sendMessageSpec(ctx, recipient, messageSpec, messageIDs)
hashes, newMessages, err := s.sendMessageSpec(ctx, recipient, messageSpec, [][]byte{messageID})
if err != nil {
s.logger.Error("failed to send a private message", zap.Error(err))
return nil, errors.Wrap(err, "failed to send a message spec")
}
for i, newMessage := range newMessages {
s.logger.Debug("sent private message without datasync", zap.String("messageID", messageID.String()), zap.String("hash", types.EncodeHex(hashes[i])))
s.transport.Track(messageIDs, hashes[i], newMessage)
}
s.logger.Debug("sent private message without datasync", zap.String("messageID", messageID.String()), zap.Strings("hashes", types.EncodeHexes(hashes)))
s.transport.Track(messageID, hashes, newMessages)
}
return messageID, nil
@ -542,16 +533,13 @@ func (s *MessageSender) SendPairInstallation(
}
messageID := v1protocol.MessageID(&s.identity.PublicKey, wrappedMessage)
messageIDs := [][]byte{messageID}
hashes, newMessages, err := s.sendMessageSpec(ctx, recipient, messageSpec, messageIDs)
hashes, newMessages, err := s.sendMessageSpec(ctx, recipient, messageSpec, [][]byte{messageID})
if err != nil {
return nil, errors.Wrap(err, "failed to send a message spec")
}
for i, newMessage := range newMessages {
s.transport.Track(messageIDs, hashes[i], newMessage)
}
s.transport.Track(messageID, hashes, newMessages)
return messageID, nil
}
@ -708,7 +696,6 @@ func (s *MessageSender) SendPublic(
return nil, err
}
hashes = append(hashes, hash)
s.logger.Debug("sent public message", zap.String("messageID", messageID.String()), zap.String("hash", types.EncodeHex(hash)))
}
sentMessage := &SentMessage{
@ -718,9 +705,8 @@ func (s *MessageSender) SendPublic(
s.notifyOnSentMessage(sentMessage)
for i, newMessage := range newMessages {
s.transport.Track([][]byte{messageID}, hashes[i], newMessage)
}
s.logger.Debug("sent public message", zap.String("messageID", messageID.String()), zap.Strings("hashes", types.EncodeHexes(hashes)))
s.transport.Track(messageID, hashes, newMessages)
return messageID, nil
}
@ -1011,16 +997,14 @@ func (s *MessageSender) sendDataSync(ctx context.Context, publicKey *ecdsa.Publi
return err
}
for i, newMessage := range newMessages {
s.logger.Debug("sent private messages", zap.Any("messageIDs", hexMessageIDs), zap.String("hash", types.EncodeHex(hashes[i])))
s.transport.Track(messageIDs, hashes[i], newMessage)
}
s.logger.Debug("sent private messages", zap.Any("messageIDs", hexMessageIDs), zap.Strings("hashes", types.EncodeHexes(hashes)))
s.transport.TrackMany(messageIDs, hashes, newMessages)
return nil
}
// sendPrivateRawMessage sends a message not wrapped in an encryption layer
func (s *MessageSender) sendPrivateRawMessage(ctx context.Context, rawMessage *RawMessage, publicKey *ecdsa.PublicKey, payload []byte, messageIDs [][]byte) ([][]byte, []*types.NewMessage, error) {
func (s *MessageSender) sendPrivateRawMessage(ctx context.Context, rawMessage *RawMessage, publicKey *ecdsa.PublicKey, payload []byte) ([][]byte, []*types.NewMessage, error) {
newMessage := &types.NewMessage{
TTL: whisperTTL,
Payload: payload,
@ -1053,7 +1037,7 @@ func (s *MessageSender) sendPrivateRawMessage(ctx context.Context, rawMessage *R
// sendCommunityMessage sends a message not wrapped in an encryption layer
// to a community
func (s *MessageSender) dispatchCommunityMessage(ctx context.Context, publicKey *ecdsa.PublicKey, payload []byte, messageIDs [][]byte, pubsubTopic string) ([][]byte, []*types.NewMessage, error) {
func (s *MessageSender) dispatchCommunityMessage(ctx context.Context, publicKey *ecdsa.PublicKey, payload []byte, pubsubTopic string) ([][]byte, []*types.NewMessage, error) {
newMessage := &types.NewMessage{
TTL: whisperTTL,
Payload: payload,

View File

@ -151,12 +151,13 @@ func (s *MessengerMessagesTrackingSuite) newMessenger(waku types.Waku, logger *z
return messenger, interceptor
}
func (s *MessengerMessagesTrackingSuite) TestMessageMarkedAsSent() {
func (s *MessengerMessagesTrackingSuite) testMessageMarkedAsSent(textSize int) {
//when message sent, its sent field should be "false" until we got confirmation
chat := CreatePublicChat("test-chat", s.bob.getTimesource())
err := s.bob.SaveChat(chat)
s.Require().NoError(err)
inputMessage := buildTestMessage(*chat)
inputMessage.Text = string(make([]byte, textSize))
_, err = s.bob.SendChatMessage(context.Background(), inputMessage)
s.Require().NoError(err)
@ -182,3 +183,10 @@ func (s *MessengerMessagesTrackingSuite) TestMessageMarkedAsSent() {
}, options)
s.Require().NoError(err)
}
func (s *MessengerMessagesTrackingSuite) TestMessageMarkedAsSent() {
s.testMessageMarkedAsSent(1)
}
func (s *MessengerMessagesTrackingSuite) TestSegmentedMessageMarkedAsSent() {
s.testMessageMarkedAsSent(4 * 1024 * 1024) // 4MB - ensure message is segmented
}

View File

@ -65,6 +65,9 @@ func NewEnvelopesMonitor(w types.Waku, config EnvelopesMonitorConfig) *Envelopes
// key is hash of the batch (event.Batch)
batches: map[types.Hash]map[types.Hash]struct{}{},
// key is stringified message identifier
identifierHashes: make(map[string][]types.Hash),
}
}
@ -86,6 +89,7 @@ type EnvelopesMonitor struct {
envelopes map[types.Hash]*monitoredEnvelope
batches map[types.Hash]map[types.Hash]struct{}
identifierHashes map[string][]types.Hash
awaitOnlyMailServerConfirmations bool
@ -112,24 +116,31 @@ func (m *EnvelopesMonitor) Stop() {
m.wg.Wait()
}
// Add hash to a tracker.
func (m *EnvelopesMonitor) Add(identifiers [][]byte, envelopeHash types.Hash, message types.NewMessage) {
m.mu.Lock()
defer m.mu.Unlock()
// Add hashes to a tracker.
// Identifiers may be backed by multiple envelopes. It happens when message is split in segmentation layer.
func (m *EnvelopesMonitor) Add(identifiers [][]byte, envelopeHashes []types.Hash, messages []*types.NewMessage) error {
if len(envelopeHashes) != len(messages) {
return errors.New("hashes don't match messages")
}
if envelope, ok := m.envelopes[envelopeHash]; !ok {
for _, identifier := range identifiers {
m.identifierHashes[string(identifier)] = envelopeHashes
}
for i, envelopeHash := range envelopeHashes {
if _, ok := m.envelopes[envelopeHash]; !ok {
m.envelopes[envelopeHash] = &monitoredEnvelope{
state: EnvelopePosted,
attempts: 1,
message: &message,
message: messages[i],
identifiers: identifiers,
}
} else if envelope.state == EnvelopeSent {
// If it's already been marked as sent, we notify the client
if m.handler != nil {
m.handler.EnvelopeSent(envelope.identifiers)
}
}
m.processIdentifiers(identifiers)
return nil
}
func (m *EnvelopesMonitor) GetState(hash types.Hash) EnvelopeState {
@ -210,9 +221,7 @@ func (m *EnvelopesMonitor) handleEventEnvelopeSent(event types.EnvelopeEvent) {
} else {
m.logger.Debug("confirmation not expected, marking as sent")
envelope.state = EnvelopeSent
if m.handler != nil {
m.handler.EnvelopeSent(envelope.identifiers)
}
m.processIdentifiers(envelope.identifiers)
}
}
@ -259,9 +268,7 @@ func (m *EnvelopesMonitor) handleAcknowledgedBatch(event types.EnvelopeEvent) {
continue
}
envelope.state = EnvelopeSent
if m.handler != nil {
m.handler.EnvelopeSent(envelope.identifiers)
}
m.processIdentifiers(envelope.identifiers)
}
delete(m.batches, event.Batch)
}
@ -318,13 +325,46 @@ func (m *EnvelopesMonitor) handleEventEnvelopeReceived(event types.EnvelopeEvent
}
m.logger.Debug("expected envelope received", zap.String("hash", event.Hash.String()), zap.String("peer", event.Peer.String()))
envelope.state = EnvelopeSent
if m.handler != nil {
m.handler.EnvelopeSent(envelope.identifiers)
m.processIdentifiers(envelope.identifiers)
}
func (m *EnvelopesMonitor) processIdentifiers(identifiers [][]byte) {
sentIdentifiers := make([][]byte, 0, len(identifiers))
for _, identifier := range identifiers {
hashes, ok := m.identifierHashes[string(identifier)]
if !ok {
continue
}
sent := true
// Consider identifier as sent if all corresponding envelopes are in EnvelopeSent state
for _, hash := range hashes {
envelope, ok := m.envelopes[hash]
if !ok || envelope.state != EnvelopeSent {
sent = false
break
}
}
if sent {
sentIdentifiers = append(sentIdentifiers, identifier)
}
}
if len(sentIdentifiers) > 0 && m.handler != nil {
m.handler.EnvelopeSent(sentIdentifiers)
}
}
// clearMessageState removes all message and envelope state.
// not thread-safe, should be protected on a higher level.
func (m *EnvelopesMonitor) clearMessageState(envelopeID types.Hash) {
delete(m.envelopes, envelopeID)
envelope, ok := m.envelopes[envelopeID]
if !ok {
return
}
delete(m.envelopes, envelopeID)
for _, identifier := range envelope.identifiers {
delete(m.identifierHashes, string(identifier))
}
}

View File

@ -1,6 +1,7 @@
package transport
import (
"reflect"
"testing"
"go.uber.org/zap"
@ -15,13 +16,29 @@ import (
var (
testHash = types.Hash{0x01}
testHashes = []types.Hash{testHash}
testIDs = [][]byte{[]byte("id")}
)
type envelopeEventsHandlerMock struct {
envelopeSentCalls [][][]byte // slice of EnvelopeSent arguments
}
func (h *envelopeEventsHandlerMock) EnvelopeSent(identifiers [][]byte) {
h.envelopeSentCalls = append(h.envelopeSentCalls, identifiers)
}
func (h *envelopeEventsHandlerMock) EnvelopeExpired([][]byte, error) {
}
func (h *envelopeEventsHandlerMock) MailServerRequestCompleted(types.Hash, types.Hash, []byte, error) {
}
func (h *envelopeEventsHandlerMock) MailServerRequestExpired(types.Hash) {
}
type EnvelopesMonitorSuite struct {
suite.Suite
monitor *EnvelopesMonitor
eventsHandlerMock *envelopeEventsHandlerMock
}
func TestEnvelopesMonitorSuite(t *testing.T) {
@ -29,10 +46,11 @@ func TestEnvelopesMonitorSuite(t *testing.T) {
}
func (s *EnvelopesMonitorSuite) SetupTest() {
s.eventsHandlerMock = &envelopeEventsHandlerMock{}
s.monitor = NewEnvelopesMonitor(
nil,
EnvelopesMonitorConfig{
EnvelopeEventsHandler: nil,
EnvelopeEventsHandler: s.eventsHandlerMock,
MaxAttempts: 0,
AwaitOnlyMailServerConfirmations: false,
IsMailserver: func(types.EnodeID) bool { return false },
@ -42,7 +60,8 @@ func (s *EnvelopesMonitorSuite) SetupTest() {
}
func (s *EnvelopesMonitorSuite) TestEnvelopePosted() {
s.monitor.Add(testIDs, testHash, types.NewMessage{})
err := s.monitor.Add(testIDs, testHashes, []*types.NewMessage{{}})
s.Require().NoError(err)
s.Contains(s.monitor.envelopes, testHash)
s.Equal(EnvelopePosted, s.monitor.envelopes[testHash].state)
s.monitor.handleEvent(types.EnvelopeEvent{
@ -59,7 +78,8 @@ func (s *EnvelopesMonitorSuite) TestEnvelopePostedOutOfOrder() {
Hash: testHash,
})
s.monitor.Add(testIDs, testHash, types.NewMessage{})
err := s.monitor.Add(testIDs, testHashes, []*types.NewMessage{{}})
s.Require().NoError(err)
s.Require().Contains(s.monitor.envelopes, testHash)
s.Require().Equal(EnvelopeSent, s.monitor.envelopes[testHash].state)
}
@ -69,7 +89,8 @@ func (s *EnvelopesMonitorSuite) TestConfirmedWithAcknowledge() {
pkey, err := crypto.GenerateKey()
s.Require().NoError(err)
node := enode.NewV4(&pkey.PublicKey, nil, 0, 0)
s.monitor.Add(testIDs, testHash, types.NewMessage{})
err = s.monitor.Add(testIDs, testHashes, []*types.NewMessage{{}})
s.Require().NoError(err)
s.Contains(s.monitor.envelopes, testHash)
s.Equal(EnvelopePosted, s.monitor.envelopes[testHash].state)
s.monitor.handleEvent(types.EnvelopeEvent{
@ -88,7 +109,8 @@ func (s *EnvelopesMonitorSuite) TestConfirmedWithAcknowledge() {
}
func (s *EnvelopesMonitorSuite) TestRemoved() {
s.monitor.Add(testIDs, testHash, types.NewMessage{})
err := s.monitor.Add(testIDs, testHashes, []*types.NewMessage{{}})
s.Require().NoError(err)
s.Contains(s.monitor.envelopes, testHash)
s.monitor.handleEvent(types.EnvelopeEvent{
Event: types.EventEnvelopeExpired,
@ -100,7 +122,8 @@ func (s *EnvelopesMonitorSuite) TestRemoved() {
func (s *EnvelopesMonitorSuite) TestIgnoreNotFromMailserver() {
// enables filter in the tracker to drop confirmations from non-mailserver peers
s.monitor.awaitOnlyMailServerConfirmations = true
s.monitor.Add(testIDs, testHash, types.NewMessage{})
err := s.monitor.Add(testIDs, testHashes, []*types.NewMessage{{}})
s.Require().NoError(err)
s.monitor.handleEvent(types.EnvelopeEvent{
Event: types.EventEnvelopeSent,
Hash: testHash,
@ -113,7 +136,8 @@ func (s *EnvelopesMonitorSuite) TestReceived() {
s.monitor.isMailserver = func(peer types.EnodeID) bool {
return true
}
s.monitor.Add(testIDs, testHash, types.NewMessage{})
err := s.monitor.Add(testIDs, testHashes, []*types.NewMessage{{}})
s.Require().NoError(err)
s.Contains(s.monitor.envelopes, testHash)
s.monitor.handleEvent(types.EnvelopeEvent{
Event: types.EventEnvelopeReceived,
@ -121,3 +145,80 @@ func (s *EnvelopesMonitorSuite) TestReceived() {
})
s.Require().Equal(EnvelopeSent, s.monitor.GetState(testHash))
}
func (s *EnvelopesMonitorSuite) TestMultipleHashes() {
messageIDs := [][]byte{[]byte("id1"), []byte("id2")}
hashes := []types.Hash{{0x01}, {0x02}, {0x03}}
messages := []*types.NewMessage{{}, {}, {}}
err := s.monitor.Add(messageIDs, hashes, messages)
s.Require().NoError(err)
for _, hash := range hashes {
s.Contains(s.monitor.envelopes, hash)
}
s.Require().Empty(s.eventsHandlerMock.envelopeSentCalls)
s.Require().Equal(EnvelopePosted, s.monitor.envelopes[hashes[0]].state)
s.Require().Equal(EnvelopePosted, s.monitor.envelopes[hashes[1]].state)
s.Require().Equal(EnvelopePosted, s.monitor.envelopes[hashes[2]].state)
s.monitor.handleEvent(types.EnvelopeEvent{
Event: types.EventEnvelopeSent,
Hash: hashes[0],
})
s.Require().Empty(s.eventsHandlerMock.envelopeSentCalls)
s.Require().Equal(EnvelopeSent, s.monitor.envelopes[hashes[0]].state)
s.Require().Equal(EnvelopePosted, s.monitor.envelopes[hashes[1]].state)
s.Require().Equal(EnvelopePosted, s.monitor.envelopes[hashes[2]].state)
s.monitor.handleEvent(types.EnvelopeEvent{
Event: types.EventEnvelopeSent,
Hash: hashes[1],
})
s.Require().Empty(s.eventsHandlerMock.envelopeSentCalls)
s.Require().Equal(EnvelopeSent, s.monitor.envelopes[hashes[0]].state)
s.Require().Equal(EnvelopeSent, s.monitor.envelopes[hashes[1]].state)
s.Require().Equal(EnvelopePosted, s.monitor.envelopes[hashes[2]].state)
s.monitor.handleEvent(types.EnvelopeEvent{
Event: types.EventEnvelopeSent,
Hash: hashes[2],
})
// Identifiers should be marked as sent only if all corresponding envelopes are sent
s.Require().Len(s.eventsHandlerMock.envelopeSentCalls, 1)
s.Require().True(reflect.DeepEqual(messageIDs, s.eventsHandlerMock.envelopeSentCalls[0]))
s.Require().Equal(EnvelopeSent, s.monitor.envelopes[hashes[0]].state)
s.Require().Equal(EnvelopeSent, s.monitor.envelopes[hashes[1]].state)
s.Require().Equal(EnvelopeSent, s.monitor.envelopes[hashes[2]].state)
}
func (s *EnvelopesMonitorSuite) TestMultipleHashes_EnvelopeExpired() {
messageIDs := [][]byte{[]byte("id1"), []byte("id2")}
hashes := []types.Hash{{0x01}, {0x02}, {0x03}}
messages := []*types.NewMessage{{}, {}, {}}
err := s.monitor.Add(messageIDs, hashes, messages)
s.Require().NoError(err)
// If any envelope fails, then identifiers are considered as not sent
s.monitor.handleEvent(types.EnvelopeEvent{
Event: types.EventEnvelopeExpired,
Hash: hashes[0],
})
s.monitor.handleEvent(types.EnvelopeEvent{
Event: types.EventEnvelopeSent,
Hash: hashes[1],
})
s.monitor.handleEvent(types.EnvelopeEvent{
Event: types.EventEnvelopeSent,
Hash: hashes[2],
})
s.Require().Empty(s.eventsHandlerMock.envelopeSentCalls)
s.Require().Empty(s.monitor.identifierHashes)
s.Require().Len(s.monitor.envelopes, 2)
}
func (s *EnvelopesMonitorSuite) TestMultipleHashes_Failure() {
err := s.monitor.Add(testIDs, []types.Hash{{0x01}, {0x02}}, []*types.NewMessage{{}})
s.Require().Error(err)
}

View File

@ -387,9 +387,23 @@ func (t *Transport) addSig(newMessage *types.NewMessage) error {
return nil
}
func (t *Transport) Track(identifiers [][]byte, hash []byte, newMessage *types.NewMessage) {
if t.envelopesMonitor != nil {
t.envelopesMonitor.Add(identifiers, types.BytesToHash(hash), *newMessage)
func (t *Transport) Track(identifier []byte, hashes [][]byte, newMessages []*types.NewMessage) {
t.TrackMany([][]byte{identifier}, hashes, newMessages)
}
func (t *Transport) TrackMany(identifiers [][]byte, hashes [][]byte, newMessages []*types.NewMessage) {
if t.envelopesMonitor == nil {
return
}
envelopeHashes := make([]types.Hash, len(hashes))
for i, hash := range hashes {
envelopeHashes[i] = types.BytesToHash(hash)
}
err := t.envelopesMonitor.Add(identifiers, envelopeHashes, newMessages)
if err != nil {
t.logger.Error("failed to track messages", zap.Error(err))
}
}