Integrated server side only tls, public key and aes key connection string
This commit is contained in:
parent
f7cbe0b1e8
commit
b1def931eb
|
@ -7,6 +7,7 @@ import (
|
|||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/asn1"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"math/big"
|
||||
|
@ -20,22 +21,31 @@ type PairingClient struct {
|
|||
|
||||
baseAddress *url.URL
|
||||
certPEM []byte
|
||||
privateKey *ecdsa.PrivateKey
|
||||
serverPK *ecdsa.PublicKey
|
||||
serverMode Mode
|
||||
serverCert *x509.Certificate
|
||||
}
|
||||
|
||||
func NewPairingClient(c *ConnectionParams, config *PairingPayloadManagerConfig) (*PairingClient, error) {
|
||||
u, certPem, err := c.Generate()
|
||||
u, err := c.URL()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
serverCert, err := getServerCert(u)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = verifyCert(serverCert, c.publicKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
certPem := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: serverCert.Raw})
|
||||
|
||||
rootCAs, err := x509.SystemCertPool()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if ok := rootCAs.AppendCertsFromPEM(certPem); !ok {
|
||||
return nil, fmt.Errorf("failed to append certPem to rootCAs")
|
||||
}
|
||||
|
@ -48,7 +58,7 @@ func NewPairingClient(c *ConnectionParams, config *PairingPayloadManagerConfig)
|
|||
},
|
||||
}
|
||||
|
||||
pm, err := NewPairingPayloadManager(c.privateKey, config)
|
||||
pm, err := NewPairingPayloadManager(c.aesKey, config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -57,7 +67,8 @@ func NewPairingClient(c *ConnectionParams, config *PairingPayloadManagerConfig)
|
|||
Client: &http.Client{Transport: tr},
|
||||
baseAddress: u,
|
||||
certPEM: certPem,
|
||||
privateKey: c.privateKey,
|
||||
serverCert: serverCert,
|
||||
serverPK: c.publicKey,
|
||||
serverMode: c.serverMode,
|
||||
PayloadManager: pm,
|
||||
}, nil
|
||||
|
@ -99,6 +110,18 @@ func (c *PairingClient) receiveAccountData() error {
|
|||
return c.PayloadManager.Receive(payload)
|
||||
}
|
||||
|
||||
func verifyPublicKey(cert *x509.Certificate, publicKey *ecdsa.PublicKey) error {
|
||||
certKey, ok := cert.PublicKey.(*ecdsa.PublicKey)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected public key type, expected ecdsa.PublicKey")
|
||||
}
|
||||
|
||||
if !certKey.Equal(publicKey) {
|
||||
return fmt.Errorf("server certificate MUST match the given public key")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func verifyCertSig(cert *x509.Certificate) (bool, error) {
|
||||
var esig struct {
|
||||
R, S *big.Int
|
||||
|
@ -114,39 +137,37 @@ func verifyCertSig(cert *x509.Certificate) (bool, error) {
|
|||
return verified, nil
|
||||
}
|
||||
|
||||
func (c *PairingClient) getServerCert() error {
|
||||
conf := &tls.Config{
|
||||
InsecureSkipVerify: true, // Only skip verify to get the server's TLS cert. DO NOT skip for any other reason.
|
||||
}
|
||||
|
||||
conn, err := tls.Dial("tcp", c.baseAddress.Host, conf)
|
||||
func verifyCert(cert *x509.Certificate, publicKey *ecdsa.PublicKey) error {
|
||||
err := verifyPublicKey(cert, publicKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
certs := conn.ConnectionState().PeerCertificates
|
||||
if len(certs) != 1 {
|
||||
return fmt.Errorf("expected 1 TLS certificate, received '%d'", len(certs))
|
||||
}
|
||||
|
||||
certKey, ok := certs[0].PublicKey.(*ecdsa.PublicKey)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected public key type, expected ecdsa.PublicKey")
|
||||
}
|
||||
|
||||
if certKey.Equal(c.privateKey) {
|
||||
return fmt.Errorf("server certificate MUST match the given public key")
|
||||
}
|
||||
|
||||
verified, err := verifyCertSig(certs[0])
|
||||
verified, err := verifyCertSig(cert)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !verified {
|
||||
return fmt.Errorf("server certificate signature MUST verify")
|
||||
}
|
||||
|
||||
c.serverCert = certs[0]
|
||||
return nil
|
||||
}
|
||||
|
||||
func getServerCert(URL *url.URL) (*x509.Certificate, error) {
|
||||
conf := &tls.Config{
|
||||
InsecureSkipVerify: true, // Only skip verify to get the server's TLS cert. DO NOT skip for any other reason.
|
||||
}
|
||||
|
||||
conn, err := tls.Dial("tcp", URL.Host, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
certs := conn.ConnectionState().PeerCertificates
|
||||
if len(certs) != 1 {
|
||||
return nil, fmt.Errorf("expected 1 TLS certificate, received '%d'", len(certs))
|
||||
}
|
||||
|
||||
return certs[0], nil
|
||||
}
|
||||
|
|
|
@ -19,6 +19,7 @@ const (
|
|||
X = "7744735542292224619198421067303535767629647588258222392379329927711683109548"
|
||||
Y = "6855516769916529066379811647277920115118980625614889267697023742462401590771"
|
||||
D = "38564357061962143106230288374146033267100509055924181407058066820384455255240"
|
||||
AES = "BbnZ7Gc66t54a9kEFCf7FW8SGQuYypwHVeNkRYeNoqV6"
|
||||
DB58 = "6jpbvo2ucrtrnpXXF4DQYuysh697isH9ppd2aT8uSRDh"
|
||||
SN = "91849736469742262272885892667727604096707836853856473239722372976236128900962"
|
||||
CertTime = "eQUriVtGtkWhPJFeLZjF"
|
||||
|
@ -28,6 +29,7 @@ type TestKeyComponents struct {
|
|||
X *big.Int
|
||||
Y *big.Int
|
||||
D *big.Int
|
||||
AES []byte
|
||||
DBytes []byte
|
||||
PK *ecdsa.PrivateKey
|
||||
}
|
||||
|
@ -44,6 +46,9 @@ func (tk *TestKeyComponents) SetupKeyComponents(t *testing.T) {
|
|||
tk.D, ok = new(big.Int).SetString(D, 10)
|
||||
require.True(t, ok)
|
||||
|
||||
tk.AES = base58.Decode(AES)
|
||||
require.Len(t, tk.AES, 32)
|
||||
|
||||
tk.DBytes = base58.Decode(DB58)
|
||||
require.Exactly(t, tk.D.Bytes(), tk.DBytes)
|
||||
|
||||
|
@ -75,34 +80,36 @@ func (tcc *TestCertComponents) SetupCertComponents(t *testing.T) {
|
|||
}
|
||||
|
||||
type TestPairingServerComponents struct {
|
||||
EphemeralPK *ecdsa.PrivateKey
|
||||
OutboundIP net.IP
|
||||
CertTime time.Time
|
||||
Cert tls.Certificate
|
||||
PS *PairingServer
|
||||
EphemeralPK *ecdsa.PrivateKey
|
||||
EphemeralAES []byte
|
||||
OutboundIP net.IP
|
||||
Cert tls.Certificate
|
||||
PS *PairingServer
|
||||
}
|
||||
|
||||
func (tpsc *TestPairingServerComponents) SetupPairingServerComponents(t *testing.T) {
|
||||
var err error
|
||||
|
||||
// Get 3 key components for tls.cert generation
|
||||
// Get 4 key components for tls.cert generation
|
||||
// 1) Ephemeral private key
|
||||
tpsc.EphemeralPK, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 2) Device outbound IP address
|
||||
// 2) AES encryption key
|
||||
tpsc.EphemeralAES, err = makeEncryptionKey(tpsc.EphemeralPK)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 3) Device outbound IP address
|
||||
tpsc.OutboundIP, err = GetOutboundIP()
|
||||
require.NoError(t, err)
|
||||
|
||||
// 3) NotBefore time
|
||||
tpsc.CertTime = time.Now()
|
||||
|
||||
// Generate tls.Certificate and Server
|
||||
tpsc.Cert, _, err = GenerateCertFromKey(tpsc.EphemeralPK, tpsc.CertTime, tpsc.OutboundIP.String())
|
||||
tpsc.Cert, _, err = GenerateCertFromKey(tpsc.EphemeralPK, time.Now(), tpsc.OutboundIP.String())
|
||||
require.NoError(t, err)
|
||||
|
||||
tpsc.PS, err = NewPairingServer(&Config{
|
||||
PK: tpsc.EphemeralPK,
|
||||
PK: &tpsc.EphemeralPK.PublicKey,
|
||||
EK: tpsc.EphemeralAES,
|
||||
Cert: &tpsc.Cert,
|
||||
Hostname: tpsc.OutboundIP.String()})
|
||||
require.NoError(t, err)
|
||||
|
@ -112,8 +119,8 @@ type MockEncryptOnlyPayloadManager struct {
|
|||
pem *PayloadEncryptionManager
|
||||
}
|
||||
|
||||
func NewMockEncryptOnlyPayloadManager(pk *ecdsa.PrivateKey) (*MockEncryptOnlyPayloadManager, error) {
|
||||
pem, err := NewPayloadEncryptionManager(pk)
|
||||
func NewMockEncryptOnlyPayloadManager(aesKey []byte) (*MockEncryptOnlyPayloadManager, error) {
|
||||
pem, err := NewPayloadEncryptionManager(aesKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -2,13 +2,12 @@ package server
|
|||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"encoding/asn1"
|
||||
"crypto/elliptic"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/btcsuite/btcutil/base58"
|
||||
)
|
||||
|
@ -29,18 +28,18 @@ type ConnectionParams struct {
|
|||
version ConnectionParamVersion
|
||||
netIP net.IP
|
||||
port int
|
||||
privateKey *ecdsa.PrivateKey
|
||||
notBefore time.Time
|
||||
publicKey *ecdsa.PublicKey
|
||||
aesKey []byte
|
||||
serverMode Mode
|
||||
}
|
||||
|
||||
func NewConnectionParams(netIP net.IP, port int, privateKey *ecdsa.PrivateKey, notBefore time.Time, mode Mode) *ConnectionParams {
|
||||
func NewConnectionParams(netIP net.IP, port int, publicKey *ecdsa.PublicKey, aesKey []byte, mode Mode) *ConnectionParams {
|
||||
cp := new(ConnectionParams)
|
||||
cp.version = Version1
|
||||
cp.netIP = netIP
|
||||
cp.port = port
|
||||
cp.privateKey = privateKey
|
||||
cp.notBefore = notBefore
|
||||
cp.publicKey = publicKey
|
||||
cp.aesKey = aesKey
|
||||
cp.serverMode = mode
|
||||
return cp
|
||||
}
|
||||
|
@ -48,28 +47,24 @@ func NewConnectionParams(netIP net.IP, port int, privateKey *ecdsa.PrivateKey, n
|
|||
// ToString generates a string required for generating a secure connection to another Status device.
|
||||
//
|
||||
// The returned string will look like below:
|
||||
// - "2:4FHRnp:H6G:6jpbvo2ucrtrnpXXF4DQYuysh697isH9ppd2aT8uSRDh:eQUriVtGtkWhPJFeLZjF:2"
|
||||
// - "2:4FHRnp:H6G:uqnnMwVUfJc2Fkcaojet8F1ufKC3hZdGEt47joyBx9yd:BbnZ7Gc66t54a9kEFCf7FW8SGQuYypwHVeNkRYeNoqV6:2"
|
||||
//
|
||||
// Format bytes encoded into a base58 string, delimited by ":"
|
||||
// - version
|
||||
// - net.IP
|
||||
// - port
|
||||
// - ecdsa.PrivateKey D field
|
||||
// - asn1.Marshal time.Time
|
||||
// - ecdsa CompressedPublicKey
|
||||
// - AES encryption key
|
||||
// - server mode
|
||||
func (cp *ConnectionParams) ToString() (string, error) {
|
||||
func (cp *ConnectionParams) ToString() string {
|
||||
v := base58.Encode(new(big.Int).SetInt64(int64(cp.version)).Bytes())
|
||||
ip := base58.Encode(cp.netIP)
|
||||
p := base58.Encode(new(big.Int).SetInt64(int64(cp.port)).Bytes())
|
||||
k := base58.Encode(cp.privateKey.D.Bytes())
|
||||
tb, err := asn1.Marshal(cp.notBefore.UTC())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
t := base58.Encode(tb)
|
||||
k := base58.Encode(elliptic.MarshalCompressed(cp.publicKey.Curve, cp.publicKey.X, cp.publicKey.Y))
|
||||
ek := base58.Encode(cp.aesKey)
|
||||
m := base58.Encode(new(big.Int).SetInt64(int64(cp.serverMode)).Bytes())
|
||||
|
||||
return fmt.Sprintf("%s:%s:%s:%s:%s:%s", v, ip, p, k, t, m), nil
|
||||
return fmt.Sprintf("%s:%s:%s:%s:%s:%s", v, ip, p, k, ek, m)
|
||||
}
|
||||
|
||||
// FromString parses a connection params string required for to securely connect to another Status device.
|
||||
|
@ -85,14 +80,10 @@ func (cp *ConnectionParams) FromString(s string) error {
|
|||
cp.version = ConnectionParamVersion(new(big.Int).SetBytes(base58.Decode(sData[0])).Int64())
|
||||
cp.netIP = base58.Decode(sData[1])
|
||||
cp.port = int(new(big.Int).SetBytes(base58.Decode(sData[2])).Int64())
|
||||
cp.privateKey = ToECDSA(base58.Decode(sData[3]))
|
||||
|
||||
t := time.Time{}
|
||||
_, err := asn1.Unmarshal(base58.Decode(sData[4]), &t)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cp.notBefore = t
|
||||
cp.publicKey = new(ecdsa.PublicKey)
|
||||
cp.publicKey.X, cp.publicKey.Y = elliptic.UnmarshalCompressed(elliptic.P256(), base58.Decode(sData[3]))
|
||||
cp.publicKey.Curve = elliptic.P256()
|
||||
cp.aesKey = base58.Decode(sData[4])
|
||||
cp.serverMode = Mode(new(big.Int).SetBytes(base58.Decode(sData[5])).Int64())
|
||||
|
||||
return cp.validate()
|
||||
|
@ -114,12 +105,12 @@ func (cp *ConnectionParams) validate() error {
|
|||
return err
|
||||
}
|
||||
|
||||
err = cp.validatePrivateKey()
|
||||
err = cp.validatePublicKey()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = cp.validateNotBefore()
|
||||
err = cp.validateAESKey()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -151,22 +142,22 @@ func (cp *ConnectionParams) validatePort() error {
|
|||
return fmt.Errorf("port '%d' outside of bounds of 1 - 65535", cp.port)
|
||||
}
|
||||
|
||||
func (cp *ConnectionParams) validatePrivateKey() error {
|
||||
func (cp *ConnectionParams) validatePublicKey() error {
|
||||
switch {
|
||||
case cp.privateKey.D == nil, cp.privateKey.D.Cmp(big.NewInt(0)) == 0:
|
||||
return fmt.Errorf("private key D not set")
|
||||
case cp.privateKey.PublicKey.X == nil, cp.privateKey.PublicKey.X.Cmp(big.NewInt(0)) == 0:
|
||||
case cp.publicKey.Curve == nil, cp.publicKey.Curve != elliptic.P256():
|
||||
return fmt.Errorf("public key Curve not `elliptic.P256`")
|
||||
case cp.publicKey.X == nil, cp.publicKey.X.Cmp(big.NewInt(0)) == 0:
|
||||
return fmt.Errorf("public key X not set")
|
||||
case cp.privateKey.PublicKey.Y == nil, cp.privateKey.PublicKey.Y.Cmp(big.NewInt(0)) == 0:
|
||||
case cp.publicKey.Y == nil, cp.publicKey.Y.Cmp(big.NewInt(0)) == 0:
|
||||
return fmt.Errorf("public key Y not set")
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (cp *ConnectionParams) validateNotBefore() error {
|
||||
if cp.notBefore.IsZero() {
|
||||
return fmt.Errorf("notBefore time is zero")
|
||||
func (cp *ConnectionParams) validateAESKey() error {
|
||||
if len(cp.aesKey) != 32 {
|
||||
return fmt.Errorf("AES key invalid length, expect length 32, received length '%d'", len(cp.aesKey))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -180,22 +171,15 @@ func (cp *ConnectionParams) validateServerMode() error {
|
|||
}
|
||||
}
|
||||
|
||||
// Generate returns a *url.URL and encoded pem.Block generated from ConnectionParams set fields
|
||||
func (cp *ConnectionParams) Generate() (*url.URL, []byte, error) {
|
||||
func (cp *ConnectionParams) URL() (*url.URL, error) {
|
||||
err := cp.validate()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
u := &url.URL{
|
||||
Scheme: "https",
|
||||
Host: fmt.Sprintf("%s:%d", cp.netIP, cp.port),
|
||||
}
|
||||
|
||||
_, pem, err := GenerateCertFromKey(cp.privateKey, cp.notBefore, cp.netIP.String())
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return u, pem, nil
|
||||
return u, nil
|
||||
}
|
||||
|
|
|
@ -1,16 +1,13 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
var (
|
||||
connectionString = "2:4FHRnp:Q4:6jpbvo2ucrtrnpXXF4DQYuysh697isH9ppd2aT8uSRDh:eQUriVtGtkWhPJFeLZjF:3"
|
||||
connectionString = "2:4FHRnp:Q4:uqnnMwVUfJc2Fkcaojet8F1ufKC3hZdGEt47joyBx9yd:BbnZ7Gc66t54a9kEFCf7FW8SGQuYypwHVeNkRYeNoqV6:3"
|
||||
)
|
||||
|
||||
func TestConnectionParamsSuite(t *testing.T) {
|
||||
|
@ -37,7 +34,8 @@ func (s *ConnectionParamsSuite) SetupSuite() {
|
|||
|
||||
s.server = &PairingServer{
|
||||
Server: bs,
|
||||
pk: s.PK,
|
||||
pk: &s.PK.PublicKey,
|
||||
ek: s.AES,
|
||||
mode: Sending,
|
||||
}
|
||||
}
|
||||
|
@ -46,9 +44,7 @@ func (s *ConnectionParamsSuite) TestConnectionParams_ToString() {
|
|||
cp, err := s.server.MakeConnectionParams()
|
||||
s.Require().NoError(err)
|
||||
|
||||
cps, err := cp.ToString()
|
||||
s.Require().NoError(err)
|
||||
|
||||
cps := cp.ToString()
|
||||
s.Require().Equal(connectionString, cps)
|
||||
}
|
||||
|
||||
|
@ -59,27 +55,13 @@ func (s *ConnectionParamsSuite) TestConnectionParams_Generate() {
|
|||
|
||||
s.Require().Exactly(Sending, cp.serverMode)
|
||||
|
||||
u, c, err := cp.Generate()
|
||||
u, err := cp.URL()
|
||||
s.Require().NoError(err)
|
||||
|
||||
s.Require().Equal("https://127.0.0.1:1337", u.String())
|
||||
s.Require().Equal(defaultIP.String(), u.Hostname())
|
||||
s.Require().Equal("1337", u.Port())
|
||||
|
||||
// Parse cert PEM into x509 cert
|
||||
block, _ := pem.Decode(c)
|
||||
s.Require().NotNil(block)
|
||||
cert, err := x509.ParseCertificate(block.Bytes)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Compare cert values
|
||||
cl := s.server.cert.Leaf
|
||||
s.Require().NotEqual(cl.Signature, cert.Signature)
|
||||
s.Require().Zero(cl.PublicKey.(*ecdsa.PublicKey).X.Cmp(cert.PublicKey.(*ecdsa.PublicKey).X))
|
||||
s.Require().Zero(cl.PublicKey.(*ecdsa.PublicKey).Y.Cmp(cert.PublicKey.(*ecdsa.PublicKey).Y))
|
||||
s.Require().Equal(cl.Version, cert.Version)
|
||||
s.Require().Zero(cl.SerialNumber.Cmp(cert.SerialNumber))
|
||||
s.Require().Exactly(cl.NotBefore, cert.NotBefore)
|
||||
s.Require().Exactly(cl.NotAfter, cert.NotAfter)
|
||||
s.Require().Exactly(cl.IPAddresses, cert.IPAddresses)
|
||||
s.Require().True(cp.publicKey.Equal(&s.PK.PublicKey))
|
||||
s.Require().Equal(s.AES, cp.aesKey)
|
||||
}
|
||||
|
|
|
@ -63,8 +63,7 @@ func (s *GetOutboundIPSuite) TestGetOutboundIPWithFullServerE2e() {
|
|||
cp, err := s.PS.MakeConnectionParams()
|
||||
s.Require().NoError(err)
|
||||
|
||||
qr, err := cp.ToString()
|
||||
s.Require().NoError(err)
|
||||
qr := cp.ToString()
|
||||
|
||||
// Client reads QR code and parses the connection string
|
||||
ccp := new(ConnectionParams)
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/rand"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
@ -44,8 +43,8 @@ type PairingPayloadManager struct {
|
|||
}
|
||||
|
||||
// NewPairingPayloadManager generates a new and initialised PairingPayloadManager
|
||||
func NewPairingPayloadManager(pk *ecdsa.PrivateKey, config *PairingPayloadManagerConfig) (*PairingPayloadManager, error) {
|
||||
pem, err := NewPayloadEncryptionManager(pk)
|
||||
func NewPairingPayloadManager(aesKey []byte, config *PairingPayloadManagerConfig) (*PairingPayloadManager, error) {
|
||||
pem, err := NewPayloadEncryptionManager(aesKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -120,13 +119,8 @@ type PayloadEncryptionManager struct {
|
|||
received *EncryptionPayload
|
||||
}
|
||||
|
||||
func NewPayloadEncryptionManager(pk *ecdsa.PrivateKey) (*PayloadEncryptionManager, error) {
|
||||
ek, err := makeEncryptionKey(pk)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &PayloadEncryptionManager{ek, new(EncryptionPayload), new(EncryptionPayload)}, nil
|
||||
func NewPayloadEncryptionManager(aesKey []byte) (*PayloadEncryptionManager, error) {
|
||||
return &PayloadEncryptionManager{aesKey, new(EncryptionPayload), new(EncryptionPayload)}, nil
|
||||
}
|
||||
|
||||
func (pem *PayloadEncryptionManager) Encrypt(data []byte) error {
|
||||
|
|
|
@ -11,13 +11,15 @@ type PairingServer struct {
|
|||
Server
|
||||
PayloadManager
|
||||
|
||||
pk *ecdsa.PrivateKey
|
||||
pk *ecdsa.PublicKey
|
||||
ek []byte
|
||||
mode Mode
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
// Connection fields
|
||||
PK *ecdsa.PrivateKey
|
||||
PK *ecdsa.PublicKey
|
||||
EK []byte
|
||||
Cert *tls.Certificate
|
||||
Hostname string
|
||||
Mode Mode
|
||||
|
@ -28,7 +30,7 @@ type Config struct {
|
|||
|
||||
// NewPairingServer returns a *PairingServer init from the given *Config
|
||||
func NewPairingServer(config *Config) (*PairingServer, error) {
|
||||
pm, err := NewPairingPayloadManager(config.PK, config.PairingPayloadManagerConfig)
|
||||
pm, err := NewPairingPayloadManager(config.EK, config.PairingPayloadManagerConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -38,21 +40,13 @@ func NewPairingServer(config *Config) (*PairingServer, error) {
|
|||
config.Hostname,
|
||||
),
|
||||
pk: config.PK,
|
||||
ek: config.EK,
|
||||
mode: config.Mode,
|
||||
PayloadManager: pm}, nil
|
||||
}
|
||||
|
||||
// MakeConnectionParams generates a *ConnectionParams based on the Server's current state
|
||||
func (s *PairingServer) MakeConnectionParams() (*ConnectionParams, error) {
|
||||
switch {
|
||||
case s.cert == nil:
|
||||
return nil, fmt.Errorf("server has no cert set")
|
||||
case s.cert.Leaf == nil:
|
||||
return nil, fmt.Errorf("server cert has no Leaf set")
|
||||
case s.cert.Leaf.NotBefore.IsZero():
|
||||
return nil, fmt.Errorf("server cert Leaf has a zero value NotBefore")
|
||||
}
|
||||
|
||||
netIP := net.ParseIP(s.hostname)
|
||||
if netIP == nil {
|
||||
return nil, fmt.Errorf("invalid ip address given '%s'", s.hostname)
|
||||
|
@ -67,7 +61,7 @@ func (s *PairingServer) MakeConnectionParams() (*ConnectionParams, error) {
|
|||
return nil, fmt.Errorf("port is 0, listener is not yet set")
|
||||
}
|
||||
|
||||
return NewConnectionParams(netIP, s.port, s.pk, s.cert.Leaf.NotBefore, s.mode), nil
|
||||
return NewConnectionParams(netIP, s.port, s.pk, s.ek, s.mode), nil
|
||||
}
|
||||
|
||||
func (s *PairingServer) StartPairing() error {
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -22,7 +23,7 @@ func (s *PairingServerSuite) SetupSuite() {
|
|||
|
||||
func (s *PairingServerSuite) TestPairingServer_StartPairing() {
|
||||
// Replace PairingServer.PayloadManager with a MockEncryptOnlyPayloadManager
|
||||
pm, err := NewMockEncryptOnlyPayloadManager(s.EphemeralPK)
|
||||
pm, err := NewMockEncryptOnlyPayloadManager(s.EphemeralAES)
|
||||
s.Require().NoError(err)
|
||||
s.PS.PayloadManager = pm
|
||||
|
||||
|
@ -48,8 +49,7 @@ func (s *PairingServerSuite) TestPairingServer_StartPairing() {
|
|||
cp, err := s.PS.MakeConnectionParams()
|
||||
s.Require().NoError(err)
|
||||
|
||||
qr, err := cp.ToString()
|
||||
s.Require().NoError(err)
|
||||
qr := cp.ToString()
|
||||
|
||||
// Client reads QR code and parses the connection string
|
||||
ccp := new(ConnectionParams)
|
||||
|
@ -59,11 +59,20 @@ func (s *PairingServerSuite) TestPairingServer_StartPairing() {
|
|||
c, err := NewPairingClient(ccp, nil)
|
||||
s.Require().NoError(err)
|
||||
|
||||
err = c.getServerCert()
|
||||
s.Require().NoError(err)
|
||||
// Compare cert values
|
||||
cert := c.serverCert
|
||||
cl := s.PS.cert.Leaf
|
||||
s.Require().Equal(cl.Signature, cert.Signature)
|
||||
s.Require().Zero(cl.PublicKey.(*ecdsa.PublicKey).X.Cmp(cert.PublicKey.(*ecdsa.PublicKey).X))
|
||||
s.Require().Zero(cl.PublicKey.(*ecdsa.PublicKey).Y.Cmp(cert.PublicKey.(*ecdsa.PublicKey).Y))
|
||||
s.Require().Equal(cl.Version, cert.Version)
|
||||
s.Require().Zero(cl.SerialNumber.Cmp(cert.SerialNumber))
|
||||
s.Require().Exactly(cl.NotBefore, cert.NotBefore)
|
||||
s.Require().Exactly(cl.NotAfter, cert.NotAfter)
|
||||
s.Require().Exactly(cl.IPAddresses, cert.IPAddresses)
|
||||
|
||||
// Replace PairingClient.PayloadManager with a MockEncryptOnlyPayloadManager
|
||||
c.PayloadManager, err = NewMockEncryptOnlyPayloadManager(s.EphemeralPK)
|
||||
c.PayloadManager, err = NewMockEncryptOnlyPayloadManager(s.EphemeralAES)
|
||||
s.Require().NoError(err)
|
||||
|
||||
if m == Receiving {
|
||||
|
|
Loading…
Reference in New Issue