diff --git a/protocol/messenger_contact_verification.go b/protocol/messenger_contact_verification.go index 32a0cbfae..f7d868e39 100644 --- a/protocol/messenger_contact_verification.go +++ b/protocol/messenger_contact_verification.go @@ -397,12 +397,8 @@ func (m *Messenger) VerifiedTrusted(ctx context.Context, request *requests.Verif return nil, errors.New("must be a mutual contact") } - err = m.verificationDatabase.SetTrustStatus(contactID, verification.TrustStatusTRUSTED, m.getTimesource().GetCurrentTime()) - if err != nil { - return nil, err - } + err = m.setTrustStatusForContact(context.Background(), contactID, verification.TrustStatusTRUSTED) - err = m.SyncTrustedUser(context.Background(), contactID, verification.TrustStatusTRUSTED, m.dispatchMessage) if err != nil { return nil, err } @@ -507,24 +503,12 @@ func (m *Messenger) VerifiedUntrustworthy(ctx context.Context, request *requests contactID := notification.ReplyMessage.From - contact, ok := m.allContacts.Load(contactID) - if !ok || !contact.mutual() { - return nil, errors.New("must be a mutual contact") - } - - err = m.verificationDatabase.SetTrustStatus(contactID, verification.TrustStatusUNTRUSTWORTHY, m.getTimesource().GetCurrentTime()) + err = m.setTrustStatusForContact(context.Background(), contactID, verification.TrustStatusUNTRUSTWORTHY) if err != nil { return nil, err } - err = m.SyncTrustedUser(context.Background(), contactID, verification.TrustStatusUNTRUSTWORTHY, m.dispatchMessage) - if err != nil { - return nil, err - } - - contact.VerificationStatus = VerificationStatusVERIFIED - contact.LastUpdatedLocally = m.getTimesource().GetCurrentTime() - err = m.persistence.SaveContact(contact, nil) + contact, err := m.setContactVerificationStatus(contactID, VerificationStatusVERIFIED) if err != nil { return nil, err } @@ -562,18 +546,6 @@ func (m *Messenger) VerifiedUntrustworthy(ctx context.Context, request *requests return nil, err } - // We sync the contact with the other devices - err = m.syncContact(context.Background(), contact, m.dispatchMessage) - if err != nil { - return nil, err - } - - // Dispatch profile message to remove a contact from the encrypted profile part - err = m.DispatchProfileShowcase() - if err != nil { - return nil, err - } - response := &MessengerResponse{} notification.ContactVerificationStatus = verification.RequestStatusUNTRUSTWORTHY @@ -619,6 +591,12 @@ func (m *Messenger) DeclineContactVerificationRequest(ctx context.Context, id st if !ok || !contact.mutual() { return nil, errors.New("must be a mutual contact") } + contactID := verifRequest.From + contact, err = m.setContactVerificationStatus(contactID, VerificationStatusVERIFIED) + + if err != nil { + return nil, err + } if verifRequest == nil { return nil, errors.New("no contact verification found") @@ -705,37 +683,74 @@ func (m *Messenger) DeclineContactVerificationRequest(ctx context.Context, id st return response, nil } -func (m *Messenger) MarkAsTrusted(ctx context.Context, contactID string) error { - err := m.verificationDatabase.SetTrustStatus(contactID, verification.TrustStatusTRUSTED, m.getTimesource().GetCurrentTime()) +func (m *Messenger) setContactVerificationStatus(contactID string, verificationStatus VerificationStatus) (*Contact, error) { + contact, ok := m.allContacts.Load(contactID) + if !ok || !contact.mutual() { + return nil, errors.New("must be a mutual contact") + } + + contact.VerificationStatus = verificationStatus + contact.LastUpdatedLocally = m.getTimesource().GetCurrentTime() + + err := m.persistence.SaveContact(contact, nil) + if err != nil { + return nil, err + } + + err = m.syncContact(context.Background(), contact, m.dispatchMessage) + if err != nil { + return nil, err + } + + m.allContacts.Store(contact.ID, contact) + + // Dispatch profile message to save a contact to the encrypted profile part + err = m.DispatchProfileShowcase() + if err != nil { + return nil, err + } + + return contact, nil +} + +func (m *Messenger) setTrustStatusForContact(ctx context.Context, contactID string, trustStatus verification.TrustStatus) error { + currentTime := m.getTimesource().GetCurrentTime() + + err := m.verificationDatabase.SetTrustStatus(contactID, trustStatus, currentTime) if err != nil { return err } - return m.SyncTrustedUser(ctx, contactID, verification.TrustStatusTRUSTED, m.dispatchMessage) + return m.SyncTrustedUser(ctx, contactID, trustStatus, m.dispatchMessage) +} + +func (m *Messenger) MarkAsTrusted(ctx context.Context, contactID string) error { + return m.setTrustStatusForContact(ctx, contactID, verification.TrustStatusTRUSTED) } func (m *Messenger) MarkAsUntrustworthy(ctx context.Context, contactID string) error { - err := m.verificationDatabase.SetTrustStatus(contactID, verification.TrustStatusUNTRUSTWORTHY, m.getTimesource().GetCurrentTime()) - if err != nil { - return err - } - - return m.SyncTrustedUser(ctx, contactID, verification.TrustStatusUNTRUSTWORTHY, m.dispatchMessage) + return m.setTrustStatusForContact(ctx, contactID, verification.TrustStatusUNTRUSTWORTHY) } func (m *Messenger) RemoveTrustStatus(ctx context.Context, contactID string) error { - err := m.verificationDatabase.SetTrustStatus(contactID, verification.TrustStatusUNKNOWN, m.getTimesource().GetCurrentTime()) + return m.setTrustStatusForContact(ctx, contactID, verification.TrustStatusUNKNOWN) +} + +func (m *Messenger) RemoveTrustVerificationStatus(ctx context.Context, contactID string) (*MessengerResponse, error) { + err := m.setTrustStatusForContact(ctx, contactID, verification.TrustStatusUNKNOWN) if err != nil { - return err + return nil, err } - // Dispatch profile message to remove a contact from the encrypted profile part - err = m.DispatchProfileShowcase() + contact, err := m.setContactVerificationStatus(contactID, VerificationStatusUNVERIFIED) if err != nil { - return err + return nil, err } - return m.SyncTrustedUser(ctx, contactID, verification.TrustStatusUNKNOWN, m.dispatchMessage) + response := &MessengerResponse{} + response.AddContact(contact) + + return response, nil } func (m *Messenger) GetTrustStatus(contactID string) (verification.TrustStatus, error) { diff --git a/protocol/messenger_contact_verification_test.go b/protocol/messenger_contact_verification_test.go index 1079ec2ed..ee46692bb 100644 --- a/protocol/messenger_contact_verification_test.go +++ b/protocol/messenger_contact_verification_test.go @@ -116,7 +116,7 @@ func (s *MessengerVerificationRequests) mutualContact(theirMessenger *Messenger) s.Require().NotNil(resp.ActivityCenterNotifications()[0].Message) s.Require().Equal(common.ContactRequestStateAccepted, resp.ActivityCenterNotifications()[0].Message.ContactRequestState) - // Make sure the message is updated, sender s2de + // Make sure the message is updated, sender side s.Require().NotNil(resp) s.Require().Len(resp.Messages(), 2) @@ -494,6 +494,65 @@ func (s *MessengerVerificationRequests) TestUnthrustworthyVerificationRequests() s.Require().Equal(common.ContactVerificationStateUntrustworthy, resp.Messages()[0].ContactVerificationState) } +func (s *MessengerVerificationRequests) TestRemoveTrustVerificationStatus() { + // GIVEN + theirMessenger := s.newMessenger(s.shh) + defer TearDownMessenger(&s.Suite, theirMessenger) + + s.mutualContact(theirMessenger) + + theirPk := types.EncodeHex(crypto.FromECDSAPub(&theirMessenger.identity.PublicKey)) + challenge := "challenge" + + resp, err := s.m.SendContactVerificationRequest(context.Background(), theirPk, challenge) + s.Require().NoError(err) + s.Require().Len(resp.VerificationRequests(), 1) + verificationRequestID := resp.VerificationRequests()[0].ID + + // Wait for the message to reach its destination + resp, err = WaitOnMessengerResponse( + theirMessenger, + func(r *MessengerResponse) bool { + return len(r.VerificationRequests()) == 1 && len(r.ActivityCenterNotifications()) == 1 + }, + "no messages", + ) + s.Require().NoError(err) + s.Require().Len(resp.VerificationRequests(), 1) + + resp, err = theirMessenger.AcceptContactVerificationRequest(context.Background(), verificationRequestID, "hello back") + s.Require().NoError(err) + s.Require().Len(resp.VerificationRequests(), 1) + + // Wait for the message to reach its destination + _, err = WaitOnMessengerResponse( + s.m, + func(r *MessengerResponse) bool { + return len(r.VerificationRequests()) == 1 + }, + "no messages", + ) + s.Require().NoError(err) + + // Mark as trusted + _, err = s.m.VerifiedTrusted(context.Background(), &requests.VerifiedTrusted{ID: types.FromHex(verificationRequestID)}) + s.Require().NoError(err) + + // WHEN + _, err = s.m.RemoveTrustVerificationStatus(context.Background(), theirPk) + s.Require().NoError(err) + + // THEN + trustStatus, err := s.m.GetTrustStatus(theirPk) + s.Require().NoError(err) + s.Require().Equal(verification.TrustStatusUNKNOWN, trustStatus) + + contact, _ := s.m.allContacts.Load(theirPk) + s.Require().NotNil(contact) + s.Require().Equal(VerificationStatusUNVERIFIED, contact.VerificationStatus) + s.Require().Equal(verification.TrustStatusUNKNOWN, contact.TrustStatus) +} + func (s *MessengerVerificationRequests) TestDeclineVerificationRequests() { theirMessenger := s.newMessenger(s.shh) defer TearDownMessenger(&s.Suite, theirMessenger) diff --git a/services/ext/api.go b/services/ext/api.go index a204b14fa..2365a363f 100644 --- a/services/ext/api.go +++ b/services/ext/api.go @@ -999,6 +999,10 @@ func (api *PublicAPI) RemoveTrustStatus(ctx context.Context, contactID string) e return api.service.messenger.RemoveTrustStatus(ctx, contactID) } +func (api *PublicAPI) RemoveTrustVerificationStatus(ctx context.Context, contactID string) (*protocol.MessengerResponse, error) { + return api.service.messenger.RemoveTrustVerificationStatus(ctx, contactID) +} + func (api *PublicAPI) GetTrustStatus(ctx context.Context, contactID string) (verification.TrustStatus, error) { return api.service.messenger.GetTrustStatus(contactID) }