Use ErrRecordNotFound instead of sql.ErrNoRows

We were checking for the wrong error kind when pulling messages from the
database, which resulted in the code not retrying to pull the message,
giving flaky tests / race condition (that's present in production as
well)
This commit is contained in:
Andrea Maria Piana 2020-10-08 12:46:03 +02:00
parent 34c41d4bc8
commit b11399ffc6
5 changed files with 19 additions and 21 deletions

View File

@ -0,0 +1,5 @@
package common
import "errors"
var ErrRecordNotFound = errors.New("record not found")

View File

@ -73,7 +73,7 @@ func (m *MessageHandler) HandleMembershipUpdate(messageState *ReceivedMessageSta
} }
groupChatInvitation, err = m.persistence.InvitationByID(groupChatInvitation.ID()) groupChatInvitation, err = m.persistence.InvitationByID(groupChatInvitation.ID())
if err != nil && err != errRecordNotFound { if err != nil && err != common.ErrRecordNotFound {
return err return err
} }
if groupChatInvitation != nil { if groupChatInvitation != nil {
@ -479,7 +479,7 @@ func (m *MessageHandler) HandleAcceptRequestAddressForTransaction(messageState *
// Hide previous message // Hide previous message
previousMessage, err := m.persistence.MessageByCommandID(messageState.CurrentMessageState.Contact.ID, command.Id) previousMessage, err := m.persistence.MessageByCommandID(messageState.CurrentMessageState.Contact.ID, command.Id)
if err != nil && err != errRecordNotFound { if err != nil && err != common.ErrRecordNotFound {
return err return err
} }
@ -691,7 +691,7 @@ func (m *MessageHandler) messageExists(messageID string, existingMessagesMap map
// Check against the database, this is probably a bit slow for // Check against the database, this is probably a bit slow for
// each message, but for now might do, we'll make it faster later // each message, but for now might do, we'll make it faster later
existingMessage, err := m.persistence.MessageByID(messageID) existingMessage, err := m.persistence.MessageByID(messageID)
if err != nil && err != errRecordNotFound { if err != nil && err != common.ErrRecordNotFound {
return false, err return false, err
} }
if existingMessage != nil { if existingMessage != nil {
@ -716,7 +716,7 @@ func (m *MessageHandler) HandleEmojiReaction(state *ReceivedMessageState, pbEmoj
} }
existingEmoji, err := m.persistence.EmojiReactionByID(emojiReaction.ID()) existingEmoji, err := m.persistence.EmojiReactionByID(emojiReaction.ID())
if err != errRecordNotFound && err != nil { if err != common.ErrRecordNotFound && err != nil {
return err return err
} }
@ -775,7 +775,7 @@ func (m *MessageHandler) HandleGroupChatInvitation(state *ReceivedMessageState,
} }
existingInvitation, err := m.persistence.InvitationByID(groupChatInvitation.ID()) existingInvitation, err := m.persistence.InvitationByID(groupChatInvitation.ID())
if err != errRecordNotFound && err != nil { if err != common.ErrRecordNotFound && err != nil {
return err return err
} }

View File

@ -9,12 +9,6 @@ import (
"github.com/status-im/status-go/protocol/common" "github.com/status-im/status-go/protocol/common"
"github.com/status-im/status-go/protocol/protobuf" "github.com/status-im/status-go/protocol/protobuf"
"github.com/pkg/errors"
)
var (
errRecordNotFound = errors.New("record not found")
) )
func (db sqlitePersistence) tableUserMessagesAllFields() string { func (db sqlitePersistence) tableUserMessagesAllFields() string {
@ -315,7 +309,7 @@ func (db sqlitePersistence) messageByID(tx *sql.Tx, id string) (*common.Message,
err = db.tableUserMessagesScanAllFields(row, &message) err = db.tableUserMessagesScanAllFields(row, &message)
switch err { switch err {
case sql.ErrNoRows: case sql.ErrNoRows:
return nil, errRecordNotFound return nil, common.ErrRecordNotFound
case nil: case nil:
return &message, nil return &message, nil
default: default:
@ -356,7 +350,7 @@ func (db sqlitePersistence) MessageByCommandID(chatID, id string) (*common.Messa
err := db.tableUserMessagesScanAllFields(row, &message) err := db.tableUserMessagesScanAllFields(row, &message)
switch err { switch err {
case sql.ErrNoRows: case sql.ErrNoRows:
return nil, errRecordNotFound return nil, common.ErrRecordNotFound
case nil: case nil:
return &message, nil return &message, nil
default: default:
@ -854,7 +848,7 @@ func (db sqlitePersistence) EmojiReactionByID(id string) (*EmojiReaction, error)
switch err { switch err {
case sql.ErrNoRows: case sql.ErrNoRows:
return nil, errRecordNotFound return nil, common.ErrRecordNotFound
case nil: case nil:
return emojiReaction, nil return emojiReaction, nil
default: default:
@ -947,7 +941,7 @@ func (db sqlitePersistence) InvitationByID(id string) (*GroupChatInvitation, err
switch err { switch err {
case sql.ErrNoRows: case sql.ErrNoRows:
return nil, errRecordNotFound return nil, common.ErrRecordNotFound
case nil: case nil:
return chatInvitations, nil return chatInvitations, nil
default: default:

View File

@ -890,7 +890,7 @@ func (m *Messenger) AddMembersToGroupChat(ctx context.Context, chatID string, me
} }
groupChatInvitation, err = m.persistence.InvitationByID(groupChatInvitation.ID()) groupChatInvitation, err = m.persistence.InvitationByID(groupChatInvitation.ID())
if err != nil && err != errRecordNotFound { if err != nil && err != common.ErrRecordNotFound {
return nil, err return nil, err
} }
if groupChatInvitation != nil { if groupChatInvitation != nil {
@ -3199,7 +3199,7 @@ func (m *Messenger) AcceptRequestTransaction(ctx context.Context, transactionHas
// Hide previous message // Hide previous message
previousMessage, err := m.persistence.MessageByCommandID(chatID, messageID) previousMessage, err := m.persistence.MessageByCommandID(chatID, messageID)
if err != nil && err != errRecordNotFound { if err != nil && err != common.ErrRecordNotFound {
return nil, err return nil, err
} }
@ -3423,7 +3423,7 @@ func (m *Messenger) ValidateTransactions(ctx context.Context, addresses []types.
if len(message.CommandParameters.ID) != 0 { if len(message.CommandParameters.ID) != 0 {
// Hide previous message // Hide previous message
previousMessage, err := m.persistence.MessageByCommandID(chatID, message.CommandParameters.ID) previousMessage, err := m.persistence.MessageByCommandID(chatID, message.CommandParameters.ID)
if err != nil && err != errRecordNotFound { if err != nil && err != common.ErrRecordNotFound {
return nil, err return nil, err
} }

View File

@ -7,7 +7,6 @@ import (
"crypto/cipher" "crypto/cipher"
"crypto/ecdsa" "crypto/ecdsa"
"crypto/rand" "crypto/rand"
"database/sql"
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"errors" "errors"
@ -839,7 +838,7 @@ func (c *Client) getMessage(messageID string) (*common.Message, error) {
retries := 0 retries := 0
for retries < 10 { for retries < 10 {
message, err := c.messagePersistence.MessageByID(messageID) message, err := c.messagePersistence.MessageByID(messageID)
if err == sql.ErrNoRows { if err == common.ErrRecordNotFound {
retries++ retries++
time.Sleep(300 * time.Millisecond) time.Sleep(300 * time.Millisecond)
continue continue
@ -849,7 +848,7 @@ func (c *Client) getMessage(messageID string) (*common.Message, error) {
return message, nil return message, nil
} }
return nil, sql.ErrNoRows return nil, common.ErrRecordNotFound
} }
// handlePublicMessageSent handles public messages, we notify only on mentions // handlePublicMessageSent handles public messages, we notify only on mentions