make Chat and Messenger thread-safe (#8)

This commit is contained in:
Adam Babik 2019-03-13 21:20:22 +01:00 committed by GitHub
parent 179241a572
commit 0d773353bc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 234 additions and 115 deletions

50
chat.go
View File

@ -46,20 +46,33 @@ func (c *ChatViewController) readEventsLoop() {
defer close(c.done)
for {
log.Printf("[ChatViewController::readEventsLoops] waiting for events")
select {
case event := <-c.messenger.Events():
log.Printf("received an event: %+v", event)
log.Printf("[ChatViewController::readEventsLoops] received an event: %+v", event)
switch ev := event.(type) {
case client.EventError:
c.notifications.Error("error", ev.Error().Error()) // nolint: errcheck
case client.EventMessage:
c.printMessages(false, ev.Message())
case client.Event:
messages, err := c.messenger.Messages(c.contact)
if err != nil {
c.notifications.Error("getting messages", err.Error()) // nolint: errcheck
if ev.Type() != client.EventTypeInit {
break
}
c.printMessages(messages)
chat := c.messenger.Chat(c.contact)
if chat == nil {
c.notifications.Error("getting chat", "chat does not exist") // nolint: errcheck
break
}
messages := chat.Messages()
log.Printf("[ChatViewController::readEventsLoops] retrieved %d messages", len(messages))
c.printMessages(true, messages...)
}
case <-c.cancel:
return
@ -70,7 +83,7 @@ func (c *ChatViewController) readEventsLoop() {
// Select informs the chat view controller about a selected contact.
// The chat view controller setup subscribers and request recent messages.
func (c *ChatViewController) Select(contact client.Contact) error {
log.Printf("selected contact %s", contact.Name)
log.Printf("[ChatViewController::Select] contact %s", contact.Name)
if c.cancel == nil {
c.cancel = make(chan struct{})
@ -87,21 +100,32 @@ func (c *ChatViewController) RequestMessages(params protocol.RequestOptions) err
"REQUEST",
fmt.Sprintf("get historic messages: %+v", params),
)
return c.messenger.Request(c.contact, params)
chat := c.messenger.Chat(c.contact)
if chat == nil {
return fmt.Errorf("chat not found")
}
return chat.Request(params)
}
func (c *ChatViewController) Send(data []byte) error {
return c.messenger.Send(c.contact, data)
chat := c.messenger.Chat(c.contact)
if chat == nil {
return fmt.Errorf("chat not found")
}
return chat.Send(data)
}
func (c *ChatViewController) printMessages(messages []*protocol.Message) {
func (c *ChatViewController) printMessages(clear bool, messages ...*protocol.Message) {
c.g.Update(func(*gocui.Gui) error {
if err := c.Clear(); err != nil {
return err
if clear {
if err := c.Clear(); err != nil {
return err
}
}
for _, message := range messages {
if err := c.printMessage(message); err != nil {
if err := c.writeMessage(message); err != nil {
return err
}
}
@ -109,7 +133,7 @@ func (c *ChatViewController) printMessages(messages []*protocol.Message) {
})
}
func (c *ChatViewController) printMessage(message *protocol.Message) error {
func (c *ChatViewController) writeMessage(message *protocol.Message) error {
myPubKey := c.identity.PublicKey
pubKey := message.SigPubKey

12
main.go
View File

@ -207,7 +207,17 @@ func main() {
if !ok {
return errors.New("contact could not be found")
}
return chat.Select(contact)
// We need to call Select asynchronously,
// otherwise the main thread is blocked
// and nothing is rendered.
go func() {
if err := chat.Select(contact); err != nil {
log.Printf("[GetLineHandler] error selecting a chat: %v", err)
}
}()
return nil
}),
},
},

View File

@ -10,7 +10,7 @@ import (
)
func init() {
if err := logutils.OverrideRootLog(true, "DEBUG", "", false); err != nil {
if err := logutils.OverrideRootLog(true, "INFO", "", false); err != nil {
stdlog.Fatalf("failed to override root log: %v\n", err)
}
}

View File

@ -4,6 +4,7 @@ import (
"context"
"crypto/ecdsa"
"encoding/hex"
"log"
"sort"
"strings"
"sync"
@ -33,7 +34,8 @@ type Chat struct {
lastClock int64
ownMessages chan *protocol.Message // my private messages channel
ownMessages chan *protocol.Message // my private messages channel
// TODO: make it a ring buffer
messages []*protocol.Message // all messages ordered by Clock
messagesByHash map[string]*protocol.Message // quick access to messages by hash
}
@ -73,12 +75,15 @@ func (c *Chat) Messages() []*protocol.Message {
}
// Subscribe reads messages from the network.
func (c *Chat) Subscribe() error {
c.Lock()
defer c.Unlock()
// TODO: change method name to Join().
func (c *Chat) Subscribe() (err error) {
c.RLock()
sub := c.sub
c.RUnlock()
if c.sub != nil {
return errors.New("already subscribed")
if sub != nil {
err = errors.New("already subscribed")
return
}
opts := protocol.SubscribeOptions{}
@ -88,26 +93,29 @@ func (c *Chat) Subscribe() error {
opts.Identity = c.identity
}
var err error
messages := make(chan *protocol.Message)
c.sub, err = c.proto.Subscribe(context.Background(), messages, opts)
sub, err = c.proto.Subscribe(context.Background(), messages, opts)
if err != nil {
return errors.Wrap(err, "failed to subscribe")
err = errors.Wrap(err, "failed to subscribe")
return
}
go func() {
// Send at least one event to kick off the logic.
// TODO: change type of the event.
c.events <- baseEvent{contact: c.contact, typ: EventTypeMessage}
}()
c.Lock()
c.sub = sub
c.Unlock()
cancel := make(chan struct{}) // can be closed by any loop
go c.readLoop(messages, c.sub, cancel)
go c.readLoop(messages, sub, cancel)
go c.readOwnMessagesLoop(c.ownMessages, cancel)
// Load should have it's own lock.
return c.Load()
}
// Load loads messages from the database cache and the network.
func (c *Chat) Load() error {
params := protocol.DefaultRequestOptions()
// Get already cached messages from the database.
@ -120,10 +128,14 @@ func (c *Chat) Subscribe() error {
return errors.Wrap(err, "db failed to get messages")
}
c.Lock()
c.handleMessages(cachedMessages...)
c.Unlock()
go func() {
for _, m := range cachedMessages {
messages <- m
}
log.Printf("[Chat::Subscribe] sending EventTypeInit")
c.events <- baseEvent{contact: c.contact, typ: EventTypeInit}
log.Printf("[Chat::Subscribe] sent EventTypeInit")
}()
if c.contact.Type == ContactPublicChat {
@ -132,7 +144,7 @@ func (c *Chat) Subscribe() error {
params.Recipient = c.contact.PublicKey
}
// Request historic messages from the network.
if err := c.proto.Request(context.Background(), params); err != nil {
if err := c.request(params); err != nil {
return errors.Wrap(err, "failed to request for messages")
}
@ -143,15 +155,17 @@ func (c *Chat) Subscribe() error {
func (c *Chat) Unsubscribe() {
c.RLock()
defer c.RUnlock()
if c.sub == nil {
return
if c.sub != nil {
c.sub.Unsubscribe()
}
c.sub.Unsubscribe()
}
// Request sends a request for historic messages.
func (c *Chat) Request(params protocol.RequestOptions) error {
return c.request(params)
}
func (c *Chat) request(params protocol.RequestOptions) error {
return c.proto.Request(context.Background(), params)
}
@ -187,12 +201,16 @@ func (c *Chat) Send(data []byte) error {
return errors.Wrap(err, "failed to encode message")
}
c.Lock()
c.updateLastClock(clock)
c.Unlock()
hash, err := c.proto.Send(context.Background(), encodedMessage, opts)
// Own messages need to be pushed manually to the pipeline.
if c.contact.Type == ContactPrivateChat {
log.Printf("[Chat::Send] sent a private message")
c.ownMessages <- &protocol.Message{
Decoded: message,
SigPubKey: &c.identity.PublicKey,
@ -209,11 +227,28 @@ func (c *Chat) readLoop(messages <-chan *protocol.Message, sub *protocol.Subscri
for {
select {
case m := <-messages:
if err := c.handleMessage(m); err != nil {
if c.HasMessage(m) {
break
}
c.Lock()
c.handleMessages(m)
c.Unlock()
if err := c.saveMessages(m); err != nil {
c.Lock()
c.err = err
c.Unlock()
return
}
c.events <- baseEvent{contact: c.contact, typ: EventTypeMessage}
c.events <- messageEvent{
baseEvent: baseEvent{
contact: c.contact,
typ: EventTypeMessage,
},
message: m,
}
case <-sub.Done():
c.err = sub.Err()
return
@ -229,39 +264,53 @@ func (c *Chat) readOwnMessagesLoop(messages <-chan *protocol.Message, cancel cha
for {
select {
case m := <-messages:
if err := c.handleMessage(m); err != nil {
if c.HasMessage(m) {
break
}
c.Lock()
c.handleMessages(m)
c.Unlock()
if err := c.saveMessages(m); err != nil {
c.Lock()
c.err = err
c.Unlock()
return
}
c.events <- baseEvent{contact: c.contact, typ: EventTypeMessage}
c.events <- messageEvent{
baseEvent: baseEvent{
contact: c.contact,
typ: EventTypeMessage,
},
message: m,
}
case <-cancel:
return
}
}
}
func (c *Chat) handleMessage(message *protocol.Message) error {
lessFn := func(i, j int) bool {
return c.messages[i].Decoded.Clock < c.messages[j].Decoded.Clock
func (c *Chat) handleMessages(messages ...*protocol.Message) {
for _, message := range messages {
c.updateLastClock(message.Decoded.Clock)
hash := messageHashStr(message)
c.messagesByHash[hash] = message
c.messages = append(c.messages, message)
sort.Slice(c.messages, c.lessFn)
}
hash := hex.EncodeToString(message.Hash)
}
// the message already exists
if _, ok := c.messagesByHash[hash]; ok {
return nil
}
func (c *Chat) saveMessages(messages ...*protocol.Message) error {
return c.db.SaveMessages(c.contact, messages)
}
c.updateLastClock(message.Decoded.Clock)
c.messagesByHash[hash] = message
c.messages = append(c.messages, message)
isSorted := sort.SliceIsSorted(c.messages, lessFn)
if !isSorted {
sort.Slice(c.messages, lessFn)
}
return c.db.SaveMessages(c.contact, message)
func (c *Chat) lessFn(i, j int) bool {
return c.messages[i].Decoded.Clock < c.messages[j].Decoded.Clock
}
func (c *Chat) updateLastClock(clock int64) {
@ -269,3 +318,20 @@ func (c *Chat) updateLastClock(clock int64) {
c.lastClock = clock
}
}
func (c *Chat) hasMessage(m *protocol.Message) bool {
hash := messageHashStr(m)
_, ok := c.messagesByHash[hash]
return ok
}
// HasMessage returns true if a given message is already cached.
func (c *Chat) HasMessage(m *protocol.Message) bool {
c.Lock()
defer c.Unlock()
return c.hasMessage(m)
}
func messageHashStr(m *protocol.Message) string {
return hex.EncodeToString(m.Hash)
}

View File

@ -80,7 +80,7 @@ func (d *Database) Messages(c Contact, from, to int64) (result []*protocol.Messa
}
// SaveMessages stores messages on a disk.
func (d *Database) SaveMessages(c Contact, messages ...*protocol.Message) error {
func (d *Database) SaveMessages(c Contact, messages []*protocol.Message) error {
var buf bytes.Buffer
enc := gob.NewEncoder(&buf)

View File

@ -1,7 +1,10 @@
package client
import "github.com/status-im/status-console-client/protocol/v1"
const (
EventTypeMessage int = iota + 1
EventTypeInit int = iota + 1
EventTypeMessage
EventTypeError
)
@ -15,6 +18,11 @@ type EventError interface {
Error() error
}
type EventMessage interface {
Event
Message() *protocol.Message
}
type baseEvent struct {
contact Contact
typ int
@ -23,6 +31,20 @@ type baseEvent struct {
func (e baseEvent) Contact() Contact { return e.contact }
func (e baseEvent) Type() int { return e.typ }
type errorEvent struct {
baseEvent
err error
}
func (e errorEvent) Error() error { return e.err }
type messageEvent struct {
baseEvent
message *protocol.Message
}
func (e messageEvent) Message() *protocol.Message { return e.message }
// type eventError struct {
// Event
// err error

View File

@ -2,6 +2,7 @@ package client
import (
"crypto/ecdsa"
"log"
"sync"
"github.com/pkg/errors"
@ -10,12 +11,13 @@ import (
// Messenger coordinates chats.
type Messenger struct {
sync.RWMutex
proto protocol.Chat
identity *ecdsa.PrivateKey
db *Database
chats map[Contact]*Chat
wg sync.WaitGroup
events chan interface{}
}
@ -33,76 +35,70 @@ func NewMessenger(proto protocol.Chat, identity *ecdsa.PrivateKey, db *Database)
// Events returns a channel with chat events.
func (m *Messenger) Events() <-chan interface{} {
m.RLock()
defer m.RUnlock()
return m.events
}
func (m *Messenger) Chat(c Contact) *Chat {
m.RLock()
defer m.RUnlock()
return m.chats[c]
}
// Join creates a new chat and creates a subscription.
func (m *Messenger) Join(contact Contact) error {
chat := NewChat(m.proto, m.identity, contact, m.db)
m.RLock()
chat, found := m.chats[contact]
m.RUnlock()
if err := chat.Subscribe(); err != nil {
return err
if found {
return chat.Load()
}
chat = NewChat(m.proto, m.identity, contact, m.db)
m.Lock()
m.chats[contact] = chat
m.Unlock()
m.wg.Add(1)
go func() {
defer m.wg.Done()
go func(events <-chan interface{}) {
log.Printf("[Messenger::Join] waiting for events")
for ev := range chat.Events() {
for ev := range events {
log.Printf("[Messenger::Join] received an event: %+v", ev)
m.events <- ev
}
if err := chat.Err(); err != nil {
m.events <- baseEvent{contact: contact, typ: EventTypeError}
m.events <- errorEvent{
baseEvent: baseEvent{contact: contact, typ: EventTypeError},
err: err,
}
}
}()
}(chat.Events())
return nil
return chat.Subscribe()
}
// Leave unsubscribes from the chat.
func (m *Messenger) Leave(contact Contact) error {
m.RLock()
chat, ok := m.chats[contact]
m.RUnlock()
if !ok {
return errors.New("chat for the contact not found")
}
chat.Unsubscribe()
m.Lock()
delete(m.chats, contact)
m.Unlock()
return nil
}
// Messages returns a list of messages for a given contact.
func (m *Messenger) Messages(contact Contact) ([]*protocol.Message, error) {
chat, ok := m.chats[contact]
if !ok {
return nil, errors.New("chat for the contact not found")
}
return chat.Messages(), nil
}
func (m *Messenger) Request(contact Contact, params protocol.RequestOptions) error {
chat, ok := m.chats[contact]
if !ok {
return errors.New("chat for the contact not found")
}
return chat.Request(params)
}
func (m *Messenger) Send(contact Contact, data []byte) error {
chat, ok := m.chats[contact]
if !ok {
return errors.New("chat for the contact not found")
}
return chat.Send(data)
}
func (m *Messenger) Contacts() ([]Contact, error) {
return m.db.Contacts()
}
@ -118,7 +114,6 @@ func (m *Messenger) AddContact(c Contact) error {
}
contacts = append(contacts, c)
return m.db.SaveContacts(contacts)
}
@ -129,14 +124,17 @@ func (m *Messenger) RemoveContact(c Contact) error {
}
for i, item := range contacts {
if item == c {
copy(contacts[i:], contacts[i+1:])
contacts[len(contacts)-1] = Contact{}
contacts = contacts[:len(contacts)-1]
if item != c {
continue
}
copy(contacts[i:], contacts[i+1:])
contacts[len(contacts)-1] = Contact{}
contacts = contacts[:len(contacts)-1]
break
}
contacts = append(contacts, c)
return m.db.SaveContacts(contacts)
}

View File

@ -235,10 +235,11 @@ func (a *WhisperServiceAdapter) requestMessages(ctx context.Context, enode strin
return err
}
_, err = shhextAPI.RequestMessages(ctx, req)
// TODO: wait for the request to finish before returning.
// Use a different method or relay on signals.
return err
return shhextAPI.RequestMessagesSync(shhext.RetryConfig{
BaseTimeout: time.Second * 10,
StepTimeout: time.Second,
MaxRetries: 3,
}, req)
}
func (a *WhisperServiceAdapter) createMessagesRequest(
@ -297,8 +298,6 @@ func (s whisperSubscription) Messages() ([]*Message, error) {
result := make([]*Message, 0, len(items))
for _, item := range items {
log.Printf("retrieve a message with ID %s", item.EnvelopeHash.String())
decoded, err := DecodeMessage(item.Payload)
if err != nil {
log.Printf("failed to decode message: %v", err)