doubleratchet/keys_storage.go

165 lines
4.2 KiB
Go

package doubleratchet
import (
"bytes"
"sort"
)
// KeysStorage is an interface of an abstract in-memory or persistent keys storage.
type KeysStorage interface {
// Get returns a message key by the given key and message number.
Get(k Key, msgNum uint) (mk Key, ok bool, err error)
// Put saves the given mk under the specified key and msgNum.
Put(sessionID []byte, k Key, msgNum uint, mk Key, keySeqNum uint) error
// DeleteMk ensures there's no message key under the specified key and msgNum.
DeleteMk(k Key, msgNum uint) error
// DeleteOldMKeys deletes old message keys for a session.
DeleteOldMks(sessionID []byte, deleteUntilSeqKey uint) error
// TruncateMks truncates the number of keys to maxKeys.
TruncateMks(sessionID []byte, maxKeys int) error
// Count returns number of message keys stored under the specified key.
Count(k Key) (uint, error)
// All returns all the keys
All() (map[Key]map[uint]Key, error)
}
// KeysStorageInMemory is an in-memory message keys storage.
type KeysStorageInMemory struct {
keys map[Key]map[uint]InMemoryKey
}
// Get returns a message key by the given key and message number.
func (s *KeysStorageInMemory) Get(pubKey Key, msgNum uint) (Key, bool, error) {
if s.keys == nil {
return Key{}, false, nil
}
msgs, ok := s.keys[pubKey]
if !ok {
return Key{}, false, nil
}
mk, ok := msgs[msgNum]
if !ok {
return Key{}, false, nil
}
return mk.messageKey, true, nil
}
type InMemoryKey struct {
messageKey Key
seqNum uint
sessionID []byte
}
// Put saves the given mk under the specified key and msgNum.
func (s *KeysStorageInMemory) Put(sessionID []byte, pubKey Key, msgNum uint, mk Key, seqNum uint) error {
if s.keys == nil {
s.keys = make(map[Key]map[uint]InMemoryKey)
}
if _, ok := s.keys[pubKey]; !ok {
s.keys[pubKey] = make(map[uint]InMemoryKey)
}
s.keys[pubKey][msgNum] = InMemoryKey{
sessionID: sessionID,
messageKey: mk,
seqNum: seqNum,
}
return nil
}
// DeleteMk ensures there's no message key under the specified key and msgNum.
func (s *KeysStorageInMemory) DeleteMk(pubKey Key, msgNum uint) error {
if s.keys == nil {
return nil
}
if _, ok := s.keys[pubKey]; !ok {
return nil
}
if _, ok := s.keys[pubKey][msgNum]; !ok {
return nil
}
delete(s.keys[pubKey], msgNum)
if len(s.keys[pubKey]) == 0 {
delete(s.keys, pubKey)
}
return nil
}
// TruncateMks truncates the number of keys to maxKeys.
func (s *KeysStorageInMemory) TruncateMks(sessionID []byte, maxKeys int) error {
var seqNos []uint
// Collect all seq numbers
for _, keys := range s.keys {
for _, inMemoryKey := range keys {
if bytes.Equal(inMemoryKey.sessionID, sessionID) {
seqNos = append(seqNos, inMemoryKey.seqNum)
}
}
}
// Nothing to do if we haven't reached the limit
if len(seqNos) <= maxKeys {
return nil
}
// Take the sequence numbers we care about
sort.Slice(seqNos, func(i, j int) bool { return seqNos[i] < seqNos[j] })
toDeleteSlice := seqNos[:len(seqNos)-maxKeys]
// Put in map for easier lookup
toDelete := make(map[uint]bool)
for _, seqNo := range toDeleteSlice {
toDelete[seqNo] = true
}
for pubKey, keys := range s.keys {
for i, inMemoryKey := range keys {
if toDelete[inMemoryKey.seqNum] && bytes.Equal(inMemoryKey.sessionID, sessionID) {
delete(s.keys[pubKey], i)
}
}
}
return nil
}
// DeleteOldMKeys deletes old message keys for a session.
func (s *KeysStorageInMemory) DeleteOldMks(sessionID []byte, deleteUntilSeqKey uint) error {
for pubKey, keys := range s.keys {
for i, inMemoryKey := range keys {
if inMemoryKey.seqNum <= deleteUntilSeqKey && bytes.Equal(inMemoryKey.sessionID, sessionID) {
delete(s.keys[pubKey], i)
}
}
}
return nil
}
// Count returns number of message keys stored under the specified key.
func (s *KeysStorageInMemory) Count(pubKey Key) (uint, error) {
if s.keys == nil {
return 0, nil
}
return uint(len(s.keys[pubKey])), nil
}
// All returns all the keys
func (s *KeysStorageInMemory) All() (map[Key]map[uint]Key, error) {
response := make(map[Key]map[uint]Key)
for pubKey, keys := range s.keys {
response[pubKey] = make(map[uint]Key)
for n, key := range keys {
response[pubKey][n] = key.messageKey
}
}
return response, nil
}