Use ErrRecordNotFound instead of sql.ErrNoRows
We were checking for the wrong error kind when pulling messages from the database, which resulted in the code not retrying to pull the message, giving flaky tests / race condition (that's present in production as well)
This commit is contained in:
parent
34c41d4bc8
commit
b11399ffc6
|
@ -0,0 +1,5 @@
|
||||||
|
package common
|
||||||
|
|
||||||
|
import "errors"
|
||||||
|
|
||||||
|
var ErrRecordNotFound = errors.New("record not found")
|
|
@ -73,7 +73,7 @@ func (m *MessageHandler) HandleMembershipUpdate(messageState *ReceivedMessageSta
|
||||||
}
|
}
|
||||||
|
|
||||||
groupChatInvitation, err = m.persistence.InvitationByID(groupChatInvitation.ID())
|
groupChatInvitation, err = m.persistence.InvitationByID(groupChatInvitation.ID())
|
||||||
if err != nil && err != errRecordNotFound {
|
if err != nil && err != common.ErrRecordNotFound {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if groupChatInvitation != nil {
|
if groupChatInvitation != nil {
|
||||||
|
@ -479,7 +479,7 @@ func (m *MessageHandler) HandleAcceptRequestAddressForTransaction(messageState *
|
||||||
|
|
||||||
// Hide previous message
|
// Hide previous message
|
||||||
previousMessage, err := m.persistence.MessageByCommandID(messageState.CurrentMessageState.Contact.ID, command.Id)
|
previousMessage, err := m.persistence.MessageByCommandID(messageState.CurrentMessageState.Contact.ID, command.Id)
|
||||||
if err != nil && err != errRecordNotFound {
|
if err != nil && err != common.ErrRecordNotFound {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -691,7 +691,7 @@ func (m *MessageHandler) messageExists(messageID string, existingMessagesMap map
|
||||||
// Check against the database, this is probably a bit slow for
|
// Check against the database, this is probably a bit slow for
|
||||||
// each message, but for now might do, we'll make it faster later
|
// each message, but for now might do, we'll make it faster later
|
||||||
existingMessage, err := m.persistence.MessageByID(messageID)
|
existingMessage, err := m.persistence.MessageByID(messageID)
|
||||||
if err != nil && err != errRecordNotFound {
|
if err != nil && err != common.ErrRecordNotFound {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
if existingMessage != nil {
|
if existingMessage != nil {
|
||||||
|
@ -716,7 +716,7 @@ func (m *MessageHandler) HandleEmojiReaction(state *ReceivedMessageState, pbEmoj
|
||||||
}
|
}
|
||||||
|
|
||||||
existingEmoji, err := m.persistence.EmojiReactionByID(emojiReaction.ID())
|
existingEmoji, err := m.persistence.EmojiReactionByID(emojiReaction.ID())
|
||||||
if err != errRecordNotFound && err != nil {
|
if err != common.ErrRecordNotFound && err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -775,7 +775,7 @@ func (m *MessageHandler) HandleGroupChatInvitation(state *ReceivedMessageState,
|
||||||
}
|
}
|
||||||
|
|
||||||
existingInvitation, err := m.persistence.InvitationByID(groupChatInvitation.ID())
|
existingInvitation, err := m.persistence.InvitationByID(groupChatInvitation.ID())
|
||||||
if err != errRecordNotFound && err != nil {
|
if err != common.ErrRecordNotFound && err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -9,12 +9,6 @@ import (
|
||||||
|
|
||||||
"github.com/status-im/status-go/protocol/common"
|
"github.com/status-im/status-go/protocol/common"
|
||||||
"github.com/status-im/status-go/protocol/protobuf"
|
"github.com/status-im/status-go/protocol/protobuf"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
errRecordNotFound = errors.New("record not found")
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func (db sqlitePersistence) tableUserMessagesAllFields() string {
|
func (db sqlitePersistence) tableUserMessagesAllFields() string {
|
||||||
|
@ -315,7 +309,7 @@ func (db sqlitePersistence) messageByID(tx *sql.Tx, id string) (*common.Message,
|
||||||
err = db.tableUserMessagesScanAllFields(row, &message)
|
err = db.tableUserMessagesScanAllFields(row, &message)
|
||||||
switch err {
|
switch err {
|
||||||
case sql.ErrNoRows:
|
case sql.ErrNoRows:
|
||||||
return nil, errRecordNotFound
|
return nil, common.ErrRecordNotFound
|
||||||
case nil:
|
case nil:
|
||||||
return &message, nil
|
return &message, nil
|
||||||
default:
|
default:
|
||||||
|
@ -356,7 +350,7 @@ func (db sqlitePersistence) MessageByCommandID(chatID, id string) (*common.Messa
|
||||||
err := db.tableUserMessagesScanAllFields(row, &message)
|
err := db.tableUserMessagesScanAllFields(row, &message)
|
||||||
switch err {
|
switch err {
|
||||||
case sql.ErrNoRows:
|
case sql.ErrNoRows:
|
||||||
return nil, errRecordNotFound
|
return nil, common.ErrRecordNotFound
|
||||||
case nil:
|
case nil:
|
||||||
return &message, nil
|
return &message, nil
|
||||||
default:
|
default:
|
||||||
|
@ -854,7 +848,7 @@ func (db sqlitePersistence) EmojiReactionByID(id string) (*EmojiReaction, error)
|
||||||
|
|
||||||
switch err {
|
switch err {
|
||||||
case sql.ErrNoRows:
|
case sql.ErrNoRows:
|
||||||
return nil, errRecordNotFound
|
return nil, common.ErrRecordNotFound
|
||||||
case nil:
|
case nil:
|
||||||
return emojiReaction, nil
|
return emojiReaction, nil
|
||||||
default:
|
default:
|
||||||
|
@ -947,7 +941,7 @@ func (db sqlitePersistence) InvitationByID(id string) (*GroupChatInvitation, err
|
||||||
|
|
||||||
switch err {
|
switch err {
|
||||||
case sql.ErrNoRows:
|
case sql.ErrNoRows:
|
||||||
return nil, errRecordNotFound
|
return nil, common.ErrRecordNotFound
|
||||||
case nil:
|
case nil:
|
||||||
return chatInvitations, nil
|
return chatInvitations, nil
|
||||||
default:
|
default:
|
||||||
|
|
|
@ -890,7 +890,7 @@ func (m *Messenger) AddMembersToGroupChat(ctx context.Context, chatID string, me
|
||||||
}
|
}
|
||||||
|
|
||||||
groupChatInvitation, err = m.persistence.InvitationByID(groupChatInvitation.ID())
|
groupChatInvitation, err = m.persistence.InvitationByID(groupChatInvitation.ID())
|
||||||
if err != nil && err != errRecordNotFound {
|
if err != nil && err != common.ErrRecordNotFound {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if groupChatInvitation != nil {
|
if groupChatInvitation != nil {
|
||||||
|
@ -3199,7 +3199,7 @@ func (m *Messenger) AcceptRequestTransaction(ctx context.Context, transactionHas
|
||||||
|
|
||||||
// Hide previous message
|
// Hide previous message
|
||||||
previousMessage, err := m.persistence.MessageByCommandID(chatID, messageID)
|
previousMessage, err := m.persistence.MessageByCommandID(chatID, messageID)
|
||||||
if err != nil && err != errRecordNotFound {
|
if err != nil && err != common.ErrRecordNotFound {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3423,7 +3423,7 @@ func (m *Messenger) ValidateTransactions(ctx context.Context, addresses []types.
|
||||||
if len(message.CommandParameters.ID) != 0 {
|
if len(message.CommandParameters.ID) != 0 {
|
||||||
// Hide previous message
|
// Hide previous message
|
||||||
previousMessage, err := m.persistence.MessageByCommandID(chatID, message.CommandParameters.ID)
|
previousMessage, err := m.persistence.MessageByCommandID(chatID, message.CommandParameters.ID)
|
||||||
if err != nil && err != errRecordNotFound {
|
if err != nil && err != common.ErrRecordNotFound {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -7,7 +7,6 @@ import (
|
||||||
"crypto/cipher"
|
"crypto/cipher"
|
||||||
"crypto/ecdsa"
|
"crypto/ecdsa"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"database/sql"
|
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
@ -839,7 +838,7 @@ func (c *Client) getMessage(messageID string) (*common.Message, error) {
|
||||||
retries := 0
|
retries := 0
|
||||||
for retries < 10 {
|
for retries < 10 {
|
||||||
message, err := c.messagePersistence.MessageByID(messageID)
|
message, err := c.messagePersistence.MessageByID(messageID)
|
||||||
if err == sql.ErrNoRows {
|
if err == common.ErrRecordNotFound {
|
||||||
retries++
|
retries++
|
||||||
time.Sleep(300 * time.Millisecond)
|
time.Sleep(300 * time.Millisecond)
|
||||||
continue
|
continue
|
||||||
|
@ -849,7 +848,7 @@ func (c *Client) getMessage(messageID string) (*common.Message, error) {
|
||||||
|
|
||||||
return message, nil
|
return message, nil
|
||||||
}
|
}
|
||||||
return nil, sql.ErrNoRows
|
return nil, common.ErrRecordNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
// handlePublicMessageSent handles public messages, we notify only on mentions
|
// handlePublicMessageSent handles public messages, we notify only on mentions
|
||||||
|
|
Loading…
Reference in New Issue