diff --git a/persistence.go b/persistence.go index bf3316e..8e14f29 100644 --- a/persistence.go +++ b/persistence.go @@ -14,10 +14,6 @@ import ( protocol "github.com/status-im/status-protocol-go/v1" ) -const ( - uniqueIDContstraint = "UNIQUE constraint failed: user_messages.id" -) - var ( // ErrMsgAlreadyExist returned if msg already exist. ErrMsgAlreadyExist = errors.New("message with given ID already exist") @@ -117,62 +113,61 @@ func (db sqlitePersistence) Chats() ([]*Chat, error) { return db.chats(nil) } -func (db sqlitePersistence) chats(tx *sql.Tx) ([]*Chat, error) { - var err error - +func (db sqlitePersistence) chats(tx *sql.Tx) (chats []*Chat, err error) { if tx == nil { tx, err = db.db.BeginTx(context.Background(), &sql.TxOptions{}) if err != nil { - return nil, err + return } defer func() { if err == nil { err = tx.Commit() return - } // don't shadow original error _ = tx.Rollback() }() } - rows, err := tx.Query(`SELECT - id, - name, - color, - active, - type, - timestamp, - deleted_at_clock_value, - public_key, - unviewed_message_count, - last_clock_value, - last_message_content_type, - last_message_content, - last_message_timestamp, - last_message_clock_value, - members, - membership_updates - FROM chats - ORDER BY chats.timestamp DESC`) + rows, err := tx.Query(` + SELECT + id, + name, + color, + active, + type, + timestamp, + deleted_at_clock_value, + public_key, + unviewed_message_count, + last_clock_value, + last_message_content_type, + last_message_content, + last_message_timestamp, + last_message_clock_value, + members, + membership_updates + FROM chats + ORDER BY chats.timestamp DESC + `) if err != nil { - return nil, err + return } defer rows.Close() - var response []*Chat - for rows.Next() { - var lastMessageContentType sql.NullString - var lastMessageContent sql.NullString - var lastMessageTimestamp sql.NullInt64 - var lastMessageClockValue sql.NullInt64 + var ( + lastMessageContentType sql.NullString + lastMessageContent sql.NullString + lastMessageTimestamp sql.NullInt64 + lastMessageClockValue sql.NullInt64 - chat := &Chat{} - encodedMembers := []byte{} - encodedMembershipUpdates := []byte{} - pkey := []byte{} - err := rows.Scan( + chat Chat + encodedMembers []byte + encodedMembershipUpdates []byte + pkey []byte + ) + err = rows.Scan( &chat.ID, &chat.Name, &chat.Color, @@ -191,7 +186,7 @@ func (db sqlitePersistence) chats(tx *sql.Tx) ([]*Chat, error) { &encodedMembershipUpdates, ) if err != nil { - return nil, err + return } chat.LastMessageContent = lastMessageContent.String chat.LastMessageContentType = lastMessageContentType.String @@ -200,44 +195,47 @@ func (db sqlitePersistence) chats(tx *sql.Tx) ([]*Chat, error) { // Restore members membersDecoder := gob.NewDecoder(bytes.NewBuffer(encodedMembers)) - if err := membersDecoder.Decode(&chat.Members); err != nil { - return nil, err + err = membersDecoder.Decode(&chat.Members) + if err != nil { + return } // Restore membership updates membershipUpdatesDecoder := gob.NewDecoder(bytes.NewBuffer(encodedMembershipUpdates)) - if err := membershipUpdatesDecoder.Decode(&chat.MembershipUpdates); err != nil { - return nil, err + err = membershipUpdatesDecoder.Decode(&chat.MembershipUpdates) + if err != nil { + return } if len(pkey) != 0 { chat.PublicKey, err = crypto.UnmarshalPubkey(pkey) if err != nil { - return nil, err + return } } - response = append(response, chat) + chats = append(chats, &chat) } - return response, nil + return } func (db sqlitePersistence) Contacts() ([]*Contact, error) { - rows, err := db.db.Query(`SELECT - id, - address, - name, - alias, - identicon, - photo, - last_updated, - system_tags, - device_info, - ens_verified, - ens_verified_at, - tribute_to_talk - FROM contacts`) - + rows, err := db.db.Query(` + SELECT + id, + address, + name, + alias, + identicon, + photo, + last_updated, + system_tags, + device_info, + ens_verified, + ens_verified_at, + tribute_to_talk + FROM contacts + `) if err != nil { return nil, err } @@ -246,9 +244,11 @@ func (db sqlitePersistence) Contacts() ([]*Contact, error) { var response []*Contact for rows.Next() { - contact := &Contact{} - encodedDeviceInfo := []byte{} - encodedSystemTags := []byte{} + var ( + contact Contact + encodedDeviceInfo []byte + encodedSystemTags []byte + ) err := rows.Scan( &contact.ID, &contact.Address, @@ -283,7 +283,7 @@ func (db sqlitePersistence) Contacts() ([]*Contact, error) { } } - response = append(response, contact) + response = append(response, &contact) } return response, nil @@ -298,7 +298,6 @@ func (db sqlitePersistence) SetContactsENSData(contacts []Contact) error { if err == nil { err = tx.Commit() return - } // don't shadow original error _ = tx.Rollback() @@ -324,8 +323,7 @@ func (db sqlitePersistence) SetContactsENSData(contacts []Contact) error { // SetContactsGeneratedData sets a contact generated data if not existing already // in the database -func (db sqlitePersistence) SetContactsGeneratedData(contacts []Contact, tx *sql.Tx) error { - var err error +func (db sqlitePersistence) SetContactsGeneratedData(contacts []Contact, tx *sql.Tx) (err error) { if tx == nil { tx, err = db.db.BeginTx(context.Background(), &sql.TxOptions{}) if err != nil { @@ -343,36 +341,35 @@ func (db sqlitePersistence) SetContactsGeneratedData(contacts []Contact, tx *sql } for _, contact := range contacts { - _, err := tx.Exec(`INSERT OR IGNORE INTO contacts( - id, - address, - name, - alias, - identicon, - photo, - last_updated, - tribute_to_talk) - VALUES (?, ?, "", ?, ?, "", 0, "")`, + _, err = tx.Exec(` + INSERT OR IGNORE INTO contacts( + id, + address, + name, + alias, + identicon, + photo, + last_updated, + tribute_to_talk + ) VALUES (?, ?, "", ?, ?, "", 0, "")`, contact.ID, contact.Address, contact.Alias, contact.Identicon, ) if err != nil { - return err + return } } - return nil + return } -func (db sqlitePersistence) SaveContact(contact Contact, tx *sql.Tx) error { - var err error - +func (db sqlitePersistence) SaveContact(contact Contact, tx *sql.Tx) (err error) { if tx == nil { tx, err = db.db.BeginTx(context.Background(), &sql.TxOptions{}) if err != nil { - return err + return } defer func() { if err == nil { @@ -387,37 +384,38 @@ func (db sqlitePersistence) SaveContact(contact Contact, tx *sql.Tx) error { // Encode device info var encodedDeviceInfo bytes.Buffer deviceInfoEncoder := gob.NewEncoder(&encodedDeviceInfo) - - if err := deviceInfoEncoder.Encode(contact.DeviceInfo); err != nil { - return err + err = deviceInfoEncoder.Encode(contact.DeviceInfo) + if err != nil { + return } // Encoded system tags var encodedSystemTags bytes.Buffer systemTagsEncoder := gob.NewEncoder(&encodedSystemTags) - - if err := systemTagsEncoder.Encode(contact.SystemTags); err != nil { - return err + err = systemTagsEncoder.Encode(contact.SystemTags) + if err != nil { + return } // Insert record - stmt, err := tx.Prepare(`INSERT INTO contacts( - id, - address, - name, - alias, - identicon, - photo, - last_updated, - system_tags, - device_info, - ens_verified, - ens_verified_at, - tribute_to_talk - ) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`) + stmt, err := tx.Prepare(` + INSERT INTO contacts( + id, + address, + name, + alias, + identicon, + photo, + last_updated, + system_tags, + device_info, + ens_verified, + ens_verified_at, + tribute_to_talk + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + `) if err != nil { - return err + return } defer stmt.Close() @@ -435,12 +433,13 @@ func (db sqlitePersistence) SaveContact(contact Contact, tx *sql.Tx) error { contact.ENSVerifiedAt, contact.TributeToTalk, ) - return err + return } // Messages returns messages for a given contact, in a given period. Ordered by a timestamp. func (db sqlitePersistence) Messages(from, to time.Time) (result []*protocol.Message, err error) { - rows, err := db.db.Query(`SELECT + rows, err := db.db.Query(` + SELECT id, chat_id, content_type, @@ -452,47 +451,42 @@ func (db sqlitePersistence) Messages(from, to time.Time) (result []*protocol.Mes content_text, public_key, flags - FROM user_messages + FROM user_messages WHERE timestamp >= ? AND timestamp <= ? ORDER BY timestamp`, protocol.TimestampInMsFromTime(from), protocol.TimestampInMsFromTime(to), ) if err != nil { - return nil, err + return } defer rows.Close() - var rst []*protocol.Message for rows.Next() { msg := protocol.Message{ Content: protocol.Content{}, } - pkey := []byte{} + var pkey []byte err = rows.Scan( &msg.ID, &msg.ChatID, &msg.ContentT, &msg.MessageT, &msg.Text, &msg.Clock, &msg.Timestamp, &msg.Content.ChatID, &msg.Content.Text, &pkey, &msg.Flags, ) if err != nil { - return nil, err + return } if len(pkey) != 0 { msg.SigPubKey, err = crypto.UnmarshalPubkey(pkey) if err != nil { - return nil, err + return } } - rst = append(rst, &msg) + result = append(result, &msg) } - return rst, nil + return } func (db sqlitePersistence) SaveMessages(messages []*protocol.Message) (last int64, err error) { - var ( - tx *sql.Tx - stmt *sql.Stmt - ) - tx, err = db.db.BeginTx(context.Background(), &sql.TxOptions{}) + tx, err := db.db.BeginTx(context.Background(), &sql.TxOptions{}) if err != nil { return } @@ -505,8 +499,8 @@ func (db sqlitePersistence) SaveMessages(messages []*protocol.Message) (last int _ = tx.Rollback() }() - stmt, err = tx.Prepare(`INSERT INTO - user_messages( + stmt, err := tx.Prepare(` + INSERT OR IGNORE INTO user_messages( id, chat_id, content_type, @@ -536,11 +530,6 @@ func (db sqlitePersistence) SaveMessages(messages []*protocol.Message) (last int msg.Content.ChatID, msg.Content.Text, pkey, msg.Flags, ) if err != nil { - if err.Error() == uniqueIDContstraint { - // skip duplicated messages - err = nil - continue - } return } @@ -549,5 +538,6 @@ func (db sqlitePersistence) SaveMessages(messages []*protocol.Message) (last int return } } + return } diff --git a/persistence_legacy.go b/persistence_legacy.go index a8c041b..b7a8642 100644 --- a/persistence_legacy.go +++ b/persistence_legacy.go @@ -267,21 +267,15 @@ func (db sqlitePersistence) MessageByChatID(chatID string, currCursor string, li return result, newCursor, nil } -func (db sqlitePersistence) SaveMessagesLegacy(messages []*Message) error { - var ( - tx *sql.Tx - stmt *sql.Stmt - err error - ) - tx, err = db.db.BeginTx(context.Background(), &sql.TxOptions{}) +func (db sqlitePersistence) SaveMessagesLegacy(messages []*Message) (err error) { + tx, err := db.db.BeginTx(context.Background(), &sql.TxOptions{}) if err != nil { - return err + return } defer func() { if err == nil { err = tx.Commit() return - } // don't shadow original error _ = tx.Rollback() @@ -289,21 +283,19 @@ func (db sqlitePersistence) SaveMessagesLegacy(messages []*Message) error { allFields := db.tableUserMessagesLegacyAllFields() valuesVector := strings.Repeat("?, ", db.tableUserMessagesLegacyAllFieldsCount()-1) + "?" - query := fmt.Sprintf(`INSERT INTO user_messages_legacy(%s) VALUES (%s)`, allFields, valuesVector) - - stmt, err = tx.Prepare(query) + stmt, err := tx.Prepare(query) if err != nil { - return err + return } for _, msg := range messages { - _, err := stmt.Exec(db.tableUserMessagesLegacyAllValues(msg)...) + _, err = stmt.Exec(db.tableUserMessagesLegacyAllValues(msg)...) if err != nil { - return err + return } } - return err + return } func (db sqlitePersistence) DeleteMessage(id string) error { @@ -343,20 +335,15 @@ func (db sqlitePersistence) UpdateMessageOutgoingStatus(id string, newOutgoingSt } // BlockContact updates a contact, deletes all the messages and 1-to-1 chat, updates the unread messages count and returns a map with the new count -func (db sqlitePersistence) BlockContact(contact Contact) ([]*Chat, error) { - var ( - tx *sql.Tx - err error - ) - tx, err = db.db.BeginTx(context.Background(), &sql.TxOptions{}) +func (db sqlitePersistence) BlockContact(contact Contact) (chats []*Chat, err error) { + tx, err := db.db.BeginTx(context.Background(), &sql.TxOptions{}) if err != nil { - return nil, err + return } defer func() { if err == nil { err = tx.Commit() return - } // don't shadow original error _ = tx.Rollback() @@ -370,34 +357,36 @@ func (db sqlitePersistence) BlockContact(contact Contact) ([]*Chat, error) { contact.ID, ) if err != nil { - return nil, err + return } // Update contact err = db.SaveContact(contact, tx) if err != nil { - return nil, err + return } // Delete one-to-one chat _, err = tx.Exec("DELETE FROM chats WHERE id = ?", contact.ID) if err != nil { - return nil, err + return } // Recalculate denormalized fields _, err = tx.Exec(` - UPDATE chats - SET - unviewed_message_count = (SELECT COUNT(1) FROM user_messages_legacy WHERE seen = 0 AND chat_id = chats.id), - last_message_content = (SELECT content from user_messages_legacy WHERE chat_id = chats.id ORDER BY clock_value DESC LIMIT 1), - last_message_timestamp = (SELECT timestamp from user_messages_legacy WHERE chat_id = chats.id ORDER BY clock_value DESC LIMIT 1), - last_message_clock_value = (SELECT clock_value from user_messages_legacy WHERE chat_id = chats.id ORDER BY clock_value DESC LIMIT 1), - last_message_content_type = (SELECT content_type from user_messages_legacy WHERE chat_id = chats.id ORDER BY clock_value DESC LIMIT 1)`) + UPDATE chats + SET + unviewed_message_count = (SELECT COUNT(1) FROM user_messages_legacy WHERE seen = 0 AND chat_id = chats.id), + last_message_content = (SELECT content from user_messages_legacy WHERE chat_id = chats.id ORDER BY clock_value DESC LIMIT 1), + last_message_timestamp = (SELECT timestamp from user_messages_legacy WHERE chat_id = chats.id ORDER BY clock_value DESC LIMIT 1), + last_message_clock_value = (SELECT clock_value from user_messages_legacy WHERE chat_id = chats.id ORDER BY clock_value DESC LIMIT 1), + last_message_content_type = (SELECT content_type from user_messages_legacy WHERE chat_id = chats.id ORDER BY clock_value DESC LIMIT 1) + `) if err != nil { - return nil, err + return } // return the updated chats - return db.chats(tx) + chats, err = db.chats(tx) + return }