From d99fdf1b672c2f8207371e7e2386a403196341b9 Mon Sep 17 00:00:00 2001 From: Jonathan Rainville Date: Fri, 25 Oct 2024 17:28:58 -0400 Subject: [PATCH] fix(contacts)_: fix trust status not being saved to cache when changed (#5965) Fixes https://github.com/status-im/status-desktop/issues/16392 --- protocol/messenger_contact_verification.go | 30 ++++++++---- .../messenger_contact_verification_test.go | 47 +++++++++++++++++++ 2 files changed, 68 insertions(+), 9 deletions(-) diff --git a/protocol/messenger_contact_verification.go b/protocol/messenger_contact_verification.go index b0d3328b6..af6c6c8fa 100644 --- a/protocol/messenger_contact_verification.go +++ b/protocol/messenger_contact_verification.go @@ -24,6 +24,10 @@ import ( const minContactVerificationMessageLen = 1 const maxContactVerificationMessageLen = 280 +var ( + ErrContactNotMutual = errors.New("must be a mutual contact") +) + func (m *Messenger) SendContactVerificationRequest(ctx context.Context, contactID string, challenge string) (*MessengerResponse, error) { if len(challenge) < minContactVerificationMessageLen || len(challenge) > maxContactVerificationMessageLen { return nil, errors.New("invalid verification request challenge length") @@ -31,7 +35,7 @@ func (m *Messenger) SendContactVerificationRequest(ctx context.Context, contactI contact, ok := m.allContacts.Load(contactID) if !ok || !contact.mutual() { - return nil, errors.New("must be a mutual contact") + return nil, ErrContactNotMutual } verifRequest := &verification.Request{ @@ -138,7 +142,7 @@ func (m *Messenger) SendContactVerificationRequest(ctx context.Context, contactI func (m *Messenger) GetVerificationRequestSentTo(ctx context.Context, contactID string) (*verification.Request, error) { _, ok := m.allContacts.Load(contactID) if !ok { - return nil, errors.New("contact not found") + return nil, ErrContactNotFound } return m.verificationDatabase.GetLatestVerificationRequestSentTo(contactID) @@ -279,7 +283,7 @@ func (m *Messenger) AcceptContactVerificationRequest(ctx context.Context, id str contact, ok := m.allContacts.Load(contactID) if !ok || !contact.mutual() { - return nil, errors.New("must be a mutual contact") + return nil, ErrContactNotMutual } chat, ok := m.allChats.Load(contactID) @@ -394,7 +398,7 @@ func (m *Messenger) VerifiedTrusted(ctx context.Context, request *requests.Verif contact, ok := m.allContacts.Load(contactID) if !ok || !contact.mutual() { - return nil, errors.New("must be a mutual contact") + return nil, ErrContactNotMutual } err = m.setTrustStatusForContact(context.Background(), contactID, verification.TrustStatusTRUSTED) @@ -589,7 +593,7 @@ func (m *Messenger) DeclineContactVerificationRequest(ctx context.Context, id st contact, ok := m.allContacts.Load(verifRequest.From) if !ok || !contact.mutual() { - return nil, errors.New("must be a mutual contact") + return nil, ErrContactNotMutual } contactID := verifRequest.From contact, err = m.setContactVerificationStatus(contactID, VerificationStatusVERIFIED) @@ -686,7 +690,7 @@ func (m *Messenger) DeclineContactVerificationRequest(ctx context.Context, id st func (m *Messenger) setContactVerificationStatus(contactID string, verificationStatus VerificationStatus) (*Contact, error) { contact, ok := m.allContacts.Load(contactID) if !ok || !contact.mutual() { - return nil, errors.New("must be a mutual contact") + return nil, ErrContactNotMutual } contact.VerificationStatus = verificationStatus @@ -714,6 +718,11 @@ func (m *Messenger) setContactVerificationStatus(contactID string, verificationS } func (m *Messenger) setTrustStatusForContact(ctx context.Context, contactID string, trustStatus verification.TrustStatus) error { + contact, ok := m.allContacts.Load(contactID) + if !ok { + return ErrContactNotFound + } + currentTime := m.getTimesource().GetCurrentTime() err := m.verificationDatabase.SetTrustStatus(contactID, trustStatus, currentTime) @@ -721,6 +730,9 @@ func (m *Messenger) setTrustStatusForContact(ctx context.Context, contactID stri return err } + contact.TrustStatus = trustStatus + m.allContacts.Store(contactID, contact) + return m.SyncTrustedUser(ctx, contactID, trustStatus, m.dispatchMessage) } @@ -784,7 +796,7 @@ func (m *Messenger) HandleRequestContactVerification(state *ReceivedMessageState contact := state.CurrentMessageState.Contact if !contact.mutual() { m.logger.Debug("Received a verification request for a non added mutual contact", zap.String("contactID", contactID)) - return errors.New("must be a mutual contact") + return ErrContactNotMutual } persistedVR, err := m.verificationDatabase.GetVerificationRequest(id) @@ -875,7 +887,7 @@ func (m *Messenger) HandleAcceptContactVerification(state *ReceivedMessageState, contact := state.CurrentMessageState.Contact if !contact.mutual() { m.logger.Debug("Received a verification response for a non mutual contact", zap.String("contactID", contactID)) - return errors.New("must be a mutual contact") + return ErrContactNotMutual } persistedVR, err := m.verificationDatabase.GetVerificationRequest(request.Id) @@ -964,7 +976,7 @@ func (m *Messenger) HandleDeclineContactVerification(state *ReceivedMessageState contact := state.CurrentMessageState.Contact if !contact.mutual() { m.logger.Debug("Received a verification decline for a non mutual contact", zap.String("contactID", contactID)) - return errors.New("must be a mutual contact") + return ErrContactNotMutual } persistedVR, err := m.verificationDatabase.GetVerificationRequest(request.Id) diff --git a/protocol/messenger_contact_verification_test.go b/protocol/messenger_contact_verification_test.go index ee46692bb..8ff0ff696 100644 --- a/protocol/messenger_contact_verification_test.go +++ b/protocol/messenger_contact_verification_test.go @@ -769,3 +769,50 @@ func (s *MessengerVerificationRequests) newMessenger(shh types.Waku) *Messenger s.Require().NoError(err) return messenger } + +func (s *MessengerVerificationRequests) TestTrustStatus() { + theirMessenger := s.newMessenger(s.shh) + defer TearDownMessenger(&s.Suite, theirMessenger) + + s.mutualContact(theirMessenger) + + theirPk := types.EncodeHex(crypto.FromECDSAPub(&theirMessenger.identity.PublicKey)) + + // Test Mark as Trusted + err := s.m.MarkAsTrusted(context.Background(), theirPk) + s.Require().NoError(err) + + contactFromCache, ok := s.m.allContacts.Load(theirPk) + s.Require().True(ok) + s.Require().Equal(verification.TrustStatusTRUSTED, contactFromCache.TrustStatus) + trustStatusFromDb, err := s.m.GetTrustStatus(theirPk) + s.Require().NoError(err) + s.Require().Equal(verification.TrustStatusTRUSTED, trustStatusFromDb) + + // Test Remove Trust Mark + err = s.m.RemoveTrustStatus(context.Background(), theirPk) + s.Require().NoError(err) + + contactFromCache, ok = s.m.allContacts.Load(theirPk) + s.Require().True(ok) + s.Require().Equal(verification.TrustStatusUNKNOWN, contactFromCache.TrustStatus) + trustStatusFromDb, err = s.m.GetTrustStatus(theirPk) + s.Require().NoError(err) + s.Require().Equal(verification.TrustStatusUNKNOWN, trustStatusFromDb) + + // Test Mark as Untrustoworthy + err = s.m.MarkAsUntrustworthy(context.Background(), theirPk) + s.Require().NoError(err) + + contactFromCache, ok = s.m.allContacts.Load(theirPk) + s.Require().True(ok) + s.Require().Equal(verification.TrustStatusUNTRUSTWORTHY, contactFromCache.TrustStatus) + trustStatusFromDb, err = s.m.GetTrustStatus(theirPk) + s.Require().NoError(err) + s.Require().Equal(verification.TrustStatusUNTRUSTWORTHY, trustStatusFromDb) + + // Test calling with an unknown contact + err = s.m.MarkAsTrusted(context.Background(), "0x00000123") + s.Require().Error(err) + s.Require().Equal("contact not found", err.Error()) +}