mirror of
https://github.com/status-im/status-console-client.git
synced 2025-02-23 16:18:23 +00:00
make Chat and Messenger thread-safe (#8)
This commit is contained in:
parent
179241a572
commit
0d773353bc
50
chat.go
50
chat.go
@ -46,20 +46,33 @@ func (c *ChatViewController) readEventsLoop() {
|
||||
defer close(c.done)
|
||||
|
||||
for {
|
||||
log.Printf("[ChatViewController::readEventsLoops] waiting for events")
|
||||
|
||||
select {
|
||||
case event := <-c.messenger.Events():
|
||||
log.Printf("received an event: %+v", event)
|
||||
log.Printf("[ChatViewController::readEventsLoops] received an event: %+v", event)
|
||||
|
||||
switch ev := event.(type) {
|
||||
case client.EventError:
|
||||
c.notifications.Error("error", ev.Error().Error()) // nolint: errcheck
|
||||
case client.EventMessage:
|
||||
c.printMessages(false, ev.Message())
|
||||
case client.Event:
|
||||
messages, err := c.messenger.Messages(c.contact)
|
||||
if err != nil {
|
||||
c.notifications.Error("getting messages", err.Error()) // nolint: errcheck
|
||||
if ev.Type() != client.EventTypeInit {
|
||||
break
|
||||
}
|
||||
c.printMessages(messages)
|
||||
|
||||
chat := c.messenger.Chat(c.contact)
|
||||
if chat == nil {
|
||||
c.notifications.Error("getting chat", "chat does not exist") // nolint: errcheck
|
||||
break
|
||||
}
|
||||
|
||||
messages := chat.Messages()
|
||||
|
||||
log.Printf("[ChatViewController::readEventsLoops] retrieved %d messages", len(messages))
|
||||
|
||||
c.printMessages(true, messages...)
|
||||
}
|
||||
case <-c.cancel:
|
||||
return
|
||||
@ -70,7 +83,7 @@ func (c *ChatViewController) readEventsLoop() {
|
||||
// Select informs the chat view controller about a selected contact.
|
||||
// The chat view controller setup subscribers and request recent messages.
|
||||
func (c *ChatViewController) Select(contact client.Contact) error {
|
||||
log.Printf("selected contact %s", contact.Name)
|
||||
log.Printf("[ChatViewController::Select] contact %s", contact.Name)
|
||||
|
||||
if c.cancel == nil {
|
||||
c.cancel = make(chan struct{})
|
||||
@ -87,21 +100,32 @@ func (c *ChatViewController) RequestMessages(params protocol.RequestOptions) err
|
||||
"REQUEST",
|
||||
fmt.Sprintf("get historic messages: %+v", params),
|
||||
)
|
||||
return c.messenger.Request(c.contact, params)
|
||||
|
||||
chat := c.messenger.Chat(c.contact)
|
||||
if chat == nil {
|
||||
return fmt.Errorf("chat not found")
|
||||
}
|
||||
return chat.Request(params)
|
||||
}
|
||||
|
||||
func (c *ChatViewController) Send(data []byte) error {
|
||||
return c.messenger.Send(c.contact, data)
|
||||
chat := c.messenger.Chat(c.contact)
|
||||
if chat == nil {
|
||||
return fmt.Errorf("chat not found")
|
||||
}
|
||||
return chat.Send(data)
|
||||
}
|
||||
|
||||
func (c *ChatViewController) printMessages(messages []*protocol.Message) {
|
||||
func (c *ChatViewController) printMessages(clear bool, messages ...*protocol.Message) {
|
||||
c.g.Update(func(*gocui.Gui) error {
|
||||
if err := c.Clear(); err != nil {
|
||||
return err
|
||||
if clear {
|
||||
if err := c.Clear(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
for _, message := range messages {
|
||||
if err := c.printMessage(message); err != nil {
|
||||
if err := c.writeMessage(message); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@ -109,7 +133,7 @@ func (c *ChatViewController) printMessages(messages []*protocol.Message) {
|
||||
})
|
||||
}
|
||||
|
||||
func (c *ChatViewController) printMessage(message *protocol.Message) error {
|
||||
func (c *ChatViewController) writeMessage(message *protocol.Message) error {
|
||||
myPubKey := c.identity.PublicKey
|
||||
pubKey := message.SigPubKey
|
||||
|
||||
|
12
main.go
12
main.go
@ -207,7 +207,17 @@ func main() {
|
||||
if !ok {
|
||||
return errors.New("contact could not be found")
|
||||
}
|
||||
return chat.Select(contact)
|
||||
|
||||
// We need to call Select asynchronously,
|
||||
// otherwise the main thread is blocked
|
||||
// and nothing is rendered.
|
||||
go func() {
|
||||
if err := chat.Select(contact); err != nil {
|
||||
log.Printf("[GetLineHandler] error selecting a chat: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}),
|
||||
},
|
||||
},
|
||||
|
@ -10,7 +10,7 @@ import (
|
||||
)
|
||||
|
||||
func init() {
|
||||
if err := logutils.OverrideRootLog(true, "DEBUG", "", false); err != nil {
|
||||
if err := logutils.OverrideRootLog(true, "INFO", "", false); err != nil {
|
||||
stdlog.Fatalf("failed to override root log: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"encoding/hex"
|
||||
"log"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
@ -33,7 +34,8 @@ type Chat struct {
|
||||
|
||||
lastClock int64
|
||||
|
||||
ownMessages chan *protocol.Message // my private messages channel
|
||||
ownMessages chan *protocol.Message // my private messages channel
|
||||
// TODO: make it a ring buffer
|
||||
messages []*protocol.Message // all messages ordered by Clock
|
||||
messagesByHash map[string]*protocol.Message // quick access to messages by hash
|
||||
}
|
||||
@ -73,12 +75,15 @@ func (c *Chat) Messages() []*protocol.Message {
|
||||
}
|
||||
|
||||
// Subscribe reads messages from the network.
|
||||
func (c *Chat) Subscribe() error {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
// TODO: change method name to Join().
|
||||
func (c *Chat) Subscribe() (err error) {
|
||||
c.RLock()
|
||||
sub := c.sub
|
||||
c.RUnlock()
|
||||
|
||||
if c.sub != nil {
|
||||
return errors.New("already subscribed")
|
||||
if sub != nil {
|
||||
err = errors.New("already subscribed")
|
||||
return
|
||||
}
|
||||
|
||||
opts := protocol.SubscribeOptions{}
|
||||
@ -88,26 +93,29 @@ func (c *Chat) Subscribe() error {
|
||||
opts.Identity = c.identity
|
||||
}
|
||||
|
||||
var err error
|
||||
|
||||
messages := make(chan *protocol.Message)
|
||||
|
||||
c.sub, err = c.proto.Subscribe(context.Background(), messages, opts)
|
||||
sub, err = c.proto.Subscribe(context.Background(), messages, opts)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to subscribe")
|
||||
err = errors.Wrap(err, "failed to subscribe")
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
// Send at least one event to kick off the logic.
|
||||
// TODO: change type of the event.
|
||||
c.events <- baseEvent{contact: c.contact, typ: EventTypeMessage}
|
||||
}()
|
||||
c.Lock()
|
||||
c.sub = sub
|
||||
c.Unlock()
|
||||
|
||||
cancel := make(chan struct{}) // can be closed by any loop
|
||||
|
||||
go c.readLoop(messages, c.sub, cancel)
|
||||
go c.readLoop(messages, sub, cancel)
|
||||
go c.readOwnMessagesLoop(c.ownMessages, cancel)
|
||||
|
||||
// Load should have it's own lock.
|
||||
return c.Load()
|
||||
}
|
||||
|
||||
// Load loads messages from the database cache and the network.
|
||||
func (c *Chat) Load() error {
|
||||
params := protocol.DefaultRequestOptions()
|
||||
|
||||
// Get already cached messages from the database.
|
||||
@ -120,10 +128,14 @@ func (c *Chat) Subscribe() error {
|
||||
return errors.Wrap(err, "db failed to get messages")
|
||||
}
|
||||
|
||||
c.Lock()
|
||||
c.handleMessages(cachedMessages...)
|
||||
c.Unlock()
|
||||
|
||||
go func() {
|
||||
for _, m := range cachedMessages {
|
||||
messages <- m
|
||||
}
|
||||
log.Printf("[Chat::Subscribe] sending EventTypeInit")
|
||||
c.events <- baseEvent{contact: c.contact, typ: EventTypeInit}
|
||||
log.Printf("[Chat::Subscribe] sent EventTypeInit")
|
||||
}()
|
||||
|
||||
if c.contact.Type == ContactPublicChat {
|
||||
@ -132,7 +144,7 @@ func (c *Chat) Subscribe() error {
|
||||
params.Recipient = c.contact.PublicKey
|
||||
}
|
||||
// Request historic messages from the network.
|
||||
if err := c.proto.Request(context.Background(), params); err != nil {
|
||||
if err := c.request(params); err != nil {
|
||||
return errors.Wrap(err, "failed to request for messages")
|
||||
}
|
||||
|
||||
@ -143,15 +155,17 @@ func (c *Chat) Subscribe() error {
|
||||
func (c *Chat) Unsubscribe() {
|
||||
c.RLock()
|
||||
defer c.RUnlock()
|
||||
|
||||
if c.sub == nil {
|
||||
return
|
||||
if c.sub != nil {
|
||||
c.sub.Unsubscribe()
|
||||
}
|
||||
c.sub.Unsubscribe()
|
||||
}
|
||||
|
||||
// Request sends a request for historic messages.
|
||||
func (c *Chat) Request(params protocol.RequestOptions) error {
|
||||
return c.request(params)
|
||||
}
|
||||
|
||||
func (c *Chat) request(params protocol.RequestOptions) error {
|
||||
return c.proto.Request(context.Background(), params)
|
||||
}
|
||||
|
||||
@ -187,12 +201,16 @@ func (c *Chat) Send(data []byte) error {
|
||||
return errors.Wrap(err, "failed to encode message")
|
||||
}
|
||||
|
||||
c.Lock()
|
||||
c.updateLastClock(clock)
|
||||
c.Unlock()
|
||||
|
||||
hash, err := c.proto.Send(context.Background(), encodedMessage, opts)
|
||||
|
||||
// Own messages need to be pushed manually to the pipeline.
|
||||
if c.contact.Type == ContactPrivateChat {
|
||||
log.Printf("[Chat::Send] sent a private message")
|
||||
|
||||
c.ownMessages <- &protocol.Message{
|
||||
Decoded: message,
|
||||
SigPubKey: &c.identity.PublicKey,
|
||||
@ -209,11 +227,28 @@ func (c *Chat) readLoop(messages <-chan *protocol.Message, sub *protocol.Subscri
|
||||
for {
|
||||
select {
|
||||
case m := <-messages:
|
||||
if err := c.handleMessage(m); err != nil {
|
||||
if c.HasMessage(m) {
|
||||
break
|
||||
}
|
||||
|
||||
c.Lock()
|
||||
c.handleMessages(m)
|
||||
c.Unlock()
|
||||
|
||||
if err := c.saveMessages(m); err != nil {
|
||||
c.Lock()
|
||||
c.err = err
|
||||
c.Unlock()
|
||||
return
|
||||
}
|
||||
c.events <- baseEvent{contact: c.contact, typ: EventTypeMessage}
|
||||
|
||||
c.events <- messageEvent{
|
||||
baseEvent: baseEvent{
|
||||
contact: c.contact,
|
||||
typ: EventTypeMessage,
|
||||
},
|
||||
message: m,
|
||||
}
|
||||
case <-sub.Done():
|
||||
c.err = sub.Err()
|
||||
return
|
||||
@ -229,39 +264,53 @@ func (c *Chat) readOwnMessagesLoop(messages <-chan *protocol.Message, cancel cha
|
||||
for {
|
||||
select {
|
||||
case m := <-messages:
|
||||
if err := c.handleMessage(m); err != nil {
|
||||
if c.HasMessage(m) {
|
||||
break
|
||||
}
|
||||
|
||||
c.Lock()
|
||||
c.handleMessages(m)
|
||||
c.Unlock()
|
||||
|
||||
if err := c.saveMessages(m); err != nil {
|
||||
c.Lock()
|
||||
c.err = err
|
||||
c.Unlock()
|
||||
return
|
||||
}
|
||||
c.events <- baseEvent{contact: c.contact, typ: EventTypeMessage}
|
||||
|
||||
c.events <- messageEvent{
|
||||
baseEvent: baseEvent{
|
||||
contact: c.contact,
|
||||
typ: EventTypeMessage,
|
||||
},
|
||||
message: m,
|
||||
}
|
||||
case <-cancel:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Chat) handleMessage(message *protocol.Message) error {
|
||||
lessFn := func(i, j int) bool {
|
||||
return c.messages[i].Decoded.Clock < c.messages[j].Decoded.Clock
|
||||
func (c *Chat) handleMessages(messages ...*protocol.Message) {
|
||||
for _, message := range messages {
|
||||
c.updateLastClock(message.Decoded.Clock)
|
||||
|
||||
hash := messageHashStr(message)
|
||||
|
||||
c.messagesByHash[hash] = message
|
||||
c.messages = append(c.messages, message)
|
||||
|
||||
sort.Slice(c.messages, c.lessFn)
|
||||
}
|
||||
hash := hex.EncodeToString(message.Hash)
|
||||
}
|
||||
|
||||
// the message already exists
|
||||
if _, ok := c.messagesByHash[hash]; ok {
|
||||
return nil
|
||||
}
|
||||
func (c *Chat) saveMessages(messages ...*protocol.Message) error {
|
||||
return c.db.SaveMessages(c.contact, messages)
|
||||
}
|
||||
|
||||
c.updateLastClock(message.Decoded.Clock)
|
||||
|
||||
c.messagesByHash[hash] = message
|
||||
c.messages = append(c.messages, message)
|
||||
|
||||
isSorted := sort.SliceIsSorted(c.messages, lessFn)
|
||||
if !isSorted {
|
||||
sort.Slice(c.messages, lessFn)
|
||||
}
|
||||
|
||||
return c.db.SaveMessages(c.contact, message)
|
||||
func (c *Chat) lessFn(i, j int) bool {
|
||||
return c.messages[i].Decoded.Clock < c.messages[j].Decoded.Clock
|
||||
}
|
||||
|
||||
func (c *Chat) updateLastClock(clock int64) {
|
||||
@ -269,3 +318,20 @@ func (c *Chat) updateLastClock(clock int64) {
|
||||
c.lastClock = clock
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Chat) hasMessage(m *protocol.Message) bool {
|
||||
hash := messageHashStr(m)
|
||||
_, ok := c.messagesByHash[hash]
|
||||
return ok
|
||||
}
|
||||
|
||||
// HasMessage returns true if a given message is already cached.
|
||||
func (c *Chat) HasMessage(m *protocol.Message) bool {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
return c.hasMessage(m)
|
||||
}
|
||||
|
||||
func messageHashStr(m *protocol.Message) string {
|
||||
return hex.EncodeToString(m.Hash)
|
||||
}
|
||||
|
@ -80,7 +80,7 @@ func (d *Database) Messages(c Contact, from, to int64) (result []*protocol.Messa
|
||||
}
|
||||
|
||||
// SaveMessages stores messages on a disk.
|
||||
func (d *Database) SaveMessages(c Contact, messages ...*protocol.Message) error {
|
||||
func (d *Database) SaveMessages(c Contact, messages []*protocol.Message) error {
|
||||
var buf bytes.Buffer
|
||||
enc := gob.NewEncoder(&buf)
|
||||
|
||||
|
@ -1,7 +1,10 @@
|
||||
package client
|
||||
|
||||
import "github.com/status-im/status-console-client/protocol/v1"
|
||||
|
||||
const (
|
||||
EventTypeMessage int = iota + 1
|
||||
EventTypeInit int = iota + 1
|
||||
EventTypeMessage
|
||||
EventTypeError
|
||||
)
|
||||
|
||||
@ -15,6 +18,11 @@ type EventError interface {
|
||||
Error() error
|
||||
}
|
||||
|
||||
type EventMessage interface {
|
||||
Event
|
||||
Message() *protocol.Message
|
||||
}
|
||||
|
||||
type baseEvent struct {
|
||||
contact Contact
|
||||
typ int
|
||||
@ -23,6 +31,20 @@ type baseEvent struct {
|
||||
func (e baseEvent) Contact() Contact { return e.contact }
|
||||
func (e baseEvent) Type() int { return e.typ }
|
||||
|
||||
type errorEvent struct {
|
||||
baseEvent
|
||||
err error
|
||||
}
|
||||
|
||||
func (e errorEvent) Error() error { return e.err }
|
||||
|
||||
type messageEvent struct {
|
||||
baseEvent
|
||||
message *protocol.Message
|
||||
}
|
||||
|
||||
func (e messageEvent) Message() *protocol.Message { return e.message }
|
||||
|
||||
// type eventError struct {
|
||||
// Event
|
||||
// err error
|
||||
|
@ -2,6 +2,7 @@ package client
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"log"
|
||||
"sync"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
@ -10,12 +11,13 @@ import (
|
||||
|
||||
// Messenger coordinates chats.
|
||||
type Messenger struct {
|
||||
sync.RWMutex
|
||||
|
||||
proto protocol.Chat
|
||||
identity *ecdsa.PrivateKey
|
||||
db *Database
|
||||
|
||||
chats map[Contact]*Chat
|
||||
wg sync.WaitGroup
|
||||
|
||||
events chan interface{}
|
||||
}
|
||||
@ -33,76 +35,70 @@ func NewMessenger(proto protocol.Chat, identity *ecdsa.PrivateKey, db *Database)
|
||||
|
||||
// Events returns a channel with chat events.
|
||||
func (m *Messenger) Events() <-chan interface{} {
|
||||
m.RLock()
|
||||
defer m.RUnlock()
|
||||
return m.events
|
||||
}
|
||||
|
||||
func (m *Messenger) Chat(c Contact) *Chat {
|
||||
m.RLock()
|
||||
defer m.RUnlock()
|
||||
return m.chats[c]
|
||||
}
|
||||
|
||||
// Join creates a new chat and creates a subscription.
|
||||
func (m *Messenger) Join(contact Contact) error {
|
||||
chat := NewChat(m.proto, m.identity, contact, m.db)
|
||||
m.RLock()
|
||||
chat, found := m.chats[contact]
|
||||
m.RUnlock()
|
||||
|
||||
if err := chat.Subscribe(); err != nil {
|
||||
return err
|
||||
if found {
|
||||
return chat.Load()
|
||||
}
|
||||
|
||||
chat = NewChat(m.proto, m.identity, contact, m.db)
|
||||
|
||||
m.Lock()
|
||||
m.chats[contact] = chat
|
||||
m.Unlock()
|
||||
|
||||
m.wg.Add(1)
|
||||
go func() {
|
||||
defer m.wg.Done()
|
||||
go func(events <-chan interface{}) {
|
||||
log.Printf("[Messenger::Join] waiting for events")
|
||||
|
||||
for ev := range chat.Events() {
|
||||
for ev := range events {
|
||||
log.Printf("[Messenger::Join] received an event: %+v", ev)
|
||||
m.events <- ev
|
||||
}
|
||||
|
||||
if err := chat.Err(); err != nil {
|
||||
m.events <- baseEvent{contact: contact, typ: EventTypeError}
|
||||
m.events <- errorEvent{
|
||||
baseEvent: baseEvent{contact: contact, typ: EventTypeError},
|
||||
err: err,
|
||||
}
|
||||
}
|
||||
}()
|
||||
}(chat.Events())
|
||||
|
||||
return nil
|
||||
return chat.Subscribe()
|
||||
}
|
||||
|
||||
// Leave unsubscribes from the chat.
|
||||
func (m *Messenger) Leave(contact Contact) error {
|
||||
m.RLock()
|
||||
chat, ok := m.chats[contact]
|
||||
m.RUnlock()
|
||||
if !ok {
|
||||
return errors.New("chat for the contact not found")
|
||||
}
|
||||
|
||||
chat.Unsubscribe()
|
||||
|
||||
m.Lock()
|
||||
delete(m.chats, contact)
|
||||
m.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Messages returns a list of messages for a given contact.
|
||||
func (m *Messenger) Messages(contact Contact) ([]*protocol.Message, error) {
|
||||
chat, ok := m.chats[contact]
|
||||
if !ok {
|
||||
return nil, errors.New("chat for the contact not found")
|
||||
}
|
||||
|
||||
return chat.Messages(), nil
|
||||
}
|
||||
|
||||
func (m *Messenger) Request(contact Contact, params protocol.RequestOptions) error {
|
||||
chat, ok := m.chats[contact]
|
||||
if !ok {
|
||||
return errors.New("chat for the contact not found")
|
||||
}
|
||||
|
||||
return chat.Request(params)
|
||||
}
|
||||
|
||||
func (m *Messenger) Send(contact Contact, data []byte) error {
|
||||
chat, ok := m.chats[contact]
|
||||
if !ok {
|
||||
return errors.New("chat for the contact not found")
|
||||
}
|
||||
|
||||
return chat.Send(data)
|
||||
}
|
||||
|
||||
func (m *Messenger) Contacts() ([]Contact, error) {
|
||||
return m.db.Contacts()
|
||||
}
|
||||
@ -118,7 +114,6 @@ func (m *Messenger) AddContact(c Contact) error {
|
||||
}
|
||||
|
||||
contacts = append(contacts, c)
|
||||
|
||||
return m.db.SaveContacts(contacts)
|
||||
}
|
||||
|
||||
@ -129,14 +124,17 @@ func (m *Messenger) RemoveContact(c Contact) error {
|
||||
}
|
||||
|
||||
for i, item := range contacts {
|
||||
if item == c {
|
||||
copy(contacts[i:], contacts[i+1:])
|
||||
contacts[len(contacts)-1] = Contact{}
|
||||
contacts = contacts[:len(contacts)-1]
|
||||
if item != c {
|
||||
continue
|
||||
}
|
||||
|
||||
copy(contacts[i:], contacts[i+1:])
|
||||
contacts[len(contacts)-1] = Contact{}
|
||||
contacts = contacts[:len(contacts)-1]
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
contacts = append(contacts, c)
|
||||
|
||||
return m.db.SaveContacts(contacts)
|
||||
}
|
||||
|
@ -235,10 +235,11 @@ func (a *WhisperServiceAdapter) requestMessages(ctx context.Context, enode strin
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = shhextAPI.RequestMessages(ctx, req)
|
||||
// TODO: wait for the request to finish before returning.
|
||||
// Use a different method or relay on signals.
|
||||
return err
|
||||
return shhextAPI.RequestMessagesSync(shhext.RetryConfig{
|
||||
BaseTimeout: time.Second * 10,
|
||||
StepTimeout: time.Second,
|
||||
MaxRetries: 3,
|
||||
}, req)
|
||||
}
|
||||
|
||||
func (a *WhisperServiceAdapter) createMessagesRequest(
|
||||
@ -297,8 +298,6 @@ func (s whisperSubscription) Messages() ([]*Message, error) {
|
||||
result := make([]*Message, 0, len(items))
|
||||
|
||||
for _, item := range items {
|
||||
log.Printf("retrieve a message with ID %s", item.EnvelopeHash.String())
|
||||
|
||||
decoded, err := DecodeMessage(item.Payload)
|
||||
if err != nil {
|
||||
log.Printf("failed to decode message: %v", err)
|
||||
|
Loading…
x
Reference in New Issue
Block a user