Added granual locking on Messenger

This commit is contained in:
Samuel Hawksby-Robinson 2021-03-29 16:41:30 +01:00 committed by Andrea Maria Piana
parent 9d09cb3e9a
commit 759f7bbeb3
9 changed files with 507 additions and 349 deletions

View File

@ -11,14 +11,19 @@ import (
v1protocol "github.com/status-im/status-go/protocol/v1" v1protocol "github.com/status-im/status-go/protocol/v1"
) )
var defaultSystemMessagesTranslations = map[protobuf.MembershipUpdateEvent_EventType]string{ var defaultSystemMessagesTranslations = new(systemMessageTranslationsMap)
protobuf.MembershipUpdateEvent_CHAT_CREATED: "{{from}} created the group {{name}}",
protobuf.MembershipUpdateEvent_NAME_CHANGED: "{{from}} changed the group's name to {{name}}", func init() {
protobuf.MembershipUpdateEvent_MEMBERS_ADDED: "{{from}} has invited {{members}}", defaultSystemMessagesTranslationSet := map[protobuf.MembershipUpdateEvent_EventType]string{
protobuf.MembershipUpdateEvent_MEMBER_JOINED: "{{from}} joined the group", protobuf.MembershipUpdateEvent_CHAT_CREATED: "{{from}} created the group {{name}}",
protobuf.MembershipUpdateEvent_ADMINS_ADDED: "{{from}} has made {{members}} admin", protobuf.MembershipUpdateEvent_NAME_CHANGED: "{{from}} changed the group's name to {{name}}",
protobuf.MembershipUpdateEvent_MEMBER_REMOVED: "{{member}} left the group", protobuf.MembershipUpdateEvent_MEMBERS_ADDED: "{{from}} has invited {{members}}",
protobuf.MembershipUpdateEvent_ADMIN_REMOVED: "{{member}} is not admin anymore", protobuf.MembershipUpdateEvent_MEMBER_JOINED: "{{from}} joined the group",
protobuf.MembershipUpdateEvent_ADMINS_ADDED: "{{from}} has made {{members}} admin",
protobuf.MembershipUpdateEvent_MEMBER_REMOVED: "{{member}} left the group",
protobuf.MembershipUpdateEvent_ADMIN_REMOVED: "{{member}} is not admin anymore",
}
defaultSystemMessagesTranslations.Init(defaultSystemMessagesTranslationSet)
} }
func tsprintf(format string, params map[string]string) string { func tsprintf(format string, params map[string]string) string {
@ -28,32 +33,39 @@ func tsprintf(format string, params map[string]string) string {
return format return format
} }
func eventToSystemMessage(e v1protocol.MembershipUpdateEvent, translations map[protobuf.MembershipUpdateEvent_EventType]string) *common.Message { func eventToSystemMessage(e v1protocol.MembershipUpdateEvent, translations *systemMessageTranslationsMap) *common.Message {
var text string var text string
switch e.Type { switch e.Type {
case protobuf.MembershipUpdateEvent_CHAT_CREATED: case protobuf.MembershipUpdateEvent_CHAT_CREATED:
text = tsprintf(translations[protobuf.MembershipUpdateEvent_CHAT_CREATED], map[string]string{"from": "@" + e.From, "name": e.Name}) message, _ := translations.Load(protobuf.MembershipUpdateEvent_CHAT_CREATED)
text = tsprintf(message, map[string]string{"from": "@" + e.From, "name": e.Name})
case protobuf.MembershipUpdateEvent_NAME_CHANGED: case protobuf.MembershipUpdateEvent_NAME_CHANGED:
text = tsprintf(translations[protobuf.MembershipUpdateEvent_NAME_CHANGED], map[string]string{"from": "@" + e.From, "name": e.Name}) message, _ := translations.Load(protobuf.MembershipUpdateEvent_NAME_CHANGED)
text = tsprintf(message, map[string]string{"from": "@" + e.From, "name": e.Name})
case protobuf.MembershipUpdateEvent_MEMBERS_ADDED: case protobuf.MembershipUpdateEvent_MEMBERS_ADDED:
var memberMentions []string var memberMentions []string
for _, s := range e.Members { for _, s := range e.Members {
memberMentions = append(memberMentions, "@"+s) memberMentions = append(memberMentions, "@"+s)
} }
text = tsprintf(translations[protobuf.MembershipUpdateEvent_MEMBERS_ADDED], map[string]string{"from": "@" + e.From, "members": strings.Join(memberMentions, ", ")}) message, _ := translations.Load(protobuf.MembershipUpdateEvent_MEMBERS_ADDED)
text = tsprintf(message, map[string]string{"from": "@" + e.From, "members": strings.Join(memberMentions, ", ")})
case protobuf.MembershipUpdateEvent_MEMBER_JOINED: case protobuf.MembershipUpdateEvent_MEMBER_JOINED:
text = tsprintf(translations[protobuf.MembershipUpdateEvent_MEMBER_JOINED], map[string]string{"from": "@" + e.From}) message, _ := translations.Load(protobuf.MembershipUpdateEvent_MEMBER_JOINED)
text = tsprintf(message, map[string]string{"from": "@" + e.From})
case protobuf.MembershipUpdateEvent_ADMINS_ADDED: case protobuf.MembershipUpdateEvent_ADMINS_ADDED:
var memberMentions []string var memberMentions []string
for _, s := range e.Members { for _, s := range e.Members {
memberMentions = append(memberMentions, "@"+s) memberMentions = append(memberMentions, "@"+s)
} }
text = tsprintf(translations[protobuf.MembershipUpdateEvent_ADMINS_ADDED], map[string]string{"from": "@" + e.From, "members": strings.Join(memberMentions, ", ")}) message, _ := translations.Load(protobuf.MembershipUpdateEvent_ADMINS_ADDED)
text = tsprintf(message, map[string]string{"from": "@" + e.From, "members": strings.Join(memberMentions, ", ")})
case protobuf.MembershipUpdateEvent_MEMBER_REMOVED: case protobuf.MembershipUpdateEvent_MEMBER_REMOVED:
text = tsprintf(translations[protobuf.MembershipUpdateEvent_MEMBER_REMOVED], map[string]string{"member": "@" + e.Members[0]}) message, _ := translations.Load(protobuf.MembershipUpdateEvent_MEMBER_REMOVED)
text = tsprintf(message, map[string]string{"member": "@" + e.Members[0]})
case protobuf.MembershipUpdateEvent_ADMIN_REMOVED: case protobuf.MembershipUpdateEvent_ADMIN_REMOVED:
text = tsprintf(translations[protobuf.MembershipUpdateEvent_ADMIN_REMOVED], map[string]string{"member": "@" + e.Members[0]}) message, _ := translations.Load(protobuf.MembershipUpdateEvent_ADMIN_REMOVED)
text = tsprintf(message, map[string]string{"member": "@" + e.Members[0]})
} }
timestamp := v1protocol.TimestampInMsFromTime(time.Now()) timestamp := v1protocol.TimestampInMsFromTime(time.Now())
@ -76,7 +88,7 @@ func eventToSystemMessage(e v1protocol.MembershipUpdateEvent, translations map[p
return message return message
} }
func buildSystemMessages(events []v1protocol.MembershipUpdateEvent, translations map[protobuf.MembershipUpdateEvent_EventType]string) []*common.Message { func buildSystemMessages(events []v1protocol.MembershipUpdateEvent, translations *systemMessageTranslationsMap) []*common.Message {
var messages []*common.Message var messages []*common.Message
for _, e := range events { for _, e := range events {

View File

@ -47,7 +47,7 @@ func (n NotificationBody) MarshalJSON() ([]byte, error) {
return json.Marshal(item) return json.Marshal(item)
} }
func NewMessageNotification(id string, message *common.Message, chat *Chat, contact *Contact, contacts map[string]*Contact) (*localnotifications.Notification, error) { func NewMessageNotification(id string, message *common.Message, chat *Chat, contact *Contact, contacts *contactMap) (*localnotifications.Notification, error) {
body := &NotificationBody{ body := &NotificationBody{
Message: message, Message: message,
Chat: chat, Chat: chat,
@ -66,7 +66,7 @@ func NewCommunityRequestToJoinNotification(id string, community *communities.Com
return body.toCommunityRequestToJoinNotification(id) return body.toCommunityRequestToJoinNotification(id)
} }
func (n NotificationBody) toMessageNotification(id string, contacts map[string]*Contact) (*localnotifications.Notification, error) { func (n NotificationBody) toMessageNotification(id string, contacts *contactMap) (*localnotifications.Notification, error) {
var title string var title string
if n.Chat.PrivateGroupChat() || n.Chat.Public() || n.Chat.CommunityChat() { if n.Chat.PrivateGroupChat() || n.Chat.Public() || n.Chat.CommunityChat() {
title = n.Chat.Name title = n.Chat.Name
@ -76,16 +76,16 @@ func (n NotificationBody) toMessageNotification(id string, contacts map[string]*
} }
canonicalNames := make(map[string]string) canonicalNames := make(map[string]string)
for _, id := range n.Message.Mentions { for _, mentionID := range n.Message.Mentions {
contact, ok := contacts[id] contact, ok := contacts.Load(mentionID)
if !ok { if !ok {
var err error var err error
contact, err = buildContactFromPkString(id) contact, err = buildContactFromPkString(mentionID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
canonicalNames[id] = contact.CanonicalName() canonicalNames[mentionID] = contact.CanonicalName()
} }
simplifiedText, err := n.Message.GetSimplifiedText(canonicalNames) simplifiedText, err := n.Message.GetSimplifiedText(canonicalNames)

View File

@ -49,7 +49,7 @@ func newMessageHandler(identity *ecdsa.PrivateKey, logger *zap.Logger, persisten
// HandleMembershipUpdate updates a Chat instance according to the membership updates. // HandleMembershipUpdate updates a Chat instance according to the membership updates.
// It retrieves chat, if exists, and merges membership updates from the message. // It retrieves chat, if exists, and merges membership updates from the message.
// Finally, the Chat is updated with the new group events. // Finally, the Chat is updated with the new group events.
func (m *MessageHandler) HandleMembershipUpdate(messageState *ReceivedMessageState, chat *Chat, rawMembershipUpdate protobuf.MembershipUpdateMessage, translations map[protobuf.MembershipUpdateEvent_EventType]string) error { func (m *MessageHandler) HandleMembershipUpdate(messageState *ReceivedMessageState, chat *Chat, rawMembershipUpdate protobuf.MembershipUpdateMessage, translations *systemMessageTranslationsMap) error {
var group *v1protocol.Group var group *v1protocol.Group
var err error var err error
@ -142,7 +142,7 @@ func (m *MessageHandler) HandleMembershipUpdate(messageState *ReceivedMessageSta
} }
// Store in chats map as it might be a new one // Store in chats map as it might be a new one
messageState.AllChats[chat.ID] = chat messageState.AllChats.Store(chat.ID, chat)
messageState.Response.AddChat(chat) messageState.Response.AddChat(chat)
if message.Message != nil { if message.Message != nil {
@ -179,7 +179,7 @@ func (m *MessageHandler) handleCommandMessage(state *ReceivedMessageState, messa
// Set the LocalChatID for the message // Set the LocalChatID for the message
message.LocalChatID = chat.ID message.LocalChatID = chat.ID
if c, ok := state.AllChats[chat.ID]; ok { if c, ok := state.AllChats.Load(chat.ID); ok {
chat = c chat = c
} }
@ -204,7 +204,8 @@ func (m *MessageHandler) handleCommandMessage(state *ReceivedMessageState, messa
chat.Active = true chat.Active = true
// Set in the modified maps chat // Set in the modified maps chat
state.Response.AddChat(chat) state.Response.AddChat(chat)
state.AllChats[chat.ID] = chat // TODO(samyoul) remove storing of an updated reference pointer?
state.AllChats.Store(chat.ID, chat)
// Add to response // Add to response
if message != nil { if message != nil {
@ -214,14 +215,14 @@ func (m *MessageHandler) handleCommandMessage(state *ReceivedMessageState, messa
} }
func (m *MessageHandler) HandleSyncInstallationContact(state *ReceivedMessageState, message protobuf.SyncInstallationContact) error { func (m *MessageHandler) HandleSyncInstallationContact(state *ReceivedMessageState, message protobuf.SyncInstallationContact) error {
chat, ok := state.AllChats[state.CurrentMessageState.Contact.ID] chat, ok := state.AllChats.Load(state.CurrentMessageState.Contact.ID)
if !ok { if !ok {
chat = OneToOneFromPublicKey(state.CurrentMessageState.PublicKey, state.Timesource) chat = OneToOneFromPublicKey(state.CurrentMessageState.PublicKey, state.Timesource)
// We don't want to show the chat to the user // We don't want to show the chat to the user
chat.Active = false chat.Active = false
} }
contact, ok := state.AllContacts[message.Id] contact, ok := state.AllContacts.Load(message.Id)
if !ok { if !ok {
var err error var err error
contact, err = buildContactFromPkString(message.Id) contact, err = buildContactFromPkString(message.Id)
@ -241,25 +242,26 @@ func (m *MessageHandler) HandleSyncInstallationContact(state *ReceivedMessageSta
contact.LastUpdated = message.Clock contact.LastUpdated = message.Clock
contact.LocalNickname = message.LocalNickname contact.LocalNickname = message.LocalNickname
state.ModifiedContacts[contact.ID] = true state.ModifiedContacts.Store(contact.ID, true)
state.AllContacts[contact.ID] = contact state.AllContacts.Store(contact.ID, contact)
} }
state.AllChats[chat.ID] = chat // TODO(samyoul) remove storing of an updated reference pointer?
state.AllChats.Store(chat.ID, chat)
return nil return nil
} }
func (m *MessageHandler) HandleSyncInstallationPublicChat(state *ReceivedMessageState, message protobuf.SyncInstallationPublicChat) bool { func (m *MessageHandler) HandleSyncInstallationPublicChat(state *ReceivedMessageState, message protobuf.SyncInstallationPublicChat) bool {
chatID := message.Id chatID := message.Id
_, ok := state.AllChats[chatID] _, ok := state.AllChats.Load(chatID)
if ok { if ok {
return false return false
} }
chat := CreatePublicChat(chatID, state.Timesource) chat := CreatePublicChat(chatID, state.Timesource)
state.AllChats[chat.ID] = chat state.AllChats.Store(chat.ID, chat)
state.Response.AddChat(chat) state.Response.AddChat(chat)
return true return true
@ -268,7 +270,7 @@ func (m *MessageHandler) HandleSyncInstallationPublicChat(state *ReceivedMessage
func (m *MessageHandler) HandleContactUpdate(state *ReceivedMessageState, message protobuf.ContactUpdate) error { func (m *MessageHandler) HandleContactUpdate(state *ReceivedMessageState, message protobuf.ContactUpdate) error {
logger := m.logger.With(zap.String("site", "HandleContactUpdate")) logger := m.logger.With(zap.String("site", "HandleContactUpdate"))
contact := state.CurrentMessageState.Contact contact := state.CurrentMessageState.Contact
chat, ok := state.AllChats[contact.ID] chat, ok := state.AllChats.Load(contact.ID)
if !ok { if !ok {
chat = OneToOneFromPublicKey(state.CurrentMessageState.PublicKey, state.Timesource) chat = OneToOneFromPublicKey(state.CurrentMessageState.PublicKey, state.Timesource)
// We don't want to show the chat to the user // We don't want to show the chat to the user
@ -287,8 +289,8 @@ func (m *MessageHandler) HandleContactUpdate(state *ReceivedMessageState, messag
contact.ENSVerified = false contact.ENSVerified = false
} }
contact.LastUpdated = message.Clock contact.LastUpdated = message.Clock
state.ModifiedContacts[contact.ID] = true state.ModifiedContacts.Store(contact.ID, true)
state.AllContacts[contact.ID] = contact state.AllContacts.Store(contact.ID, contact)
} }
if chat.LastClockValue < message.Clock { if chat.LastClockValue < message.Clock {
@ -296,7 +298,8 @@ func (m *MessageHandler) HandleContactUpdate(state *ReceivedMessageState, messag
} }
state.Response.AddChat(chat) state.Response.AddChat(chat)
state.AllChats[chat.ID] = chat // TODO(samyoul) remove storing of an updated reference pointer?
state.AllChats.Store(chat.ID, chat)
return nil return nil
} }
@ -308,7 +311,7 @@ func (m *MessageHandler) HandlePairInstallation(state *ReceivedMessageState, mes
return err return err
} }
installation, ok := state.AllInstallations[message.InstallationId] installation, ok := state.AllInstallations.Load(message.InstallationId)
if !ok { if !ok {
return errors.New("installation not found") return errors.New("installation not found")
} }
@ -319,8 +322,9 @@ func (m *MessageHandler) HandlePairInstallation(state *ReceivedMessageState, mes
} }
installation.InstallationMetadata = metadata installation.InstallationMetadata = metadata
state.AllInstallations[message.InstallationId] = installation // TODO(samyoul) remove storing of an updated reference pointer?
state.ModifiedInstallations[message.InstallationId] = true state.AllInstallations.Store(message.InstallationId, installation)
state.ModifiedInstallations.Store(message.InstallationId, true)
return nil return nil
} }
@ -348,16 +352,18 @@ func (m *MessageHandler) HandleCommunityDescription(state *ReceivedMessageState,
var chatIDs []string var chatIDs []string
for i, chat := range chats { for i, chat := range chats {
oldChat, ok := state.AllChats[chat.ID] oldChat, ok := state.AllChats.Load(chat.ID)
if !ok { if !ok {
// Beware, don't use the reference in the range (i.e chat) as it's a shallow copy // Beware, don't use the reference in the range (i.e chat) as it's a shallow copy
state.AllChats[chat.ID] = chats[i] state.AllChats.Store(chat.ID, chats[i])
state.Response.AddChat(chat) state.Response.AddChat(chat)
chatIDs = append(chatIDs, chat.ID) chatIDs = append(chatIDs, chat.ID)
// Update name, currently is the only field is mutable // Update name, currently is the only field is mutable
} else if oldChat.Name != chat.Name { } else if oldChat.Name != chat.Name {
state.AllChats[chat.ID].Name = chat.Name oldChat.Name = chat.Name
// TODO(samyoul) remove storing of an updated reference pointer?
state.AllChats.Store(chat.ID, oldChat)
state.Response.AddChat(chat) state.Response.AddChat(chat)
} }
} }
@ -422,7 +428,7 @@ func (m *MessageHandler) HandleCommunityRequestToJoin(state *ReceivedMessageStat
contactID := contactIDFromPublicKey(signer) contactID := contactIDFromPublicKey(signer)
contact := state.AllContacts[contactID] contact, _ := state.AllContacts.Load(contactID)
state.Response.AddNotification(NewCommunityRequestToJoinNotification(requestToJoin.ID.String(), community, contact)) state.Response.AddNotification(NewCommunityRequestToJoinNotification(requestToJoin.ID.String(), community, contact))
@ -478,7 +484,7 @@ func (m *MessageHandler) HandleChatMessage(state *ReceivedMessageState) error {
// Set the LocalChatID for the message // Set the LocalChatID for the message
receivedMessage.LocalChatID = chat.ID receivedMessage.LocalChatID = chat.ID
if c, ok := state.AllChats[chat.ID]; ok { if c, ok := state.AllChats.Load(chat.ID); ok {
chat = c chat = c
} }
@ -502,7 +508,8 @@ func (m *MessageHandler) HandleChatMessage(state *ReceivedMessageState) error {
chat.Active = true chat.Active = true
// Set in the modified maps chat // Set in the modified maps chat
state.Response.AddChat(chat) state.Response.AddChat(chat)
state.AllChats[chat.ID] = chat // TODO(samyoul) remove storing of an updated reference pointer?
state.AllChats.Store(chat.ID, chat)
contact := state.CurrentMessageState.Contact contact := state.CurrentMessageState.Contact
if receivedMessage.EnsName != "" { if receivedMessage.EnsName != "" {
@ -513,8 +520,8 @@ func (m *MessageHandler) HandleChatMessage(state *ReceivedMessageState) error {
// If oldRecord is nil, a new verification process will take place // If oldRecord is nil, a new verification process will take place
// so we reset the record // so we reset the record
contact.ENSVerified = false contact.ENSVerified = false
state.ModifiedContacts[contact.ID] = true state.ModifiedContacts.Store(contact.ID, true)
state.AllContacts[contact.ID] = contact state.AllContacts.Store(contact.ID, contact)
} }
} }
@ -738,7 +745,7 @@ func (m *MessageHandler) HandleDeclineRequestTransaction(messageState *ReceivedM
return m.handleCommandMessage(messageState, oldMessage) return m.handleCommandMessage(messageState, oldMessage)
} }
func (m *MessageHandler) matchChatEntity(chatEntity common.ChatEntity, chats map[string]*Chat, timesource common.TimeSource) (*Chat, error) { func (m *MessageHandler) matchChatEntity(chatEntity common.ChatEntity, chats *chatMap, timesource common.TimeSource) (*Chat, error) {
if chatEntity.GetSigPubKey() == nil { if chatEntity.GetSigPubKey() == nil {
m.logger.Error("public key can't be empty") m.logger.Error("public key can't be empty")
return nil, errors.New("received a chatEntity with empty public key") return nil, errors.New("received a chatEntity with empty public key")
@ -749,8 +756,8 @@ func (m *MessageHandler) matchChatEntity(chatEntity common.ChatEntity, chats map
// For public messages, all outgoing and incoming messages have the same chatID // For public messages, all outgoing and incoming messages have the same chatID
// equal to a public chat name. // equal to a public chat name.
chatID := chatEntity.GetChatId() chatID := chatEntity.GetChatId()
chat := chats[chatID] chat, ok := chats.Load(chatID)
if chat == nil { if !ok {
return nil, errors.New("received a public chatEntity from non-existing chat") return nil, errors.New("received a public chatEntity from non-existing chat")
} }
return chat, nil return chat, nil
@ -758,8 +765,8 @@ func (m *MessageHandler) matchChatEntity(chatEntity common.ChatEntity, chats map
// It's a private message coming from us so we rely on Message.ChatID // It's a private message coming from us so we rely on Message.ChatID
// If chat does not exist, it should be created to support multidevice synchronization. // If chat does not exist, it should be created to support multidevice synchronization.
chatID := chatEntity.GetChatId() chatID := chatEntity.GetChatId()
chat := chats[chatID] chat, ok := chats.Load(chatID)
if chat == nil { if !ok {
if len(chatID) != PubKeyStringLength { if len(chatID) != PubKeyStringLength {
return nil, errors.New("invalid pubkey length") return nil, errors.New("invalid pubkey length")
} }
@ -780,16 +787,16 @@ func (m *MessageHandler) matchChatEntity(chatEntity common.ChatEntity, chats map
// It's an incoming private chatEntity. ChatID is calculated from the signature. // It's an incoming private chatEntity. ChatID is calculated from the signature.
// If a chat does not exist, a new one is created and saved. // If a chat does not exist, a new one is created and saved.
chatID := contactIDFromPublicKey(chatEntity.GetSigPubKey()) chatID := contactIDFromPublicKey(chatEntity.GetSigPubKey())
chat := chats[chatID] chat, ok := chats.Load(chatID)
if chat == nil { if !ok {
// TODO: this should be a three-word name used in the mobile client // TODO: this should be a three-word name used in the mobile client
chat = CreateOneToOneChat(chatID[:8], chatEntity.GetSigPubKey(), timesource) chat = CreateOneToOneChat(chatID[:8], chatEntity.GetSigPubKey(), timesource)
} }
return chat, nil return chat, nil
case chatEntity.GetMessageType() == protobuf.MessageType_COMMUNITY_CHAT: case chatEntity.GetMessageType() == protobuf.MessageType_COMMUNITY_CHAT:
chatID := chatEntity.GetChatId() chatID := chatEntity.GetChatId()
chat := chats[chatID] chat, ok := chats.Load(chatID)
if chat == nil { if !ok {
return nil, errors.New("received community chat chatEntity for non-existing chat") return nil, errors.New("received community chat chatEntity for non-existing chat")
} }
@ -817,8 +824,8 @@ func (m *MessageHandler) matchChatEntity(chatEntity common.ChatEntity, chats map
// In the case of a group chatEntity, ChatID is the same for all messages belonging to a group. // In the case of a group chatEntity, ChatID is the same for all messages belonging to a group.
// It needs to be verified if the signature public key belongs to the chat. // It needs to be verified if the signature public key belongs to the chat.
chatID := chatEntity.GetChatId() chatID := chatEntity.GetChatId()
chat := chats[chatID] chat, ok := chats.Load(chatID)
if chat == nil { if !ok {
return nil, errors.New("received group chat chatEntity for non-existing chat") return nil, errors.New("received group chat chatEntity for non-existing chat")
} }
@ -906,7 +913,8 @@ func (m *MessageHandler) HandleEmojiReaction(state *ReceivedMessageState, pbEmoj
} }
state.Response.AddChat(chat) state.Response.AddChat(chat)
state.AllChats[chat.ID] = chat // TODO(samyoul) remove storing of an updated reference pointer?
state.AllChats.Store(chat.ID, chat)
// save emoji reaction // save emoji reaction
err = m.persistence.SaveEmojiReaction(emojiReaction) err = m.persistence.SaveEmojiReaction(emojiReaction)
@ -980,8 +988,8 @@ func (m *MessageHandler) HandleChatIdentity(state *ReceivedMessageState, ci prot
contact.Images[imageType] = images.IdentityImage{Name: imageType, Payload: image.Payload} contact.Images[imageType] = images.IdentityImage{Name: imageType, Payload: image.Payload}
} }
state.ModifiedContacts[contact.ID] = true state.ModifiedContacts.Store(contact.ID, true)
state.AllContacts[contact.ID] = contact state.AllContacts.Store(contact.ID, contact)
} }
return nil return nil

View File

@ -74,6 +74,7 @@ var messageCacheIntervalMs uint64 = 1000 * 60 * 60 * 48
// because installations are managed by the user. // because installations are managed by the user.
// Similarly, it needs to expose an interface to manage // Similarly, it needs to expose an interface to manage
// mailservers because they can also be managed by the user. // mailservers because they can also be managed by the user.
// TODO we may want to change the maps into sync.Maps if we start getting unexpected locks or weird collision bugs
type Messenger struct { type Messenger struct {
node types.Node node types.Node
config *config config *config
@ -92,11 +93,11 @@ type Messenger struct {
featureFlags common.FeatureFlags featureFlags common.FeatureFlags
shutdownTasks []func() error shutdownTasks []func() error
shouldPublishContactCode bool shouldPublishContactCode bool
systemMessagesTranslations map[protobuf.MembershipUpdateEvent_EventType]string systemMessagesTranslations *systemMessageTranslationsMap
allChats map[string]*Chat allChats *chatMap
allContacts map[string]*Contact allContacts *contactMap
allInstallations map[string]*multidevice.Installation allInstallations *installationMap
modifiedInstallations map[string]bool modifiedInstallations *stringBoolMap
installationID string installationID string
mailserver []byte mailserver []byte
database *sql.DB database *sql.DB
@ -106,6 +107,7 @@ type Messenger struct {
mailserversDatabase *mailservers.Database mailserversDatabase *mailservers.Database
quit chan struct{} quit chan struct{}
// TODO(samyoul) Determine if/how the remaining usage of this mutex can be removed
mutex sync.Mutex mutex sync.Mutex
} }
@ -299,11 +301,11 @@ func NewMessenger(
ensVerifier: ensVerifier, ensVerifier: ensVerifier,
featureFlags: c.featureFlags, featureFlags: c.featureFlags,
systemMessagesTranslations: c.systemMessagesTranslations, systemMessagesTranslations: c.systemMessagesTranslations,
allChats: make(map[string]*Chat), allChats: new(chatMap),
allContacts: make(map[string]*Contact), allContacts: new(contactMap),
allInstallations: make(map[string]*multidevice.Installation), allInstallations: new(installationMap),
installationID: installationID, installationID: installationID,
modifiedInstallations: make(map[string]bool), modifiedInstallations: new(stringBoolMap),
verifyTransactionClient: c.verifyTransactionClient, verifyTransactionClient: c.verifyTransactionClient,
database: database, database: database,
multiAccounts: c.multiAccount, multiAccounts: c.multiAccount,
@ -780,9 +782,9 @@ func (m *Messenger) handleSharedSecrets(secrets []*sharedsecret.Secret) error {
func (m *Messenger) handleInstallations(installations []*multidevice.Installation) { func (m *Messenger) handleInstallations(installations []*multidevice.Installation) {
for _, installation := range installations { for _, installation := range installations {
if installation.Identity == contactIDFromPublicKey(&m.identity.PublicKey) { if installation.Identity == contactIDFromPublicKey(&m.identity.PublicKey) {
if _, ok := m.allInstallations[installation.ID]; !ok { if _, ok := m.allInstallations.Load(installation.ID); !ok {
m.allInstallations[installation.ID] = installation m.allInstallations.Store(installation.ID, installation)
m.modifiedInstallations[installation.ID] = true m.modifiedInstallations.Store(installation.ID, true)
} }
} }
} }
@ -811,13 +813,10 @@ func (m *Messenger) handleEncryptionLayerSubscriptions(subscriptions *encryption
} }
func (m *Messenger) handleENSVerified(records []*ens.VerificationRecord) { func (m *Messenger) handleENSVerified(records []*ens.VerificationRecord) {
m.mutex.Lock()
defer m.mutex.Unlock()
var contacts []*Contact var contacts []*Contact
for _, record := range records { for _, record := range records {
m.logger.Info("handling record", zap.Any("record", record)) m.logger.Info("handling record", zap.Any("record", record))
contact, ok := m.allContacts[record.PublicKey] contact, ok := m.allContacts.Load(record.PublicKey)
if !ok { if !ok {
m.logger.Info("contact not found") m.logger.Info("contact not found")
continue continue
@ -1003,7 +1002,7 @@ func (m *Messenger) Init() error {
continue continue
} }
m.allChats[chat.ID] = chat m.allChats.Store(chat.ID, chat)
if !chat.Active || chat.Timeline() { if !chat.Active || chat.Timeline() {
continue continue
} }
@ -1038,7 +1037,7 @@ func (m *Messenger) Init() error {
return err return err
} }
for _, contact := range contacts { for _, contact := range contacts {
m.allContacts[contact.ID] = contact m.allContacts.Store(contact.ID, contact)
// We only need filters for contacts added by us and not blocked. // We only need filters for contacts added by us and not blocked.
if !contact.IsAdded() || contact.IsBlocked() { if !contact.IsAdded() || contact.IsBlocked() {
continue continue
@ -1057,7 +1056,7 @@ func (m *Messenger) Init() error {
} }
for _, installation := range installations { for _, installation := range installations {
m.allInstallations[installation.ID] = installation m.allInstallations.Store(installation.ID, installation)
} }
_, err = m.transport.InitFilters(publicChatIDs, publicKeys) _, err = m.transport.InitFilters(publicChatIDs, publicKeys)
@ -1086,10 +1085,7 @@ func (m *Messenger) Shutdown() (err error) {
} }
func (m *Messenger) EnableInstallation(id string) error { func (m *Messenger) EnableInstallation(id string) error {
m.mutex.Lock() installation, ok := m.allInstallations.Load(id)
defer m.mutex.Unlock()
installation, ok := m.allInstallations[id]
if !ok { if !ok {
return errors.New("no installation found") return errors.New("no installation found")
} }
@ -1099,15 +1095,13 @@ func (m *Messenger) EnableInstallation(id string) error {
return err return err
} }
installation.Enabled = true installation.Enabled = true
m.allInstallations[id] = installation // TODO(samyoul) remove storing of an updated reference pointer?
m.allInstallations.Store(id, installation)
return nil return nil
} }
func (m *Messenger) DisableInstallation(id string) error { func (m *Messenger) DisableInstallation(id string) error {
m.mutex.Lock() installation, ok := m.allInstallations.Load(id)
defer m.mutex.Unlock()
installation, ok := m.allInstallations[id]
if !ok { if !ok {
return errors.New("no installation found") return errors.New("no installation found")
} }
@ -1117,25 +1111,25 @@ func (m *Messenger) DisableInstallation(id string) error {
return err return err
} }
installation.Enabled = false installation.Enabled = false
m.allInstallations[id] = installation // TODO(samyoul) remove storing of an updated reference pointer?
m.allInstallations.Store(id, installation)
return nil return nil
} }
func (m *Messenger) Installations() []*multidevice.Installation { func (m *Messenger) Installations() []*multidevice.Installation {
m.mutex.Lock() installations := make([]*multidevice.Installation, m.allInstallations.Len())
defer m.mutex.Unlock()
installations := make([]*multidevice.Installation, len(m.allInstallations))
var i = 0 var i = 0
for _, installation := range m.allInstallations { m.allInstallations.Range(func(installationID string, installation *multidevice.Installation) (shouldContinue bool) {
installations[i] = installation installations[i] = installation
i++ i++
} return true
})
return installations return installations
} }
func (m *Messenger) setInstallationMetadata(id string, data *multidevice.InstallationMetadata) error { func (m *Messenger) setInstallationMetadata(id string, data *multidevice.InstallationMetadata) error {
installation, ok := m.allInstallations[id] installation, ok := m.allInstallations.Load(id)
if !ok { if !ok {
return errors.New("no installation found") return errors.New("no installation found")
} }
@ -1145,8 +1139,6 @@ func (m *Messenger) setInstallationMetadata(id string, data *multidevice.Install
} }
func (m *Messenger) SetInstallationMetadata(id string, data *multidevice.InstallationMetadata) error { func (m *Messenger) SetInstallationMetadata(id string, data *multidevice.InstallationMetadata) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.setInstallationMetadata(id, data) return m.setInstallationMetadata(id, data)
} }
@ -1194,9 +1186,6 @@ func (m *Messenger) Leave(chat Chat) error {
} }
func (m *Messenger) CreateGroupChatWithMembers(ctx context.Context, name string, members []string) (*MessengerResponse, error) { func (m *Messenger) CreateGroupChatWithMembers(ctx context.Context, name string, members []string) (*MessengerResponse, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
var response MessengerResponse var response MessengerResponse
logger := m.logger.With(zap.String("site", "CreateGroupChatWithMembers")) logger := m.logger.With(zap.String("site", "CreateGroupChatWithMembers"))
logger.Info("Creating group chat", zap.String("name", name), zap.Any("members", members)) logger.Info("Creating group chat", zap.String("name", name), zap.Any("members", members))
@ -1239,7 +1228,8 @@ func (m *Messenger) CreateGroupChatWithMembers(ctx context.Context, name string,
if err != nil { if err != nil {
return nil, err return nil, err
} }
m.allChats[chat.ID] = &chat
m.allChats.Store(chat.ID, &chat)
_, err = m.dispatchMessage(ctx, common.RawMessage{ _, err = m.dispatchMessage(ctx, common.RawMessage{
LocalChatID: chat.ID, LocalChatID: chat.ID,
@ -1265,9 +1255,6 @@ func (m *Messenger) CreateGroupChatWithMembers(ctx context.Context, name string,
} }
func (m *Messenger) CreateGroupChatFromInvitation(name string, chatID string, adminPK string) (*MessengerResponse, error) { func (m *Messenger) CreateGroupChatFromInvitation(name string, chatID string, adminPK string) (*MessengerResponse, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
var response MessengerResponse var response MessengerResponse
logger := m.logger.With(zap.String("site", "CreateGroupChatFromInvitation")) logger := m.logger.With(zap.String("site", "CreateGroupChatFromInvitation"))
logger.Info("Creating group chat from invitation", zap.String("name", name)) logger.Info("Creating group chat from invitation", zap.String("name", name))
@ -1282,13 +1269,10 @@ func (m *Messenger) CreateGroupChatFromInvitation(name string, chatID string, ad
} }
func (m *Messenger) RemoveMemberFromGroupChat(ctx context.Context, chatID string, member string) (*MessengerResponse, error) { func (m *Messenger) RemoveMemberFromGroupChat(ctx context.Context, chatID string, member string) (*MessengerResponse, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
var response MessengerResponse var response MessengerResponse
logger := m.logger.With(zap.String("site", "RemoveMemberFromGroupChat")) logger := m.logger.With(zap.String("site", "RemoveMemberFromGroupChat"))
logger.Info("Removing member form group chat", zap.String("chatID", chatID), zap.String("member", member)) logger.Info("Removing member form group chat", zap.String("chatID", chatID), zap.String("member", member))
chat, ok := m.allChats[chatID] chat, ok := m.allChats.Load(chatID)
if !ok { if !ok {
return nil, ErrChatNotFound return nil, ErrChatNotFound
} }
@ -1346,13 +1330,10 @@ func (m *Messenger) RemoveMemberFromGroupChat(ctx context.Context, chatID string
} }
func (m *Messenger) AddMembersToGroupChat(ctx context.Context, chatID string, members []string) (*MessengerResponse, error) { func (m *Messenger) AddMembersToGroupChat(ctx context.Context, chatID string, members []string) (*MessengerResponse, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
var response MessengerResponse var response MessengerResponse
logger := m.logger.With(zap.String("site", "AddMembersFromGroupChat")) logger := m.logger.With(zap.String("site", "AddMembersFromGroupChat"))
logger.Info("Adding members form group chat", zap.String("chatID", chatID), zap.Any("members", members)) logger.Info("Adding members form group chat", zap.String("chatID", chatID), zap.Any("members", members))
chat, ok := m.allChats[chatID] chat, ok := m.allChats.Load(chatID)
if !ok { if !ok {
return nil, ErrChatNotFound return nil, ErrChatNotFound
} }
@ -1438,10 +1419,7 @@ func (m *Messenger) ChangeGroupChatName(ctx context.Context, chatID string, name
logger := m.logger.With(zap.String("site", "ChangeGroupChatName")) logger := m.logger.With(zap.String("site", "ChangeGroupChatName"))
logger.Info("Changing group chat name", zap.String("chatID", chatID), zap.String("name", name)) logger.Info("Changing group chat name", zap.String("chatID", chatID), zap.String("name", name))
m.mutex.Lock() chat, ok := m.allChats.Load(chatID)
defer m.mutex.Unlock()
chat, ok := m.allChats[chatID]
if !ok { if !ok {
return nil, ErrChatNotFound return nil, ErrChatNotFound
} }
@ -1505,13 +1483,10 @@ func (m *Messenger) SendGroupChatInvitationRequest(ctx context.Context, chatID s
logger.Info("Sending group chat invitation request", zap.String("chatID", chatID), logger.Info("Sending group chat invitation request", zap.String("chatID", chatID),
zap.String("adminPK", adminPK), zap.String("message", message)) zap.String("adminPK", adminPK), zap.String("message", message))
m.mutex.Lock()
defer m.mutex.Unlock()
var response MessengerResponse var response MessengerResponse
// Get chat and clock // Get chat and clock
chat, ok := m.allChats[chatID] chat, ok := m.allChats.Load(chatID)
if !ok { if !ok {
return nil, ErrChatNotFound return nil, ErrChatNotFound
} }
@ -1579,9 +1554,6 @@ func (m *Messenger) SendGroupChatInvitationRejection(ctx context.Context, invita
logger := m.logger.With(zap.String("site", "SendGroupChatInvitationRejection")) logger := m.logger.With(zap.String("site", "SendGroupChatInvitationRejection"))
logger.Info("Sending group chat invitation reject", zap.String("invitationRequestID", invitationRequestID)) logger.Info("Sending group chat invitation reject", zap.String("invitationRequestID", invitationRequestID))
m.mutex.Lock()
defer m.mutex.Unlock()
invitationR, err := m.persistence.InvitationByID(invitationRequestID) invitationR, err := m.persistence.InvitationByID(invitationRequestID)
if err != nil { if err != nil {
return nil, err return nil, err
@ -1590,7 +1562,7 @@ func (m *Messenger) SendGroupChatInvitationRejection(ctx context.Context, invita
invitationR.State = protobuf.GroupChatInvitation_REJECTED invitationR.State = protobuf.GroupChatInvitation_REJECTED
// Get chat and clock // Get chat and clock
chat, ok := m.allChats[invitationR.ChatId] chat, ok := m.allChats.Load(invitationR.ChatId)
if !ok { if !ok {
return nil, ErrChatNotFound return nil, ErrChatNotFound
} }
@ -1645,14 +1617,11 @@ func (m *Messenger) SendGroupChatInvitationRejection(ctx context.Context, invita
} }
func (m *Messenger) AddAdminsToGroupChat(ctx context.Context, chatID string, members []string) (*MessengerResponse, error) { func (m *Messenger) AddAdminsToGroupChat(ctx context.Context, chatID string, members []string) (*MessengerResponse, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
var response MessengerResponse var response MessengerResponse
logger := m.logger.With(zap.String("site", "AddAdminsToGroupChat")) logger := m.logger.With(zap.String("site", "AddAdminsToGroupChat"))
logger.Info("Add admins to group chat", zap.String("chatID", chatID), zap.Any("members", members)) logger.Info("Add admins to group chat", zap.String("chatID", chatID), zap.Any("members", members))
chat, ok := m.allChats[chatID] chat, ok := m.allChats.Load(chatID)
if !ok { if !ok {
return nil, ErrChatNotFound return nil, ErrChatNotFound
} }
@ -1709,12 +1678,9 @@ func (m *Messenger) AddAdminsToGroupChat(ctx context.Context, chatID string, mem
} }
func (m *Messenger) ConfirmJoiningGroup(ctx context.Context, chatID string) (*MessengerResponse, error) { func (m *Messenger) ConfirmJoiningGroup(ctx context.Context, chatID string) (*MessengerResponse, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
var response MessengerResponse var response MessengerResponse
chat, ok := m.allChats[chatID] chat, ok := m.allChats.Load(chatID)
if !ok { if !ok {
return nil, ErrChatNotFound return nil, ErrChatNotFound
} }
@ -1775,12 +1741,9 @@ func (m *Messenger) ConfirmJoiningGroup(ctx context.Context, chatID string) (*Me
} }
func (m *Messenger) LeaveGroupChat(ctx context.Context, chatID string, remove bool) (*MessengerResponse, error) { func (m *Messenger) LeaveGroupChat(ctx context.Context, chatID string, remove bool) (*MessengerResponse, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
var response MessengerResponse var response MessengerResponse
chat, ok := m.allChats[chatID] chat, ok := m.allChats.Load(chatID)
if !ok { if !ok {
return nil, ErrChatNotFound return nil, ErrChatNotFound
} }
@ -1861,7 +1824,7 @@ func (m *Messenger) reSendRawMessage(ctx context.Context, messageID string) erro
return err return err
} }
chat, ok := m.allChats[message.LocalChatID] chat, ok := m.allChats.Load(message.LocalChatID)
if !ok { if !ok {
return errors.New("chat not found") return errors.New("chat not found")
} }
@ -1879,19 +1842,17 @@ func (m *Messenger) reSendRawMessage(ctx context.Context, messageID string) erro
// ReSendChatMessage pulls a message from the database and sends it again // ReSendChatMessage pulls a message from the database and sends it again
func (m *Messenger) ReSendChatMessage(ctx context.Context, messageID string) error { func (m *Messenger) ReSendChatMessage(ctx context.Context, messageID string) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.reSendRawMessage(ctx, messageID) return m.reSendRawMessage(ctx, messageID)
} }
func (m *Messenger) hasPairedDevices() bool { func (m *Messenger) hasPairedDevices() bool {
var count int var count int
for _, i := range m.allInstallations { m.allInstallations.Range(func(installationID string, installation *multidevice.Installation) (shouldContinue bool) {
if i.Enabled { if installation.Enabled {
count++ count++
} }
} return true
})
return count > 1 return count > 1
} }
@ -1931,7 +1892,7 @@ func (m *Messenger) dispatchMessage(ctx context.Context, spec common.RawMessage)
var err error var err error
var id []byte var id []byte
logger := m.logger.With(zap.String("site", "dispatchMessage"), zap.String("chatID", spec.LocalChatID)) logger := m.logger.With(zap.String("site", "dispatchMessage"), zap.String("chatID", spec.LocalChatID))
chat, ok := m.allChats[spec.LocalChatID] chat, ok := m.allChats.Load(spec.LocalChatID)
if !ok { if !ok {
return spec, errors.New("no chat found") return spec, errors.New("no chat found")
} }
@ -2051,20 +2012,15 @@ func (m *Messenger) dispatchMessage(ctx context.Context, spec common.RawMessage)
// SendChatMessage takes a minimal message and sends it based on the corresponding chat // SendChatMessage takes a minimal message and sends it based on the corresponding chat
func (m *Messenger) SendChatMessage(ctx context.Context, message *common.Message) (*MessengerResponse, error) { func (m *Messenger) SendChatMessage(ctx context.Context, message *common.Message) (*MessengerResponse, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.sendChatMessage(ctx, message) return m.sendChatMessage(ctx, message)
} }
// SendChatMessages takes a array of messages and sends it based on the corresponding chats // SendChatMessages takes a array of messages and sends it based on the corresponding chats
func (m *Messenger) SendChatMessages(ctx context.Context, messages []*common.Message) (*MessengerResponse, error) { func (m *Messenger) SendChatMessages(ctx context.Context, messages []*common.Message) (*MessengerResponse, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
var response MessengerResponse var response MessengerResponse
for _, message := range messages { for _, message := range messages {
messageResponse, err := m.sendChatMessage(ctx, message) messageResponse, err := m.SendChatMessage(ctx, message)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -2143,7 +2099,7 @@ func (m *Messenger) sendChatMessage(ctx context.Context, message *common.Message
var response MessengerResponse var response MessengerResponse
// A valid added chat is required. // A valid added chat is required.
chat, ok := m.allChats[message.ChatId] chat, ok := m.allChats.Load(message.ChatId)
if !ok { if !ok {
return nil, errors.New("Chat not found") return nil, errors.New("Chat not found")
} }
@ -2205,44 +2161,45 @@ func (m *Messenger) sendChatMessage(ctx context.Context, message *common.Message
// SyncDevices sends all public chats and contacts to paired devices // SyncDevices sends all public chats and contacts to paired devices
// TODO remove use of photoPath in contacts // TODO remove use of photoPath in contacts
func (m *Messenger) SyncDevices(ctx context.Context, ensName, photoPath string) error { func (m *Messenger) SyncDevices(ctx context.Context, ensName, photoPath string) (err error) {
m.mutex.Lock()
defer m.mutex.Unlock()
myID := contactIDFromPublicKey(&m.identity.PublicKey) myID := contactIDFromPublicKey(&m.identity.PublicKey)
if _, err := m.sendContactUpdate(ctx, myID, ensName, photoPath); err != nil { if _, err = m.sendContactUpdate(ctx, myID, ensName, photoPath); err != nil {
return err return err
} }
for _, chat := range m.allChats { m.allChats.Range(func(chatID string, chat *Chat) (shouldContinue bool) {
if !chat.Timeline() && !chat.ProfileUpdates() && chat.Public() && chat.Active { if !chat.Timeline() && !chat.ProfileUpdates() && chat.Public() && chat.Active {
if err := m.syncPublicChat(ctx, chat); err != nil { err = m.syncPublicChat(ctx, chat)
return err if err != nil {
return false
} }
} }
return true
})
if err != nil {
return err
} }
for _, contact := range m.allContacts { m.allContacts.Range(func(contactID string, contact *Contact) (shouldContinue bool) {
if contact.IsAdded() && contact.ID != myID { if contact.IsAdded() && contact.ID != myID {
if err := m.syncContact(ctx, contact); err != nil { if err = m.syncContact(ctx, contact); err != nil {
return err return false
} }
} }
} return true
})
return nil return err
} }
// SendPairInstallation sends a pair installation message // SendPairInstallation sends a pair installation message
func (m *Messenger) SendPairInstallation(ctx context.Context) (*MessengerResponse, error) { func (m *Messenger) SendPairInstallation(ctx context.Context) (*MessengerResponse, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
var err error var err error
var response MessengerResponse var response MessengerResponse
installation, ok := m.allInstallations[m.installationID] installation, ok := m.allInstallations.Load(m.installationID)
if !ok { if !ok {
return nil, errors.New("no installation found") return nil, errors.New("no installation found")
} }
@ -2253,14 +2210,14 @@ func (m *Messenger) SendPairInstallation(ctx context.Context) (*MessengerRespons
chatID := contactIDFromPublicKey(&m.identity.PublicKey) chatID := contactIDFromPublicKey(&m.identity.PublicKey)
chat, ok := m.allChats[chatID] chat, ok := m.allChats.Load(chatID)
if !ok { if !ok {
chat = OneToOneFromPublicKey(&m.identity.PublicKey, m.getTimesource()) chat = OneToOneFromPublicKey(&m.identity.PublicKey, m.getTimesource())
// We don't want to show the chat to the user // We don't want to show the chat to the user
chat.Active = false chat.Active = false
} }
m.allChats[chat.ID] = chat m.allChats.Store(chat.ID, chat)
clock, _ := chat.NextClockAndTimestamp(m.getTimesource()) clock, _ := chat.NextClockAndTimestamp(m.getTimesource())
pairMessage := &protobuf.PairInstallation{ pairMessage := &protobuf.PairInstallation{
@ -2301,14 +2258,14 @@ func (m *Messenger) syncPublicChat(ctx context.Context, publicChat *Chat) error
} }
chatID := contactIDFromPublicKey(&m.identity.PublicKey) chatID := contactIDFromPublicKey(&m.identity.PublicKey)
chat, ok := m.allChats[chatID] chat, ok := m.allChats.Load(chatID)
if !ok { if !ok {
chat = OneToOneFromPublicKey(&m.identity.PublicKey, m.getTimesource()) chat = OneToOneFromPublicKey(&m.identity.PublicKey, m.getTimesource())
// We don't want to show the chat to the user // We don't want to show the chat to the user
chat.Active = false chat.Active = false
} }
m.allChats[chat.ID] = chat m.allChats.Store(chat.ID, chat)
clock, _ := chat.NextClockAndTimestamp(m.getTimesource()) clock, _ := chat.NextClockAndTimestamp(m.getTimesource())
syncMessage := &protobuf.SyncInstallationPublicChat{ syncMessage := &protobuf.SyncInstallationPublicChat{
@ -2342,14 +2299,14 @@ func (m *Messenger) syncContact(ctx context.Context, contact *Contact) error {
} }
chatID := contactIDFromPublicKey(&m.identity.PublicKey) chatID := contactIDFromPublicKey(&m.identity.PublicKey)
chat, ok := m.allChats[chatID] chat, ok := m.allChats.Load(chatID)
if !ok { if !ok {
chat = OneToOneFromPublicKey(&m.identity.PublicKey, m.getTimesource()) chat = OneToOneFromPublicKey(&m.identity.PublicKey, m.getTimesource())
// We don't want to show the chat to the user // We don't want to show the chat to the user
chat.Active = false chat.Active = false
} }
m.allChats[chat.ID] = chat m.allChats.Store(chat.ID, chat)
clock, _ := chat.NextClockAndTimestamp(m.getTimesource()) clock, _ := chat.NextClockAndTimestamp(m.getTimesource())
syncMessage := &protobuf.SyncInstallationContact{ syncMessage := &protobuf.SyncInstallationContact{
@ -2405,16 +2362,15 @@ type ReceivedMessageState struct {
// State on the message being processed // State on the message being processed
CurrentMessageState *CurrentMessageState CurrentMessageState *CurrentMessageState
// AllChats in memory // AllChats in memory
AllChats map[string]*Chat AllChats *chatMap
// All contacts in memory // All contacts in memory
AllContacts map[string]*Contact AllContacts *contactMap
// List of contacts modified // List of contacts modified
ModifiedContacts map[string]bool ModifiedContacts *stringBoolMap
// All installations in memory // All installations in memory
AllInstallations map[string]*multidevice.Installation AllInstallations *installationMap
// List of communities modified // List of communities modified
ModifiedInstallations map[string]bool ModifiedInstallations *stringBoolMap
// List of filters // List of filters
AllFilters map[string]*transport.Filter AllFilters map[string]*transport.Filter
// Map of existing messages // Map of existing messages
@ -2472,8 +2428,15 @@ func (r *ReceivedMessageState) addNewMessageNotification(publicKey ecdsa.PublicK
} }
contactID := contactIDFromPublicKey(pubKey) contactID := contactIDFromPublicKey(pubKey)
chat := r.AllChats[m.LocalChatID] chat, ok := r.AllChats.Load(m.LocalChatID)
contact := r.AllContacts[contactID] if !ok {
return fmt.Errorf("chat ID '%s' not present", m.LocalChatID)
}
contact, ok := r.AllContacts.Load(contactID)
if !ok {
return fmt.Errorf("contact ID '%s' not present", contactID)
}
if showMessageNotification(publicKey, m, chat, responseTo) { if showMessageNotification(publicKey, m, chat, responseTo) {
notification, err := NewMessageNotification(m.ID, m, chat, contact, r.AllContacts) notification, err := NewMessageNotification(m.ID, m, chat, contact, r.AllContacts)
@ -2487,12 +2450,10 @@ func (r *ReceivedMessageState) addNewMessageNotification(publicKey ecdsa.PublicK
} }
func (m *Messenger) handleRetrievedMessages(chatWithMessages map[transport.Filter][]*types.Message) (*MessengerResponse, error) { func (m *Messenger) handleRetrievedMessages(chatWithMessages map[transport.Filter][]*types.Message) (*MessengerResponse, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
messageState := &ReceivedMessageState{ messageState := &ReceivedMessageState{
AllChats: m.allChats, AllChats: m.allChats,
AllContacts: m.allContacts, AllContacts: m.allContacts,
ModifiedContacts: make(map[string]bool), ModifiedContacts: new(stringBoolMap),
AllInstallations: m.allInstallations, AllInstallations: m.allInstallations,
ModifiedInstallations: m.modifiedInstallations, ModifiedInstallations: m.modifiedInstallations,
ExistingMessagesMap: make(map[string]bool), ExistingMessagesMap: make(map[string]bool),
@ -2531,7 +2492,7 @@ func (m *Messenger) handleRetrievedMessages(chatWithMessages map[transport.Filte
// Check for messages from blocked users // Check for messages from blocked users
senderID := contactIDFromPublicKey(publicKey) senderID := contactIDFromPublicKey(publicKey)
if _, ok := messageState.AllContacts[senderID]; ok && messageState.AllContacts[senderID].IsBlocked() { if contact, ok := messageState.AllContacts.Load(senderID); ok && contact.IsBlocked() {
continue continue
} }
@ -2547,7 +2508,7 @@ func (m *Messenger) handleRetrievedMessages(chatWithMessages map[transport.Filte
} }
var contact *Contact var contact *Contact
if c, ok := messageState.AllContacts[senderID]; ok { if c, ok := messageState.AllContacts.Load(senderID); ok {
contact = c contact = c
} else { } else {
c, err := buildContact(senderID, publicKey) c, err := buildContact(senderID, publicKey)
@ -2557,8 +2518,8 @@ func (m *Messenger) handleRetrievedMessages(chatWithMessages map[transport.Filte
continue continue
} }
contact = c contact = c
messageState.AllContacts[senderID] = c messageState.AllContacts.Store(senderID, contact)
messageState.ModifiedContacts[contact.ID] = true messageState.ModifiedContacts.Store(contact.ID, true)
} }
messageState.CurrentMessageState = &CurrentMessageState{ messageState.CurrentMessageState = &CurrentMessageState{
MessageID: messageID, MessageID: messageID,
@ -2576,7 +2537,8 @@ func (m *Messenger) handleRetrievedMessages(chatWithMessages map[transport.Filte
rawMembershipUpdate := msg.ParsedMessage.Interface().(protobuf.MembershipUpdateMessage) rawMembershipUpdate := msg.ParsedMessage.Interface().(protobuf.MembershipUpdateMessage)
err = m.handler.HandleMembershipUpdate(messageState, messageState.AllChats[rawMembershipUpdate.ChatId], rawMembershipUpdate, m.systemMessagesTranslations) chat, _ := messageState.AllChats.Load(rawMembershipUpdate.ChatId)
err = m.handler.HandleMembershipUpdate(messageState, chat, rawMembershipUpdate, m.systemMessagesTranslations)
if err != nil { if err != nil {
logger.Warn("failed to handle MembershipUpdate", zap.Error(err)) logger.Warn("failed to handle MembershipUpdate", zap.Error(err))
allMessagesProcessed = false allMessagesProcessed = false
@ -2923,9 +2885,9 @@ func (m *Messenger) handleRetrievedMessages(chatWithMessages map[transport.Filte
} }
var contactsToSave []*Contact var contactsToSave []*Contact
for id := range messageState.ModifiedContacts { messageState.ModifiedContacts.Range(func(id string, value bool) (shouldContinue bool) {
contact := messageState.AllContacts[id] contact, ok := messageState.AllContacts.Load(id)
if contact != nil { if ok {
// We save all contacts so we can pull back name/image, // We save all contacts so we can pull back name/image,
// but we only send to client those // but we only send to client those
// that have some custom fields // that have some custom fields
@ -2934,7 +2896,8 @@ func (m *Messenger) handleRetrievedMessages(chatWithMessages map[transport.Filte
messageState.Response.Contacts = append(messageState.Response.Contacts, contact) messageState.Response.Contacts = append(messageState.Response.Contacts, contact)
} }
} }
} return true
})
for _, filter := range messageState.AllFilters { for _, filter := range messageState.AllFilters {
messageState.Response.Filters = append(messageState.Response.Filters, filter) messageState.Response.Filters = append(messageState.Response.Filters, filter)
@ -2942,9 +2905,9 @@ func (m *Messenger) handleRetrievedMessages(chatWithMessages map[transport.Filte
// Hydrate chat alias and identicon // Hydrate chat alias and identicon
for id := range messageState.Response.chats { for id := range messageState.Response.chats {
chat := messageState.AllChats[id] chat, _ := messageState.AllChats.Load(id)
if chat.OneToOne() { if chat.OneToOne() {
contact, ok := m.allContacts[chat.ID] contact, ok := m.allContacts.Load(chat.ID)
if ok { if ok {
chat.Alias = contact.Alias chat.Alias = contact.Alias
chat.Identicon = contact.Identicon chat.Identicon = contact.Identicon
@ -2954,18 +2917,23 @@ func (m *Messenger) handleRetrievedMessages(chatWithMessages map[transport.Filte
messageState.Response.AddChat(chat) messageState.Response.AddChat(chat)
} }
for id := range messageState.ModifiedInstallations { var err error
installation := messageState.AllInstallations[id] messageState.ModifiedInstallations.Range(func(id string, value bool) (shouldContinue bool) {
installation, _ := messageState.AllInstallations.Load(id)
messageState.Response.Installations = append(messageState.Response.Installations, installation) messageState.Response.Installations = append(messageState.Response.Installations, installation)
if installation.InstallationMetadata != nil { if installation.InstallationMetadata != nil {
err := m.setInstallationMetadata(id, installation.InstallationMetadata) err = m.setInstallationMetadata(id, installation.InstallationMetadata)
if err != nil { if err != nil {
return nil, err return false
} }
} }
return true
})
if err != nil {
return nil, err
} }
var err error
if len(messageState.Response.chats) > 0 { if len(messageState.Response.chats) > 0 {
err = m.saveChats(messageState.Response.Chats()) err = m.saveChats(messageState.Response.Chats())
if err != nil { if err != nil {
@ -3028,7 +2996,7 @@ func (m *Messenger) handleRetrievedMessages(chatWithMessages map[transport.Filte
} }
// Reset installations // Reset installations
m.modifiedInstallations = make(map[string]bool) m.modifiedInstallations = new(stringBoolMap)
return messageState.Response, nil return messageState.Response, nil
} }
@ -3109,14 +3077,11 @@ func (m *Messenger) DeleteMessagesByChatID(id string) error {
} }
func (m *Messenger) ClearHistory(id string) (*MessengerResponse, error) { func (m *Messenger) ClearHistory(id string) (*MessengerResponse, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.clearHistory(id) return m.clearHistory(id)
} }
func (m *Messenger) clearHistory(id string) (*MessengerResponse, error) { func (m *Messenger) clearHistory(id string) (*MessengerResponse, error) {
chat, ok := m.allChats[id] chat, ok := m.allChats.Load(id)
if !ok { if !ok {
return nil, ErrChatNotFound return nil, ErrChatNotFound
} }
@ -3128,7 +3093,7 @@ func (m *Messenger) clearHistory(id string) (*MessengerResponse, error) {
return nil, err return nil, err
} }
m.allChats[id] = chat m.allChats.Store(id, chat)
response := &MessengerResponse{} response := &MessengerResponse{}
response.AddChat(chat) response.AddChat(chat)
@ -3139,9 +3104,6 @@ func (m *Messenger) clearHistory(id string) (*MessengerResponse, error) {
// It returns the number of affected messages or error. If there is an error, // It returns the number of affected messages or error. If there is an error,
// the number of affected messages is always zero. // the number of affected messages is always zero.
func (m *Messenger) MarkMessagesSeen(chatID string, ids []string) (uint64, error) { func (m *Messenger) MarkMessagesSeen(chatID string, ids []string) (uint64, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
count, err := m.persistence.MarkMessagesSeen(chatID, ids) count, err := m.persistence.MarkMessagesSeen(chatID, ids)
if err != nil { if err != nil {
return 0, err return 0, err
@ -3150,14 +3112,12 @@ func (m *Messenger) MarkMessagesSeen(chatID string, ids []string) (uint64, error
if err != nil { if err != nil {
return 0, err return 0, err
} }
m.allChats[chatID] = chat m.allChats.Store(chatID, chat)
return count, nil return count, nil
} }
func (m *Messenger) MarkAllRead(chatID string) error { func (m *Messenger) MarkAllRead(chatID string) error {
m.mutex.Lock() chat, ok := m.allChats.Load(chatID)
defer m.mutex.Unlock()
chat, ok := m.allChats[chatID]
if !ok { if !ok {
return errors.New("chat not found") return errors.New("chat not found")
} }
@ -3168,16 +3128,15 @@ func (m *Messenger) MarkAllRead(chatID string) error {
} }
chat.UnviewedMessagesCount = 0 chat.UnviewedMessagesCount = 0
m.allChats[chat.ID] = chat // TODO(samyoul) remove storing of an updated reference pointer?
m.allChats.Store(chat.ID, chat)
return nil return nil
} }
// MuteChat signals to the messenger that we don't want to be notified // MuteChat signals to the messenger that we don't want to be notified
// on new messages from this chat // on new messages from this chat
func (m *Messenger) MuteChat(chatID string) error { func (m *Messenger) MuteChat(chatID string) error {
m.mutex.Lock() chat, ok := m.allChats.Load(chatID)
defer m.mutex.Unlock()
chat, ok := m.allChats[chatID]
if !ok { if !ok {
return errors.New("chat not found") return errors.New("chat not found")
} }
@ -3188,7 +3147,8 @@ func (m *Messenger) MuteChat(chatID string) error {
} }
chat.Muted = true chat.Muted = true
m.allChats[chat.ID] = chat // TODO(samyoul) remove storing of an updated reference pointer?
m.allChats.Store(chat.ID, chat)
return m.reregisterForPushNotifications() return m.reregisterForPushNotifications()
} }
@ -3196,9 +3156,7 @@ func (m *Messenger) MuteChat(chatID string) error {
// UnmuteChat signals to the messenger that we want to be notified // UnmuteChat signals to the messenger that we want to be notified
// on new messages from this chat // on new messages from this chat
func (m *Messenger) UnmuteChat(chatID string) error { func (m *Messenger) UnmuteChat(chatID string) error {
m.mutex.Lock() chat, ok := m.allChats.Load(chatID)
defer m.mutex.Unlock()
chat, ok := m.allChats[chatID]
if !ok { if !ok {
return errors.New("chat not found") return errors.New("chat not found")
} }
@ -3209,7 +3167,8 @@ func (m *Messenger) UnmuteChat(chatID string) error {
} }
chat.Muted = false chat.Muted = false
m.allChats[chat.ID] = chat // TODO(samyoul) remove storing of an updated reference pointer?
m.allChats.Store(chat.ID, chat)
return m.reregisterForPushNotifications() return m.reregisterForPushNotifications()
} }
@ -3228,13 +3187,10 @@ func GenerateAlias(id string) (string, error) {
} }
func (m *Messenger) RequestTransaction(ctx context.Context, chatID, value, contract, address string) (*MessengerResponse, error) { func (m *Messenger) RequestTransaction(ctx context.Context, chatID, value, contract, address string) (*MessengerResponse, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
var response MessengerResponse var response MessengerResponse
// A valid added chat is required. // A valid added chat is required.
chat, ok := m.allChats[chatID] chat, ok := m.allChats.Load(chatID)
if !ok { if !ok {
return nil, errors.New("Chat not found") return nil, errors.New("Chat not found")
} }
@ -3306,13 +3262,10 @@ func (m *Messenger) RequestTransaction(ctx context.Context, chatID, value, contr
} }
func (m *Messenger) RequestAddressForTransaction(ctx context.Context, chatID, from, value, contract string) (*MessengerResponse, error) { func (m *Messenger) RequestAddressForTransaction(ctx context.Context, chatID, from, value, contract string) (*MessengerResponse, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
var response MessengerResponse var response MessengerResponse
// A valid added chat is required. // A valid added chat is required.
chat, ok := m.allChats[chatID] chat, ok := m.allChats.Load(chatID)
if !ok { if !ok {
return nil, errors.New("Chat not found") return nil, errors.New("Chat not found")
} }
@ -3383,9 +3336,6 @@ func (m *Messenger) RequestAddressForTransaction(ctx context.Context, chatID, fr
} }
func (m *Messenger) AcceptRequestAddressForTransaction(ctx context.Context, messageID, address string) (*MessengerResponse, error) { func (m *Messenger) AcceptRequestAddressForTransaction(ctx context.Context, messageID, address string) (*MessengerResponse, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
var response MessengerResponse var response MessengerResponse
message, err := m.MessageByID(messageID) message, err := m.MessageByID(messageID)
@ -3400,7 +3350,7 @@ func (m *Messenger) AcceptRequestAddressForTransaction(ctx context.Context, mess
chatID := message.LocalChatID chatID := message.LocalChatID
// A valid added chat is required. // A valid added chat is required.
chat, ok := m.allChats[chatID] chat, ok := m.allChats.Load(chatID)
if !ok { if !ok {
return nil, errors.New("Chat not found") return nil, errors.New("Chat not found")
} }
@ -3479,9 +3429,6 @@ func (m *Messenger) AcceptRequestAddressForTransaction(ctx context.Context, mess
} }
func (m *Messenger) DeclineRequestTransaction(ctx context.Context, messageID string) (*MessengerResponse, error) { func (m *Messenger) DeclineRequestTransaction(ctx context.Context, messageID string) (*MessengerResponse, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
var response MessengerResponse var response MessengerResponse
message, err := m.MessageByID(messageID) message, err := m.MessageByID(messageID)
@ -3496,7 +3443,7 @@ func (m *Messenger) DeclineRequestTransaction(ctx context.Context, messageID str
chatID := message.LocalChatID chatID := message.LocalChatID
// A valid added chat is required. // A valid added chat is required.
chat, ok := m.allChats[chatID] chat, ok := m.allChats.Load(chatID)
if !ok { if !ok {
return nil, errors.New("Chat not found") return nil, errors.New("Chat not found")
} }
@ -3562,9 +3509,6 @@ func (m *Messenger) DeclineRequestTransaction(ctx context.Context, messageID str
} }
func (m *Messenger) DeclineRequestAddressForTransaction(ctx context.Context, messageID string) (*MessengerResponse, error) { func (m *Messenger) DeclineRequestAddressForTransaction(ctx context.Context, messageID string) (*MessengerResponse, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
var response MessengerResponse var response MessengerResponse
message, err := m.MessageByID(messageID) message, err := m.MessageByID(messageID)
@ -3579,7 +3523,7 @@ func (m *Messenger) DeclineRequestAddressForTransaction(ctx context.Context, mes
chatID := message.LocalChatID chatID := message.LocalChatID
// A valid added chat is required. // A valid added chat is required.
chat, ok := m.allChats[chatID] chat, ok := m.allChats.Load(chatID)
if !ok { if !ok {
return nil, errors.New("Chat not found") return nil, errors.New("Chat not found")
} }
@ -3645,9 +3589,6 @@ func (m *Messenger) DeclineRequestAddressForTransaction(ctx context.Context, mes
} }
func (m *Messenger) AcceptRequestTransaction(ctx context.Context, transactionHash, messageID string, signature []byte) (*MessengerResponse, error) { func (m *Messenger) AcceptRequestTransaction(ctx context.Context, transactionHash, messageID string, signature []byte) (*MessengerResponse, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
var response MessengerResponse var response MessengerResponse
message, err := m.MessageByID(messageID) message, err := m.MessageByID(messageID)
@ -3662,7 +3603,7 @@ func (m *Messenger) AcceptRequestTransaction(ctx context.Context, transactionHas
chatID := message.LocalChatID chatID := message.LocalChatID
// A valid added chat is required. // A valid added chat is required.
chat, ok := m.allChats[chatID] chat, ok := m.allChats.Load(chatID)
if !ok { if !ok {
return nil, errors.New("Chat not found") return nil, errors.New("Chat not found")
} }
@ -3745,13 +3686,10 @@ func (m *Messenger) AcceptRequestTransaction(ctx context.Context, transactionHas
} }
func (m *Messenger) SendTransaction(ctx context.Context, chatID, value, contract, transactionHash string, signature []byte) (*MessengerResponse, error) { func (m *Messenger) SendTransaction(ctx context.Context, chatID, value, contract, transactionHash string, signature []byte) (*MessengerResponse, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
var response MessengerResponse var response MessengerResponse
// A valid added chat is required. // A valid added chat is required.
chat, ok := m.allChats[chatID] chat, ok := m.allChats.Load(chatID)
if !ok { if !ok {
return nil, errors.New("Chat not found") return nil, errors.New("Chat not found")
} }
@ -3830,8 +3768,6 @@ func (m *Messenger) ValidateTransactions(ctx context.Context, addresses []types.
if m.verifyTransactionClient == nil { if m.verifyTransactionClient == nil {
return nil, nil return nil, nil
} }
m.mutex.Lock()
defer m.mutex.Unlock()
logger := m.logger.With(zap.String("site", "ValidateTransactions")) logger := m.logger.With(zap.String("site", "ValidateTransactions"))
logger.Debug("Validating transactions") logger.Debug("Validating transactions")
@ -3851,7 +3787,7 @@ func (m *Messenger) ValidateTransactions(ctx context.Context, addresses []types.
for _, validationResult := range responses { for _, validationResult := range responses {
var message *common.Message var message *common.Message
chatID := contactIDFromPublicKey(validationResult.Transaction.From) chatID := contactIDFromPublicKey(validationResult.Transaction.From)
chat, ok := m.allChats[chatID] chat, ok := m.allChats.Load(chatID)
if !ok { if !ok {
chat = OneToOneFromPublicKey(validationResult.Transaction.From, m.transport) chat = OneToOneFromPublicKey(validationResult.Transaction.From, m.transport)
} }
@ -3917,7 +3853,7 @@ func (m *Messenger) ValidateTransactions(ctx context.Context, addresses []types.
} }
response.Messages = append(response.Messages, message) response.Messages = append(response.Messages, message)
m.allChats[chat.ID] = chat m.allChats.Store(chat.ID, chat)
response.AddChat(chat) response.AddChat(chat)
contact, err := m.getOrBuildContactFromMessage(message) contact, err := m.getOrBuildContactFromMessage(message)
@ -4017,19 +3953,21 @@ func (m *Messenger) pushNotificationOptions() *pushnotificationclient.Registrati
var mutedChatIDs []string var mutedChatIDs []string
var publicChatIDs []string var publicChatIDs []string
for _, contact := range m.allContacts { m.allContacts.Range(func(contactID string, contact *Contact) (shouldContinue bool) {
if contact.IsAdded() && !contact.IsBlocked() { if contact.IsAdded() && !contact.IsBlocked() {
pk, err := contact.PublicKey() pk, err := contact.PublicKey()
if err != nil { if err != nil {
m.logger.Warn("could not parse contact public key") m.logger.Warn("could not parse contact public key")
continue return true
} }
contactIDs = append(contactIDs, pk) contactIDs = append(contactIDs, pk)
} else if contact.IsBlocked() { } else if contact.IsBlocked() {
mutedChatIDs = append(mutedChatIDs, contact.ID) mutedChatIDs = append(mutedChatIDs, contact.ID)
} }
} return true
for _, chat := range m.allChats { })
m.allChats.Range(func(chatID string, chat *Chat) (shouldContinue bool) {
if chat.Muted { if chat.Muted {
mutedChatIDs = append(mutedChatIDs, chat.ID) mutedChatIDs = append(mutedChatIDs, chat.ID)
} }
@ -4037,7 +3975,8 @@ func (m *Messenger) pushNotificationOptions() *pushnotificationclient.Registrati
publicChatIDs = append(publicChatIDs, chat.ID) publicChatIDs = append(publicChatIDs, chat.ID)
} }
} return true
})
return &pushnotificationclient.RegistrationOptions{ return &pushnotificationclient.RegistrationOptions{
ContactIDs: contactIDs, ContactIDs: contactIDs,
MutedChatIDs: mutedChatIDs, MutedChatIDs: mutedChatIDs,
@ -4157,12 +4096,9 @@ func generateAliasAndIdenticon(pk string) (string, string, error) {
} }
func (m *Messenger) SendEmojiReaction(ctx context.Context, chatID, messageID string, emojiID protobuf.EmojiReaction_Type) (*MessengerResponse, error) { func (m *Messenger) SendEmojiReaction(ctx context.Context, chatID, messageID string, emojiID protobuf.EmojiReaction_Type) (*MessengerResponse, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
var response MessengerResponse var response MessengerResponse
chat, ok := m.allChats[chatID] chat, ok := m.allChats.Load(chatID)
if !ok { if !ok {
return nil, ErrChatNotFound return nil, ErrChatNotFound
} }
@ -4215,20 +4151,18 @@ func (m *Messenger) EmojiReactionsByChatID(chatID string, cursor string, limit i
if chat.Timeline() { if chat.Timeline() {
var chatIDs = []string{"@" + contactIDFromPublicKey(&m.identity.PublicKey)} var chatIDs = []string{"@" + contactIDFromPublicKey(&m.identity.PublicKey)}
for _, contact := range m.allContacts { m.allContacts.Range(func(contactID string, contact *Contact) (shouldContinue bool) {
if contact.IsAdded() { if contact.IsAdded() {
chatIDs = append(chatIDs, "@"+contact.ID) chatIDs = append(chatIDs, "@"+contact.ID)
} }
} return true
})
return m.persistence.EmojiReactionsByChatIDs(chatIDs, cursor, limit) return m.persistence.EmojiReactionsByChatIDs(chatIDs, cursor, limit)
} }
return m.persistence.EmojiReactionsByChatID(chatID, cursor, limit) return m.persistence.EmojiReactionsByChatID(chatID, cursor, limit)
} }
func (m *Messenger) SendEmojiReactionRetraction(ctx context.Context, emojiReactionID string) (*MessengerResponse, error) { func (m *Messenger) SendEmojiReactionRetraction(ctx context.Context, emojiReactionID string) (*MessengerResponse, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
emojiR, err := m.persistence.EmojiReactionByID(emojiReactionID) emojiR, err := m.persistence.EmojiReactionByID(emojiReactionID)
if err != nil { if err != nil {
return nil, err return nil, err
@ -4245,7 +4179,7 @@ func (m *Messenger) SendEmojiReactionRetraction(ctx context.Context, emojiReacti
} }
// Get chat and clock // Get chat and clock
chat, ok := m.allChats[emojiR.GetChatId()] chat, ok := m.allChats.Load(emojiR.GetChatId())
if !ok { if !ok {
return nil, ErrChatNotFound return nil, ErrChatNotFound
} }
@ -4346,7 +4280,7 @@ func (m *Messenger) encodeChatEntity(chat *Chat, message common.ChatEntity) ([]b
} }
func (m *Messenger) getOrBuildContactFromMessage(msg *common.Message) (*Contact, error) { func (m *Messenger) getOrBuildContactFromMessage(msg *common.Message) (*Contact, error) {
if c, ok := m.allContacts[msg.From]; ok { if c, ok := m.allContacts.Load(msg.From); ok {
return c, nil return c, nil
} }
@ -4360,6 +4294,7 @@ func (m *Messenger) getOrBuildContactFromMessage(msg *common.Message) (*Contact,
return nil, err return nil, err
} }
m.allContacts[msg.From] = c // TODO(samyoul) remove storing of an updated reference pointer?
m.allContacts.Store(msg.From, c)
return c, nil return c, nil
} }

View File

@ -10,22 +10,17 @@ import (
) )
func (m *Messenger) Chats() []*Chat { func (m *Messenger) Chats() []*Chat {
m.mutex.Lock()
defer m.mutex.Unlock()
var chats []*Chat var chats []*Chat
for _, c := range m.allChats { m.allChats.Range(func(chatID string, chat *Chat) (shouldContinue bool) {
chats = append(chats, c) chats = append(chats, chat)
} return true
})
return chats return chats
} }
func (m *Messenger) CreateOneToOneChat(request *requests.CreateOneToOneChat) (*MessengerResponse, error) { func (m *Messenger) CreateOneToOneChat(request *requests.CreateOneToOneChat) (*MessengerResponse, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
if err := request.Validate(); err != nil { if err := request.Validate(); err != nil {
return nil, err return nil, err
} }
@ -36,7 +31,7 @@ func (m *Messenger) CreateOneToOneChat(request *requests.CreateOneToOneChat) (*M
return nil, err return nil, err
} }
chat, ok := m.allChats[chatID] chat, ok := m.allChats.Load(chatID)
if !ok { if !ok {
chat = CreateOneToOneChat(chatID, pk, m.getTimesource()) chat = CreateOneToOneChat(chatID, pk, m.getTimesource())
} }
@ -52,7 +47,8 @@ func (m *Messenger) CreateOneToOneChat(request *requests.CreateOneToOneChat) (*M
return nil, err return nil, err
} }
m.allChats[chatID] = chat // TODO(Samyoul) remove storing of an updated reference pointer?
m.allChats.Store(chatID, chat)
response := &MessengerResponse{ response := &MessengerResponse{
Filters: filters, Filters: filters,
@ -64,9 +60,6 @@ func (m *Messenger) CreateOneToOneChat(request *requests.CreateOneToOneChat) (*M
} }
func (m *Messenger) DeleteChat(chatID string) error { func (m *Messenger) DeleteChat(chatID string) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.deleteChat(chatID) return m.deleteChat(chatID)
} }
@ -75,10 +68,10 @@ func (m *Messenger) deleteChat(chatID string) error {
if err != nil { if err != nil {
return err return err
} }
chat, ok := m.allChats[chatID] chat, ok := m.allChats.Load(chatID)
if ok && chat.Active && chat.Public() { if ok && chat.Active && chat.Public() {
delete(m.allChats, chatID) m.allChats.Delete(chatID)
return m.reregisterForPushNotifications() return m.reregisterForPushNotifications()
} }
@ -86,20 +79,16 @@ func (m *Messenger) deleteChat(chatID string) error {
} }
func (m *Messenger) SaveChat(chat *Chat) error { func (m *Messenger) SaveChat(chat *Chat) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.saveChat(chat) return m.saveChat(chat)
} }
func (m *Messenger) DeactivateChat(chatID string) (*MessengerResponse, error) { func (m *Messenger) DeactivateChat(chatID string) (*MessengerResponse, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.deactivateChat(chatID) return m.deactivateChat(chatID)
} }
func (m *Messenger) deactivateChat(chatID string) (*MessengerResponse, error) { func (m *Messenger) deactivateChat(chatID string) (*MessengerResponse, error) {
var response MessengerResponse var response MessengerResponse
chat, ok := m.allChats[chatID] chat, ok := m.allChats.Load(chatID)
if !ok { if !ok {
return nil, ErrChatNotFound return nil, ErrChatNotFound
} }
@ -121,7 +110,8 @@ func (m *Messenger) deactivateChat(chatID string) (*MessengerResponse, error) {
} }
} }
m.allChats[chatID] = chat // TODO(samyoul) remove storing of an updated reference pointer?
m.allChats.Store(chatID, chat)
response.AddChat(chat) response.AddChat(chat)
// TODO: Remove filters // TODO: Remove filters
@ -135,7 +125,7 @@ func (m *Messenger) saveChats(chats []*Chat) error {
return err return err
} }
for _, chat := range chats { for _, chat := range chats {
m.allChats[chat.ID] = chat m.allChats.Store(chat.ID, chat)
} }
return nil return nil
@ -143,7 +133,7 @@ func (m *Messenger) saveChats(chats []*Chat) error {
} }
func (m *Messenger) saveChat(chat *Chat) error { func (m *Messenger) saveChat(chat *Chat) error {
previousChat, ok := m.allChats[chat.ID] previousChat, ok := m.allChats.Load(chat.ID)
if chat.OneToOne() { if chat.OneToOne() {
name, identicon, err := generateAliasAndIdenticon(chat.ID) name, identicon, err := generateAliasAndIdenticon(chat.ID)
if err != nil { if err != nil {
@ -170,7 +160,8 @@ func (m *Messenger) saveChat(chat *Chat) error {
if err != nil { if err != nil {
return err return err
} }
m.allChats[chat.ID] = chat // TODO(samyoul) remove storing of an updated reference pointer?
m.allChats.Store(chat.ID, chat)
if shouldRegisterForPushNotifications { if shouldRegisterForPushNotifications {
// Re-register for push notifications, as we want to receive mentions // Re-register for push notifications, as we want to receive mentions

View File

@ -24,7 +24,7 @@ type config struct {
onContactENSVerified func(*MessengerResponse) onContactENSVerified func(*MessengerResponse)
// systemMessagesTranslations holds translations for system-messages // systemMessagesTranslations holds translations for system-messages
systemMessagesTranslations map[protobuf.MembershipUpdateEvent_EventType]string systemMessagesTranslations *systemMessageTranslationsMap
// Config for the envelopes monitor // Config for the envelopes monitor
envelopesMonitorConfig *transport.EnvelopesMonitorConfig envelopesMonitorConfig *transport.EnvelopesMonitorConfig
@ -56,7 +56,7 @@ type Option func(*config) error
// nolint: unused // nolint: unused
func WithSystemMessagesTranslations(t map[protobuf.MembershipUpdateEvent_EventType]string) Option { func WithSystemMessagesTranslations(t map[protobuf.MembershipUpdateEvent_EventType]string) Option {
return func(c *config) error { return func(c *config) error {
c.systemMessagesTranslations = t c.systemMessagesTranslations.Init(t)
return nil return nil
} }
} }

View File

@ -11,15 +11,11 @@ import (
) )
func (m *Messenger) SaveContact(contact *Contact) error { func (m *Messenger) SaveContact(contact *Contact) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.saveContact(contact) return m.saveContact(contact)
} }
func (m *Messenger) AddContact(ctx context.Context, pubKey string) (*MessengerResponse, error) { func (m *Messenger) AddContact(ctx context.Context, pubKey string) (*MessengerResponse, error) {
m.mutex.Lock() contact, ok := m.allContacts.Load(pubKey)
defer m.mutex.Unlock()
contact, ok := m.allContacts[pubKey]
if !ok { if !ok {
var err error var err error
contact, err = buildContactFromPkString(pubKey) contact, err = buildContactFromPkString(pubKey)
@ -43,7 +39,8 @@ func (m *Messenger) AddContact(ctx context.Context, pubKey string) (*MessengerRe
return nil, err return nil, err
} }
m.allContacts[contact.ID] = contact // TODO(samyoul) remove storing of an updated reference pointer?
m.allContacts.Store(contact.ID, contact)
// And we re-register for push notications // And we re-register for push notications
err = m.reregisterForPushNotifications() err = m.reregisterForPushNotifications()
@ -53,7 +50,7 @@ func (m *Messenger) AddContact(ctx context.Context, pubKey string) (*MessengerRe
// Create the corresponding profile chat // Create the corresponding profile chat
profileChatID := buildProfileChatID(contact.ID) profileChatID := buildProfileChatID(contact.ID)
profileChat, ok := m.allChats[profileChatID] profileChat, ok := m.allChats.Load(profileChatID)
if !ok { if !ok {
profileChat = CreateProfileChat(profileChatID, contact.ID, m.getTimesource()) profileChat = CreateProfileChat(profileChatID, contact.ID, m.getTimesource())
@ -89,11 +86,9 @@ func (m *Messenger) AddContact(ctx context.Context, pubKey string) (*MessengerRe
} }
func (m *Messenger) RemoveContact(ctx context.Context, pubKey string) (*MessengerResponse, error) { func (m *Messenger) RemoveContact(ctx context.Context, pubKey string) (*MessengerResponse, error) {
m.mutex.Lock() response := new(MessengerResponse)
defer m.mutex.Unlock()
var response *MessengerResponse
contact, ok := m.allContacts[pubKey] contact, ok := m.allContacts.Load(pubKey)
if !ok { if !ok {
return nil, ErrContactNotFound return nil, ErrContactNotFound
} }
@ -105,7 +100,8 @@ func (m *Messenger) RemoveContact(ctx context.Context, pubKey string) (*Messenge
return nil, err return nil, err
} }
m.allContacts[contact.ID] = contact // TODO(samyoul) remove storing of an updated reference pointer?
m.allContacts.Store(contact.ID, contact)
// And we re-register for push notications // And we re-register for push notications
err = m.reregisterForPushNotifications() err = m.reregisterForPushNotifications()
@ -115,7 +111,7 @@ func (m *Messenger) RemoveContact(ctx context.Context, pubKey string) (*Messenge
// Create the corresponding profile chat // Create the corresponding profile chat
profileChatID := buildProfileChatID(contact.ID) profileChatID := buildProfileChatID(contact.ID)
_, ok = m.allChats[profileChatID] _, ok = m.allChats.Load(profileChatID)
if ok { if ok {
chatResponse, err := m.deactivateChat(profileChatID) chatResponse, err := m.deactivateChat(profileChatID)
@ -133,38 +129,33 @@ func (m *Messenger) RemoveContact(ctx context.Context, pubKey string) (*Messenge
} }
func (m *Messenger) Contacts() []*Contact { func (m *Messenger) Contacts() []*Contact {
m.mutex.Lock()
defer m.mutex.Unlock()
var contacts []*Contact var contacts []*Contact
for _, contact := range m.allContacts { m.allContacts.Range(func(contactID string, contact *Contact) (shouldContinue bool) {
if contact.HasCustomFields() { if contact.HasCustomFields() {
contacts = append(contacts, contact) contacts = append(contacts, contact)
} }
} return true
})
return contacts return contacts
} }
// GetContactByID assumes pubKey includes 0x prefix // GetContactByID assumes pubKey includes 0x prefix
func (m *Messenger) GetContactByID(pubKey string) *Contact { func (m *Messenger) GetContactByID(pubKey string) *Contact {
m.mutex.Lock() contact, _ := m.allContacts.Load(pubKey)
defer m.mutex.Unlock() return contact
return m.allContacts[pubKey]
} }
func (m *Messenger) BlockContact(contact *Contact) ([]*Chat, error) { func (m *Messenger) BlockContact(contact *Contact) ([]*Chat, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
chats, err := m.persistence.BlockContact(contact) chats, err := m.persistence.BlockContact(contact)
if err != nil { if err != nil {
return nil, err return nil, err
} }
m.allContacts[contact.ID] = contact m.allContacts.Store(contact.ID, contact)
for _, chat := range chats { for _, chat := range chats {
m.allChats[chat.ID] = chat m.allChats.Store(chat.ID, chat)
} }
delete(m.allChats, contact.ID) m.allChats.Delete(contact.ID)
// re-register for push notifications // re-register for push notifications
err = m.reregisterForPushNotifications() err = m.reregisterForPushNotifications()
@ -199,7 +190,7 @@ func (m *Messenger) saveContact(contact *Contact) error {
return err return err
} }
m.allContacts[contact.ID] = contact m.allContacts.Store(contact.ID, contact)
// Reregister only when data has changed // Reregister only when data has changed
if shouldReregisterForPushNotifications { if shouldReregisterForPushNotifications {
@ -210,25 +201,23 @@ func (m *Messenger) saveContact(contact *Contact) error {
} }
// Send contact updates to all contacts added by us // Send contact updates to all contacts added by us
func (m *Messenger) SendContactUpdates(ctx context.Context, ensName, profileImage string) error { func (m *Messenger) SendContactUpdates(ctx context.Context, ensName, profileImage string) (err error) {
m.mutex.Lock()
defer m.mutex.Unlock()
myID := contactIDFromPublicKey(&m.identity.PublicKey) myID := contactIDFromPublicKey(&m.identity.PublicKey)
if _, err := m.sendContactUpdate(ctx, myID, ensName, profileImage); err != nil { if _, err = m.sendContactUpdate(ctx, myID, ensName, profileImage); err != nil {
return err return err
} }
// TODO: This should not be sending paired messages, as we do it above // TODO: This should not be sending paired messages, as we do it above
for _, contact := range m.allContacts { m.allContacts.Range(func(contactID string, contact *Contact) (shouldContinue bool) {
if contact.IsAdded() { if contact.IsAdded() {
if _, err := m.sendContactUpdate(ctx, contact.ID, ensName, profileImage); err != nil { if _, err = m.sendContactUpdate(ctx, contact.ID, ensName, profileImage); err != nil {
return err return false
} }
} }
} return true
return nil })
return err
} }
// NOTE: this endpoint does not add the contact, the reason being is that currently // NOTE: this endpoint does not add the contact, the reason being is that currently
@ -239,15 +228,13 @@ func (m *Messenger) SendContactUpdates(ctx context.Context, ensName, profileImag
// SendContactUpdate sends a contact update to a user and adds the user to contacts // SendContactUpdate sends a contact update to a user and adds the user to contacts
func (m *Messenger) SendContactUpdate(ctx context.Context, chatID, ensName, profileImage string) (*MessengerResponse, error) { func (m *Messenger) SendContactUpdate(ctx context.Context, chatID, ensName, profileImage string) (*MessengerResponse, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.sendContactUpdate(ctx, chatID, ensName, profileImage) return m.sendContactUpdate(ctx, chatID, ensName, profileImage)
} }
func (m *Messenger) sendContactUpdate(ctx context.Context, chatID, ensName, profileImage string) (*MessengerResponse, error) { func (m *Messenger) sendContactUpdate(ctx context.Context, chatID, ensName, profileImage string) (*MessengerResponse, error) {
var response MessengerResponse var response MessengerResponse
contact, ok := m.allContacts[chatID] contact, ok := m.allContacts.Load(chatID)
if !ok { if !ok {
var err error var err error
contact, err = buildContactFromPkString(chatID) contact, err = buildContactFromPkString(chatID)
@ -256,7 +243,7 @@ func (m *Messenger) sendContactUpdate(ctx context.Context, chatID, ensName, prof
} }
} }
chat, ok := m.allChats[chatID] chat, ok := m.allChats.Load(chatID)
if !ok { if !ok {
publicKey, err := contact.PublicKey() publicKey, err := contact.PublicKey()
if err != nil { if err != nil {
@ -267,7 +254,8 @@ func (m *Messenger) sendContactUpdate(ctx context.Context, chatID, ensName, prof
chat.Active = false chat.Active = false
} }
m.allChats[chat.ID] = chat // TODO(samyoul) remove storing of an updated reference pointer?
m.allChats.Store(chat.ID, chat)
clock, _ := chat.NextClockAndTimestamp(m.getTimesource()) clock, _ := chat.NextClockAndTimestamp(m.getTimesource())
contactUpdate := &protobuf.ContactUpdate{ contactUpdate := &protobuf.ContactUpdate{
@ -301,12 +289,12 @@ func (m *Messenger) sendContactUpdate(ctx context.Context, chatID, ensName, prof
} }
func (m *Messenger) isNewContact(contact *Contact) bool { func (m *Messenger) isNewContact(contact *Contact) bool {
previousContact, ok := m.allContacts[contact.ID] previousContact, ok := m.allContacts.Load(contact.ID)
return contact.IsAdded() && (!ok || !previousContact.IsAdded()) return contact.IsAdded() && (!ok || !previousContact.IsAdded())
} }
func (m *Messenger) hasNicknameChanged(contact *Contact) bool { func (m *Messenger) hasNicknameChanged(contact *Contact) bool {
previousContact, ok := m.allContacts[contact.ID] previousContact, ok := m.allContacts.Load(contact.ID)
if !ok { if !ok {
return false return false
} }
@ -314,7 +302,7 @@ func (m *Messenger) hasNicknameChanged(contact *Contact) bool {
} }
func (m *Messenger) removedContact(contact *Contact) bool { func (m *Messenger) removedContact(contact *Contact) bool {
previousContact, ok := m.allContacts[contact.ID] previousContact, ok := m.allContacts.Load(contact.ID)
if !ok { if !ok {
return false return false
} }

View File

@ -0,0 +1,224 @@
package protocol
import (
"sync"
"github.com/status-im/status-go/protocol/encryption/multidevice"
"github.com/status-im/status-go/protocol/protobuf"
)
/*
|--------------------------------------------------------------------------
| chatMap
|--------------------------------------------------------------------------
|
| A sync.Map wrapper for a specific mapping of map[string]*Chat
|
*/
type chatMap struct {
sm sync.Map
}
func (cm *chatMap) Load(chatID string) (*Chat, bool) {
chat, ok := cm.sm.Load(chatID)
if chat == nil {
return nil, ok
}
return chat.(*Chat), ok
}
func (cm *chatMap) Store(chatID string, chat *Chat) {
cm.sm.Store(chatID, chat)
}
func (cm *chatMap) Range(f func(chatID string, chat *Chat) (shouldContinue bool)) {
nf := func(key, value interface{}) (shouldContinue bool) {
return f(key.(string), value.(*Chat))
}
cm.sm.Range(nf)
}
func (cm *chatMap) Delete(chatID string) {
cm.sm.Delete(chatID)
}
/*
|--------------------------------------------------------------------------
| contactMap
|--------------------------------------------------------------------------
|
| A sync.Map wrapper for a specific mapping of map[string]*Contact
|
*/
type contactMap struct {
sm sync.Map
}
func (cm *contactMap) Load(contactID string) (*Contact, bool) {
contact, ok := cm.sm.Load(contactID)
if contact == nil {
return nil, ok
}
return contact.(*Contact), ok
}
func (cm *contactMap) Store(contactID string, contact *Contact) {
cm.sm.Store(contactID, contact)
}
func (cm *contactMap) Range(f func(contactID string, contact *Contact) (shouldContinue bool)) {
nf := func(key, value interface{}) (shouldContinue bool) {
return f(key.(string), value.(*Contact))
}
cm.sm.Range(nf)
}
func (cm *contactMap) Delete(contactID string) {
cm.sm.Delete(contactID)
}
/*
|--------------------------------------------------------------------------
| systemMessageTranslationsMap
|--------------------------------------------------------------------------
|
| A sync.Map wrapper for the specific mapping of map[protobuf.MembershipUpdateEvent_EventType]string
|
*/
type systemMessageTranslationsMap struct {
sm sync.Map
}
func (smtm *systemMessageTranslationsMap) Init(set map[protobuf.MembershipUpdateEvent_EventType]string) {
for eventType, message := range set {
smtm.Store(eventType, message)
}
}
func (smtm *systemMessageTranslationsMap) Load(eventType protobuf.MembershipUpdateEvent_EventType) (string, bool) {
message, ok := smtm.sm.Load(eventType)
if message == nil {
return "", ok
}
return message.(string), ok
}
func (smtm *systemMessageTranslationsMap) Store(eventType protobuf.MembershipUpdateEvent_EventType, message string) {
smtm.sm.Store(eventType, message)
}
func (smtm *systemMessageTranslationsMap) Range(f func(eventType protobuf.MembershipUpdateEvent_EventType, message string) (shouldContinue bool)) {
nf := func(key, value interface{}) (shouldContinue bool) {
return f(key.(protobuf.MembershipUpdateEvent_EventType), value.(string))
}
smtm.sm.Range(nf)
}
func (smtm *systemMessageTranslationsMap) Delete(eventType protobuf.MembershipUpdateEvent_EventType) {
smtm.sm.Delete(eventType)
}
/*
|--------------------------------------------------------------------------
| installationMap
|--------------------------------------------------------------------------
|
| A sync.Map wrapper for the specific mapping of map[string]*multidevice.Installation
|
*/
type installationMap struct {
sm sync.Map
}
func (im *installationMap) Load(installationID string) (*multidevice.Installation, bool) {
installation, ok := im.sm.Load(installationID)
if installation == nil {
return nil, ok
}
return installation.(*multidevice.Installation), ok
}
func (im *installationMap) Store(installationID string, installation *multidevice.Installation) {
im.sm.Store(installationID, installation)
}
func (im *installationMap) Range(f func(installationID string, installation *multidevice.Installation) (shouldContinue bool)) {
nf := func(key, value interface{}) (shouldContinue bool) {
return f(key.(string), value.(*multidevice.Installation))
}
im.sm.Range(nf)
}
func (im *installationMap) Delete(installationID string) {
im.sm.Delete(installationID)
}
func (im *installationMap) Empty() bool {
count := 0
im.Range(func(installationID string, installation *multidevice.Installation) (shouldContinue bool) {
count++
return false
})
return count == 0
}
func (im *installationMap) Len() int {
count := 0
im.Range(func(installationID string, installation *multidevice.Installation) (shouldContinue bool) {
count++
return true
})
return count
}
/*
|--------------------------------------------------------------------------
| stringBoolMap
|--------------------------------------------------------------------------
|
| A sync.Map wrapper for the specific mapping of map[string]bool
|
*/
type stringBoolMap struct {
sm sync.Map
}
func (sbm *stringBoolMap) Load(key string) (bool, bool) {
state, ok := sbm.sm.Load(key)
if state == nil {
return false, ok
}
return state.(bool), ok
}
func (sbm *stringBoolMap) Store(key string, value bool) {
sbm.sm.Store(key, value)
}
func (sbm *stringBoolMap) Range(f func(key string, value bool) (shouldContinue bool)) {
nf := func(key, value interface{}) (shouldContinue bool) {
return f(key.(string), value.(bool))
}
sbm.sm.Range(nf)
}
func (sbm *stringBoolMap) Delete(key string) {
sbm.sm.Delete(key)
}
func (sbm *stringBoolMap) Len() int {
count := 0
sbm.Range(func(key string, value bool) (shouldContinue bool) {
count++
return true
})
return count
}

View File

@ -2508,9 +2508,9 @@ func (s *MessageHandlerSuite) TestRun() {
for idx, tc := range testCases { for idx, tc := range testCases {
s.Run(tc.Name, func() { s.Run(tc.Name, func() {
chatsMap := make(map[string]*Chat) chatsMap := new(chatMap)
if tc.Chat != nil && tc.Chat.ID != "" { if tc.Chat != nil && tc.Chat.ID != "" {
chatsMap[tc.Chat.ID] = tc.Chat chatsMap.Store(tc.Chat.ID, tc.Chat)
} }
message := tc.Message message := tc.Message