Add index to hash ratchet & cache processed description

This commit is contained in:
Andrea Maria Piana 2024-02-29 09:51:38 +00:00
parent 8c0e24dc26
commit 67dfff2324
13 changed files with 465 additions and 57 deletions

View File

@ -107,6 +107,7 @@ type CommunityChat struct {
ViewersCanPostReactions bool `json:"viewersCanPostReactions"`
Position int `json:"position"`
CategoryID string `json:"categoryID"`
TokenGated bool `json:"tokenGated"`
}
type CommunityCategory struct {
@ -194,6 +195,7 @@ func (o *Community) MarshalPublicAPIJSON() ([]byte, error) {
Members: c.Members,
CanPost: canPost,
ViewersCanPostReactions: c.ViewersCanPostReactions,
TokenGated: o.channelEncrypted(id),
CategoryID: c.CategoryId,
Position: int(c.Position),
}
@ -279,8 +281,10 @@ func (o *Community) MarshalJSON() ([]byte, error) {
PubsubTopicKey string `json:"pubsubTopicKey"`
Shard *shard.Shard `json:"shard"`
LastOpenedAt int64 `json:"lastOpenedAt"`
Clock uint64 `json:"clock"`
}{
ID: o.ID(),
Clock: o.Clock(),
MemberRole: o.MemberRole(o.MemberIdentity()),
IsControlNode: o.IsControlNode(),
Verified: o.config.Verified,
@ -331,6 +335,7 @@ func (o *Community) MarshalJSON() ([]byte, error) {
Members: c.Members,
CanPost: canPost,
ViewersCanPostReactions: c.ViewersCanPostReactions,
TokenGated: o.channelEncrypted(id),
CategoryID: c.CategoryId,
Position: int(c.Position),
}

View File

@ -1723,11 +1723,10 @@ func (m *Manager) HandleCommunityDescriptionMessage(signer *ecdsa.PublicKey, des
id = crypto.CompressPubkey(signer)
}
failedToDecrypt, err := m.preprocessDescription(id, description)
failedToDecrypt, processedDescription, err := m.preprocessDescription(id, description)
if err != nil {
return nil, err
}
m.communityLock.Lock(id)
defer m.communityLock.Unlock(id)
community, err := m.GetByID(id)
@ -1737,12 +1736,12 @@ func (m *Manager) HandleCommunityDescriptionMessage(signer *ecdsa.PublicKey, des
// We don't process failed to decrypt if the whole metadata is encrypted
// and we joined the community already
if community != nil && community.Joined() && len(failedToDecrypt) != 0 && description != nil && len(description.Members) == 0 {
if community != nil && community.Joined() && len(failedToDecrypt) != 0 && processedDescription != nil && len(processedDescription.Members) == 0 {
return &CommunityResponse{FailedToDecrypt: failedToDecrypt}, nil
}
// We should queue only if the community has a token owner, and the owner has been verified
hasTokenOwnership := HasTokenOwnership(description)
hasTokenOwnership := HasTokenOwnership(processedDescription)
shouldQueue := hasTokenOwnership && verifiedOwner == nil
if community == nil {
@ -1751,7 +1750,7 @@ func (m *Manager) HandleCommunityDescriptionMessage(signer *ecdsa.PublicKey, des
return nil, err
}
config := Config{
CommunityDescription: description,
CommunityDescription: processedDescription,
Logger: m.logger,
CommunityDescriptionProtocolMessage: payload,
MemberIdentity: &m.identity.PublicKey,
@ -1772,15 +1771,15 @@ func (m *Manager) HandleCommunityDescriptionMessage(signer *ecdsa.PublicKey, des
// A new community, we need to check if we need to validate async.
// That would be the case if it has a contract. We queue everything and process separately.
if shouldQueue {
return nil, m.Queue(signer, community, description.Clock, payload)
return nil, m.Queue(signer, community, processedDescription.Clock, payload)
}
} else {
// only queue if already known control node is different than the signer
// and if the clock is greater
shouldQueue = shouldQueue && !common.IsPubKeyEqual(community.ControlNode(), signer) &&
community.config.CommunityDescription.Clock < description.Clock
community.config.CommunityDescription.Clock < processedDescription.Clock
if shouldQueue {
return nil, m.Queue(signer, community, description.Clock, payload)
return nil, m.Queue(signer, community, processedDescription.Clock, payload)
}
}
@ -1806,7 +1805,7 @@ func (m *Manager) HandleCommunityDescriptionMessage(signer *ecdsa.PublicKey, des
return nil, ErrNotAuthorized
}
r, err := m.handleCommunityDescriptionMessageCommon(community, description, payload, verifiedOwner)
r, err := m.handleCommunityDescriptionMessageCommon(community, processedDescription, payload, verifiedOwner)
if err != nil {
return nil, err
}
@ -1814,10 +1813,22 @@ func (m *Manager) HandleCommunityDescriptionMessage(signer *ecdsa.PublicKey, des
return r, nil
}
func (m *Manager) preprocessDescription(id types.HexBytes, description *protobuf.CommunityDescription) ([]*CommunityPrivateDataFailedToDecrypt, error) {
func (m *Manager) NewHashRatchetKeys(keys []*encryption.HashRatchetInfo) error {
return m.persistence.InvalidateDecryptedCommunityCacheForKeys(keys)
}
func (m *Manager) preprocessDescription(id types.HexBytes, description *protobuf.CommunityDescription) ([]*CommunityPrivateDataFailedToDecrypt, *protobuf.CommunityDescription, error) {
decryptedCommunity, err := m.persistence.GetDecryptedCommunityDescription(id, description.Clock)
if err != nil {
return nil, nil, err
}
if decryptedCommunity != nil {
return nil, decryptedCommunity, nil
}
response, err := decryptDescription(id, m, description, m.logger)
if err != nil {
return response, err
return response, description, err
}
upgradeTokenPermissions(description)
@ -1825,7 +1836,7 @@ func (m *Manager) preprocessDescription(id types.HexBytes, description *protobuf
// Workaround for https://github.com/status-im/status-desktop/issues/12188
hydrateChannelsMembers(types.EncodeHex(id), description)
return response, nil
return response, description, m.persistence.SaveDecryptedCommunityDescription(id, response, description)
}
func (m *Manager) handleCommunityDescriptionMessageCommon(community *Community, description *protobuf.CommunityDescription, payload []byte, newControlNode *ecdsa.PublicKey) (*CommunityResponse, error) {
@ -3035,12 +3046,12 @@ func (m *Manager) HandleCommunityRequestToJoinResponse(signer *ecdsa.PublicKey,
return nil, ErrNotAuthorized
}
_, err = m.preprocessDescription(community.ID(), request.Community)
_, processedDescription, err := m.preprocessDescription(community.ID(), request.Community)
if err != nil {
return nil, err
}
_, err = community.UpdateCommunityDescription(request.Community, appMetadataMsg, nil)
_, err = community.UpdateCommunityDescription(processedDescription, appMetadataMsg, nil)
if err != nil {
return nil, err
}
@ -3391,11 +3402,13 @@ func (m *Manager) dbRecordBundleToCommunity(r *CommunityRecordBundle) (*Communit
}
return recordBundleToCommunity(r, &m.identity.PublicKey, m.installationID, m.logger, m.timesource, descriptionEncryptor, func(community *Community) error {
_, err := m.preprocessDescription(community.ID(), community.config.CommunityDescription)
_, description, err := m.preprocessDescription(community.ID(), community.config.CommunityDescription)
if err != nil {
return err
}
community.config.CommunityDescription = description
if community.config.EventsData != nil {
eventsDescription, err := validateAndGetEventsMessageCommunityDescription(community.config.EventsData.EventsBaseCommunityDescription, community.ControlNode())
if err != nil {

View File

@ -18,6 +18,7 @@ import (
"github.com/status-im/status-go/protocol/common"
"github.com/status-im/status-go/protocol/common/shard"
"github.com/status-im/status-go/protocol/communities/token"
"github.com/status-im/status-go/protocol/encryption"
"github.com/status-im/status-go/protocol/protobuf"
"github.com/status-im/status-go/services/wallet/bigint"
)
@ -1872,3 +1873,153 @@ func (p *Persistence) UpsertAppliedCommunityEvents(communityID types.HexBytes, p
}
return err
}
func (p *Persistence) InvalidateDecryptedCommunityCacheForKeys(keys []*encryption.HashRatchetInfo) error {
tx, err := p.db.BeginTx(context.Background(), &sql.TxOptions{})
if err != nil {
return err
}
defer func() {
if err == nil {
err = tx.Commit()
return
}
// don't shadow original error
_ = tx.Rollback()
}()
if len(keys) == 0 {
return nil
}
idsArgs := make([]interface{}, 0, len(keys))
for _, k := range keys {
idsArgs = append(idsArgs, k.KeyID)
}
inVector := strings.Repeat("?, ", len(keys)-1) + "?"
query := "SELECT DISTINCT(community_id) FROM encrypted_community_description_missing_keys WHERE key_id IN (" + inVector + ")" // nolint: gosec
var communityIDs []interface{}
rows, err := tx.Query(query, idsArgs...)
if err != nil {
return err
}
defer rows.Close()
for rows.Next() {
var communityID []byte
err = rows.Scan(&communityID)
if err != nil {
return err
}
communityIDs = append(communityIDs, communityID)
}
if len(communityIDs) == 0 {
return nil
}
inVector = strings.Repeat("?, ", len(communityIDs)-1) + "?"
query = "DELETE FROM encrypted_community_description_cache WHERE community_id IN (" + inVector + ")" //nolint: gosec
_, err = tx.Exec(query, communityIDs...)
return err
}
func (p *Persistence) SaveDecryptedCommunityDescription(communityID []byte, missingKeys []*CommunityPrivateDataFailedToDecrypt, description *protobuf.CommunityDescription) error {
if description == nil {
return nil
}
marshaledDescription, err := proto.Marshal(description)
if err != nil {
return err
}
tx, err := p.db.BeginTx(context.Background(), &sql.TxOptions{})
if err != nil {
return err
}
defer func() {
if err == nil {
err = tx.Commit()
return
}
// don't shadow original error
_ = tx.Rollback()
}()
previousCommunity, err := p.getDecryptedCommunityDescriptionByID(tx, communityID)
if err != nil {
return err
}
if previousCommunity != nil && previousCommunity.Clock >= description.Clock {
return nil
}
insertCommunity := "INSERT INTO encrypted_community_description_cache (community_id, clock, description) VALUES (?, ?, ?);"
_, err = tx.Exec(insertCommunity, communityID, description.Clock, marshaledDescription)
if err != nil {
return err
}
for _, key := range missingKeys {
insertKey := "INSERT INTO encrypted_community_description_missing_keys (community_id, key_id) VALUES(?, ?)"
_, err = tx.Exec(insertKey, communityID, key.KeyID)
if err != nil {
return err
}
}
return nil
}
func (p *Persistence) GetDecryptedCommunityDescription(communityID []byte, clock uint64) (*protobuf.CommunityDescription, error) {
return p.getDecryptedCommunityDescriptionByIDAndClock(communityID, clock)
}
func (p *Persistence) getDecryptedCommunityDescriptionByIDAndClock(communityID []byte, clock uint64) (*protobuf.CommunityDescription, error) {
query := "SELECT description FROM encrypted_community_description_cache WHERE community_id = ? AND clock = ?"
qr := p.db.QueryRow(query, communityID, clock)
var descriptionBytes []byte
err := qr.Scan(&descriptionBytes)
switch err {
case sql.ErrNoRows:
return nil, nil
case nil:
var communityDescription protobuf.CommunityDescription
err := proto.Unmarshal(descriptionBytes, &communityDescription)
if err != nil {
return nil, err
}
return &communityDescription, nil
default:
return nil, err
}
}
func (p *Persistence) getDecryptedCommunityDescriptionByID(tx *sql.Tx, communityID []byte) (*protobuf.CommunityDescription, error) {
query := "SELECT description FROM encrypted_community_description_cache WHERE community_id = ?"
qr := tx.QueryRow(query, communityID)
var descriptionBytes []byte
err := qr.Scan(&descriptionBytes)
switch err {
case sql.ErrNoRows:
return nil, nil
case nil:
var communityDescription protobuf.CommunityDescription
err := proto.Unmarshal(descriptionBytes, &communityDescription)
if err != nil {
return nil, err
}
return &communityDescription, nil
default:
return nil, err
}
}

View File

@ -8,6 +8,7 @@ import (
"testing"
"time"
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/suite"
"github.com/status-im/status-go/appdatabase"
@ -16,6 +17,7 @@ import (
"github.com/status-im/status-go/protocol/common"
"github.com/status-im/status-go/protocol/common/shard"
"github.com/status-im/status-go/protocol/communities/token"
"github.com/status-im/status-go/protocol/encryption"
"github.com/status-im/status-go/protocol/protobuf"
"github.com/status-im/status-go/protocol/sqlite"
"github.com/status-im/status-go/services/wallet/bigint"
@ -944,3 +946,109 @@ func (s *PersistenceSuite) TestProcessedCommunityEvents() {
s.Require().Len(events, 3)
s.Require().True(reflect.DeepEqual(events, map[string]uint64{"a": 2, "b": 10, "c": 1}))
}
func (s *PersistenceSuite) TestDecryptedCommunityCache() {
communityDescription := &protobuf.CommunityDescription{
Clock: 1000,
}
keyID1 := []byte("key-id-1")
keyID2 := []byte("key-id-2")
missingKeys := []*CommunityPrivateDataFailedToDecrypt{
{KeyID: keyID1},
{KeyID: keyID2},
}
communityID := []byte("id")
err := s.db.SaveDecryptedCommunityDescription(communityID, missingKeys, communityDescription)
s.Require().NoError(err)
// Can be retrieved
retrievedCommunity, err := s.db.GetDecryptedCommunityDescription(communityID, 1000)
s.Require().NoError(err)
s.Require().True(proto.Equal(communityDescription, retrievedCommunity))
// Retrieving a random one doesn't throw an error
retrievedCommunity, err = s.db.GetDecryptedCommunityDescription([]byte("non-existent-id"), 1000)
s.Require().NoError(err)
s.Require().Nil(retrievedCommunity)
// Retrieving a random one doesn't throw an error
retrievedCommunity, err = s.db.GetDecryptedCommunityDescription(communityID, 999)
s.Require().NoError(err)
s.Require().Nil(retrievedCommunity)
// invalidating the cache
err = s.db.InvalidateDecryptedCommunityCacheForKeys([]*encryption.HashRatchetInfo{{KeyID: keyID1}})
s.Require().NoError(err)
// community cannot be retrieved anymore
retrievedCommunity, err = s.db.GetDecryptedCommunityDescription(communityID, 1000)
s.Require().NoError(err)
s.Require().Nil(retrievedCommunity)
// make sure everything is cleaned up
qr := s.db.db.QueryRow("SELECT COUNT(*) FROM encrypted_community_description_missing_keys")
var count int
err = qr.Scan(&count)
s.Require().NoError(err)
s.Require().Equal(count, 0)
}
func (s *PersistenceSuite) TestDecryptedCommunityCacheClock() {
communityDescription := &protobuf.CommunityDescription{
Clock: 1000,
}
keyID1 := []byte("key-id-1")
keyID2 := []byte("key-id-2")
keyID3 := []byte("key-id-3")
missingKeys := []*CommunityPrivateDataFailedToDecrypt{
{KeyID: keyID1},
{KeyID: keyID2},
}
communityID := []byte("id")
err := s.db.SaveDecryptedCommunityDescription(communityID, missingKeys, communityDescription)
s.Require().NoError(err)
// Can be retrieved
retrievedCommunity, err := s.db.GetDecryptedCommunityDescription(communityID, 1000)
s.Require().NoError(err)
s.Require().True(proto.Equal(communityDescription, retrievedCommunity))
// Save an earlier community
communityDescription.Clock = 999
err = s.db.SaveDecryptedCommunityDescription(communityID, missingKeys, communityDescription)
s.Require().NoError(err)
// The old one should be retrieved
retrievedCommunity, err = s.db.GetDecryptedCommunityDescription(communityID, 1000)
s.Require().NoError(err)
s.Require().NotNil(retrievedCommunity)
s.Require().Equal(uint64(1000), retrievedCommunity.Clock)
// Save a later community, with a single key
missingKeys = []*CommunityPrivateDataFailedToDecrypt{
{KeyID: keyID3},
}
communityDescription.Clock = 1001
err = s.db.SaveDecryptedCommunityDescription(communityID, missingKeys, communityDescription)
s.Require().NoError(err)
// The new one should be retrieved
retrievedCommunity, err = s.db.GetDecryptedCommunityDescription(communityID, 1001)
s.Require().NoError(err)
s.Require().Equal(uint64(1001), retrievedCommunity.Clock)
// Make sure the previous two are cleaned up and there's only one left
qr := s.db.db.QueryRow("SELECT COUNT(*) FROM encrypted_community_description_missing_keys")
var count int
err = qr.Scan(&count)
s.Require().NoError(err)
s.Require().Equal(count, 1)
}

View File

@ -19,6 +19,7 @@
// 1632236298_add_communities.down.sql (151B)
// 1632236298_add_communities.up.sql (584B)
// 1636536507_add_index_bundles.up.sql (347B)
// 1709200114_add_migration_index.up.sql (483B)
// doc.go (397B)
package migrations
@ -467,6 +468,26 @@ func _1636536507_add_index_bundlesUpSql() (*asset, error) {
return a, nil
}
var __1709200114_add_migration_indexUpSql = []byte("\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\xff\xa4\x90\xc1\xaa\xc2\x30\x10\x45\xf7\xfd\x8a\x59\xbe\x07\xfe\x81\x2b\x69\x23\x74\xd3\x82\xed\xa2\xbb\x21\x24\x83\x09\xa5\x6d\x4c\x46\x30\x7f\x2f\x25\x16\x41\xb1\x2a\x6e\x66\x36\xf7\xdc\x03\x37\x3f\x88\x5d\x2b\xa0\xac\x0a\xd1\x41\xb9\x87\xaa\x6e\x41\x74\x65\xd3\x36\x60\xf5\x05\x8f\x7e\x3a\x3b\x64\x3b\x50\x60\x39\x38\xd4\x14\x54\x56\x57\x60\x64\x30\xe8\x25\x2b\x43\x8c\x34\x2a\x1f\x1d\xdb\x69\x84\xbf\x04\x58\xbd\x81\x9e\xe2\x1d\x84\x42\x34\xf9\xff\x36\xcb\xde\xf8\x5e\xf4\x62\x4f\x31\xcc\x07\xad\x5e\xd5\xa7\xc8\xaf\x22\x4d\xce\x93\x92\x4c\xfa\x13\xe7\x53\xfa\x5b\xbd\x92\xca\x10\x2e\xcb\xdd\x4a\x30\xd0\x09\xc7\x69\xc5\x9c\xb8\xc7\xc9\xe7\x9f\xd0\x65\xf3\x6b\x00\x00\x00\xff\xff\x97\xf4\x28\xe3\xe3\x01\x00\x00")
func _1709200114_add_migration_indexUpSqlBytes() ([]byte, error) {
return bindataRead(
__1709200114_add_migration_indexUpSql,
"1709200114_add_migration_index.up.sql",
)
}
func _1709200114_add_migration_indexUpSql() (*asset, error) {
bytes, err := _1709200114_add_migration_indexUpSqlBytes()
if err != nil {
return nil, err
}
info := bindataFileInfo{name: "1709200114_add_migration_index.up.sql", size: 483, mode: os.FileMode(0644), modTime: time.Unix(1700000000, 0)}
a := &asset{bytes: bytes, info: info, digest: [32]uint8{0xe2, 0xec, 0xd4, 0x54, 0xff, 0x5e, 0x6e, 0xaf, 0x3f, 0x2b, 0xb5, 0x76, 0xe9, 0x84, 0x2a, 0x4d, 0x1f, 0xd8, 0x22, 0x8b, 0x4b, 0x5c, 0xf1, 0xe0, 0x3a, 0x34, 0xc5, 0xed, 0xef, 0x74, 0xe4, 0x2b}}
return a, nil
}
var _docGo = []byte("\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\xff\x84\x8f\xbd\x6a\x2b\x31\x10\x85\xfb\x7d\x8a\x83\x1b\x37\x77\xa5\x1b\x08\x04\x02\x29\x52\xa6\xcf\x0b\x8c\xa5\x59\x69\xf0\x4a\xda\x68\x66\xfd\xf3\xf6\x61\x1d\x43\xdc\x65\xca\x0f\xbe\x73\xce\x78\x8f\xcf\x2c\x8a\x49\x66\x86\x28\x2a\x07\x56\xa5\x7e\xc5\x81\x03\xad\xca\xd8\x25\xb1\xbc\x1e\x5c\x68\xc5\xab\x91\xad\x3a\x4a\xf1\x45\x52\x27\x63\x7f\x7a\xde\x0d\xde\x23\x50\xdd\x1b\x32\xd5\x38\xf3\x2d\x4b\xa1\x46\xdd\xa4\x26\x9c\xc5\x32\x08\x4b\xe7\x49\x2e\x0e\xef\x86\x99\x49\x0d\x96\xc9\xf6\x0a\xcb\x8c\x40\xca\x5b\xcc\xd4\x3a\x52\x1b\x0f\x52\x23\x19\xb9\x0d\x7d\x4c\x0f\x64\x5b\x18\x68\x9e\x39\x62\xea\xad\xdc\x5c\xa5\xc2\x88\xd2\x39\x58\xeb\xd7\x7f\x20\x55\x36\x54\x2a\xac\x9b\x9f\xe9\xc4\xa8\xed\x5e\x0f\xaa\xf1\xef\x8f\x70\x6e\xfd\xa8\x20\x05\x5f\x16\x0e\xc6\xd1\x0d\xc3\x42\xe1\x48\x89\xa1\x5f\xb3\x18\x0f\x83\xf7\xa9\xbd\x26\xae\xbc\x59\x8f\x1b\xc7\xd2\xa2\x49\xe1\xb7\xa7\x97\xff\xf7\xc3\xb8\x1c\x13\x7e\x1a\xa4\x55\xc5\xd8\xe0\x9c\xff\x05\x2e\x35\xb8\xe1\x3b\x00\x00\xff\xff\x73\x18\x09\xa7\x8d\x01\x00\x00")
func docGoBytes() ([]byte, error) {
@ -597,6 +618,7 @@ var _bindata = map[string]func() (*asset, error){
"1632236298_add_communities.down.sql": _1632236298_add_communitiesDownSql,
"1632236298_add_communities.up.sql": _1632236298_add_communitiesUpSql,
"1636536507_add_index_bundles.up.sql": _1636536507_add_index_bundlesUpSql,
"1709200114_add_migration_index.up.sql": _1709200114_add_migration_indexUpSql,
"doc.go": docGo,
}
@ -665,6 +687,7 @@ var _bintree = &bintree{nil, map[string]*bintree{
"1632236298_add_communities.down.sql": {_1632236298_add_communitiesDownSql, map[string]*bintree{}},
"1632236298_add_communities.up.sql": {_1632236298_add_communitiesUpSql, map[string]*bintree{}},
"1636536507_add_index_bundles.up.sql": {_1636536507_add_index_bundlesUpSql, map[string]*bintree{}},
"1709200114_add_migration_index.up.sql": {_1709200114_add_migration_indexUpSql, map[string]*bintree{}},
"doc.go": {docGo, map[string]*bintree{}},
}}

View File

@ -0,0 +1,11 @@
CREATE INDEX IF NOT EXISTS idx_group_timestamp_desc
ON hash_ratchet_encryption (group_id, key_timestamp DESC);
CREATE INDEX IF NOT EXISTS idx_hash_ratchet_encryption_keys_key_id
ON hash_ratchet_encryption (key_id);
CREATE INDEX IF NOT EXISTS idx_hash_ratchet_encryption_keys_deprecated_key_id
ON hash_ratchet_encryption (deprecated_key_id);
CREATE INDEX IF NOT EXISTS idx_hash_ratchet_cache_group_id_key_id_seq_no
ON hash_ratchet_encryption_cache (group_id, key_id, seq_no DESC);

View File

@ -1,6 +1,7 @@
package encryption
import (
"context"
"crypto/ecdsa"
"database/sql"
"errors"
@ -742,52 +743,53 @@ type HRCache struct {
// If cache data with given seqNo (e.g. 0) is not found,
// then the query will return the cache data with the latest seqNo
func (s *sqlitePersistence) GetHashRatchetCache(ratchet *HashRatchetKeyCompatibility, seqNo uint32) (*HRCache, error) {
stmt, err := s.DB.Prepare(`WITH input AS (
select ? AS group_id, ? AS key_id, ? as seq_no, ? AS old_key_id
),
cec AS (
SELECT e.key, c.seq_no, c.hash FROM hash_ratchet_encryption e, input i
LEFT JOIN hash_ratchet_encryption_cache c ON e.group_id=c.group_id AND (e.key_id=c.key_id OR e.deprecated_key_id=c.key_id)
WHERE (e.key_id=i.key_id OR e.deprecated_key_id=i.old_key_id) AND e.group_id=i.group_id),
seq_nos AS (
select CASE
WHEN EXISTS (SELECT c.seq_no from cec c, input i where c.seq_no=i.seq_no)
THEN i.seq_no
ELSE (select max(seq_no) from cec)
END as seq_no from input i
)
SELECT c.key, c.seq_no, c.hash FROM cec c, input i, seq_nos s
where case when not exists(select seq_no from seq_nos where seq_no is not null)
then 1 else c.seq_no = s.seq_no end`)
tx, err := s.DB.BeginTx(context.Background(), &sql.TxOptions{})
if err != nil {
return nil, err
}
defer stmt.Close()
defer func() {
if err == nil {
err = tx.Commit()
return
}
// don't shadow original error
_ = tx.Rollback()
}()
var key, hash []byte
var seqNoPtr *uint32
oldFormat := ratchet.IsOldFormat()
if oldFormat {
// Query using the deprecated format
err = stmt.QueryRow(ratchet.GroupID, nil, seqNo, ratchet.DeprecatedKeyID()).Scan(&key, &seqNoPtr, &hash) //nolint: ineffassign
} else {
keyID, err := ratchet.GetKeyID()
var key, keyID []byte
if !ratchet.IsOldFormat() {
keyID, err = ratchet.GetKeyID()
if err != nil {
return nil, err
}
err = stmt.QueryRow(ratchet.GroupID, keyID, seqNo, ratchet.DeprecatedKeyID()).Scan(&key, &seqNoPtr, &hash) //nolint: ineffassign,staticcheck
}
if len(hash) == 0 && len(key) == 0 {
err = tx.QueryRow("SELECT key FROM hash_ratchet_encryption WHERE key_id = ? OR deprecated_key_id = ?", keyID, ratchet.DeprecatedKeyID()).Scan(&key)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
args := make([]interface{}, 0)
args = append(args, ratchet.GroupID)
args = append(args, keyID)
args = append(args, ratchet.DeprecatedKeyID())
var query string
if seqNo == 0 {
query = "SELECT seq_no, hash FROM hash_ratchet_encryption_cache WHERE group_id = ? AND (key_id = ? OR key_id = ?) ORDER BY seq_no DESC limit 1"
} else {
query = "SELECT seq_no, hash FROM hash_ratchet_encryption_cache WHERE group_id = ? AND (key_id = ? OR key_id = ?) AND seq_no == ? ORDER BY seq_no DESC limit 1"
args = append(args, seqNo)
}
var hash []byte
var seqNoPtr *uint32
err = tx.QueryRow(query, args...).Scan(&seqNoPtr, &hash) //nolint: ineffassign,staticcheck
switch err {
case sql.ErrNoRows:
return nil, nil
case nil:
case sql.ErrNoRows, nil:
var seqNoResult uint32
if seqNoPtr == nil {
seqNoResult = 0

View File

@ -360,4 +360,32 @@ func (s *SQLLitePersistenceTestSuite) TestGetHashRatchetKeyByID() {
s.Require().True(reflect.DeepEqual(key.keyID, cachedKey.KeyID))
s.Require().True(reflect.DeepEqual(key.Key, cachedKey.Key))
s.Require().EqualValues(0, cachedKey.SeqNo)
var newSeqNo uint32 = 1
newHash := []byte{10, 11, 12}
err = s.service.SaveHashRatchetKeyHash(key, newHash, newSeqNo)
s.Require().NoError(err)
cachedKey, err = s.service.GetHashRatchetCache(retrievedKey, 0)
s.Require().NoError(err)
s.Require().True(reflect.DeepEqual(key.keyID, cachedKey.KeyID))
s.Require().True(reflect.DeepEqual(key.Key, cachedKey.Key))
s.Require().EqualValues(1, cachedKey.SeqNo)
newSeqNo = 4
newHash = []byte{10, 11, 13}
err = s.service.SaveHashRatchetKeyHash(key, newHash, newSeqNo)
s.Require().NoError(err)
cachedKey, err = s.service.GetHashRatchetCache(retrievedKey, 0)
s.Require().NoError(err)
s.Require().True(reflect.DeepEqual(key.keyID, cachedKey.KeyID))
s.Require().True(reflect.DeepEqual(key.Key, cachedKey.Key))
s.Require().EqualValues(4, cachedKey.SeqNo)
cachedKey, err = s.service.GetHashRatchetCache(retrievedKey, 1)
s.Require().NoError(err)
s.Require().True(reflect.DeepEqual(key.keyID, cachedKey.KeyID))
s.Require().True(reflect.DeepEqual(key.Key, cachedKey.Key))
s.Require().EqualValues(1, cachedKey.SeqNo)
}

View File

@ -26,6 +26,7 @@ const (
sharedSecretNegotiationVersion = 1
partitionedTopicMinVersion = 1
defaultMinVersion = 0
maxKeysChannelSize = 10000
)
type PartitionTopicMode int
@ -121,9 +122,10 @@ func NewWithEncryptorConfig(
}
type Subscriptions struct {
SharedSecrets []*sharedsecret.Secret
SendContactCode <-chan struct{}
Quit chan struct{}
SharedSecrets []*sharedsecret.Secret
SendContactCode <-chan struct{}
NewHashRatchetKeys chan []*HashRatchetInfo
Quit chan struct{}
}
func (p *Protocol) Start(myIdentity *ecdsa.PrivateKey) (*Subscriptions, error) {
@ -133,9 +135,10 @@ func (p *Protocol) Start(myIdentity *ecdsa.PrivateKey) (*Subscriptions, error) {
return nil, errors.Wrap(err, "failed to get all secrets")
}
p.subscriptions = &Subscriptions{
SharedSecrets: secrets,
SendContactCode: p.publisher.Start(),
Quit: make(chan struct{}),
SharedSecrets: secrets,
SendContactCode: p.publisher.Start(),
NewHashRatchetKeys: make(chan []*HashRatchetInfo, maxKeysChannelSize),
Quit: make(chan struct{}),
}
return p.subscriptions, nil
}
@ -691,6 +694,10 @@ func (p *Protocol) HandleHashRatchetKeys(groupID []byte, keys *HRKeys, myIdentit
}
}
if p.subscriptions != nil {
p.subscriptions.NewHashRatchetKeys <- info
}
return info, nil
}

View File

@ -1428,6 +1428,13 @@ func (m *Messenger) handleEncryptionLayerSubscriptions(subscriptions *encryption
m.logger.Error("failed to clean processed messages", zap.Error(err))
}
case keys := <-subscriptions.NewHashRatchetKeys:
if m.communitiesManager == nil {
continue
}
if err := m.communitiesManager.NewHashRatchetKeys(keys); err != nil {
m.logger.Error("failed to invalidate cache for decrypted communities", zap.Error(err))
}
case <-subscriptions.Quit:
m.logger.Debug("quitting encryption subscription loop")
return
@ -3720,6 +3727,13 @@ func (m *Messenger) handleImportedMessages(messagesToHandle map[transport.Filter
publicKey := msg.SigPubKey()
senderID := contactIDFromPublicKey(publicKey)
if len(msg.EncryptionLayer.HashRatchetInfo) != 0 {
err := m.communitiesManager.NewHashRatchetKeys(msg.EncryptionLayer.HashRatchetInfo)
if err != nil {
m.logger.Warn("failed to invalidate communities description cache", zap.Error(err))
}
}
// Don't process duplicates
messageID := msg.TransportLayer.Message.ThirdPartyID
exists, err := m.messageExists(messageID, messageState.ExistingMessagesMap)

View File

@ -9,6 +9,7 @@ import (
"github.com/stretchr/testify/suite"
"go.uber.org/zap"
"google.golang.org/protobuf/proto"
"github.com/waku-org/go-waku/waku/v2/protocol/store"
@ -270,8 +271,14 @@ func (s *MessengerStoreNodeRequestSuite) requireCommunitiesEqual(c *communities.
s.Require().Equal(expected.Color(), c.Color())
s.Require().Equal(expected.Tags(), c.Tags())
s.Require().Equal(expected.Shard(), c.Shard())
s.Require().Equal(expected.TokenPermissions(), c.TokenPermissions())
s.Require().Equal(expected.CommunityTokensMetadata(), c.CommunityTokensMetadata())
s.Require().Equal(len(expected.TokenPermissions()), len(c.TokenPermissions()))
for k, v := range expected.TokenPermissions() {
s.Require().True(proto.Equal(v, c.TokenPermissions()[k]))
}
s.Require().Equal(len(expected.CommunityTokensMetadata()), len(c.CommunityTokensMetadata()))
for i, v := range expected.CommunityTokensMetadata() {
s.Require().True(proto.Equal(v, c.CommunityTokensMetadata()[i]))
}
}
func (s *MessengerStoreNodeRequestSuite) requireContactsEqual(c *Contact, expected *Contact) {

View File

@ -127,6 +127,7 @@
// 1708423707_applied_community_events.up.sql (201B)
// 1708440786_profile_showcase_social_links.up.sql (906B)
// 1709805967_simplify_profile_showcase_preferences.up.sql (701B)
// 1709828431_add_community_description_cache.up.sql (730B)
// README.md (554B)
// doc.go (870B)
@ -2736,6 +2737,26 @@ func _1709805967_simplify_profile_showcase_preferencesUpSql() (*asset, error) {
return a, nil
}
var __1709828431_add_community_description_cacheUpSql = []byte("\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\xff\xa4\x91\xb1\x6e\x83\x30\x10\x86\x77\x3f\xc5\x8d\x20\x31\x56\x5d\x98\x8c\x39\x2a\xab\xd4\x4e\x8d\x91\xc8\x64\x45\xc6\x6a\x2d\x0a\x44\x81\x0e\xbc\x7d\x45\x89\x1a\x2a\x3a\xa4\x89\x37\xdf\xfd\x3a\x7d\xf7\x1d\x53\x48\x35\x82\xa6\x49\x8e\xc0\x33\x10\x52\x03\x56\xbc\xd0\x05\xb8\xce\x9e\xa6\xe3\xe8\x6a\x63\xfb\xb6\xfd\xec\xfc\x38\x99\xda\x0d\xf6\xe4\x8f\xa3\xef\x3b\x63\x0f\xf6\xdd\x41\x40\x00\x2e\x7d\x5f\x83\xc6\x4a\xc3\x4e\xf1\x17\xaa\xf6\xf0\x8c\xfb\x68\x0e\x7c\xf4\xb6\x81\x92\x0b\xfd\xf8\x30\xff\x57\x63\x20\xc9\x65\x32\xd7\x4a\xc1\x5f\x4b\x0c\xd6\xb3\x42\x90\x02\x98\x14\x59\xce\x99\x06\x85\xbb\x9c\x32\x24\x00\x61\x4c\xc8\x1d\xe0\xad\x1f\x06\xdf\xbd\x99\xc6\x4d\xc3\x37\xff\xf2\x36\x5b\x44\x3f\xad\xc6\xfd\x51\x5c\xed\x08\xbf\xb0\xa3\x73\x3e\xbc\x64\x33\xa9\x90\x3f\x89\x6d\x36\x04\x85\x19\x2a\x14\x0c\xaf\x34\xbe\x35\x94\x62\x8e\x1a\x81\xd1\x82\xd1\x14\xc9\xca\x0e\x17\x29\x56\xff\xb4\xe3\x6b\x73\xe8\x6a\xb3\x9c\x4c\x8a\x1b\xa0\xa2\xe5\xde\x61\x7c\x0f\xc6\xa2\x70\xb8\x86\x60\x7d\xcf\xe0\xac\x3e\x26\x5f\x01\x00\x00\xff\xff\x71\xc3\x1d\x10\xda\x02\x00\x00")
func _1709828431_add_community_description_cacheUpSqlBytes() ([]byte, error) {
return bindataRead(
__1709828431_add_community_description_cacheUpSql,
"1709828431_add_community_description_cache.up.sql",
)
}
func _1709828431_add_community_description_cacheUpSql() (*asset, error) {
bytes, err := _1709828431_add_community_description_cacheUpSqlBytes()
if err != nil {
return nil, err
}
info := bindataFileInfo{name: "1709828431_add_community_description_cache.up.sql", size: 730, mode: os.FileMode(0644), modTime: time.Unix(1700000000, 0)}
a := &asset{bytes: bytes, info: info, digest: [32]uint8{0xfc, 0xe4, 0x66, 0xd6, 0x9d, 0xb8, 0x87, 0x6e, 0x70, 0xfd, 0x78, 0xa, 0x8c, 0xfb, 0xb2, 0xbc, 0xc4, 0x8c, 0x8d, 0x77, 0xc2, 0xf, 0xe1, 0x68, 0xf3, 0xd6, 0xf3, 0xb0, 0x42, 0x86, 0x3f, 0xf4}}
return a, nil
}
var _readmeMd = []byte("\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\xff\x54\x91\xc1\xce\xd3\x30\x10\x84\xef\x7e\x8a\x91\x7a\x01\xa9\x2a\x8f\xc0\x0d\x71\x82\x03\x48\x1c\xc9\x36\x9e\x36\x96\x1c\x6f\xf0\xae\x93\xe6\xed\x91\xa3\xc2\xdf\xff\x66\xed\xd8\x33\xdf\x78\x4f\xa7\x13\xbe\xea\x06\x57\x6c\x35\x39\x31\xa7\x7b\x15\x4f\x5a\xec\x73\x08\xbf\x08\x2d\x79\x7f\x4a\x43\x5b\x86\x17\xfd\x8c\x21\xea\x56\x5e\x47\x90\x4a\x14\x75\x48\xde\x64\x37\x2c\x6a\x96\xae\x99\x48\x05\xf6\x27\x77\x13\xad\x08\xae\x8a\x51\xe7\x25\xf3\xf1\xa9\x9f\xf9\x58\x58\x2c\xad\xbc\xe0\x8b\x56\xf0\x21\x5d\xeb\x4c\x95\xb3\xae\x84\x60\xd4\xdc\xe6\x82\x5d\x1b\x36\x6d\x39\x62\x92\xf5\xb8\x11\xdb\x92\xd3\x28\xce\xe0\x13\xe1\x72\xcd\x3c\x63\xd4\x65\x87\xae\xac\xe8\xc3\x28\x2e\x67\x44\x66\x3a\x21\x25\xa2\x72\xac\x14\x67\xbc\x84\x9f\x53\x32\x8c\x52\x70\x25\x56\xd6\xfd\x8d\x05\x37\xad\x30\x9d\x9f\xa6\x86\x0f\xcd\x58\x7f\xcf\x34\x93\x3b\xed\x90\x9f\xa4\x1f\xcf\x30\x85\x4d\x07\x58\xaf\x7f\x25\xc4\x9d\xf3\x72\x64\x84\xd0\x7f\xf9\x9b\x3a\x2d\x84\xef\x85\x48\x66\x8d\xd8\x88\x9b\x8c\x8c\x98\x5b\xf6\x74\x14\x4e\x33\x0d\xc9\xe0\x93\x38\xda\x12\xc5\x69\xbd\xe4\xf0\x2e\x7a\x78\x07\x1c\xfe\x13\x9f\x91\x29\x31\x95\x7b\x7f\x62\x59\x37\xb4\xe5\x5e\x25\xfe\x33\xee\xd5\x53\x71\xd6\xda\x3a\xd8\xcb\xde\x2e\xf8\xa1\x90\x55\x53\x0c\xc7\xaa\x0d\xe9\x76\x14\x29\x1c\x7b\x68\xdd\x2f\xe1\x6f\x00\x00\x00\xff\xff\x3c\x0a\xc2\xfe\x2a\x02\x00\x00")
func readmeMdBytes() ([]byte, error) {
@ -2994,6 +3015,7 @@ var _bindata = map[string]func() (*asset, error){
"1708423707_applied_community_events.up.sql": _1708423707_applied_community_eventsUpSql,
"1708440786_profile_showcase_social_links.up.sql": _1708440786_profile_showcase_social_linksUpSql,
"1709805967_simplify_profile_showcase_preferences.up.sql": _1709805967_simplify_profile_showcase_preferencesUpSql,
"1709828431_add_community_description_cache.up.sql": _1709828431_add_community_description_cacheUpSql,
"README.md": readmeMd,
"doc.go": docGo,
}
@ -3171,6 +3193,7 @@ var _bintree = &bintree{nil, map[string]*bintree{
"1708423707_applied_community_events.up.sql": {_1708423707_applied_community_eventsUpSql, map[string]*bintree{}},
"1708440786_profile_showcase_social_links.up.sql": {_1708440786_profile_showcase_social_linksUpSql, map[string]*bintree{}},
"1709805967_simplify_profile_showcase_preferences.up.sql": {_1709805967_simplify_profile_showcase_preferencesUpSql, map[string]*bintree{}},
"1709828431_add_community_description_cache.up.sql": {_1709828431_add_community_description_cacheUpSql, map[string]*bintree{}},
"README.md": {readmeMd, map[string]*bintree{}},
"doc.go": {docGo, map[string]*bintree{}},
}}

View File

@ -0,0 +1,16 @@
CREATE TABLE IF NOT EXISTS encrypted_community_description_cache (
community_id TEXT PRIMARY KEY,
clock UINT64,
description BLOB,
UNIQUE(community_id) ON CONFLICT REPLACE
);
CREATE TABLE IF NOT EXISTS encrypted_community_description_missing_keys (
community_id TEXT,
key_id TEXT,
PRIMARY KEY (community_id, key_id),
FOREIGN KEY (community_id) REFERENCES encrypted_community_description_cache(community_id) ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS encrypted_community_description_id_and_clock ON encrypted_community_description_cache(community_id, clock);
CREATE INDEX IF NOT EXISTS encrypted_community_description_key_ids ON encrypted_community_description_missing_keys(key_id);