From b11399ffc6b85125f0c8ced2d389c72ffc022f99 Mon Sep 17 00:00:00 2001 From: Andrea Maria Piana Date: Thu, 8 Oct 2020 12:46:03 +0200 Subject: [PATCH] Use ErrRecordNotFound instead of sql.ErrNoRows We were checking for the wrong error kind when pulling messages from the database, which resulted in the code not retrying to pull the message, giving flaky tests / race condition (that's present in production as well) --- protocol/common/errors.go | 5 +++++ protocol/message_handler.go | 10 +++++----- protocol/message_persistence.go | 14 ++++---------- protocol/messenger.go | 6 +++--- protocol/pushnotificationclient/client.go | 5 ++--- 5 files changed, 19 insertions(+), 21 deletions(-) create mode 100644 protocol/common/errors.go diff --git a/protocol/common/errors.go b/protocol/common/errors.go new file mode 100644 index 000000000..e40a89922 --- /dev/null +++ b/protocol/common/errors.go @@ -0,0 +1,5 @@ +package common + +import "errors" + +var ErrRecordNotFound = errors.New("record not found") diff --git a/protocol/message_handler.go b/protocol/message_handler.go index 2be57ae2c..3fb5f5a52 100644 --- a/protocol/message_handler.go +++ b/protocol/message_handler.go @@ -73,7 +73,7 @@ func (m *MessageHandler) HandleMembershipUpdate(messageState *ReceivedMessageSta } groupChatInvitation, err = m.persistence.InvitationByID(groupChatInvitation.ID()) - if err != nil && err != errRecordNotFound { + if err != nil && err != common.ErrRecordNotFound { return err } if groupChatInvitation != nil { @@ -479,7 +479,7 @@ func (m *MessageHandler) HandleAcceptRequestAddressForTransaction(messageState * // Hide previous message previousMessage, err := m.persistence.MessageByCommandID(messageState.CurrentMessageState.Contact.ID, command.Id) - if err != nil && err != errRecordNotFound { + if err != nil && err != common.ErrRecordNotFound { return err } @@ -691,7 +691,7 @@ func (m *MessageHandler) messageExists(messageID string, existingMessagesMap map // Check against the database, this is probably a bit slow for // each message, but for now might do, we'll make it faster later existingMessage, err := m.persistence.MessageByID(messageID) - if err != nil && err != errRecordNotFound { + if err != nil && err != common.ErrRecordNotFound { return false, err } if existingMessage != nil { @@ -716,7 +716,7 @@ func (m *MessageHandler) HandleEmojiReaction(state *ReceivedMessageState, pbEmoj } existingEmoji, err := m.persistence.EmojiReactionByID(emojiReaction.ID()) - if err != errRecordNotFound && err != nil { + if err != common.ErrRecordNotFound && err != nil { return err } @@ -775,7 +775,7 @@ func (m *MessageHandler) HandleGroupChatInvitation(state *ReceivedMessageState, } existingInvitation, err := m.persistence.InvitationByID(groupChatInvitation.ID()) - if err != errRecordNotFound && err != nil { + if err != common.ErrRecordNotFound && err != nil { return err } diff --git a/protocol/message_persistence.go b/protocol/message_persistence.go index d0e31243b..ae7da5848 100644 --- a/protocol/message_persistence.go +++ b/protocol/message_persistence.go @@ -9,12 +9,6 @@ import ( "github.com/status-im/status-go/protocol/common" "github.com/status-im/status-go/protocol/protobuf" - - "github.com/pkg/errors" -) - -var ( - errRecordNotFound = errors.New("record not found") ) func (db sqlitePersistence) tableUserMessagesAllFields() string { @@ -315,7 +309,7 @@ func (db sqlitePersistence) messageByID(tx *sql.Tx, id string) (*common.Message, err = db.tableUserMessagesScanAllFields(row, &message) switch err { case sql.ErrNoRows: - return nil, errRecordNotFound + return nil, common.ErrRecordNotFound case nil: return &message, nil default: @@ -356,7 +350,7 @@ func (db sqlitePersistence) MessageByCommandID(chatID, id string) (*common.Messa err := db.tableUserMessagesScanAllFields(row, &message) switch err { case sql.ErrNoRows: - return nil, errRecordNotFound + return nil, common.ErrRecordNotFound case nil: return &message, nil default: @@ -854,7 +848,7 @@ func (db sqlitePersistence) EmojiReactionByID(id string) (*EmojiReaction, error) switch err { case sql.ErrNoRows: - return nil, errRecordNotFound + return nil, common.ErrRecordNotFound case nil: return emojiReaction, nil default: @@ -947,7 +941,7 @@ func (db sqlitePersistence) InvitationByID(id string) (*GroupChatInvitation, err switch err { case sql.ErrNoRows: - return nil, errRecordNotFound + return nil, common.ErrRecordNotFound case nil: return chatInvitations, nil default: diff --git a/protocol/messenger.go b/protocol/messenger.go index 6a0bc753d..e14ec5c5c 100644 --- a/protocol/messenger.go +++ b/protocol/messenger.go @@ -890,7 +890,7 @@ func (m *Messenger) AddMembersToGroupChat(ctx context.Context, chatID string, me } groupChatInvitation, err = m.persistence.InvitationByID(groupChatInvitation.ID()) - if err != nil && err != errRecordNotFound { + if err != nil && err != common.ErrRecordNotFound { return nil, err } if groupChatInvitation != nil { @@ -3199,7 +3199,7 @@ func (m *Messenger) AcceptRequestTransaction(ctx context.Context, transactionHas // Hide previous message previousMessage, err := m.persistence.MessageByCommandID(chatID, messageID) - if err != nil && err != errRecordNotFound { + if err != nil && err != common.ErrRecordNotFound { return nil, err } @@ -3423,7 +3423,7 @@ func (m *Messenger) ValidateTransactions(ctx context.Context, addresses []types. if len(message.CommandParameters.ID) != 0 { // Hide previous message previousMessage, err := m.persistence.MessageByCommandID(chatID, message.CommandParameters.ID) - if err != nil && err != errRecordNotFound { + if err != nil && err != common.ErrRecordNotFound { return nil, err } diff --git a/protocol/pushnotificationclient/client.go b/protocol/pushnotificationclient/client.go index c22f28fd4..bd91c856f 100644 --- a/protocol/pushnotificationclient/client.go +++ b/protocol/pushnotificationclient/client.go @@ -7,7 +7,6 @@ import ( "crypto/cipher" "crypto/ecdsa" "crypto/rand" - "database/sql" "encoding/hex" "encoding/json" "errors" @@ -839,7 +838,7 @@ func (c *Client) getMessage(messageID string) (*common.Message, error) { retries := 0 for retries < 10 { message, err := c.messagePersistence.MessageByID(messageID) - if err == sql.ErrNoRows { + if err == common.ErrRecordNotFound { retries++ time.Sleep(300 * time.Millisecond) continue @@ -849,7 +848,7 @@ func (c *Client) getMessage(messageID string) (*common.Message, error) { return message, nil } - return nil, sql.ErrNoRows + return nil, common.ErrRecordNotFound } // handlePublicMessageSent handles public messages, we notify only on mentions