package statusproto import ( "bytes" "context" "crypto/ecdsa" "database/sql" "encoding/gob" "encoding/hex" "time" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto/secp256k1" "github.com/pkg/errors" protocol "github.com/status-im/status-protocol-go/v1" ) const ( uniqueIDContstraint = "UNIQUE constraint failed: user_messages.id" ) var ( // ErrMsgAlreadyExist returned if msg already exist. ErrMsgAlreadyExist = errors.New("message with given ID already exist") ) // sqlitePersistence wrapper around sql db with operations common for a client. type sqlitePersistence struct { db *sql.DB } func (db sqlitePersistence) LastMessageClock(chatID string) (int64, error) { if chatID == "" { return 0, errors.New("chat ID is empty") } var last sql.NullInt64 err := db.db.QueryRow("SELECT max(clock) FROM user_messages WHERE chat_id = ?", chatID).Scan(&last) if err != nil { return 0, err } return last.Int64, nil } func (db sqlitePersistence) SaveChat(chat Chat) error { var err error pkey := []byte{} // For one to one chatID is an encoded public key if chat.ChatType == ChatTypeOneToOne { pkey, err = hex.DecodeString(chat.ID[2:]) if err != nil { return err } // Safety check, make sure is well formed _, err := crypto.UnmarshalPubkey(pkey) if err != nil { return err } } // Encode members var encodedMembers bytes.Buffer memberEncoder := gob.NewEncoder(&encodedMembers) if err := memberEncoder.Encode(chat.Members); err != nil { return err } // Encode membership updates var encodedMembershipUpdates bytes.Buffer membershipUpdatesEncoder := gob.NewEncoder(&encodedMembershipUpdates) if err := membershipUpdatesEncoder.Encode(chat.MembershipUpdates); err != nil { return err } // Insert record stmt, err := db.db.Prepare(`INSERT INTO chats(id, name, color, active, type, timestamp, deleted_at_clock_value, public_key, unviewed_message_count, last_clock_value, last_message_content_type, last_message_content, last_message_timestamp, members, membership_updates) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`) if err != nil { return err } defer stmt.Close() _, err = stmt.Exec( chat.ID, chat.Name, chat.Color, chat.Active, chat.ChatType, chat.Timestamp, chat.DeletedAtClockValue, pkey, chat.UnviewedMessagesCount, chat.LastClockValue, chat.LastMessageContentType, chat.LastMessageContent, chat.LastMessageTimestamp, encodedMembers.Bytes(), encodedMembershipUpdates.Bytes(), ) if err != nil { return err } return err } func (db sqlitePersistence) DeleteChat(chatID string) error { _, err := db.db.Exec("DELETE FROM chats WHERE id = ?", chatID) return err } func (db sqlitePersistence) Chats() ([]*Chat, error) { return db.chats(nil) } func (db sqlitePersistence) chats(tx *sql.Tx) ([]*Chat, error) { var err error if tx == nil { tx, err = db.db.BeginTx(context.Background(), &sql.TxOptions{}) if err != nil { return nil, err } defer func() { if err == nil { err = tx.Commit() return } // don't shadow original error _ = tx.Rollback() }() } rows, err := tx.Query(`SELECT id, name, color, active, type, timestamp, deleted_at_clock_value, public_key, unviewed_message_count, last_clock_value, last_message_content_type, last_message_content, last_message_timestamp, members, membership_updates FROM chats ORDER BY chats.timestamp DESC`) if err != nil { return nil, err } defer rows.Close() var response []*Chat for rows.Next() { var lastMessageContentType sql.NullString var lastMessageContent sql.NullString var lastMessageTimestamp sql.NullInt64 chat := &Chat{} encodedMembers := []byte{} encodedMembershipUpdates := []byte{} pkey := []byte{} err := rows.Scan( &chat.ID, &chat.Name, &chat.Color, &chat.Active, &chat.ChatType, &chat.Timestamp, &chat.DeletedAtClockValue, &pkey, &chat.UnviewedMessagesCount, &chat.LastClockValue, &lastMessageContentType, &lastMessageContent, &lastMessageTimestamp, &encodedMembers, &encodedMembershipUpdates, ) if err != nil { return nil, err } chat.LastMessageContent = lastMessageContent.String chat.LastMessageContentType = lastMessageContentType.String chat.LastMessageTimestamp = lastMessageTimestamp.Int64 // Restore members membersDecoder := gob.NewDecoder(bytes.NewBuffer(encodedMembers)) if err := membersDecoder.Decode(&chat.Members); err != nil { return nil, err } // Restore membership updates membershipUpdatesDecoder := gob.NewDecoder(bytes.NewBuffer(encodedMembershipUpdates)) if err := membershipUpdatesDecoder.Decode(&chat.MembershipUpdates); err != nil { return nil, err } if len(pkey) != 0 { chat.PublicKey, err = crypto.UnmarshalPubkey(pkey) if err != nil { return nil, err } } response = append(response, chat) } return response, nil } func (db sqlitePersistence) Contacts() ([]*Contact, error) { rows, err := db.db.Query(`SELECT id, address, name, photo, last_updated, system_tags, device_info, tribute_to_talk FROM contacts`) if err != nil { return nil, err } defer rows.Close() var response []*Contact for rows.Next() { contact := &Contact{} encodedDeviceInfo := []byte{} encodedSystemTags := []byte{} err := rows.Scan( &contact.ID, &contact.Address, &contact.Name, &contact.Photo, &contact.LastUpdated, &encodedSystemTags, &encodedDeviceInfo, &contact.TributeToTalk, ) if err != nil { return nil, err } // Restore device info deviceInfoDecoder := gob.NewDecoder(bytes.NewBuffer(encodedDeviceInfo)) if err := deviceInfoDecoder.Decode(&contact.DeviceInfo); err != nil { return nil, err } // Restore system tags systemTagsDecoder := gob.NewDecoder(bytes.NewBuffer(encodedSystemTags)) if err := systemTagsDecoder.Decode(&contact.SystemTags); err != nil { return nil, err } response = append(response, contact) } return response, nil } func (db sqlitePersistence) SaveContact(contact Contact, tx *sql.Tx) error { var err error if tx == nil { tx, err = db.db.BeginTx(context.Background(), &sql.TxOptions{}) if err != nil { return err } defer func() { if err == nil { err = tx.Commit() return } // don't shadow original error _ = tx.Rollback() }() } // Encode device info var encodedDeviceInfo bytes.Buffer deviceInfoEncoder := gob.NewEncoder(&encodedDeviceInfo) if err := deviceInfoEncoder.Encode(contact.DeviceInfo); err != nil { return err } // Encoded system tags var encodedSystemTags bytes.Buffer systemTagsEncoder := gob.NewEncoder(&encodedSystemTags) if err := systemTagsEncoder.Encode(contact.SystemTags); err != nil { return err } // Insert record stmt, err := tx.Prepare(`INSERT INTO contacts( id, address, name, photo, last_updated, system_tags, device_info, tribute_to_talk ) VALUES (?, ?, ?, ?, ?, ?, ?, ?)`) if err != nil { return err } defer stmt.Close() _, err = stmt.Exec( contact.ID, contact.Address, contact.Name, contact.Photo, contact.LastUpdated, encodedSystemTags.Bytes(), encodedDeviceInfo.Bytes(), contact.TributeToTalk, ) return err } // Messages returns messages for a given contact, in a given period. Ordered by a timestamp. func (db sqlitePersistence) Messages(from, to time.Time) (result []*protocol.Message, err error) { rows, err := db.db.Query(`SELECT id, chat_id, content_type, message_type, text, clock, timestamp, content_chat_id, content_text, public_key, flags FROM user_messages WHERE timestamp >= ? AND timestamp <= ? ORDER BY timestamp`, protocol.TimestampInMsFromTime(from), protocol.TimestampInMsFromTime(to), ) if err != nil { return nil, err } defer rows.Close() var rst []*protocol.Message for rows.Next() { msg := protocol.Message{ Content: protocol.Content{}, } pkey := []byte{} err = rows.Scan( &msg.ID, &msg.ChatID, &msg.ContentT, &msg.MessageT, &msg.Text, &msg.Clock, &msg.Timestamp, &msg.Content.ChatID, &msg.Content.Text, &pkey, &msg.Flags, ) if err != nil { return nil, err } if len(pkey) != 0 { msg.SigPubKey, err = unmarshalECDSAPub(pkey) if err != nil { return nil, err } } rst = append(rst, &msg) } return rst, nil } func (db sqlitePersistence) SaveMessages(messages []*protocol.Message) (last int64, err error) { var ( tx *sql.Tx stmt *sql.Stmt ) tx, err = db.db.BeginTx(context.Background(), &sql.TxOptions{}) if err != nil { return } defer func() { if err == nil { err = tx.Commit() return } // don't shadow original error _ = tx.Rollback() }() stmt, err = tx.Prepare(`INSERT INTO user_messages( id, chat_id, content_type, message_type, text, clock, timestamp, content_chat_id, content_text, public_key, flags ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) `) if err != nil { return } var rst sql.Result for _, msg := range messages { pkey := []byte{} if msg.SigPubKey != nil { pkey, err = marshalECDSAPub(msg.SigPubKey) } rst, err = stmt.Exec( msg.ID, msg.ChatID, msg.ContentT, msg.MessageT, msg.Text, msg.Clock, msg.Timestamp, msg.Content.ChatID, msg.Content.Text, pkey, msg.Flags, ) if err != nil { if err.Error() == uniqueIDContstraint { // skip duplicated messages err = nil continue } return } last, err = rst.LastInsertId() if err != nil { return } } return } func marshalECDSAPub(pub *ecdsa.PublicKey) (rst []byte, err error) { switch pub.Curve.(type) { case *secp256k1.BitCurve: rst = make([]byte, 34) rst[0] = 1 copy(rst[1:], secp256k1.CompressPubkey(pub.X, pub.Y)) return rst[:], nil default: return nil, errors.New("unknown curve") } } func unmarshalECDSAPub(buf []byte) (*ecdsa.PublicKey, error) { pub := &ecdsa.PublicKey{} if len(buf) < 1 { return nil, errors.New("too small") } switch buf[0] { case 1: pub.Curve = secp256k1.S256() pub.X, pub.Y = secp256k1.DecompressPubkey(buf[1:]) ok := pub.IsOnCurve(pub.X, pub.Y) if !ok { return nil, errors.New("not on curve") } return pub, nil default: return nil, errors.New("unknown curve") } }