Add support for partitioned topic

This commit is contained in:
Andrea Maria Piana 2019-05-23 10:47:20 +02:00
parent 78ed35d2fe
commit 1aa3e2812a
40 changed files with 2459 additions and 1791 deletions

View File

@ -172,7 +172,7 @@ setup: setup-build setup-dev tidy ##@other Prepare project for development and b
generate: ##@other Regenerate assets and other auto-generated stuff generate: ##@other Regenerate assets and other auto-generated stuff
go generate ./static ./static/chat_db_migrations ./static/mailserver_db_migrations ./t go generate ./static ./static/chat_db_migrations ./static/mailserver_db_migrations ./t
$(shell cd ./services/shhext/chat && exec protoc --go_out=. ./*.proto) $(shell cd ./services/shhext/chat/protobuf && exec protoc --go_out=. ./*.proto)
prepare-release: clean-release prepare-release: clean-release
mkdir -p $(RELEASE_DIR) mkdir -p $(RELEASE_DIR)

View File

@ -2,7 +2,6 @@ package api
import ( import (
"context" "context"
"encoding/hex"
"errors" "errors"
"fmt" "fmt"
"math/big" "math/big"
@ -26,8 +25,8 @@ import (
"github.com/status-im/status-go/rpc" "github.com/status-im/status-go/rpc"
"github.com/status-im/status-go/services/personal" "github.com/status-im/status-go/services/personal"
"github.com/status-im/status-go/services/rpcfilters" "github.com/status-im/status-go/services/rpcfilters"
"github.com/status-im/status-go/services/shhext/chat"
"github.com/status-im/status-go/services/shhext/chat/crypto" "github.com/status-im/status-go/services/shhext/chat/crypto"
"github.com/status-im/status-go/services/shhext/filter"
"github.com/status-im/status-go/services/subscriptions" "github.com/status-im/status-go/services/subscriptions"
"github.com/status-im/status-go/services/typeddata" "github.com/status-im/status-go/services/typeddata"
"github.com/status-im/status-go/signal" "github.com/status-im/status-go/signal"
@ -645,91 +644,6 @@ func appendIf(condition bool, services []gethnode.ServiceConstructor, service ge
return append(services, service) return append(services, service)
} }
// CreateContactCode create or return the latest contact code
func (b *StatusBackend) CreateContactCode() (string, error) {
selectedChatAccount, err := b.AccountManager().SelectedChatAccount()
if err != nil {
return "", err
}
st, err := b.statusNode.ShhExtService()
if err != nil {
return "", err
}
bundle, err := st.GetBundle(selectedChatAccount.AccountKey.PrivateKey)
if err != nil {
return "", err
}
return bundle.ToBase64()
}
// GetContactCode return the latest contact code
func (b *StatusBackend) GetContactCode(identity string) (string, error) {
st, err := b.statusNode.ShhExtService()
if err != nil {
return "", err
}
publicKeyBytes, err := hex.DecodeString(identity)
if err != nil {
return "", err
}
publicKey, err := ethcrypto.UnmarshalPubkey(publicKeyBytes)
if err != nil {
return "", err
}
bundle, err := st.GetPublicBundle(publicKey)
if err != nil {
return "", err
}
if bundle == nil {
return "", nil
}
return bundle.ToBase64()
}
// ProcessContactCode process and adds the someone else's bundle
func (b *StatusBackend) ProcessContactCode(contactCode string) error {
selectedChatAccount, err := b.AccountManager().SelectedChatAccount()
if err != nil {
return err
}
st, err := b.statusNode.ShhExtService()
if err != nil {
return err
}
bundle, err := chat.FromBase64(contactCode)
if err != nil {
b.log.Error("error decoding base64", "err", err)
return err
}
if _, err := st.ProcessPublicBundle(selectedChatAccount.AccountKey.PrivateKey, bundle); err != nil {
b.log.Error("error adding bundle", "err", err)
return err
}
return nil
}
// ExtractIdentityFromContactCode extract the identity of the user generating the contact code
func (b *StatusBackend) ExtractIdentityFromContactCode(contactCode string) (string, error) {
bundle, err := chat.FromBase64(contactCode)
if err != nil {
return "", err
}
return chat.ExtractIdentity(bundle)
}
// ExtractGroupMembershipSignatures extract signatures from tuples of content/signature // ExtractGroupMembershipSignatures extract signatures from tuples of content/signature
func (b *StatusBackend) ExtractGroupMembershipSignatures(signaturePairs [][2]string) ([]string, error) { func (b *StatusBackend) ExtractGroupMembershipSignatures(signaturePairs [][2]string) ([]string, error) {
return crypto.ExtractSignatures(signaturePairs) return crypto.ExtractSignatures(signaturePairs)
@ -745,6 +659,36 @@ func (b *StatusBackend) SignGroupMembership(content string) (string, error) {
return crypto.Sign(content, selectedChatAccount.AccountKey.PrivateKey) return crypto.Sign(content, selectedChatAccount.AccountKey.PrivateKey)
} }
// LoadFilters loads filter on sshext
func (b *StatusBackend) LoadFilters(chats []*filter.Chat) ([]*filter.Chat, error) {
st, err := b.statusNode.ShhExtService()
if err != nil {
return nil, err
}
return st.LoadFilters(chats)
}
// LoadFilter loads filter on sshext
func (b *StatusBackend) LoadFilter(chat *filter.Chat) ([]*filter.Chat, error) {
st, err := b.statusNode.ShhExtService()
if err != nil {
return nil, err
}
return st.LoadFilter(chat)
}
// RemoveFilter remove a filter
func (b *StatusBackend) RemoveFilter(chat *filter.Chat) error {
st, err := b.statusNode.ShhExtService()
if err != nil {
return err
}
return st.RemoveFilter(chat)
}
// EnableInstallation enables an installation for multi-device sync. // EnableInstallation enables an installation for multi-device sync.
func (b *StatusBackend) EnableInstallation(installationID string) error { func (b *StatusBackend) EnableInstallation(installationID string) error {
selectedChatAccount, err := b.AccountManager().SelectedChatAccount() selectedChatAccount, err := b.AccountManager().SelectedChatAccount()

View File

@ -18,6 +18,7 @@ import (
"github.com/status-im/status-go/params" "github.com/status-im/status-go/params"
"github.com/status-im/status-go/profiling" "github.com/status-im/status-go/profiling"
"github.com/status-im/status-go/services/personal" "github.com/status-im/status-go/services/personal"
"github.com/status-im/status-go/services/shhext/filter"
"github.com/status-im/status-go/services/typeddata" "github.com/status-im/status-go/services/typeddata"
"github.com/status-im/status-go/signal" "github.com/status-im/status-go/signal"
"github.com/status-im/status-go/transactions" "github.com/status-im/status-go/transactions"
@ -51,40 +52,23 @@ func StopNode() *C.char {
return makeJSONResponse(nil) return makeJSONResponse(nil)
} }
// Create an X3DH bundle // LoadFilters load all whisper filters
//export CreateContactCode //export LoadFilters
func CreateContactCode() *C.char { func LoadFilters(chatsStr *C.char) *C.char {
bundle, err := statusBackend.CreateContactCode() var chats []*filter.Chat
if err != nil {
if err := json.Unmarshal([]byte(C.GoString(chatsStr)), &chats); err != nil {
return makeJSONResponse(err) return makeJSONResponse(err)
} }
cstr := C.CString(bundle) response, err := statusBackend.LoadFilters(chats)
return cstr
}
//export ProcessContactCode
func ProcessContactCode(bundleString *C.char) *C.char {
err := statusBackend.ProcessContactCode(C.GoString(bundleString))
if err != nil {
return makeJSONResponse(err)
}
return nil
}
// Get an X3DH bundle
//export GetContactCode
func GetContactCode(identityString *C.char) *C.char {
bundle, err := statusBackend.GetContactCode(C.GoString(identityString))
if err != nil { if err != nil {
return makeJSONResponse(err) return makeJSONResponse(err)
} }
data, err := json.Marshal(struct { data, err := json.Marshal(struct {
ContactCode string `json:"code"` Chats []*filter.Chat `json:"result"`
}{ContactCode: bundle}) }{Chats: response})
if err != nil { if err != nil {
return makeJSONResponse(err) return makeJSONResponse(err)
} }
@ -92,22 +76,48 @@ func GetContactCode(identityString *C.char) *C.char {
return C.CString(string(data)) return C.CString(string(data))
} }
//export ExtractIdentityFromContactCode // LoadFilter load a whisper filter
func ExtractIdentityFromContactCode(bundleString *C.char) *C.char { //export LoadFilter
bundle := C.GoString(bundleString) func LoadFilter(chatStr *C.char) *C.char {
var chat *filter.Chat
identity, err := statusBackend.ExtractIdentityFromContactCode(bundle) if err := json.Unmarshal([]byte(C.GoString(chatStr)), &chat); err != nil {
return makeJSONResponse(err)
}
response, err := statusBackend.LoadFilter(chat)
if err != nil { if err != nil {
return makeJSONResponse(err) return makeJSONResponse(err)
} }
if err := statusBackend.ProcessContactCode(bundle); err != nil { data, err := json.Marshal(struct {
Chats []*filter.Chat `json:"result"`
}{Chats: response})
if err != nil {
return makeJSONResponse(err)
}
return C.CString(string(data))
}
// RemoveFilter load a whisper filter
//export RemoveFilter
func RemoveFilter(chatStr *C.char) *C.char {
var chat *filter.Chat
if err := json.Unmarshal([]byte(C.GoString(chatStr)), &chat); err != nil {
return makeJSONResponse(err)
}
err := statusBackend.RemoveFilter(chat)
if err != nil {
return makeJSONResponse(err) return makeJSONResponse(err)
} }
data, err := json.Marshal(struct { data, err := json.Marshal(struct {
Identity string `json:"identity"` Response string `json:"response"`
}{Identity: identity}) }{Response: "ok"})
if err != nil { if err != nil {
return makeJSONResponse(err) return makeJSONResponse(err)
} }

View File

@ -16,6 +16,7 @@ import (
"github.com/status-im/status-go/params" "github.com/status-im/status-go/params"
"github.com/status-im/status-go/profiling" "github.com/status-im/status-go/profiling"
"github.com/status-im/status-go/services/personal" "github.com/status-im/status-go/services/personal"
"github.com/status-im/status-go/services/shhext/filter"
"github.com/status-im/status-go/services/typeddata" "github.com/status-im/status-go/services/typeddata"
"github.com/status-im/status-go/signal" "github.com/status-im/status-go/signal"
"github.com/status-im/status-go/transactions" "github.com/status-im/status-go/transactions"
@ -64,48 +65,6 @@ func StopNode() string {
return makeJSONResponse(nil) return makeJSONResponse(nil)
} }
// CreateContactCode creates an X3DH bundle.
func CreateContactCode() string {
bundle, err := statusBackend.CreateContactCode()
if err != nil {
return makeJSONResponse(err)
}
return bundle
}
// ProcessContactCode processes an X3DH bundle.
// TODO(adam): it looks like the return should be error.
func ProcessContactCode(bundle string) string {
err := statusBackend.ProcessContactCode(bundle)
if err != nil {
return makeJSONResponse(err)
}
return ""
}
// ExtractIdentityFromContactCode extracts an identity from an X3DH bundle.
func ExtractIdentityFromContactCode(bundle string) string {
identity, err := statusBackend.ExtractIdentityFromContactCode(bundle)
if err != nil {
return makeJSONResponse(err)
}
if err := statusBackend.ProcessContactCode(bundle); err != nil {
return makeJSONResponse(err)
}
data, err := json.Marshal(struct {
Identity string `json:"identity"`
}{Identity: identity})
if err != nil {
return makeJSONResponse(err)
}
return string(data)
}
// ExtractGroupMembershipSignatures extract public keys from tuples of content/signature. // ExtractGroupMembershipSignatures extract public keys from tuples of content/signature.
func ExtractGroupMembershipSignatures(signaturePairsStr string) string { func ExtractGroupMembershipSignatures(signaturePairsStr string) string {
var signaturePairs [][2]string var signaturePairs [][2]string
@ -617,24 +576,6 @@ func SetSignalEventCallback(cb unsafe.Pointer) {
signal.SetSignalEventCallback(cb) signal.SetSignalEventCallback(cb)
} }
// Get an X3DH bundle
//export GetContactCode
func GetContactCode(identity string) string {
bundle, err := statusBackend.GetContactCode(identity)
if err != nil {
return makeJSONResponse(err)
}
data, err := json.Marshal(struct {
ContactCode string `json:"code"`
}{ContactCode: bundle})
if err != nil {
return makeJSONResponse(err)
}
return string(data)
}
// ExportNodeLogs reads current node log and returns content to a caller. // ExportNodeLogs reads current node log and returns content to a caller.
//export ExportNodeLogs //export ExportNodeLogs
func ExportNodeLogs() string { func ExportNodeLogs() string {
@ -673,3 +614,73 @@ func SignHash(hexEncodedHash string) string {
return hexEncodedSignature return hexEncodedSignature
} }
// LoadFilters load all whisper filters
func LoadFilters(chatsStr string) string {
var chats []*filter.Chat
if err := json.Unmarshal([]byte(chatsStr), &chats); err != nil {
return makeJSONResponse(err)
}
response, err := statusBackend.LoadFilters(chats)
if err != nil {
return makeJSONResponse(err)
}
data, err := json.Marshal(struct {
Chats []*filter.Chat `json:"result"`
}{Chats: response})
if err != nil {
return makeJSONResponse(err)
}
return string(data)
}
// LoadFilter load a whisper filter
func LoadFilter(chatStr string) string {
var chat *filter.Chat
if err := json.Unmarshal([]byte(chatStr), &chat); err != nil {
return makeJSONResponse(err)
}
response, err := statusBackend.LoadFilter(chat)
if err != nil {
return makeJSONResponse(err)
}
data, err := json.Marshal(struct {
Chats []*filter.Chat `json:"result"`
}{Chats: response})
if err != nil {
return makeJSONResponse(err)
}
return string(data)
}
// RemoveFilter load a whisper filter
//export RemoveFilter
func RemoveFilter(chatStr string) string {
var chat *filter.Chat
if err := json.Unmarshal([]byte(chatStr), &chat); err != nil {
return makeJSONResponse(err)
}
err := statusBackend.RemoveFilter(chat)
if err != nil {
return makeJSONResponse(err)
}
data, err := json.Marshal(struct {
Response string `json:"response"`
}{Response: "ok"})
if err != nil {
return makeJSONResponse(err)
}
return string(data)
}

View File

@ -368,8 +368,6 @@ type WalletConfig struct {
// ShhextConfig defines options used by shhext service. // ShhextConfig defines options used by shhext service.
type ShhextConfig struct { type ShhextConfig struct {
// AsymKeyID the key id of the selected account
AsymKeyID string
PFSEnabled bool PFSEnabled bool
// BackupDisabledDataDir is the file system folder the node should use for any data storage needs that it doesn't want backed up. // BackupDisabledDataDir is the file system folder the node should use for any data storage needs that it doesn't want backed up.
BackupDisabledDataDir string BackupDisabledDataDir string

View File

@ -9,15 +9,12 @@ import (
"math/big" "math/big"
"time" "time"
"github.com/ethereum/go-ethereum/rlp"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/hexutil" "github.com/ethereum/go-ethereum/common/hexutil"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/p2p/enode"
"github.com/golang/protobuf/proto" "github.com/ethereum/go-ethereum/rlp"
"github.com/status-im/status-go/db" "github.com/status-im/status-go/db"
"github.com/status-im/status-go/mailserver" "github.com/status-im/status-go/mailserver"
"github.com/status-im/status-go/services/shhext/chat" "github.com/status-im/status-go/services/shhext/chat"
@ -424,15 +421,13 @@ func (api *PublicAPI) GetNewFilterMessages(filterID string) ([]dedup.Deduplicate
dedupMessages := api.service.deduplicator.Deduplicate(msgs) dedupMessages := api.service.deduplicator.Deduplicate(msgs)
if api.service.pfsEnabled {
// Attempt to decrypt message, otherwise leave unchanged // Attempt to decrypt message, otherwise leave unchanged
for _, dedupMessage := range dedupMessages { for _, dedupMessage := range dedupMessages {
if err := api.processPFSMessage(dedupMessage); err != nil { if err := api.service.ProcessMessage(dedupMessage); err != nil {
return nil, err return nil, err
} }
} }
}
return dedupMessages, nil return dedupMessages, nil
} }
@ -462,7 +457,7 @@ func (api *PublicAPI) ConfirmMessagesProcessed(messages []*whisper.Message) (err
// ConfirmMessagesProcessedByID is a method to confirm that messages was consumed by // ConfirmMessagesProcessedByID is a method to confirm that messages was consumed by
// the client side. // the client side.
func (api *PublicAPI) ConfirmMessagesProcessedByID(messageIDs [][]byte) error { func (api *PublicAPI) ConfirmMessagesProcessedByID(messageIDs [][]byte) error {
if err := api.service.protocol.ConfirmMessagesProcessed(messageIDs); err != nil { if err := api.service.ConfirmMessagesProcessed(messageIDs); err != nil {
return err return err
} }
@ -471,97 +466,12 @@ func (api *PublicAPI) ConfirmMessagesProcessedByID(messageIDs [][]byte) error {
// SendPublicMessage sends a public chat message to the underlying transport // SendPublicMessage sends a public chat message to the underlying transport
func (api *PublicAPI) SendPublicMessage(ctx context.Context, msg chat.SendPublicMessageRPC) (hexutil.Bytes, error) { func (api *PublicAPI) SendPublicMessage(ctx context.Context, msg chat.SendPublicMessageRPC) (hexutil.Bytes, error) {
privateKey, err := api.service.w.GetPrivateKey(msg.Sig) return api.service.SendPublicMessage(ctx, msg)
if err != nil {
return nil, err
}
// This is transport layer agnostic
protocolMessage, err := api.service.protocol.BuildPublicMessage(privateKey, msg.Payload)
if err != nil {
return nil, err
}
symKeyID, err := api.service.w.AddSymKeyFromPassword(msg.Chat)
if err != nil {
return nil, err
}
// marshal for sending to wire
marshaledMessage, err := proto.Marshal(protocolMessage)
if err != nil {
api.log.Error("encryption-service", "error marshaling message", err)
return nil, err
}
// Enrich with transport layer info
whisperMessage := chat.PublicMessageToWhisper(msg, marshaledMessage)
whisperMessage.SymKeyID = symKeyID
// And dispatch
return api.Post(ctx, whisperMessage)
} }
// SendDirectMessage sends a 1:1 chat message to the underlying transport // SendDirectMessage sends a 1:1 chat message to the underlying transport
func (api *PublicAPI) SendDirectMessage(ctx context.Context, msg chat.SendDirectMessageRPC) (hexutil.Bytes, error) { func (api *PublicAPI) SendDirectMessage(ctx context.Context, msg chat.SendDirectMessageRPC) (hexutil.Bytes, error) {
if !api.service.pfsEnabled { return api.service.SendDirectMessage(ctx, msg)
return nil, ErrPFSNotEnabled
}
// To be completely agnostic from whisper we should not be using whisper to store the key
privateKey, err := api.service.w.GetPrivateKey(msg.Sig)
if err != nil {
return nil, err
}
publicKey, err := crypto.UnmarshalPubkey(msg.PubKey)
if err != nil {
return nil, err
}
// This is transport layer-agnostic
var protocolMessage *chat.ProtocolMessage
// The negotiated secret
var topic []byte
api.log.Info("BUILDING MESSAGE")
if msg.DH {
protocolMessage, topic, err = api.service.protocol.BuildDHMessage(privateKey, &privateKey.PublicKey, msg.Payload)
} else {
protocolMessage, topic, err = api.service.protocol.BuildDirectMessage(privateKey, publicKey, msg.Payload)
}
api.log.Info("BUILT MESSAGE", "topic", topic)
if err != nil {
return nil, err
}
// marshal for sending to wire
marshaledMessage, err := proto.Marshal(protocolMessage)
if err != nil {
api.log.Error("encryption-service", "error marshaling message", err)
return nil, err
}
// TODO: Refactor this as it's not quite the right abstraction anymore
whisperMessage := chat.DirectMessageToWhisper(msg, marshaledMessage, topic)
// Enrich with transport layer info
if topic != nil {
api.log.Info("GETTING SYM KEY", "symkey", api.service.GetNegotiatedChat(publicKey))
chat := api.service.GetNegotiatedChat(publicKey)
if chat != nil {
whisperMessage.SymKeyID = chat.SymKeyID
whisperMessage.Topic = whisper.BytesToTopic(chat.Topic)
whisperMessage.PublicKey = nil
}
}
api.log.Info("WHISPER MESSAGE", "message", whisperMessage)
// And dispatch
return api.Post(ctx, whisperMessage)
} }
func (api *PublicAPI) requestMessagesUsingPayload(request db.HistoryRequest, peer, symkeyID string, payload []byte, force bool, timeout time.Duration, topics []whisper.TopicType) (hash common.Hash, err error) { func (api *PublicAPI) requestMessagesUsingPayload(request db.HistoryRequest, peer, symkeyID string, payload []byte, force bool, timeout time.Duration, topics []whisper.TopicType) (hash common.Hash, err error) {
@ -672,54 +582,6 @@ func (api *PublicAPI) CompleteRequest(parent context.Context, hex string) (err e
return err return err
} }
func (api *PublicAPI) processPFSMessage(dedupMessage dedup.DeduplicateMessage) error {
msg := dedupMessage.Message
privateKeyID := api.service.w.SelectedKeyPairID()
if privateKeyID == "" {
return errors.New("no key selected")
}
privateKey, err := api.service.w.GetPrivateKey(privateKeyID)
if err != nil {
return err
}
publicKey, err := crypto.UnmarshalPubkey(msg.Sig)
if err != nil {
return err
}
// Unmarshal message
protocolMessage := &chat.ProtocolMessage{}
if err := proto.Unmarshal(msg.Payload, protocolMessage); err != nil {
api.log.Debug("Not a protocol message", "err", err)
return nil
}
response, err := api.service.protocol.HandleMessage(privateKey, publicKey, protocolMessage, dedupMessage.DedupID)
switch err {
case nil:
// Set the decrypted payload
msg.Payload = response
case chat.ErrDeviceNotFound:
// Notify that someone tried to contact us using an invalid bundle
if privateKey.PublicKey != *publicKey {
api.log.Warn("Device not found, sending signal", "err", err)
keyString := fmt.Sprintf("0x%x", crypto.FromECDSAPub(publicKey))
handler := EnvelopeSignalHandler{}
handler.DecryptMessageFailed(keyString)
}
default:
// Log and pass to the client, even if failed to decrypt
api.log.Error("Failed handling message with error", "err", err)
}
return nil
}
// ----- // -----
// HELPER // HELPER
// ----- // -----

View File

@ -8,8 +8,8 @@
// 1540715431_add_version.up.sql // 1540715431_add_version.up.sql
// 1541164797_add_installations.down.sql // 1541164797_add_installations.down.sql
// 1541164797_add_installations.up.sql // 1541164797_add_installations.up.sql
// 1558084410_add_topic.down.sql // 1558084410_add_secret.down.sql
// 1558084410_add_topic.up.sql // 1558084410_add_secret.up.sql
// 1558588866_add_version.up.sql // 1558588866_add_version.up.sql
// static.go // static.go
// DO NOT EDIT! // DO NOT EDIT!
@ -239,42 +239,42 @@ func _1541164797_add_installationsUpSql() (*asset, error) {
return a, nil return a, nil
} }
var __1558084410_add_topicDownSql = []byte("\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\xff\x72\x09\xf2\x0f\x50\x08\x71\x74\xf2\x71\x55\x28\xc9\x2f\xc8\x4c\x8e\xcf\xcc\x2b\x2e\x49\xcc\xc9\x49\x2c\xc9\xcc\xcf\x8b\xcf\x4c\x29\xb6\xe6\x42\x57\x52\x6c\xcd\x05\x08\x00\x00\xff\xff\xf0\xe3\x8a\xc7\x36\x00\x00\x00") var __1558084410_add_secretDownSql = []byte("\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\xff\x72\x09\xf2\x0f\x50\x08\x71\x74\xf2\x71\x55\x28\x4e\x4d\x2e\x4a\x2d\x89\xcf\xcc\x2b\x2e\x49\xcc\xc9\x49\x2c\xc9\xcc\xcf\x8b\xcf\x4c\x29\xb6\xe6\xc2\x50\x53\x6c\xcd\x05\x08\x00\x00\xff\xff\xd3\xcd\x41\x83\x38\x00\x00\x00")
func _1558084410_add_topicDownSqlBytes() ([]byte, error) { func _1558084410_add_secretDownSqlBytes() ([]byte, error) {
return bindataRead( return bindataRead(
__1558084410_add_topicDownSql, __1558084410_add_secretDownSql,
"1558084410_add_topic.down.sql", "1558084410_add_secret.down.sql",
) )
} }
func _1558084410_add_topicDownSql() (*asset, error) { func _1558084410_add_secretDownSql() (*asset, error) {
bytes, err := _1558084410_add_topicDownSqlBytes() bytes, err := _1558084410_add_secretDownSqlBytes()
if err != nil { if err != nil {
return nil, err return nil, err
} }
info := bindataFileInfo{name: "1558084410_add_topic.down.sql", size: 54, mode: os.FileMode(420), modTime: time.Unix(1560418030, 0)} info := bindataFileInfo{name: "1558084410_add_secret.down.sql", size: 56, mode: os.FileMode(420), modTime: time.Unix(1560418252, 0)}
a := &asset{bytes: bytes, info: info} a := &asset{bytes: bytes, info: info}
return a, nil return a, nil
} }
var __1558084410_add_topicUpSql = []byte("\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\xff\x74\x90\x41\x6b\x85\x30\x10\x84\xef\xf9\x15\x73\x54\xf0\x1f\xf4\xa4\x61\x95\xd0\x74\xd3\xa6\x11\xea\x49\xc4\x78\x58\x10\x2d\x35\x97\xfe\xfb\x62\x79\x4f\x94\xc7\x3b\xcf\xcc\xce\x7c\xab\x3d\x95\x81\x10\xca\xca\x12\xd2\xfa\x2d\xe3\x86\x4c\x01\x12\xa7\x25\x49\xfa\x45\x65\x5d\x05\x76\x01\xdc\x5a\x8b\x77\x6f\xde\x4a\xdf\xe1\x95\x3a\x38\x86\x76\x5c\x5b\xa3\x03\x4c\xc3\xce\x53\xa1\x80\x6d\x1a\x7f\xa6\x74\x8d\xa9\xfc\x45\xa9\xc7\xa6\x5e\x96\x2d\x0d\xf3\x3c\x24\x59\x97\x5e\xe2\xbd\x19\x81\xbe\xc2\x11\x2e\x4e\x6b\x7a\x89\xd7\xcb\xbb\xd8\xb2\xf9\x68\x29\x93\x58\x9c\x7d\xf9\x93\x7d\xb5\xf3\x64\x1a\xfe\x27\xc8\x2e\x7e\x4f\x35\x79\x62\x4d\x9f\xb7\x47\x1c\x72\xbe\x03\xfc\x05\x00\x00\xff\xff\xf3\xa6\x3d\xc3\x2a\x01\x00\x00") var __1558084410_add_secretUpSql = []byte("\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\xff\x74\x50\xcf\x0a\x82\x30\x1c\xbe\xef\x29\xbe\xa3\x82\x6f\xd0\x49\xc7\x4f\x19\xad\xdf\x6a\x4d\xc8\x93\x48\xf3\x30\x10\x83\xdc\xa5\xb7\x0f\x23\x45\xa1\xce\xdf\xff\x4f\x5a\xca\x1d\xc1\xe5\x85\x26\x4c\xfd\xfd\xd9\xc7\x09\x89\x00\x82\xef\xc7\x18\xe2\x0b\x85\x36\x05\xd8\x38\x70\xad\x35\xce\x56\x9d\x72\xdb\xe0\x48\x0d\x0c\x43\x1a\x2e\xb5\x92\x0e\xaa\x62\x63\x29\x13\xf8\x9a\xec\x65\x22\x3d\x08\xf1\x23\xaa\x0d\xe3\x14\xbb\x61\xe8\x62\x78\x8c\x6d\xf0\x4b\x34\x1c\xdd\xdc\xaa\xce\x36\x75\xda\xe0\xf7\xd6\x33\x58\xb3\xba\xd4\x94\x04\x9f\x6d\x79\xe9\x9f\x82\xa5\xb1\xa4\x2a\xfe\x4c\x48\x76\x7c\x4b\x25\x59\x62\x49\xd7\xe5\x8a\x15\x4f\xe7\x09\xef\x00\x00\x00\xff\xff\xa6\xbb\x2c\x23\x2d\x01\x00\x00")
func _1558084410_add_topicUpSqlBytes() ([]byte, error) { func _1558084410_add_secretUpSqlBytes() ([]byte, error) {
return bindataRead( return bindataRead(
__1558084410_add_topicUpSql, __1558084410_add_secretUpSql,
"1558084410_add_topic.up.sql", "1558084410_add_secret.up.sql",
) )
} }
func _1558084410_add_topicUpSql() (*asset, error) { func _1558084410_add_secretUpSql() (*asset, error) {
bytes, err := _1558084410_add_topicUpSqlBytes() bytes, err := _1558084410_add_secretUpSqlBytes()
if err != nil { if err != nil {
return nil, err return nil, err
} }
info := bindataFileInfo{name: "1558084410_add_topic.up.sql", size: 298, mode: os.FileMode(420), modTime: time.Unix(1560418030, 0)} info := bindataFileInfo{name: "1558084410_add_secret.up.sql", size: 301, mode: os.FileMode(420), modTime: time.Unix(1560418252, 0)}
a := &asset{bytes: bytes, info: info} a := &asset{bytes: bytes, info: info}
return a, nil return a, nil
} }
@ -294,7 +294,7 @@ func _1558588866_add_versionUpSql() (*asset, error) {
return nil, err return nil, err
} }
info := bindataFileInfo{name: "1558588866_add_version.up.sql", size: 57, mode: os.FileMode(420), modTime: time.Unix(1558588995, 0)} info := bindataFileInfo{name: "1558588866_add_version.up.sql", size: 57, mode: os.FileMode(420), modTime: time.Unix(1560418251, 0)}
a := &asset{bytes: bytes, info: info} a := &asset{bytes: bytes, info: info}
return a, nil return a, nil
} }
@ -379,8 +379,8 @@ var _bindata = map[string]func() (*asset, error){
"1540715431_add_version.up.sql": _1540715431_add_versionUpSql, "1540715431_add_version.up.sql": _1540715431_add_versionUpSql,
"1541164797_add_installations.down.sql": _1541164797_add_installationsDownSql, "1541164797_add_installations.down.sql": _1541164797_add_installationsDownSql,
"1541164797_add_installations.up.sql": _1541164797_add_installationsUpSql, "1541164797_add_installations.up.sql": _1541164797_add_installationsUpSql,
"1558084410_add_topic.down.sql": _1558084410_add_topicDownSql, "1558084410_add_secret.down.sql": _1558084410_add_secretDownSql,
"1558084410_add_topic.up.sql": _1558084410_add_topicUpSql, "1558084410_add_secret.up.sql": _1558084410_add_secretUpSql,
"1558588866_add_version.up.sql": _1558588866_add_versionUpSql, "1558588866_add_version.up.sql": _1558588866_add_versionUpSql,
"static.go": staticGo, "static.go": staticGo,
} }
@ -433,8 +433,8 @@ var _bintree = &bintree{nil, map[string]*bintree{
"1540715431_add_version.up.sql": &bintree{_1540715431_add_versionUpSql, map[string]*bintree{}}, "1540715431_add_version.up.sql": &bintree{_1540715431_add_versionUpSql, map[string]*bintree{}},
"1541164797_add_installations.down.sql": &bintree{_1541164797_add_installationsDownSql, map[string]*bintree{}}, "1541164797_add_installations.down.sql": &bintree{_1541164797_add_installationsDownSql, map[string]*bintree{}},
"1541164797_add_installations.up.sql": &bintree{_1541164797_add_installationsUpSql, map[string]*bintree{}}, "1541164797_add_installations.up.sql": &bintree{_1541164797_add_installationsUpSql, map[string]*bintree{}},
"1558084410_add_topic.down.sql": &bintree{_1558084410_add_topicDownSql, map[string]*bintree{}}, "1558084410_add_secret.down.sql": &bintree{_1558084410_add_secretDownSql, map[string]*bintree{}},
"1558084410_add_topic.up.sql": &bintree{_1558084410_add_topicUpSql, map[string]*bintree{}}, "1558084410_add_secret.up.sql": &bintree{_1558084410_add_secretUpSql, map[string]*bintree{}},
"1558588866_add_version.up.sql": &bintree{_1558588866_add_versionUpSql, map[string]*bintree{}}, "1558588866_add_version.up.sql": &bintree{_1558588866_add_versionUpSql, map[string]*bintree{}},
"static.go": &bintree{staticGo, map[string]*bintree{}}, "static.go": &bintree{staticGo, map[string]*bintree{}},
}} }}

View File

@ -1,11 +1,9 @@
package chat package chat
import ( import (
"bytes"
"crypto/ecdsa" "crypto/ecdsa"
"encoding/hex" "encoding/hex"
"errors" "errors"
"fmt"
"sync" "sync"
"time" "time"
@ -15,6 +13,8 @@ import (
dr "github.com/status-im/doubleratchet" dr "github.com/status-im/doubleratchet"
"github.com/status-im/status-go/services/shhext/chat/crypto" "github.com/status-im/status-go/services/shhext/chat/crypto"
"github.com/status-im/status-go/services/shhext/chat/multidevice"
"github.com/status-im/status-go/services/shhext/chat/protobuf"
) )
var ErrSessionNotFound = errors.New("session not found") var ErrSessionNotFound = errors.New("session not found")
@ -51,8 +51,6 @@ type EncryptionServiceConfig struct {
BundleRefreshInterval int64 BundleRefreshInterval int64
} }
type IdentityAndIDPair [2]string
// DefaultEncryptionServiceConfig returns the default values used by the encryption service // DefaultEncryptionServiceConfig returns the default values used by the encryption service
func DefaultEncryptionServiceConfig(installationID string) EncryptionServiceConfig { func DefaultEncryptionServiceConfig(installationID string) EncryptionServiceConfig {
return EncryptionServiceConfig{ return EncryptionServiceConfig{
@ -132,19 +130,9 @@ func (s *EncryptionService) ConfirmMessagesProcessed(messageIDs [][]byte) error
} }
// CreateBundle retrieves or creates an X3DH bundle given a private key // CreateBundle retrieves or creates an X3DH bundle given a private key
func (s *EncryptionService) CreateBundle(privateKey *ecdsa.PrivateKey) (*Bundle, error) { func (s *EncryptionService) CreateBundle(privateKey *ecdsa.PrivateKey, installations []*multidevice.Installation) (*protobuf.Bundle, error) {
ourIdentityKeyC := ecrypto.CompressPubkey(&privateKey.PublicKey) ourIdentityKeyC := ecrypto.CompressPubkey(&privateKey.PublicKey)
installations, err := s.persistence.GetActiveInstallations(s.config.MaxInstallations-1, ourIdentityKeyC)
if err != nil {
return nil, err
}
installations = append(installations, &Installation{
ID: s.config.InstallationID,
Version: protocolCurrentVersion,
})
bundleContainer, err := s.persistence.GetAnyPrivateBundle(ourIdentityKeyC, installations) bundleContainer, err := s.persistence.GetAnyPrivateBundle(ourIdentityKeyC, installations)
if err != nil { if err != nil {
return nil, err return nil, err
@ -176,7 +164,7 @@ func (s *EncryptionService) CreateBundle(privateKey *ecdsa.PrivateKey) (*Bundle,
return nil, err return nil, err
} }
return s.CreateBundle(privateKey) return s.CreateBundle(privateKey, installations)
} }
// DecryptWithDH decrypts message sent with a DH key exchange, and throws away the key after decryption // DecryptWithDH decrypts message sent with a DH key exchange, and throws away the key after decryption
@ -224,55 +212,13 @@ func (s *EncryptionService) keyFromPassiveX3DH(myIdentityKey *ecdsa.PrivateKey,
return key, nil return key, nil
} }
func (s *EncryptionService) EnableInstallation(myIdentityKey *ecdsa.PublicKey, installationID string) error { // ProcessPublicBundle persists a bundle
myIdentityKeyC := ecrypto.CompressPubkey(myIdentityKey) func (s *EncryptionService) ProcessPublicBundle(myIdentityKey *ecdsa.PrivateKey, b *protobuf.Bundle) error {
return s.persistence.EnableInstallation(myIdentityKeyC, installationID) return s.persistence.AddPublicBundle(b)
}
func (s *EncryptionService) DisableInstallation(myIdentityKey *ecdsa.PublicKey, installationID string) error {
myIdentityKeyC := ecrypto.CompressPubkey(myIdentityKey)
return s.persistence.DisableInstallation(myIdentityKeyC, installationID)
}
// ProcessPublicBundle persists a bundle and returns a list of tuples identity/installationID
func (s *EncryptionService) ProcessPublicBundle(myIdentityKey *ecdsa.PrivateKey, b *Bundle) ([]IdentityAndIDPair, error) {
// Make sure the bundle belongs to who signed it
identity, err := ExtractIdentity(b)
if err != nil {
return nil, err
}
signedPreKeys := b.GetSignedPreKeys()
var response []IdentityAndIDPair
var installations []*Installation
myIdentityStr := fmt.Sprintf("0x%x", ecrypto.FromECDSAPub(&myIdentityKey.PublicKey))
// Any device from other peers will be considered enabled, ours needs to
// be explicitly enabled
fromOurIdentity := identity != myIdentityStr
for installationID, signedPreKey := range signedPreKeys {
if installationID != s.config.InstallationID {
installations = append(installations, &Installation{
ID: installationID,
Version: signedPreKey.GetProtocolVersion(),
})
response = append(response, IdentityAndIDPair{identity, installationID})
}
}
if err = s.persistence.AddInstallations(b.GetIdentity(), b.GetTimestamp(), installations, fromOurIdentity); err != nil {
return nil, err
}
if err = s.persistence.AddPublicBundle(b); err != nil {
return nil, err
}
return response, nil
} }
// DecryptPayload decrypts the payload of a DirectMessageProtocol, given an identity private key and the sender's public key // DecryptPayload decrypts the payload of a DirectMessageProtocol, given an identity private key and the sender's public key
func (s *EncryptionService) DecryptPayload(myIdentityKey *ecdsa.PrivateKey, theirIdentityKey *ecdsa.PublicKey, theirInstallationID string, msgs map[string]*DirectMessageProtocol, messageID []byte) ([]byte, error) { func (s *EncryptionService) DecryptPayload(myIdentityKey *ecdsa.PrivateKey, theirIdentityKey *ecdsa.PublicKey, theirInstallationID string, msgs map[string]*protobuf.DirectMessageProtocol, messageID []byte) ([]byte, error) {
s.mutex.Lock() s.mutex.Lock()
defer s.mutex.Unlock() defer s.mutex.Unlock()
@ -328,11 +274,6 @@ func (s *EncryptionService) DecryptPayload(myIdentityKey *ecdsa.PrivateKey, thei
return nil, err return nil, err
} }
// Add installations with a timestamp of 0, as we don't have bundle informations
if err = s.persistence.AddInstallations(theirIdentityKeyC, 0, []*Installation{{ID: theirInstallationID, Version: 0}}, true); err != nil {
return nil, err
}
// We mark the exchange as successful so we stop sending x3dh header // We mark the exchange as successful so we stop sending x3dh header
if err = s.persistence.RatchetInfoConfirmed(drHeader.GetId(), theirIdentityKeyC, theirInstallationID); err != nil { if err = s.persistence.RatchetInfoConfirmed(drHeader.GetId(), theirIdentityKeyC, theirInstallationID); err != nil {
s.log.Error("Could not confirm ratchet info", "err", err) s.log.Error("Could not confirm ratchet info", "err", err)
@ -396,7 +337,7 @@ func (s *EncryptionService) createNewSession(drInfo *RatchetInfo, sk [32]byte, k
return session, err return session, err
} }
func (s *EncryptionService) encryptUsingDR(theirIdentityKey *ecdsa.PublicKey, drInfo *RatchetInfo, payload []byte) ([]byte, *DRHeader, error) { func (s *EncryptionService) encryptUsingDR(theirIdentityKey *ecdsa.PublicKey, drInfo *RatchetInfo, payload []byte) ([]byte, *protobuf.DRHeader, error) {
var err error var err error
var session dr.Session var session dr.Session
@ -430,7 +371,7 @@ func (s *EncryptionService) encryptUsingDR(theirIdentityKey *ecdsa.PublicKey, dr
return nil, nil, err return nil, nil, err
} }
header := &DRHeader{ header := &protobuf.DRHeader{
Id: drInfo.BundleID, Id: drInfo.BundleID,
Key: response.Header.DH[:], Key: response.Header.DH[:],
N: response.Header.N, N: response.Header.N,
@ -474,7 +415,7 @@ func (s *EncryptionService) decryptUsingDR(theirIdentityKey *ecdsa.PublicKey, dr
return plaintext, nil return plaintext, nil
} }
func (s *EncryptionService) encryptWithDH(theirIdentityKey *ecdsa.PublicKey, payload []byte) (*DirectMessageProtocol, error) { func (s *EncryptionService) encryptWithDH(theirIdentityKey *ecdsa.PublicKey, payload []byte) (*protobuf.DirectMessageProtocol, error) {
symmetricKey, ourEphemeralKey, err := PerformActiveDH(theirIdentityKey) symmetricKey, ourEphemeralKey, err := PerformActiveDH(theirIdentityKey)
if err != nil { if err != nil {
return nil, err return nil, err
@ -485,16 +426,16 @@ func (s *EncryptionService) encryptWithDH(theirIdentityKey *ecdsa.PublicKey, pay
return nil, err return nil, err
} }
return &DirectMessageProtocol{ return &protobuf.DirectMessageProtocol{
DHHeader: &DHHeader{ DHHeader: &protobuf.DHHeader{
Key: ecrypto.CompressPubkey(ourEphemeralKey), Key: ecrypto.CompressPubkey(ourEphemeralKey),
}, },
Payload: encryptedPayload, Payload: encryptedPayload,
}, nil }, nil
} }
func (s *EncryptionService) EncryptPayloadWithDH(theirIdentityKey *ecdsa.PublicKey, payload []byte) (map[string]*DirectMessageProtocol, error) { func (s *EncryptionService) EncryptPayloadWithDH(theirIdentityKey *ecdsa.PublicKey, payload []byte) (map[string]*protobuf.DirectMessageProtocol, error) {
response := make(map[string]*DirectMessageProtocol) response := make(map[string]*protobuf.DirectMessageProtocol)
dmp, err := s.encryptWithDH(theirIdentityKey, payload) dmp, err := s.encryptWithDH(theirIdentityKey, payload)
if err != nil { if err != nil {
return nil, err return nil, err
@ -505,21 +446,15 @@ func (s *EncryptionService) EncryptPayloadWithDH(theirIdentityKey *ecdsa.PublicK
} }
// GetPublicBundle returns the active installations bundles for a given user // GetPublicBundle returns the active installations bundles for a given user
func (s *EncryptionService) GetPublicBundle(theirIdentityKey *ecdsa.PublicKey) (*Bundle, error) { func (s *EncryptionService) GetPublicBundle(theirIdentityKey *ecdsa.PublicKey, installations []*multidevice.Installation) (*protobuf.Bundle, error) {
theirIdentityKeyC := ecrypto.CompressPubkey(theirIdentityKey)
installations, err := s.persistence.GetActiveInstallations(s.config.MaxInstallations, theirIdentityKeyC)
if err != nil {
return nil, err
}
return s.persistence.GetPublicBundle(theirIdentityKey, installations) return s.persistence.GetPublicBundle(theirIdentityKey, installations)
} }
// EncryptPayload returns a new DirectMessageProtocol with a given payload encrypted, given a recipient's public key and the sender private identity key // EncryptPayload returns a new DirectMessageProtocol with a given payload encrypted, given a recipient's public key and the sender private identity key
// TODO: refactor this func (s *EncryptionService) EncryptPayload(theirIdentityKey *ecdsa.PublicKey, myIdentityKey *ecdsa.PrivateKey, installations []*multidevice.Installation, payload []byte) (map[string]*protobuf.DirectMessageProtocol, []*multidevice.Installation, error) {
// nolint: gocyclo // Which installations we are sending the message to
func (s *EncryptionService) EncryptPayload(theirIdentityKey *ecdsa.PublicKey, myIdentityKey *ecdsa.PrivateKey, payload []byte) (map[string]*DirectMessageProtocol, error) { var targetedInstallations []*multidevice.Installation
s.mutex.Lock() s.mutex.Lock()
defer s.mutex.Unlock() defer s.mutex.Unlock()
@ -527,18 +462,13 @@ func (s *EncryptionService) EncryptPayload(theirIdentityKey *ecdsa.PublicKey, my
theirIdentityKeyC := ecrypto.CompressPubkey(theirIdentityKey) theirIdentityKeyC := ecrypto.CompressPubkey(theirIdentityKey)
// Get their installationIds
installations, err := s.persistence.GetActiveInstallations(s.config.MaxInstallations, theirIdentityKeyC)
if err != nil {
return nil, err
}
// We don't have any, send a message with DH // We don't have any, send a message with DH
if installations == nil && !bytes.Equal(theirIdentityKeyC, ecrypto.CompressPubkey(&myIdentityKey.PublicKey)) { if len(installations) == 0 {
return s.EncryptPayloadWithDH(theirIdentityKey, payload) encryptedPayload, err := s.EncryptPayloadWithDH(theirIdentityKey, payload)
return encryptedPayload, targetedInstallations, err
} }
response := make(map[string]*DirectMessageProtocol) response := make(map[string]*protobuf.DirectMessageProtocol)
for _, installation := range installations { for _, installation := range installations {
installationID := installation.ID installationID := installation.ID
@ -546,31 +476,33 @@ func (s *EncryptionService) EncryptPayload(theirIdentityKey *ecdsa.PublicKey, my
if s.config.InstallationID == installationID { if s.config.InstallationID == installationID {
continue continue
} }
bundle, err := s.persistence.GetPublicBundle(theirIdentityKey, []*Installation{installation}) bundle, err := s.persistence.GetPublicBundle(theirIdentityKey, []*multidevice.Installation{installation})
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
// See if a session is there already // See if a session is there already
drInfo, err := s.persistence.GetAnyRatchetInfo(theirIdentityKeyC, installationID) drInfo, err := s.persistence.GetAnyRatchetInfo(theirIdentityKeyC, installationID)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
targetedInstallations = append(targetedInstallations, installation)
if drInfo != nil { if drInfo != nil {
s.log.Debug("Found DR info", "installationID", installationID) s.log.Debug("Found DR info", "installationID", installationID)
encryptedPayload, drHeader, err := s.encryptUsingDR(theirIdentityKey, drInfo, payload) encryptedPayload, drHeader, err := s.encryptUsingDR(theirIdentityKey, drInfo, payload)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
dmp := DirectMessageProtocol{ dmp := protobuf.DirectMessageProtocol{
Payload: encryptedPayload, Payload: encryptedPayload,
DRHeader: drHeader, DRHeader: drHeader,
} }
if drInfo.EphemeralKey != nil { if drInfo.EphemeralKey != nil {
dmp.X3DHHeader = &X3DHHeader{ dmp.X3DHHeader = &protobuf.X3DHHeader{
Key: drInfo.EphemeralKey, Key: drInfo.EphemeralKey,
Id: drInfo.BundleID, Id: drInfo.BundleID,
} }
@ -594,33 +526,33 @@ func (s *EncryptionService) EncryptPayload(theirIdentityKey *ecdsa.PublicKey, my
sharedKey, ourEphemeralKey, err := s.keyFromActiveX3DH(theirIdentityKeyC, theirSignedPreKey, myIdentityKey) sharedKey, ourEphemeralKey, err := s.keyFromActiveX3DH(theirIdentityKeyC, theirSignedPreKey, myIdentityKey)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
theirIdentityKeyC := ecrypto.CompressPubkey(theirIdentityKey) theirIdentityKeyC := ecrypto.CompressPubkey(theirIdentityKey)
ourEphemeralKeyC := ecrypto.CompressPubkey(ourEphemeralKey) ourEphemeralKeyC := ecrypto.CompressPubkey(ourEphemeralKey)
err = s.persistence.AddRatchetInfo(sharedKey, theirIdentityKeyC, theirSignedPreKey, ourEphemeralKeyC, installationID) err = s.persistence.AddRatchetInfo(sharedKey, theirIdentityKeyC, theirSignedPreKey, ourEphemeralKeyC, installationID)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
x3dhHeader := &X3DHHeader{ x3dhHeader := &protobuf.X3DHHeader{
Key: ourEphemeralKeyC, Key: ourEphemeralKeyC,
Id: theirSignedPreKey, Id: theirSignedPreKey,
} }
drInfo, err = s.persistence.GetRatchetInfo(theirSignedPreKey, theirIdentityKeyC, installationID) drInfo, err = s.persistence.GetRatchetInfo(theirSignedPreKey, theirIdentityKeyC, installationID)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
if drInfo != nil { if drInfo != nil {
encryptedPayload, drHeader, err := s.encryptUsingDR(theirIdentityKey, drInfo, payload) encryptedPayload, drHeader, err := s.encryptUsingDR(theirIdentityKey, drInfo, payload)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
dmp := &DirectMessageProtocol{ dmp := &protobuf.DirectMessageProtocol{
Payload: encryptedPayload, Payload: encryptedPayload,
X3DHHeader: x3dhHeader, X3DHHeader: x3dhHeader,
DRHeader: drHeader, DRHeader: drHeader,
@ -632,5 +564,5 @@ func (s *EncryptionService) EncryptPayload(theirIdentityKey *ecdsa.PublicKey, my
s.log.Debug("Built message", "theirKey", theirIdentityKey) s.log.Debug("Built message", "theirKey", theirIdentityKey)
return response, nil return response, targetedInstallations, nil
} }

View File

@ -8,6 +8,9 @@ import (
"os" "os"
"sort" "sort"
"testing" "testing"
"github.com/status-im/status-go/services/shhext/chat/multidevice"
"github.com/status-im/status-go/services/shhext/chat/sharedsecret"
) )
const ( const (
@ -20,7 +23,7 @@ func TestEncryptionServiceMultiDeviceSuite(t *testing.T) {
} }
type serviceAndKey struct { type serviceAndKey struct {
encryptionServices []*EncryptionService services []*ProtocolService
key *ecdsa.PrivateKey key *ecdsa.PrivateKey
} }
@ -37,7 +40,7 @@ func setupUser(user string, s *EncryptionServiceMultiDeviceSuite, n int) error {
s.services[user] = &serviceAndKey{ s.services[user] = &serviceAndKey{
key: key, key: key,
encryptionServices: make([]*EncryptionService, n), services: make([]*ProtocolService, n),
} }
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
@ -50,11 +53,27 @@ func setupUser(user string, s *EncryptionServiceMultiDeviceSuite, n int) error {
if err != nil { if err != nil {
return err return err
} }
// Initialize sharedsecret
multideviceConfig := &multidevice.Config{
MaxInstallations: n - 1,
InstallationID: installationID,
ProtocolVersion: 1,
}
config := DefaultEncryptionServiceConfig(installationID) sharedSecretService := sharedsecret.NewService(persistence.GetSharedSecretStorage())
config.MaxInstallations = n - 1 multideviceService := multidevice.New(multideviceConfig, persistence.GetMultideviceStorage())
s.services[user].encryptionServices[i] = NewEncryptionService(persistence, config) protocol := NewProtocolService(
NewEncryptionService(
persistence,
DefaultEncryptionServiceConfig(installationID)),
sharedSecretService,
multideviceService,
func(s []multidevice.IdentityAndIDPair) {},
func(s []*sharedsecret.Secret) {},
)
s.services[user].services[i] = protocol
} }
@ -73,43 +92,47 @@ func (s *EncryptionServiceMultiDeviceSuite) SetupTest() {
func (s *EncryptionServiceMultiDeviceSuite) TestProcessPublicBundle() { func (s *EncryptionServiceMultiDeviceSuite) TestProcessPublicBundle() {
aliceKey := s.services[aliceUser].key aliceKey := s.services[aliceUser].key
alice2Bundle, err := s.services[aliceUser].encryptionServices[1].CreateBundle(aliceKey) alice2Bundle, err := s.services[aliceUser].services[1].GetBundle(aliceKey)
s.Require().NoError(err) s.Require().NoError(err)
alice2Identity, err := ExtractIdentity(alice2Bundle) alice2IdentityPK, err := ExtractIdentity(alice2Bundle)
s.Require().NoError(err) s.Require().NoError(err)
alice3Bundle, err := s.services[aliceUser].encryptionServices[2].CreateBundle(aliceKey) alice2Identity := fmt.Sprintf("0x%x", crypto.FromECDSAPub(alice2IdentityPK))
alice3Bundle, err := s.services[aliceUser].services[2].GetBundle(aliceKey)
s.Require().NoError(err) s.Require().NoError(err)
alice3Identity, err := ExtractIdentity(alice2Bundle) alice3IdentityPK, err := ExtractIdentity(alice2Bundle)
s.Require().NoError(err) s.Require().NoError(err)
alice3Identity := fmt.Sprintf("0x%x", crypto.FromECDSAPub(alice3IdentityPK))
// Add alice2 bundle // Add alice2 bundle
response, err := s.services[aliceUser].encryptionServices[0].ProcessPublicBundle(aliceKey, alice2Bundle) response, err := s.services[aliceUser].services[0].ProcessPublicBundle(aliceKey, alice2Bundle)
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Equal(IdentityAndIDPair{alice2Identity, "alice2"}, response[0]) s.Require().Equal(multidevice.IdentityAndIDPair{alice2Identity, "alice2"}, response[0])
// Add alice3 bundle // Add alice3 bundle
response, err = s.services[aliceUser].encryptionServices[0].ProcessPublicBundle(aliceKey, alice3Bundle) response, err = s.services[aliceUser].services[0].ProcessPublicBundle(aliceKey, alice3Bundle)
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Equal(IdentityAndIDPair{alice3Identity, "alice3"}, response[0]) s.Require().Equal(multidevice.IdentityAndIDPair{alice3Identity, "alice3"}, response[0])
// No installation is enabled // No installation is enabled
alice1MergedBundle1, err := s.services[aliceUser].encryptionServices[0].CreateBundle(aliceKey) alice1MergedBundle1, err := s.services[aliceUser].services[0].GetBundle(aliceKey)
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Equal(1, len(alice1MergedBundle1.GetSignedPreKeys())) s.Require().Equal(1, len(alice1MergedBundle1.GetSignedPreKeys()))
s.Require().NotNil(alice1MergedBundle1.GetSignedPreKeys()["alice1"]) s.Require().NotNil(alice1MergedBundle1.GetSignedPreKeys()["alice1"])
// We enable the installations // We enable the installations
err = s.services[aliceUser].encryptionServices[0].EnableInstallation(&aliceKey.PublicKey, "alice2") err = s.services[aliceUser].services[0].EnableInstallation(&aliceKey.PublicKey, "alice2")
s.Require().NoError(err) s.Require().NoError(err)
err = s.services[aliceUser].encryptionServices[0].EnableInstallation(&aliceKey.PublicKey, "alice3") err = s.services[aliceUser].services[0].EnableInstallation(&aliceKey.PublicKey, "alice3")
s.Require().NoError(err) s.Require().NoError(err)
alice1MergedBundle2, err := s.services[aliceUser].encryptionServices[0].CreateBundle(aliceKey) alice1MergedBundle2, err := s.services[aliceUser].services[0].GetBundle(aliceKey)
s.Require().NoError(err) s.Require().NoError(err)
// We get back a bundle with all the installations // We get back a bundle with all the installations
@ -118,21 +141,21 @@ func (s *EncryptionServiceMultiDeviceSuite) TestProcessPublicBundle() {
s.Require().NotNil(alice1MergedBundle2.GetSignedPreKeys()["alice2"]) s.Require().NotNil(alice1MergedBundle2.GetSignedPreKeys()["alice2"])
s.Require().NotNil(alice1MergedBundle2.GetSignedPreKeys()["alice3"]) s.Require().NotNil(alice1MergedBundle2.GetSignedPreKeys()["alice3"])
response, err = s.services[aliceUser].encryptionServices[0].ProcessPublicBundle(aliceKey, alice1MergedBundle2) response, err = s.services[aliceUser].services[0].ProcessPublicBundle(aliceKey, alice1MergedBundle2)
s.Require().NoError(err) s.Require().NoError(err)
sort.Slice(response, func(i, j int) bool { sort.Slice(response, func(i, j int) bool {
return response[i][1] < response[j][1] return response[i][1] < response[j][1]
}) })
// We only get back installationIDs not equal to us // We only get back installationIDs not equal to us
s.Require().Equal(2, len(response)) s.Require().Equal(2, len(response))
s.Require().Equal(IdentityAndIDPair{alice2Identity, "alice2"}, response[0]) s.Require().Equal(multidevice.IdentityAndIDPair{alice2Identity, "alice2"}, response[0])
s.Require().Equal(IdentityAndIDPair{alice2Identity, "alice3"}, response[1]) s.Require().Equal(multidevice.IdentityAndIDPair{alice2Identity, "alice3"}, response[1])
// We disable the installations // We disable the installations
err = s.services[aliceUser].encryptionServices[0].DisableInstallation(&aliceKey.PublicKey, "alice2") err = s.services[aliceUser].services[0].DisableInstallation(&aliceKey.PublicKey, "alice2")
s.Require().NoError(err) s.Require().NoError(err)
alice1MergedBundle3, err := s.services[aliceUser].encryptionServices[0].CreateBundle(aliceKey) alice1MergedBundle3, err := s.services[aliceUser].services[0].GetBundle(aliceKey)
s.Require().NoError(err) s.Require().NoError(err)
// We get back a bundle with all the installations // We get back a bundle with all the installations
@ -146,23 +169,23 @@ func (s *EncryptionServiceMultiDeviceSuite) TestProcessPublicBundleOutOfOrder()
s.Require().NoError(err) s.Require().NoError(err)
// Alice1 creates a bundle // Alice1 creates a bundle
alice1Bundle, err := s.services[aliceUser].encryptionServices[0].CreateBundle(aliceKey) alice1Bundle, err := s.services[aliceUser].services[0].GetBundle(aliceKey)
s.Require().NoError(err) s.Require().NoError(err)
// Alice2 Receives the bundle // Alice2 Receives the bundle
_, err = s.services[aliceUser].encryptionServices[1].ProcessPublicBundle(aliceKey, alice1Bundle) _, err = s.services[aliceUser].services[1].ProcessPublicBundle(aliceKey, alice1Bundle)
s.Require().NoError(err) s.Require().NoError(err)
// Alice2 Creates a Bundle // Alice2 Creates a Bundle
_, err = s.services[aliceUser].encryptionServices[1].CreateBundle(aliceKey) _, err = s.services[aliceUser].services[1].GetBundle(aliceKey)
s.Require().NoError(err) s.Require().NoError(err)
// We enable the installation // We enable the installation
err = s.services[aliceUser].encryptionServices[1].EnableInstallation(&aliceKey.PublicKey, "alice1") err = s.services[aliceUser].services[1].EnableInstallation(&aliceKey.PublicKey, "alice1")
s.Require().NoError(err) s.Require().NoError(err)
// It should contain both bundles // It should contain both bundles
alice2MergedBundle1, err := s.services[aliceUser].encryptionServices[1].CreateBundle(aliceKey) alice2MergedBundle1, err := s.services[aliceUser].services[1].GetBundle(aliceKey)
s.Require().NoError(err) s.Require().NoError(err)
s.Require().NotNil(alice2MergedBundle1.GetSignedPreKeys()["alice1"]) s.Require().NotNil(alice2MergedBundle1.GetSignedPreKeys()["alice1"])
@ -170,9 +193,9 @@ func (s *EncryptionServiceMultiDeviceSuite) TestProcessPublicBundleOutOfOrder()
} }
func pairDevices(s *serviceAndKey, target int) error { func pairDevices(s *serviceAndKey, target int) error {
device := s.encryptionServices[target] device := s.services[target]
for i := 0; i < len(s.encryptionServices); i++ { for i := 0; i < len(s.services); i++ {
b, err := s.encryptionServices[i].CreateBundle(s.key) b, err := s.services[i].GetBundle(s.key)
if err != nil { if err != nil {
return err return err
@ -183,7 +206,7 @@ func pairDevices(s *serviceAndKey, target int) error {
return err return err
} }
err = device.EnableInstallation(&s.key.PublicKey, s.encryptionServices[i].config.InstallationID) err = device.EnableInstallation(&s.key.PublicKey, s.services[i].encryption.config.InstallationID)
if err != nil { if err != nil {
return nil return nil
} }
@ -194,14 +217,14 @@ func pairDevices(s *serviceAndKey, target int) error {
func (s *EncryptionServiceMultiDeviceSuite) TestMaxDevices() { func (s *EncryptionServiceMultiDeviceSuite) TestMaxDevices() {
err := pairDevices(s.services[aliceUser], 0) err := pairDevices(s.services[aliceUser], 0)
s.Require().NoError(err) s.Require().NoError(err)
alice1 := s.services[aliceUser].encryptionServices[0] alice1 := s.services[aliceUser].services[0]
bob1 := s.services[bobUser].encryptionServices[0] bob1 := s.services[bobUser].services[0]
aliceKey := s.services[aliceUser].key aliceKey := s.services[aliceUser].key
bobKey := s.services[bobUser].key bobKey := s.services[bobUser].key
// Check bundle is ok // Check bundle is ok
// No installation is enabled // No installation is enabled
aliceBundle, err := alice1.CreateBundle(aliceKey) aliceBundle, err := alice1.GetBundle(aliceKey)
s.Require().NoError(err) s.Require().NoError(err)
// Check all installations are correctly working, and that the oldest device is not there // Check all installations are correctly working, and that the oldest device is not there
@ -218,19 +241,20 @@ func (s *EncryptionServiceMultiDeviceSuite) TestMaxDevices() {
s.Require().NoError(err) s.Require().NoError(err)
// Bob sends a message to alice // Bob sends a message to alice
payload, err := bob1.EncryptPayload(&aliceKey.PublicKey, bobKey, []byte("test")) msg, err := bob1.BuildDirectMessage(bobKey, &aliceKey.PublicKey, []byte("test"))
s.Require().NoError(err) s.Require().NoError(err)
payload := msg.Message.GetDirectMessage()
s.Require().Equal(3, len(payload)) s.Require().Equal(3, len(payload))
s.Require().NotNil(payload["alice1"]) s.Require().NotNil(payload["alice1"])
s.Require().NotNil(payload["alice3"]) s.Require().NotNil(payload["alice3"])
s.Require().NotNil(payload["alice4"]) s.Require().NotNil(payload["alice4"])
// We disable the last installation // We disable the last installation
err = s.services[aliceUser].encryptionServices[0].DisableInstallation(&aliceKey.PublicKey, "alice4") err = s.services[aliceUser].services[0].DisableInstallation(&aliceKey.PublicKey, "alice4")
s.Require().NoError(err) s.Require().NoError(err)
// We check the bundle is updated // We check the bundle is updated
aliceBundle, err = alice1.CreateBundle(aliceKey) aliceBundle, err = alice1.GetBundle(aliceKey)
s.Require().NoError(err) s.Require().NoError(err)
// Check all installations are there // Check all installations are there
@ -247,8 +271,9 @@ func (s *EncryptionServiceMultiDeviceSuite) TestMaxDevices() {
s.Require().NoError(err) s.Require().NoError(err)
// Bob sends a message to alice // Bob sends a message to alice
payload, err = bob1.EncryptPayload(&aliceKey.PublicKey, bobKey, []byte("test")) msg, err = bob1.BuildDirectMessage(bobKey, &aliceKey.PublicKey, []byte("test"))
s.Require().NoError(err) s.Require().NoError(err)
payload = msg.Message.GetDirectMessage()
s.Require().Equal(3, len(payload)) s.Require().Equal(3, len(payload))
s.Require().NotNil(payload["alice1"]) s.Require().NotNil(payload["alice1"])
s.Require().NotNil(payload["alice2"]) s.Require().NotNil(payload["alice2"])

View File

@ -13,6 +13,9 @@ import (
"time" "time"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/status-im/status-go/services/shhext/chat/multidevice"
"github.com/status-im/status-go/services/shhext/chat/protobuf"
"github.com/status-im/status-go/services/shhext/chat/sharedsecret"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
) )
@ -27,8 +30,8 @@ func TestEncryptionServiceTestSuite(t *testing.T) {
type EncryptionServiceTestSuite struct { type EncryptionServiceTestSuite struct {
suite.Suite suite.Suite
alice *EncryptionService alice *ProtocolService
bob *EncryptionService bob *ProtocolService
aliceDBPath string aliceDBPath string
bobDBPath string bobDBPath string
} }
@ -56,21 +59,57 @@ func (s *EncryptionServiceTestSuite) initDatabases(baseConfig *EncryptionService
bobDBKey = "bob" bobDBKey = "bob"
) )
aliceMultideviceConfig := &multidevice.Config{
MaxInstallations: 3,
InstallationID: aliceInstallationID,
ProtocolVersion: 1,
}
alicePersistence, err := NewSQLLitePersistence(aliceDBPath, aliceDBKey) alicePersistence, err := NewSQLLitePersistence(aliceDBPath, aliceDBKey)
if err != nil { if err != nil {
panic(err) panic(err)
} }
baseConfig.InstallationID = aliceInstallationID
aliceEncryptionService := NewEncryptionService(alicePersistence, *baseConfig)
aliceSharedSecretService := sharedsecret.NewService(alicePersistence.GetSharedSecretStorage())
aliceMultideviceService := multidevice.New(aliceMultideviceConfig, alicePersistence.GetMultideviceStorage())
s.alice = NewProtocolService(
aliceEncryptionService,
aliceSharedSecretService,
aliceMultideviceService,
func(s []multidevice.IdentityAndIDPair) {},
func(s []*sharedsecret.Secret) {},
)
bobPersistence, err := NewSQLLitePersistence(bobDBPath, bobDBKey) bobPersistence, err := NewSQLLitePersistence(bobDBPath, bobDBKey)
if err != nil { if err != nil {
panic(err) panic(err)
} }
baseConfig.InstallationID = aliceInstallationID bobMultideviceConfig := &multidevice.Config{
s.alice = NewEncryptionService(alicePersistence, *baseConfig) MaxInstallations: 3,
InstallationID: bobInstallationID,
ProtocolVersion: 1,
}
bobMultideviceService := multidevice.New(bobMultideviceConfig, bobPersistence.GetMultideviceStorage())
bobSharedSecretService := sharedsecret.NewService(bobPersistence.GetSharedSecretStorage())
baseConfig.InstallationID = bobInstallationID baseConfig.InstallationID = bobInstallationID
s.bob = NewEncryptionService(bobPersistence, *baseConfig) bobEncryptionService := NewEncryptionService(bobPersistence, *baseConfig)
s.bob = NewProtocolService(
bobEncryptionService,
bobSharedSecretService,
bobMultideviceService,
func(s []multidevice.IdentityAndIDPair) {},
func(s []*sharedsecret.Secret) {},
)
} }
func (s *EncryptionServiceTestSuite) SetupTest() { func (s *EncryptionServiceTestSuite) SetupTest() {
@ -82,14 +121,14 @@ func (s *EncryptionServiceTestSuite) TearDownTest() {
os.Remove(s.bobDBPath) os.Remove(s.bobDBPath)
} }
func (s *EncryptionServiceTestSuite) TestCreateBundle() { func (s *EncryptionServiceTestSuite) TestGetBundle() {
aliceKey, err := crypto.GenerateKey() aliceKey, err := crypto.GenerateKey()
s.Require().NoError(err) s.Require().NoError(err)
aliceBundle1, err := s.alice.CreateBundle(aliceKey) aliceBundle1, err := s.alice.GetBundle(aliceKey)
s.Require().NoError(err) s.Require().NoError(err)
s.NotNil(aliceBundle1, "It creates a bundle") s.NotNil(aliceBundle1, "It creates a bundle")
aliceBundle2, err := s.alice.CreateBundle(aliceKey) aliceBundle2, err := s.alice.GetBundle(aliceKey)
s.Require().NoError(err) s.Require().NoError(err)
s.Equal(aliceBundle1, aliceBundle2, "It returns the same bundle") s.Equal(aliceBundle1, aliceBundle2, "It returns the same bundle")
} }
@ -105,9 +144,11 @@ func (s *EncryptionServiceTestSuite) TestEncryptPayloadNoBundle() {
aliceKey, err := crypto.GenerateKey() aliceKey, err := crypto.GenerateKey()
s.Require().NoError(err) s.Require().NoError(err)
encryptionResponse1, err := s.alice.EncryptPayload(&bobKey.PublicKey, aliceKey, cleartext) response1, err := s.alice.BuildDirectMessage(aliceKey, &bobKey.PublicKey, cleartext)
s.Require().NoError(err) s.Require().NoError(err)
encryptionResponse1 := response1.Message.GetDirectMessage()
installationResponse1 := encryptionResponse1["none"] installationResponse1 := encryptionResponse1["none"]
// That's for any device // That's for any device
s.Require().NotNil(installationResponse1) s.Require().NotNil(installationResponse1)
@ -119,14 +160,16 @@ func (s *EncryptionServiceTestSuite) TestEncryptPayloadNoBundle() {
s.NotEqual(cyphertext1, cleartext, "It encrypts the payload correctly") s.NotEqual(cyphertext1, cleartext, "It encrypts the payload correctly")
// On the receiver side, we should be able to decrypt using our private key and the ephemeral just sent // On the receiver side, we should be able to decrypt using our private key and the ephemeral just sent
decryptedPayload1, err := s.bob.DecryptPayload(bobKey, &aliceKey.PublicKey, aliceInstallationID, encryptionResponse1, defaultMessageID) decryptedPayload1, err := s.bob.HandleMessage(bobKey, &aliceKey.PublicKey, response1.Message, defaultMessageID)
s.Require().NoError(err) s.Require().NoError(err)
s.Equal(cleartext, decryptedPayload1, "It correctly decrypts the payload using DH") s.Equal(cleartext, decryptedPayload1, "It correctly decrypts the payload using DH")
// The next message will not be re-using the same key // The next message will not be re-using the same key
encryptionResponse2, err := s.alice.EncryptPayload(&bobKey.PublicKey, aliceKey, cleartext) response2, err := s.alice.BuildDirectMessage(aliceKey, &bobKey.PublicKey, cleartext)
s.Require().NoError(err) s.Require().NoError(err)
encryptionResponse2 := response2.Message.GetDirectMessage()
installationResponse2 := encryptionResponse2[aliceInstallationID] installationResponse2 := encryptionResponse2[aliceInstallationID]
cyphertext2 := installationResponse2.GetPayload() cyphertext2 := installationResponse2.GetPayload()
@ -134,7 +177,7 @@ func (s *EncryptionServiceTestSuite) TestEncryptPayloadNoBundle() {
s.NotEqual(cyphertext1, cyphertext2, "It does not re-use the symmetric key") s.NotEqual(cyphertext1, cyphertext2, "It does not re-use the symmetric key")
s.NotEqual(ephemeralKey1, ephemeralKey2, "It does not re-use the ephemeral key") s.NotEqual(ephemeralKey1, ephemeralKey2, "It does not re-use the ephemeral key")
decryptedPayload2, err := s.bob.DecryptPayload(bobKey, &aliceKey.PublicKey, aliceInstallationID, encryptionResponse2, defaultMessageID) decryptedPayload2, err := s.bob.HandleMessage(bobKey, &aliceKey.PublicKey, response2.Message, defaultMessageID)
s.Require().NoError(err) s.Require().NoError(err)
s.Equal(cleartext, decryptedPayload2, "It correctly decrypts the payload using DH") s.Equal(cleartext, decryptedPayload2, "It correctly decrypts the payload using DH")
} }
@ -150,7 +193,7 @@ func (s *EncryptionServiceTestSuite) TestEncryptPayloadBundle() {
s.Require().NoError(err) s.Require().NoError(err)
// Create a bundle // Create a bundle
bobBundle, err := s.bob.CreateBundle(bobKey) bobBundle, err := s.bob.GetBundle(bobKey)
s.Require().NoError(err) s.Require().NoError(err)
// We add bob bundle // We add bob bundle
@ -158,9 +201,11 @@ func (s *EncryptionServiceTestSuite) TestEncryptPayloadBundle() {
s.Require().NoError(err) s.Require().NoError(err)
// We send a message using the bundle // We send a message using the bundle
encryptionResponse1, err := s.alice.EncryptPayload(&bobKey.PublicKey, aliceKey, cleartext) response1, err := s.alice.BuildDirectMessage(aliceKey, &bobKey.PublicKey, cleartext)
s.Require().NoError(err) s.Require().NoError(err)
encryptionResponse1 := response1.Message.GetDirectMessage()
installationResponse1 := encryptionResponse1[bobInstallationID] installationResponse1 := encryptionResponse1[bobInstallationID]
s.Require().NotNil(installationResponse1) s.Require().NotNil(installationResponse1)
@ -186,7 +231,7 @@ func (s *EncryptionServiceTestSuite) TestEncryptPayloadBundle() {
s.Equal(uint32(0), drHeader.GetPn(), "It adds the correct length of the message chain") s.Equal(uint32(0), drHeader.GetPn(), "It adds the correct length of the message chain")
// Bob is able to decrypt it using the bundle // Bob is able to decrypt it using the bundle
decryptedPayload1, err := s.bob.DecryptPayload(bobKey, &aliceKey.PublicKey, aliceInstallationID, encryptionResponse1, defaultMessageID) decryptedPayload1, err := s.bob.HandleMessage(bobKey, &aliceKey.PublicKey, response1.Message, defaultMessageID)
s.Require().NoError(err) s.Require().NoError(err)
s.Equal(cleartext, decryptedPayload1, "It correctly decrypts the payload using X3DH") s.Equal(cleartext, decryptedPayload1, "It correctly decrypts the payload using X3DH")
} }
@ -209,7 +254,7 @@ func (s *EncryptionServiceTestSuite) TestConsequentMessagesBundle() {
s.Require().NoError(err) s.Require().NoError(err)
// Create a bundle // Create a bundle
bobBundle, err := s.bob.CreateBundle(bobKey) bobBundle, err := s.bob.GetBundle(bobKey)
s.Require().NoError(err) s.Require().NoError(err)
// We add bob bundle // We add bob bundle
@ -217,12 +262,13 @@ func (s *EncryptionServiceTestSuite) TestConsequentMessagesBundle() {
s.Require().NoError(err) s.Require().NoError(err)
// We send a message using the bundle // We send a message using the bundle
_, err = s.alice.EncryptPayload(&bobKey.PublicKey, aliceKey, cleartext1) _, err = s.alice.BuildDirectMessage(aliceKey, &bobKey.PublicKey, cleartext1)
s.Require().NoError(err) s.Require().NoError(err)
// We send another message using the bundle // We send another message using the bundle
encryptionResponse, err := s.alice.EncryptPayload(&bobKey.PublicKey, aliceKey, cleartext2) response, err := s.alice.BuildDirectMessage(aliceKey, &bobKey.PublicKey, cleartext2)
s.Require().NoError(err) s.Require().NoError(err)
encryptionResponse := response.Message.GetDirectMessage()
installationResponse := encryptionResponse[bobInstallationID] installationResponse := encryptionResponse[bobInstallationID]
s.Require().NotNil(installationResponse) s.Require().NotNil(installationResponse)
@ -250,7 +296,7 @@ func (s *EncryptionServiceTestSuite) TestConsequentMessagesBundle() {
s.Equal(uint32(0), drHeader.GetPn(), "It adds the correct length of the message chain") s.Equal(uint32(0), drHeader.GetPn(), "It adds the correct length of the message chain")
// Bob is able to decrypt it using the bundle // Bob is able to decrypt it using the bundle
decryptedPayload1, err := s.bob.DecryptPayload(bobKey, &aliceKey.PublicKey, aliceInstallationID, encryptionResponse, defaultMessageID) decryptedPayload1, err := s.bob.HandleMessage(bobKey, &aliceKey.PublicKey, response.Message, defaultMessageID)
s.Require().NoError(err) s.Require().NoError(err)
s.Equal(cleartext2, decryptedPayload1, "It correctly decrypts the payload using X3DH") s.Equal(cleartext2, decryptedPayload1, "It correctly decrypts the payload using X3DH")
@ -274,11 +320,11 @@ func (s *EncryptionServiceTestSuite) TestConversation() {
s.Require().NoError(err) s.Require().NoError(err)
// Create a bundle // Create a bundle
bobBundle, err := s.bob.CreateBundle(bobKey) bobBundle, err := s.bob.GetBundle(bobKey)
s.Require().NoError(err) s.Require().NoError(err)
// Create a bundle // Create a bundle
aliceBundle, err := s.alice.CreateBundle(aliceKey) aliceBundle, err := s.alice.GetBundle(aliceKey)
s.Require().NoError(err) s.Require().NoError(err)
// We add bob bundle // We add bob bundle
@ -290,24 +336,25 @@ func (s *EncryptionServiceTestSuite) TestConversation() {
s.Require().NoError(err) s.Require().NoError(err)
// Alice sends a message // Alice sends a message
encryptionResponse, err := s.alice.EncryptPayload(&bobKey.PublicKey, aliceKey, cleartext1) response, err := s.alice.BuildDirectMessage(aliceKey, &bobKey.PublicKey, cleartext1)
s.Require().NoError(err) s.Require().NoError(err)
// Bob receives the message // Bob receives the message
_, err = s.bob.DecryptPayload(bobKey, &aliceKey.PublicKey, aliceInstallationID, encryptionResponse, defaultMessageID) _, err = s.bob.HandleMessage(bobKey, &aliceKey.PublicKey, response.Message, defaultMessageID)
s.Require().NoError(err) s.Require().NoError(err)
// Bob replies to the message // Bob replies to the message
encryptionResponse, err = s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, cleartext1) response, err = s.bob.BuildDirectMessage(bobKey, &aliceKey.PublicKey, cleartext1)
s.Require().NoError(err) s.Require().NoError(err)
// Alice receives the message // Alice receives the message
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, encryptionResponse, defaultMessageID) _, err = s.alice.HandleMessage(aliceKey, &bobKey.PublicKey, response.Message, defaultMessageID)
s.Require().NoError(err) s.Require().NoError(err)
// We send another message using the bundle // We send another message using the bundle
encryptionResponse, err = s.alice.EncryptPayload(&bobKey.PublicKey, aliceKey, cleartext2) response, err = s.alice.BuildDirectMessage(aliceKey, &bobKey.PublicKey, cleartext2)
s.Require().NoError(err) s.Require().NoError(err)
encryptionResponse := response.Message.GetDirectMessage()
installationResponse := encryptionResponse[bobInstallationID] installationResponse := encryptionResponse[bobInstallationID]
s.Require().NotNil(installationResponse) s.Require().NotNil(installationResponse)
@ -333,7 +380,7 @@ func (s *EncryptionServiceTestSuite) TestConversation() {
s.Equal(uint32(1), drHeader.GetPn(), "It adds the correct length of the message chain") s.Equal(uint32(1), drHeader.GetPn(), "It adds the correct length of the message chain")
// Bob is able to decrypt it using the bundle // Bob is able to decrypt it using the bundle
decryptedPayload1, err := s.bob.DecryptPayload(bobKey, &aliceKey.PublicKey, aliceInstallationID, encryptionResponse, defaultMessageID) decryptedPayload1, err := s.bob.HandleMessage(bobKey, &aliceKey.PublicKey, response.Message, defaultMessageID)
s.Require().NoError(err) s.Require().NoError(err)
s.Equal(cleartext2, decryptedPayload1, "It correctly decrypts the payload using X3DH") s.Equal(cleartext2, decryptedPayload1, "It correctly decrypts the payload using X3DH")
@ -354,7 +401,7 @@ func (s *EncryptionServiceTestSuite) TestMaxSkipKeys() {
s.Require().NoError(err) s.Require().NoError(err)
// Create a bundle // Create a bundle
bobBundle, err := s.bob.CreateBundle(bobKey) bobBundle, err := s.bob.GetBundle(bobKey)
s.Require().NoError(err) s.Require().NoError(err)
// We add bob bundle // We add bob bundle
@ -362,7 +409,7 @@ func (s *EncryptionServiceTestSuite) TestMaxSkipKeys() {
s.Require().NoError(err) s.Require().NoError(err)
// Create a bundle // Create a bundle
aliceBundle, err := s.alice.CreateBundle(aliceKey) aliceBundle, err := s.alice.GetBundle(aliceKey)
s.Require().NoError(err) s.Require().NoError(err)
// We add alice bundle // We add alice bundle
@ -371,30 +418,30 @@ func (s *EncryptionServiceTestSuite) TestMaxSkipKeys() {
// Bob sends a message // Bob sends a message
for i := 0; i < s.alice.config.MaxSkip; i++ { for i := 0; i < s.alice.encryption.config.MaxSkip; i++ {
_, err = s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText) _, err = s.bob.BuildDirectMessage(bobKey, &aliceKey.PublicKey, bobText)
s.Require().NoError(err) s.Require().NoError(err)
} }
// Bob sends a message // Bob sends a message
bobMessage1, err := s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText) bobMessage1, err := s.bob.BuildDirectMessage(bobKey, &aliceKey.PublicKey, bobText)
s.Require().NoError(err) s.Require().NoError(err)
// Alice receives the message // Alice receives the message
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage1, defaultMessageID) _, err = s.alice.HandleMessage(aliceKey, &bobKey.PublicKey, bobMessage1.Message, defaultMessageID)
s.Require().NoError(err) s.Require().NoError(err)
// Bob sends a message // Bob sends a message
_, err = s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText) _, err = s.bob.BuildDirectMessage(bobKey, &aliceKey.PublicKey, bobText)
s.Require().NoError(err) s.Require().NoError(err)
// Bob sends a message // Bob sends a message
bobMessage2, err := s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText) bobMessage2, err := s.bob.BuildDirectMessage(bobKey, &aliceKey.PublicKey, bobText)
s.Require().NoError(err) s.Require().NoError(err)
// Alice receives the message, we should have maxSkip + 1 keys in the db, but // Alice receives the message, we should have maxSkip + 1 keys in the db, but
// we should not throw an error // we should not throw an error
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage2, defaultMessageID) _, err = s.alice.HandleMessage(aliceKey, &bobKey.PublicKey, bobMessage2.Message, defaultMessageID)
s.Require().NoError(err) s.Require().NoError(err)
} }
@ -409,7 +456,7 @@ func (s *EncryptionServiceTestSuite) TestMaxSkipKeysError() {
s.Require().NoError(err) s.Require().NoError(err)
// Create a bundle // Create a bundle
bobBundle, err := s.bob.CreateBundle(bobKey) bobBundle, err := s.bob.GetBundle(bobKey)
s.Require().NoError(err) s.Require().NoError(err)
// We add bob bundle // We add bob bundle
@ -417,7 +464,7 @@ func (s *EncryptionServiceTestSuite) TestMaxSkipKeysError() {
s.Require().NoError(err) s.Require().NoError(err)
// Create a bundle // Create a bundle
aliceBundle, err := s.alice.CreateBundle(aliceKey) aliceBundle, err := s.alice.GetBundle(aliceKey)
s.Require().NoError(err) s.Require().NoError(err)
// We add alice bundle // We add alice bundle
@ -426,17 +473,17 @@ func (s *EncryptionServiceTestSuite) TestMaxSkipKeysError() {
// Bob sends a message // Bob sends a message
for i := 0; i < s.alice.config.MaxSkip+1; i++ { for i := 0; i < s.alice.encryption.config.MaxSkip+1; i++ {
_, err = s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText) _, err = s.bob.BuildDirectMessage(bobKey, &aliceKey.PublicKey, bobText)
s.Require().NoError(err) s.Require().NoError(err)
} }
// Bob sends a message // Bob sends a message
bobMessage1, err := s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText) bobMessage1, err := s.bob.BuildDirectMessage(bobKey, &aliceKey.PublicKey, bobText)
s.Require().NoError(err) s.Require().NoError(err)
// Alice receives the message // Alice receives the message
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage1, defaultMessageID) _, err = s.alice.HandleMessage(aliceKey, &bobKey.PublicKey, bobMessage1.Message, defaultMessageID)
s.Require().Equal(errors.New("can't skip current chain message keys: too many messages"), err) s.Require().Equal(errors.New("can't skip current chain message keys: too many messages"), err)
} }
@ -457,7 +504,7 @@ func (s *EncryptionServiceTestSuite) TestMaxMessageKeysPerSession() {
s.Require().NoError(err) s.Require().NoError(err)
// Create a bundle // Create a bundle
bobBundle, err := s.bob.CreateBundle(bobKey) bobBundle, err := s.bob.GetBundle(bobKey)
s.Require().NoError(err) s.Require().NoError(err)
// We add bob bundle // We add bob bundle
@ -465,7 +512,7 @@ func (s *EncryptionServiceTestSuite) TestMaxMessageKeysPerSession() {
s.Require().NoError(err) s.Require().NoError(err)
// Create a bundle // Create a bundle
aliceBundle, err := s.alice.CreateBundle(aliceKey) aliceBundle, err := s.alice.GetBundle(aliceKey)
s.Require().NoError(err) s.Require().NoError(err)
// We add alice bundle // We add alice bundle
@ -474,27 +521,27 @@ func (s *EncryptionServiceTestSuite) TestMaxMessageKeysPerSession() {
// We create just enough messages so that the first key should be deleted // We create just enough messages so that the first key should be deleted
nMessages := s.alice.config.MaxMessageKeysPerSession nMessages := s.alice.encryption.config.MaxMessageKeysPerSession
messages := make([]map[string]*DirectMessageProtocol, nMessages) messages := make([]*protobuf.ProtocolMessage, nMessages)
for i := 0; i < nMessages; i++ { for i := 0; i < nMessages; i++ {
m, err := s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText) m, err := s.bob.BuildDirectMessage(bobKey, &aliceKey.PublicKey, bobText)
s.Require().NoError(err) s.Require().NoError(err)
messages[i] = m messages[i] = m.Message
} }
// Another message to trigger the deletion // Another message to trigger the deletion
m, err := s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText) m, err := s.bob.BuildDirectMessage(bobKey, &aliceKey.PublicKey, bobText)
s.Require().NoError(err) s.Require().NoError(err)
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, m, defaultMessageID) _, err = s.alice.HandleMessage(aliceKey, &bobKey.PublicKey, m.Message, defaultMessageID)
s.Require().NoError(err) s.Require().NoError(err)
// We decrypt the first message, and it should fail // We decrypt the first message, and it should fail
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, messages[0], defaultMessageID) _, err = s.alice.HandleMessage(aliceKey, &bobKey.PublicKey, messages[0], defaultMessageID)
s.Require().Equal(errors.New("can't skip current chain message keys: bad until: probably an out-of-order message that was deleted"), err) s.Require().Equal(errors.New("can't skip current chain message keys: bad until: probably an out-of-order message that was deleted"), err)
// We decrypt the second message, and it should be decrypted // We decrypt the second message, and it should be decrypted
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, messages[1], defaultMessageID) _, err = s.alice.HandleMessage(aliceKey, &bobKey.PublicKey, messages[1], defaultMessageID)
s.Require().NoError(err) s.Require().NoError(err)
} }
@ -514,7 +561,7 @@ func (s *EncryptionServiceTestSuite) TestMaxKeep() {
s.Require().NoError(err) s.Require().NoError(err)
// Create a bundle // Create a bundle
bobBundle, err := s.bob.CreateBundle(bobKey) bobBundle, err := s.bob.GetBundle(bobKey)
s.Require().NoError(err) s.Require().NoError(err)
// We add bob bundle // We add bob bundle
@ -522,7 +569,7 @@ func (s *EncryptionServiceTestSuite) TestMaxKeep() {
s.Require().NoError(err) s.Require().NoError(err)
// Create a bundle // Create a bundle
aliceBundle, err := s.alice.CreateBundle(aliceKey) aliceBundle, err := s.alice.GetBundle(aliceKey)
s.Require().NoError(err) s.Require().NoError(err)
// We add alice bundle // We add alice bundle
@ -530,15 +577,15 @@ func (s *EncryptionServiceTestSuite) TestMaxKeep() {
s.Require().NoError(err) s.Require().NoError(err)
// We decrypt all messages but 1 & 2 // We decrypt all messages but 1 & 2
messages := make([]map[string]*DirectMessageProtocol, s.alice.config.MaxKeep) messages := make([]*protobuf.ProtocolMessage, s.alice.encryption.config.MaxKeep)
for i := 0; i < s.alice.config.MaxKeep; i++ { for i := 0; i < s.alice.encryption.config.MaxKeep; i++ {
m, err := s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText) m, err := s.bob.BuildDirectMessage(bobKey, &aliceKey.PublicKey, bobText)
messages[i] = m messages[i] = m.Message
s.Require().NoError(err) s.Require().NoError(err)
if i != 0 && i != 1 { if i != 0 && i != 1 {
messageID := []byte(fmt.Sprintf("%d", i)) messageID := []byte(fmt.Sprintf("%d", i))
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, m, messageID) _, err = s.alice.HandleMessage(aliceKey, &bobKey.PublicKey, m.Message, messageID)
s.Require().NoError(err) s.Require().NoError(err)
err = s.alice.ConfirmMessagesProcessed([][]byte{messageID}) err = s.alice.ConfirmMessagesProcessed([][]byte{messageID})
s.Require().NoError(err) s.Require().NoError(err)
@ -547,11 +594,11 @@ func (s *EncryptionServiceTestSuite) TestMaxKeep() {
} }
// We decrypt the first message, and it should fail, as it should have been removed // We decrypt the first message, and it should fail, as it should have been removed
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, messages[0], defaultMessageID) _, err = s.alice.HandleMessage(aliceKey, &bobKey.PublicKey, messages[0], defaultMessageID)
s.Require().Equal(errors.New("can't skip current chain message keys: bad until: probably an out-of-order message that was deleted"), err) s.Require().Equal(errors.New("can't skip current chain message keys: bad until: probably an out-of-order message that was deleted"), err)
// We decrypt the second message, and it should be decrypted // We decrypt the second message, and it should be decrypted
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, messages[1], defaultMessageID) _, err = s.alice.HandleMessage(aliceKey, &bobKey.PublicKey, messages[1], defaultMessageID)
s.Require().NoError(err) s.Require().NoError(err)
} }
@ -576,7 +623,7 @@ func (s *EncryptionServiceTestSuite) TestConcurrentBundles() {
s.Require().NoError(err) s.Require().NoError(err)
// Create a bundle // Create a bundle
bobBundle, err := s.bob.CreateBundle(bobKey) bobBundle, err := s.bob.GetBundle(bobKey)
s.Require().NoError(err) s.Require().NoError(err)
// We add bob bundle // We add bob bundle
@ -584,7 +631,7 @@ func (s *EncryptionServiceTestSuite) TestConcurrentBundles() {
s.Require().NoError(err) s.Require().NoError(err)
// Create a bundle // Create a bundle
aliceBundle, err := s.alice.CreateBundle(aliceKey) aliceBundle, err := s.alice.GetBundle(aliceKey)
s.Require().NoError(err) s.Require().NoError(err)
// We add alice bundle // We add alice bundle
@ -592,44 +639,44 @@ func (s *EncryptionServiceTestSuite) TestConcurrentBundles() {
s.Require().NoError(err) s.Require().NoError(err)
// Alice sends a message // Alice sends a message
aliceMessage1, err := s.alice.EncryptPayload(&bobKey.PublicKey, aliceKey, aliceText1) aliceMessage1, err := s.alice.BuildDirectMessage(aliceKey, &bobKey.PublicKey, aliceText1)
s.Require().NoError(err) s.Require().NoError(err)
// Bob sends a message // Bob sends a message
bobMessage1, err := s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText1) bobMessage1, err := s.bob.BuildDirectMessage(bobKey, &aliceKey.PublicKey, bobText1)
s.Require().NoError(err) s.Require().NoError(err)
// Bob receives the message // Bob receives the message
_, err = s.bob.DecryptPayload(bobKey, &aliceKey.PublicKey, aliceInstallationID, aliceMessage1, defaultMessageID) _, err = s.bob.HandleMessage(bobKey, &aliceKey.PublicKey, aliceMessage1.Message, defaultMessageID)
s.Require().NoError(err) s.Require().NoError(err)
// Alice receives the message // Alice receives the message
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage1, defaultMessageID) _, err = s.alice.HandleMessage(aliceKey, &bobKey.PublicKey, bobMessage1.Message, defaultMessageID)
s.Require().NoError(err) s.Require().NoError(err)
// Bob replies to the message // Bob replies to the message
bobMessage2, err := s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText2) bobMessage2, err := s.bob.BuildDirectMessage(bobKey, &aliceKey.PublicKey, bobText2)
s.Require().NoError(err) s.Require().NoError(err)
// Alice sends a message // Alice sends a message
aliceMessage2, err := s.alice.EncryptPayload(&bobKey.PublicKey, aliceKey, aliceText2) aliceMessage2, err := s.alice.BuildDirectMessage(aliceKey, &bobKey.PublicKey, aliceText2)
s.Require().NoError(err) s.Require().NoError(err)
// Alice receives the message // Alice receives the message
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage2, defaultMessageID) _, err = s.alice.HandleMessage(aliceKey, &bobKey.PublicKey, bobMessage2.Message, defaultMessageID)
s.Require().NoError(err) s.Require().NoError(err)
// Bob receives the message // Bob receives the message
_, err = s.bob.DecryptPayload(bobKey, &aliceKey.PublicKey, aliceInstallationID, aliceMessage2, defaultMessageID) _, err = s.bob.HandleMessage(bobKey, &aliceKey.PublicKey, aliceMessage2.Message, defaultMessageID)
s.Require().NoError(err) s.Require().NoError(err)
} }
func publisher( func publisher(
e *EncryptionService, e *ProtocolService,
privateKey *ecdsa.PrivateKey, privateKey *ecdsa.PrivateKey,
publicKey *ecdsa.PublicKey, publicKey *ecdsa.PublicKey,
errChan chan error, errChan chan error,
output chan map[string]*DirectMessageProtocol, output chan *protobuf.ProtocolMessage,
) { ) {
var wg sync.WaitGroup var wg sync.WaitGroup
@ -642,13 +689,13 @@ func publisher(
go func() { go func() {
defer wg.Done() defer wg.Done()
time.Sleep(time.Duration(rand.Intn(50)) * time.Millisecond) time.Sleep(time.Duration(rand.Intn(50)) * time.Millisecond)
response, err := e.EncryptPayload(publicKey, privateKey, cleartext) response, err := e.BuildDirectMessage(privateKey, publicKey, cleartext)
if err != nil { if err != nil {
errChan <- err errChan <- err
return return
} }
output <- response output <- response.Message
}() }()
} }
} }
@ -658,17 +705,16 @@ func publisher(
} }
func receiver( func receiver(
s *EncryptionService, s *ProtocolService,
privateKey *ecdsa.PrivateKey, privateKey *ecdsa.PrivateKey,
publicKey *ecdsa.PublicKey, publicKey *ecdsa.PublicKey,
installationID string,
errChan chan error, errChan chan error,
input chan map[string]*DirectMessageProtocol, input chan *protobuf.ProtocolMessage,
) { ) {
i := 0 i := 0
for payload := range input { for payload := range input {
actualCleartext, err := s.DecryptPayload(privateKey, publicKey, installationID, payload, defaultMessageID) actualCleartext, err := s.HandleMessage(privateKey, publicKey, payload, defaultMessageID)
if err != nil { if err != nil {
errChan <- err errChan <- err
return return
@ -697,7 +743,7 @@ func (s *EncryptionServiceTestSuite) TestRandomised() {
s.Require().NoError(err) s.Require().NoError(err)
// Create a bundle // Create a bundle
bobBundle, err := s.bob.CreateBundle(bobKey) bobBundle, err := s.bob.GetBundle(bobKey)
s.Require().NoError(err) s.Require().NoError(err)
// We add bob bundle // We add bob bundle
@ -705,15 +751,15 @@ func (s *EncryptionServiceTestSuite) TestRandomised() {
s.Require().NoError(err) s.Require().NoError(err)
// Create a bundle // Create a bundle
aliceBundle, err := s.alice.CreateBundle(aliceKey) aliceBundle, err := s.alice.GetBundle(aliceKey)
s.Require().NoError(err) s.Require().NoError(err)
// We add alice bundle // We add alice bundle
_, err = s.bob.ProcessPublicBundle(bobKey, aliceBundle) _, err = s.bob.ProcessPublicBundle(bobKey, aliceBundle)
s.Require().NoError(err) s.Require().NoError(err)
aliceChan := make(chan map[string]*DirectMessageProtocol, 100) aliceChan := make(chan *protobuf.ProtocolMessage, 100)
bobChan := make(chan map[string]*DirectMessageProtocol, 100) bobChan := make(chan *protobuf.ProtocolMessage, 100)
alicePublisherErrChan := make(chan error, 1) alicePublisherErrChan := make(chan error, 1)
bobPublisherErrChan := make(chan error, 1) bobPublisherErrChan := make(chan error, 1)
@ -727,10 +773,10 @@ func (s *EncryptionServiceTestSuite) TestRandomised() {
go publisher(s.bob, bobKey, &aliceKey.PublicKey, bobPublisherErrChan, aliceChan) go publisher(s.bob, bobKey, &aliceKey.PublicKey, bobPublisherErrChan, aliceChan)
// Set up bob receiver // Set up bob receiver
go receiver(s.bob, bobKey, &aliceKey.PublicKey, aliceInstallationID, bobReceiverErrChan, bobChan) go receiver(s.bob, bobKey, &aliceKey.PublicKey, bobReceiverErrChan, bobChan)
// Set up alice receiver // Set up alice receiver
go receiver(s.alice, aliceKey, &bobKey.PublicKey, bobInstallationID, aliceReceiverErrChan, aliceChan) go receiver(s.alice, aliceKey, &bobKey.PublicKey, aliceReceiverErrChan, aliceChan)
aliceErr := <-alicePublisherErrChan aliceErr := <-alicePublisherErrChan
s.Require().NoError(aliceErr) s.Require().NoError(aliceErr)
@ -771,11 +817,11 @@ func (s *EncryptionServiceTestSuite) TestBundleNotExisting() {
s.Require().NoError(err) s.Require().NoError(err)
// Alice sends a message // Alice sends a message
aliceMessage, err := s.alice.EncryptPayload(&bobKey.PublicKey, aliceKey, aliceText) aliceMessage, err := s.alice.BuildDirectMessage(aliceKey, &bobKey.PublicKey, aliceText)
s.Require().NoError(err) s.Require().NoError(err)
// Bob receives the message, and returns a bundlenotfound error // Bob receives the message, and returns a bundlenotfound error
_, err = s.bob.DecryptPayload(bobKey, &aliceKey.PublicKey, aliceInstallationID, aliceMessage, defaultMessageID) _, err = s.bob.HandleMessage(bobKey, &aliceKey.PublicKey, aliceMessage.Message, defaultMessageID)
s.Require().Error(err) s.Require().Error(err)
s.Equal(ErrSessionNotFound, err) s.Equal(ErrSessionNotFound, err)
} }
@ -804,11 +850,11 @@ func (s *EncryptionServiceTestSuite) TestDeviceNotIncluded() {
s.Require().NoError(err) s.Require().NoError(err)
// Alice sends a message // Alice sends a message
aliceMessage, err := s.alice.EncryptPayload(&bobKey.PublicKey, aliceKey, []byte("does not matter")) aliceMessage, err := s.alice.BuildDirectMessage(aliceKey, &bobKey.PublicKey, []byte("does not matter"))
s.Require().NoError(err) s.Require().NoError(err)
// Bob receives the message, and returns a bundlenotfound error // Bob receives the message, and returns a bundlenotfound error
_, err = s.bob.DecryptPayload(bobKey, &aliceKey.PublicKey, aliceInstallationID, aliceMessage, defaultMessageID) _, err = s.bob.HandleMessage(bobKey, &aliceKey.PublicKey, aliceMessage.Message, defaultMessageID)
s.Require().Error(err) s.Require().Error(err)
s.Equal(ErrDeviceNotFound, err) s.Equal(ErrDeviceNotFound, err)
} }
@ -829,7 +875,7 @@ func (s *EncryptionServiceTestSuite) TestRefreshedBundle() {
s.Require().NoError(err) s.Require().NoError(err)
// Create bundles // Create bundles
bobBundle1, err := s.bob.CreateBundle(bobKey) bobBundle1, err := s.bob.GetBundle(bobKey)
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Equal(uint32(1), bobBundle1.GetSignedPreKeys()[bobInstallationID].GetVersion()) s.Require().Equal(uint32(1), bobBundle1.GetSignedPreKeys()[bobInstallationID].GetVersion())
@ -837,7 +883,7 @@ func (s *EncryptionServiceTestSuite) TestRefreshedBundle() {
time.Sleep(time.Duration(config.BundleRefreshInterval) * time.Millisecond) time.Sleep(time.Duration(config.BundleRefreshInterval) * time.Millisecond)
// Create bundles // Create bundles
bobBundle2, err := s.bob.CreateBundle(bobKey) bobBundle2, err := s.bob.GetBundle(bobKey)
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Equal(uint32(2), bobBundle2.GetSignedPreKeys()[bobInstallationID].GetVersion()) s.Require().Equal(uint32(2), bobBundle2.GetSignedPreKeys()[bobInstallationID].GetVersion())
@ -846,8 +892,9 @@ func (s *EncryptionServiceTestSuite) TestRefreshedBundle() {
s.Require().NoError(err) s.Require().NoError(err)
// Alice sends a message // Alice sends a message
encryptionResponse1, err := s.alice.EncryptPayload(&bobKey.PublicKey, aliceKey, []byte("anything")) response1, err := s.alice.BuildDirectMessage(aliceKey, &bobKey.PublicKey, []byte("anything"))
s.Require().NoError(err) s.Require().NoError(err)
encryptionResponse1 := response1.Message.GetDirectMessage()
installationResponse1 := encryptionResponse1[bobInstallationID] installationResponse1 := encryptionResponse1[bobInstallationID]
s.Require().NotNil(installationResponse1) s.Require().NotNil(installationResponse1)
@ -859,7 +906,7 @@ func (s *EncryptionServiceTestSuite) TestRefreshedBundle() {
s.Equal(bobBundle1.GetSignedPreKeys()[bobInstallationID].GetSignedPreKey(), x3dhHeader1.GetId()) s.Equal(bobBundle1.GetSignedPreKeys()[bobInstallationID].GetSignedPreKey(), x3dhHeader1.GetId())
// Bob decrypts the message // Bob decrypts the message
_, err = s.bob.DecryptPayload(bobKey, &aliceKey.PublicKey, aliceInstallationID, encryptionResponse1, defaultMessageID) _, err = s.bob.HandleMessage(bobKey, &aliceKey.PublicKey, response1.Message, defaultMessageID)
s.Require().NoError(err) s.Require().NoError(err)
// We add the second bob bundle // We add the second bob bundle
@ -867,8 +914,9 @@ func (s *EncryptionServiceTestSuite) TestRefreshedBundle() {
s.Require().NoError(err) s.Require().NoError(err)
// Alice sends a message // Alice sends a message
encryptionResponse2, err := s.alice.EncryptPayload(&bobKey.PublicKey, aliceKey, []byte("anything")) response2, err := s.alice.BuildDirectMessage(aliceKey, &bobKey.PublicKey, []byte("anything"))
s.Require().NoError(err) s.Require().NoError(err)
encryptionResponse2 := response2.Message.GetDirectMessage()
installationResponse2 := encryptionResponse2[bobInstallationID] installationResponse2 := encryptionResponse2[bobInstallationID]
s.Require().NotNil(installationResponse2) s.Require().NotNil(installationResponse2)
@ -880,7 +928,7 @@ func (s *EncryptionServiceTestSuite) TestRefreshedBundle() {
s.Equal(bobBundle2.GetSignedPreKeys()[bobInstallationID].GetSignedPreKey(), x3dhHeader2.GetId()) s.Equal(bobBundle2.GetSignedPreKeys()[bobInstallationID].GetSignedPreKey(), x3dhHeader2.GetId())
// Bob decrypts the message // Bob decrypts the message
_, err = s.bob.DecryptPayload(bobKey, &aliceKey.PublicKey, aliceInstallationID, encryptionResponse2, defaultMessageID) _, err = s.bob.HandleMessage(bobKey, &aliceKey.PublicKey, response2.Message, defaultMessageID)
s.Require().NoError(err) s.Require().NoError(err)
} }
@ -894,7 +942,7 @@ func (s *EncryptionServiceTestSuite) TestMessageConfirmation() {
s.Require().NoError(err) s.Require().NoError(err)
// Create a bundle // Create a bundle
bobBundle, err := s.bob.CreateBundle(bobKey) bobBundle, err := s.bob.GetBundle(bobKey)
s.Require().NoError(err) s.Require().NoError(err)
// We add bob bundle // We add bob bundle
@ -902,7 +950,7 @@ func (s *EncryptionServiceTestSuite) TestMessageConfirmation() {
s.Require().NoError(err) s.Require().NoError(err)
// Create a bundle // Create a bundle
aliceBundle, err := s.alice.CreateBundle(aliceKey) aliceBundle, err := s.alice.GetBundle(aliceKey)
s.Require().NoError(err) s.Require().NoError(err)
// We add alice bundle // We add alice bundle
@ -910,16 +958,16 @@ func (s *EncryptionServiceTestSuite) TestMessageConfirmation() {
s.Require().NoError(err) s.Require().NoError(err)
// Bob sends a message // Bob sends a message
bobMessage1, err := s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText1) bobMessage1, err := s.bob.BuildDirectMessage(bobKey, &aliceKey.PublicKey, bobText1)
s.Require().NoError(err) s.Require().NoError(err)
bobMessage1ID := []byte("bob-message-1-id") bobMessage1ID := []byte("bob-message-1-id")
// Alice receives the message once // Alice receives the message once
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage1, bobMessage1ID) _, err = s.alice.HandleMessage(aliceKey, &bobKey.PublicKey, bobMessage1.Message, bobMessage1ID)
s.Require().NoError(err) s.Require().NoError(err)
// Alice receives the message twice // Alice receives the message twice
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage1, bobMessage1ID) _, err = s.alice.HandleMessage(aliceKey, &bobKey.PublicKey, bobMessage1.Message, bobMessage1ID)
s.Require().NoError(err) s.Require().NoError(err)
// Alice confirms the message // Alice confirms the message
@ -927,33 +975,33 @@ func (s *EncryptionServiceTestSuite) TestMessageConfirmation() {
s.Require().NoError(err) s.Require().NoError(err)
// Alice decrypts it again, it should fail // Alice decrypts it again, it should fail
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage1, bobMessage1ID) _, err = s.alice.HandleMessage(aliceKey, &bobKey.PublicKey, bobMessage1.Message, bobMessage1ID)
s.Require().Equal(errors.New("can't skip current chain message keys: bad until: probably an out-of-order message that was deleted"), err) s.Require().Equal(errors.New("can't skip current chain message keys: bad until: probably an out-of-order message that was deleted"), err)
// Bob sends a message // Bob sends a message
bobMessage2, err := s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText1) bobMessage2, err := s.bob.BuildDirectMessage(bobKey, &aliceKey.PublicKey, bobText1)
s.Require().NoError(err) s.Require().NoError(err)
bobMessage2ID := []byte("bob-message-2-id") bobMessage2ID := []byte("bob-message-2-id")
// Bob sends a message // Bob sends a message
bobMessage3, err := s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText1) bobMessage3, err := s.bob.BuildDirectMessage(bobKey, &aliceKey.PublicKey, bobText1)
s.Require().NoError(err) s.Require().NoError(err)
bobMessage3ID := []byte("bob-message-3-id") bobMessage3ID := []byte("bob-message-3-id")
// Alice receives message 3 once // Alice receives message 3 once
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage3, bobMessage3ID) _, err = s.alice.HandleMessage(aliceKey, &bobKey.PublicKey, bobMessage3.Message, bobMessage3ID)
s.Require().NoError(err) s.Require().NoError(err)
// Alice receives message 3 twice // Alice receives message 3 twice
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage3, bobMessage3ID) _, err = s.alice.HandleMessage(aliceKey, &bobKey.PublicKey, bobMessage3.Message, bobMessage3ID)
s.Require().NoError(err) s.Require().NoError(err)
// Alice receives message 2 once // Alice receives message 2 once
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage2, bobMessage2ID) _, err = s.alice.HandleMessage(aliceKey, &bobKey.PublicKey, bobMessage2.Message, bobMessage2ID)
s.Require().NoError(err) s.Require().NoError(err)
// Alice receives message 2 twice // Alice receives message 2 twice
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage2, bobMessage2ID) _, err = s.alice.HandleMessage(aliceKey, &bobKey.PublicKey, bobMessage2.Message, bobMessage2ID)
s.Require().NoError(err) s.Require().NoError(err)
// Alice confirms the messages // Alice confirms the messages
@ -961,10 +1009,10 @@ func (s *EncryptionServiceTestSuite) TestMessageConfirmation() {
s.Require().NoError(err) s.Require().NoError(err)
// Alice decrypts it again, it should fail // Alice decrypts it again, it should fail
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage3, bobMessage3ID) _, err = s.alice.HandleMessage(aliceKey, &bobKey.PublicKey, bobMessage3.Message, bobMessage3ID)
s.Require().Equal(errors.New("can't skip current chain message keys: bad until: probably an out-of-order message that was deleted"), err) s.Require().Equal(errors.New("can't skip current chain message keys: bad until: probably an out-of-order message that was deleted"), err)
// Alice decrypts it again, it should fail // Alice decrypts it again, it should fail
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage2, bobMessage2ID) _, err = s.alice.HandleMessage(aliceKey, &bobKey.PublicKey, bobMessage2.Message, bobMessage2ID)
s.Require().Equal(errors.New("can't skip current chain message keys: bad until: probably an out-of-order message that was deleted"), err) s.Require().Equal(errors.New("can't skip current chain message keys: bad until: probably an out-of-order message that was deleted"), err)
} }

View File

@ -0,0 +1,12 @@
package multidevice
type Persistence interface {
// GetActiveInstallations returns the active installations for a given identity.
GetActiveInstallations(maxInstallations int, identity []byte) ([]*Installation, error)
// EnableInstallation enables the installation.
EnableInstallation(identity []byte, installationID string) error
// DisableInstallation disable the installation.
DisableInstallation(identity []byte, installationID string) error
// AddInstallations adds the installations for a given identity, maintaining the enabled flag
AddInstallations(identity []byte, timestamp int64, installations []*Installation, defaultEnabled bool) error
}

View File

@ -0,0 +1,93 @@
package multidevice
import (
"crypto/ecdsa"
"fmt"
"github.com/ethereum/go-ethereum/crypto"
"github.com/status-im/status-go/services/shhext/chat/protobuf"
)
type Installation struct {
ID string
Version uint32
}
type Config struct {
MaxInstallations int
ProtocolVersion uint32
InstallationID string
}
func New(config *Config, persistence Persistence) *Service {
return &Service{
config: config,
persistence: persistence,
}
}
type Service struct {
persistence Persistence
config *Config
}
type IdentityAndIDPair [2]string
func (s *Service) GetActiveInstallations(identity *ecdsa.PublicKey) ([]*Installation, error) {
identityC := crypto.CompressPubkey(identity)
return s.persistence.GetActiveInstallations(s.config.MaxInstallations, identityC)
}
func (s *Service) GetOurActiveInstallations(identity *ecdsa.PublicKey) ([]*Installation, error) {
identityC := crypto.CompressPubkey(identity)
installations, err := s.persistence.GetActiveInstallations(s.config.MaxInstallations-1, identityC)
if err != nil {
return nil, err
}
// Move to layer above
installations = append(installations, &Installation{
ID: s.config.InstallationID,
Version: s.config.ProtocolVersion,
})
return installations, nil
}
func (s *Service) EnableInstallation(identity *ecdsa.PublicKey, installationID string) error {
identityC := crypto.CompressPubkey(identity)
return s.persistence.EnableInstallation(identityC, installationID)
}
func (s *Service) DisableInstallation(myIdentityKey *ecdsa.PublicKey, installationID string) error {
myIdentityKeyC := crypto.CompressPubkey(myIdentityKey)
return s.persistence.DisableInstallation(myIdentityKeyC, installationID)
}
// ProcessPublicBundle persists a bundle and returns a list of tuples identity/installationID
func (s *Service) ProcessPublicBundle(myIdentityKey *ecdsa.PrivateKey, theirIdentity *ecdsa.PublicKey, b *protobuf.Bundle) ([]IdentityAndIDPair, error) {
signedPreKeys := b.GetSignedPreKeys()
var response []IdentityAndIDPair
var installations []*Installation
myIdentityStr := fmt.Sprintf("0x%x", crypto.FromECDSAPub(&myIdentityKey.PublicKey))
theirIdentityStr := fmt.Sprintf("0x%x", crypto.FromECDSAPub(theirIdentity))
// Any device from other peers will be considered enabled, ours needs to
// be explicitly enabled
fromOurIdentity := theirIdentityStr != myIdentityStr
for installationID, signedPreKey := range signedPreKeys {
if installationID != s.config.InstallationID {
installations = append(installations, &Installation{
ID: installationID,
Version: signedPreKey.GetProtocolVersion(),
})
response = append(response, IdentityAndIDPair{theirIdentityStr, installationID})
}
}
if err := s.persistence.AddInstallations(b.GetIdentity(), b.GetTimestamp(), installations, fromOurIdentity); err != nil {
return nil, err
}
return response, nil
}

View File

@ -0,0 +1,168 @@
package multidevice
import (
"database/sql"
)
// SQLLitePersistence represents a persistence service tied to an SQLite database
type SQLLitePersistence struct {
db *sql.DB
}
// NewSQLLitePersistence creates a new SQLLitePersistence instance, given a path and a key
func NewSQLLitePersistence(db *sql.DB) *SQLLitePersistence {
return &SQLLitePersistence{db: db}
}
// GetActiveInstallations returns the active installations for a given identity
func (s *SQLLitePersistence) GetActiveInstallations(maxInstallations int, identity []byte) ([]*Installation, error) {
stmt, err := s.db.Prepare(`SELECT installation_id, version
FROM installations
WHERE enabled = 1 AND identity = ?
ORDER BY timestamp DESC
LIMIT ?`)
if err != nil {
return nil, err
}
var installations []*Installation
rows, err := stmt.Query(identity, maxInstallations)
if err != nil {
return nil, err
}
for rows.Next() {
var installationID string
var version uint32
err = rows.Scan(
&installationID,
&version,
)
if err != nil {
return nil, err
}
installations = append(installations, &Installation{
ID: installationID,
Version: version,
})
}
return installations, nil
}
// AddInstallations adds the installations for a given identity, maintaining the enabled flag
func (s *SQLLitePersistence) AddInstallations(identity []byte, timestamp int64, installations []*Installation, defaultEnabled bool) error {
tx, err := s.db.Begin()
if err != nil {
return nil
}
for _, installation := range installations {
stmt, err := tx.Prepare(`SELECT enabled, version
FROM installations
WHERE identity = ? AND installation_id = ?
LIMIT 1`)
if err != nil {
return err
}
defer stmt.Close()
var oldEnabled bool
// We don't override version once we saw one
var oldVersion uint32
latestVersion := installation.Version
err = stmt.QueryRow(identity, installation.ID).Scan(&oldEnabled, &oldVersion)
if err != nil && err != sql.ErrNoRows {
return err
}
// We update timestamp if present without changing enabled, only if this is a new bundle
// and we set the version to the latest we ever saw
if err != sql.ErrNoRows {
if oldVersion > installation.Version {
latestVersion = oldVersion
}
stmt, err = tx.Prepare(`UPDATE installations
SET timestamp = ?, enabled = ?, version = ?
WHERE identity = ?
AND installation_id = ?
AND timestamp < ?`)
if err != nil {
return err
}
_, err = stmt.Exec(
timestamp,
oldEnabled,
latestVersion,
identity,
installation.ID,
timestamp,
)
if err != nil {
return err
}
defer stmt.Close()
} else {
stmt, err = tx.Prepare(`INSERT INTO installations(identity, installation_id, timestamp, enabled, version)
VALUES (?, ?, ?, ?, ?)`)
if err != nil {
return err
}
_, err = stmt.Exec(
identity,
installation.ID,
timestamp,
defaultEnabled,
latestVersion,
)
if err != nil {
return err
}
defer stmt.Close()
}
}
if err := tx.Commit(); err != nil {
_ = tx.Rollback()
return err
}
return nil
}
// EnableInstallation enables the installation
func (s *SQLLitePersistence) EnableInstallation(identity []byte, installationID string) error {
stmt, err := s.db.Prepare(`UPDATE installations
SET enabled = 1
WHERE identity = ? AND installation_id = ?`)
if err != nil {
return err
}
_, err = stmt.Exec(identity, installationID)
return err
}
// DisableInstallation disable the installation
func (s *SQLLitePersistence) DisableInstallation(identity []byte, installationID string) error {
stmt, err := s.db.Prepare(`UPDATE installations
SET enabled = 0
WHERE identity = ? AND installation_id = ?`)
if err != nil {
return err
}
_, err = stmt.Exec(identity, installationID)
return err
}

View File

@ -0,0 +1,241 @@
package multidevice
import (
"database/sql"
"os"
"testing"
appDB "github.com/status-im/status-go/services/shhext/chat/db"
"github.com/stretchr/testify/suite"
)
const (
dbPath = "/tmp/status-key-store.db"
)
func TestSQLLitePersistenceTestSuite(t *testing.T) {
suite.Run(t, new(SQLLitePersistenceTestSuite))
}
type SQLLitePersistenceTestSuite struct {
suite.Suite
// nolint: structcheck, megacheck
db *sql.DB
service Persistence
}
func (s *SQLLitePersistenceTestSuite) SetupTest() {
os.Remove(dbPath)
db, err := appDB.Open(dbPath, "", 0)
s.Require().NoError(err)
s.service = NewSQLLitePersistence(db)
}
func (s *SQLLitePersistenceTestSuite) TestAddInstallations() {
identity := []byte("alice")
installations := []*Installation{
{ID: "alice-1", Version: 1},
{ID: "alice-2", Version: 2},
}
err := s.service.AddInstallations(
identity,
1,
installations,
true,
)
s.Require().NoError(err)
enabledInstallations, err := s.service.GetActiveInstallations(5, identity)
s.Require().NoError(err)
s.Require().Equal(installations, enabledInstallations)
}
func (s *SQLLitePersistenceTestSuite) TestAddInstallationVersions() {
identity := []byte("alice")
installations := []*Installation{
{ID: "alice-1", Version: 1},
}
err := s.service.AddInstallations(
identity,
1,
installations,
true,
)
s.Require().NoError(err)
enabledInstallations, err := s.service.GetActiveInstallations(5, identity)
s.Require().NoError(err)
s.Require().Equal(installations, enabledInstallations)
installationsWithDowngradedVersion := []*Installation{
{ID: "alice-1", Version: 0},
}
err = s.service.AddInstallations(
identity,
3,
installationsWithDowngradedVersion,
true,
)
s.Require().NoError(err)
enabledInstallations, err = s.service.GetActiveInstallations(5, identity)
s.Require().NoError(err)
s.Require().Equal(installations, enabledInstallations)
}
func (s *SQLLitePersistenceTestSuite) TestAddInstallationsLimit() {
identity := []byte("alice")
installations := []*Installation{
{ID: "alice-1", Version: 1},
{ID: "alice-2", Version: 2},
}
err := s.service.AddInstallations(
identity,
1,
installations,
true,
)
s.Require().NoError(err)
installations = []*Installation{
{ID: "alice-1", Version: 1},
{ID: "alice-3", Version: 3},
}
err = s.service.AddInstallations(
identity,
2,
installations,
true,
)
s.Require().NoError(err)
installations = []*Installation{
{ID: "alice-2", Version: 2},
{ID: "alice-3", Version: 3},
{ID: "alice-4", Version: 4},
}
err = s.service.AddInstallations(
identity,
3,
installations,
true,
)
s.Require().NoError(err)
enabledInstallations, err := s.service.GetActiveInstallations(3, identity)
s.Require().NoError(err)
s.Require().Equal(installations, enabledInstallations)
}
func (s *SQLLitePersistenceTestSuite) TestAddInstallationsDisabled() {
identity := []byte("alice")
installations := []*Installation{
{ID: "alice-1", Version: 1},
{ID: "alice-2", Version: 2},
}
err := s.service.AddInstallations(
identity,
1,
installations,
false,
)
s.Require().NoError(err)
actualInstallations, err := s.service.GetActiveInstallations(3, identity)
s.Require().NoError(err)
s.Require().Nil(actualInstallations)
}
func (s *SQLLitePersistenceTestSuite) TestDisableInstallation() {
identity := []byte("alice")
installations := []*Installation{
{ID: "alice-1", Version: 1},
{ID: "alice-2", Version: 2},
}
err := s.service.AddInstallations(
identity,
1,
installations,
true,
)
s.Require().NoError(err)
err = s.service.DisableInstallation(identity, "alice-1")
s.Require().NoError(err)
// We add the installations again
installations = []*Installation{
{ID: "alice-1", Version: 1},
{ID: "alice-2", Version: 2},
}
err = s.service.AddInstallations(
identity,
1,
installations,
true,
)
s.Require().NoError(err)
actualInstallations, err := s.service.GetActiveInstallations(3, identity)
s.Require().NoError(err)
expected := []*Installation{{ID: "alice-2", Version: 2}}
s.Require().Equal(expected, actualInstallations)
}
func (s *SQLLitePersistenceTestSuite) TestEnableInstallation() {
identity := []byte("alice")
installations := []*Installation{
{ID: "alice-1", Version: 1},
{ID: "alice-2", Version: 2},
}
err := s.service.AddInstallations(
identity,
1,
installations,
true,
)
s.Require().NoError(err)
err = s.service.DisableInstallation(identity, "alice-1")
s.Require().NoError(err)
actualInstallations, err := s.service.GetActiveInstallations(3, identity)
s.Require().NoError(err)
expected := []*Installation{{ID: "alice-2", Version: 2}}
s.Require().Equal(expected, actualInstallations)
err = s.service.EnableInstallation(identity, "alice-1")
s.Require().NoError(err)
actualInstallations, err = s.service.GetActiveInstallations(3, identity)
s.Require().NoError(err)
expected = []*Installation{
{ID: "alice-1", Version: 1},
{ID: "alice-2", Version: 2},
}
s.Require().Equal(expected, actualInstallations)
}

View File

@ -4,13 +4,10 @@ import (
"crypto/ecdsa" "crypto/ecdsa"
dr "github.com/status-im/doubleratchet" dr "github.com/status-im/doubleratchet"
"github.com/status-im/status-go/services/shhext/chat/multidevice"
"github.com/status-im/status-go/services/shhext/chat/protobuf"
) )
type Installation struct {
ID string
Version uint32
}
// RatchetInfo holds the current ratchet state // RatchetInfo holds the current ratchet state
type RatchetInfo struct { type RatchetInfo struct {
ID []byte ID []byte
@ -30,17 +27,17 @@ type PersistenceService interface {
// GetSessionStorage returns the associated double ratchet SessionStorage object. // GetSessionStorage returns the associated double ratchet SessionStorage object.
GetSessionStorage() dr.SessionStorage GetSessionStorage() dr.SessionStorage
// GetPublicBundle retrieves an existing Bundle for the specified public key & installationIDs. // GetPublicBundle retrieves an existing Bundle for the specified public key & installations
GetPublicBundle(*ecdsa.PublicKey, []*Installation) (*Bundle, error) GetPublicBundle(*ecdsa.PublicKey, []*multidevice.Installation) (*protobuf.Bundle, error)
// AddPublicBundle persists a specified Bundle // AddPublicBundle persists a specified Bundle
AddPublicBundle(*Bundle) error AddPublicBundle(*protobuf.Bundle) error
// GetAnyPrivateBundle retrieves any bundle for our identity & installationIDs // GetAnyPrivateBundle retrieves any bundle for our identity & installations
GetAnyPrivateBundle([]byte, []*Installation) (*BundleContainer, error) GetAnyPrivateBundle([]byte, []*multidevice.Installation) (*protobuf.BundleContainer, error)
// GetPrivateKeyBundle retrieves a BundleContainer with the specified signed prekey. // GetPrivateKeyBundle retrieves a BundleContainer with the specified signed prekey.
GetPrivateKeyBundle([]byte) ([]byte, error) GetPrivateKeyBundle([]byte) ([]byte, error)
// AddPrivateBundle persists a BundleContainer. // AddPrivateBundle persists a BundleContainer.
AddPrivateBundle(*BundleContainer) error AddPrivateBundle(*protobuf.BundleContainer) error
// MarkBundleExpired marks a private bundle as expired, not to be used for encryption anymore. // MarkBundleExpired marks a private bundle as expired, not to be used for encryption anymore.
MarkBundleExpired([]byte) error MarkBundleExpired([]byte) error
@ -53,13 +50,4 @@ type PersistenceService interface {
// RatchetInfoConfirmed clears the ephemeral key in the RatchetInfo // RatchetInfoConfirmed clears the ephemeral key in the RatchetInfo
// associated with the specified bundle ID and interlocutor identity public key. // associated with the specified bundle ID and interlocutor identity public key.
RatchetInfoConfirmed([]byte, []byte, string) error RatchetInfoConfirmed([]byte, []byte, string) error
// GetActiveInstallations returns the active installations for a given identity.
GetActiveInstallations(maxInstallations int, identity []byte) ([]*Installation, error)
// AddInstallations adds the installations for a given identity.
AddInstallations(identity []byte, timestamp int64, installations []*Installation, enabled bool) error
// EnableInstallation enables the installation.
EnableInstallation(identity []byte, installationID string) error
// DisableInstallation disable the installation.
DisableInstallation(identity []byte, installationID string) error
} }

View File

@ -1,7 +1,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT. // Code generated by protoc-gen-go. DO NOT EDIT.
// source: encryption.proto // source: encryption.proto
package chat package protobuf
import ( import (
fmt "fmt" fmt "fmt"
@ -491,56 +491,56 @@ func (m *ProtocolMessage) GetPublicMessage() []byte {
} }
func init() { func init() {
proto.RegisterType((*SignedPreKey)(nil), "chat.SignedPreKey") proto.RegisterType((*SignedPreKey)(nil), "protobuf.SignedPreKey")
proto.RegisterType((*Bundle)(nil), "chat.Bundle") proto.RegisterType((*Bundle)(nil), "protobuf.Bundle")
proto.RegisterMapType((map[string]*SignedPreKey)(nil), "chat.Bundle.SignedPreKeysEntry") proto.RegisterMapType((map[string]*SignedPreKey)(nil), "protobuf.Bundle.SignedPreKeysEntry")
proto.RegisterType((*BundleContainer)(nil), "chat.BundleContainer") proto.RegisterType((*BundleContainer)(nil), "protobuf.BundleContainer")
proto.RegisterType((*DRHeader)(nil), "chat.DRHeader") proto.RegisterType((*DRHeader)(nil), "protobuf.DRHeader")
proto.RegisterType((*DHHeader)(nil), "chat.DHHeader") proto.RegisterType((*DHHeader)(nil), "protobuf.DHHeader")
proto.RegisterType((*X3DHHeader)(nil), "chat.X3DHHeader") proto.RegisterType((*X3DHHeader)(nil), "protobuf.X3DHHeader")
proto.RegisterType((*DirectMessageProtocol)(nil), "chat.DirectMessageProtocol") proto.RegisterType((*DirectMessageProtocol)(nil), "protobuf.DirectMessageProtocol")
proto.RegisterType((*ProtocolMessage)(nil), "chat.ProtocolMessage") proto.RegisterType((*ProtocolMessage)(nil), "protobuf.ProtocolMessage")
proto.RegisterMapType((map[string]*DirectMessageProtocol)(nil), "chat.ProtocolMessage.DirectMessageEntry") proto.RegisterMapType((map[string]*DirectMessageProtocol)(nil), "protobuf.ProtocolMessage.DirectMessageEntry")
} }
func init() { proto.RegisterFile("encryption.proto", fileDescriptor_8293a649ce9418c6) } func init() { proto.RegisterFile("encryption.proto", fileDescriptor_8293a649ce9418c6) }
var fileDescriptor_8293a649ce9418c6 = []byte{ var fileDescriptor_8293a649ce9418c6 = []byte{
// 562 bytes of a gzipped FileDescriptorProto // 566 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x8c, 0x54, 0x61, 0x8b, 0xd3, 0x4c, 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x94, 0x53, 0xdd, 0x6a, 0xdb, 0x4c,
0x10, 0x26, 0x49, 0xef, 0xda, 0x4e, 0xd3, 0xb4, 0xec, 0xcb, 0x2b, 0xa1, 0x1e, 0x58, 0xc2, 0xa9, 0x10, 0x45, 0x52, 0xe2, 0x9f, 0xb1, 0xfc, 0xc3, 0x7e, 0x5f, 0x83, 0x30, 0x81, 0x1a, 0xb5, 0xa5,
0x11, 0xa1, 0x70, 0xad, 0x1f, 0xc4, 0x8f, 0x5a, 0xb1, 0x9e, 0x88, 0xc7, 0x2a, 0xe2, 0x17, 0x09, 0x6e, 0x09, 0x2e, 0xd8, 0x0d, 0x94, 0x5e, 0xb6, 0x2e, 0xb8, 0x09, 0x85, 0xb0, 0x81, 0x92, 0x3b,
0xdb, 0x66, 0xbd, 0x5b, 0x4c, 0x93, 0xb0, 0xbb, 0x2d, 0xe4, 0xcf, 0xf9, 0x57, 0xfc, 0x29, 0x4a, 0xb1, 0xb6, 0x36, 0xe9, 0x52, 0x79, 0x25, 0x76, 0xd7, 0x06, 0x3d, 0x41, 0xdf, 0xad, 0x2f, 0xd3,
0x76, 0xb3, 0xed, 0xb6, 0x77, 0x07, 0x7e, 0xeb, 0xcc, 0x3c, 0xfb, 0xcc, 0x33, 0xcf, 0x74, 0x02, 0x57, 0x28, 0x5a, 0x69, 0xad, 0xb5, 0x9d, 0x5c, 0xf4, 0xca, 0x9e, 0xb3, 0x73, 0xce, 0xcc, 0x9c,
0x43, 0x9a, 0xaf, 0x78, 0x55, 0x4a, 0x56, 0xe4, 0x93, 0x92, 0x17, 0xb2, 0x40, 0xad, 0xd5, 0x0d, 0xd1, 0xc0, 0x80, 0xf2, 0x95, 0xc8, 0x33, 0xc5, 0x52, 0x3e, 0xc9, 0x44, 0xaa, 0x52, 0xd4, 0xd2,
0x91, 0x51, 0x05, 0xfe, 0x67, 0x76, 0x9d, 0xd3, 0xf4, 0x8a, 0xd3, 0x0f, 0xb4, 0x42, 0xe7, 0x10, 0x3f, 0xcb, 0xcd, 0x7d, 0x98, 0x83, 0x7f, 0xcb, 0x1e, 0x38, 0x8d, 0x6f, 0x04, 0xbd, 0xa6, 0x39,
0x08, 0x15, 0x27, 0x25, 0xa7, 0xc9, 0x4f, 0x5a, 0x85, 0xce, 0xd8, 0x89, 0x7d, 0xec, 0x0b, 0x1b, 0x7a, 0x09, 0x3d, 0xa9, 0xe3, 0x28, 0x13, 0x34, 0xfa, 0x49, 0xf3, 0xc0, 0x19, 0x39, 0x63, 0x1f,
0x15, 0x42, 0x7b, 0x4b, 0xb9, 0x60, 0x45, 0x1e, 0xba, 0x63, 0x27, 0xee, 0x63, 0x13, 0xa2, 0x67, 0xfb, 0xd2, 0xce, 0x0a, 0xa0, 0xb9, 0xa5, 0x42, 0xb2, 0x94, 0x07, 0xee, 0xc8, 0x19, 0x77, 0xb1,
0x30, 0x54, 0xf4, 0xab, 0x22, 0x4b, 0x0c, 0xc4, 0x53, 0x90, 0x81, 0xc9, 0x7f, 0xd5, 0xe9, 0xe8, 0x09, 0xd1, 0x1b, 0x18, 0x68, 0xed, 0x55, 0x9a, 0x44, 0x26, 0xc5, 0xd3, 0x29, 0x7d, 0x83, 0x7f,
0x8f, 0x03, 0xa7, 0xaf, 0x37, 0x79, 0x9a, 0x51, 0x34, 0x82, 0x0e, 0x4b, 0x69, 0x2e, 0x99, 0x34, 0x2f, 0xe1, 0xf0, 0x97, 0x0b, 0x8d, 0x4f, 0x1b, 0x1e, 0x27, 0x14, 0x0d, 0xa1, 0xc5, 0x62, 0xca,
0xfd, 0x76, 0x31, 0x7a, 0x07, 0x83, 0x43, 0x45, 0x22, 0x74, 0xc7, 0x5e, 0xdc, 0x9b, 0x3e, 0x9a, 0x15, 0x53, 0xa6, 0xde, 0x2e, 0x46, 0xd7, 0xd0, 0xdf, 0xef, 0x48, 0x06, 0xee, 0xc8, 0x1b, 0x77,
0xd4, 0x13, 0x4c, 0x34, 0xc5, 0xc4, 0x9e, 0x42, 0xbc, 0xcd, 0x25, 0xaf, 0x70, 0xdf, 0xd6, 0x2c, 0xa6, 0x2f, 0x26, 0x66, 0x8a, 0x49, 0x29, 0x33, 0xb1, 0x27, 0x91, 0x5f, 0xb8, 0x12, 0x39, 0xee,
0xd0, 0x19, 0x74, 0xeb, 0x04, 0x91, 0x1b, 0x4e, 0xc3, 0x96, 0xea, 0xb2, 0x4f, 0xd4, 0x55, 0xc9, 0xda, 0x7d, 0x4b, 0x74, 0x0e, 0xed, 0x02, 0x20, 0x6a, 0x23, 0x68, 0x70, 0xa2, 0x2b, 0xd5, 0x40,
0xd6, 0x54, 0x48, 0xb2, 0x2e, 0xc3, 0x93, 0xb1, 0x13, 0x7b, 0x78, 0x9f, 0x18, 0x7d, 0x01, 0x74, 0xf1, 0xaa, 0xd8, 0x9a, 0x4a, 0x45, 0xd6, 0x59, 0x70, 0x3a, 0x72, 0xc6, 0x1e, 0xae, 0x81, 0xe1,
0xbb, 0x01, 0x1a, 0x82, 0x67, 0x1c, 0xea, 0xe2, 0xfa, 0x27, 0x8a, 0xe1, 0x64, 0x4b, 0xb2, 0x0d, 0x1d, 0xa0, 0xe3, 0x02, 0x68, 0x00, 0x9e, 0x71, 0xa9, 0x8d, 0x8b, 0xbf, 0xe8, 0x02, 0x4e, 0xb7,
0x55, 0xb6, 0xf4, 0xa6, 0x48, 0x4b, 0xb4, 0x9f, 0x62, 0x0d, 0x78, 0xe5, 0xbe, 0x74, 0x22, 0x0e, 0x24, 0xd9, 0x50, 0x6d, 0x4d, 0x67, 0x7a, 0x56, 0xb7, 0x69, 0xd3, 0x71, 0x99, 0xf4, 0xd1, 0xfd,
0x03, 0xad, 0xfe, 0x4d, 0x91, 0x4b, 0xc2, 0x72, 0xca, 0xd1, 0x39, 0x9c, 0x2e, 0x55, 0x4a, 0xb1, 0xe0, 0x84, 0x5b, 0xe8, 0x97, 0x13, 0x7c, 0x4e, 0xb9, 0x22, 0x8c, 0x53, 0x81, 0xc6, 0xd0, 0x58,
0xf6, 0xa6, 0xbe, 0x3d, 0x24, 0x6e, 0x6a, 0x68, 0x06, 0x0f, 0x4a, 0xce, 0xb6, 0x44, 0xd2, 0xe4, 0x6a, 0x48, 0x2b, 0x77, 0xa6, 0x83, 0xc3, 0x61, 0x71, 0xf5, 0x8e, 0x66, 0x70, 0x96, 0x09, 0xb6,
0x68, 0x5b, 0xae, 0x9a, 0xeb, 0xbf, 0xa6, 0x6a, 0x37, 0xbe, 0x6c, 0x75, 0xbc, 0x61, 0x2b, 0xba, 0x25, 0x8a, 0x46, 0x07, 0x9b, 0x73, 0xf5, 0x7c, 0xff, 0x55, 0xaf, 0x76, 0xf1, 0xab, 0x93, 0x96,
0x84, 0xce, 0x1c, 0x2f, 0x28, 0x49, 0x29, 0xb7, 0xf5, 0xfb, 0x5a, 0xbf, 0x0f, 0x8e, 0x59, 0xa9, 0x37, 0x38, 0x09, 0xaf, 0xa0, 0x35, 0xc7, 0x0b, 0x4a, 0x62, 0x2a, 0xec, 0x39, 0xfc, 0x72, 0x0e,
0x93, 0xa3, 0x00, 0xdc, 0xd2, 0xac, 0xcf, 0x2d, 0x55, 0xcc, 0xd2, 0xc6, 0x3a, 0x97, 0xa5, 0xd1, 0x1f, 0x1c, 0xb3, 0x5e, 0x87, 0xa3, 0x1e, 0xb8, 0x99, 0x59, 0xa5, 0x9b, 0xe9, 0x98, 0xc5, 0x95,
0x19, 0x74, 0xe6, 0x8b, 0xfb, 0xb8, 0xa2, 0x17, 0x00, 0xdf, 0x66, 0xf7, 0xd7, 0x8f, 0xd9, 0x1a, 0x85, 0x2e, 0x8b, 0xc3, 0x73, 0x68, 0xcd, 0x17, 0x4f, 0x69, 0x85, 0xef, 0x01, 0xee, 0x66, 0x4f,
0x7d, 0xbf, 0x1c, 0xf8, 0x7f, 0xce, 0x38, 0x5d, 0xc9, 0x8f, 0x54, 0x08, 0x72, 0x4d, 0xaf, 0x9a, 0xbf, 0x1f, 0xaa, 0x55, 0xfd, 0xfd, 0x76, 0xe0, 0xd9, 0x9c, 0x09, 0xba, 0x52, 0xdf, 0xa8, 0x94,
0xbf, 0x0d, 0xba, 0x80, 0x5e, 0xcd, 0x97, 0xdc, 0x28, 0xc2, 0xc6, 0x9f, 0xa1, 0xf6, 0x67, 0xdf, 0xe4, 0x81, 0xde, 0x54, 0x9f, 0x10, 0xba, 0x84, 0x4e, 0xa1, 0x17, 0xfd, 0xd0, 0x82, 0x95, 0x47,
0x08, 0xdb, 0x4d, 0x9f, 0x43, 0x77, 0x8e, 0xcd, 0x03, 0xbd, 0x92, 0x40, 0x3f, 0x30, 0x1e, 0xe0, 0xff, 0xd7, 0x1e, 0xd5, 0xc5, 0xb0, 0x5d, 0xf8, 0x1d, 0xb4, 0xe7, 0xd8, 0x90, 0xca, 0xf5, 0xa0,
0xbd, 0x1b, 0x35, 0x78, 0xc7, 0x4e, 0x0f, 0xc0, 0x8b, 0x1d, 0xd8, 0x30, 0x87, 0xd0, 0x2e, 0x49, 0x9a, 0x64, 0xbc, 0xc0, 0xb5, 0x2b, 0x05, 0x61, 0x57, 0x85, 0x1e, 0x11, 0x16, 0x3b, 0x82, 0xa9,
0x95, 0x15, 0x24, 0x55, 0xfe, 0xf8, 0xd8, 0x84, 0xd1, 0x6f, 0x17, 0x06, 0x46, 0x73, 0x33, 0xc2, 0x10, 0x40, 0x33, 0x23, 0x79, 0x92, 0x92, 0x58, 0x7b, 0xe5, 0x63, 0x13, 0x86, 0x7f, 0x5c, 0xe8,
0x3f, 0x6e, 0xf5, 0x29, 0x0c, 0x58, 0x2e, 0x24, 0xc9, 0x32, 0x52, 0xdf, 0x69, 0xc2, 0x52, 0xa5, 0x9b, 0xfe, 0xab, 0x71, 0xfe, 0x61, 0xcb, 0xaf, 0xa1, 0xcf, 0xb8, 0x54, 0x24, 0x49, 0x48, 0x71,
0xb9, 0x8b, 0x03, 0x3b, 0xfd, 0x3e, 0x45, 0x4f, 0xa0, 0xad, 0x9f, 0x88, 0xd0, 0x53, 0xa7, 0x70, 0xc7, 0x11, 0x8b, 0x75, 0xff, 0x6d, 0xdc, 0xb3, 0xe1, 0xaf, 0x31, 0x7a, 0x0b, 0xcd, 0x92, 0x22,
0xc8, 0x67, 0x8a, 0xe8, 0x13, 0x04, 0xa9, 0xb2, 0x32, 0x59, 0x6b, 0x21, 0x21, 0x55, 0xf0, 0x58, 0x03, 0x4f, 0x9f, 0xc9, 0xb1, 0xa6, 0x49, 0x40, 0xb7, 0xd0, 0x8b, 0xb5, 0xbd, 0xd1, 0xba, 0x6c,
0xc3, 0x8f, 0x54, 0x4e, 0x0e, 0x6c, 0x6f, 0x4e, 0x28, 0xb5, 0x73, 0xe8, 0x31, 0x04, 0xe5, 0x66, 0x28, 0xa0, 0x9a, 0x72, 0x51, 0x53, 0x0e, 0x3a, 0x9e, 0xec, 0xad, 0xa3, 0x3a, 0xb1, 0xd8, 0xc6,
0x99, 0xb1, 0xd5, 0x8e, 0xf0, 0x87, 0x1a, 0xbe, 0xaf, 0xb3, 0x0d, 0x6c, 0xf4, 0x1d, 0xd0, 0x6d, 0xd0, 0x2b, 0xe8, 0x65, 0x9b, 0x65, 0xc2, 0x56, 0x3b, 0xd1, 0x7b, 0x6d, 0x44, 0xb7, 0x44, 0xab,
0xae, 0x3b, 0xae, 0xe5, 0xe2, 0xf0, 0x5a, 0x1e, 0x36, 0x6e, 0xdf, 0xb5, 0x7d, 0xeb, 0x6c, 0x96, 0xb4, 0x21, 0x01, 0x74, 0xac, 0xf5, 0xc8, 0x35, 0x5d, 0xee, 0x5f, 0xd3, 0x73, 0xcb, 0xfd, 0xc7,
0xa7, 0xea, 0x4b, 0x32, 0xfb, 0x1b, 0x00, 0x00, 0xff, 0xff, 0x9e, 0x75, 0x6d, 0x59, 0xd4, 0x04, 0xbe, 0x0c, 0xeb, 0xac, 0x96, 0x0d, 0x9d, 0x3a, 0xfb, 0x1b, 0x00, 0x00, 0xff, 0xff, 0xcd, 0x2d,
0x00, 0x00, 0x0e, 0xc8, 0x00, 0x05, 0x00, 0x00,
} }

View File

@ -1,6 +1,6 @@
syntax = "proto3"; syntax = "proto3";
package chat; package protobuf;
message SignedPreKey { message SignedPreKey {
bytes signed_pre_key = 1; bytes signed_pre_key = 1;

View File

@ -5,37 +5,48 @@ import (
"errors" "errors"
"github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/log"
"github.com/status-im/status-go/services/shhext/chat/topic" "github.com/status-im/status-go/services/shhext/chat/multidevice"
"github.com/status-im/status-go/services/shhext/chat/protobuf"
"github.com/status-im/status-go/services/shhext/chat/sharedsecret"
) )
const protocolCurrentVersion = 1 const ProtocolVersion = 1
const topicNegotiationVersion = 1 const sharedSecretNegotiationVersion = 1
const partitionedTopicMinVersion = 1
type ProtocolService struct { type ProtocolService struct {
log log.Logger log log.Logger
encryption *EncryptionService encryption *EncryptionService
topic *topic.Service secret *sharedsecret.Service
addedBundlesHandler func([]IdentityAndIDPair) multidevice *multidevice.Service
onNewTopicHandler func([]*topic.Secret) addedBundlesHandler func([]multidevice.IdentityAndIDPair)
onNewSharedSecretHandler func([]*sharedsecret.Secret)
Enabled bool Enabled bool
} }
var ErrNotProtocolMessage = errors.New("Not a protocol message") var ErrNotProtocolMessage = errors.New("Not a protocol message")
// NewProtocolService creates a new ProtocolService instance // NewProtocolService creates a new ProtocolService instance
func NewProtocolService(encryption *EncryptionService, topic *topic.Service, addedBundlesHandler func([]IdentityAndIDPair), onNewTopicHandler func([]*topic.Secret)) *ProtocolService { func NewProtocolService(encryption *EncryptionService, secret *sharedsecret.Service, multidevice *multidevice.Service, addedBundlesHandler func([]multidevice.IdentityAndIDPair), onNewSharedSecretHandler func([]*sharedsecret.Secret)) *ProtocolService {
return &ProtocolService{ return &ProtocolService{
log: log.New("package", "status-go/services/sshext.chat"), log: log.New("package", "status-go/services/sshext.chat"),
encryption: encryption, encryption: encryption,
topic: topic, secret: secret,
multidevice: multidevice,
addedBundlesHandler: addedBundlesHandler, addedBundlesHandler: addedBundlesHandler,
onNewTopicHandler: onNewTopicHandler, onNewSharedSecretHandler: onNewSharedSecretHandler,
} }
} }
func (p *ProtocolService) addBundle(myIdentityKey *ecdsa.PrivateKey, msg *ProtocolMessage, sendSingle bool) (*ProtocolMessage, error) { func (p *ProtocolService) addBundle(myIdentityKey *ecdsa.PrivateKey, msg *protobuf.ProtocolMessage, sendSingle bool) (*protobuf.ProtocolMessage, error) {
// Get a bundle // Get a bundle
bundle, err := p.encryption.CreateBundle(myIdentityKey) installations, err := p.multidevice.GetOurActiveInstallations(&myIdentityKey.PublicKey)
if err != nil {
return nil, err
}
bundle, err := p.encryption.CreateBundle(myIdentityKey, installations)
if err != nil { if err != nil {
p.log.Error("encryption-service", "error creating bundle", err) p.log.Error("encryption-service", "error creating bundle", err)
return nil, err return nil, err
@ -46,16 +57,16 @@ func (p *ProtocolService) addBundle(myIdentityKey *ecdsa.PrivateKey, msg *Protoc
// an issue anymore // an issue anymore
msg.Bundle = bundle msg.Bundle = bundle
} else { } else {
msg.Bundles = []*Bundle{bundle} msg.Bundles = []*protobuf.Bundle{bundle}
} }
return msg, nil return msg, nil
} }
// BuildPublicMessage marshals a public chat message given the user identity private key and a payload // BuildPublicMessage marshals a public chat message given the user identity private key and a payload
func (p *ProtocolService) BuildPublicMessage(myIdentityKey *ecdsa.PrivateKey, payload []byte) (*ProtocolMessage, error) { func (p *ProtocolService) BuildPublicMessage(myIdentityKey *ecdsa.PrivateKey, payload []byte) (*protobuf.ProtocolMessage, error) {
// Build message not encrypted // Build message not encrypted
protocolMessage := &ProtocolMessage{ protocolMessage := &protobuf.ProtocolMessage{
InstallationId: p.encryption.config.InstallationID, InstallationId: p.encryption.config.InstallationID,
PublicMessage: payload, PublicMessage: payload,
} }
@ -63,100 +74,153 @@ func (p *ProtocolService) BuildPublicMessage(myIdentityKey *ecdsa.PrivateKey, pa
return p.addBundle(myIdentityKey, protocolMessage, false) return p.addBundle(myIdentityKey, protocolMessage, false)
} }
type ProtocolMessageSpec struct {
Message *protobuf.ProtocolMessage
// Installations is the targeted devices
Installations []*multidevice.Installation
// SharedSecret is a shared secret established among the installations
SharedSecret []byte
}
func (p *ProtocolMessageSpec) MinVersion() uint32 {
var version uint32
for _, installation := range p.Installations {
if installation.Version < version {
version = installation.Version
}
}
return version
}
func (p *ProtocolMessageSpec) PartitionedTopic() bool {
return p.MinVersion() >= partitionedTopicMinVersion
}
// BuildDirectMessage returns a 1:1 chat message and optionally a negotiated topic given the user identity private key, the recipient's public key, and a payload // BuildDirectMessage returns a 1:1 chat message and optionally a negotiated topic given the user identity private key, the recipient's public key, and a payload
func (p *ProtocolService) BuildDirectMessage(myIdentityKey *ecdsa.PrivateKey, publicKey *ecdsa.PublicKey, payload []byte) (*ProtocolMessage, []byte, error) { func (p *ProtocolService) BuildDirectMessage(myIdentityKey *ecdsa.PrivateKey, publicKey *ecdsa.PublicKey, payload []byte) (*ProtocolMessageSpec, error) {
activeInstallations, err := p.multidevice.GetActiveInstallations(publicKey)
if err != nil {
return nil, err
}
// Encrypt payload // Encrypt payload
encryptionResponse, err := p.encryption.EncryptPayload(publicKey, myIdentityKey, payload) encryptionResponse, installations, err := p.encryption.EncryptPayload(publicKey, myIdentityKey, activeInstallations, payload)
if err != nil { if err != nil {
p.log.Error("encryption-service", "error encrypting payload", err) p.log.Error("encryption-service", "error encrypting payload", err)
return nil, nil, err return nil, err
} }
// Build message // Build message
protocolMessage := &ProtocolMessage{ protocolMessage := &protobuf.ProtocolMessage{
InstallationId: p.encryption.config.InstallationID, InstallationId: p.encryption.config.InstallationID,
DirectMessage: encryptionResponse, DirectMessage: encryptionResponse,
} }
msg, err := p.addBundle(myIdentityKey, protocolMessage, true) msg, err := p.addBundle(myIdentityKey, protocolMessage, true)
if err != nil { if err != nil {
return nil, nil, err return nil, err
} }
// Check who we are sending the message to, and see if we have a shared secret // Check who we are sending the message to, and see if we have a shared secret
// across devices // across devices
var installationIDs []string var installationIDs []string
var sharedSecret *topic.Secret var sharedSecret *sharedsecret.Secret
var agreed bool var agreed bool
for installationID := range protocolMessage.GetDirectMessage() { for installationID := range protocolMessage.GetDirectMessage() {
if installationID != noInstallationID { if installationID != noInstallationID {
installationIDs = append(installationIDs, installationID) installationIDs = append(installationIDs, installationID)
} }
} }
if len(installationIDs) != 0 {
sharedSecret, agreed, err = p.topic.Send(myIdentityKey, p.encryption.config.InstallationID, publicKey, installationIDs) sharedSecret, agreed, err = p.secret.Send(myIdentityKey, p.encryption.config.InstallationID, publicKey, installationIDs)
if err != nil { if err != nil {
return nil, nil, err return nil, err
}
} }
// Call handler // Call handler
if sharedSecret != nil { if sharedSecret != nil {
p.onNewTopicHandler([]*topic.Secret{sharedSecret}) p.onNewSharedSecretHandler([]*sharedsecret.Secret{sharedSecret})
}
response := &ProtocolMessageSpec{
Message: msg,
Installations: installations,
} }
if agreed { if agreed {
return msg, sharedSecret.Key, nil response.SharedSecret = sharedSecret.Key
} }
return msg, nil, nil return response, nil
} }
// BuildDHMessage builds a message with DH encryption so that it can be decrypted by any other device. // BuildDHMessage builds a message with DH encryption so that it can be decrypted by any other device.
func (p *ProtocolService) BuildDHMessage(myIdentityKey *ecdsa.PrivateKey, destination *ecdsa.PublicKey, payload []byte) (*ProtocolMessage, []byte, error) { func (p *ProtocolService) BuildDHMessage(myIdentityKey *ecdsa.PrivateKey, destination *ecdsa.PublicKey, payload []byte) (*ProtocolMessageSpec, error) {
// Encrypt payload // Encrypt payload
encryptionResponse, err := p.encryption.EncryptPayloadWithDH(destination, payload) encryptionResponse, err := p.encryption.EncryptPayloadWithDH(destination, payload)
if err != nil { if err != nil {
p.log.Error("encryption-service", "error encrypting payload", err) p.log.Error("encryption-service", "error encrypting payload", err)
return nil, nil, err return nil, err
} }
// Build message // Build message
protocolMessage := &ProtocolMessage{ protocolMessage := &protobuf.ProtocolMessage{
InstallationId: p.encryption.config.InstallationID, InstallationId: p.encryption.config.InstallationID,
DirectMessage: encryptionResponse, DirectMessage: encryptionResponse,
} }
msg, err := p.addBundle(myIdentityKey, protocolMessage, true) msg, err := p.addBundle(myIdentityKey, protocolMessage, true)
if err != nil { if err != nil {
return nil, nil, err return nil, err
} }
return msg, nil, nil return &ProtocolMessageSpec{Message: msg}, nil
} }
// ProcessPublicBundle processes a received X3DH bundle. // ProcessPublicBundle processes a received X3DH bundle.
func (p *ProtocolService) ProcessPublicBundle(myIdentityKey *ecdsa.PrivateKey, bundle *Bundle) ([]IdentityAndIDPair, error) { func (p *ProtocolService) ProcessPublicBundle(myIdentityKey *ecdsa.PrivateKey, bundle *protobuf.Bundle) ([]multidevice.IdentityAndIDPair, error) {
return p.encryption.ProcessPublicBundle(myIdentityKey, bundle) if err := p.encryption.ProcessPublicBundle(myIdentityKey, bundle); err != nil {
return nil, err
}
theirIdentityKey, err := ExtractIdentity(bundle)
if err != nil {
return nil, err
}
return p.multidevice.ProcessPublicBundle(myIdentityKey, theirIdentityKey, bundle)
} }
// GetBundle retrieves or creates a X3DH bundle, given a private identity key. // GetBundle retrieves or creates a X3DH bundle, given a private identity key.
func (p *ProtocolService) GetBundle(myIdentityKey *ecdsa.PrivateKey) (*Bundle, error) { func (p *ProtocolService) GetBundle(myIdentityKey *ecdsa.PrivateKey) (*protobuf.Bundle, error) {
return p.encryption.CreateBundle(myIdentityKey) installations, err := p.multidevice.GetOurActiveInstallations(&myIdentityKey.PublicKey)
if err != nil {
return nil, err
}
return p.encryption.CreateBundle(myIdentityKey, installations)
} }
// EnableInstallation enables an installation for multi-device sync. // EnableInstallation enables an installation for multi-device sync.
func (p *ProtocolService) EnableInstallation(myIdentityKey *ecdsa.PublicKey, installationID string) error { func (p *ProtocolService) EnableInstallation(myIdentityKey *ecdsa.PublicKey, installationID string) error {
return p.encryption.EnableInstallation(myIdentityKey, installationID) return p.multidevice.EnableInstallation(myIdentityKey, installationID)
} }
// DisableInstallation disables an installation for multi-device sync. // DisableInstallation disables an installation for multi-device sync.
func (p *ProtocolService) DisableInstallation(myIdentityKey *ecdsa.PublicKey, installationID string) error { func (p *ProtocolService) DisableInstallation(myIdentityKey *ecdsa.PublicKey, installationID string) error {
return p.encryption.DisableInstallation(myIdentityKey, installationID) return p.multidevice.DisableInstallation(myIdentityKey, installationID)
} }
// GetPublicBundle retrieves a public bundle given an identity // GetPublicBundle retrieves a public bundle given an identity
func (p *ProtocolService) GetPublicBundle(theirIdentityKey *ecdsa.PublicKey) (*Bundle, error) { func (p *ProtocolService) GetPublicBundle(theirIdentityKey *ecdsa.PublicKey) (*protobuf.Bundle, error) {
return p.encryption.GetPublicBundle(theirIdentityKey) installations, err := p.multidevice.GetActiveInstallations(theirIdentityKey)
if err != nil {
return nil, err
}
return p.encryption.GetPublicBundle(theirIdentityKey, installations)
} }
// ConfirmMessagesProcessed confirms and deletes message keys for the given messages // ConfirmMessagesProcessed confirms and deletes message keys for the given messages
@ -165,7 +229,7 @@ func (p *ProtocolService) ConfirmMessagesProcessed(messageIDs [][]byte) error {
} }
// HandleMessage unmarshals a message and processes it, decrypting it if it is a 1:1 message. // HandleMessage unmarshals a message and processes it, decrypting it if it is a 1:1 message.
func (p *ProtocolService) HandleMessage(myIdentityKey *ecdsa.PrivateKey, theirPublicKey *ecdsa.PublicKey, protocolMessage *ProtocolMessage, messageID []byte) ([]byte, error) { func (p *ProtocolService) HandleMessage(myIdentityKey *ecdsa.PrivateKey, theirPublicKey *ecdsa.PublicKey, protocolMessage *protobuf.ProtocolMessage, messageID []byte) ([]byte, error) {
if p.encryption == nil { if p.encryption == nil {
return nil, errors.New("encryption service not initialized") return nil, errors.New("encryption service not initialized")
} }
@ -173,7 +237,7 @@ func (p *ProtocolService) HandleMessage(myIdentityKey *ecdsa.PrivateKey, theirPu
// Process bundle, deprecated, here for backward compatibility // Process bundle, deprecated, here for backward compatibility
if bundle := protocolMessage.GetBundle(); bundle != nil { if bundle := protocolMessage.GetBundle(); bundle != nil {
// Should we stop processing if the bundle cannot be verified? // Should we stop processing if the bundle cannot be verified?
addedBundles, err := p.encryption.ProcessPublicBundle(myIdentityKey, bundle) addedBundles, err := p.ProcessPublicBundle(myIdentityKey, bundle)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -184,7 +248,7 @@ func (p *ProtocolService) HandleMessage(myIdentityKey *ecdsa.PrivateKey, theirPu
// Process bundles // Process bundles
for _, bundle := range protocolMessage.GetBundles() { for _, bundle := range protocolMessage.GetBundles() {
// Should we stop processing if the bundle cannot be verified? // Should we stop processing if the bundle cannot be verified?
addedBundles, err := p.encryption.ProcessPublicBundle(myIdentityKey, bundle) addedBundles, err := p.ProcessPublicBundle(myIdentityKey, bundle)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -205,17 +269,20 @@ func (p *ProtocolService) HandleMessage(myIdentityKey *ecdsa.PrivateKey, theirPu
return nil, err return nil, err
} }
var bundles []*protobuf.Bundle
p.log.Info("Checking version") p.log.Info("Checking version")
// Handle protocol negotiation for compatible clients // Handle protocol negotiation for compatible clients
version := getProtocolVersion(protocolMessage.GetBundles(), protocolMessage.GetInstallationId()) p.log.Info("bundle", "bundles", protocolMessage)
if version >= topicNegotiationVersion { bundles = append(protocolMessage.GetBundles(), protocolMessage.GetBundle())
version := getProtocolVersion(bundles, protocolMessage.GetInstallationId())
if version >= sharedSecretNegotiationVersion {
p.log.Info("Version greater than 1 negotianting") p.log.Info("Version greater than 1 negotianting")
sharedSecret, err := p.topic.Receive(myIdentityKey, theirPublicKey, protocolMessage.GetInstallationId()) sharedSecret, err := p.secret.Receive(myIdentityKey, theirPublicKey, protocolMessage.GetInstallationId())
if err != nil { if err != nil {
return nil, err return nil, err
} }
p.onNewTopicHandler([]*topic.Secret{sharedSecret}) p.onNewSharedSecretHandler([]*sharedsecret.Secret{sharedSecret})
} }
return message, nil return message, nil
@ -225,12 +292,13 @@ func (p *ProtocolService) HandleMessage(myIdentityKey *ecdsa.PrivateKey, theirPu
return nil, errors.New("no payload") return nil, errors.New("no payload")
} }
func getProtocolVersion(bundles []*Bundle, installationID string) uint32 { func getProtocolVersion(bundles []*protobuf.Bundle, installationID string) uint32 {
if installationID == "" { if installationID == "" {
return 0 return 0
} }
for _, bundle := range bundles { for _, bundle := range bundles {
if bundle != nil {
signedPreKeys := bundle.GetSignedPreKeys() signedPreKeys := bundle.GetSignedPreKeys()
if signedPreKeys == nil { if signedPreKeys == nil {
continue continue
@ -243,6 +311,7 @@ func getProtocolVersion(bundles []*Bundle, installationID string) uint32 {
return signedPreKey.GetProtocolVersion() return signedPreKey.GetProtocolVersion()
} }
}
return 0 return 0
} }

View File

@ -5,7 +5,8 @@ import (
"testing" "testing"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/status-im/status-go/services/shhext/chat/topic" "github.com/status-im/status-go/services/shhext/chat/multidevice"
"github.com/status-im/status-go/services/shhext/chat/sharedsecret"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
) )
@ -38,21 +39,35 @@ func (s *ProtocolServiceTestSuite) SetupTest() {
panic(err) panic(err)
} }
addedBundlesHandler := func(addedBundles []IdentityAndIDPair) {} addedBundlesHandler := func(addedBundles []multidevice.IdentityAndIDPair) {}
onNewTopicHandler := func(topic []*topic.Secret) {} onNewSharedSecretHandler := func(secret []*sharedsecret.Secret) {}
aliceMultideviceConfig := &multidevice.Config{
MaxInstallations: 3,
InstallationID: "1",
ProtocolVersion: ProtocolVersion,
}
s.alice = NewProtocolService( s.alice = NewProtocolService(
NewEncryptionService(alicePersistence, DefaultEncryptionServiceConfig("1")), NewEncryptionService(alicePersistence, DefaultEncryptionServiceConfig("1")),
topic.NewService(alicePersistence.GetTopicStorage()), sharedsecret.NewService(alicePersistence.GetSharedSecretStorage()),
multidevice.New(aliceMultideviceConfig, alicePersistence.GetMultideviceStorage()),
addedBundlesHandler, addedBundlesHandler,
onNewTopicHandler, onNewSharedSecretHandler,
) )
bobMultideviceConfig := &multidevice.Config{
MaxInstallations: 3,
InstallationID: "2",
ProtocolVersion: ProtocolVersion,
}
s.bob = NewProtocolService( s.bob = NewProtocolService(
NewEncryptionService(bobPersistence, DefaultEncryptionServiceConfig("2")), NewEncryptionService(bobPersistence, DefaultEncryptionServiceConfig("2")),
topic.NewService(bobPersistence.GetTopicStorage()), sharedsecret.NewService(bobPersistence.GetSharedSecretStorage()),
multidevice.New(bobMultideviceConfig, bobPersistence.GetMultideviceStorage()),
addedBundlesHandler, addedBundlesHandler,
onNewTopicHandler, onNewSharedSecretHandler,
) )
} }
@ -79,9 +94,12 @@ func (s *ProtocolServiceTestSuite) TestBuildDirectMessage() {
payload := []byte("test") payload := []byte("test")
msg, _, err := s.alice.BuildDirectMessage(aliceKey, &bobKey.PublicKey, payload) msgSpec, err := s.alice.BuildDirectMessage(aliceKey, &bobKey.PublicKey, payload)
s.NoError(err) s.NoError(err)
s.NotNil(msg, "It creates a message") s.NotNil(msgSpec, "It creates a message spec")
msg := msgSpec.Message
s.NotNil(msg, "It creates a messages")
s.NotNilf(msg.GetBundle(), "It adds a bundle to the message") s.NotNilf(msg.GetBundle(), "It adds a bundle to the message")
@ -95,6 +113,32 @@ func (s *ProtocolServiceTestSuite) TestBuildDirectMessage() {
} }
func (s *ProtocolServiceTestSuite) TestBuildAndReadDirectMessage() { func (s *ProtocolServiceTestSuite) TestBuildAndReadDirectMessage() {
bobKey, err := crypto.GenerateKey()
s.Require().NoError(err)
aliceKey, err := crypto.GenerateKey()
s.Require().NoError(err)
payload := []byte("test")
// Message is sent with DH
msgSpec, err := s.alice.BuildDirectMessage(aliceKey, &bobKey.PublicKey, payload)
s.Require().NoError(err)
s.Require().NotNil(msgSpec)
msg := msgSpec.Message
s.Require().NotNil(msg)
// Bob is able to decrypt the message
unmarshaledMsg, err := s.bob.HandleMessage(bobKey, &aliceKey.PublicKey, msg, []byte("message-id"))
s.NoError(err)
s.NotNil(unmarshaledMsg)
recoveredPayload := []byte("test")
s.Equalf(payload, recoveredPayload, "It successfully unmarshal the decrypted message")
}
func (s *ProtocolServiceTestSuite) TestSecretNegotiation() {
var secretResponse []*sharedsecret.Secret
bobKey, err := crypto.GenerateKey() bobKey, err := crypto.GenerateKey()
s.NoError(err) s.NoError(err)
aliceKey, err := crypto.GenerateKey() aliceKey, err := crypto.GenerateKey()
@ -102,17 +146,26 @@ func (s *ProtocolServiceTestSuite) TestBuildAndReadDirectMessage() {
payload := []byte("test") payload := []byte("test")
// Message is sent with DH s.bob.onNewSharedSecretHandler = func(secret []*sharedsecret.Secret) {
marshaledMsg, _, err := s.alice.BuildDirectMessage(aliceKey, &bobKey.PublicKey, payload) secretResponse = secret
}
s.NoError(err) msgSpec, err := s.alice.BuildDirectMessage(aliceKey, &bobKey.PublicKey, payload)
s.NoError(err)
// Bob is able to decrypt the message s.NotNil(msgSpec, "It creates a message spec")
unmarshaledMsg, err := s.bob.HandleMessage(bobKey, &aliceKey.PublicKey, marshaledMsg, []byte("message-id"))
s.NoError(err) bundle := msgSpec.Message.GetBundle()
s.Require().NotNil(bundle)
s.NotNil(unmarshaledMsg)
signedPreKeys := bundle.GetSignedPreKeys()
recoveredPayload := []byte("test") s.Require().NotNil(signedPreKeys)
s.Equalf(payload, recoveredPayload, "It successfully unmarshal the decrypted message")
signedPreKey := signedPreKeys["1"]
s.Require().NotNil(signedPreKey)
s.Require().Equal(uint32(1), signedPreKey.GetProtocolVersion())
_, err = s.bob.HandleMessage(bobKey, &aliceKey.PublicKey, msgSpec.Message, []byte("message-id"))
s.NoError(err)
s.Require().NotNil(secretResponse)
} }

View File

@ -1,4 +1,4 @@
package topic package sharedsecret
import ( import (
"database/sql" "database/sql"
@ -30,20 +30,20 @@ func (s *SQLLitePersistence) Add(identity []byte, secret []byte, installationID
return err return err
} }
insertTopicStmt, err := tx.Prepare("INSERT INTO topics(identity, secret) VALUES (?, ?)") insertSecretStmt, err := tx.Prepare("INSERT INTO secrets(identity, secret) VALUES (?, ?)")
if err != nil { if err != nil {
_ = tx.Rollback() _ = tx.Rollback()
return err return err
} }
defer insertTopicStmt.Close() defer insertSecretStmt.Close()
_, err = insertTopicStmt.Exec(identity, secret) _, err = insertSecretStmt.Exec(identity, secret)
if err != nil { if err != nil {
_ = tx.Rollback() _ = tx.Rollback()
return err return err
} }
insertInstallationIDStmt, err := tx.Prepare("INSERT INTO topic_installation_ids(id, identity_id) VALUES (?, ?)") insertInstallationIDStmt, err := tx.Prepare("INSERT INTO secret_installation_ids(id, identity_id) VALUES (?, ?)")
if err != nil { if err != nil {
_ = tx.Rollback() _ = tx.Rollback()
return err return err
@ -70,9 +70,9 @@ func (s *SQLLitePersistence) Get(identity []byte, installationIDs []string) (*Re
/* #nosec */ /* #nosec */
query := `SELECT secret, id query := `SELECT secret, id
FROM topics t FROM secrets t
JOIN JOIN
topic_installation_ids tid secret_installation_ids tid
ON t.identity = tid.identity_id ON t.identity = tid.identity_id
WHERE WHERE
t.identity = ? t.identity = ?
@ -101,7 +101,7 @@ func (s *SQLLitePersistence) Get(identity []byte, installationIDs []string) (*Re
func (s *SQLLitePersistence) All() ([][][]byte, error) { func (s *SQLLitePersistence) All() ([][][]byte, error) {
query := `SELECT identity, secret query := `SELECT identity, secret
FROM topics` FROM secrets`
var secrets [][][]byte var secrets [][][]byte

View File

@ -1,4 +1,4 @@
package topic package sharedsecret
import ( import (
"crypto/ecdsa" "crypto/ecdsa"
@ -17,8 +17,8 @@ func NewService(persistence PersistenceService) *Service {
return &Service{persistence: persistence} return &Service{persistence: persistence}
} }
func (s *Service) setupTopic(myPrivateKey *ecdsa.PrivateKey, theirPublicKey *ecdsa.PublicKey, installationID string) (*Secret, error) { func (s *Service) setup(myPrivateKey *ecdsa.PrivateKey, theirPublicKey *ecdsa.PublicKey, installationID string) (*Secret, error) {
log.Info("Setup topic called for", "installationID", installationID) log.Info("Setup called for", "installationID", installationID)
sharedKey, err := ecies.ImportECDSA(myPrivateKey).GenerateShared( sharedKey, err := ecies.ImportECDSA(myPrivateKey).GenerateShared(
ecies.ImportECDSAPublic(theirPublicKey), ecies.ImportECDSAPublic(theirPublicKey),
sskLen, sskLen,
@ -38,16 +38,20 @@ func (s *Service) setupTopic(myPrivateKey *ecdsa.PrivateKey, theirPublicKey *ecd
// Receive will generate a shared secret for a given identity, and return it // Receive will generate a shared secret for a given identity, and return it
func (s *Service) Receive(myPrivateKey *ecdsa.PrivateKey, theirPublicKey *ecdsa.PublicKey, installationID string) (*Secret, error) { func (s *Service) Receive(myPrivateKey *ecdsa.PrivateKey, theirPublicKey *ecdsa.PublicKey, installationID string) (*Secret, error) {
return s.setupTopic(myPrivateKey, theirPublicKey, installationID) return s.setup(myPrivateKey, theirPublicKey, installationID)
} }
// Send returns a shared key and whether it has been acknowledged from all the installationIDs // Send returns a shared key and whether it has been acknowledged from all the installationIDs
func (s *Service) Send(myPrivateKey *ecdsa.PrivateKey, myInstallationID string, theirPublicKey *ecdsa.PublicKey, theirInstallationIDs []string) (*Secret, bool, error) { func (s *Service) Send(myPrivateKey *ecdsa.PrivateKey, myInstallationID string, theirPublicKey *ecdsa.PublicKey, theirInstallationIDs []string) (*Secret, bool, error) {
sharedKey, err := s.setupTopic(myPrivateKey, theirPublicKey, myInstallationID) secret, err := s.setup(myPrivateKey, theirPublicKey, myInstallationID)
if err != nil { if err != nil {
return nil, false, err return nil, false, err
} }
if len(theirInstallationIDs) == 0 {
return secret, false, nil
}
theirIdentity := crypto.CompressPubkey(theirPublicKey) theirIdentity := crypto.CompressPubkey(theirPublicKey)
response, err := s.persistence.Get(theirIdentity, theirInstallationIDs) response, err := s.persistence.Get(theirIdentity, theirInstallationIDs)
if err != nil { if err != nil {
@ -56,14 +60,11 @@ func (s *Service) Send(myPrivateKey *ecdsa.PrivateKey, myInstallationID string,
for _, installationID := range theirInstallationIDs { for _, installationID := range theirInstallationIDs {
if !response.installationIDs[installationID] { if !response.installationIDs[installationID] {
return sharedKey, false, nil return secret, false, nil
} }
} }
return &Secret{ return secret, true, nil
Key: response.secret,
Identity: theirPublicKey,
}, true, nil
} }
type Secret struct { type Secret struct {

View File

@ -1,4 +1,4 @@
package topic package sharedsecret
import ( import (
"io/ioutil" "io/ioutil"
@ -21,7 +21,7 @@ type ServiceTestSuite struct {
} }
func (s *ServiceTestSuite) SetupTest() { func (s *ServiceTestSuite) SetupTest() {
dbFile, err := ioutil.TempFile(os.TempDir(), "topic") dbFile, err := ioutil.TempFile(os.TempDir(), "sharedsecret")
s.Require().NoError(err) s.Require().NoError(err)
s.path = dbFile.Name() s.path = dbFile.Name()
@ -103,12 +103,12 @@ func (s *ServiceTestSuite) TestAll() {
s.Require().NoError(err) s.Require().NoError(err)
s.Require().NotNil(sharedKey2, "it generates a shared key") s.Require().NotNil(sharedKey2, "it generates a shared key")
// All the topics are there // All the secrets are there
topics, err := s.service.All() secrets, err := s.service.All()
s.Require().NoError(err) s.Require().NoError(err)
expected := []*Secret{ expected := []*Secret{
sharedKey1, sharedKey1,
sharedKey2, sharedKey2,
} }
s.Require().Equal(expected, topics) s.Require().Equal(expected, secrets)
} }

View File

@ -13,7 +13,9 @@ import (
"github.com/status-im/migrate/v4/source/go_bindata" "github.com/status-im/migrate/v4/source/go_bindata"
ecrypto "github.com/status-im/status-go/services/shhext/chat/crypto" ecrypto "github.com/status-im/status-go/services/shhext/chat/crypto"
appDB "github.com/status-im/status-go/services/shhext/chat/db" appDB "github.com/status-im/status-go/services/shhext/chat/db"
"github.com/status-im/status-go/services/shhext/chat/topic" "github.com/status-im/status-go/services/shhext/chat/multidevice"
"github.com/status-im/status-go/services/shhext/chat/protobuf"
"github.com/status-im/status-go/services/shhext/chat/sharedsecret"
) )
// A safe max number of rows // A safe max number of rows
@ -24,7 +26,8 @@ type SQLLitePersistence struct {
db *sql.DB db *sql.DB
keysStorage dr.KeysStorage keysStorage dr.KeysStorage
sessionStorage dr.SessionStorage sessionStorage dr.SessionStorage
topicStorage topic.PersistenceService secretStorage sharedsecret.PersistenceService
multideviceStorage multidevice.Persistence
} }
// SQLLiteKeysStorage represents a keys persistence service tied to an SQLite database // SQLLiteKeysStorage represents a keys persistence service tied to an SQLite database
@ -49,7 +52,9 @@ func NewSQLLitePersistence(path string, key string) (*SQLLitePersistence, error)
s.sessionStorage = NewSQLLiteSessionStorage(s.db) s.sessionStorage = NewSQLLiteSessionStorage(s.db)
s.topicStorage = topic.NewSQLLitePersistence(s.db) s.secretStorage = sharedsecret.NewSQLLitePersistence(s.db)
s.multideviceStorage = multidevice.NewSQLLitePersistence(s.db)
return s, nil return s, nil
} }
@ -78,9 +83,14 @@ func (s *SQLLitePersistence) GetSessionStorage() dr.SessionStorage {
return s.sessionStorage return s.sessionStorage
} }
// GetTopicStorage returns the associated topicStorageObject // GetSharedSecretStorage returns the associated secretStorageObject
func (s *SQLLitePersistence) GetTopicStorage() topic.PersistenceService { func (s *SQLLitePersistence) GetSharedSecretStorage() sharedsecret.PersistenceService {
return s.topicStorage return s.secretStorage
}
// GetMultideviceStorage returns the associated multideviceStorage
func (s *SQLLitePersistence) GetMultideviceStorage() multidevice.Persistence {
return s.multideviceStorage
} }
// Open opens a file at the specified path // Open opens a file at the specified path
@ -96,7 +106,7 @@ func (s *SQLLitePersistence) Open(path string, key string) error {
} }
// AddPrivateBundle adds the specified BundleContainer to the database // AddPrivateBundle adds the specified BundleContainer to the database
func (s *SQLLitePersistence) AddPrivateBundle(bc *BundleContainer) error { func (s *SQLLitePersistence) AddPrivateBundle(bc *protobuf.BundleContainer) error {
tx, err := s.db.Begin() tx, err := s.db.Begin()
if err != nil { if err != nil {
return err return err
@ -150,7 +160,7 @@ func (s *SQLLitePersistence) AddPrivateBundle(bc *BundleContainer) error {
} }
// AddPublicBundle adds the specified Bundle to the database // AddPublicBundle adds the specified Bundle to the database
func (s *SQLLitePersistence) AddPublicBundle(b *Bundle) error { func (s *SQLLitePersistence) AddPublicBundle(b *protobuf.Bundle) error {
tx, err := s.db.Begin() tx, err := s.db.Begin()
if err != nil { if err != nil {
@ -203,7 +213,7 @@ func (s *SQLLitePersistence) AddPublicBundle(b *Bundle) error {
} }
// GetAnyPrivateBundle retrieves any bundle from the database containing a private key // GetAnyPrivateBundle retrieves any bundle from the database containing a private key
func (s *SQLLitePersistence) GetAnyPrivateBundle(myIdentityKey []byte, installations []*Installation) (*BundleContainer, error) { func (s *SQLLitePersistence) GetAnyPrivateBundle(myIdentityKey []byte, installations []*multidevice.Installation) (*protobuf.BundleContainer, error) {
versions := make(map[string]uint32) versions := make(map[string]uint32)
/* #nosec */ /* #nosec */
@ -239,11 +249,11 @@ func (s *SQLLitePersistence) GetAnyPrivateBundle(myIdentityKey []byte, installat
defer rows.Close() defer rows.Close()
bundle := &Bundle{ bundle := &protobuf.Bundle{
SignedPreKeys: make(map[string]*SignedPreKey), SignedPreKeys: make(map[string]*protobuf.SignedPreKey),
} }
bundleContainer := &BundleContainer{ bundleContainer := &protobuf.BundleContainer{
Bundle: bundle, Bundle: bundle,
} }
@ -267,7 +277,7 @@ func (s *SQLLitePersistence) GetAnyPrivateBundle(myIdentityKey []byte, installat
bundle.Timestamp = timestamp bundle.Timestamp = timestamp
} }
bundle.SignedPreKeys[installationID] = &SignedPreKey{ bundle.SignedPreKeys[installationID] = &protobuf.SignedPreKey{
SignedPreKey: signedPreKey, SignedPreKey: signedPreKey,
Version: version, Version: version,
ProtocolVersion: versions[installationID], ProtocolVersion: versions[installationID],
@ -323,7 +333,7 @@ func (s *SQLLitePersistence) MarkBundleExpired(identity []byte) error {
} }
// GetPublicBundle retrieves an existing Bundle for the specified public key from the database // GetPublicBundle retrieves an existing Bundle for the specified public key from the database
func (s *SQLLitePersistence) GetPublicBundle(publicKey *ecdsa.PublicKey, installations []*Installation) (*Bundle, error) { func (s *SQLLitePersistence) GetPublicBundle(publicKey *ecdsa.PublicKey, installations []*multidevice.Installation) (*protobuf.Bundle, error) {
if len(installations) == 0 { if len(installations) == 0 {
return nil, nil return nil, nil
@ -360,9 +370,9 @@ func (s *SQLLitePersistence) GetPublicBundle(publicKey *ecdsa.PublicKey, install
defer rows.Close() defer rows.Close()
bundle := &Bundle{ bundle := &protobuf.Bundle{
Identity: identity, Identity: identity,
SignedPreKeys: make(map[string]*SignedPreKey), SignedPreKeys: make(map[string]*protobuf.SignedPreKey),
} }
for rows.Next() { for rows.Next() {
@ -379,7 +389,7 @@ func (s *SQLLitePersistence) GetPublicBundle(publicKey *ecdsa.PublicKey, install
return nil, err return nil, err
} }
bundle.SignedPreKeys[installationID] = &SignedPreKey{ bundle.SignedPreKeys[installationID] = &protobuf.SignedPreKey{
SignedPreKey: signedPreKey, SignedPreKey: signedPreKey,
Version: version, Version: version,
ProtocolVersion: versions[installationID], ProtocolVersion: versions[installationID],
@ -754,159 +764,6 @@ func (s *SQLLiteSessionStorage) Load(id []byte) (*dr.State, error) {
} }
} }
// GetActiveInstallations returns the active installations for a given identity
func (s *SQLLitePersistence) GetActiveInstallations(maxInstallations int, identity []byte) ([]*Installation, error) {
stmt, err := s.db.Prepare(`SELECT installation_id, version
FROM installations
WHERE enabled = 1 AND identity = ?
ORDER BY timestamp DESC
LIMIT ?`)
if err != nil {
return nil, err
}
var installations []*Installation
rows, err := stmt.Query(identity, maxInstallations)
if err != nil {
return nil, err
}
for rows.Next() {
var installationID string
var version uint32
err = rows.Scan(
&installationID,
&version,
)
if err != nil {
return nil, err
}
installations = append(installations, &Installation{
ID: installationID,
Version: version,
})
}
return installations, nil
}
// AddInstallations adds the installations for a given identity, maintaining the enabled flag
func (s *SQLLitePersistence) AddInstallations(identity []byte, timestamp int64, installations []*Installation, defaultEnabled bool) error {
tx, err := s.db.Begin()
if err != nil {
return nil
}
for _, installation := range installations {
stmt, err := tx.Prepare(`SELECT enabled, version
FROM installations
WHERE identity = ? AND installation_id = ?
LIMIT 1`)
if err != nil {
return err
}
defer stmt.Close()
var oldEnabled bool
// We don't override version once we saw one
var oldVersion uint32
latestVersion := installation.Version
err = stmt.QueryRow(identity, installation.ID).Scan(&oldEnabled, &oldVersion)
if err != nil && err != sql.ErrNoRows {
return err
}
// We update timestamp if present without changing enabled, only if this is a new bundle
// and we set the version to the latest we ever saw
if err != sql.ErrNoRows {
if oldVersion > installation.Version {
latestVersion = oldVersion
}
stmt, err = tx.Prepare(`UPDATE installations
SET timestamp = ?, enabled = ?, version = ?
WHERE identity = ?
AND installation_id = ?
AND timestamp < ?`)
if err != nil {
return err
}
_, err = stmt.Exec(
timestamp,
oldEnabled,
latestVersion,
identity,
installation.ID,
timestamp,
)
if err != nil {
return err
}
defer stmt.Close()
} else {
stmt, err = tx.Prepare(`INSERT INTO installations(identity, installation_id, timestamp, enabled, version)
VALUES (?, ?, ?, ?, ?)`)
if err != nil {
return err
}
_, err = stmt.Exec(
identity,
installation.ID,
timestamp,
defaultEnabled,
latestVersion,
)
if err != nil {
return err
}
defer stmt.Close()
}
}
if err := tx.Commit(); err != nil {
_ = tx.Rollback()
return err
}
return nil
}
// EnableInstallation enables the installation
func (s *SQLLitePersistence) EnableInstallation(identity []byte, installationID string) error {
stmt, err := s.db.Prepare(`UPDATE installations
SET enabled = 1
WHERE identity = ? AND installation_id = ?`)
if err != nil {
return err
}
_, err = stmt.Exec(identity, installationID)
return err
}
// DisableInstallation disable the installation
func (s *SQLLitePersistence) DisableInstallation(identity []byte, installationID string) error {
stmt, err := s.db.Prepare(`UPDATE installations
SET enabled = 0
WHERE identity = ? AND installation_id = ?`)
if err != nil {
return err
}
_, err = stmt.Exec(identity, installationID)
return err
}
func toKey(a []byte) dr.Key { func toKey(a []byte) dr.Key {
var k [32]byte var k [32]byte
copy(k[:], a) copy(k[:], a)

View File

@ -6,6 +6,7 @@ import (
"testing" "testing"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/status-im/status-go/services/shhext/chat/multidevice"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
) )
@ -53,7 +54,7 @@ func (s *SQLLitePersistenceTestSuite) TestPrivateBundle() {
s.Require().NoError(err, "Error was not returned even though bundle is not there") s.Require().NoError(err, "Error was not returned even though bundle is not there")
s.Nil(actualKey) s.Nil(actualKey)
anyPrivateBundle, err := s.service.GetAnyPrivateBundle([]byte("non-existing-id"), []*Installation{{ID: installationID, Version: 1}}) anyPrivateBundle, err := s.service.GetAnyPrivateBundle([]byte("non-existing-id"), []*multidevice.Installation{{ID: installationID, Version: 1}})
s.Require().NoError(err) s.Require().NoError(err)
s.Nil(anyPrivateBundle) s.Nil(anyPrivateBundle)
@ -70,7 +71,7 @@ func (s *SQLLitePersistenceTestSuite) TestPrivateBundle() {
s.Equal(bundle.GetPrivateSignedPreKey(), actualKey, "It returns the same key") s.Equal(bundle.GetPrivateSignedPreKey(), actualKey, "It returns the same key")
identity := crypto.CompressPubkey(&key.PublicKey) identity := crypto.CompressPubkey(&key.PublicKey)
anyPrivateBundle, err = s.service.GetAnyPrivateBundle(identity, []*Installation{{ID: installationID, Version: 1}}) anyPrivateBundle, err = s.service.GetAnyPrivateBundle(identity, []*multidevice.Installation{{ID: installationID, Version: 1}})
s.Require().NoError(err) s.Require().NoError(err)
s.NotNil(anyPrivateBundle) s.NotNil(anyPrivateBundle)
s.Equal(bundle.GetBundle().GetSignedPreKeys()[installationID].SignedPreKey, anyPrivateBundle.GetBundle().GetSignedPreKeys()[installationID].SignedPreKey, "It returns the same bundle") s.Equal(bundle.GetBundle().GetSignedPreKeys()[installationID].SignedPreKey, anyPrivateBundle.GetBundle().GetSignedPreKeys()[installationID].SignedPreKey, "It returns the same bundle")
@ -80,7 +81,7 @@ func (s *SQLLitePersistenceTestSuite) TestPublicBundle() {
key, err := crypto.GenerateKey() key, err := crypto.GenerateKey()
s.Require().NoError(err) s.Require().NoError(err)
actualBundle, err := s.service.GetPublicBundle(&key.PublicKey, []*Installation{{ID: "1", Version: 1}}) actualBundle, err := s.service.GetPublicBundle(&key.PublicKey, []*multidevice.Installation{{ID: "1", Version: 1}})
s.Require().NoError(err, "Error was not returned even though bundle is not there") s.Require().NoError(err, "Error was not returned even though bundle is not there")
s.Nil(actualBundle) s.Nil(actualBundle)
@ -91,7 +92,7 @@ func (s *SQLLitePersistenceTestSuite) TestPublicBundle() {
err = s.service.AddPublicBundle(bundle) err = s.service.AddPublicBundle(bundle)
s.Require().NoError(err) s.Require().NoError(err)
actualBundle, err = s.service.GetPublicBundle(&key.PublicKey, []*Installation{{ID: "1", Version: 1}}) actualBundle, err = s.service.GetPublicBundle(&key.PublicKey, []*multidevice.Installation{{ID: "1", Version: 1}})
s.Require().NoError(err) s.Require().NoError(err)
s.Equal(bundle.GetIdentity(), actualBundle.GetIdentity(), "It sets the right identity") s.Equal(bundle.GetIdentity(), actualBundle.GetIdentity(), "It sets the right identity")
s.Equal(bundle.GetSignedPreKeys(), actualBundle.GetSignedPreKeys(), "It sets the right prekeys") s.Equal(bundle.GetSignedPreKeys(), actualBundle.GetSignedPreKeys(), "It sets the right prekeys")
@ -101,7 +102,7 @@ func (s *SQLLitePersistenceTestSuite) TestUpdatedBundle() {
key, err := crypto.GenerateKey() key, err := crypto.GenerateKey()
s.Require().NoError(err) s.Require().NoError(err)
actualBundle, err := s.service.GetPublicBundle(&key.PublicKey, []*Installation{{ID: "1", Version: 1}}) actualBundle, err := s.service.GetPublicBundle(&key.PublicKey, []*multidevice.Installation{{ID: "1", Version: 1}})
s.Require().NoError(err, "Error was not returned even though bundle is not there") s.Require().NoError(err, "Error was not returned even though bundle is not there")
s.Nil(actualBundle) s.Nil(actualBundle)
@ -123,7 +124,7 @@ func (s *SQLLitePersistenceTestSuite) TestUpdatedBundle() {
err = s.service.AddPublicBundle(bundle) err = s.service.AddPublicBundle(bundle)
s.Require().NoError(err) s.Require().NoError(err)
actualBundle, err = s.service.GetPublicBundle(&key.PublicKey, []*Installation{{ID: "1", Version: 1}}) actualBundle, err = s.service.GetPublicBundle(&key.PublicKey, []*multidevice.Installation{{ID: "1", Version: 1}})
s.Require().NoError(err) s.Require().NoError(err)
s.Equal(bundle.GetIdentity(), actualBundle.GetIdentity(), "It sets the right identity") s.Equal(bundle.GetIdentity(), actualBundle.GetIdentity(), "It sets the right identity")
s.Equal(bundle.GetSignedPreKeys(), actualBundle.GetSignedPreKeys(), "It sets the right prekeys") s.Equal(bundle.GetSignedPreKeys(), actualBundle.GetSignedPreKeys(), "It sets the right prekeys")
@ -133,7 +134,7 @@ func (s *SQLLitePersistenceTestSuite) TestOutOfOrderBundles() {
key, err := crypto.GenerateKey() key, err := crypto.GenerateKey()
s.Require().NoError(err) s.Require().NoError(err)
actualBundle, err := s.service.GetPublicBundle(&key.PublicKey, []*Installation{{ID: "1", Version: 1}}) actualBundle, err := s.service.GetPublicBundle(&key.PublicKey, []*multidevice.Installation{{ID: "1", Version: 1}})
s.Require().NoError(err, "Error was not returned even though bundle is not there") s.Require().NoError(err, "Error was not returned even though bundle is not there")
s.Nil(actualBundle) s.Nil(actualBundle)
@ -160,7 +161,7 @@ func (s *SQLLitePersistenceTestSuite) TestOutOfOrderBundles() {
err = s.service.AddPublicBundle(bundle1) err = s.service.AddPublicBundle(bundle1)
s.Require().NoError(err) s.Require().NoError(err)
actualBundle, err = s.service.GetPublicBundle(&key.PublicKey, []*Installation{{ID: "1", Version: 1}}) actualBundle, err = s.service.GetPublicBundle(&key.PublicKey, []*multidevice.Installation{{ID: "1", Version: 1}})
s.Require().NoError(err) s.Require().NoError(err)
s.Equal(bundle2.GetIdentity(), actualBundle.GetIdentity(), "It sets the right identity") s.Equal(bundle2.GetIdentity(), actualBundle.GetIdentity(), "It sets the right identity")
s.Equal(bundle2.GetSignedPreKeys()["1"].GetVersion(), uint32(1)) s.Equal(bundle2.GetSignedPreKeys()["1"].GetVersion(), uint32(1))
@ -171,7 +172,7 @@ func (s *SQLLitePersistenceTestSuite) TestMultiplePublicBundle() {
key, err := crypto.GenerateKey() key, err := crypto.GenerateKey()
s.Require().NoError(err) s.Require().NoError(err)
actualBundle, err := s.service.GetPublicBundle(&key.PublicKey, []*Installation{{ID: "1", Version: 1}}) actualBundle, err := s.service.GetPublicBundle(&key.PublicKey, []*multidevice.Installation{{ID: "1", Version: 1}})
s.Require().NoError(err, "Error was not returned even though bundle is not there") s.Require().NoError(err, "Error was not returned even though bundle is not there")
s.Nil(actualBundle) s.Nil(actualBundle)
@ -197,7 +198,7 @@ func (s *SQLLitePersistenceTestSuite) TestMultiplePublicBundle() {
s.Require().NoError(err) s.Require().NoError(err)
// Returns the most recent bundle // Returns the most recent bundle
actualBundle, err = s.service.GetPublicBundle(&key.PublicKey, []*Installation{{ID: "1", Version: 1}}) actualBundle, err = s.service.GetPublicBundle(&key.PublicKey, []*multidevice.Installation{{ID: "1", Version: 1}})
s.Require().NoError(err) s.Require().NoError(err)
s.Equal(bundle.GetIdentity(), actualBundle.GetIdentity(), "It sets the identity") s.Equal(bundle.GetIdentity(), actualBundle.GetIdentity(), "It sets the identity")
@ -209,7 +210,7 @@ func (s *SQLLitePersistenceTestSuite) TestMultiDevicePublicBundle() {
key, err := crypto.GenerateKey() key, err := crypto.GenerateKey()
s.Require().NoError(err) s.Require().NoError(err)
actualBundle, err := s.service.GetPublicBundle(&key.PublicKey, []*Installation{{ID: "1", Version: 1}}) actualBundle, err := s.service.GetPublicBundle(&key.PublicKey, []*multidevice.Installation{{ID: "1", Version: 1}})
s.Require().NoError(err, "Error was not returned even though bundle is not there") s.Require().NoError(err, "Error was not returned even though bundle is not there")
s.Nil(actualBundle) s.Nil(actualBundle)
@ -234,7 +235,7 @@ func (s *SQLLitePersistenceTestSuite) TestMultiDevicePublicBundle() {
// Returns the most recent bundle // Returns the most recent bundle
actualBundle, err = s.service.GetPublicBundle(&key.PublicKey, actualBundle, err = s.service.GetPublicBundle(&key.PublicKey,
[]*Installation{ []*multidevice.Installation{
{ID: "1", Version: 1}, {ID: "1", Version: 1},
{ID: "2", Version: 1}, {ID: "2", Version: 1},
}) })
@ -347,211 +348,4 @@ func (s *SQLLitePersistenceTestSuite) TestRatchetInfoNoBundle() {
s.Nil(ratchetInfo, "It returns nil when no bundle is there") s.Nil(ratchetInfo, "It returns nil when no bundle is there")
} }
func (s *SQLLitePersistenceTestSuite) TestAddInstallations() {
identity := []byte("alice")
installations := []*Installation{
{ID: "alice-1", Version: 1},
{ID: "alice-2", Version: 2},
}
err := s.service.AddInstallations(
identity,
1,
installations,
true,
)
s.Require().NoError(err)
enabledInstallations, err := s.service.GetActiveInstallations(5, identity)
s.Require().NoError(err)
s.Require().Equal(installations, enabledInstallations)
}
func (s *SQLLitePersistenceTestSuite) TestAddInstallationVersions() {
identity := []byte("alice")
installations := []*Installation{
{ID: "alice-1", Version: 1},
}
err := s.service.AddInstallations(
identity,
1,
installations,
true,
)
s.Require().NoError(err)
enabledInstallations, err := s.service.GetActiveInstallations(5, identity)
s.Require().NoError(err)
s.Require().Equal(installations, enabledInstallations)
installationsWithDowngradedVersion := []*Installation{
{ID: "alice-1", Version: 0},
}
err = s.service.AddInstallations(
identity,
3,
installationsWithDowngradedVersion,
true,
)
s.Require().NoError(err)
enabledInstallations, err = s.service.GetActiveInstallations(5, identity)
s.Require().NoError(err)
s.Require().Equal(installations, enabledInstallations)
}
func (s *SQLLitePersistenceTestSuite) TestAddInstallationsLimit() {
identity := []byte("alice")
installations := []*Installation{
{ID: "alice-1", Version: 1},
{ID: "alice-2", Version: 2},
}
err := s.service.AddInstallations(
identity,
1,
installations,
true,
)
s.Require().NoError(err)
installations = []*Installation{
{ID: "alice-1", Version: 1},
{ID: "alice-3", Version: 3},
}
err = s.service.AddInstallations(
identity,
2,
installations,
true,
)
s.Require().NoError(err)
installations = []*Installation{
{ID: "alice-2", Version: 2},
{ID: "alice-3", Version: 3},
{ID: "alice-4", Version: 4},
}
err = s.service.AddInstallations(
identity,
3,
installations,
true,
)
s.Require().NoError(err)
enabledInstallations, err := s.service.GetActiveInstallations(3, identity)
s.Require().NoError(err)
s.Require().Equal(installations, enabledInstallations)
}
func (s *SQLLitePersistenceTestSuite) TestAddInstallationsDisabled() {
identity := []byte("alice")
installations := []*Installation{
{ID: "alice-1", Version: 1},
{ID: "alice-2", Version: 2},
}
err := s.service.AddInstallations(
identity,
1,
installations,
false,
)
s.Require().NoError(err)
actualInstallations, err := s.service.GetActiveInstallations(3, identity)
s.Require().NoError(err)
s.Require().Nil(actualInstallations)
}
func (s *SQLLitePersistenceTestSuite) TestDisableInstallation() {
identity := []byte("alice")
installations := []*Installation{
{ID: "alice-1", Version: 1},
{ID: "alice-2", Version: 2},
}
err := s.service.AddInstallations(
identity,
1,
installations,
true,
)
s.Require().NoError(err)
err = s.service.DisableInstallation(identity, "alice-1")
s.Require().NoError(err)
// We add the installations again
installations = []*Installation{
{ID: "alice-1", Version: 1},
{ID: "alice-2", Version: 2},
}
err = s.service.AddInstallations(
identity,
1,
installations,
true,
)
s.Require().NoError(err)
actualInstallations, err := s.service.GetActiveInstallations(3, identity)
s.Require().NoError(err)
expected := []*Installation{{ID: "alice-2", Version: 2}}
s.Require().Equal(expected, actualInstallations)
}
func (s *SQLLitePersistenceTestSuite) TestEnableInstallation() {
identity := []byte("alice")
installations := []*Installation{
{ID: "alice-1", Version: 1},
{ID: "alice-2", Version: 2},
}
err := s.service.AddInstallations(
identity,
1,
installations,
true,
)
s.Require().NoError(err)
err = s.service.DisableInstallation(identity, "alice-1")
s.Require().NoError(err)
actualInstallations, err := s.service.GetActiveInstallations(3, identity)
s.Require().NoError(err)
expected := []*Installation{{ID: "alice-2", Version: 2}}
s.Require().Equal(expected, actualInstallations)
err = s.service.EnableInstallation(identity, "alice-1")
s.Require().NoError(err)
actualInstallations, err = s.service.GetActiveInstallations(3, identity)
s.Require().NoError(err)
expected = []*Installation{
{ID: "alice-1", Version: 1},
{ID: "alice-2", Version: 2},
}
s.Require().Equal(expected, actualInstallations)
}
// TODO: Add test for MarkBundleExpired // TODO: Add test for MarkBundleExpired

View File

@ -1,64 +0,0 @@
package chat
import (
"github.com/ethereum/go-ethereum/crypto"
whisper "github.com/status-im/whisper/whisperv6"
)
var discoveryTopic = "contact-discovery"
var discoveryTopicBytes = toTopic(discoveryTopic)
var topicSalt = []byte{0x01, 0x02, 0x03, 0x04}
func toTopic(s string) whisper.TopicType {
return whisper.BytesToTopic(crypto.Keccak256([]byte(s)))
}
func SharedSecretToTopic(secret []byte) whisper.TopicType {
return whisper.BytesToTopic(crypto.Keccak256(append(secret, topicSalt...)))
}
func defaultWhisperMessage() whisper.NewMessage {
msg := whisper.NewMessage{}
msg.TTL = 10
msg.PowTarget = 0.002
msg.PowTime = 1
return msg
}
func PublicMessageToWhisper(rpcMsg SendPublicMessageRPC, payload []byte) whisper.NewMessage {
msg := defaultWhisperMessage()
msg.Topic = toTopic(rpcMsg.Chat)
msg.Payload = payload
msg.Sig = rpcMsg.Sig
return msg
}
func DirectMessageToWhisper(rpcMsg SendDirectMessageRPC, payload []byte, sharedSecret []byte) whisper.NewMessage {
var topicBytes whisper.TopicType
msg := defaultWhisperMessage()
if rpcMsg.Chat == "" {
if sharedSecret != nil {
topicBytes = SharedSecretToTopic(sharedSecret)
} else {
topicBytes = discoveryTopicBytes
msg.PublicKey = rpcMsg.PubKey
}
} else {
topicBytes = toTopic(rpcMsg.Chat)
msg.PublicKey = rpcMsg.PubKey
}
msg.Topic = topicBytes
msg.Payload = payload
msg.Sig = rpcMsg.Sig
return msg
}

View File

@ -1,56 +0,0 @@
package chat
import (
whisper "github.com/status-im/whisper/whisperv6"
"testing"
"github.com/stretchr/testify/assert"
)
func TestPublicMessageToWhisper(t *testing.T) {
rpcMessage := SendPublicMessageRPC{
Chat: "test-chat",
Sig: "test",
}
payload := []byte("test")
whisperMessage := PublicMessageToWhisper(rpcMessage, payload)
assert.Equalf(t, uint32(10), whisperMessage.TTL, "It sets the TTL")
assert.Equalf(t, 0.002, whisperMessage.PowTarget, "It sets the pow target")
assert.Equalf(t, uint32(1), whisperMessage.PowTime, "It sets the pow time")
assert.Equalf(t, whisper.TopicType{0xa4, 0xab, 0xdf, 0x64}, whisperMessage.Topic, "It sets the topic")
}
func TestDirectMessageToWhisper(t *testing.T) {
rpcMessage := SendDirectMessageRPC{
PubKey: []byte("some pubkey"),
Sig: "test",
}
payload := []byte("test")
whisperMessage := DirectMessageToWhisper(rpcMessage, payload, nil)
assert.Equalf(t, uint32(10), whisperMessage.TTL, "It sets the TTL")
assert.Equalf(t, 0.002, whisperMessage.PowTarget, "It sets the pow target")
assert.Equalf(t, uint32(1), whisperMessage.PowTime, "It sets the pow time")
assert.Equalf(t, whisper.TopicType{0xf8, 0x94, 0x6a, 0xac}, whisperMessage.Topic, "It sets the discovery topic")
}
func TestDirectMessageToWhisperWithSharedSecret(t *testing.T) {
rpcMessage := SendDirectMessageRPC{
PubKey: []byte("some pubkey"),
Sig: "test",
}
payload := []byte("test")
secret := []byte("test-secret")
whisperMessage := DirectMessageToWhisper(rpcMessage, payload, secret)
assert.Equalf(t, uint32(10), whisperMessage.TTL, "It sets the TTL")
assert.Equalf(t, 0.002, whisperMessage.PowTarget, "It sets the pow target")
assert.Equalf(t, uint32(1), whisperMessage.PowTime, "It sets the pow time")
assert.Equalf(t, whisper.TopicType{0xd8, 0xa2, 0xf3, 0x64}, whisperMessage.Topic, "It sets the discovery topic")
}

View File

@ -2,16 +2,14 @@ package chat
import ( import (
"crypto/ecdsa" "crypto/ecdsa"
"encoding/base64"
"errors" "errors"
"fmt"
"sort" "sort"
"strconv" "strconv"
"time" "time"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/crypto/ecies" "github.com/ethereum/go-ethereum/crypto/ecies"
"github.com/golang/protobuf/proto" "github.com/status-im/status-go/services/shhext/chat/protobuf"
) )
const ( const (
@ -19,33 +17,7 @@ const (
sskLen = 16 sskLen = 16
) )
// ToBase64 returns a Base64 encoding representation of the protobuf Bundle message func buildSignatureMaterial(bundle *protobuf.Bundle) []byte {
func (bundle *Bundle) ToBase64() (string, error) {
marshaledMessage, err := proto.Marshal(bundle)
if err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(marshaledMessage), nil
}
// FromBase64 unmarshals a Bundle from a Base64 encoding representation of the protobuf Bundle message
func FromBase64(str string) (*Bundle, error) {
bundle := &Bundle{}
decodedBundle, err := base64.StdEncoding.DecodeString(str)
if err != nil {
return nil, err
}
if err := proto.Unmarshal(decodedBundle, bundle); err != nil {
return nil, err
}
return bundle, nil
}
func buildSignatureMaterial(bundle *Bundle) []byte {
signedPreKeys := bundle.GetSignedPreKeys() signedPreKeys := bundle.GetSignedPreKeys()
timestamp := bundle.GetTimestamp() timestamp := bundle.GetTimestamp()
var keys []string var keys []string
@ -73,7 +45,7 @@ func buildSignatureMaterial(bundle *Bundle) []byte {
} }
func SignBundle(identity *ecdsa.PrivateKey, bundleContainer *BundleContainer) error { func SignBundle(identity *ecdsa.PrivateKey, bundleContainer *protobuf.BundleContainer) error {
signatureMaterial := buildSignatureMaterial(bundleContainer.GetBundle()) signatureMaterial := buildSignatureMaterial(bundleContainer.GetBundle())
signature, err := crypto.Sign(crypto.Keccak256(signatureMaterial), identity) signature, err := crypto.Sign(crypto.Keccak256(signatureMaterial), identity)
@ -85,7 +57,7 @@ func SignBundle(identity *ecdsa.PrivateKey, bundleContainer *BundleContainer) er
} }
// NewBundleContainer creates a new BundleContainer from an identity private key // NewBundleContainer creates a new BundleContainer from an identity private key
func NewBundleContainer(identity *ecdsa.PrivateKey, installationID string) (*BundleContainer, error) { func NewBundleContainer(identity *ecdsa.PrivateKey, installationID string) (*protobuf.BundleContainer, error) {
preKey, err := crypto.GenerateKey() preKey, err := crypto.GenerateKey()
if err != nil { if err != nil {
return nil, err return nil, err
@ -95,35 +67,35 @@ func NewBundleContainer(identity *ecdsa.PrivateKey, installationID string) (*Bun
compressedIdentityKey := crypto.CompressPubkey(&identity.PublicKey) compressedIdentityKey := crypto.CompressPubkey(&identity.PublicKey)
encodedPreKey := crypto.FromECDSA(preKey) encodedPreKey := crypto.FromECDSA(preKey)
signedPreKeys := make(map[string]*SignedPreKey) signedPreKeys := make(map[string]*protobuf.SignedPreKey)
signedPreKeys[installationID] = &SignedPreKey{ signedPreKeys[installationID] = &protobuf.SignedPreKey{
ProtocolVersion: protocolCurrentVersion, ProtocolVersion: ProtocolVersion,
SignedPreKey: compressedPreKey, SignedPreKey: compressedPreKey,
} }
bundle := Bundle{ bundle := protobuf.Bundle{
Timestamp: time.Now().UnixNano(), Timestamp: time.Now().UnixNano(),
Identity: compressedIdentityKey, Identity: compressedIdentityKey,
SignedPreKeys: signedPreKeys, SignedPreKeys: signedPreKeys,
} }
return &BundleContainer{ return &protobuf.BundleContainer{
Bundle: &bundle, Bundle: &bundle,
PrivateSignedPreKey: encodedPreKey, PrivateSignedPreKey: encodedPreKey,
}, nil }, nil
} }
// VerifyBundle checks that a bundle is valid // VerifyBundle checks that a bundle is valid
func VerifyBundle(bundle *Bundle) error { func VerifyBundle(bundle *protobuf.Bundle) error {
_, err := ExtractIdentity(bundle) _, err := ExtractIdentity(bundle)
return err return err
} }
// ExtractIdentity extracts the identity key from a given bundle // ExtractIdentity extracts the identity key from a given bundle
func ExtractIdentity(bundle *Bundle) (string, error) { func ExtractIdentity(bundle *protobuf.Bundle) (*ecdsa.PublicKey, error) {
bundleIdentityKey, err := crypto.DecompressPubkey(bundle.GetIdentity()) bundleIdentityKey, err := crypto.DecompressPubkey(bundle.GetIdentity())
if err != nil { if err != nil {
return "", err return nil, err
} }
signatureMaterial := buildSignatureMaterial(bundle) signatureMaterial := buildSignatureMaterial(bundle)
@ -133,14 +105,14 @@ func ExtractIdentity(bundle *Bundle) (string, error) {
bundle.GetSignature(), bundle.GetSignature(),
) )
if err != nil { if err != nil {
return "", err return nil, err
} }
if crypto.PubkeyToAddress(*recoveredKey) != crypto.PubkeyToAddress(*bundleIdentityKey) { if crypto.PubkeyToAddress(*recoveredKey) != crypto.PubkeyToAddress(*bundleIdentityKey) {
return "", errors.New("identity key and signature mismatch") return nil, errors.New("identity key and signature mismatch")
} }
return fmt.Sprintf("0x%x", crypto.FromECDSAPub(recoveredKey)), nil return recoveredKey, nil
} }
// PerformDH generates a shared key given a private and a public key // PerformDH generates a shared key given a private and a public key

View File

@ -6,6 +6,7 @@ import (
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/crypto/ecies" "github.com/ethereum/go-ethereum/crypto/ecies"
"github.com/status-im/status-go/services/shhext/chat/protobuf"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -14,12 +15,11 @@ const (
aliceEphemeralKey = "11111111111111111111111111111111" aliceEphemeralKey = "11111111111111111111111111111111"
bobPrivateKey = "22222222222222222222222222222222" bobPrivateKey = "22222222222222222222222222222222"
bobSignedPreKey = "33333333333333333333333333333333" bobSignedPreKey = "33333333333333333333333333333333"
base64Bundle = "CiECkJmdu/QwNL/7HdU+rB60wzpOocT0i6WFz944MIQPBVUSKAoBMhIjCiECPHKt20/fCa+U8MlNf+kqOGp+cM+KHYWRY4a7JTXHsbEiQT9Wse3UkJgo6/1HzxQfHcZNBaMH0j+0eylfBf1ropsLZ7yZM98k/qDQ3ZW5uHXQ4zhY8E1Q7HDytqm62k5JIPYA"
) )
var sharedKey = []byte{0xa4, 0xe9, 0x23, 0xd0, 0xaf, 0x8f, 0xe7, 0x8a, 0x5, 0x63, 0x63, 0xbe, 0x20, 0xe7, 0x1c, 0xa, 0x58, 0xe5, 0x69, 0xea, 0x8f, 0xc1, 0xf7, 0x92, 0x89, 0xec, 0xa1, 0xd, 0x9f, 0x68, 0x13, 0x3a} var sharedKey = []byte{0xa4, 0xe9, 0x23, 0xd0, 0xaf, 0x8f, 0xe7, 0x8a, 0x5, 0x63, 0x63, 0xbe, 0x20, 0xe7, 0x1c, 0xa, 0x58, 0xe5, 0x69, 0xea, 0x8f, 0xc1, 0xf7, 0x92, 0x89, 0xec, 0xa1, 0xd, 0x9f, 0x68, 0x13, 0x3a}
func bobBundle() (*Bundle, error) { func bobBundle() (*protobuf.Bundle, error) {
privateKey, err := crypto.ToECDSA([]byte(bobPrivateKey)) privateKey, err := crypto.ToECDSA([]byte(bobPrivateKey))
if err != nil { if err != nil {
return nil, err return nil, err
@ -37,10 +37,10 @@ func bobBundle() (*Bundle, error) {
return nil, err return nil, err
} }
signedPreKeys := make(map[string]*SignedPreKey) signedPreKeys := make(map[string]*protobuf.SignedPreKey)
signedPreKeys[bobInstallationID] = &SignedPreKey{SignedPreKey: compressedPreKey} signedPreKeys[bobInstallationID] = &protobuf.SignedPreKey{SignedPreKey: compressedPreKey}
bundle := Bundle{ bundle := protobuf.Bundle{
Identity: crypto.CompressPubkey(&privateKey.PublicKey), Identity: crypto.CompressPubkey(&privateKey.PublicKey),
SignedPreKeys: signedPreKeys, SignedPreKeys: signedPreKeys,
Signature: signature, Signature: signature,
@ -76,8 +76,8 @@ func TestNewBundleContainer(t *testing.T) {
require.Equal( require.Equal(
t, t,
&privateKey.PublicKey, privateKey.PublicKey,
recoveredPublicKey, *recoveredPublicKey,
"The correct public key should be recovered", "The correct public key should be recovered",
) )
} }
@ -94,7 +94,7 @@ func TestSignBundle(t *testing.T) {
// We add a signed pre key // We add a signed pre key
signedPreKeys := bundle1.GetSignedPreKeys() signedPreKeys := bundle1.GetSignedPreKeys()
signedPreKeys["2"] = &SignedPreKey{SignedPreKey: []byte("key")} signedPreKeys["2"] = &protobuf.SignedPreKey{SignedPreKey: []byte("key")}
err = SignBundle(privateKey, bundleContainer1) err = SignBundle(privateKey, bundleContainer1)
require.NoError(t, err) require.NoError(t, err)
@ -115,40 +115,12 @@ func TestSignBundle(t *testing.T) {
require.Equal( require.Equal(
t, t,
&privateKey.PublicKey, privateKey.PublicKey,
recoveredPublicKey, *recoveredPublicKey,
"The correct public key should be recovered", "The correct public key should be recovered",
) )
} }
func TestToBase64(t *testing.T) {
bundle, err := bobBundle()
require.NoError(t, err, "Test bundle should be generated without errors")
actualBase64Bundle, err := bundle.ToBase64()
require.NoError(t, err, "No error should be reported")
require.Equal(
t,
base64Bundle,
actualBase64Bundle,
"The correct bundle should be generated",
)
}
func TestFromBase64(t *testing.T) {
expectedBundle, err := bobBundle()
require.NoError(t, err, "Test bundle should be generated without errors")
actualBundle, err := FromBase64(base64Bundle)
require.NoError(t, err, "Bundle should be unmarshaled without errors")
require.Equal(
t,
expectedBundle,
actualBundle,
"The correct bundle should be generated",
)
}
func TestExtractIdentity(t *testing.T) { func TestExtractIdentity(t *testing.T) {
privateKey, err := crypto.ToECDSA([]byte(alicePrivateKey)) privateKey, err := crypto.ToECDSA([]byte(alicePrivateKey))
require.NoError(t, err, "Private key should be generated without errors") require.NoError(t, err, "Private key should be generated without errors")
@ -168,8 +140,8 @@ func TestExtractIdentity(t *testing.T) {
require.Equal( require.Equal(
t, t,
"0x042ed557f5ad336b31a49857e4e9664954ac33385aa20a93e2d64bfe7f08f51277bcb27c1259f802a52ed3ea7ac939043f0cc864e27400294bf121f23877995852", privateKey.PublicKey,
recoveredPublicKey, *recoveredPublicKey,
"The correct public key should be recovered", "The correct public key should be recovered",
) )
} }

View File

@ -3,10 +3,11 @@ package filter
import ( import (
"crypto/ecdsa" "crypto/ecdsa"
"encoding/hex" "encoding/hex"
"errors"
"fmt" "fmt"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/log"
"github.com/status-im/status-go/services/shhext/chat/topic" "github.com/status-im/status-go/services/shhext/chat/sharedsecret"
whisper "github.com/status-im/whisper/whisperv6" whisper "github.com/status-im/whisper/whisperv6"
"math/big" "math/big"
"sync" "sync"
@ -18,79 +19,262 @@ const (
// The number of partitions // The number of partitions
var nPartitions = big.NewInt(5000) var nPartitions = big.NewInt(5000)
var minPow = 0.0
func toTopic(s string) []byte {
return crypto.Keccak256([]byte(s))[:whisper.TopicLength]
}
func chatIDToPartitionedTopic(identity string) (string, error) {
partition := big.NewInt(0)
publicKeyBytes, err := hex.DecodeString(identity)
if err != nil {
return "", err
}
publicKey, err := crypto.UnmarshalPubkey(publicKeyBytes)
if err != nil {
return "", err
}
partition.Mod(publicKey.X, nPartitions)
return fmt.Sprintf("contact-discovery-%d", partition.Int64()), nil
}
type Filter struct { type Filter struct {
FilterID string FilterID string
Topic []byte Topic whisper.TopicType
SymKeyID string SymKeyID string
} }
type Chat struct { type Chat struct {
// ChatID is the identifier of the chat // ChatID is the identifier of the chat
ChatID string ChatID string `json:"chatId"`
// SymKeyID is the symmetric key id used for symmetric chats // SymKeyID is the symmetric key id used for symmetric chats
SymKeyID string SymKeyID string `json:"symKeyId"`
// OneToOne tells us if we need to use asymmetric encryption for this chat // OneToOne tells us if we need to use asymmetric encryption for this chat
OneToOne bool OneToOne bool `json:"oneToOne"`
// Listen is whether we are actually listening for messages on this chat, or the filter is only created in order to be able to post on the topic // Listen is whether we are actually listening for messages on this chat, or the filter is only created in order to be able to post on the topic
Listen bool Listen bool `json:"listen"`
// FilterID the whisper filter id generated // FilterID the whisper filter id generated
FilterID string FilterID string `json:"filterId"`
// Identity is the public key of the other recipient for non-public chats // Identity is the public key of the other recipient for non-public chats
Identity string Identity string `json:"identity"`
// Topic is the whisper topic // Topic is the whisper topic
Topic []byte Topic whisper.TopicType `json:"topic"`
} }
type Service struct { type Service struct {
keyID string
whisper *whisper.Whisper whisper *whisper.Whisper
topic *topic.Service secret *sharedsecret.Service
chats map[string]*Chat chats map[string]*Chat
mutex sync.Mutex mutex sync.Mutex
} }
func New(k string, w *whisper.Whisper, t *topic.Service) *Service { // New returns a new filter service
func New(w *whisper.Whisper, s *sharedsecret.Service) *Service {
return &Service{ return &Service{
keyID: k,
whisper: w, whisper: w,
topic: t, secret: s,
mutex: sync.Mutex{}, mutex: sync.Mutex{},
chats: make(map[string]*Chat), chats: make(map[string]*Chat),
} }
} }
// LoadDiscovery adds the discovery filter // LoadChat should return a list of newly chats loaded
func (s *Service) LoadDiscovery(myKey *ecdsa.PrivateKey) error { func (s *Service) Init(chats []*Chat) ([]*Chat, error) {
log.Debug("Initializing filter service", "chats", chats)
keyID := s.whisper.SelectedKeyPairID()
if keyID == "" {
return nil, errors.New("no key selected")
}
myKey, err := s.whisper.GetPrivateKey(keyID)
if err != nil {
return nil, err
}
// Add our own topic
log.Debug("Loading one to one chats")
identityStr := fmt.Sprintf("%x", crypto.FromECDSAPub(&myKey.PublicKey))
_, err = s.loadOneToOne(myKey, identityStr, true)
if err != nil {
log.Error("Error loading one to one chats", "err", err)
return nil, err
}
// Add discovery topic
log.Debug("Loading discovery topics")
err = s.loadDiscovery(myKey)
if err != nil {
return nil, err
}
// Add the various one to one and public chats
log.Debug("Loading chats")
for _, chat := range chats {
_, err = s.load(myKey, chat)
if err != nil {
return nil, err
}
}
// Add the negotiated secrets
log.Debug("Loading negotiated topics")
secrets, err := s.secret.All()
if err != nil {
return nil, err
}
for _, secret := range secrets {
if _, err := s.ProcessNegotiatedSecret(secret); err != nil {
return nil, err
}
}
s.mutex.Lock() s.mutex.Lock()
defer s.mutex.Unlock() defer s.mutex.Unlock()
discoveryChat := &Chat{ var allChats []*Chat
ChatID: discoveryTopic, for _, chat := range s.chats {
allChats = append(allChats, chat)
}
return allChats, nil
} }
discoveryResponse, err := s.AddAsymmetricFilter(myKey, discoveryChat.ChatID, true) // Stop removes all the filters
func (s *Service) Stop() error {
for _, chat := range s.chats {
if err := s.Remove(chat); err != nil {
return err
}
}
return nil
}
// Remove remove all the filters associated with a chat/identity
func (s *Service) Remove(chat *Chat) error {
s.mutex.Lock()
defer s.mutex.Unlock()
if err := s.whisper.Unsubscribe(chat.FilterID); err != nil {
return err
}
if chat.SymKeyID != "" {
s.whisper.DeleteSymKey(chat.SymKeyID)
}
delete(s.chats, chat.ChatID)
return nil
}
// LoadPartitioned creates a filter for a partitioned topic
func (s *Service) LoadPartitioned(myKey *ecdsa.PrivateKey, theirPublicKey *ecdsa.PublicKey, listen bool) (*Chat, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
chatID := PublicKeyToPartitionedTopic(theirPublicKey)
if _, ok := s.chats[chatID]; ok {
return s.chats[chatID], nil
}
// We set up a filter so we can publish, but we discard envelopes if listen is false
filter, err := s.addAsymmetricFilter(myKey, chatID, listen)
if err != nil {
return nil, err
}
chat := &Chat{
ChatID: chatID,
FilterID: filter.FilterID,
Topic: filter.Topic,
Listen: listen,
}
s.chats[chatID] = chat
return chat, nil
}
// Load creates filters for a given chat, and returns all the created filters
func (s *Service) Load(chat *Chat) ([]*Chat, error) {
keyID := s.whisper.SelectedKeyPairID()
if keyID == "" {
return nil, errors.New("no key selected")
}
myKey, err := s.whisper.GetPrivateKey(keyID)
if err != nil {
return nil, err
}
return s.load(myKey, chat)
}
// Get returns a negotiated filter given an identity
func (s *Service) GetNegotiated(identity *ecdsa.PublicKey) *Chat {
s.mutex.Lock()
defer s.mutex.Unlock()
return s.chats[negotiatedID(identity)]
}
// GetByID returns a filter by chatID
func (s *Service) GetByID(chatID string) *Chat {
s.mutex.Lock()
defer s.mutex.Unlock()
return s.chats[chatID]
}
// ProcessNegotiatedSecret adds a filter based on the agreed secret
func (s *Service) ProcessNegotiatedSecret(secret *sharedsecret.Secret) (*Chat, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
chatID := negotiatedID(secret.Identity)
// If we already have a filter do nothing
if _, ok := s.chats[chatID]; ok {
return s.chats[chatID], nil
}
keyString := fmt.Sprintf("%x", secret.Key)
filter, err := s.addSymmetric(keyString)
if err != nil {
return nil, err
}
identityStr := fmt.Sprintf("%x", crypto.FromECDSAPub(secret.Identity))
chat := &Chat{
ChatID: chatID,
Topic: filter.Topic,
SymKeyID: filter.SymKeyID,
FilterID: filter.FilterID,
Identity: identityStr,
Listen: true,
}
log.Info("PROCESSING SECRET", "chat-id", chatID, "topic", filter.Topic, "symKey", keyString)
s.chats[chat.ChatID] = chat
return chat, nil
}
// ToTopic converts a string to a whisper topic
func ToTopic(s string) []byte {
return crypto.Keccak256([]byte(s))[:whisper.TopicLength]
}
// PublicKeyToPartitionedTopic returns the associated partitioned topic string
// with the given public key
func PublicKeyToPartitionedTopic(publicKey *ecdsa.PublicKey) string {
partition := big.NewInt(0)
partition.Mod(publicKey.X, nPartitions)
return fmt.Sprintf("contact-discovery-%d", partition.Int64())
}
// PublicKeyToPartitionedTopicBytes returns the bytes of the partitioned topic
// associated with the given public key
func PublicKeyToPartitionedTopicBytes(publicKey *ecdsa.PublicKey) []byte {
return ToTopic(PublicKeyToPartitionedTopic(publicKey))
}
// loadDiscovery adds the discovery filter
func (s *Service) loadDiscovery(myKey *ecdsa.PrivateKey) error {
s.mutex.Lock()
defer s.mutex.Unlock()
if _, ok := s.chats[discoveryTopic]; ok {
return nil
}
discoveryChat := &Chat{
ChatID: discoveryTopic,
Listen: true,
}
discoveryResponse, err := s.addAsymmetricFilter(myKey, discoveryChat.ChatID, true)
if err != nil { if err != nil {
return err return err
} }
@ -102,125 +286,93 @@ func (s *Service) LoadDiscovery(myKey *ecdsa.PrivateKey) error {
return nil return nil
} }
func (s *Service) Init(chats []*Chat) error { // loadPublic adds a filter for a public chat
log.Debug("Initializing filter service") func (s *Service) loadPublic(chat *Chat) error {
myKey, err := s.whisper.GetPrivateKey(s.keyID)
if err != nil {
return err
}
// Add our own topic
log.Debug("Loading one to one chats")
identityStr := fmt.Sprintf("%x", crypto.FromECDSAPub(&myKey.PublicKey))
err = s.LoadOneToOne(myKey, identityStr, true)
if err != nil {
log.Error("Error loading one to one chats", "err", err)
return err
}
// Add discovery topic
log.Debug("Loading discovery topics")
err = s.LoadDiscovery(myKey)
if err != nil {
return err
}
// Add the various one to one and public chats
log.Debug("Loading chats")
for _, chat := range chats {
err = s.Load(myKey, chat)
if err != nil {
return err
}
}
// Add the negotiated topics
log.Debug("Loading negotiated topics")
secrets, err := s.topic.All()
if err != nil {
return err
}
for _, secret := range secrets {
if err := s.ProcessNegotiatedSecret(secret); err != nil {
return err
}
}
return nil
}
func (s *Service) Stop() error {
for _, chat := range s.chats {
if err := s.Remove(chat); err != nil {
return err
}
}
return nil
}
func (s *Service) Remove(chat *Chat) error {
s.mutex.Lock() s.mutex.Lock()
defer s.mutex.Unlock() defer s.mutex.Unlock()
if err := s.whisper.Unsubscribe(chat.ChatID); err != nil { if _, ok := s.chats[chat.ChatID]; ok {
return err
}
if chat.SymKeyID != "" {
s.whisper.DeleteSymKey(chat.SymKeyID)
}
delete(s.chats, chat.ChatID)
return nil return nil
} }
// LoadOneToOne creates two filters for a given chat, one listening to the contact codes filterAndTopic, err := s.addSymmetric(chat.ChatID)
// and another on the partitioned topic. We pass a listen parameter to indicated whether
// we are listening to messages on the partitioned topic
func (s *Service) LoadOneToOne(myKey *ecdsa.PrivateKey, identity string, listen bool) error {
s.mutex.Lock()
defer s.mutex.Unlock()
contactCodeChatID := identity + "-contact-code"
contactCodeFilter, err := s.AddSymmetric(contactCodeChatID)
if err != nil { if err != nil {
return err return err
} }
s.chats[contactCodeChatID] = &Chat{ chat.FilterID = filterAndTopic.FilterID
ChatID: contactCodeChatID, chat.SymKeyID = filterAndTopic.SymKeyID
chat.Topic = filterAndTopic.Topic
chat.Listen = true
s.chats[chat.ChatID] = chat
return nil
}
// loadOneToOne creates two filters for a given chat, one listening to the contact codes
// and another on the partitioned topic, if listen is specified.
func (s *Service) loadOneToOne(myKey *ecdsa.PrivateKey, identity string, listen bool) ([]*Chat, error) {
var chats []*Chat
contactCodeChat, err := s.loadContactCode(identity)
if err != nil {
return nil, err
}
chats = append(chats, contactCodeChat)
if listen {
publicKeyBytes, err := hex.DecodeString(identity)
if err != nil {
return nil, err
}
publicKey, err := crypto.UnmarshalPubkey(publicKeyBytes)
if err != nil {
return nil, err
}
partitionedChat, err := s.LoadPartitioned(myKey, publicKey, listen)
if err != nil {
return nil, err
}
chats = append(chats, partitionedChat)
}
return chats, nil
}
// loadContactCode creates a filter for the topic are advertised for a given identity
func (s *Service) loadContactCode(identity string) (*Chat, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
chatID := "0x" + identity + "-contact-code"
if _, ok := s.chats[chatID]; ok {
return s.chats[chatID], nil
}
contactCodeFilter, err := s.addSymmetric(chatID)
if err != nil {
return nil, err
}
chat := &Chat{
ChatID: chatID,
FilterID: contactCodeFilter.FilterID, FilterID: contactCodeFilter.FilterID,
Topic: contactCodeFilter.Topic, Topic: contactCodeFilter.Topic,
SymKeyID: contactCodeFilter.SymKeyID, SymKeyID: contactCodeFilter.SymKeyID,
Identity: identity, Identity: identity,
Listen: true,
} }
partitionedTopicChatID, err := chatIDToPartitionedTopic(identity) s.chats[chatID] = chat
if err != nil { return chat, nil
return err
}
// We set up a filter so we can publish, but we discard envelopes if listen is false
partitionedTopicFilter, err := s.AddAsymmetricFilter(myKey, partitionedTopicChatID, listen)
if err != nil {
return err
}
s.chats[partitionedTopicChatID] = &Chat{
ChatID: partitionedTopicChatID,
FilterID: partitionedTopicFilter.FilterID,
Topic: partitionedTopicFilter.Topic,
Identity: identity,
Listen: listen,
} }
return nil // addSymmetric adds a symmetric key filter
} func (s *Service) addSymmetric(chatID string) (*Filter, error) {
func (s *Service) AddSymmetric(chatID string) (*Filter, error) {
var symKey []byte var symKey []byte
topic := toTopic(chatID) topic := ToTopic(chatID)
topics := [][]byte{topic} topics := [][]byte{topic}
symKeyID, err := s.whisper.AddSymKeyFromPassword(chatID) symKeyID, err := s.whisper.AddSymKeyFromPassword(chatID)
@ -235,7 +387,7 @@ func (s *Service) AddSymmetric(chatID string) (*Filter, error) {
f := &whisper.Filter{ f := &whisper.Filter{
KeySym: symKey, KeySym: symKey,
PoW: 0.002, PoW: minPow,
AllowP2P: true, AllowP2P: true,
Topics: topics, Topics: topics,
Messages: s.whisper.NewMessageStore(), Messages: s.whisper.NewMessageStore(),
@ -249,28 +401,29 @@ func (s *Service) AddSymmetric(chatID string) (*Filter, error) {
return &Filter{ return &Filter{
FilterID: id, FilterID: id,
SymKeyID: symKeyID, SymKeyID: symKeyID,
Topic: topic, Topic: whisper.BytesToTopic(topic),
}, nil }, nil
} }
func (s *Service) AddAsymmetricFilter(keyAsym *ecdsa.PrivateKey, chatID string, listen bool) (*Filter, error) { // addAsymmetricFilter adds a filter with our privatekey, and set minPow according to the listen parameter
func (s *Service) addAsymmetricFilter(keyAsym *ecdsa.PrivateKey, chatID string, listen bool) (*Filter, error) {
var err error var err error
var pow float64 var pow float64
if listen { if listen {
pow = 0.002 pow = minPow
} else { } else {
// Set high pow so we discard messages // Set high pow so we discard messages
pow = 1 pow = 1
} }
topic := toTopic(chatID) topic := ToTopic(chatID)
topics := [][]byte{topic} topics := [][]byte{topic}
f := &whisper.Filter{ f := &whisper.Filter{
KeyAsym: keyAsym, KeyAsym: keyAsym,
PoW: pow, PoW: pow,
AllowP2P: listen, AllowP2P: true,
Topics: topics, Topics: topics,
Messages: s.whisper.NewMessageStore(), Messages: s.whisper.NewMessageStore(),
} }
@ -280,73 +433,19 @@ func (s *Service) AddAsymmetricFilter(keyAsym *ecdsa.PrivateKey, chatID string,
return nil, err return nil, err
} }
return &Filter{FilterID: id, Topic: topic}, nil return &Filter{FilterID: id, Topic: whisper.BytesToTopic(topic)}, nil
}
func (s *Service) LoadPublic(chat *Chat) error {
s.mutex.Lock()
defer s.mutex.Unlock()
filterAndTopic, err := s.AddSymmetric(chat.ChatID)
if err != nil {
return err
}
// Add mutex
chat.FilterID = filterAndTopic.FilterID
chat.SymKeyID = filterAndTopic.SymKeyID
chat.Topic = filterAndTopic.Topic
s.chats[chat.ChatID] = chat
return nil
}
func (s *Service) Load(myKey *ecdsa.PrivateKey, chat *Chat) error {
var err error
log.Debug("Loading chat", "chatID", chat.ChatID)
// Check we haven't already loaded the chat
if _, ok := s.chats[chat.ChatID]; !ok {
if chat.OneToOne {
err = s.LoadOneToOne(myKey, chat.Identity, false)
} else {
err = s.LoadPublic(chat)
}
if err != nil {
return err
}
}
return nil
} }
func negotiatedID(identity *ecdsa.PublicKey) string { func negotiatedID(identity *ecdsa.PublicKey) string {
return fmt.Sprintf("%x-negotiated", crypto.FromECDSAPub(identity)) return fmt.Sprintf("0x%x-negotiated", crypto.FromECDSAPub(identity))
} }
func (s *Service) Get(identity *ecdsa.PublicKey) *Chat { func (s *Service) load(myKey *ecdsa.PrivateKey, chat *Chat) ([]*Chat, error) {
return s.chats[negotiatedID(identity)] log.Debug("Loading chat", "chatID", chat.ChatID)
if chat.OneToOne {
return s.loadOneToOne(myKey, chat.Identity, false)
} }
return []*Chat{chat}, s.loadPublic(chat)
func (s *Service) ProcessNegotiatedSecret(secret *topic.Secret) error {
s.mutex.Lock()
defer s.mutex.Unlock()
keyString := fmt.Sprintf("%x", secret.Key)
filter, err := s.AddSymmetric(keyString)
if err != nil {
return err
}
identityStr := fmt.Sprintf("0x%x", crypto.FromECDSAPub(secret.Identity))
chat := &Chat{
ChatID: negotiatedID(secret.Identity),
Topic: filter.Topic,
SymKeyID: filter.SymKeyID,
Identity: identityStr,
}
s.chats[chat.ChatID] = chat
return nil
} }

View File

@ -9,7 +9,7 @@ import (
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
appDB "github.com/status-im/status-go/services/shhext/chat/db" appDB "github.com/status-im/status-go/services/shhext/chat/db"
"github.com/status-im/status-go/services/shhext/chat/topic" "github.com/status-im/status-go/services/shhext/chat/sharedsecret"
whisper "github.com/status-im/whisper/whisperv6" whisper "github.com/status-im/whisper/whisperv6"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
) )
@ -51,7 +51,7 @@ func (s *ServiceTestSuite) SetupTest() {
keyStrs := []string{"c6cbd7d76bc5baca530c875663711b947efa6a86a900a9e8645ce32e5821484e", "d51dd64ad19ea84968a308dca246012c00d2b2101d41bce740acd1c650acc509"} keyStrs := []string{"c6cbd7d76bc5baca530c875663711b947efa6a86a900a9e8645ce32e5821484e", "d51dd64ad19ea84968a308dca246012c00d2b2101d41bce740acd1c650acc509"}
keyTopics := []int{4490, 3991} keyTopics := []int{4490, 3991}
dbFile, err := ioutil.TempFile(os.TempDir(), "topic") dbFile, err := ioutil.TempFile(os.TempDir(), "filter")
s.Require().NoError(err) s.Require().NoError(err)
s.path = dbFile.Name() s.path = dbFile.Name()
@ -67,12 +67,12 @@ func (s *ServiceTestSuite) SetupTest() {
s.Require().NoError(err) s.Require().NoError(err)
// Build services // Build services
topicService := topic.NewService(topic.NewSQLLitePersistence(db)) sharedSecretService := sharedsecret.NewService(sharedsecret.NewSQLLitePersistence(db))
whisper := whisper.New(nil) whisper := whisper.New(nil)
keyID, err := whisper.AddKeyPair(s.keys[0].privateKey) _, err = whisper.AddKeyPair(s.keys[0].privateKey)
s.Require().NoError(err) s.Require().NoError(err)
s.service = New(keyID, whisper, topicService) s.service = New(whisper, sharedSecretService)
} }
func (s *ServiceTestSuite) TearDownTest() { func (s *ServiceTestSuite) TearDownTest() {
@ -82,21 +82,24 @@ func (s *ServiceTestSuite) TearDownTest() {
func (s *ServiceTestSuite) TestDiscoveryAndPartitionedTopic() { func (s *ServiceTestSuite) TestDiscoveryAndPartitionedTopic() {
chats := []*Chat{} chats := []*Chat{}
partitionedTopic := fmt.Sprintf("contact-discovery-%d", s.keys[0].partitionedTopic) partitionedTopic := fmt.Sprintf("contact-discovery-%d", s.keys[0].partitionedTopic)
contactCodeTopic := s.keys[0].PublicKeyString() + "-contact-code" contactCodeTopic := "0x" + s.keys[0].PublicKeyString() + "-contact-code"
err := s.service.Init(chats) _, err := s.service.Init(chats)
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Equal(3, len(s.service.chats), "It creates two filters") s.Require().Equal(3, len(s.service.chats), "It creates two filters")
discoveryFilter := s.service.chats[discoveryTopic] discoveryFilter := s.service.chats[discoveryTopic]
s.Require().NotNil(discoveryFilter, "It adds the discovery filter") s.Require().NotNil(discoveryFilter, "It adds the discovery filter")
s.Require().True(discoveryFilter.Listen)
contactCodeFilter := s.service.chats[contactCodeTopic] contactCodeFilter := s.service.chats[contactCodeTopic]
s.Require().NotNil(contactCodeFilter, "It adds the contact code filter") s.Require().NotNil(contactCodeFilter, "It adds the contact code filter")
s.Require().True(contactCodeFilter.Listen)
partitionedFilter := s.service.chats[partitionedTopic] partitionedFilter := s.service.chats[partitionedTopic]
s.Require().NotNil(partitionedFilter, "It adds the partitioned filter") s.Require().NotNil(partitionedFilter, "It adds the partitioned filter")
s.Require().True(partitionedFilter.Listen)
} }
func (s *ServiceTestSuite) TestPublicAndOneToOneChats() { func (s *ServiceTestSuite) TestPublicAndOneToOneChats() {
@ -110,46 +113,95 @@ func (s *ServiceTestSuite) TestPublicAndOneToOneChats() {
OneToOne: true, OneToOne: true,
}, },
} }
partitionedTopic := fmt.Sprintf("contact-discovery-%d", s.keys[1].partitionedTopic) contactCodeTopic := "0x" + s.keys[1].PublicKeyString() + "-contact-code"
contactCodeTopic := s.keys[1].PublicKeyString() + "-contact-code"
err := s.service.Init(chats) response, err := s.service.Init(chats)
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Equal(6, len(s.service.chats), "It creates two additional filters for the one to one and one for the public chat") actualChats := make(map[string]*Chat)
statusFilter := s.service.chats["status"] for _, chat := range response {
actualChats[chat.ChatID] = chat
}
s.Require().Equal(5, len(actualChats), "It creates two additional filters for the one to one and one for the public chat")
statusFilter := actualChats["status"]
s.Require().NotNil(statusFilter, "It creates a filter for the public chat") s.Require().NotNil(statusFilter, "It creates a filter for the public chat")
s.Require().NotNil(statusFilter.SymKeyID, "It returns a sym key id") s.Require().NotNil(statusFilter.SymKeyID, "It returns a sym key id")
s.Require().True(statusFilter.Listen)
contactCodeFilter := s.service.chats[contactCodeTopic] contactCodeFilter := actualChats[contactCodeTopic]
s.Require().NotNil(contactCodeFilter, "It adds the contact code filter") s.Require().NotNil(contactCodeFilter, "It adds the contact code filter")
s.Require().True(contactCodeFilter.Listen)
partitionedFilter := s.service.chats[partitionedTopic]
s.Require().NotNil(partitionedFilter, "It adds the partitioned filter")
} }
func (s *ServiceTestSuite) TestNegotiatedTopic() { func (s *ServiceTestSuite) TestNegotiatedTopic() {
chats := []*Chat{} chats := []*Chat{}
negotiatedTopic1 := s.keys[0].PublicKeyString() + "-negotiated" negotiatedTopic1 := "0x" + s.keys[0].PublicKeyString() + "-negotiated"
negotiatedTopic2 := s.keys[1].PublicKeyString() + "-negotiated" negotiatedTopic2 := "0x" + s.keys[1].PublicKeyString() + "-negotiated"
// We send a message to ourselves // We send a message to ourselves
_, _, err := s.service.topic.Send(s.keys[0].privateKey, "0-1", &s.keys[0].privateKey.PublicKey, []string{"0-2"}) _, _, err := s.service.secret.Send(s.keys[0].privateKey, "0-1", &s.keys[0].privateKey.PublicKey, []string{"0-2"})
s.Require().NoError(err) s.Require().NoError(err)
// We send a message to someone else // We send a message to someone else
_, _, err = s.service.topic.Send(s.keys[0].privateKey, "0-1", &s.keys[1].privateKey.PublicKey, []string{"0-2"}) _, _, err = s.service.secret.Send(s.keys[0].privateKey, "0-1", &s.keys[1].privateKey.PublicKey, []string{"0-2"})
s.Require().NoError(err) s.Require().NoError(err)
err = s.service.Init(chats) response, err := s.service.Init(chats)
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Equal(5, len(s.service.chats), "It creates two additional filters for the negotiated topics") actualChats := make(map[string]*Chat)
negotiatedFilter1 := s.service.chats[negotiatedTopic1] for _, chat := range response {
actualChats[chat.ChatID] = chat
}
s.Require().Equal(5, len(actualChats), "It creates two additional filters for the negotiated topics")
negotiatedFilter1 := actualChats[negotiatedTopic1]
s.Require().NotNil(negotiatedFilter1, "It adds the negotiated filter") s.Require().NotNil(negotiatedFilter1, "It adds the negotiated filter")
negotiatedFilter2 := s.service.chats[negotiatedTopic2] negotiatedFilter2 := actualChats[negotiatedTopic2]
s.Require().NotNil(negotiatedFilter2, "It adds the negotiated filter") s.Require().NotNil(negotiatedFilter2, "It adds the negotiated filter")
} }
func (s *ServiceTestSuite) TestLoadChat() {
chats := []*Chat{}
_, err := s.service.Init(chats)
s.Require().NoError(err)
// We add a public chat
response1, err := s.service.Load(&Chat{ChatID: "status"})
s.Require().NoError(err)
s.Require().Equal(1, len(response1))
s.Require().Equal("status", response1[0].ChatID)
s.Require().True(response1[0].Listen)
}
func (s *ServiceTestSuite) TestNoInstallationIDs() {
chats := []*Chat{}
negotiatedTopic1 := "0x" + s.keys[1].PublicKeyString() + "-negotiated"
// We send a message to someone else, but without any installation ID
_, _, err := s.service.secret.Send(s.keys[0].privateKey, "0-1", &s.keys[1].privateKey.PublicKey, []string{})
s.Require().NoError(err)
response, err := s.service.Init(chats)
s.Require().NoError(err)
actualChats := make(map[string]*Chat)
for _, chat := range response {
actualChats[chat.ChatID] = chat
}
s.Require().Equal(4, len(actualChats), "It creates two additional filters for the negotiated topics")
negotiatedFilter1 := actualChats[negotiatedTopic1]
s.Require().NotNil(negotiatedFilter1, "It adds the negotiated filter")
}

View File

@ -0,0 +1,446 @@
package publisher
import (
"context"
"crypto/ecdsa"
"errors"
"fmt"
"github.com/ethereum/go-ethereum/common/hexutil"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/log"
"github.com/golang/protobuf/proto"
"github.com/status-im/status-go/services/shhext/chat"
appDB "github.com/status-im/status-go/services/shhext/chat/db"
"github.com/status-im/status-go/services/shhext/chat/multidevice"
"github.com/status-im/status-go/services/shhext/chat/protobuf"
"github.com/status-im/status-go/services/shhext/chat/sharedsecret"
"github.com/status-im/status-go/services/shhext/dedup"
"github.com/status-im/status-go/services/shhext/filter"
"github.com/status-im/status-go/services/shhext/whisperutils"
"github.com/status-im/status-go/signal"
whisper "github.com/status-im/whisper/whisperv6"
"golang.org/x/crypto/sha3"
"os"
"path/filepath"
"time"
)
const (
tickerInterval = 120
maxInstallations = 3
)
var (
errProtocolNotInitialized = errors.New("procotol is not initialized")
// ErrPFSNotEnabled is returned when an endpoint PFS only is called but
// PFS is disabled
ErrPFSNotEnabled = errors.New("pfs not enabled")
)
//type Persistence interface {
//}
type Service struct {
whisper *whisper.Whisper
whisperAPI *whisper.PublicWhisperAPI
protocol *chat.ProtocolService
// persistence Persistence
log log.Logger
filter *filter.Service
config *Config
quit chan struct{}
ticker *time.Ticker
}
type Config struct {
PfsEnabled bool
DataDir string
InstallationID string
}
func New(config *Config, w *whisper.Whisper) *Service {
return &Service{
config: config,
whisper: w,
whisperAPI: whisper.NewPublicWhisperAPI(w),
log: log.New("package", "status-go/services/publisher.Service"),
}
}
// InitProtocolWithPassword creates an instance of ProtocolService given an address and password used to generate an encryption key.
func (s *Service) InitProtocolWithPassword(address string, password string) error {
digest := sha3.Sum256([]byte(password))
encKey := fmt.Sprintf("%x", digest)
return s.initProtocol(address, encKey, password)
}
// InitProtocolWithEncyptionKey creates an instance of ProtocolService given an address and encryption key.
func (s *Service) InitProtocolWithEncyptionKey(address string, encKey string) error {
return s.initProtocol(address, encKey, "")
}
func (s *Service) initProtocol(address, encKey, password string) error {
if !s.config.PfsEnabled {
return nil
}
if err := os.MkdirAll(filepath.Clean(s.config.DataDir), os.ModePerm); err != nil {
return err
}
v0Path := filepath.Join(s.config.DataDir, fmt.Sprintf("%x.db", address))
v1Path := filepath.Join(s.config.DataDir, fmt.Sprintf("%s.db", s.config.InstallationID))
v2Path := filepath.Join(s.config.DataDir, fmt.Sprintf("%s.v2.db", s.config.InstallationID))
v3Path := filepath.Join(s.config.DataDir, fmt.Sprintf("%s.v3.db", s.config.InstallationID))
v4Path := filepath.Join(s.config.DataDir, fmt.Sprintf("%s.v4.db", s.config.InstallationID))
if password != "" {
if err := appDB.MigrateDBFile(v0Path, v1Path, "ON", password); err != nil {
return err
}
if err := appDB.MigrateDBFile(v1Path, v2Path, password, encKey); err != nil {
// Remove db file as created with a blank password and never used,
// and there's no need to rekey in this case
os.Remove(v1Path)
os.Remove(v2Path)
}
}
if err := appDB.MigrateDBKeyKdfIterations(v2Path, v3Path, encKey); err != nil {
os.Remove(v2Path)
os.Remove(v3Path)
}
// Fix IOS not encrypting database
if err := appDB.EncryptDatabase(v3Path, v4Path, encKey); err != nil {
os.Remove(v3Path)
os.Remove(v4Path)
}
// Desktop was passing a network dependent directory, which meant that
// if running on testnet it would not access the right db. This copies
// the db from mainnet to the root location.
networkDependentPath := filepath.Join(s.config.DataDir, "ethereum", "mainnet_rpc", fmt.Sprintf("%s.v4.db", s.config.InstallationID))
if _, err := os.Stat(networkDependentPath); err == nil {
if err := os.Rename(networkDependentPath, v4Path); err != nil {
return err
}
} else if !os.IsNotExist(err) {
return err
}
persistence, err := chat.NewSQLLitePersistence(v4Path, encKey)
if err != nil {
return err
}
addedBundlesHandler := func(addedBundles []multidevice.IdentityAndIDPair) {
handler := SignalHandler{}
for _, bundle := range addedBundles {
handler.BundleAdded(bundle[0], bundle[1])
}
}
// Initialize sharedsecret
sharedSecretService := sharedsecret.NewService(persistence.GetSharedSecretStorage())
// Initialize filter
filterService := filter.New(s.whisper, sharedSecretService)
s.filter = filterService
// Initialize multidevice
multideviceConfig := &multidevice.Config{
InstallationID: s.config.InstallationID,
ProtocolVersion: chat.ProtocolVersion,
MaxInstallations: maxInstallations,
}
multideviceService := multidevice.New(multideviceConfig, persistence.GetMultideviceStorage())
s.protocol = chat.NewProtocolService(
chat.NewEncryptionService(
persistence,
chat.DefaultEncryptionServiceConfig(s.config.InstallationID)),
sharedSecretService,
multideviceService,
addedBundlesHandler,
s.onNewSharedSecretHandler)
return nil
}
func (s *Service) ProcessPublicBundle(myIdentityKey *ecdsa.PrivateKey, bundle *protobuf.Bundle) ([]multidevice.IdentityAndIDPair, error) {
if s.protocol == nil {
return nil, errProtocolNotInitialized
}
return s.protocol.ProcessPublicBundle(myIdentityKey, bundle)
}
func (s *Service) GetBundle(myIdentityKey *ecdsa.PrivateKey) (*protobuf.Bundle, error) {
if s.protocol == nil {
return nil, errProtocolNotInitialized
}
return s.protocol.GetBundle(myIdentityKey)
}
// EnableInstallation enables an installation for multi-device sync.
func (s *Service) EnableInstallation(myIdentityKey *ecdsa.PublicKey, installationID string) error {
if s.protocol == nil {
return errProtocolNotInitialized
}
return s.protocol.EnableInstallation(myIdentityKey, installationID)
}
func (s *Service) GetPublicBundle(identityKey *ecdsa.PublicKey) (*protobuf.Bundle, error) {
if s.protocol == nil {
return nil, errProtocolNotInitialized
}
return s.protocol.GetPublicBundle(identityKey)
}
// DisableInstallation disables an installation for multi-device sync.
func (s *Service) DisableInstallation(myIdentityKey *ecdsa.PublicKey, installationID string) error {
if s.protocol == nil {
return errProtocolNotInitialized
}
return s.protocol.DisableInstallation(myIdentityKey, installationID)
}
func (s *Service) Start() error {
s.startTicker()
return nil
}
func (s *Service) Stop() error {
if s.filter != nil {
if err := s.filter.Stop(); err != nil {
log.Error("Failed to stop filter service with error", "err", err)
}
}
return nil
}
func (s *Service) GetNegotiatedChat(identity *ecdsa.PublicKey) *filter.Chat {
return s.filter.GetNegotiated(identity)
}
func (s *Service) LoadFilters(chats []*filter.Chat) ([]*filter.Chat, error) {
return s.filter.Init(chats)
}
func (s *Service) LoadFilter(chat *filter.Chat) ([]*filter.Chat, error) {
return s.filter.Load(chat)
}
func (s *Service) RemoveFilter(chat *filter.Chat) error {
return s.filter.Remove(chat)
}
func (s *Service) onNewSharedSecretHandler(sharedSecrets []*sharedsecret.Secret) {
var filters []*signal.Filter
for _, sharedSecret := range sharedSecrets {
chat, err := s.filter.ProcessNegotiatedSecret(sharedSecret)
if err != nil {
log.Error("Failed to process negotiated secret", "err", err)
return
}
filter := &signal.Filter{
ChatID: chat.ChatID,
SymKeyID: chat.SymKeyID,
Listen: chat.Listen,
FilterID: chat.FilterID,
Identity: chat.Identity,
Topic: chat.Topic,
}
filters = append(filters, filter)
}
if len(filters) != 0 {
handler := SignalHandler{}
handler.WhisperFilterAdded(filters)
}
}
func (s *Service) ProcessMessage(dedupMessage dedup.DeduplicateMessage) error {
if !s.config.PfsEnabled {
return nil
}
msg := dedupMessage.Message
privateKeyID := s.whisper.SelectedKeyPairID()
if privateKeyID == "" {
return errors.New("no key selected")
}
privateKey, err := s.whisper.GetPrivateKey(privateKeyID)
if err != nil {
return err
}
publicKey, err := crypto.UnmarshalPubkey(msg.Sig)
if err != nil {
return err
}
// Unmarshal message
protocolMessage := &protobuf.ProtocolMessage{}
if err := proto.Unmarshal(msg.Payload, protocolMessage); err != nil {
s.log.Debug("Not a protocol message", "err", err)
return nil
}
response, err := s.protocol.HandleMessage(privateKey, publicKey, protocolMessage, dedupMessage.DedupID)
switch err {
case nil:
// Set the decrypted payload
msg.Payload = response
case chat.ErrDeviceNotFound:
// Notify that someone tried to contact us using an invalid bundle
if privateKey.PublicKey != *publicKey {
s.log.Warn("Device not found, sending signal", "err", err)
keyString := fmt.Sprintf("0x%x", crypto.FromECDSAPub(publicKey))
handler := SignalHandler{}
handler.DecryptMessageFailed(keyString)
}
default:
// Log and pass to the client, even if failed to decrypt
s.log.Error("Failed handling message with error", "err", err)
}
return nil
}
// SendDirectMessage sends a 1:1 chat message to the underlying transport
func (s *Service) SendDirectMessage(ctx context.Context, msg chat.SendDirectMessageRPC) (hexutil.Bytes, error) {
if !s.config.PfsEnabled {
return nil, ErrPFSNotEnabled
}
privateKey, err := s.whisper.GetPrivateKey(msg.Sig)
if err != nil {
return nil, err
}
publicKey, err := crypto.UnmarshalPubkey(msg.PubKey)
if err != nil {
return nil, err
}
var msgSpec *chat.ProtocolMessageSpec
if msg.DH {
s.log.Debug("Building dh message")
msgSpec, err = s.protocol.BuildDHMessage(privateKey, publicKey, msg.Payload)
} else {
s.log.Debug("Building direct message")
msgSpec, err = s.protocol.BuildDirectMessage(privateKey, publicKey, msg.Payload)
}
if err != nil {
return nil, err
}
whisperMessage, err := s.directMessageToWhisper(privateKey, publicKey, msg.PubKey, msg.Sig, msgSpec)
if err != nil {
s.log.Error("sshext-service", "error building whisper message", err)
return nil, err
}
return s.whisperAPI.Post(ctx, *whisperMessage)
}
func (s *Service) directMessageToWhisper(myPrivateKey *ecdsa.PrivateKey, theirPublicKey *ecdsa.PublicKey, destination hexutil.Bytes, signature string, spec *chat.ProtocolMessageSpec) (*whisper.NewMessage, error) {
// marshal for sending to wire
marshaledMessage, err := proto.Marshal(spec.Message)
if err != nil {
s.log.Error("encryption-service", "error marshaling message", err)
return nil, err
}
whisperMessage := whisperutils.DefaultWhisperMessage()
whisperMessage.Payload = marshaledMessage
whisperMessage.Sig = signature
if spec.SharedSecret != nil {
chat := s.GetNegotiatedChat(theirPublicKey)
if chat != nil {
s.log.Debug("Sending on negotiated topic")
whisperMessage.SymKeyID = chat.SymKeyID
whisperMessage.Topic = chat.Topic
whisperMessage.PublicKey = nil
return &whisperMessage, nil
}
} else if spec.PartitionedTopic() {
s.log.Debug("Sending on partitioned topic")
// Create filter on demand
if _, err := s.filter.LoadPartitioned(myPrivateKey, theirPublicKey, false); err != nil {
return nil, err
}
t := filter.PublicKeyToPartitionedTopicBytes(theirPublicKey)
whisperMessage.Topic = whisper.BytesToTopic(t)
whisperMessage.PublicKey = destination
return &whisperMessage, nil
}
s.log.Debug("Sending on old discovery topic")
whisperMessage.Topic = whisperutils.DiscoveryTopicBytes
whisperMessage.PublicKey = destination
return &whisperMessage, nil
}
// SendPublicMessage sends a public chat message to the underlying transport
func (s *Service) SendPublicMessage(ctx context.Context, msg chat.SendPublicMessageRPC) (hexutil.Bytes, error) {
if !s.config.PfsEnabled {
return nil, ErrPFSNotEnabled
}
filter := s.filter.GetByID(msg.Chat)
if filter == nil {
return nil, errors.New("not subscribed to chat")
}
// Enrich with transport layer info
whisperMessage := whisperutils.DefaultWhisperMessage()
whisperMessage.Payload = msg.Payload
whisperMessage.Sig = msg.Sig
whisperMessage.Topic = whisperutils.ToTopic(msg.Chat)
whisperMessage.SymKeyID = filter.SymKeyID
// And dispatch
return s.whisperAPI.Post(ctx, whisperMessage)
}
func (s *Service) ConfirmMessagesProcessed(ids [][]byte) error {
return s.protocol.ConfirmMessagesProcessed(ids)
}
func (s *Service) startTicker() {
s.ticker = time.NewTicker(tickerInterval * time.Second)
s.quit = make(chan struct{})
go func() {
for {
select {
case <-s.ticker.C:
err := s.perform()
if err != nil {
s.log.Error("could not execute tick", "err", err)
}
case <-s.quit:
s.ticker.Stop()
return
}
}
}()
}
func (s *Service) perform() error {
return nil
}

View File

@ -0,0 +1,20 @@
package publisher
import (
"github.com/status-im/status-go/signal"
)
// SignalHandler sends signals on protocol events
type SignalHandler struct{}
func (h SignalHandler) DecryptMessageFailed(pubKey string) {
signal.SendDecryptMessageFailed(pubKey)
}
func (h SignalHandler) BundleAdded(identity string, installationID string) {
signal.SendBundleAdded(identity, installationID)
}
func (h SignalHandler) WhisperFilterAdded(filters []*signal.Filter) {
signal.SendWhisperFilterAdded(filters)
}

View File

@ -2,10 +2,6 @@ package shhext
import ( import (
"crypto/ecdsa" "crypto/ecdsa"
"errors"
"fmt"
"os"
"path/filepath"
"time" "time"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
@ -16,16 +12,12 @@ import (
"github.com/ethereum/go-ethereum/rpc" "github.com/ethereum/go-ethereum/rpc"
"github.com/status-im/status-go/db" "github.com/status-im/status-go/db"
"github.com/status-im/status-go/params" "github.com/status-im/status-go/params"
"github.com/status-im/status-go/services/shhext/chat"
appDB "github.com/status-im/status-go/services/shhext/chat/db"
"github.com/status-im/status-go/services/shhext/chat/topic"
"github.com/status-im/status-go/services/shhext/dedup" "github.com/status-im/status-go/services/shhext/dedup"
"github.com/status-im/status-go/services/shhext/filter" "github.com/status-im/status-go/services/shhext/filter"
"github.com/status-im/status-go/services/shhext/mailservers" "github.com/status-im/status-go/services/shhext/mailservers"
"github.com/status-im/status-go/signal" "github.com/status-im/status-go/services/shhext/publisher"
whisper "github.com/status-im/whisper/whisperv6" whisper "github.com/status-im/whisper/whisperv6"
"github.com/syndtr/goleveldb/leveldb" "github.com/syndtr/goleveldb/leveldb"
"golang.org/x/crypto/sha3"
) )
const ( const (
@ -35,8 +27,6 @@ const (
defaultTimeoutWaitAdded = 5 * time.Second defaultTimeoutWaitAdded = 5 * time.Second
) )
var errProtocolNotInitialized = errors.New("protocol is not initialized")
// EnvelopeEventsHandler used for two different event types. // EnvelopeEventsHandler used for two different event types.
type EnvelopeEventsHandler interface { type EnvelopeEventsHandler interface {
EnvelopeSent(common.Hash) EnvelopeSent(common.Hash)
@ -47,6 +37,7 @@ type EnvelopeEventsHandler interface {
// Service is a service that provides some additional Whisper API. // Service is a service that provides some additional Whisper API.
type Service struct { type Service struct {
*publisher.Service
storage db.TransactionalStorage storage db.TransactionalStorage
w *whisper.Whisper w *whisper.Whisper
config params.ShhextConfig config params.ShhextConfig
@ -57,10 +48,6 @@ type Service struct {
server *p2p.Server server *p2p.Server
nodeID *ecdsa.PrivateKey nodeID *ecdsa.PrivateKey
deduplicator *dedup.Deduplicator deduplicator *dedup.Deduplicator
protocol *chat.ProtocolService
dataDir string
installationID string
pfsEnabled bool
peerStore *mailservers.PeerStore peerStore *mailservers.PeerStore
cache *mailservers.Cache cache *mailservers.Cache
connManager *mailservers.ConnectionManager connManager *mailservers.ConnectionManager
@ -71,7 +58,7 @@ type Service struct {
// Make sure that Service implements node.Service interface. // Make sure that Service implements node.Service interface.
var _ node.Service = (*Service)(nil) var _ node.Service = (*Service)(nil)
// New returns a new Service. dataDir is a folder path to a network-independent location // New returns a new Service.
func New(w *whisper.Whisper, handler EnvelopeEventsHandler, ldb *leveldb.DB, config params.ShhextConfig) *Service { func New(w *whisper.Whisper, handler EnvelopeEventsHandler, ldb *leveldb.DB, config params.ShhextConfig) *Service {
cache := mailservers.NewCache(ldb) cache := mailservers.NewCache(ldb)
ps := mailservers.NewPeerStore(cache) ps := mailservers.NewPeerStore(cache)
@ -88,7 +75,14 @@ func New(w *whisper.Whisper, handler EnvelopeEventsHandler, ldb *leveldb.DB, con
requestsRegistry: requestsRegistry, requestsRegistry: requestsRegistry,
} }
envelopesMonitor := NewEnvelopesMonitor(w, handler, config.MailServerConfirmations, ps, config.MaxMessageDeliveryAttempts) envelopesMonitor := NewEnvelopesMonitor(w, handler, config.MailServerConfirmations, ps, config.MaxMessageDeliveryAttempts)
publisherConfig := &publisher.Config{
PfsEnabled: config.PFSEnabled,
DataDir: config.BackupDisabledDataDir,
InstallationID: config.InstallationID,
}
publisherService := publisher.New(publisherConfig, w)
return &Service{ return &Service{
Service: publisherService,
storage: db.NewLevelDBStorage(ldb), storage: db.NewLevelDBStorage(ldb),
w: w, w: w,
config: config, config: config,
@ -97,9 +91,6 @@ func New(w *whisper.Whisper, handler EnvelopeEventsHandler, ldb *leveldb.DB, con
requestsRegistry: requestsRegistry, requestsRegistry: requestsRegistry,
historyUpdates: historyUpdates, historyUpdates: historyUpdates,
deduplicator: dedup.NewDeduplicator(w, ldb), deduplicator: dedup.NewDeduplicator(w, ldb),
dataDir: config.BackupDisabledDataDir,
installationID: config.InstallationID,
pfsEnabled: config.PFSEnabled,
peerStore: ps, peerStore: ps,
cache: cache, cache: cache,
} }
@ -121,138 +112,6 @@ func (s *Service) Protocols() []p2p.Protocol {
return []p2p.Protocol{} return []p2p.Protocol{}
} }
// InitProtocolWithPassword creates an instance of ProtocolService given an address and password used to generate an encryption key.
func (s *Service) InitProtocolWithPassword(address string, password string) error {
digest := sha3.Sum256([]byte(password))
encKey := fmt.Sprintf("%x", digest)
return s.initProtocol(address, encKey, password)
}
// InitProtocolWithEncyptionKey creates an instance of ProtocolService given an address and encryption key.
func (s *Service) InitProtocolWithEncyptionKey(address string, encKey string) error {
return s.initProtocol(address, encKey, "")
}
func (s *Service) initProtocol(address, encKey, password string) error {
if !s.pfsEnabled {
return nil
}
if err := os.MkdirAll(filepath.Clean(s.dataDir), os.ModePerm); err != nil {
return err
}
v0Path := filepath.Join(s.dataDir, fmt.Sprintf("%x.db", address))
v1Path := filepath.Join(s.dataDir, fmt.Sprintf("%s.db", s.installationID))
v2Path := filepath.Join(s.dataDir, fmt.Sprintf("%s.v2.db", s.installationID))
v3Path := filepath.Join(s.dataDir, fmt.Sprintf("%s.v3.db", s.installationID))
v4Path := filepath.Join(s.dataDir, fmt.Sprintf("%s.v4.db", s.installationID))
if password != "" {
if err := appDB.MigrateDBFile(v0Path, v1Path, "ON", password); err != nil {
return err
}
if err := appDB.MigrateDBFile(v1Path, v2Path, password, encKey); err != nil {
// Remove db file as created with a blank password and never used,
// and there's no need to rekey in this case
os.Remove(v1Path)
os.Remove(v2Path)
}
}
if err := appDB.MigrateDBKeyKdfIterations(v2Path, v3Path, encKey); err != nil {
os.Remove(v2Path)
os.Remove(v3Path)
}
// Fix IOS not encrypting database
if err := appDB.EncryptDatabase(v3Path, v4Path, encKey); err != nil {
os.Remove(v3Path)
os.Remove(v4Path)
}
// Desktop was passing a network dependent directory, which meant that
// if running on testnet it would not access the right db. This copies
// the db from mainnet to the root location.
networkDependentPath := filepath.Join(s.dataDir, "ethereum", "mainnet_rpc", fmt.Sprintf("%s.v4.db", s.installationID))
if _, err := os.Stat(networkDependentPath); err == nil {
if err := os.Rename(networkDependentPath, v4Path); err != nil {
return err
}
} else if !os.IsNotExist(err) {
return err
}
persistence, err := chat.NewSQLLitePersistence(v4Path, encKey)
if err != nil {
return err
}
addedBundlesHandler := func(addedBundles []chat.IdentityAndIDPair) {
handler := EnvelopeSignalHandler{}
for _, bundle := range addedBundles {
handler.BundleAdded(bundle[0], bundle[1])
}
}
// Initialize topics
topicService := topic.NewService(persistence.GetTopicStorage())
filterService := filter.New(s.config.AsymKeyID, s.w, topicService)
s.filter = filterService
s.protocol = chat.NewProtocolService(
chat.NewEncryptionService(
persistence,
chat.DefaultEncryptionServiceConfig(s.installationID)),
topicService,
addedBundlesHandler,
s.onNewTopicHandler)
return nil
}
func (s *Service) ProcessPublicBundle(myIdentityKey *ecdsa.PrivateKey, bundle *chat.Bundle) ([]chat.IdentityAndIDPair, error) {
if s.protocol == nil {
return nil, errProtocolNotInitialized
}
return s.protocol.ProcessPublicBundle(myIdentityKey, bundle)
}
func (s *Service) GetBundle(myIdentityKey *ecdsa.PrivateKey) (*chat.Bundle, error) {
if s.protocol == nil {
return nil, errProtocolNotInitialized
}
return s.protocol.GetBundle(myIdentityKey)
}
// EnableInstallation enables an installation for multi-device sync.
func (s *Service) EnableInstallation(myIdentityKey *ecdsa.PublicKey, installationID string) error {
if s.protocol == nil {
return errProtocolNotInitialized
}
return s.protocol.EnableInstallation(myIdentityKey, installationID)
}
func (s *Service) GetPublicBundle(identityKey *ecdsa.PublicKey) (*chat.Bundle, error) {
if s.protocol == nil {
return nil, errProtocolNotInitialized
}
return s.protocol.GetPublicBundle(identityKey)
}
// DisableInstallation disables an installation for multi-device sync.
func (s *Service) DisableInstallation(myIdentityKey *ecdsa.PublicKey, installationID string) error {
if s.protocol == nil {
return errProtocolNotInitialized
}
return s.protocol.DisableInstallation(myIdentityKey, installationID)
}
// APIs returns a list of new APIs. // APIs returns a list of new APIs.
func (s *Service) APIs() []rpc.API { func (s *Service) APIs() []rpc.API {
apis := []rpc.API{ apis := []rpc.API{
@ -293,7 +152,7 @@ func (s *Service) Start(server *p2p.Server) error {
s.mailMonitor.Start() s.mailMonitor.Start()
s.nodeID = server.PrivateKey s.nodeID = server.PrivateKey
s.server = server s.server = server
return nil return s.Service.Start()
} }
// Stop is run when a service is stopped. // Stop is run when a service is stopped.
@ -314,38 +173,5 @@ func (s *Service) Stop() error {
} }
} }
return nil return s.Service.Stop()
}
func (s *Service) GetNegotiatedChat(identity *ecdsa.PublicKey) *filter.Chat {
return s.filter.Get(identity)
}
func (s *Service) LoadFilters(chats []*filter.Chat) error {
return s.filter.Init(chats)
}
func (s *Service) RemoveFilter(chat *filter.Chat) {
// remove filter
}
func (s *Service) onNewTopicHandler(sharedSecrets []*topic.Secret) {
var filters []*signal.Filter
log.Info("NEW TOPIC HANDLER", "secrets", sharedSecrets)
for _, sharedSecret := range sharedSecrets {
err := s.filter.ProcessNegotiatedSecret(sharedSecret)
if err != nil {
log.Error("Failed to process negotiated secret", "err", err)
return
}
}
// TODO: send back chat filter
log.Info("FILTER IDS", "filter", filters)
if len(filters) != 0 {
log.Info("SENDING FILTERS")
handler := EnvelopeSignalHandler{}
handler.WhisperFilterAdded(filters)
}
} }

View File

@ -0,0 +1,23 @@
package whisperutils
import (
"github.com/ethereum/go-ethereum/crypto"
whisper "github.com/status-im/whisper/whisperv6"
)
var discoveryTopic = "contact-discovery"
var DiscoveryTopicBytes = ToTopic(discoveryTopic)
func ToTopic(s string) whisper.TopicType {
return whisper.BytesToTopic(crypto.Keccak256([]byte(s)))
}
func DefaultWhisperMessage() whisper.NewMessage {
msg := whisper.NewMessage{}
msg.TTL = 10
msg.PowTarget = 0.002
msg.PowTime = 1
return msg
}

View File

@ -63,10 +63,17 @@ type BundleAddedSignal struct {
} }
type Filter struct { type Filter struct {
Identity string `json:"identity"` // ChatID is the identifier of the chat
FilterID string `json:"filterId"`
SymKeyID string `json:"symKeyId"`
ChatID string `json:"chatId"` ChatID string `json:"chatId"`
// SymKeyID is the symmetric key id used for symmetric chats
SymKeyID string `json:"symKeyId"`
// OneToOne tells us if we need to use asymmetric encryption for this chat
Listen bool `json:"listen"`
// FilterID the whisper filter id generated
FilterID string `json:"filterId"`
// Identity is the public key of the other recipient for non-public chats
Identity string `json:"identity"`
// Topic is the whisper topic
Topic whisper.TopicType `json:"topic"` Topic whisper.TopicType `json:"topic"`
} }

View File

@ -0,0 +1,2 @@
DROP TABLE secret_installation_ids;
DROP TABLE secrets;

View File

@ -1,11 +1,11 @@
CREATE TABLE topics ( CREATE TABLE secrets (
identity BLOB NOT NULL PRIMARY KEY ON CONFLICT IGNORE, identity BLOB NOT NULL PRIMARY KEY ON CONFLICT IGNORE,
secret BLOB NOT NULL secret BLOB NOT NULL
); );
CREATE TABLE topic_installation_ids ( CREATE TABLE secret_installation_ids (
id TEXT NOT NULL, id TEXT NOT NULL,
identity_id BLOB NOT NULL, identity_id BLOB NOT NULL,
UNIQUE(id, identity_id) ON CONFLICT IGNORE, UNIQUE(id, identity_id) ON CONFLICT IGNORE,
FOREIGN KEY (identity_id) REFERENCES topics(identity) FOREIGN KEY (identity_id) REFERENCES secrets(identity)
); );

View File

@ -1,2 +0,0 @@
DROP TABLE topic_installation_ids;
DROP TABLE topics;

View File

@ -0,0 +1,263 @@
// Copyright 2016 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package whisperv6
import (
"crypto/ecdsa"
"fmt"
"sync"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/log"
)
// Filter represents a Whisper message filter
type Filter struct {
Src *ecdsa.PublicKey // Sender of the message
KeyAsym *ecdsa.PrivateKey // Private Key of recipient
KeySym []byte // Key associated with the Topic
Topics [][]byte // Topics to filter messages with
PoW float64 // Proof of work as described in the Whisper spec
AllowP2P bool // Indicates whether this filter is interested in direct peer-to-peer messages
SymKeyHash common.Hash // The Keccak256Hash of the symmetric key, needed for optimization
id string // unique identifier
Messages map[common.Hash]*ReceivedMessage
mutex sync.RWMutex
}
// Filters represents a collection of filters
type Filters struct {
watchers map[string]*Filter
topicMatcher map[TopicType]map[*Filter]struct{} // map a topic to the filters that are interested in being notified when a message matches that topic
allTopicsMatcher map[*Filter]struct{} // list all the filters that will be notified of a new message, no matter what its topic is
whisper *Whisper
mutex sync.RWMutex
}
// NewFilters returns a newly created filter collection
func NewFilters(w *Whisper) *Filters {
return &Filters{
watchers: make(map[string]*Filter),
topicMatcher: make(map[TopicType]map[*Filter]struct{}),
allTopicsMatcher: make(map[*Filter]struct{}),
whisper: w,
}
}
// Install will add a new filter to the filter collection
func (fs *Filters) Install(watcher *Filter) (string, error) {
if watcher.KeySym != nil && watcher.KeyAsym != nil {
return "", fmt.Errorf("filters must choose between symmetric and asymmetric keys")
}
if watcher.Messages == nil {
watcher.Messages = make(map[common.Hash]*ReceivedMessage)
}
id, err := GenerateRandomID()
if err != nil {
return "", err
}
fs.mutex.Lock()
defer fs.mutex.Unlock()
if fs.watchers[id] != nil {
return "", fmt.Errorf("failed to generate unique ID")
}
if watcher.expectsSymmetricEncryption() {
watcher.SymKeyHash = crypto.Keccak256Hash(watcher.KeySym)
}
watcher.id = id
fs.watchers[id] = watcher
fs.addTopicMatcher(watcher)
return id, err
}
// Uninstall will remove a filter whose id has been specified from
// the filter collection
func (fs *Filters) Uninstall(id string) bool {
fs.mutex.Lock()
defer fs.mutex.Unlock()
if fs.watchers[id] != nil {
fs.removeFromTopicMatchers(fs.watchers[id])
delete(fs.watchers, id)
return true
}
return false
}
// addTopicMatcher adds a filter to the topic matchers.
// If the filter's Topics array is empty, it will be tried on every topic.
// Otherwise, it will be tried on the topics specified.
func (fs *Filters) addTopicMatcher(watcher *Filter) {
if len(watcher.Topics) == 0 {
fs.allTopicsMatcher[watcher] = struct{}{}
} else {
for _, t := range watcher.Topics {
topic := BytesToTopic(t)
if fs.topicMatcher[topic] == nil {
fs.topicMatcher[topic] = make(map[*Filter]struct{})
}
fs.topicMatcher[topic][watcher] = struct{}{}
}
}
}
// removeFromTopicMatchers removes a filter from the topic matchers
func (fs *Filters) removeFromTopicMatchers(watcher *Filter) {
delete(fs.allTopicsMatcher, watcher)
for _, topic := range watcher.Topics {
delete(fs.topicMatcher[BytesToTopic(topic)], watcher)
}
}
// getWatchersByTopic returns a slice containing the filters that
// match a specific topic
func (fs *Filters) getWatchersByTopic(topic TopicType) []*Filter {
res := make([]*Filter, 0, len(fs.allTopicsMatcher))
for watcher := range fs.allTopicsMatcher {
res = append(res, watcher)
}
for watcher := range fs.topicMatcher[topic] {
res = append(res, watcher)
}
return res
}
// Get returns a filter from the collection with a specific ID
func (fs *Filters) Get(id string) *Filter {
fs.mutex.RLock()
defer fs.mutex.RUnlock()
return fs.watchers[id]
}
// NotifyWatchers notifies any filter that has declared interest
// for the envelope's topic.
func (fs *Filters) NotifyWatchers(env *Envelope, p2pMessage bool) {
var msg *ReceivedMessage
fs.mutex.RLock()
defer fs.mutex.RUnlock()
candidates := fs.getWatchersByTopic(env.Topic)
for _, watcher := range candidates {
if p2pMessage && !watcher.AllowP2P {
log.Trace(fmt.Sprintf("msg [%x], filter [%s]: p2p messages are not allowed", env.Hash(), watcher.id))
continue
}
var match bool
if msg != nil {
match = watcher.MatchMessage(msg)
} else {
match = watcher.MatchEnvelope(env)
if match {
msg = env.Open(watcher)
if msg == nil {
log.Trace("processing message: failed to open", "message", env.Hash().Hex(), "filter", watcher.id)
}
} else {
log.Trace("processing message: does not match", "message", env.Hash().Hex(), "filter", watcher.id)
}
}
if match && msg != nil {
log.Trace("processing message: decrypted", "hash", env.Hash().Hex())
if watcher.Src == nil || IsPubKeyEqual(msg.Src, watcher.Src) {
watcher.Trigger(msg)
}
}
}
}
func (f *Filter) expectsAsymmetricEncryption() bool {
return f.KeyAsym != nil
}
func (f *Filter) expectsSymmetricEncryption() bool {
return f.KeySym != nil
}
// Trigger adds a yet-unknown message to the filter's list of
// received messages.
func (f *Filter) Trigger(msg *ReceivedMessage) {
f.mutex.Lock()
defer f.mutex.Unlock()
if _, exist := f.Messages[msg.EnvelopeHash]; !exist {
f.Messages[msg.EnvelopeHash] = msg
}
}
// Retrieve will return the list of all received messages associated
// to a filter.
func (f *Filter) Retrieve() (all []*ReceivedMessage) {
f.mutex.Lock()
defer f.mutex.Unlock()
all = make([]*ReceivedMessage, 0, len(f.Messages))
for _, msg := range f.Messages {
all = append(all, msg)
}
f.Messages = make(map[common.Hash]*ReceivedMessage) // delete old messages
return all
}
// MatchMessage checks if the filter matches an already decrypted
// message (i.e. a Message that has already been handled by
// MatchEnvelope when checked by a previous filter).
// Topics are not checked here, since this is done by topic matchers.
func (f *Filter) MatchMessage(msg *ReceivedMessage) bool {
if f.PoW > 0 && msg.PoW < f.PoW {
return false
}
if f.expectsAsymmetricEncryption() && msg.isAsymmetricEncryption() {
return IsPubKeyEqual(&f.KeyAsym.PublicKey, msg.Dst)
} else if f.expectsSymmetricEncryption() && msg.isSymmetricEncryption() {
return f.SymKeyHash == msg.SymKeyHash
}
return false
}
// MatchEnvelope checks if it's worth decrypting the message. If
// it returns `true`, client code is expected to attempt decrypting
// the message and subsequently call MatchMessage.
// Topics are not checked here, since this is done by topic matchers.
func (f *Filter) MatchEnvelope(envelope *Envelope) bool {
log.Trace("checking pow", "filter", f.PoW, "envelope", envelope.pow)
return f.PoW <= 0 || envelope.pow >= f.PoW
}
// IsPubKeyEqual checks that two public keys are equal
func IsPubKeyEqual(a, b *ecdsa.PublicKey) bool {
if !ValidatePublicKey(a) {
return false
} else if !ValidatePublicKey(b) {
return false
}
// the curve is always the same, just compare the points
return a.X.Cmp(b.X) == 0 && a.Y.Cmp(b.Y) == 0
}

View File

@ -197,6 +197,7 @@ func (fs *Filters) NotifyWatchers(env *Envelope, p2pMessage bool) {
fs.mutex.RLock() fs.mutex.RLock()
defer fs.mutex.RUnlock() defer fs.mutex.RUnlock()
log.Info("Got envelope for topic", "topic", env.Topic)
candidates := fs.getWatchersByTopic(env.Topic) candidates := fs.getWatchersByTopic(env.Topic)
for _, watcher := range candidates { for _, watcher := range candidates {
if p2pMessage && !watcher.AllowP2P { if p2pMessage && !watcher.AllowP2P {
@ -279,6 +280,7 @@ func (f *Filter) MatchMessage(msg *ReceivedMessage) bool {
// the message and subsequently call MatchMessage. // the message and subsequently call MatchMessage.
// Topics are not checked here, since this is done by topic matchers. // Topics are not checked here, since this is done by topic matchers.
func (f *Filter) MatchEnvelope(envelope *Envelope) bool { func (f *Filter) MatchEnvelope(envelope *Envelope) bool {
log.Trace("checking pow", "filter", f.PoW, "envelope", envelope.pow)
return f.PoW <= 0 || envelope.pow >= f.PoW return f.PoW <= 0 || envelope.pow >= f.PoW
} }