make Chat and Messenger thread-safe (#8)

This commit is contained in:
Adam Babik 2019-03-13 21:20:22 +01:00 committed by GitHub
parent 179241a572
commit 0d773353bc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 234 additions and 115 deletions

46
chat.go
View File

@ -46,20 +46,33 @@ func (c *ChatViewController) readEventsLoop() {
defer close(c.done) defer close(c.done)
for { for {
log.Printf("[ChatViewController::readEventsLoops] waiting for events")
select { select {
case event := <-c.messenger.Events(): case event := <-c.messenger.Events():
log.Printf("received an event: %+v", event) log.Printf("[ChatViewController::readEventsLoops] received an event: %+v", event)
switch ev := event.(type) { switch ev := event.(type) {
case client.EventError: case client.EventError:
c.notifications.Error("error", ev.Error().Error()) // nolint: errcheck c.notifications.Error("error", ev.Error().Error()) // nolint: errcheck
case client.EventMessage:
c.printMessages(false, ev.Message())
case client.Event: case client.Event:
messages, err := c.messenger.Messages(c.contact) if ev.Type() != client.EventTypeInit {
if err != nil {
c.notifications.Error("getting messages", err.Error()) // nolint: errcheck
break break
} }
c.printMessages(messages)
chat := c.messenger.Chat(c.contact)
if chat == nil {
c.notifications.Error("getting chat", "chat does not exist") // nolint: errcheck
break
}
messages := chat.Messages()
log.Printf("[ChatViewController::readEventsLoops] retrieved %d messages", len(messages))
c.printMessages(true, messages...)
} }
case <-c.cancel: case <-c.cancel:
return return
@ -70,7 +83,7 @@ func (c *ChatViewController) readEventsLoop() {
// Select informs the chat view controller about a selected contact. // Select informs the chat view controller about a selected contact.
// The chat view controller setup subscribers and request recent messages. // The chat view controller setup subscribers and request recent messages.
func (c *ChatViewController) Select(contact client.Contact) error { func (c *ChatViewController) Select(contact client.Contact) error {
log.Printf("selected contact %s", contact.Name) log.Printf("[ChatViewController::Select] contact %s", contact.Name)
if c.cancel == nil { if c.cancel == nil {
c.cancel = make(chan struct{}) c.cancel = make(chan struct{})
@ -87,21 +100,32 @@ func (c *ChatViewController) RequestMessages(params protocol.RequestOptions) err
"REQUEST", "REQUEST",
fmt.Sprintf("get historic messages: %+v", params), fmt.Sprintf("get historic messages: %+v", params),
) )
return c.messenger.Request(c.contact, params)
chat := c.messenger.Chat(c.contact)
if chat == nil {
return fmt.Errorf("chat not found")
}
return chat.Request(params)
} }
func (c *ChatViewController) Send(data []byte) error { func (c *ChatViewController) Send(data []byte) error {
return c.messenger.Send(c.contact, data) chat := c.messenger.Chat(c.contact)
if chat == nil {
return fmt.Errorf("chat not found")
}
return chat.Send(data)
} }
func (c *ChatViewController) printMessages(messages []*protocol.Message) { func (c *ChatViewController) printMessages(clear bool, messages ...*protocol.Message) {
c.g.Update(func(*gocui.Gui) error { c.g.Update(func(*gocui.Gui) error {
if clear {
if err := c.Clear(); err != nil { if err := c.Clear(); err != nil {
return err return err
} }
}
for _, message := range messages { for _, message := range messages {
if err := c.printMessage(message); err != nil { if err := c.writeMessage(message); err != nil {
return err return err
} }
} }
@ -109,7 +133,7 @@ func (c *ChatViewController) printMessages(messages []*protocol.Message) {
}) })
} }
func (c *ChatViewController) printMessage(message *protocol.Message) error { func (c *ChatViewController) writeMessage(message *protocol.Message) error {
myPubKey := c.identity.PublicKey myPubKey := c.identity.PublicKey
pubKey := message.SigPubKey pubKey := message.SigPubKey

12
main.go
View File

@ -207,7 +207,17 @@ func main() {
if !ok { if !ok {
return errors.New("contact could not be found") return errors.New("contact could not be found")
} }
return chat.Select(contact)
// We need to call Select asynchronously,
// otherwise the main thread is blocked
// and nothing is rendered.
go func() {
if err := chat.Select(contact); err != nil {
log.Printf("[GetLineHandler] error selecting a chat: %v", err)
}
}()
return nil
}), }),
}, },
}, },

View File

@ -10,7 +10,7 @@ import (
) )
func init() { func init() {
if err := logutils.OverrideRootLog(true, "DEBUG", "", false); err != nil { if err := logutils.OverrideRootLog(true, "INFO", "", false); err != nil {
stdlog.Fatalf("failed to override root log: %v\n", err) stdlog.Fatalf("failed to override root log: %v\n", err)
} }
} }

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"crypto/ecdsa" "crypto/ecdsa"
"encoding/hex" "encoding/hex"
"log"
"sort" "sort"
"strings" "strings"
"sync" "sync"
@ -34,6 +35,7 @@ type Chat struct {
lastClock int64 lastClock int64
ownMessages chan *protocol.Message // my private messages channel ownMessages chan *protocol.Message // my private messages channel
// TODO: make it a ring buffer
messages []*protocol.Message // all messages ordered by Clock messages []*protocol.Message // all messages ordered by Clock
messagesByHash map[string]*protocol.Message // quick access to messages by hash messagesByHash map[string]*protocol.Message // quick access to messages by hash
} }
@ -73,12 +75,15 @@ func (c *Chat) Messages() []*protocol.Message {
} }
// Subscribe reads messages from the network. // Subscribe reads messages from the network.
func (c *Chat) Subscribe() error { // TODO: change method name to Join().
c.Lock() func (c *Chat) Subscribe() (err error) {
defer c.Unlock() c.RLock()
sub := c.sub
c.RUnlock()
if c.sub != nil { if sub != nil {
return errors.New("already subscribed") err = errors.New("already subscribed")
return
} }
opts := protocol.SubscribeOptions{} opts := protocol.SubscribeOptions{}
@ -88,26 +93,29 @@ func (c *Chat) Subscribe() error {
opts.Identity = c.identity opts.Identity = c.identity
} }
var err error
messages := make(chan *protocol.Message) messages := make(chan *protocol.Message)
c.sub, err = c.proto.Subscribe(context.Background(), messages, opts) sub, err = c.proto.Subscribe(context.Background(), messages, opts)
if err != nil { if err != nil {
return errors.Wrap(err, "failed to subscribe") err = errors.Wrap(err, "failed to subscribe")
return
} }
go func() { c.Lock()
// Send at least one event to kick off the logic. c.sub = sub
// TODO: change type of the event. c.Unlock()
c.events <- baseEvent{contact: c.contact, typ: EventTypeMessage}
}()
cancel := make(chan struct{}) // can be closed by any loop cancel := make(chan struct{}) // can be closed by any loop
go c.readLoop(messages, c.sub, cancel) go c.readLoop(messages, sub, cancel)
go c.readOwnMessagesLoop(c.ownMessages, cancel) go c.readOwnMessagesLoop(c.ownMessages, cancel)
// Load should have it's own lock.
return c.Load()
}
// Load loads messages from the database cache and the network.
func (c *Chat) Load() error {
params := protocol.DefaultRequestOptions() params := protocol.DefaultRequestOptions()
// Get already cached messages from the database. // Get already cached messages from the database.
@ -120,10 +128,14 @@ func (c *Chat) Subscribe() error {
return errors.Wrap(err, "db failed to get messages") return errors.Wrap(err, "db failed to get messages")
} }
c.Lock()
c.handleMessages(cachedMessages...)
c.Unlock()
go func() { go func() {
for _, m := range cachedMessages { log.Printf("[Chat::Subscribe] sending EventTypeInit")
messages <- m c.events <- baseEvent{contact: c.contact, typ: EventTypeInit}
} log.Printf("[Chat::Subscribe] sent EventTypeInit")
}() }()
if c.contact.Type == ContactPublicChat { if c.contact.Type == ContactPublicChat {
@ -132,7 +144,7 @@ func (c *Chat) Subscribe() error {
params.Recipient = c.contact.PublicKey params.Recipient = c.contact.PublicKey
} }
// Request historic messages from the network. // Request historic messages from the network.
if err := c.proto.Request(context.Background(), params); err != nil { if err := c.request(params); err != nil {
return errors.Wrap(err, "failed to request for messages") return errors.Wrap(err, "failed to request for messages")
} }
@ -143,15 +155,17 @@ func (c *Chat) Subscribe() error {
func (c *Chat) Unsubscribe() { func (c *Chat) Unsubscribe() {
c.RLock() c.RLock()
defer c.RUnlock() defer c.RUnlock()
if c.sub != nil {
if c.sub == nil {
return
}
c.sub.Unsubscribe() c.sub.Unsubscribe()
} }
}
// Request sends a request for historic messages. // Request sends a request for historic messages.
func (c *Chat) Request(params protocol.RequestOptions) error { func (c *Chat) Request(params protocol.RequestOptions) error {
return c.request(params)
}
func (c *Chat) request(params protocol.RequestOptions) error {
return c.proto.Request(context.Background(), params) return c.proto.Request(context.Background(), params)
} }
@ -187,12 +201,16 @@ func (c *Chat) Send(data []byte) error {
return errors.Wrap(err, "failed to encode message") return errors.Wrap(err, "failed to encode message")
} }
c.Lock()
c.updateLastClock(clock) c.updateLastClock(clock)
c.Unlock()
hash, err := c.proto.Send(context.Background(), encodedMessage, opts) hash, err := c.proto.Send(context.Background(), encodedMessage, opts)
// Own messages need to be pushed manually to the pipeline. // Own messages need to be pushed manually to the pipeline.
if c.contact.Type == ContactPrivateChat { if c.contact.Type == ContactPrivateChat {
log.Printf("[Chat::Send] sent a private message")
c.ownMessages <- &protocol.Message{ c.ownMessages <- &protocol.Message{
Decoded: message, Decoded: message,
SigPubKey: &c.identity.PublicKey, SigPubKey: &c.identity.PublicKey,
@ -209,11 +227,28 @@ func (c *Chat) readLoop(messages <-chan *protocol.Message, sub *protocol.Subscri
for { for {
select { select {
case m := <-messages: case m := <-messages:
if err := c.handleMessage(m); err != nil { if c.HasMessage(m) {
break
}
c.Lock()
c.handleMessages(m)
c.Unlock()
if err := c.saveMessages(m); err != nil {
c.Lock()
c.err = err c.err = err
c.Unlock()
return return
} }
c.events <- baseEvent{contact: c.contact, typ: EventTypeMessage}
c.events <- messageEvent{
baseEvent: baseEvent{
contact: c.contact,
typ: EventTypeMessage,
},
message: m,
}
case <-sub.Done(): case <-sub.Done():
c.err = sub.Err() c.err = sub.Err()
return return
@ -229,39 +264,53 @@ func (c *Chat) readOwnMessagesLoop(messages <-chan *protocol.Message, cancel cha
for { for {
select { select {
case m := <-messages: case m := <-messages:
if err := c.handleMessage(m); err != nil { if c.HasMessage(m) {
break
}
c.Lock()
c.handleMessages(m)
c.Unlock()
if err := c.saveMessages(m); err != nil {
c.Lock()
c.err = err c.err = err
c.Unlock()
return return
} }
c.events <- baseEvent{contact: c.contact, typ: EventTypeMessage}
c.events <- messageEvent{
baseEvent: baseEvent{
contact: c.contact,
typ: EventTypeMessage,
},
message: m,
}
case <-cancel: case <-cancel:
return return
} }
} }
} }
func (c *Chat) handleMessage(message *protocol.Message) error { func (c *Chat) handleMessages(messages ...*protocol.Message) {
lessFn := func(i, j int) bool { for _, message := range messages {
return c.messages[i].Decoded.Clock < c.messages[j].Decoded.Clock
}
hash := hex.EncodeToString(message.Hash)
// the message already exists
if _, ok := c.messagesByHash[hash]; ok {
return nil
}
c.updateLastClock(message.Decoded.Clock) c.updateLastClock(message.Decoded.Clock)
hash := messageHashStr(message)
c.messagesByHash[hash] = message c.messagesByHash[hash] = message
c.messages = append(c.messages, message) c.messages = append(c.messages, message)
isSorted := sort.SliceIsSorted(c.messages, lessFn) sort.Slice(c.messages, c.lessFn)
if !isSorted { }
sort.Slice(c.messages, lessFn)
} }
return c.db.SaveMessages(c.contact, message) func (c *Chat) saveMessages(messages ...*protocol.Message) error {
return c.db.SaveMessages(c.contact, messages)
}
func (c *Chat) lessFn(i, j int) bool {
return c.messages[i].Decoded.Clock < c.messages[j].Decoded.Clock
} }
func (c *Chat) updateLastClock(clock int64) { func (c *Chat) updateLastClock(clock int64) {
@ -269,3 +318,20 @@ func (c *Chat) updateLastClock(clock int64) {
c.lastClock = clock c.lastClock = clock
} }
} }
func (c *Chat) hasMessage(m *protocol.Message) bool {
hash := messageHashStr(m)
_, ok := c.messagesByHash[hash]
return ok
}
// HasMessage returns true if a given message is already cached.
func (c *Chat) HasMessage(m *protocol.Message) bool {
c.Lock()
defer c.Unlock()
return c.hasMessage(m)
}
func messageHashStr(m *protocol.Message) string {
return hex.EncodeToString(m.Hash)
}

View File

@ -80,7 +80,7 @@ func (d *Database) Messages(c Contact, from, to int64) (result []*protocol.Messa
} }
// SaveMessages stores messages on a disk. // SaveMessages stores messages on a disk.
func (d *Database) SaveMessages(c Contact, messages ...*protocol.Message) error { func (d *Database) SaveMessages(c Contact, messages []*protocol.Message) error {
var buf bytes.Buffer var buf bytes.Buffer
enc := gob.NewEncoder(&buf) enc := gob.NewEncoder(&buf)

View File

@ -1,7 +1,10 @@
package client package client
import "github.com/status-im/status-console-client/protocol/v1"
const ( const (
EventTypeMessage int = iota + 1 EventTypeInit int = iota + 1
EventTypeMessage
EventTypeError EventTypeError
) )
@ -15,6 +18,11 @@ type EventError interface {
Error() error Error() error
} }
type EventMessage interface {
Event
Message() *protocol.Message
}
type baseEvent struct { type baseEvent struct {
contact Contact contact Contact
typ int typ int
@ -23,6 +31,20 @@ type baseEvent struct {
func (e baseEvent) Contact() Contact { return e.contact } func (e baseEvent) Contact() Contact { return e.contact }
func (e baseEvent) Type() int { return e.typ } func (e baseEvent) Type() int { return e.typ }
type errorEvent struct {
baseEvent
err error
}
func (e errorEvent) Error() error { return e.err }
type messageEvent struct {
baseEvent
message *protocol.Message
}
func (e messageEvent) Message() *protocol.Message { return e.message }
// type eventError struct { // type eventError struct {
// Event // Event
// err error // err error

View File

@ -2,6 +2,7 @@ package client
import ( import (
"crypto/ecdsa" "crypto/ecdsa"
"log"
"sync" "sync"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -10,12 +11,13 @@ import (
// Messenger coordinates chats. // Messenger coordinates chats.
type Messenger struct { type Messenger struct {
sync.RWMutex
proto protocol.Chat proto protocol.Chat
identity *ecdsa.PrivateKey identity *ecdsa.PrivateKey
db *Database db *Database
chats map[Contact]*Chat chats map[Contact]*Chat
wg sync.WaitGroup
events chan interface{} events chan interface{}
} }
@ -33,76 +35,70 @@ func NewMessenger(proto protocol.Chat, identity *ecdsa.PrivateKey, db *Database)
// Events returns a channel with chat events. // Events returns a channel with chat events.
func (m *Messenger) Events() <-chan interface{} { func (m *Messenger) Events() <-chan interface{} {
m.RLock()
defer m.RUnlock()
return m.events return m.events
} }
func (m *Messenger) Chat(c Contact) *Chat {
m.RLock()
defer m.RUnlock()
return m.chats[c]
}
// Join creates a new chat and creates a subscription. // Join creates a new chat and creates a subscription.
func (m *Messenger) Join(contact Contact) error { func (m *Messenger) Join(contact Contact) error {
chat := NewChat(m.proto, m.identity, contact, m.db) m.RLock()
chat, found := m.chats[contact]
m.RUnlock()
if err := chat.Subscribe(); err != nil { if found {
return err return chat.Load()
} }
chat = NewChat(m.proto, m.identity, contact, m.db)
m.Lock()
m.chats[contact] = chat m.chats[contact] = chat
m.Unlock()
m.wg.Add(1) go func(events <-chan interface{}) {
go func() { log.Printf("[Messenger::Join] waiting for events")
defer m.wg.Done()
for ev := range chat.Events() { for ev := range events {
log.Printf("[Messenger::Join] received an event: %+v", ev)
m.events <- ev m.events <- ev
} }
if err := chat.Err(); err != nil { if err := chat.Err(); err != nil {
m.events <- baseEvent{contact: contact, typ: EventTypeError} m.events <- errorEvent{
baseEvent: baseEvent{contact: contact, typ: EventTypeError},
err: err,
} }
}() }
}(chat.Events())
return nil return chat.Subscribe()
} }
// Leave unsubscribes from the chat. // Leave unsubscribes from the chat.
func (m *Messenger) Leave(contact Contact) error { func (m *Messenger) Leave(contact Contact) error {
m.RLock()
chat, ok := m.chats[contact] chat, ok := m.chats[contact]
m.RUnlock()
if !ok { if !ok {
return errors.New("chat for the contact not found") return errors.New("chat for the contact not found")
} }
chat.Unsubscribe() chat.Unsubscribe()
m.Lock()
delete(m.chats, contact) delete(m.chats, contact)
m.Unlock()
return nil return nil
} }
// Messages returns a list of messages for a given contact.
func (m *Messenger) Messages(contact Contact) ([]*protocol.Message, error) {
chat, ok := m.chats[contact]
if !ok {
return nil, errors.New("chat for the contact not found")
}
return chat.Messages(), nil
}
func (m *Messenger) Request(contact Contact, params protocol.RequestOptions) error {
chat, ok := m.chats[contact]
if !ok {
return errors.New("chat for the contact not found")
}
return chat.Request(params)
}
func (m *Messenger) Send(contact Contact, data []byte) error {
chat, ok := m.chats[contact]
if !ok {
return errors.New("chat for the contact not found")
}
return chat.Send(data)
}
func (m *Messenger) Contacts() ([]Contact, error) { func (m *Messenger) Contacts() ([]Contact, error) {
return m.db.Contacts() return m.db.Contacts()
} }
@ -118,7 +114,6 @@ func (m *Messenger) AddContact(c Contact) error {
} }
contacts = append(contacts, c) contacts = append(contacts, c)
return m.db.SaveContacts(contacts) return m.db.SaveContacts(contacts)
} }
@ -129,14 +124,17 @@ func (m *Messenger) RemoveContact(c Contact) error {
} }
for i, item := range contacts { for i, item := range contacts {
if item == c { if item != c {
continue
}
copy(contacts[i:], contacts[i+1:]) copy(contacts[i:], contacts[i+1:])
contacts[len(contacts)-1] = Contact{} contacts[len(contacts)-1] = Contact{}
contacts = contacts[:len(contacts)-1] contacts = contacts[:len(contacts)-1]
}
break
} }
contacts = append(contacts, c) contacts = append(contacts, c)
return m.db.SaveContacts(contacts) return m.db.SaveContacts(contacts)
} }

View File

@ -235,10 +235,11 @@ func (a *WhisperServiceAdapter) requestMessages(ctx context.Context, enode strin
return err return err
} }
_, err = shhextAPI.RequestMessages(ctx, req) return shhextAPI.RequestMessagesSync(shhext.RetryConfig{
// TODO: wait for the request to finish before returning. BaseTimeout: time.Second * 10,
// Use a different method or relay on signals. StepTimeout: time.Second,
return err MaxRetries: 3,
}, req)
} }
func (a *WhisperServiceAdapter) createMessagesRequest( func (a *WhisperServiceAdapter) createMessagesRequest(
@ -297,8 +298,6 @@ func (s whisperSubscription) Messages() ([]*Message, error) {
result := make([]*Message, 0, len(items)) result := make([]*Message, 0, len(items))
for _, item := range items { for _, item := range items {
log.Printf("retrieve a message with ID %s", item.EnvelopeHash.String())
decoded, err := DecodeMessage(item.Payload) decoded, err := DecodeMessage(item.Payload)
if err != nil { if err != nil {
log.Printf("failed to decode message: %v", err) log.Printf("failed to decode message: %v", err)