From 0d773353bc2eab855b4b83e14c63be48ed25e48f Mon Sep 17 00:00:00 2001 From: Adam Babik Date: Wed, 13 Mar 2019 21:20:22 +0100 Subject: [PATCH] make Chat and Messenger thread-safe (#8) --- chat.go | 50 ++++++++--- main.go | 12 ++- node_config.go | 2 +- protocol/client/chat.go | 160 +++++++++++++++++++++++---------- protocol/client/database.go | 2 +- protocol/client/event.go | 24 ++++- protocol/client/messenger.go | 88 +++++++++--------- protocol/v1/whisper_service.go | 11 ++- 8 files changed, 234 insertions(+), 115 deletions(-) diff --git a/chat.go b/chat.go index ba05e50..e58464d 100644 --- a/chat.go +++ b/chat.go @@ -46,20 +46,33 @@ func (c *ChatViewController) readEventsLoop() { defer close(c.done) for { + log.Printf("[ChatViewController::readEventsLoops] waiting for events") + select { case event := <-c.messenger.Events(): - log.Printf("received an event: %+v", event) + log.Printf("[ChatViewController::readEventsLoops] received an event: %+v", event) switch ev := event.(type) { case client.EventError: c.notifications.Error("error", ev.Error().Error()) // nolint: errcheck + case client.EventMessage: + c.printMessages(false, ev.Message()) case client.Event: - messages, err := c.messenger.Messages(c.contact) - if err != nil { - c.notifications.Error("getting messages", err.Error()) // nolint: errcheck + if ev.Type() != client.EventTypeInit { break } - c.printMessages(messages) + + chat := c.messenger.Chat(c.contact) + if chat == nil { + c.notifications.Error("getting chat", "chat does not exist") // nolint: errcheck + break + } + + messages := chat.Messages() + + log.Printf("[ChatViewController::readEventsLoops] retrieved %d messages", len(messages)) + + c.printMessages(true, messages...) } case <-c.cancel: return @@ -70,7 +83,7 @@ func (c *ChatViewController) readEventsLoop() { // Select informs the chat view controller about a selected contact. // The chat view controller setup subscribers and request recent messages. func (c *ChatViewController) Select(contact client.Contact) error { - log.Printf("selected contact %s", contact.Name) + log.Printf("[ChatViewController::Select] contact %s", contact.Name) if c.cancel == nil { c.cancel = make(chan struct{}) @@ -87,21 +100,32 @@ func (c *ChatViewController) RequestMessages(params protocol.RequestOptions) err "REQUEST", fmt.Sprintf("get historic messages: %+v", params), ) - return c.messenger.Request(c.contact, params) + + chat := c.messenger.Chat(c.contact) + if chat == nil { + return fmt.Errorf("chat not found") + } + return chat.Request(params) } func (c *ChatViewController) Send(data []byte) error { - return c.messenger.Send(c.contact, data) + chat := c.messenger.Chat(c.contact) + if chat == nil { + return fmt.Errorf("chat not found") + } + return chat.Send(data) } -func (c *ChatViewController) printMessages(messages []*protocol.Message) { +func (c *ChatViewController) printMessages(clear bool, messages ...*protocol.Message) { c.g.Update(func(*gocui.Gui) error { - if err := c.Clear(); err != nil { - return err + if clear { + if err := c.Clear(); err != nil { + return err + } } for _, message := range messages { - if err := c.printMessage(message); err != nil { + if err := c.writeMessage(message); err != nil { return err } } @@ -109,7 +133,7 @@ func (c *ChatViewController) printMessages(messages []*protocol.Message) { }) } -func (c *ChatViewController) printMessage(message *protocol.Message) error { +func (c *ChatViewController) writeMessage(message *protocol.Message) error { myPubKey := c.identity.PublicKey pubKey := message.SigPubKey diff --git a/main.go b/main.go index 600fc8c..a8cd614 100644 --- a/main.go +++ b/main.go @@ -207,7 +207,17 @@ func main() { if !ok { return errors.New("contact could not be found") } - return chat.Select(contact) + + // We need to call Select asynchronously, + // otherwise the main thread is blocked + // and nothing is rendered. + go func() { + if err := chat.Select(contact); err != nil { + log.Printf("[GetLineHandler] error selecting a chat: %v", err) + } + }() + + return nil }), }, }, diff --git a/node_config.go b/node_config.go index 5a03594..422120f 100644 --- a/node_config.go +++ b/node_config.go @@ -10,7 +10,7 @@ import ( ) func init() { - if err := logutils.OverrideRootLog(true, "DEBUG", "", false); err != nil { + if err := logutils.OverrideRootLog(true, "INFO", "", false); err != nil { stdlog.Fatalf("failed to override root log: %v\n", err) } } diff --git a/protocol/client/chat.go b/protocol/client/chat.go index 9bdc19d..212e789 100644 --- a/protocol/client/chat.go +++ b/protocol/client/chat.go @@ -4,6 +4,7 @@ import ( "context" "crypto/ecdsa" "encoding/hex" + "log" "sort" "strings" "sync" @@ -33,7 +34,8 @@ type Chat struct { lastClock int64 - ownMessages chan *protocol.Message // my private messages channel + ownMessages chan *protocol.Message // my private messages channel + // TODO: make it a ring buffer messages []*protocol.Message // all messages ordered by Clock messagesByHash map[string]*protocol.Message // quick access to messages by hash } @@ -73,12 +75,15 @@ func (c *Chat) Messages() []*protocol.Message { } // Subscribe reads messages from the network. -func (c *Chat) Subscribe() error { - c.Lock() - defer c.Unlock() +// TODO: change method name to Join(). +func (c *Chat) Subscribe() (err error) { + c.RLock() + sub := c.sub + c.RUnlock() - if c.sub != nil { - return errors.New("already subscribed") + if sub != nil { + err = errors.New("already subscribed") + return } opts := protocol.SubscribeOptions{} @@ -88,26 +93,29 @@ func (c *Chat) Subscribe() error { opts.Identity = c.identity } - var err error - messages := make(chan *protocol.Message) - c.sub, err = c.proto.Subscribe(context.Background(), messages, opts) + sub, err = c.proto.Subscribe(context.Background(), messages, opts) if err != nil { - return errors.Wrap(err, "failed to subscribe") + err = errors.Wrap(err, "failed to subscribe") + return } - go func() { - // Send at least one event to kick off the logic. - // TODO: change type of the event. - c.events <- baseEvent{contact: c.contact, typ: EventTypeMessage} - }() + c.Lock() + c.sub = sub + c.Unlock() cancel := make(chan struct{}) // can be closed by any loop - go c.readLoop(messages, c.sub, cancel) + go c.readLoop(messages, sub, cancel) go c.readOwnMessagesLoop(c.ownMessages, cancel) + // Load should have it's own lock. + return c.Load() +} + +// Load loads messages from the database cache and the network. +func (c *Chat) Load() error { params := protocol.DefaultRequestOptions() // Get already cached messages from the database. @@ -120,10 +128,14 @@ func (c *Chat) Subscribe() error { return errors.Wrap(err, "db failed to get messages") } + c.Lock() + c.handleMessages(cachedMessages...) + c.Unlock() + go func() { - for _, m := range cachedMessages { - messages <- m - } + log.Printf("[Chat::Subscribe] sending EventTypeInit") + c.events <- baseEvent{contact: c.contact, typ: EventTypeInit} + log.Printf("[Chat::Subscribe] sent EventTypeInit") }() if c.contact.Type == ContactPublicChat { @@ -132,7 +144,7 @@ func (c *Chat) Subscribe() error { params.Recipient = c.contact.PublicKey } // Request historic messages from the network. - if err := c.proto.Request(context.Background(), params); err != nil { + if err := c.request(params); err != nil { return errors.Wrap(err, "failed to request for messages") } @@ -143,15 +155,17 @@ func (c *Chat) Subscribe() error { func (c *Chat) Unsubscribe() { c.RLock() defer c.RUnlock() - - if c.sub == nil { - return + if c.sub != nil { + c.sub.Unsubscribe() } - c.sub.Unsubscribe() } // Request sends a request for historic messages. func (c *Chat) Request(params protocol.RequestOptions) error { + return c.request(params) +} + +func (c *Chat) request(params protocol.RequestOptions) error { return c.proto.Request(context.Background(), params) } @@ -187,12 +201,16 @@ func (c *Chat) Send(data []byte) error { return errors.Wrap(err, "failed to encode message") } + c.Lock() c.updateLastClock(clock) + c.Unlock() hash, err := c.proto.Send(context.Background(), encodedMessage, opts) // Own messages need to be pushed manually to the pipeline. if c.contact.Type == ContactPrivateChat { + log.Printf("[Chat::Send] sent a private message") + c.ownMessages <- &protocol.Message{ Decoded: message, SigPubKey: &c.identity.PublicKey, @@ -209,11 +227,28 @@ func (c *Chat) readLoop(messages <-chan *protocol.Message, sub *protocol.Subscri for { select { case m := <-messages: - if err := c.handleMessage(m); err != nil { + if c.HasMessage(m) { + break + } + + c.Lock() + c.handleMessages(m) + c.Unlock() + + if err := c.saveMessages(m); err != nil { + c.Lock() c.err = err + c.Unlock() return } - c.events <- baseEvent{contact: c.contact, typ: EventTypeMessage} + + c.events <- messageEvent{ + baseEvent: baseEvent{ + contact: c.contact, + typ: EventTypeMessage, + }, + message: m, + } case <-sub.Done(): c.err = sub.Err() return @@ -229,39 +264,53 @@ func (c *Chat) readOwnMessagesLoop(messages <-chan *protocol.Message, cancel cha for { select { case m := <-messages: - if err := c.handleMessage(m); err != nil { + if c.HasMessage(m) { + break + } + + c.Lock() + c.handleMessages(m) + c.Unlock() + + if err := c.saveMessages(m); err != nil { + c.Lock() c.err = err + c.Unlock() return } - c.events <- baseEvent{contact: c.contact, typ: EventTypeMessage} + + c.events <- messageEvent{ + baseEvent: baseEvent{ + contact: c.contact, + typ: EventTypeMessage, + }, + message: m, + } case <-cancel: return } } } -func (c *Chat) handleMessage(message *protocol.Message) error { - lessFn := func(i, j int) bool { - return c.messages[i].Decoded.Clock < c.messages[j].Decoded.Clock +func (c *Chat) handleMessages(messages ...*protocol.Message) { + for _, message := range messages { + c.updateLastClock(message.Decoded.Clock) + + hash := messageHashStr(message) + + c.messagesByHash[hash] = message + c.messages = append(c.messages, message) + + sort.Slice(c.messages, c.lessFn) } - hash := hex.EncodeToString(message.Hash) +} - // the message already exists - if _, ok := c.messagesByHash[hash]; ok { - return nil - } +func (c *Chat) saveMessages(messages ...*protocol.Message) error { + return c.db.SaveMessages(c.contact, messages) +} - c.updateLastClock(message.Decoded.Clock) - - c.messagesByHash[hash] = message - c.messages = append(c.messages, message) - - isSorted := sort.SliceIsSorted(c.messages, lessFn) - if !isSorted { - sort.Slice(c.messages, lessFn) - } - - return c.db.SaveMessages(c.contact, message) +func (c *Chat) lessFn(i, j int) bool { + return c.messages[i].Decoded.Clock < c.messages[j].Decoded.Clock } func (c *Chat) updateLastClock(clock int64) { @@ -269,3 +318,20 @@ func (c *Chat) updateLastClock(clock int64) { c.lastClock = clock } } + +func (c *Chat) hasMessage(m *protocol.Message) bool { + hash := messageHashStr(m) + _, ok := c.messagesByHash[hash] + return ok +} + +// HasMessage returns true if a given message is already cached. +func (c *Chat) HasMessage(m *protocol.Message) bool { + c.Lock() + defer c.Unlock() + return c.hasMessage(m) +} + +func messageHashStr(m *protocol.Message) string { + return hex.EncodeToString(m.Hash) +} diff --git a/protocol/client/database.go b/protocol/client/database.go index aa9a95e..7dd38c7 100644 --- a/protocol/client/database.go +++ b/protocol/client/database.go @@ -80,7 +80,7 @@ func (d *Database) Messages(c Contact, from, to int64) (result []*protocol.Messa } // SaveMessages stores messages on a disk. -func (d *Database) SaveMessages(c Contact, messages ...*protocol.Message) error { +func (d *Database) SaveMessages(c Contact, messages []*protocol.Message) error { var buf bytes.Buffer enc := gob.NewEncoder(&buf) diff --git a/protocol/client/event.go b/protocol/client/event.go index 4d86d93..7d7ef49 100644 --- a/protocol/client/event.go +++ b/protocol/client/event.go @@ -1,7 +1,10 @@ package client +import "github.com/status-im/status-console-client/protocol/v1" + const ( - EventTypeMessage int = iota + 1 + EventTypeInit int = iota + 1 + EventTypeMessage EventTypeError ) @@ -15,6 +18,11 @@ type EventError interface { Error() error } +type EventMessage interface { + Event + Message() *protocol.Message +} + type baseEvent struct { contact Contact typ int @@ -23,6 +31,20 @@ type baseEvent struct { func (e baseEvent) Contact() Contact { return e.contact } func (e baseEvent) Type() int { return e.typ } +type errorEvent struct { + baseEvent + err error +} + +func (e errorEvent) Error() error { return e.err } + +type messageEvent struct { + baseEvent + message *protocol.Message +} + +func (e messageEvent) Message() *protocol.Message { return e.message } + // type eventError struct { // Event // err error diff --git a/protocol/client/messenger.go b/protocol/client/messenger.go index 773279f..ebd80de 100644 --- a/protocol/client/messenger.go +++ b/protocol/client/messenger.go @@ -2,6 +2,7 @@ package client import ( "crypto/ecdsa" + "log" "sync" "github.com/pkg/errors" @@ -10,12 +11,13 @@ import ( // Messenger coordinates chats. type Messenger struct { + sync.RWMutex + proto protocol.Chat identity *ecdsa.PrivateKey db *Database chats map[Contact]*Chat - wg sync.WaitGroup events chan interface{} } @@ -33,76 +35,70 @@ func NewMessenger(proto protocol.Chat, identity *ecdsa.PrivateKey, db *Database) // Events returns a channel with chat events. func (m *Messenger) Events() <-chan interface{} { + m.RLock() + defer m.RUnlock() return m.events } +func (m *Messenger) Chat(c Contact) *Chat { + m.RLock() + defer m.RUnlock() + return m.chats[c] +} + // Join creates a new chat and creates a subscription. func (m *Messenger) Join(contact Contact) error { - chat := NewChat(m.proto, m.identity, contact, m.db) + m.RLock() + chat, found := m.chats[contact] + m.RUnlock() - if err := chat.Subscribe(); err != nil { - return err + if found { + return chat.Load() } + chat = NewChat(m.proto, m.identity, contact, m.db) + + m.Lock() m.chats[contact] = chat + m.Unlock() - m.wg.Add(1) - go func() { - defer m.wg.Done() + go func(events <-chan interface{}) { + log.Printf("[Messenger::Join] waiting for events") - for ev := range chat.Events() { + for ev := range events { + log.Printf("[Messenger::Join] received an event: %+v", ev) m.events <- ev } if err := chat.Err(); err != nil { - m.events <- baseEvent{contact: contact, typ: EventTypeError} + m.events <- errorEvent{ + baseEvent: baseEvent{contact: contact, typ: EventTypeError}, + err: err, + } } - }() + }(chat.Events()) - return nil + return chat.Subscribe() } // Leave unsubscribes from the chat. func (m *Messenger) Leave(contact Contact) error { + m.RLock() chat, ok := m.chats[contact] + m.RUnlock() if !ok { return errors.New("chat for the contact not found") } chat.Unsubscribe() + + m.Lock() delete(m.chats, contact) + m.Unlock() return nil } -// Messages returns a list of messages for a given contact. -func (m *Messenger) Messages(contact Contact) ([]*protocol.Message, error) { - chat, ok := m.chats[contact] - if !ok { - return nil, errors.New("chat for the contact not found") - } - - return chat.Messages(), nil -} - -func (m *Messenger) Request(contact Contact, params protocol.RequestOptions) error { - chat, ok := m.chats[contact] - if !ok { - return errors.New("chat for the contact not found") - } - - return chat.Request(params) -} - -func (m *Messenger) Send(contact Contact, data []byte) error { - chat, ok := m.chats[contact] - if !ok { - return errors.New("chat for the contact not found") - } - - return chat.Send(data) -} - func (m *Messenger) Contacts() ([]Contact, error) { return m.db.Contacts() } @@ -118,7 +114,6 @@ func (m *Messenger) AddContact(c Contact) error { } contacts = append(contacts, c) - return m.db.SaveContacts(contacts) } @@ -129,14 +124,17 @@ func (m *Messenger) RemoveContact(c Contact) error { } for i, item := range contacts { - if item == c { - copy(contacts[i:], contacts[i+1:]) - contacts[len(contacts)-1] = Contact{} - contacts = contacts[:len(contacts)-1] + if item != c { + continue } + + copy(contacts[i:], contacts[i+1:]) + contacts[len(contacts)-1] = Contact{} + contacts = contacts[:len(contacts)-1] + + break } contacts = append(contacts, c) - return m.db.SaveContacts(contacts) } diff --git a/protocol/v1/whisper_service.go b/protocol/v1/whisper_service.go index b1c5437..1556cee 100644 --- a/protocol/v1/whisper_service.go +++ b/protocol/v1/whisper_service.go @@ -235,10 +235,11 @@ func (a *WhisperServiceAdapter) requestMessages(ctx context.Context, enode strin return err } - _, err = shhextAPI.RequestMessages(ctx, req) - // TODO: wait for the request to finish before returning. - // Use a different method or relay on signals. - return err + return shhextAPI.RequestMessagesSync(shhext.RetryConfig{ + BaseTimeout: time.Second * 10, + StepTimeout: time.Second, + MaxRetries: 3, + }, req) } func (a *WhisperServiceAdapter) createMessagesRequest( @@ -297,8 +298,6 @@ func (s whisperSubscription) Messages() ([]*Message, error) { result := make([]*Message, 0, len(items)) for _, item := range items { - log.Printf("retrieve a message with ID %s", item.EnvelopeHash.String()) - decoded, err := DecodeMessage(item.Payload) if err != nil { log.Printf("failed to decode message: %v", err)