Added payload locking to prevent multiple requests for the pairing data

Signed-off-by: Samuel Hawksby-Robinson <samuel@samyoul.com>
This commit is contained in:
Samuel Hawksby-Robinson 2022-10-28 11:30:18 +01:00
parent 2a9ac92db9
commit f33c1cec38
4 changed files with 63 additions and 1 deletions

View File

@ -119,6 +119,8 @@ func (c *PairingClient) sendAccountData() error {
}
signal.SendLocalPairingEvent(Event{Type: EventTransferSuccess})
c.PayloadManager.LockPayload()
return nil
}

View File

@ -424,7 +424,6 @@ func handlePairingReceive(ps *PairingServer) http.HandlerFunc {
func handlePairingSend(ps *PairingServer) http.HandlerFunc {
signal.SendLocalPairingEvent(Event{Type: EventConnectionSuccess})
// TODO lock sending after one successful transfer, perhaps perform the lock on the PayloadManager level
return func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/octet-stream")
_, err := w.Write(ps.PayloadManager.ToSend())
@ -434,6 +433,8 @@ func handlePairingSend(ps *PairingServer) http.HandlerFunc {
return
}
signal.SendLocalPairingEvent(Event{Type: EventTransferSuccess})
ps.PayloadManager.LockPayload()
}
}

View File

@ -20,12 +20,26 @@ import (
// PayloadManager is the interface for PayloadManagers and wraps the basic functions for fulfilling payload management
type PayloadManager interface {
// Mount Loads the payload into the PayloadManager's state
Mount() error
// Receive stores data from an inbound source into the PayloadManager's state
Receive(data []byte) error
// ToSend returns an outbound safe (encrypted) payload
ToSend() []byte
// Received returns a decrypted and parsed payload from an inbound source
Received() []byte
// ResetPayload resets all payloads the PayloadManager has in its state
ResetPayload()
// EncryptPlain encrypts the given plaintext using internal key(s)
EncryptPlain(plaintext []byte) ([]byte, error)
// LockPayload prevents future excess to outbound safe and received data
LockPayload()
}
// PairingPayloadSourceConfig represents location and access data of the pairing payload
@ -135,6 +149,11 @@ func (ppm *PairingPayloadManager) ResetPayload() {
type EncryptionPayload struct {
plain []byte
encrypted []byte
locked bool
}
func (ep *EncryptionPayload) lock() {
ep.locked = true
}
// PayloadEncryptionManager is responsible for encrypting and decrypting payload data
@ -149,6 +168,8 @@ func NewPayloadEncryptionManager(aesKey []byte, logger *zap.Logger) (*PayloadEnc
return &PayloadEncryptionManager{logger.Named("PayloadEncryptionManager"), aesKey, new(EncryptionPayload), new(EncryptionPayload)}, nil
}
// EncryptPlain encrypts any given plain text using the internal AES key and returns the encrypted value
// This function is different to Encrypt as the internal EncryptionPayload.encrypted value is not set
func (pem *PayloadEncryptionManager) EncryptPlain(plaintext []byte) ([]byte, error) {
l := pem.logger.Named("EncryptPlain()")
l.Debug("fired")
@ -200,10 +221,16 @@ func (pem *PayloadEncryptionManager) Decrypt(data []byte) error {
}
func (pem *PayloadEncryptionManager) ToSend() []byte {
if pem.toSend.locked {
return nil
}
return pem.toSend.encrypted
}
func (pem *PayloadEncryptionManager) Received() []byte {
if pem.toSend.locked {
return nil
}
return pem.received.plain
}
@ -212,6 +239,14 @@ func (pem *PayloadEncryptionManager) ResetPayload() {
pem.received = new(EncryptionPayload)
}
func (pem *PayloadEncryptionManager) LockPayload() {
l := pem.logger.Named("LockPayload")
l.Debug("fired")
pem.toSend.lock()
pem.received.lock()
}
// PairingPayload represents the payload structure a PairingServer handles
type PairingPayload struct {
keys map[string][]byte

View File

@ -2,6 +2,7 @@ package server
import (
"bytes"
"crypto/rand"
"crypto/sha256"
"fmt"
"io/ioutil"
@ -326,3 +327,26 @@ func (pms *PayloadMarshallerSuite) TestPayloadMarshaller_StorePayloads() {
pms.Require().Exactly(expected.Timestamp, acc.Timestamp)
pms.Require().Len(acc.Images, 2)
}
func (pms *PayloadMarshallerSuite) TestPayloadMarshaller_LockPayload() {
AESKey := make([]byte, 32)
_, err := rand.Read(AESKey)
pms.Require().NoError(err)
pm, err := NewMockEncryptOnlyPayloadManager(AESKey)
pms.Require().NoError(err)
err = pm.Mount()
pms.Require().NoError(err)
toSend := pm.ToSend()
pms.Len(toSend, 60)
toSend2 := pm.ToSend()
pms.Len(toSend2, 60)
pm.LockPayload()
toSend3 := pm.ToSend()
pms.Nil(toSend3)
}