diff --git a/go.mod b/go.mod index 86c4ce960..0685b16f2 100644 --- a/go.mod +++ b/go.mod @@ -21,7 +21,7 @@ require ( github.com/status-im/doubleratchet v2.0.0+incompatible github.com/status-im/migrate/v4 v4.3.1-status github.com/status-im/rendezvous v1.3.0 - github.com/status-im/status-protocol-go v0.0.0-20190701094942-c2b7b022b722d7bebe1c6d6f05cdead79f1b57bd + github.com/status-im/status-protocol-go v0.0.0-20190701094942-9f0db157bf5a1ac9b52c47bbea37fc7dbe14e8fe github.com/status-im/whisper v1.4.14 github.com/stretchr/testify v1.3.0 github.com/syndtr/goleveldb v1.0.0 diff --git a/go.sum b/go.sum index 01b582953..7ad9f12d5 100644 --- a/go.sum +++ b/go.sum @@ -450,8 +450,8 @@ github.com/status-im/migrate/v4 v4.3.1-status h1:tJwsEYLgbFkvlTSMk89APwRDfpr4yG8 github.com/status-im/migrate/v4 v4.3.1-status/go.mod h1:r8HggRBZ/k7TRwByq/Hp3P/ubFppIna0nvyavVK0pjA= github.com/status-im/rendezvous v1.3.0 h1:7RK/MXXW+tlm0asKm1u7Qp7Yni6AO29a7j8+E4Lbjg4= github.com/status-im/rendezvous v1.3.0/go.mod h1:+hzjuP+j/XzLPeF6E50b88pWOTLdTcwjvNYt+Gh1W1s= -github.com/status-im/status-protocol-go v0.0.0-20190701094942-c2b7b022b722d7bebe1c6d6f05cdead79f1b57bd h1:ZGCzGQ41kPy5oNpHColf3ZTNN9DXWZATgJoV2cQZaC4= -github.com/status-im/status-protocol-go v0.0.0-20190701094942-c2b7b022b722d7bebe1c6d6f05cdead79f1b57bd/go.mod h1:thrQ4V0ZUmLZPDf74xVzub1gxgSNFaSTeTQdxtRJnTU= +github.com/status-im/status-protocol-go v0.0.0-20190701094942-9f0db157bf5a1ac9b52c47bbea37fc7dbe14e8fe h1:QqpJe7fgZk8nKWfjUiYi9SCoN3Ozveyn60b8xWq3rik= +github.com/status-im/status-protocol-go v0.0.0-20190701094942-9f0db157bf5a1ac9b52c47bbea37fc7dbe14e8fe/go.mod h1:thrQ4V0ZUmLZPDf74xVzub1gxgSNFaSTeTQdxtRJnTU= github.com/status-im/whisper v1.4.14 h1:9VHqx4+PUYfhDnYYtDxHkg/3cfVvkHjPNciY4LO83yc= github.com/status-im/whisper v1.4.14/go.mod h1:WS6z39YJQ8WJa9s+DmTuEM/s2nVF6Iz3B1SZYw5cYf0= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= diff --git a/services/shhext/api.go b/services/shhext/api.go index c0767de98..1a756a445 100644 --- a/services/shhext/api.go +++ b/services/shhext/api.go @@ -9,14 +9,13 @@ import ( "math/big" "time" - "github.com/status-im/status-go/services/shhext/dedup" - "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/hexutil" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/rlp" + "github.com/status-im/status-go/services/shhext/dedup" "github.com/status-im/status-go/db" "github.com/status-im/status-go/mailserver" @@ -400,46 +399,24 @@ func (api *PublicAPI) SyncMessages(ctx context.Context, r SyncMessagesRequest) ( } } -// GetNewFilterMessages is a prototype method with deduplication. -func (api *PublicAPI) GetNewFilterMessages(filterID string) ([]dedup.DeduplicateMessage, error) { - messages, err := api.publicAPI.GetFilterMessages(filterID) - if err != nil { - return nil, err - } - return api.service.deduplicator.Deduplicate(messages), nil -} - -// ConfirmMessagesProcessed is a method to confirm that messages was consumed by -// the client side. -func (api *PublicAPI) ConfirmMessagesProcessed(messages []*whisper.Message) (err error) { - tx := api.service.storage.NewTx() - defer func() { - if err == nil { - err = tx.Commit() - } - }() - ctx := NewContextFromService(context.Background(), api.service, tx) - for _, msg := range messages { - if msg.P2P { - err = api.service.historyUpdates.UpdateTopicHistory(ctx, msg.Topic, time.Unix(int64(msg.Timestamp), 0)) - if err != nil { - return err - } - } - } - err = api.service.deduplicator.AddMessages(messages) - return err -} - // ConfirmMessagesProcessedByID is a method to confirm that messages was consumed by // the client side. // TODO: this is broken now as it requires dedup ID while a message hash should be used. -func (api *PublicAPI) ConfirmMessagesProcessedByID(messageIDs [][]byte) error { - /*if err := api.service.ConfirmMessagesProcessed(messageIDs); err != nil { - return err - }*/ +func (api *PublicAPI) ConfirmMessagesProcessedByID(messageConfirmations []*dedup.Metadata) error { + confirmationCount := len(messageConfirmations) + dedupIDs := make([][]byte, confirmationCount) + encryptionIDs := make([][]byte, confirmationCount) - return api.service.deduplicator.AddMessageByID(messageIDs) + for i, confirmation := range messageConfirmations { + dedupIDs[i] = confirmation.DedupID + encryptionIDs[i] = confirmation.EncryptionID + } + + if err := api.service.ConfirmMessagesProcessed(encryptionIDs); err != nil { + return err + } + + return api.service.deduplicator.AddMessageByID(dedupIDs) } // Post is used to send one-to-one for those who did not enabled device-to-device sync, @@ -590,6 +567,7 @@ func (api *PublicAPI) LoadFilters(parent context.Context, chats []*statustransp. } func (api *PublicAPI) SaveChat(parent context.Context, chat statusproto.Chat) error { + api.log.Info("saving chat", "chat", chat) return api.service.messenger.SaveChat(chat) } @@ -597,14 +575,19 @@ func (api *PublicAPI) Chats(parent context.Context, to, from int) ([]*statusprot return api.service.messenger.Chats(to, from) } -func (api *PublicAPI) DeleteChat(parent context.Context, chatID string, chatType statusproto.ChatType) error { - return api.service.messenger.DeleteChat(chatID, chatType) +func (api *PublicAPI) DeleteChat(parent context.Context, chatID string) error { + return api.service.messenger.DeleteChat(chatID) } func (api *PublicAPI) SaveContact(parent context.Context, contact statusproto.Contact) error { return api.service.messenger.SaveContact(contact) } +func (api *PublicAPI) BlockContact(parent context.Context, contact statusproto.Contact) ([]*statusproto.Chat, error) { + api.log.Info("blocking contact", "contact", contact.ID) + return api.service.messenger.BlockContact(contact) +} + func (api *PublicAPI) Contacts(parent context.Context) ([]*statusproto.Contact, error) { return api.service.messenger.Contacts() } @@ -633,23 +616,36 @@ func (api *PublicAPI) SetInstallationMetadata(installationID string, data *multi return api.service.messenger.SetInstallationMetadata(installationID, data) } -func (api *PublicAPI) MessageByChatID(chatID, cursor string, limit int) ([]*statusproto.Message, string, error) { - return api.service.messenger.MessageByChatID(chatID, cursor, limit) +type ApplicationMessagesResponse struct { + Messages []*statusproto.Message `json:"messages"` + Cursor string `json:"cursor"` } -func (api *PublicAPI) MessagesFrom(from string) ([]*statusproto.Message, error) { - return api.service.messenger.MessagesFrom(from) +func (api *PublicAPI) ChatMessages(chatID, cursor string, limit int) (*ApplicationMessagesResponse, error) { + messages, cursor, err := api.service.messenger.MessageByChatID(chatID, cursor, limit) + if err != nil { + return nil, err + } + + return &ApplicationMessagesResponse{ + Messages: messages, + Cursor: cursor, + }, nil } -func (api *PublicAPI) SaveMessage(message *statusproto.Message) error { - return api.service.messenger.SaveMessage(message) +func (api *PublicAPI) SaveMessages(messages []*statusproto.Message) error { + return api.service.messenger.SaveMessages(messages) } func (api *PublicAPI) DeleteMessage(id string) error { return api.service.messenger.DeleteMessage(id) } -func (api *PublicAPI) MarkMessagesSeen(ids ...string) error { +func (api *PublicAPI) DeleteMessagesByChatID(id string) error { + return api.service.messenger.DeleteMessagesByChatID(id) +} + +func (api *PublicAPI) MarkMessagesSeen(ids []string) error { return api.service.messenger.MarkMessagesSeen(ids...) } diff --git a/services/shhext/dedup/deduplicator.go b/services/shhext/dedup/deduplicator.go index e1a7560b2..67bdc4182 100644 --- a/services/shhext/dedup/deduplicator.go +++ b/services/shhext/dedup/deduplicator.go @@ -2,6 +2,8 @@ package dedup import ( "github.com/ethereum/go-ethereum/log" + + "github.com/ethereum/go-ethereum/common/hexutil" whisper "github.com/status-im/whisper/whisperv6" "github.com/syndtr/goleveldb/leveldb" ) @@ -18,9 +20,15 @@ type Deduplicator struct { log log.Logger } +type Metadata struct { + DedupID []byte `json:"dedupId"` + EncryptionID hexutil.Bytes `json:"encryptionId"` + MessageID hexutil.Bytes `json:"messageId"` +} + type DeduplicateMessage struct { - DedupID []byte `json:"id"` - Message *whisper.Message `json:"message"` + Message *whisper.Message `json:"message"` + Metadata Metadata `json:"metadata"` } // NewDeduplicator creates a new deduplicator @@ -35,19 +43,17 @@ func NewDeduplicator(keyPairProvider keyPairProvider, db *leveldb.DB) *Deduplica // Deduplicate receives a list of whisper messages and // returns the list of the messages that weren't filtered previously for the // specified filter. -func (d *Deduplicator) Deduplicate(messages []*whisper.Message) []DeduplicateMessage { - result := make([]DeduplicateMessage, 0) +func (d *Deduplicator) Deduplicate(messages []*DeduplicateMessage) []*DeduplicateMessage { + result := make([]*DeduplicateMessage, 0) selectedKeyPairID := d.keyPairProvider.SelectedKeyPairID() for _, message := range messages { - if has, err := d.cache.Has(selectedKeyPairID, message); !has { + if has, err := d.cache.Has(selectedKeyPairID, message.Message); !has { if err != nil { d.log.Error("error while deduplicating messages: search cache failed", "err", err) } - result = append(result, DeduplicateMessage{ - DedupID: d.cache.KeyToday(selectedKeyPairID, message), - Message: message, - }) + message.Metadata.DedupID = d.cache.KeyToday(selectedKeyPairID, message.Message) + result = append(result, message) } } @@ -56,8 +62,8 @@ func (d *Deduplicator) Deduplicate(messages []*whisper.Message) []DeduplicateMes // AddMessages adds a message to the deduplicator DB, so it will be filtered // out. -func (d *Deduplicator) AddMessages(messages []*whisper.Message) error { - return d.cache.Put(d.keyPairProvider.SelectedKeyPairID(), messages) +func (d *Deduplicator) AddMessagesByID(messageIDs [][]byte) error { + return d.cache.PutIDs(messageIDs) } // AddMessageByID adds a message to the deduplicator DB, so it will be filtered diff --git a/services/shhext/dedup/deduplicator_test.go b/services/shhext/dedup/deduplicator_test.go index 89bfc823a..ebc8a0a53 100644 --- a/services/shhext/dedup/deduplicator_test.go +++ b/services/shhext/dedup/deduplicator_test.go @@ -43,7 +43,7 @@ func BenchmarkDeduplicate30000MessagesADay(b *testing.B) { d := NewDeduplicator(dummyKeyPairProvider{}, db) b.Log("generating messages") - messagesOld := generateMessages(100000) + messagesOld := generateDedupMessages(100000) b.Log("generation is done") // pre-fill deduplicator @@ -62,8 +62,12 @@ func BenchmarkDeduplicate30000MessagesADay(b *testing.B) { } messages := messagesOld[start:(start + length)] start += length - d.Deduplicate(messages) - assert.NoError(b, d.AddMessages(messages)) + dedupMessages := d.Deduplicate(messages) + ids := make([][]byte, len(dedupMessages)) + for i, m := range dedupMessages { + ids[i] = m.Metadata.DedupID + } + assert.NoError(b, d.AddMessagesByID(ids)) } } @@ -92,34 +96,48 @@ func (s *DeduplicatorTestSuite) TearDownTest() { func (s *DeduplicatorTestSuite) TestDeduplicateSingleFilter() { s.d.keyPairProvider = dummyKeyPairProvider{"acc1"} - messages1 := generateMessages(10) - messages2 := generateMessages(12) + messages1 := generateDedupMessages(10) + messages2 := generateDedupMessages(12) result := s.d.Deduplicate(messages1) s.Equal(len(messages1), len(result)) - s.NoError(s.d.AddMessages(messages1)) + + ids := make([][]byte, len(result)) + for i, m := range result { + ids[i] = m.Metadata.DedupID + } + s.NoError(s.d.AddMessagesByID(ids)) result = s.d.Deduplicate(messages1) s.Equal(0, len(result)) result = s.d.Deduplicate(messages2) s.Equal(len(messages2), len(result)) - s.NoError(s.d.AddMessages(messages2)) - messages3 := append(messages2, generateMessages(11)...) + ids = make([][]byte, len(result)) + for i, m := range result { + ids[i] = m.Metadata.DedupID + } + s.NoError(s.d.AddMessagesByID(ids)) + + messages3 := append(messages2, generateDedupMessages(11)...) result = s.d.Deduplicate(messages3) s.Equal(11, len(result)) } func (s *DeduplicatorTestSuite) TestDeduplicateMultipleFilters() { - messages1 := generateMessages(10) + messages1 := generateDedupMessages(10) s.d.keyPairProvider = dummyKeyPairProvider{"acc1"} result := s.d.Deduplicate(messages1) s.Equal(len(messages1), len(result)) + ids := make([][]byte, len(result)) + for i, m := range result { + ids[i] = m.Metadata.DedupID + } - s.NoError(s.d.AddMessages(messages1)) + s.NoError(s.d.AddMessagesByID(ids)) result = s.d.Deduplicate(messages1) s.Equal(0, len(result)) diff --git a/services/shhext/dedup/utils_test.go b/services/shhext/dedup/utils_test.go index 8c33e9eb9..64372d2f6 100644 --- a/services/shhext/dedup/utils_test.go +++ b/services/shhext/dedup/utils_test.go @@ -15,6 +15,18 @@ func generateMessages(count int) []*whisper.Message { return result } +func generateDedupMessages(count int) []*DeduplicateMessage { + result := []*DeduplicateMessage{} + for ; count > 0; count-- { + content := mustGenerateRandomBytes() + result = append(result, &DeduplicateMessage{ + Metadata: Metadata{}, + Message: &whisper.Message{Payload: content}, + }) + } + return result +} + func mustGenerateRandomBytes() []byte { c := 2048 b := make([]byte, c) diff --git a/services/shhext/service.go b/services/shhext/service.go index 6950f21f5..c1569cad5 100644 --- a/services/shhext/service.go +++ b/services/shhext/service.go @@ -235,23 +235,50 @@ func (s *Service) retrieveMessagesLoop(tick time.Duration, cancel <-chan struct{ log.Error("failed to retrieve raw messages", "err", err) continue } + var messageIDs []string + + for _, messages := range chatWithMessages { + for _, message := range messages { + messageIDs = append(messageIDs, message.ID.String()) + } + } + + existingMessages, err := s.messenger.MessagesExist(messageIDs) + if err != nil { + log.Error("failed to check existing messages", "err", err) + continue + } var signalMessages []*signal.Messages for chat, messages := range chatWithMessages { - var retrievedMessages []*whisper.Message - for _, message := range messages { - whisperMessage := message.TransportMessage - whisperMessage.Payload = message.DecryptedPayload - retrievedMessages = append(retrievedMessages, whisperMessage) - } - signalMessage := &signal.Messages{ - Chat: chat, - Error: nil, // TODO: what is it needed for? - Messages: s.deduplicator.Deduplicate(retrievedMessages), + var dedupMessages []*dedup.DeduplicateMessage + // Filter out already saved messages + for _, message := range messages { + if !existingMessages[message.ID.String()] { + dedupMessage := &dedup.DeduplicateMessage{ + Metadata: dedup.Metadata{ + MessageID: message.ID, + EncryptionID: message.Hash, + }, + Message: message.TransportMessage, + } + dedupMessage.Message.Payload = message.DecryptedPayload + dedupMessages = append(dedupMessages, dedupMessage) + } + } + dedupMessages = s.deduplicator.Deduplicate(dedupMessages) + + if len(dedupMessages) != 0 { + signalMessage := &signal.Messages{ + Chat: chat, + Error: nil, // TODO: what is it needed for? + Messages: dedupMessages, + } + + signalMessages = append(signalMessages, signalMessage) } - signalMessages = append(signalMessages, signalMessage) } log.Debug("retrieve messages loop", "messages", len(signalMessages)) diff --git a/services/shhext/service_test.go b/services/shhext/service_test.go index 3f89e69e6..3191446f7 100644 --- a/services/shhext/service_test.go +++ b/services/shhext/service_test.go @@ -3,15 +3,12 @@ package shhext import ( "context" "encoding/hex" - "encoding/json" "errors" "fmt" - "github.com/status-im/status-go/signal" "io/ioutil" "math" "net" "os" - "sync/atomic" "testing" "time" @@ -360,94 +357,6 @@ func (s *ShhExtSuite) TestRequestMessagesSuccess() { s.Require().NotNil(hash) } -// TestRetrieveMessageLoopNoMessages verifies that there are no signals sent -// if there are no messages. -func (s *ShhExtSuite) TestRetrieveMessageLoopNoMessages() { - shhConfig := whisper.DefaultConfig - shhConfig.MinimumAcceptedPOW = 0 // accept all messages - shh := whisper.New(&shhConfig) - privateKey, err := crypto.GenerateKey() - s.Require().NoError(err) - err = shh.SelectKeyPair(privateKey) - s.Require().NoError(err) - aNode, err := node.New(&node.Config{ - P2P: p2p.Config{ - MaxPeers: math.MaxInt32, - NoDiscovery: true, - }, - NoUSB: true, - }) // in-memory node as no data dir - s.Require().NoError(err) - err = aNode.Register(func(*node.ServiceContext) (node.Service, error) { return shh, nil }) - s.Require().NoError(err) - - err = aNode.Start() - s.Require().NoError(err) - defer func() { err := aNode.Stop(); s.NoError(err) }() - - mock := newHandlerMock(1) - config := params.ShhextConfig{ - InstallationID: "1", - BackupDisabledDataDir: os.TempDir(), - PFSEnabled: true, - } - db, err := leveldb.Open(storage.NewMemStorage(), nil) - s.Require().NoError(err) - service := New(shh, mock, db, config) - s.Require().NoError(service.InitProtocolWithPassword("abc", "password")) - - testCases := []struct { - name string - signalName string - action func() - expectedValue int - }{ - { - name: "send one public message", - signalName: signal.EventNewMessages, - action: func() { - api := NewPublicAPI(service) - _, err = api.SendPublicMessage(context.Background(), SendPublicMessageRPC{ - Chat: "test", - Payload: []byte("abc"), - }) - s.Require().NoError(err) - }, - expectedValue: 1, - }, - { - name: "no messages", - action: func() {}, - expectedValue: 0, - }, - } - - for _, tc := range testCases { - s.Run(tc.name, func() { - // Verify a proper signal is sent when a message is received. - var counter int64 - signal.SetDefaultNodeNotificationHandler(func(jsonEvent string) { - var envelope signal.Envelope - err := json.Unmarshal([]byte(jsonEvent), &envelope) - s.NoError(err) - - switch envelope.Type { - case signal.EventNewMessages: - atomic.AddInt64(&counter, 1) - } - }) - - tc.action() - - cancel := make(chan struct{}) - go service.retrieveMessagesLoop(time.Millisecond*10, cancel) - time.Sleep(time.Millisecond * 100) - close(cancel) - s.Require().EqualValues(tc.expectedValue, counter) - }) - } -} - func (s *ShhExtSuite) TearDown() { for _, n := range s.nodes { s.NoError(n.Stop()) diff --git a/signal/events_shhext.go b/signal/events_shhext.go index 026f4d314..82fec2560 100644 --- a/signal/events_shhext.go +++ b/signal/events_shhext.go @@ -147,9 +147,9 @@ type EnodeDiscoveredSignal struct { } type Messages struct { - Error error `json:"error"` - Messages []dedup.DeduplicateMessage `json:"messages"` - Chat statustransp.Filter `json:"chat"` // not a mistake, it's called chat in status-react + Error error `json:"error"` + Messages []*dedup.DeduplicateMessage `json:"messages"` + Chat statustransp.Filter `json:"chat"` // not a mistake, it's called chat in status-react } // SendEnodeDiscovered tiggered when an enode is discovered. diff --git a/vendor/github.com/status-im/status-protocol-go/chat.go b/vendor/github.com/status-im/status-protocol-go/chat.go index b71174de0..f378dd63d 100644 --- a/vendor/github.com/status-im/status-protocol-go/chat.go +++ b/vendor/github.com/status-im/status-protocol-go/chat.go @@ -57,7 +57,7 @@ type ChatMembershipUpdate struct { // Type indicates the kind of event (i.e changed-name, added-member, etc) Type string `json:"type"` // Name represents the name in the event of changing name events - Name string `json:"name"` + Name string `json:"name,omitempty"` // Clock value of the event ClockValue uint64 `json:"clockValue"` // Signature of the event @@ -65,9 +65,9 @@ type ChatMembershipUpdate struct { // Hex encoded public key of the creator of the event From string `json:"from"` // Target of the event for single-target events - Member string `json:"member"` + Member string `json:"member,omitempty"` // Target of the event for multi-target events - Members []string `json:"members"` + Members []string `json:"members,omitempty"` } // ChatMember represents a member who participates in a group chat diff --git a/vendor/github.com/status-im/status-protocol-go/encryption/encryptor.go b/vendor/github.com/status-im/status-protocol-go/encryption/encryptor.go index d4db00bae..64a3ca9f1 100644 --- a/vendor/github.com/status-im/status-protocol-go/encryption/encryptor.go +++ b/vendor/github.com/status-im/status-protocol-go/encryption/encryptor.go @@ -5,7 +5,6 @@ import ( "database/sql" "encoding/hex" "errors" - "fmt" "sync" "time" @@ -122,8 +121,9 @@ func (s *encryptor) ConfirmMessageProcessed(messageID []byte) error { id := confirmationIDString(messageID) confirmationData, ok := s.messageIDs[id] if !ok { - s.logger.Debug("could not confirm message", zap.String("messageID", id)) - return fmt.Errorf("message with ID %#x not found", messageID) + s.logger.Debug("could not confirm message or message already confirmed", zap.String("messageID", id)) + // We are ok with this, means no key material is stored (public message, or already confirmed) + return nil } // Load session from store first @@ -136,6 +136,9 @@ func (s *encryptor) ConfirmMessageProcessed(messageID []byte) error { return err } + // Clean up + delete(s.messageIDs, id) + return nil } diff --git a/vendor/github.com/status-im/status-protocol-go/encryption/protocol.go b/vendor/github.com/status-im/status-protocol-go/encryption/protocol.go index dd19ee285..ce4d3c493 100644 --- a/vendor/github.com/status-im/status-protocol-go/encryption/protocol.go +++ b/vendor/github.com/status-im/status-protocol-go/encryption/protocol.go @@ -405,6 +405,8 @@ func (p *Protocol) GetPublicBundle(theirIdentityKey *ecdsa.PublicKey) (*Bundle, // ConfirmMessageProcessed confirms and deletes message keys for the given messages func (p *Protocol) ConfirmMessageProcessed(messageID []byte) error { + logger := p.logger.With(zap.String("site", "ConfirmMessageProcessed")) + logger.Debug("confirming message", zap.Binary("message-id", messageID)) return p.encryptor.ConfirmMessageProcessed(messageID) } @@ -417,7 +419,7 @@ func (p *Protocol) HandleMessage( ) ([]byte, error) { logger := p.logger.With(zap.String("site", "HandleMessage")) - logger.Debug("received a protocol message", zap.Binary("sender-public-key", crypto.FromECDSAPub(theirPublicKey))) + logger.Debug("received a protocol message", zap.Binary("sender-public-key", crypto.FromECDSAPub(theirPublicKey)), zap.Binary("message-id", messageID)) if p.encryptor == nil { return nil, errors.New("encryption service not initialized") diff --git a/vendor/github.com/status-im/status-protocol-go/message.go b/vendor/github.com/status-im/status-protocol-go/message.go index a4c83024e..7e16778d1 100644 --- a/vendor/github.com/status-im/status-protocol-go/message.go +++ b/vendor/github.com/status-im/status-protocol-go/message.go @@ -13,6 +13,10 @@ func (h hexutilSQL) Value() (driver.Value, error) { return []byte(h), nil } +func (h hexutilSQL) String() string { + return hexutil.Encode(h) +} + func (h *hexutilSQL) Scan(value interface{}) error { if value == nil { return nil @@ -24,6 +28,13 @@ func (h *hexutilSQL) Scan(value interface{}) error { return errors.New("failed to scan hexutilSQL") } +// QuotedMessage contains the original text of the message replied to +type QuotedMessage struct { + // From is a public key of the author of the message. + From string `json:"from"` + Content string `json:"content"` +} + // Message represents a message record in the database, // more specifically in user_messages_legacy table. // Encoding and decoding of byte blobs should be performed @@ -31,12 +42,10 @@ func (h *hexutilSQL) Scan(value interface{}) error { type Message struct { // ID calculated as keccak256(compressedAuthorPubKey, data) where data is unencrypted payload. ID string `json:"id"` - // RawPayloadHash is a Whisper envelope hash. - RawPayloadHash string `json:"rawPayloadHash"` // WhisperTimestamp is a timestamp of a Whisper envelope. WhisperTimestamp int64 `json:"whisperTimestamp"` // From is a public key of the author of the message. - From hexutilSQL `json:"from"` + From string `json:"from"` // To is a public key of the recipient unless it's a public message then it's empty. To hexutilSQL `json:"to,omitempty"` // BEGIN: fields from protocol.Message. @@ -53,4 +62,7 @@ type Message struct { Show bool `json:"show"` // default true Seen bool `json:"seen"` OutgoingStatus string `json:"outgoingStatus,omitempty"` + // MessageID of the replied message + ReplyTo string `json:"replyTo"` + QuotedMessage *QuotedMessage `json:"quotedMessage"` } diff --git a/vendor/github.com/status-im/status-protocol-go/messenger.go b/vendor/github.com/status-im/status-protocol-go/messenger.go index 923354a05..6bdcbfa54 100644 --- a/vendor/github.com/status-im/status-protocol-go/messenger.go +++ b/vendor/github.com/status-im/status-protocol-go/messenger.go @@ -6,8 +6,6 @@ import ( "database/sql" "time" - "github.com/ethereum/go-ethereum/common/hexutil" - "go.uber.org/zap" "github.com/pkg/errors" @@ -402,12 +400,16 @@ func (m *Messenger) Chats(from, to int) ([]*Chat, error) { return m.persistence.Chats(from, to) } -func (m *Messenger) DeleteChat(chatID string, chatType ChatType) error { - return m.persistence.DeleteChat(chatID, chatType) +func (m *Messenger) DeleteChat(chatID string) error { + return m.persistence.DeleteChat(chatID) } func (m *Messenger) SaveContact(contact Contact) error { - return m.persistence.SaveContact(contact) + return m.persistence.SaveContact(contact, nil) +} + +func (m *Messenger) BlockContact(contact Contact) ([]*Chat, error) { + return m.persistence.BlockContact(contact) } func (m *Messenger) Contacts() ([]*Contact, error) { @@ -617,8 +619,8 @@ func (m *Messenger) MessageByID(id string) (*Message, error) { } // DEPRECATED: required by status-react. -func (m *Messenger) MessageExists(id string) (bool, error) { - return m.persistence.MessageExists(id) +func (m *Messenger) MessagesExist(ids []string) (map[string]bool, error) { + return m.persistence.MessagesExist(ids) } // DEPRECATED: required by status-react. @@ -627,31 +629,8 @@ func (m *Messenger) MessageByChatID(chatID, cursor string, limit int) ([]*Messag } // DEPRECATED: required by status-react. -func (m *Messenger) MessagesFrom(from string) ([]*Message, error) { - publicKeyBytes, err := hexutil.Decode(from) - if err != nil { - return nil, errors.Wrap(err, "failed to decode from argument") - } - return m.persistence.MessagesFrom(publicKeyBytes) -} - -// DEPRECATED: required by status-react. -func (m *Messenger) UnseenMessageIDs() ([]string, error) { - ids, err := m.persistence.UnseenMessageIDs() - if err != nil { - return nil, err - } - - result := make([]string, 0, len(ids)) - for _, id := range ids { - result = append(result, hexutil.Encode(id)) - } - return result, nil -} - -// DEPRECATED: required by status-react. -func (m *Messenger) SaveMessage(message *Message) error { - return m.persistence.SaveMessage(message) +func (m *Messenger) SaveMessages(messages []*Message) error { + return m.persistence.SaveMessagesLegacy(messages) } // DEPRECATED: required by status-react. @@ -659,6 +638,11 @@ func (m *Messenger) DeleteMessage(id string) error { return m.persistence.DeleteMessage(id) } +// DEPRECATED: required by status-react. +func (m *Messenger) DeleteMessagesByChatID(id string) error { + return m.persistence.DeleteMessagesByChatID(id) +} + // DEPRECATED: required by status-react. func (m *Messenger) MarkMessagesSeen(ids ...string) error { return m.persistence.MarkMessagesSeen(ids...) diff --git a/vendor/github.com/status-im/status-protocol-go/migrations/migrations.go b/vendor/github.com/status-im/status-protocol-go/migrations/migrations.go index 5e10dd12b..6aef1b1b9 100644 --- a/vendor/github.com/status-im/status-protocol-go/migrations/migrations.go +++ b/vendor/github.com/status-im/status-protocol-go/migrations/migrations.go @@ -7,7 +7,7 @@ // 000003_add_contacts.down.db.sql (21B) // 000003_add_contacts.up.db.sql (251B) // 000004_user_messages_compatibility.down.sql (33B) -// 000004_user_messages_compatibility.up.sql (945B) +// 000004_user_messages_compatibility.up.sql (928B) // doc.go (377B) package migrations @@ -212,12 +212,12 @@ func _000004_user_messages_compatibilityDownSql() (*asset, error) { return nil, err } - info := bindataFileInfo{name: "000004_user_messages_compatibility.down.sql", size: 33, mode: os.FileMode(0644), modTime: time.Unix(1565597680, 0)} + info := bindataFileInfo{name: "000004_user_messages_compatibility.down.sql", size: 33, mode: os.FileMode(0644), modTime: time.Unix(1565631683, 0)} a := &asset{bytes: bytes, info: info, digest: [32]uint8{0xb9, 0xaf, 0x48, 0x80, 0x3d, 0x54, 0x5e, 0x53, 0xee, 0x98, 0x26, 0xbb, 0x99, 0x6a, 0xd8, 0x37, 0x94, 0xf2, 0xf, 0x82, 0xfa, 0xb7, 0x6a, 0x68, 0xcd, 0x8b, 0xe2, 0xc4, 0x6, 0x25, 0xdc, 0x6}} return a, nil } -var __000004_user_messages_compatibilityUpSql = []byte("\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\xff\xa4\x92\x41\x6f\x9b\x4e\x10\xc5\xef\xfe\x14\x73\xb3\x2d\x99\xbf\x72\x88\x72\xc9\x09\x3b\xeb\x7f\x51\x29\x44\x18\x57\xc9\x69\xb5\x86\x29\xac\x0a\xbb\x68\x67\x28\x45\xca\x87\xaf\x08\x38\x0a\xae\x7d\x2a\x07\x0e\xfb\x7b\xb3\xf3\xf6\xcd\x78\x1e\x04\xbc\x24\xd0\x75\x63\x1d\x2b\xc3\xc0\xa5\x1a\x7e\x9a\x80\xd5\xa9\x42\x28\x15\x81\xb3\x9d\xce\x41\x11\x74\x08\x0e\xab\x1e\xac\x01\xcd\x0b\xcf\x83\xae\x44\x33\x14\x57\x58\xa3\x61\x6d\x0a\xd0\xe6\x87\x36\x9a\xd1\xa3\xcc\xd9\xaa\xfa\x6f\xb1\x4b\x84\x9f\x0a\x48\xfd\x6d\x28\x20\xd8\x43\x14\xa7\x20\x5e\x82\x43\x7a\x80\x96\xd0\xc9\x1a\x89\x54\x81\x24\x2b\x2c\x54\xd6\xc3\x6a\x01\x00\xa0\x73\xf8\xee\x27\xbb\x2f\x7e\x02\xcf\x49\xf0\xcd\x4f\x5e\xe1\xab\x78\x85\x38\x82\x5d\x1c\xed\xc3\x60\x97\x42\x22\x9e\x43\x7f\x27\x36\xef\x7a\xa7\x3a\xd9\xa8\xbe\xb2\x2a\x97\xa5\xa2\xf2\xa3\x7a\x68\x17\x1d\xc3\x70\x94\x75\xa5\xa6\x06\x9d\x64\x5d\x23\xb1\xaa\x1b\x08\xa2\x54\xfc\x2f\x2e\x75\x64\x5b\x97\x21\x6c\xc3\x78\x7b\x41\x72\x24\xd6\x46\xb1\xb6\xe6\x1d\x8f\xa7\x99\x35\x8c\x86\x6f\x74\x9d\xa8\xe4\xbe\xc1\x1b\x92\x21\x0a\xa3\xea\x0f\x3c\x9e\xce\x6c\x5e\x5e\x5a\x2a\x96\x9f\x62\x9a\x53\x87\xec\x7a\x99\xd9\xd6\xf0\xac\x16\x9e\xc4\xde\x3f\x86\x29\xdc\x8d\xba\x29\xfd\x99\xb5\x39\x21\x56\xdc\xd2\x9c\x65\x95\xcd\x7e\xca\x5f\xaa\x6a\xf1\x8a\x33\x2a\x6d\x07\xdb\x38\x0e\x85\x1f\xfd\xdd\x38\x4d\x8e\xd3\xcc\x08\xd1\xdc\xd6\xed\xfd\xf0\x30\x09\x6d\xcb\x85\xd5\xa6\xb8\xf0\xb2\x58\x3f\x2e\xce\xeb\x15\x44\x4f\xe2\x05\x74\xfe\x5b\x4e\xa3\x8b\xa3\xab\xeb\xb5\x1a\xf1\xfa\xf1\x4a\x21\x2a\x97\x95\xf2\xd4\xcb\x73\xb4\x71\x04\xd7\x2f\x19\xed\xb7\x27\x62\xb7\x5a\xde\xfd\xe3\xb7\x84\xb7\xb7\xcf\x89\x6e\xc0\x7b\xb8\xdf\xc0\xc3\xfd\x7a\x00\x3a\xdf\x9c\x47\x3d\xbc\xf7\x4f\x00\x00\x00\xff\xff\x6b\xae\x37\x6d\xb1\x03\x00\x00") +var __000004_user_messages_compatibilityUpSql = []byte("\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\xff\xa4\x92\xcf\x6e\x9b\x4c\x14\xc5\xf7\x7e\x8a\xb3\xb3\x2d\x99\x4f\x59\x44\xd9\x64\x85\x9d\xf1\x57\x54\x0a\x11\xc6\x55\xb2\x1a\x8d\xf1\xad\x19\x15\x66\xac\x99\x4b\x5d\xa4\x3c\x7c\x85\xc1\x51\x70\x9d\x55\x59\xb0\x98\xdf\x39\x33\xe7\xfe\x09\x02\x44\x3c\xf5\xd0\xf5\xd1\x3a\x56\x86\xc1\xa5\xea\x7e\xda\x83\xd5\xae\x22\x94\xca\xc3\xd9\x93\xde\x43\x79\x9c\x08\x8e\xaa\x16\xd6\x40\xf3\x24\x08\x70\x2a\xc9\x74\xe6\x8a\x6a\x32\xac\xcd\x01\xda\xfc\xd0\x46\x33\x05\xbe\x70\xb6\xaa\xfe\x9b\xac\x32\x11\xe6\x02\x79\xb8\x8c\x05\xa2\x35\x92\x34\x87\x78\x89\x36\xf9\x06\x8d\x27\x27\x6b\xf2\x5e\x1d\xc8\xcb\x8a\x0e\xaa\x68\x31\x9b\x00\x80\xde\xe3\x7b\x98\xad\xbe\x84\x19\x9e\xb3\xe8\x5b\x98\xbd\xe2\xab\x78\x45\x9a\x60\x95\x26\xeb\x38\x5a\xe5\xc8\xc4\x73\x1c\xae\xc4\xe2\xac\x3f\x95\xda\x1f\xc9\x49\xd6\x35\x79\x56\xf5\x11\x51\x92\x8b\xff\x45\x76\x7e\x2f\xd9\xc6\x71\xaf\xf3\xb6\x71\x05\x61\x19\xa7\xcb\x2b\xb2\x27\xcf\xda\x28\xd6\xd6\x9c\x71\x7f\x5a\x58\xc3\x64\xf8\x3d\xcc\xd8\x33\x50\xc9\xed\x91\x3e\x91\x74\x35\x1a\x55\xbf\xe3\xfe\x74\x14\xf3\xfa\xd2\x52\xb1\xfc\x50\xff\x98\x3a\x62\xd7\xca\xc2\x36\x86\x47\x5e\x3c\x89\x75\xb8\x8d\x73\xdc\x5d\x74\xc7\xaa\x95\x6c\xc7\xef\x0e\xcd\x1e\x05\x1e\x13\xcf\x8a\x1b\x3f\x66\x45\x65\x8b\x9f\xf2\x97\xaa\x1a\xba\x91\xd7\x97\xf6\x84\x65\x9a\xc6\x22\x4c\xfe\x8e\x93\x67\xdb\x61\x44\x9e\xc8\x7c\xae\x5b\x87\xf1\x66\x10\xda\x86\x0f\x56\x9b\xc3\x55\x96\xc9\xfc\x71\x72\xd9\xa6\x28\x79\x12\x2f\xd0\xfb\xdf\x72\x18\x68\x9a\xdc\xdc\xa6\x59\x8f\xe7\x8f\x37\x8c\xa4\x5c\x51\xca\x5d\x2b\x2f\x0d\x4f\x13\xdc\xbe\xa4\x8f\xdf\xec\x3c\xbb\xd9\xf4\xee\x1f\xbf\x29\xde\xde\x3e\x76\x74\x81\xe0\xe1\x7e\x81\x87\xfb\x79\x07\xf4\x7e\x71\x59\x80\xae\xde\x3f\x01\x00\x00\xff\xff\xba\x8f\x77\x72\xa0\x03\x00\x00") func _000004_user_messages_compatibilityUpSqlBytes() ([]byte, error) { return bindataRead( @@ -232,8 +232,8 @@ func _000004_user_messages_compatibilityUpSql() (*asset, error) { return nil, err } - info := bindataFileInfo{name: "000004_user_messages_compatibility.up.sql", size: 945, mode: os.FileMode(0644), modTime: time.Unix(1565597764, 0)} - a := &asset{bytes: bytes, info: info, digest: [32]uint8{0xde, 0x18, 0x87, 0x4e, 0x97, 0xf9, 0x2f, 0x18, 0x2, 0x34, 0x53, 0x24, 0x23, 0x2a, 0x6c, 0xe1, 0xba, 0x34, 0x44, 0x72, 0x24, 0x14, 0xf0, 0x3a, 0x2e, 0x56, 0x77, 0xd0, 0xfe, 0x5f, 0x45, 0x63}} + info := bindataFileInfo{name: "000004_user_messages_compatibility.up.sql", size: 928, mode: os.FileMode(0644), modTime: time.Unix(1565697832, 0)} + a := &asset{bytes: bytes, info: info, digest: [32]uint8{0xdf, 0xc4, 0x5c, 0xed, 0x4, 0x26, 0xb1, 0xb2, 0x53, 0xac, 0x1, 0x20, 0xf3, 0x17, 0x37, 0xb3, 0x3d, 0x84, 0x5e, 0xd8, 0x1, 0x53, 0x88, 0x9a, 0x9c, 0xaf, 0x9, 0xdf, 0x58, 0x2e, 0xf0, 0x19}} return a, nil } diff --git a/vendor/github.com/status-im/status-protocol-go/persistence.go b/vendor/github.com/status-im/status-protocol-go/persistence.go index 175706a5c..66cc2f530 100644 --- a/vendor/github.com/status-im/status-protocol-go/persistence.go +++ b/vendor/github.com/status-im/status-protocol-go/persistence.go @@ -7,7 +7,6 @@ import ( "database/sql" "encoding/gob" "encoding/hex" - "fmt" "time" "github.com/ethereum/go-ethereum/crypto" @@ -44,14 +43,8 @@ func (db sqlitePersistence) LastMessageClock(chatID string) (int64, error) { return last.Int64, nil } -func formatChatID(chatID string, chatType ChatType) string { - return fmt.Sprintf("%s-%d", chatID, chatType) -} - func (db sqlitePersistence) SaveChat(chat Chat) error { var err error - // We build the db chatID using the type, so that we have no clashes - chatID := formatChatID(chat.ID, chat.ChatType) pkey := []byte{} // For one to one chatID is an encoded public key @@ -93,7 +86,7 @@ func (db sqlitePersistence) SaveChat(chat Chat) error { defer stmt.Close() _, err = stmt.Exec( - chatID, + chat.ID, chat.Name, chat.Color, chat.Active, @@ -115,15 +108,35 @@ func (db sqlitePersistence) SaveChat(chat Chat) error { return err } -func (db sqlitePersistence) DeleteChat(chatID string, chatType ChatType) error { - dbChatID := formatChatID(chatID, chatType) - _, err := db.db.Exec("DELETE FROM chats WHERE id = ?", dbChatID) +func (db sqlitePersistence) DeleteChat(chatID string) error { + _, err := db.db.Exec("DELETE FROM chats WHERE id = ?", chatID) return err } func (db sqlitePersistence) Chats(from, to int) ([]*Chat, error) { + return db.chats(from, to, nil) +} - rows, err := db.db.Query(`SELECT +func (db sqlitePersistence) chats(from, to int, tx *sql.Tx) ([]*Chat, error) { + var err error + + if tx == nil { + tx, err = db.db.BeginTx(context.Background(), &sql.TxOptions{}) + if err != nil { + return nil, err + } + defer func() { + if err == nil { + err = tx.Commit() + return + + } + // don't shadow original error + _ = tx.Rollback() + }() + } + + rows, err := tx.Query(`SELECT id, name, color, @@ -148,6 +161,9 @@ func (db sqlitePersistence) Chats(from, to int) ([]*Chat, error) { var response []*Chat for rows.Next() { + var lastMessageContentType sql.NullString + var lastMessageContent sql.NullString + chat := &Chat{} encodedMembers := []byte{} encodedMembershipUpdates := []byte{} @@ -163,17 +179,16 @@ func (db sqlitePersistence) Chats(from, to int) ([]*Chat, error) { &pkey, &chat.UnviewedMessagesCount, &chat.LastClockValue, - &chat.LastMessageContentType, - &chat.LastMessageContent, + &lastMessageContentType, + &lastMessageContent, &encodedMembers, &encodedMembershipUpdates, ) if err != nil { return nil, err } - - // Restore the backward compatible ID - chat.ID = chat.ID[:len(chat.ID)-2] + chat.LastMessageContent = lastMessageContent.String + chat.LastMessageContentType = lastMessageContentType.String // Restore members membersDecoder := gob.NewDecoder(bytes.NewBuffer(encodedMembers)) @@ -254,7 +269,25 @@ func (db sqlitePersistence) Contacts() ([]*Contact, error) { return response, nil } -func (db sqlitePersistence) SaveContact(contact Contact) error { +func (db sqlitePersistence) SaveContact(contact Contact, tx *sql.Tx) error { + var err error + + if tx == nil { + tx, err = db.db.BeginTx(context.Background(), &sql.TxOptions{}) + if err != nil { + return err + } + defer func() { + if err == nil { + err = tx.Commit() + return + + } + // don't shadow original error + _ = tx.Rollback() + }() + } + // Encode device info var encodedDeviceInfo bytes.Buffer deviceInfoEncoder := gob.NewEncoder(&encodedDeviceInfo) @@ -272,7 +305,7 @@ func (db sqlitePersistence) SaveContact(contact Contact) error { } // Insert record - stmt, err := db.db.Prepare(`INSERT INTO contacts( + stmt, err := tx.Prepare(`INSERT INTO contacts( id, address, name, diff --git a/vendor/github.com/status-im/status-protocol-go/persistence_legacy.go b/vendor/github.com/status-im/status-protocol-go/persistence_legacy.go index 6338469a4..6e75a5e29 100644 --- a/vendor/github.com/status-im/status-protocol-go/persistence_legacy.go +++ b/vendor/github.com/status-im/status-protocol-go/persistence_legacy.go @@ -1,6 +1,7 @@ package statusproto import ( + "context" "database/sql" "fmt" "strings" @@ -14,7 +15,6 @@ var ( func (db sqlitePersistence) tableUserMessagesLegacyAllFields() string { return `id, - raw_payload_hash, whisper_timestamp, source, destination, @@ -29,7 +29,30 @@ func (db sqlitePersistence) tableUserMessagesLegacyAllFields() string { clock_value, show, seen, - outgoing_status` + outgoing_status, + reply_to` +} + +func (db sqlitePersistence) tableUserMessagesLegacyAllFieldsJoin() string { + return `m1.id, + m1.whisper_timestamp, + m1.source, + m1.destination, + m1.content, + m1.content_type, + m1.username, + m1.timestamp, + m1.chat_id, + m1.retry_count, + m1.message_type, + m1.message_status, + m1.clock_value, + m1.show, + m1.seen, + m1.outgoing_status, + m1.reply_to, + m2.source, + m2.content` } func (db sqlitePersistence) tableUserMessagesLegacyAllFieldsCount() int { @@ -41,9 +64,10 @@ type scanner interface { } func (db sqlitePersistence) tableUserMessagesLegacyScanAllFields(row scanner, message *Message, others ...interface{}) error { + var quotedContent sql.NullString + var quotedFrom sql.NullString args := []interface{}{ &message.ID, - &message.RawPayloadHash, &message.WhisperTimestamp, &message.From, // source in table &message.To, // destination in table @@ -59,14 +83,27 @@ func (db sqlitePersistence) tableUserMessagesLegacyScanAllFields(row scanner, me &message.Show, &message.Seen, &message.OutgoingStatus, + &message.ReplyTo, + "edFrom, + "edContent, } - return row.Scan(append(args, others...)...) + err := row.Scan(append(args, others...)...) + if err != nil { + return err + } + + if quotedContent.Valid { + message.QuotedMessage = &QuotedMessage{ + From: quotedFrom.String, + Content: quotedContent.String, + } + } + return nil } func (db sqlitePersistence) tableUserMessagesLegacyAllValues(message *Message) []interface{} { return []interface{}{ message.ID, - message.RawPayloadHash, message.WhisperTimestamp, message.From, // source in table message.To, // destination in table @@ -82,21 +119,26 @@ func (db sqlitePersistence) tableUserMessagesLegacyAllValues(message *Message) [ message.Show, message.Seen, message.OutgoingStatus, + message.ReplyTo, } } func (db sqlitePersistence) MessageByID(id string) (*Message, error) { var message Message - allFields := db.tableUserMessagesLegacyAllFields() + allFields := db.tableUserMessagesLegacyAllFieldsJoin() row := db.db.QueryRow( fmt.Sprintf(` SELECT %s FROM - user_messages_legacy + user_messages_legacy m1 + LEFT JOIN + user_messages_legacy m2 + ON + m1.reply_to = m2.id WHERE - id = ? + m1.id = ? `, allFields), id, ) @@ -111,17 +153,35 @@ func (db sqlitePersistence) MessageByID(id string) (*Message, error) { } } -func (db sqlitePersistence) MessageExists(id string) (bool, error) { - var result bool - err := db.db.QueryRow(`SELECT EXISTS(SELECT 1 FROM user_messages_legacy WHERE id = ?)`, id).Scan(&result) - switch err { - case sql.ErrNoRows: - return false, errRecordNotFound - case nil: +func (db sqlitePersistence) MessagesExist(ids []string) (map[string]bool, error) { + result := make(map[string]bool) + if len(ids) == 0 { return result, nil - default: - return false, err } + + idsArgs := make([]interface{}, 0, len(ids)) + for _, id := range ids { + idsArgs = append(idsArgs, id) + } + + inVector := strings.Repeat("?, ", len(ids)-1) + "?" + query := fmt.Sprintf(`SELECT id FROM user_messages_legacy WHERE id IN (%s)`, inVector) + rows, err := db.db.Query(query, idsArgs...) + if err != nil { + return nil, err + } + defer rows.Close() + + for rows.Next() { + var id string + err := rows.Scan(&id) + if err != nil { + return nil, err + } + result[id] = true + } + + return result, nil } // MessageByChatID returns all messages for a given chatID in descending order. @@ -132,7 +192,7 @@ func (db sqlitePersistence) MessageByChatID(chatID string, currCursor string, li if currCursor != "" { cursorWhere = "AND cursor <= ?" } - allFields := db.tableUserMessagesLegacyAllFields() + allFields := db.tableUserMessagesLegacyAllFieldsJoin() args := []interface{}{chatID} if currCursor != "" { args = append(args, currCursor) @@ -144,11 +204,15 @@ func (db sqlitePersistence) MessageByChatID(chatID string, currCursor string, li fmt.Sprintf(` SELECT %s, - substr('0000000000000000000000000000000000000000000000000000000000000000' || clock_value, -64, 64) || id as cursor + substr('0000000000000000000000000000000000000000000000000000000000000000' || m1.clock_value, -64, 64) || m1.id as cursor FROM - user_messages_legacy + user_messages_legacy m1 + LEFT JOIN + user_messages_legacy m2 + ON + m1.reply_to = m2.id WHERE - chat_id = ? %s + m1.chat_id = ? %s ORDER BY cursor DESC LIMIT ? `, allFields, cursorWhere), @@ -183,61 +247,42 @@ func (db sqlitePersistence) MessageByChatID(chatID string, currCursor string, li return result, newCursor, nil } -func (db sqlitePersistence) MessagesFrom(from []byte) ([]*Message, error) { - allFields := db.tableUserMessagesLegacyAllFields() - rows, err := db.db.Query( - fmt.Sprintf(` - SELECT - %s - FROM - user_messages_legacy - WHERE - source = ? - `, allFields), - from, +func (db sqlitePersistence) SaveMessagesLegacy(messages []*Message) error { + var ( + tx *sql.Tx + stmt *sql.Stmt + err error ) + tx, err = db.db.BeginTx(context.Background(), &sql.TxOptions{}) if err != nil { - return nil, err + return err } - defer rows.Close() + defer func() { + if err == nil { + err = tx.Commit() + return - var result []*Message - for rows.Next() { - var message Message - if err := db.tableUserMessagesLegacyScanAllFields(rows, &message); err != nil { - return nil, err } - result = append(result, &message) - } - return result, nil -} + // don't shadow original error + _ = tx.Rollback() + }() -func (db sqlitePersistence) UnseenMessageIDs() ([][]byte, error) { - rows, err := db.db.Query(`SELECT id FROM user_messages_legacy WHERE seen = 0`) - if err != nil { - return nil, err - } - defer rows.Close() - - var result [][]byte - for rows.Next() { - var id []byte - if err := rows.Scan(&id); err != nil { - return nil, err - } - result = append(result, id) - } - return result, nil -} - -func (db sqlitePersistence) SaveMessage(m *Message) error { allFields := db.tableUserMessagesLegacyAllFields() valuesVector := strings.Repeat("?, ", db.tableUserMessagesLegacyAllFieldsCount()-1) + "?" + query := fmt.Sprintf(`INSERT INTO user_messages_legacy(%s) VALUES (%s)`, allFields, valuesVector) - _, err := db.db.Exec( - query, - db.tableUserMessagesLegacyAllValues(m)..., - ) + + stmt, err = tx.Prepare(query) + if err != nil { + return err + } + + for _, msg := range messages { + _, err := stmt.Exec(db.tableUserMessagesLegacyAllValues(msg)...) + if err != nil { + return err + } + } return err } @@ -246,6 +291,11 @@ func (db sqlitePersistence) DeleteMessage(id string) error { return err } +func (db sqlitePersistence) DeleteMessagesByChatID(id string) error { + _, err := db.db.Exec(`DELETE FROM user_messages_legacy WHERE chat_id = ?`, id) + return err +} + func (db sqlitePersistence) MarkMessagesSeen(ids ...string) error { idsArgs := make([]interface{}, 0, len(ids)) for _, id := range ids { @@ -271,3 +321,61 @@ func (db sqlitePersistence) UpdateMessageOutgoingStatus(id string, newOutgoingSt `, newOutgoingStatus, id) return err } + +// BlockContact updates a contact, deletes all the messages and 1-to-1 chat, updates the unread messages count and returns a map with the new count +func (db sqlitePersistence) BlockContact(contact Contact) ([]*Chat, error) { + var ( + tx *sql.Tx + err error + ) + tx, err = db.db.BeginTx(context.Background(), &sql.TxOptions{}) + if err != nil { + return nil, err + } + defer func() { + if err == nil { + err = tx.Commit() + return + + } + // don't shadow original error + _ = tx.Rollback() + }() + + // Delete messages + _, err = tx.Exec( + `DELETE + FROM user_messages_legacy + WHERE source = ?`, + contact.ID, + ) + if err != nil { + return nil, err + } + + // Update contact + err = db.SaveContact(contact, tx) + if err != nil { + return nil, err + } + + // Delete one-to-one chat + _, err = tx.Exec("DELETE FROM chats WHERE id = ?", contact.ID) + if err != nil { + return nil, err + } + + // Recalculate denormalized fields + _, err = tx.Exec(` + UPDATE chats + SET + unviewed_message_count = (SELECT COUNT(1) FROM user_messages_legacy WHERE seen = 0 AND chat_id = chats.id), + last_message_content = (SELECT content from user_messages_legacy WHERE chat_id = chats.id ORDER BY clock_value DESC LIMIT 1), + last_message_content_type = (SELECT content_type from user_messages_legacy WHERE chat_id = chats.id ORDER BY clock_value DESC LIMIT 1)`) + if err != nil { + return nil, err + } + + // return the updated chats + return db.chats(0, -1, tx) +} diff --git a/vendor/github.com/status-im/status-protocol-go/v1/message.go b/vendor/github.com/status-im/status-protocol-go/v1/message.go index 0881718bc..641b7fbeb 100644 --- a/vendor/github.com/status-im/status-protocol-go/v1/message.go +++ b/vendor/github.com/status-im/status-protocol-go/v1/message.go @@ -8,6 +8,7 @@ import ( "strings" "time" + "github.com/ethereum/go-ethereum/common/hexutil" "github.com/ethereum/go-ethereum/crypto" "github.com/golang/protobuf/proto" "github.com/pkg/errors" @@ -146,7 +147,7 @@ func EncodeMessage(value Message) ([]byte, error) { // MessageID calculates the messageID from author's compressed public key // and not encrypted but encoded payload. -func MessageID(author *ecdsa.PublicKey, data []byte) []byte { +func MessageID(author *ecdsa.PublicKey, data []byte) hexutil.Bytes { keyBytes := crypto.FromECDSAPub(author) return crypto.Keccak256(append(keyBytes, data...)) } diff --git a/vendor/github.com/status-im/status-protocol-go/v1/status_message.go b/vendor/github.com/status-im/status-protocol-go/v1/status_message.go index c8fd5b425..f68359387 100644 --- a/vendor/github.com/status-im/status-protocol-go/v1/status_message.go +++ b/vendor/github.com/status-im/status-protocol-go/v1/status_message.go @@ -5,6 +5,7 @@ import ( "github.com/pkg/errors" "log" + "github.com/ethereum/go-ethereum/common/hexutil" "github.com/ethereum/go-ethereum/crypto" "github.com/golang/protobuf/proto" "github.com/jinzhu/copier" @@ -27,7 +28,7 @@ type StatusMessage struct { DecryptedPayload []byte // ID is the canonical ID of the message - ID []byte + ID hexutil.Bytes // Hash is the transport layer hash Hash []byte diff --git a/vendor/modules.txt b/vendor/modules.txt index c478e2e69..656f555d3 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -317,7 +317,7 @@ github.com/status-im/migrate/v4/database/sqlcipher github.com/status-im/rendezvous github.com/status-im/rendezvous/protocol github.com/status-im/rendezvous/server -# github.com/status-im/status-protocol-go v0.0.0-20190701094942-c2b7b022b722d7bebe1c6d6f05cdead79f1b57bd +# github.com/status-im/status-protocol-go v0.0.0-20190701094942-9f0db157bf5a1ac9b52c47bbea37fc7dbe14e8fe github.com/status-im/status-protocol-go/zaputil github.com/status-im/status-protocol-go github.com/status-im/status-protocol-go/encryption/multidevice