From 67dfff232433890b9d316dce763305359056c321 Mon Sep 17 00:00:00 2001 From: Andrea Maria Piana Date: Thu, 29 Feb 2024 09:51:38 +0000 Subject: [PATCH] Add index to hash ratchet & cache processed description --- protocol/communities/community.go | 5 + protocol/communities/manager.go | 43 +++-- protocol/communities/persistence.go | 151 ++++++++++++++++++ protocol/communities/persistence_test.go | 108 +++++++++++++ protocol/encryption/migrations/migrations.go | 23 +++ .../1709200114_add_migration_index.up.sql | 11 ++ protocol/encryption/persistence.go | 70 ++++---- protocol/encryption/persistence_test.go | 28 ++++ protocol/encryption/protocol.go | 19 ++- protocol/messenger.go | 14 ++ protocol/messenger_storenode_request_test.go | 11 +- protocol/migrations/migrations.go | 23 +++ ...431_add_community_description_cache.up.sql | 16 ++ 13 files changed, 465 insertions(+), 57 deletions(-) create mode 100644 protocol/encryption/migrations/sqlite/1709200114_add_migration_index.up.sql create mode 100644 protocol/migrations/sqlite/1709828431_add_community_description_cache.up.sql diff --git a/protocol/communities/community.go b/protocol/communities/community.go index f7613beae..c55a3e2d7 100644 --- a/protocol/communities/community.go +++ b/protocol/communities/community.go @@ -107,6 +107,7 @@ type CommunityChat struct { ViewersCanPostReactions bool `json:"viewersCanPostReactions"` Position int `json:"position"` CategoryID string `json:"categoryID"` + TokenGated bool `json:"tokenGated"` } type CommunityCategory struct { @@ -194,6 +195,7 @@ func (o *Community) MarshalPublicAPIJSON() ([]byte, error) { Members: c.Members, CanPost: canPost, ViewersCanPostReactions: c.ViewersCanPostReactions, + TokenGated: o.channelEncrypted(id), CategoryID: c.CategoryId, Position: int(c.Position), } @@ -279,8 +281,10 @@ func (o *Community) MarshalJSON() ([]byte, error) { PubsubTopicKey string `json:"pubsubTopicKey"` Shard *shard.Shard `json:"shard"` LastOpenedAt int64 `json:"lastOpenedAt"` + Clock uint64 `json:"clock"` }{ ID: o.ID(), + Clock: o.Clock(), MemberRole: o.MemberRole(o.MemberIdentity()), IsControlNode: o.IsControlNode(), Verified: o.config.Verified, @@ -331,6 +335,7 @@ func (o *Community) MarshalJSON() ([]byte, error) { Members: c.Members, CanPost: canPost, ViewersCanPostReactions: c.ViewersCanPostReactions, + TokenGated: o.channelEncrypted(id), CategoryID: c.CategoryId, Position: int(c.Position), } diff --git a/protocol/communities/manager.go b/protocol/communities/manager.go index a0c1d039d..2c08951cf 100644 --- a/protocol/communities/manager.go +++ b/protocol/communities/manager.go @@ -1723,11 +1723,10 @@ func (m *Manager) HandleCommunityDescriptionMessage(signer *ecdsa.PublicKey, des id = crypto.CompressPubkey(signer) } - failedToDecrypt, err := m.preprocessDescription(id, description) + failedToDecrypt, processedDescription, err := m.preprocessDescription(id, description) if err != nil { return nil, err } - m.communityLock.Lock(id) defer m.communityLock.Unlock(id) community, err := m.GetByID(id) @@ -1737,12 +1736,12 @@ func (m *Manager) HandleCommunityDescriptionMessage(signer *ecdsa.PublicKey, des // We don't process failed to decrypt if the whole metadata is encrypted // and we joined the community already - if community != nil && community.Joined() && len(failedToDecrypt) != 0 && description != nil && len(description.Members) == 0 { + if community != nil && community.Joined() && len(failedToDecrypt) != 0 && processedDescription != nil && len(processedDescription.Members) == 0 { return &CommunityResponse{FailedToDecrypt: failedToDecrypt}, nil } // We should queue only if the community has a token owner, and the owner has been verified - hasTokenOwnership := HasTokenOwnership(description) + hasTokenOwnership := HasTokenOwnership(processedDescription) shouldQueue := hasTokenOwnership && verifiedOwner == nil if community == nil { @@ -1751,7 +1750,7 @@ func (m *Manager) HandleCommunityDescriptionMessage(signer *ecdsa.PublicKey, des return nil, err } config := Config{ - CommunityDescription: description, + CommunityDescription: processedDescription, Logger: m.logger, CommunityDescriptionProtocolMessage: payload, MemberIdentity: &m.identity.PublicKey, @@ -1772,15 +1771,15 @@ func (m *Manager) HandleCommunityDescriptionMessage(signer *ecdsa.PublicKey, des // A new community, we need to check if we need to validate async. // That would be the case if it has a contract. We queue everything and process separately. if shouldQueue { - return nil, m.Queue(signer, community, description.Clock, payload) + return nil, m.Queue(signer, community, processedDescription.Clock, payload) } } else { // only queue if already known control node is different than the signer // and if the clock is greater shouldQueue = shouldQueue && !common.IsPubKeyEqual(community.ControlNode(), signer) && - community.config.CommunityDescription.Clock < description.Clock + community.config.CommunityDescription.Clock < processedDescription.Clock if shouldQueue { - return nil, m.Queue(signer, community, description.Clock, payload) + return nil, m.Queue(signer, community, processedDescription.Clock, payload) } } @@ -1806,7 +1805,7 @@ func (m *Manager) HandleCommunityDescriptionMessage(signer *ecdsa.PublicKey, des return nil, ErrNotAuthorized } - r, err := m.handleCommunityDescriptionMessageCommon(community, description, payload, verifiedOwner) + r, err := m.handleCommunityDescriptionMessageCommon(community, processedDescription, payload, verifiedOwner) if err != nil { return nil, err } @@ -1814,10 +1813,22 @@ func (m *Manager) HandleCommunityDescriptionMessage(signer *ecdsa.PublicKey, des return r, nil } -func (m *Manager) preprocessDescription(id types.HexBytes, description *protobuf.CommunityDescription) ([]*CommunityPrivateDataFailedToDecrypt, error) { +func (m *Manager) NewHashRatchetKeys(keys []*encryption.HashRatchetInfo) error { + return m.persistence.InvalidateDecryptedCommunityCacheForKeys(keys) +} + +func (m *Manager) preprocessDescription(id types.HexBytes, description *protobuf.CommunityDescription) ([]*CommunityPrivateDataFailedToDecrypt, *protobuf.CommunityDescription, error) { + decryptedCommunity, err := m.persistence.GetDecryptedCommunityDescription(id, description.Clock) + if err != nil { + return nil, nil, err + } + if decryptedCommunity != nil { + return nil, decryptedCommunity, nil + } + response, err := decryptDescription(id, m, description, m.logger) if err != nil { - return response, err + return response, description, err } upgradeTokenPermissions(description) @@ -1825,7 +1836,7 @@ func (m *Manager) preprocessDescription(id types.HexBytes, description *protobuf // Workaround for https://github.com/status-im/status-desktop/issues/12188 hydrateChannelsMembers(types.EncodeHex(id), description) - return response, nil + return response, description, m.persistence.SaveDecryptedCommunityDescription(id, response, description) } func (m *Manager) handleCommunityDescriptionMessageCommon(community *Community, description *protobuf.CommunityDescription, payload []byte, newControlNode *ecdsa.PublicKey) (*CommunityResponse, error) { @@ -3035,12 +3046,12 @@ func (m *Manager) HandleCommunityRequestToJoinResponse(signer *ecdsa.PublicKey, return nil, ErrNotAuthorized } - _, err = m.preprocessDescription(community.ID(), request.Community) + _, processedDescription, err := m.preprocessDescription(community.ID(), request.Community) if err != nil { return nil, err } - _, err = community.UpdateCommunityDescription(request.Community, appMetadataMsg, nil) + _, err = community.UpdateCommunityDescription(processedDescription, appMetadataMsg, nil) if err != nil { return nil, err } @@ -3391,11 +3402,13 @@ func (m *Manager) dbRecordBundleToCommunity(r *CommunityRecordBundle) (*Communit } return recordBundleToCommunity(r, &m.identity.PublicKey, m.installationID, m.logger, m.timesource, descriptionEncryptor, func(community *Community) error { - _, err := m.preprocessDescription(community.ID(), community.config.CommunityDescription) + _, description, err := m.preprocessDescription(community.ID(), community.config.CommunityDescription) if err != nil { return err } + community.config.CommunityDescription = description + if community.config.EventsData != nil { eventsDescription, err := validateAndGetEventsMessageCommunityDescription(community.config.EventsData.EventsBaseCommunityDescription, community.ControlNode()) if err != nil { diff --git a/protocol/communities/persistence.go b/protocol/communities/persistence.go index 10267394b..12a5acb2d 100644 --- a/protocol/communities/persistence.go +++ b/protocol/communities/persistence.go @@ -18,6 +18,7 @@ import ( "github.com/status-im/status-go/protocol/common" "github.com/status-im/status-go/protocol/common/shard" "github.com/status-im/status-go/protocol/communities/token" + "github.com/status-im/status-go/protocol/encryption" "github.com/status-im/status-go/protocol/protobuf" "github.com/status-im/status-go/services/wallet/bigint" ) @@ -1872,3 +1873,153 @@ func (p *Persistence) UpsertAppliedCommunityEvents(communityID types.HexBytes, p } return err } + +func (p *Persistence) InvalidateDecryptedCommunityCacheForKeys(keys []*encryption.HashRatchetInfo) error { + tx, err := p.db.BeginTx(context.Background(), &sql.TxOptions{}) + if err != nil { + return err + } + + defer func() { + if err == nil { + err = tx.Commit() + return + } + // don't shadow original error + _ = tx.Rollback() + }() + + if len(keys) == 0 { + return nil + } + idsArgs := make([]interface{}, 0, len(keys)) + for _, k := range keys { + idsArgs = append(idsArgs, k.KeyID) + } + + inVector := strings.Repeat("?, ", len(keys)-1) + "?" + + query := "SELECT DISTINCT(community_id) FROM encrypted_community_description_missing_keys WHERE key_id IN (" + inVector + ")" // nolint: gosec + + var communityIDs []interface{} + rows, err := tx.Query(query, idsArgs...) + if err != nil { + return err + } + defer rows.Close() + + for rows.Next() { + var communityID []byte + err = rows.Scan(&communityID) + if err != nil { + return err + } + communityIDs = append(communityIDs, communityID) + } + if len(communityIDs) == 0 { + return nil + } + + inVector = strings.Repeat("?, ", len(communityIDs)-1) + "?" + + query = "DELETE FROM encrypted_community_description_cache WHERE community_id IN (" + inVector + ")" //nolint: gosec + _, err = tx.Exec(query, communityIDs...) + + return err +} + +func (p *Persistence) SaveDecryptedCommunityDescription(communityID []byte, missingKeys []*CommunityPrivateDataFailedToDecrypt, description *protobuf.CommunityDescription) error { + if description == nil { + return nil + } + marshaledDescription, err := proto.Marshal(description) + if err != nil { + return err + } + tx, err := p.db.BeginTx(context.Background(), &sql.TxOptions{}) + if err != nil { + return err + } + + defer func() { + if err == nil { + err = tx.Commit() + return + } + // don't shadow original error + _ = tx.Rollback() + }() + previousCommunity, err := p.getDecryptedCommunityDescriptionByID(tx, communityID) + if err != nil { + return err + } + + if previousCommunity != nil && previousCommunity.Clock >= description.Clock { + return nil + } + + insertCommunity := "INSERT INTO encrypted_community_description_cache (community_id, clock, description) VALUES (?, ?, ?);" + _, err = tx.Exec(insertCommunity, communityID, description.Clock, marshaledDescription) + if err != nil { + return err + } + for _, key := range missingKeys { + insertKey := "INSERT INTO encrypted_community_description_missing_keys (community_id, key_id) VALUES(?, ?)" + _, err = tx.Exec(insertKey, communityID, key.KeyID) + if err != nil { + return err + } + } + + return nil +} + +func (p *Persistence) GetDecryptedCommunityDescription(communityID []byte, clock uint64) (*protobuf.CommunityDescription, error) { + return p.getDecryptedCommunityDescriptionByIDAndClock(communityID, clock) +} + +func (p *Persistence) getDecryptedCommunityDescriptionByIDAndClock(communityID []byte, clock uint64) (*protobuf.CommunityDescription, error) { + query := "SELECT description FROM encrypted_community_description_cache WHERE community_id = ? AND clock = ?" + + qr := p.db.QueryRow(query, communityID, clock) + + var descriptionBytes []byte + + err := qr.Scan(&descriptionBytes) + switch err { + case sql.ErrNoRows: + return nil, nil + case nil: + var communityDescription protobuf.CommunityDescription + err := proto.Unmarshal(descriptionBytes, &communityDescription) + if err != nil { + return nil, err + } + return &communityDescription, nil + default: + return nil, err + } +} + +func (p *Persistence) getDecryptedCommunityDescriptionByID(tx *sql.Tx, communityID []byte) (*protobuf.CommunityDescription, error) { + query := "SELECT description FROM encrypted_community_description_cache WHERE community_id = ?" + + qr := tx.QueryRow(query, communityID) + + var descriptionBytes []byte + + err := qr.Scan(&descriptionBytes) + switch err { + case sql.ErrNoRows: + return nil, nil + case nil: + var communityDescription protobuf.CommunityDescription + err := proto.Unmarshal(descriptionBytes, &communityDescription) + if err != nil { + return nil, err + } + return &communityDescription, nil + default: + return nil, err + } +} diff --git a/protocol/communities/persistence_test.go b/protocol/communities/persistence_test.go index e42f6a21d..d86b279bd 100644 --- a/protocol/communities/persistence_test.go +++ b/protocol/communities/persistence_test.go @@ -8,6 +8,7 @@ import ( "testing" "time" + "github.com/golang/protobuf/proto" "github.com/stretchr/testify/suite" "github.com/status-im/status-go/appdatabase" @@ -16,6 +17,7 @@ import ( "github.com/status-im/status-go/protocol/common" "github.com/status-im/status-go/protocol/common/shard" "github.com/status-im/status-go/protocol/communities/token" + "github.com/status-im/status-go/protocol/encryption" "github.com/status-im/status-go/protocol/protobuf" "github.com/status-im/status-go/protocol/sqlite" "github.com/status-im/status-go/services/wallet/bigint" @@ -944,3 +946,109 @@ func (s *PersistenceSuite) TestProcessedCommunityEvents() { s.Require().Len(events, 3) s.Require().True(reflect.DeepEqual(events, map[string]uint64{"a": 2, "b": 10, "c": 1})) } + +func (s *PersistenceSuite) TestDecryptedCommunityCache() { + communityDescription := &protobuf.CommunityDescription{ + Clock: 1000, + } + keyID1 := []byte("key-id-1") + keyID2 := []byte("key-id-2") + missingKeys := []*CommunityPrivateDataFailedToDecrypt{ + {KeyID: keyID1}, + {KeyID: keyID2}, + } + communityID := []byte("id") + err := s.db.SaveDecryptedCommunityDescription(communityID, missingKeys, communityDescription) + s.Require().NoError(err) + + // Can be retrieved + retrievedCommunity, err := s.db.GetDecryptedCommunityDescription(communityID, 1000) + s.Require().NoError(err) + s.Require().True(proto.Equal(communityDescription, retrievedCommunity)) + + // Retrieving a random one doesn't throw an error + retrievedCommunity, err = s.db.GetDecryptedCommunityDescription([]byte("non-existent-id"), 1000) + s.Require().NoError(err) + s.Require().Nil(retrievedCommunity) + + // Retrieving a random one doesn't throw an error + retrievedCommunity, err = s.db.GetDecryptedCommunityDescription(communityID, 999) + s.Require().NoError(err) + s.Require().Nil(retrievedCommunity) + + // invalidating the cache + err = s.db.InvalidateDecryptedCommunityCacheForKeys([]*encryption.HashRatchetInfo{{KeyID: keyID1}}) + s.Require().NoError(err) + + // community cannot be retrieved anymore + retrievedCommunity, err = s.db.GetDecryptedCommunityDescription(communityID, 1000) + s.Require().NoError(err) + s.Require().Nil(retrievedCommunity) + + // make sure everything is cleaned up + + qr := s.db.db.QueryRow("SELECT COUNT(*) FROM encrypted_community_description_missing_keys") + + var count int + + err = qr.Scan(&count) + s.Require().NoError(err) + s.Require().Equal(count, 0) + +} + +func (s *PersistenceSuite) TestDecryptedCommunityCacheClock() { + communityDescription := &protobuf.CommunityDescription{ + Clock: 1000, + } + keyID1 := []byte("key-id-1") + keyID2 := []byte("key-id-2") + keyID3 := []byte("key-id-3") + + missingKeys := []*CommunityPrivateDataFailedToDecrypt{ + {KeyID: keyID1}, + {KeyID: keyID2}, + } + communityID := []byte("id") + err := s.db.SaveDecryptedCommunityDescription(communityID, missingKeys, communityDescription) + s.Require().NoError(err) + + // Can be retrieved + retrievedCommunity, err := s.db.GetDecryptedCommunityDescription(communityID, 1000) + s.Require().NoError(err) + s.Require().True(proto.Equal(communityDescription, retrievedCommunity)) + + // Save an earlier community + communityDescription.Clock = 999 + err = s.db.SaveDecryptedCommunityDescription(communityID, missingKeys, communityDescription) + s.Require().NoError(err) + + // The old one should be retrieved + retrievedCommunity, err = s.db.GetDecryptedCommunityDescription(communityID, 1000) + s.Require().NoError(err) + s.Require().NotNil(retrievedCommunity) + s.Require().Equal(uint64(1000), retrievedCommunity.Clock) + + // Save a later community, with a single key + missingKeys = []*CommunityPrivateDataFailedToDecrypt{ + {KeyID: keyID3}, + } + + communityDescription.Clock = 1001 + err = s.db.SaveDecryptedCommunityDescription(communityID, missingKeys, communityDescription) + s.Require().NoError(err) + + // The new one should be retrieved + retrievedCommunity, err = s.db.GetDecryptedCommunityDescription(communityID, 1001) + s.Require().NoError(err) + s.Require().Equal(uint64(1001), retrievedCommunity.Clock) + + // Make sure the previous two are cleaned up and there's only one left + qr := s.db.db.QueryRow("SELECT COUNT(*) FROM encrypted_community_description_missing_keys") + + var count int + + err = qr.Scan(&count) + s.Require().NoError(err) + s.Require().Equal(count, 1) +} diff --git a/protocol/encryption/migrations/migrations.go b/protocol/encryption/migrations/migrations.go index 3b1495a35..68b7d5e24 100644 --- a/protocol/encryption/migrations/migrations.go +++ b/protocol/encryption/migrations/migrations.go @@ -19,6 +19,7 @@ // 1632236298_add_communities.down.sql (151B) // 1632236298_add_communities.up.sql (584B) // 1636536507_add_index_bundles.up.sql (347B) +// 1709200114_add_migration_index.up.sql (483B) // doc.go (397B) package migrations @@ -467,6 +468,26 @@ func _1636536507_add_index_bundlesUpSql() (*asset, error) { return a, nil } +var __1709200114_add_migration_indexUpSql = []byte("\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\xff\xa4\x90\xc1\xaa\xc2\x30\x10\x45\xf7\xfd\x8a\x59\xbe\x07\xfe\x81\x2b\x69\x23\x74\xd3\x82\xed\xa2\xbb\x21\x24\x83\x09\xa5\x6d\x4c\x46\x30\x7f\x2f\x25\x16\x41\xb1\x2a\x6e\x66\x36\xf7\xdc\x03\x37\x3f\x88\x5d\x2b\xa0\xac\x0a\xd1\x41\xb9\x87\xaa\x6e\x41\x74\x65\xd3\x36\x60\xf5\x05\x8f\x7e\x3a\x3b\x64\x3b\x50\x60\x39\x38\xd4\x14\x54\x56\x57\x60\x64\x30\xe8\x25\x2b\x43\x8c\x34\x2a\x1f\x1d\xdb\x69\x84\xbf\x04\x58\xbd\x81\x9e\xe2\x1d\x84\x42\x34\xf9\xff\x36\xcb\xde\xf8\x5e\xf4\x62\x4f\x31\xcc\x07\xad\x5e\xd5\xa7\xc8\xaf\x22\x4d\xce\x93\x92\x4c\xfa\x13\xe7\x53\xfa\x5b\xbd\x92\xca\x10\x2e\xcb\xdd\x4a\x30\xd0\x09\xc7\x69\xc5\x9c\xb8\xc7\xc9\xe7\x9f\xd0\x65\xf3\x6b\x00\x00\x00\xff\xff\x97\xf4\x28\xe3\xe3\x01\x00\x00") + +func _1709200114_add_migration_indexUpSqlBytes() ([]byte, error) { + return bindataRead( + __1709200114_add_migration_indexUpSql, + "1709200114_add_migration_index.up.sql", + ) +} + +func _1709200114_add_migration_indexUpSql() (*asset, error) { + bytes, err := _1709200114_add_migration_indexUpSqlBytes() + if err != nil { + return nil, err + } + + info := bindataFileInfo{name: "1709200114_add_migration_index.up.sql", size: 483, mode: os.FileMode(0644), modTime: time.Unix(1700000000, 0)} + a := &asset{bytes: bytes, info: info, digest: [32]uint8{0xe2, 0xec, 0xd4, 0x54, 0xff, 0x5e, 0x6e, 0xaf, 0x3f, 0x2b, 0xb5, 0x76, 0xe9, 0x84, 0x2a, 0x4d, 0x1f, 0xd8, 0x22, 0x8b, 0x4b, 0x5c, 0xf1, 0xe0, 0x3a, 0x34, 0xc5, 0xed, 0xef, 0x74, 0xe4, 0x2b}} + return a, nil +} + var _docGo = []byte("\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\xff\x84\x8f\xbd\x6a\x2b\x31\x10\x85\xfb\x7d\x8a\x83\x1b\x37\x77\xa5\x1b\x08\x04\x02\x29\x52\xa6\xcf\x0b\x8c\xa5\x59\x69\xf0\x4a\xda\x68\x66\xfd\xf3\xf6\x61\x1d\x43\xdc\x65\xca\x0f\xbe\x73\xce\x78\x8f\xcf\x2c\x8a\x49\x66\x86\x28\x2a\x07\x56\xa5\x7e\xc5\x81\x03\xad\xca\xd8\x25\xb1\xbc\x1e\x5c\x68\xc5\xab\x91\xad\x3a\x4a\xf1\x45\x52\x27\x63\x7f\x7a\xde\x0d\xde\x23\x50\xdd\x1b\x32\xd5\x38\xf3\x2d\x4b\xa1\x46\xdd\xa4\x26\x9c\xc5\x32\x08\x4b\xe7\x49\x2e\x0e\xef\x86\x99\x49\x0d\x96\xc9\xf6\x0a\xcb\x8c\x40\xca\x5b\xcc\xd4\x3a\x52\x1b\x0f\x52\x23\x19\xb9\x0d\x7d\x4c\x0f\x64\x5b\x18\x68\x9e\x39\x62\xea\xad\xdc\x5c\xa5\xc2\x88\xd2\x39\x58\xeb\xd7\x7f\x20\x55\x36\x54\x2a\xac\x9b\x9f\xe9\xc4\xa8\xed\x5e\x0f\xaa\xf1\xef\x8f\x70\x6e\xfd\xa8\x20\x05\x5f\x16\x0e\xc6\xd1\x0d\xc3\x42\xe1\x48\x89\xa1\x5f\xb3\x18\x0f\x83\xf7\xa9\xbd\x26\xae\xbc\x59\x8f\x1b\xc7\xd2\xa2\x49\xe1\xb7\xa7\x97\xff\xf7\xc3\xb8\x1c\x13\x7e\x1a\xa4\x55\xc5\xd8\xe0\x9c\xff\x05\x2e\x35\xb8\xe1\x3b\x00\x00\xff\xff\x73\x18\x09\xa7\x8d\x01\x00\x00") func docGoBytes() ([]byte, error) { @@ -597,6 +618,7 @@ var _bindata = map[string]func() (*asset, error){ "1632236298_add_communities.down.sql": _1632236298_add_communitiesDownSql, "1632236298_add_communities.up.sql": _1632236298_add_communitiesUpSql, "1636536507_add_index_bundles.up.sql": _1636536507_add_index_bundlesUpSql, + "1709200114_add_migration_index.up.sql": _1709200114_add_migration_indexUpSql, "doc.go": docGo, } @@ -665,6 +687,7 @@ var _bintree = &bintree{nil, map[string]*bintree{ "1632236298_add_communities.down.sql": {_1632236298_add_communitiesDownSql, map[string]*bintree{}}, "1632236298_add_communities.up.sql": {_1632236298_add_communitiesUpSql, map[string]*bintree{}}, "1636536507_add_index_bundles.up.sql": {_1636536507_add_index_bundlesUpSql, map[string]*bintree{}}, + "1709200114_add_migration_index.up.sql": {_1709200114_add_migration_indexUpSql, map[string]*bintree{}}, "doc.go": {docGo, map[string]*bintree{}}, }} diff --git a/protocol/encryption/migrations/sqlite/1709200114_add_migration_index.up.sql b/protocol/encryption/migrations/sqlite/1709200114_add_migration_index.up.sql new file mode 100644 index 000000000..c1bf16532 --- /dev/null +++ b/protocol/encryption/migrations/sqlite/1709200114_add_migration_index.up.sql @@ -0,0 +1,11 @@ +CREATE INDEX IF NOT EXISTS idx_group_timestamp_desc +ON hash_ratchet_encryption (group_id, key_timestamp DESC); + +CREATE INDEX IF NOT EXISTS idx_hash_ratchet_encryption_keys_key_id +ON hash_ratchet_encryption (key_id); + +CREATE INDEX IF NOT EXISTS idx_hash_ratchet_encryption_keys_deprecated_key_id +ON hash_ratchet_encryption (deprecated_key_id); + +CREATE INDEX IF NOT EXISTS idx_hash_ratchet_cache_group_id_key_id_seq_no +ON hash_ratchet_encryption_cache (group_id, key_id, seq_no DESC); diff --git a/protocol/encryption/persistence.go b/protocol/encryption/persistence.go index 2d7b48db7..3988aea03 100644 --- a/protocol/encryption/persistence.go +++ b/protocol/encryption/persistence.go @@ -1,6 +1,7 @@ package encryption import ( + "context" "crypto/ecdsa" "database/sql" "errors" @@ -742,52 +743,53 @@ type HRCache struct { // If cache data with given seqNo (e.g. 0) is not found, // then the query will return the cache data with the latest seqNo func (s *sqlitePersistence) GetHashRatchetCache(ratchet *HashRatchetKeyCompatibility, seqNo uint32) (*HRCache, error) { - stmt, err := s.DB.Prepare(`WITH input AS ( - select ? AS group_id, ? AS key_id, ? as seq_no, ? AS old_key_id - ), - cec AS ( - SELECT e.key, c.seq_no, c.hash FROM hash_ratchet_encryption e, input i - LEFT JOIN hash_ratchet_encryption_cache c ON e.group_id=c.group_id AND (e.key_id=c.key_id OR e.deprecated_key_id=c.key_id) - WHERE (e.key_id=i.key_id OR e.deprecated_key_id=i.old_key_id) AND e.group_id=i.group_id), - seq_nos AS ( - select CASE - WHEN EXISTS (SELECT c.seq_no from cec c, input i where c.seq_no=i.seq_no) - THEN i.seq_no - ELSE (select max(seq_no) from cec) - END as seq_no from input i - ) - SELECT c.key, c.seq_no, c.hash FROM cec c, input i, seq_nos s - where case when not exists(select seq_no from seq_nos where seq_no is not null) - then 1 else c.seq_no = s.seq_no end`) + tx, err := s.DB.BeginTx(context.Background(), &sql.TxOptions{}) if err != nil { return nil, err } - defer stmt.Close() + defer func() { + if err == nil { + err = tx.Commit() + return + } + // don't shadow original error + _ = tx.Rollback() + }() - var key, hash []byte - var seqNoPtr *uint32 - - oldFormat := ratchet.IsOldFormat() - if oldFormat { - // Query using the deprecated format - err = stmt.QueryRow(ratchet.GroupID, nil, seqNo, ratchet.DeprecatedKeyID()).Scan(&key, &seqNoPtr, &hash) //nolint: ineffassign - - } else { - keyID, err := ratchet.GetKeyID() + var key, keyID []byte + if !ratchet.IsOldFormat() { + keyID, err = ratchet.GetKeyID() if err != nil { return nil, err } - - err = stmt.QueryRow(ratchet.GroupID, keyID, seqNo, ratchet.DeprecatedKeyID()).Scan(&key, &seqNoPtr, &hash) //nolint: ineffassign,staticcheck } - if len(hash) == 0 && len(key) == 0 { + + err = tx.QueryRow("SELECT key FROM hash_ratchet_encryption WHERE key_id = ? OR deprecated_key_id = ?", keyID, ratchet.DeprecatedKeyID()).Scan(&key) + if err == sql.ErrNoRows { return nil, nil } + if err != nil { + return nil, err + } + args := make([]interface{}, 0) + args = append(args, ratchet.GroupID) + args = append(args, keyID) + args = append(args, ratchet.DeprecatedKeyID()) + var query string + if seqNo == 0 { + query = "SELECT seq_no, hash FROM hash_ratchet_encryption_cache WHERE group_id = ? AND (key_id = ? OR key_id = ?) ORDER BY seq_no DESC limit 1" + } else { + query = "SELECT seq_no, hash FROM hash_ratchet_encryption_cache WHERE group_id = ? AND (key_id = ? OR key_id = ?) AND seq_no == ? ORDER BY seq_no DESC limit 1" + args = append(args, seqNo) + } + + var hash []byte + var seqNoPtr *uint32 + + err = tx.QueryRow(query, args...).Scan(&seqNoPtr, &hash) //nolint: ineffassign,staticcheck switch err { - case sql.ErrNoRows: - return nil, nil - case nil: + case sql.ErrNoRows, nil: var seqNoResult uint32 if seqNoPtr == nil { seqNoResult = 0 diff --git a/protocol/encryption/persistence_test.go b/protocol/encryption/persistence_test.go index 5edda8391..8c8d9cdb5 100644 --- a/protocol/encryption/persistence_test.go +++ b/protocol/encryption/persistence_test.go @@ -360,4 +360,32 @@ func (s *SQLLitePersistenceTestSuite) TestGetHashRatchetKeyByID() { s.Require().True(reflect.DeepEqual(key.keyID, cachedKey.KeyID)) s.Require().True(reflect.DeepEqual(key.Key, cachedKey.Key)) s.Require().EqualValues(0, cachedKey.SeqNo) + + var newSeqNo uint32 = 1 + newHash := []byte{10, 11, 12} + err = s.service.SaveHashRatchetKeyHash(key, newHash, newSeqNo) + s.Require().NoError(err) + + cachedKey, err = s.service.GetHashRatchetCache(retrievedKey, 0) + s.Require().NoError(err) + s.Require().True(reflect.DeepEqual(key.keyID, cachedKey.KeyID)) + s.Require().True(reflect.DeepEqual(key.Key, cachedKey.Key)) + s.Require().EqualValues(1, cachedKey.SeqNo) + + newSeqNo = 4 + newHash = []byte{10, 11, 13} + err = s.service.SaveHashRatchetKeyHash(key, newHash, newSeqNo) + s.Require().NoError(err) + + cachedKey, err = s.service.GetHashRatchetCache(retrievedKey, 0) + s.Require().NoError(err) + s.Require().True(reflect.DeepEqual(key.keyID, cachedKey.KeyID)) + s.Require().True(reflect.DeepEqual(key.Key, cachedKey.Key)) + s.Require().EqualValues(4, cachedKey.SeqNo) + + cachedKey, err = s.service.GetHashRatchetCache(retrievedKey, 1) + s.Require().NoError(err) + s.Require().True(reflect.DeepEqual(key.keyID, cachedKey.KeyID)) + s.Require().True(reflect.DeepEqual(key.Key, cachedKey.Key)) + s.Require().EqualValues(1, cachedKey.SeqNo) } diff --git a/protocol/encryption/protocol.go b/protocol/encryption/protocol.go index 82a4c951f..28907c407 100644 --- a/protocol/encryption/protocol.go +++ b/protocol/encryption/protocol.go @@ -26,6 +26,7 @@ const ( sharedSecretNegotiationVersion = 1 partitionedTopicMinVersion = 1 defaultMinVersion = 0 + maxKeysChannelSize = 10000 ) type PartitionTopicMode int @@ -121,9 +122,10 @@ func NewWithEncryptorConfig( } type Subscriptions struct { - SharedSecrets []*sharedsecret.Secret - SendContactCode <-chan struct{} - Quit chan struct{} + SharedSecrets []*sharedsecret.Secret + SendContactCode <-chan struct{} + NewHashRatchetKeys chan []*HashRatchetInfo + Quit chan struct{} } func (p *Protocol) Start(myIdentity *ecdsa.PrivateKey) (*Subscriptions, error) { @@ -133,9 +135,10 @@ func (p *Protocol) Start(myIdentity *ecdsa.PrivateKey) (*Subscriptions, error) { return nil, errors.Wrap(err, "failed to get all secrets") } p.subscriptions = &Subscriptions{ - SharedSecrets: secrets, - SendContactCode: p.publisher.Start(), - Quit: make(chan struct{}), + SharedSecrets: secrets, + SendContactCode: p.publisher.Start(), + NewHashRatchetKeys: make(chan []*HashRatchetInfo, maxKeysChannelSize), + Quit: make(chan struct{}), } return p.subscriptions, nil } @@ -691,6 +694,10 @@ func (p *Protocol) HandleHashRatchetKeys(groupID []byte, keys *HRKeys, myIdentit } } + if p.subscriptions != nil { + p.subscriptions.NewHashRatchetKeys <- info + } + return info, nil } diff --git a/protocol/messenger.go b/protocol/messenger.go index 41a0e4470..c158ef5f1 100644 --- a/protocol/messenger.go +++ b/protocol/messenger.go @@ -1428,6 +1428,13 @@ func (m *Messenger) handleEncryptionLayerSubscriptions(subscriptions *encryption m.logger.Error("failed to clean processed messages", zap.Error(err)) } + case keys := <-subscriptions.NewHashRatchetKeys: + if m.communitiesManager == nil { + continue + } + if err := m.communitiesManager.NewHashRatchetKeys(keys); err != nil { + m.logger.Error("failed to invalidate cache for decrypted communities", zap.Error(err)) + } case <-subscriptions.Quit: m.logger.Debug("quitting encryption subscription loop") return @@ -3720,6 +3727,13 @@ func (m *Messenger) handleImportedMessages(messagesToHandle map[transport.Filter publicKey := msg.SigPubKey() senderID := contactIDFromPublicKey(publicKey) + if len(msg.EncryptionLayer.HashRatchetInfo) != 0 { + err := m.communitiesManager.NewHashRatchetKeys(msg.EncryptionLayer.HashRatchetInfo) + if err != nil { + m.logger.Warn("failed to invalidate communities description cache", zap.Error(err)) + } + + } // Don't process duplicates messageID := msg.TransportLayer.Message.ThirdPartyID exists, err := m.messageExists(messageID, messageState.ExistingMessagesMap) diff --git a/protocol/messenger_storenode_request_test.go b/protocol/messenger_storenode_request_test.go index 55cc0e4b7..132840906 100644 --- a/protocol/messenger_storenode_request_test.go +++ b/protocol/messenger_storenode_request_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/suite" "go.uber.org/zap" + "google.golang.org/protobuf/proto" "github.com/waku-org/go-waku/waku/v2/protocol/store" @@ -270,8 +271,14 @@ func (s *MessengerStoreNodeRequestSuite) requireCommunitiesEqual(c *communities. s.Require().Equal(expected.Color(), c.Color()) s.Require().Equal(expected.Tags(), c.Tags()) s.Require().Equal(expected.Shard(), c.Shard()) - s.Require().Equal(expected.TokenPermissions(), c.TokenPermissions()) - s.Require().Equal(expected.CommunityTokensMetadata(), c.CommunityTokensMetadata()) + s.Require().Equal(len(expected.TokenPermissions()), len(c.TokenPermissions())) + for k, v := range expected.TokenPermissions() { + s.Require().True(proto.Equal(v, c.TokenPermissions()[k])) + } + s.Require().Equal(len(expected.CommunityTokensMetadata()), len(c.CommunityTokensMetadata())) + for i, v := range expected.CommunityTokensMetadata() { + s.Require().True(proto.Equal(v, c.CommunityTokensMetadata()[i])) + } } func (s *MessengerStoreNodeRequestSuite) requireContactsEqual(c *Contact, expected *Contact) { diff --git a/protocol/migrations/migrations.go b/protocol/migrations/migrations.go index 2fdb878e4..e68772f6c 100644 --- a/protocol/migrations/migrations.go +++ b/protocol/migrations/migrations.go @@ -127,6 +127,7 @@ // 1708423707_applied_community_events.up.sql (201B) // 1708440786_profile_showcase_social_links.up.sql (906B) // 1709805967_simplify_profile_showcase_preferences.up.sql (701B) +// 1709828431_add_community_description_cache.up.sql (730B) // README.md (554B) // doc.go (870B) @@ -2736,6 +2737,26 @@ func _1709805967_simplify_profile_showcase_preferencesUpSql() (*asset, error) { return a, nil } +var __1709828431_add_community_description_cacheUpSql = []byte("\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\xff\xa4\x91\xb1\x6e\x83\x30\x10\x86\x77\x3f\xc5\x8d\x20\x31\x56\x5d\x98\x8c\x39\x2a\xab\xd4\x4e\x8d\x91\xc8\x64\x45\xc6\x6a\x2d\x0a\x44\x81\x0e\xbc\x7d\x45\x89\x1a\x2a\x3a\xa4\x89\x37\xdf\xfd\x3a\x7d\xf7\x1d\x53\x48\x35\x82\xa6\x49\x8e\xc0\x33\x10\x52\x03\x56\xbc\xd0\x05\xb8\xce\x9e\xa6\xe3\xe8\x6a\x63\xfb\xb6\xfd\xec\xfc\x38\x99\xda\x0d\xf6\xe4\x8f\xa3\xef\x3b\x63\x0f\xf6\xdd\x41\x40\x00\x2e\x7d\x5f\x83\xc6\x4a\xc3\x4e\xf1\x17\xaa\xf6\xf0\x8c\xfb\x68\x0e\x7c\xf4\xb6\x81\x92\x0b\xfd\xf8\x30\xff\x57\x63\x20\xc9\x65\x32\xd7\x4a\xc1\x5f\x4b\x0c\xd6\xb3\x42\x90\x02\x98\x14\x59\xce\x99\x06\x85\xbb\x9c\x32\x24\x00\x61\x4c\xc8\x1d\xe0\xad\x1f\x06\xdf\xbd\x99\xc6\x4d\xc3\x37\xff\xf2\x36\x5b\x44\x3f\xad\xc6\xfd\x51\x5c\xed\x08\xbf\xb0\xa3\x73\x3e\xbc\x64\x33\xa9\x90\x3f\x89\x6d\x36\x04\x85\x19\x2a\x14\x0c\xaf\x34\xbe\x35\x94\x62\x8e\x1a\x81\xd1\x82\xd1\x14\xc9\xca\x0e\x17\x29\x56\xff\xb4\xe3\x6b\x73\xe8\x6a\xb3\x9c\x4c\x8a\x1b\xa0\xa2\xe5\xde\x61\x7c\x0f\xc6\xa2\x70\xb8\x86\x60\x7d\xcf\xe0\xac\x3e\x26\x5f\x01\x00\x00\xff\xff\x71\xc3\x1d\x10\xda\x02\x00\x00") + +func _1709828431_add_community_description_cacheUpSqlBytes() ([]byte, error) { + return bindataRead( + __1709828431_add_community_description_cacheUpSql, + "1709828431_add_community_description_cache.up.sql", + ) +} + +func _1709828431_add_community_description_cacheUpSql() (*asset, error) { + bytes, err := _1709828431_add_community_description_cacheUpSqlBytes() + if err != nil { + return nil, err + } + + info := bindataFileInfo{name: "1709828431_add_community_description_cache.up.sql", size: 730, mode: os.FileMode(0644), modTime: time.Unix(1700000000, 0)} + a := &asset{bytes: bytes, info: info, digest: [32]uint8{0xfc, 0xe4, 0x66, 0xd6, 0x9d, 0xb8, 0x87, 0x6e, 0x70, 0xfd, 0x78, 0xa, 0x8c, 0xfb, 0xb2, 0xbc, 0xc4, 0x8c, 0x8d, 0x77, 0xc2, 0xf, 0xe1, 0x68, 0xf3, 0xd6, 0xf3, 0xb0, 0x42, 0x86, 0x3f, 0xf4}} + return a, nil +} + var _readmeMd = []byte("\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\xff\x54\x91\xc1\xce\xd3\x30\x10\x84\xef\x7e\x8a\x91\x7a\x01\xa9\x2a\x8f\xc0\x0d\x71\x82\x03\x48\x1c\xc9\x36\x9e\x36\x96\x1c\x6f\xf0\xae\x93\xe6\xed\x91\xa3\xc2\xdf\xff\x66\xed\xd8\x33\xdf\x78\x4f\xa7\x13\xbe\xea\x06\x57\x6c\x35\x39\x31\xa7\x7b\x15\x4f\x5a\xec\x73\x08\xbf\x08\x2d\x79\x7f\x4a\x43\x5b\x86\x17\xfd\x8c\x21\xea\x56\x5e\x47\x90\x4a\x14\x75\x48\xde\x64\x37\x2c\x6a\x96\xae\x99\x48\x05\xf6\x27\x77\x13\xad\x08\xae\x8a\x51\xe7\x25\xf3\xf1\xa9\x9f\xf9\x58\x58\x2c\xad\xbc\xe0\x8b\x56\xf0\x21\x5d\xeb\x4c\x95\xb3\xae\x84\x60\xd4\xdc\xe6\x82\x5d\x1b\x36\x6d\x39\x62\x92\xf5\xb8\x11\xdb\x92\xd3\x28\xce\xe0\x13\xe1\x72\xcd\x3c\x63\xd4\x65\x87\xae\xac\xe8\xc3\x28\x2e\x67\x44\x66\x3a\x21\x25\xa2\x72\xac\x14\x67\xbc\x84\x9f\x53\x32\x8c\x52\x70\x25\x56\xd6\xfd\x8d\x05\x37\xad\x30\x9d\x9f\xa6\x86\x0f\xcd\x58\x7f\xcf\x34\x93\x3b\xed\x90\x9f\xa4\x1f\xcf\x30\x85\x4d\x07\x58\xaf\x7f\x25\xc4\x9d\xf3\x72\x64\x84\xd0\x7f\xf9\x9b\x3a\x2d\x84\xef\x85\x48\x66\x8d\xd8\x88\x9b\x8c\x8c\x98\x5b\xf6\x74\x14\x4e\x33\x0d\xc9\xe0\x93\x38\xda\x12\xc5\x69\xbd\xe4\xf0\x2e\x7a\x78\x07\x1c\xfe\x13\x9f\x91\x29\x31\x95\x7b\x7f\x62\x59\x37\xb4\xe5\x5e\x25\xfe\x33\xee\xd5\x53\x71\xd6\xda\x3a\xd8\xcb\xde\x2e\xf8\xa1\x90\x55\x53\x0c\xc7\xaa\x0d\xe9\x76\x14\x29\x1c\x7b\x68\xdd\x2f\xe1\x6f\x00\x00\x00\xff\xff\x3c\x0a\xc2\xfe\x2a\x02\x00\x00") func readmeMdBytes() ([]byte, error) { @@ -2994,6 +3015,7 @@ var _bindata = map[string]func() (*asset, error){ "1708423707_applied_community_events.up.sql": _1708423707_applied_community_eventsUpSql, "1708440786_profile_showcase_social_links.up.sql": _1708440786_profile_showcase_social_linksUpSql, "1709805967_simplify_profile_showcase_preferences.up.sql": _1709805967_simplify_profile_showcase_preferencesUpSql, + "1709828431_add_community_description_cache.up.sql": _1709828431_add_community_description_cacheUpSql, "README.md": readmeMd, "doc.go": docGo, } @@ -3171,6 +3193,7 @@ var _bintree = &bintree{nil, map[string]*bintree{ "1708423707_applied_community_events.up.sql": {_1708423707_applied_community_eventsUpSql, map[string]*bintree{}}, "1708440786_profile_showcase_social_links.up.sql": {_1708440786_profile_showcase_social_linksUpSql, map[string]*bintree{}}, "1709805967_simplify_profile_showcase_preferences.up.sql": {_1709805967_simplify_profile_showcase_preferencesUpSql, map[string]*bintree{}}, + "1709828431_add_community_description_cache.up.sql": {_1709828431_add_community_description_cacheUpSql, map[string]*bintree{}}, "README.md": {readmeMd, map[string]*bintree{}}, "doc.go": {docGo, map[string]*bintree{}}, }} diff --git a/protocol/migrations/sqlite/1709828431_add_community_description_cache.up.sql b/protocol/migrations/sqlite/1709828431_add_community_description_cache.up.sql new file mode 100644 index 000000000..b7435760b --- /dev/null +++ b/protocol/migrations/sqlite/1709828431_add_community_description_cache.up.sql @@ -0,0 +1,16 @@ +CREATE TABLE IF NOT EXISTS encrypted_community_description_cache ( + community_id TEXT PRIMARY KEY, + clock UINT64, + description BLOB, + UNIQUE(community_id) ON CONFLICT REPLACE + ); + +CREATE TABLE IF NOT EXISTS encrypted_community_description_missing_keys ( + community_id TEXT, + key_id TEXT, + PRIMARY KEY (community_id, key_id), + FOREIGN KEY (community_id) REFERENCES encrypted_community_description_cache(community_id) ON DELETE CASCADE +); + +CREATE INDEX IF NOT EXISTS encrypted_community_description_id_and_clock ON encrypted_community_description_cache(community_id, clock); +CREATE INDEX IF NOT EXISTS encrypted_community_description_key_ids ON encrypted_community_description_missing_keys(key_id);