Integrated server side only tls, public key and aes key connection string

This commit is contained in:
Samuel Hawksby-Robinson 2022-08-07 23:14:33 +01:00
parent f7cbe0b1e8
commit b1def931eb
8 changed files with 135 additions and 145 deletions

View File

@ -7,6 +7,7 @@ import (
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"encoding/asn1" "encoding/asn1"
"encoding/pem"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"math/big" "math/big"
@ -20,22 +21,31 @@ type PairingClient struct {
baseAddress *url.URL baseAddress *url.URL
certPEM []byte certPEM []byte
privateKey *ecdsa.PrivateKey serverPK *ecdsa.PublicKey
serverMode Mode serverMode Mode
serverCert *x509.Certificate serverCert *x509.Certificate
} }
func NewPairingClient(c *ConnectionParams, config *PairingPayloadManagerConfig) (*PairingClient, error) { func NewPairingClient(c *ConnectionParams, config *PairingPayloadManagerConfig) (*PairingClient, error) {
u, certPem, err := c.Generate() u, err := c.URL()
if err != nil { if err != nil {
return nil, err return nil, err
} }
serverCert, err := getServerCert(u)
if err != nil {
return nil, err
}
err = verifyCert(serverCert, c.publicKey)
if err != nil {
return nil, err
}
certPem := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: serverCert.Raw})
rootCAs, err := x509.SystemCertPool() rootCAs, err := x509.SystemCertPool()
if err != nil { if err != nil {
return nil, err return nil, err
} }
if ok := rootCAs.AppendCertsFromPEM(certPem); !ok { if ok := rootCAs.AppendCertsFromPEM(certPem); !ok {
return nil, fmt.Errorf("failed to append certPem to rootCAs") return nil, fmt.Errorf("failed to append certPem to rootCAs")
} }
@ -48,7 +58,7 @@ func NewPairingClient(c *ConnectionParams, config *PairingPayloadManagerConfig)
}, },
} }
pm, err := NewPairingPayloadManager(c.privateKey, config) pm, err := NewPairingPayloadManager(c.aesKey, config)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -57,7 +67,8 @@ func NewPairingClient(c *ConnectionParams, config *PairingPayloadManagerConfig)
Client: &http.Client{Transport: tr}, Client: &http.Client{Transport: tr},
baseAddress: u, baseAddress: u,
certPEM: certPem, certPEM: certPem,
privateKey: c.privateKey, serverCert: serverCert,
serverPK: c.publicKey,
serverMode: c.serverMode, serverMode: c.serverMode,
PayloadManager: pm, PayloadManager: pm,
}, nil }, nil
@ -99,6 +110,18 @@ func (c *PairingClient) receiveAccountData() error {
return c.PayloadManager.Receive(payload) return c.PayloadManager.Receive(payload)
} }
func verifyPublicKey(cert *x509.Certificate, publicKey *ecdsa.PublicKey) error {
certKey, ok := cert.PublicKey.(*ecdsa.PublicKey)
if !ok {
return fmt.Errorf("unexpected public key type, expected ecdsa.PublicKey")
}
if !certKey.Equal(publicKey) {
return fmt.Errorf("server certificate MUST match the given public key")
}
return nil
}
func verifyCertSig(cert *x509.Certificate) (bool, error) { func verifyCertSig(cert *x509.Certificate) (bool, error) {
var esig struct { var esig struct {
R, S *big.Int R, S *big.Int
@ -114,39 +137,37 @@ func verifyCertSig(cert *x509.Certificate) (bool, error) {
return verified, nil return verified, nil
} }
func (c *PairingClient) getServerCert() error { func verifyCert(cert *x509.Certificate, publicKey *ecdsa.PublicKey) error {
conf := &tls.Config{ err := verifyPublicKey(cert, publicKey)
InsecureSkipVerify: true, // Only skip verify to get the server's TLS cert. DO NOT skip for any other reason.
}
conn, err := tls.Dial("tcp", c.baseAddress.Host, conf)
if err != nil { if err != nil {
return err return err
} }
defer conn.Close()
certs := conn.ConnectionState().PeerCertificates verified, err := verifyCertSig(cert)
if len(certs) != 1 {
return fmt.Errorf("expected 1 TLS certificate, received '%d'", len(certs))
}
certKey, ok := certs[0].PublicKey.(*ecdsa.PublicKey)
if !ok {
return fmt.Errorf("unexpected public key type, expected ecdsa.PublicKey")
}
if certKey.Equal(c.privateKey) {
return fmt.Errorf("server certificate MUST match the given public key")
}
verified, err := verifyCertSig(certs[0])
if err != nil { if err != nil {
return err return err
} }
if !verified { if !verified {
return fmt.Errorf("server certificate signature MUST verify") return fmt.Errorf("server certificate signature MUST verify")
} }
c.serverCert = certs[0]
return nil return nil
} }
func getServerCert(URL *url.URL) (*x509.Certificate, error) {
conf := &tls.Config{
InsecureSkipVerify: true, // Only skip verify to get the server's TLS cert. DO NOT skip for any other reason.
}
conn, err := tls.Dial("tcp", URL.Host, conf)
if err != nil {
return nil, err
}
defer conn.Close()
certs := conn.ConnectionState().PeerCertificates
if len(certs) != 1 {
return nil, fmt.Errorf("expected 1 TLS certificate, received '%d'", len(certs))
}
return certs[0], nil
}

View File

@ -19,6 +19,7 @@ const (
X = "7744735542292224619198421067303535767629647588258222392379329927711683109548" X = "7744735542292224619198421067303535767629647588258222392379329927711683109548"
Y = "6855516769916529066379811647277920115118980625614889267697023742462401590771" Y = "6855516769916529066379811647277920115118980625614889267697023742462401590771"
D = "38564357061962143106230288374146033267100509055924181407058066820384455255240" D = "38564357061962143106230288374146033267100509055924181407058066820384455255240"
AES = "BbnZ7Gc66t54a9kEFCf7FW8SGQuYypwHVeNkRYeNoqV6"
DB58 = "6jpbvo2ucrtrnpXXF4DQYuysh697isH9ppd2aT8uSRDh" DB58 = "6jpbvo2ucrtrnpXXF4DQYuysh697isH9ppd2aT8uSRDh"
SN = "91849736469742262272885892667727604096707836853856473239722372976236128900962" SN = "91849736469742262272885892667727604096707836853856473239722372976236128900962"
CertTime = "eQUriVtGtkWhPJFeLZjF" CertTime = "eQUriVtGtkWhPJFeLZjF"
@ -28,6 +29,7 @@ type TestKeyComponents struct {
X *big.Int X *big.Int
Y *big.Int Y *big.Int
D *big.Int D *big.Int
AES []byte
DBytes []byte DBytes []byte
PK *ecdsa.PrivateKey PK *ecdsa.PrivateKey
} }
@ -44,6 +46,9 @@ func (tk *TestKeyComponents) SetupKeyComponents(t *testing.T) {
tk.D, ok = new(big.Int).SetString(D, 10) tk.D, ok = new(big.Int).SetString(D, 10)
require.True(t, ok) require.True(t, ok)
tk.AES = base58.Decode(AES)
require.Len(t, tk.AES, 32)
tk.DBytes = base58.Decode(DB58) tk.DBytes = base58.Decode(DB58)
require.Exactly(t, tk.D.Bytes(), tk.DBytes) require.Exactly(t, tk.D.Bytes(), tk.DBytes)
@ -75,34 +80,36 @@ func (tcc *TestCertComponents) SetupCertComponents(t *testing.T) {
} }
type TestPairingServerComponents struct { type TestPairingServerComponents struct {
EphemeralPK *ecdsa.PrivateKey EphemeralPK *ecdsa.PrivateKey
OutboundIP net.IP EphemeralAES []byte
CertTime time.Time OutboundIP net.IP
Cert tls.Certificate Cert tls.Certificate
PS *PairingServer PS *PairingServer
} }
func (tpsc *TestPairingServerComponents) SetupPairingServerComponents(t *testing.T) { func (tpsc *TestPairingServerComponents) SetupPairingServerComponents(t *testing.T) {
var err error var err error
// Get 3 key components for tls.cert generation // Get 4 key components for tls.cert generation
// 1) Ephemeral private key // 1) Ephemeral private key
tpsc.EphemeralPK, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) tpsc.EphemeralPK, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err) require.NoError(t, err)
// 2) Device outbound IP address // 2) AES encryption key
tpsc.EphemeralAES, err = makeEncryptionKey(tpsc.EphemeralPK)
require.NoError(t, err)
// 3) Device outbound IP address
tpsc.OutboundIP, err = GetOutboundIP() tpsc.OutboundIP, err = GetOutboundIP()
require.NoError(t, err) require.NoError(t, err)
// 3) NotBefore time
tpsc.CertTime = time.Now()
// Generate tls.Certificate and Server // Generate tls.Certificate and Server
tpsc.Cert, _, err = GenerateCertFromKey(tpsc.EphemeralPK, tpsc.CertTime, tpsc.OutboundIP.String()) tpsc.Cert, _, err = GenerateCertFromKey(tpsc.EphemeralPK, time.Now(), tpsc.OutboundIP.String())
require.NoError(t, err) require.NoError(t, err)
tpsc.PS, err = NewPairingServer(&Config{ tpsc.PS, err = NewPairingServer(&Config{
PK: tpsc.EphemeralPK, PK: &tpsc.EphemeralPK.PublicKey,
EK: tpsc.EphemeralAES,
Cert: &tpsc.Cert, Cert: &tpsc.Cert,
Hostname: tpsc.OutboundIP.String()}) Hostname: tpsc.OutboundIP.String()})
require.NoError(t, err) require.NoError(t, err)
@ -112,8 +119,8 @@ type MockEncryptOnlyPayloadManager struct {
pem *PayloadEncryptionManager pem *PayloadEncryptionManager
} }
func NewMockEncryptOnlyPayloadManager(pk *ecdsa.PrivateKey) (*MockEncryptOnlyPayloadManager, error) { func NewMockEncryptOnlyPayloadManager(aesKey []byte) (*MockEncryptOnlyPayloadManager, error) {
pem, err := NewPayloadEncryptionManager(pk) pem, err := NewPayloadEncryptionManager(aesKey)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -2,13 +2,12 @@ package server
import ( import (
"crypto/ecdsa" "crypto/ecdsa"
"encoding/asn1" "crypto/elliptic"
"fmt" "fmt"
"math/big" "math/big"
"net" "net"
"net/url" "net/url"
"strings" "strings"
"time"
"github.com/btcsuite/btcutil/base58" "github.com/btcsuite/btcutil/base58"
) )
@ -29,18 +28,18 @@ type ConnectionParams struct {
version ConnectionParamVersion version ConnectionParamVersion
netIP net.IP netIP net.IP
port int port int
privateKey *ecdsa.PrivateKey publicKey *ecdsa.PublicKey
notBefore time.Time aesKey []byte
serverMode Mode serverMode Mode
} }
func NewConnectionParams(netIP net.IP, port int, privateKey *ecdsa.PrivateKey, notBefore time.Time, mode Mode) *ConnectionParams { func NewConnectionParams(netIP net.IP, port int, publicKey *ecdsa.PublicKey, aesKey []byte, mode Mode) *ConnectionParams {
cp := new(ConnectionParams) cp := new(ConnectionParams)
cp.version = Version1 cp.version = Version1
cp.netIP = netIP cp.netIP = netIP
cp.port = port cp.port = port
cp.privateKey = privateKey cp.publicKey = publicKey
cp.notBefore = notBefore cp.aesKey = aesKey
cp.serverMode = mode cp.serverMode = mode
return cp return cp
} }
@ -48,28 +47,24 @@ func NewConnectionParams(netIP net.IP, port int, privateKey *ecdsa.PrivateKey, n
// ToString generates a string required for generating a secure connection to another Status device. // ToString generates a string required for generating a secure connection to another Status device.
// //
// The returned string will look like below: // The returned string will look like below:
// - "2:4FHRnp:H6G:6jpbvo2ucrtrnpXXF4DQYuysh697isH9ppd2aT8uSRDh:eQUriVtGtkWhPJFeLZjF:2" // - "2:4FHRnp:H6G:uqnnMwVUfJc2Fkcaojet8F1ufKC3hZdGEt47joyBx9yd:BbnZ7Gc66t54a9kEFCf7FW8SGQuYypwHVeNkRYeNoqV6:2"
// //
// Format bytes encoded into a base58 string, delimited by ":" // Format bytes encoded into a base58 string, delimited by ":"
// - version // - version
// - net.IP // - net.IP
// - port // - port
// - ecdsa.PrivateKey D field // - ecdsa CompressedPublicKey
// - asn1.Marshal time.Time // - AES encryption key
// - server mode // - server mode
func (cp *ConnectionParams) ToString() (string, error) { func (cp *ConnectionParams) ToString() string {
v := base58.Encode(new(big.Int).SetInt64(int64(cp.version)).Bytes()) v := base58.Encode(new(big.Int).SetInt64(int64(cp.version)).Bytes())
ip := base58.Encode(cp.netIP) ip := base58.Encode(cp.netIP)
p := base58.Encode(new(big.Int).SetInt64(int64(cp.port)).Bytes()) p := base58.Encode(new(big.Int).SetInt64(int64(cp.port)).Bytes())
k := base58.Encode(cp.privateKey.D.Bytes()) k := base58.Encode(elliptic.MarshalCompressed(cp.publicKey.Curve, cp.publicKey.X, cp.publicKey.Y))
tb, err := asn1.Marshal(cp.notBefore.UTC()) ek := base58.Encode(cp.aesKey)
if err != nil {
return "", err
}
t := base58.Encode(tb)
m := base58.Encode(new(big.Int).SetInt64(int64(cp.serverMode)).Bytes()) m := base58.Encode(new(big.Int).SetInt64(int64(cp.serverMode)).Bytes())
return fmt.Sprintf("%s:%s:%s:%s:%s:%s", v, ip, p, k, t, m), nil return fmt.Sprintf("%s:%s:%s:%s:%s:%s", v, ip, p, k, ek, m)
} }
// FromString parses a connection params string required for to securely connect to another Status device. // FromString parses a connection params string required for to securely connect to another Status device.
@ -85,14 +80,10 @@ func (cp *ConnectionParams) FromString(s string) error {
cp.version = ConnectionParamVersion(new(big.Int).SetBytes(base58.Decode(sData[0])).Int64()) cp.version = ConnectionParamVersion(new(big.Int).SetBytes(base58.Decode(sData[0])).Int64())
cp.netIP = base58.Decode(sData[1]) cp.netIP = base58.Decode(sData[1])
cp.port = int(new(big.Int).SetBytes(base58.Decode(sData[2])).Int64()) cp.port = int(new(big.Int).SetBytes(base58.Decode(sData[2])).Int64())
cp.privateKey = ToECDSA(base58.Decode(sData[3])) cp.publicKey = new(ecdsa.PublicKey)
cp.publicKey.X, cp.publicKey.Y = elliptic.UnmarshalCompressed(elliptic.P256(), base58.Decode(sData[3]))
t := time.Time{} cp.publicKey.Curve = elliptic.P256()
_, err := asn1.Unmarshal(base58.Decode(sData[4]), &t) cp.aesKey = base58.Decode(sData[4])
if err != nil {
return err
}
cp.notBefore = t
cp.serverMode = Mode(new(big.Int).SetBytes(base58.Decode(sData[5])).Int64()) cp.serverMode = Mode(new(big.Int).SetBytes(base58.Decode(sData[5])).Int64())
return cp.validate() return cp.validate()
@ -114,12 +105,12 @@ func (cp *ConnectionParams) validate() error {
return err return err
} }
err = cp.validatePrivateKey() err = cp.validatePublicKey()
if err != nil { if err != nil {
return err return err
} }
err = cp.validateNotBefore() err = cp.validateAESKey()
if err != nil { if err != nil {
return err return err
} }
@ -151,22 +142,22 @@ func (cp *ConnectionParams) validatePort() error {
return fmt.Errorf("port '%d' outside of bounds of 1 - 65535", cp.port) return fmt.Errorf("port '%d' outside of bounds of 1 - 65535", cp.port)
} }
func (cp *ConnectionParams) validatePrivateKey() error { func (cp *ConnectionParams) validatePublicKey() error {
switch { switch {
case cp.privateKey.D == nil, cp.privateKey.D.Cmp(big.NewInt(0)) == 0: case cp.publicKey.Curve == nil, cp.publicKey.Curve != elliptic.P256():
return fmt.Errorf("private key D not set") return fmt.Errorf("public key Curve not `elliptic.P256`")
case cp.privateKey.PublicKey.X == nil, cp.privateKey.PublicKey.X.Cmp(big.NewInt(0)) == 0: case cp.publicKey.X == nil, cp.publicKey.X.Cmp(big.NewInt(0)) == 0:
return fmt.Errorf("public key X not set") return fmt.Errorf("public key X not set")
case cp.privateKey.PublicKey.Y == nil, cp.privateKey.PublicKey.Y.Cmp(big.NewInt(0)) == 0: case cp.publicKey.Y == nil, cp.publicKey.Y.Cmp(big.NewInt(0)) == 0:
return fmt.Errorf("public key Y not set") return fmt.Errorf("public key Y not set")
default: default:
return nil return nil
} }
} }
func (cp *ConnectionParams) validateNotBefore() error { func (cp *ConnectionParams) validateAESKey() error {
if cp.notBefore.IsZero() { if len(cp.aesKey) != 32 {
return fmt.Errorf("notBefore time is zero") return fmt.Errorf("AES key invalid length, expect length 32, received length '%d'", len(cp.aesKey))
} }
return nil return nil
} }
@ -180,22 +171,15 @@ func (cp *ConnectionParams) validateServerMode() error {
} }
} }
// Generate returns a *url.URL and encoded pem.Block generated from ConnectionParams set fields func (cp *ConnectionParams) URL() (*url.URL, error) {
func (cp *ConnectionParams) Generate() (*url.URL, []byte, error) {
err := cp.validate() err := cp.validate()
if err != nil { if err != nil {
return nil, nil, err return nil, err
} }
u := &url.URL{ u := &url.URL{
Scheme: "https", Scheme: "https",
Host: fmt.Sprintf("%s:%d", cp.netIP, cp.port), Host: fmt.Sprintf("%s:%d", cp.netIP, cp.port),
} }
return u, nil
_, pem, err := GenerateCertFromKey(cp.privateKey, cp.notBefore, cp.netIP.String())
if err != nil {
return nil, nil, err
}
return u, pem, nil
} }

View File

@ -1,16 +1,13 @@
package server package server
import ( import (
"crypto/ecdsa"
"crypto/x509"
"encoding/pem"
"testing" "testing"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
) )
var ( var (
connectionString = "2:4FHRnp:Q4:6jpbvo2ucrtrnpXXF4DQYuysh697isH9ppd2aT8uSRDh:eQUriVtGtkWhPJFeLZjF:3" connectionString = "2:4FHRnp:Q4:uqnnMwVUfJc2Fkcaojet8F1ufKC3hZdGEt47joyBx9yd:BbnZ7Gc66t54a9kEFCf7FW8SGQuYypwHVeNkRYeNoqV6:3"
) )
func TestConnectionParamsSuite(t *testing.T) { func TestConnectionParamsSuite(t *testing.T) {
@ -37,7 +34,8 @@ func (s *ConnectionParamsSuite) SetupSuite() {
s.server = &PairingServer{ s.server = &PairingServer{
Server: bs, Server: bs,
pk: s.PK, pk: &s.PK.PublicKey,
ek: s.AES,
mode: Sending, mode: Sending,
} }
} }
@ -46,9 +44,7 @@ func (s *ConnectionParamsSuite) TestConnectionParams_ToString() {
cp, err := s.server.MakeConnectionParams() cp, err := s.server.MakeConnectionParams()
s.Require().NoError(err) s.Require().NoError(err)
cps, err := cp.ToString() cps := cp.ToString()
s.Require().NoError(err)
s.Require().Equal(connectionString, cps) s.Require().Equal(connectionString, cps)
} }
@ -59,27 +55,13 @@ func (s *ConnectionParamsSuite) TestConnectionParams_Generate() {
s.Require().Exactly(Sending, cp.serverMode) s.Require().Exactly(Sending, cp.serverMode)
u, c, err := cp.Generate() u, err := cp.URL()
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Equal("https://127.0.0.1:1337", u.String()) s.Require().Equal("https://127.0.0.1:1337", u.String())
s.Require().Equal(defaultIP.String(), u.Hostname()) s.Require().Equal(defaultIP.String(), u.Hostname())
s.Require().Equal("1337", u.Port()) s.Require().Equal("1337", u.Port())
// Parse cert PEM into x509 cert s.Require().True(cp.publicKey.Equal(&s.PK.PublicKey))
block, _ := pem.Decode(c) s.Require().Equal(s.AES, cp.aesKey)
s.Require().NotNil(block)
cert, err := x509.ParseCertificate(block.Bytes)
s.Require().NoError(err)
// Compare cert values
cl := s.server.cert.Leaf
s.Require().NotEqual(cl.Signature, cert.Signature)
s.Require().Zero(cl.PublicKey.(*ecdsa.PublicKey).X.Cmp(cert.PublicKey.(*ecdsa.PublicKey).X))
s.Require().Zero(cl.PublicKey.(*ecdsa.PublicKey).Y.Cmp(cert.PublicKey.(*ecdsa.PublicKey).Y))
s.Require().Equal(cl.Version, cert.Version)
s.Require().Zero(cl.SerialNumber.Cmp(cert.SerialNumber))
s.Require().Exactly(cl.NotBefore, cert.NotBefore)
s.Require().Exactly(cl.NotAfter, cert.NotAfter)
s.Require().Exactly(cl.IPAddresses, cert.IPAddresses)
} }

View File

@ -63,8 +63,7 @@ func (s *GetOutboundIPSuite) TestGetOutboundIPWithFullServerE2e() {
cp, err := s.PS.MakeConnectionParams() cp, err := s.PS.MakeConnectionParams()
s.Require().NoError(err) s.Require().NoError(err)
qr, err := cp.ToString() qr := cp.ToString()
s.Require().NoError(err)
// Client reads QR code and parses the connection string // Client reads QR code and parses the connection string
ccp := new(ConnectionParams) ccp := new(ConnectionParams)

View File

@ -1,7 +1,6 @@
package server package server
import ( import (
"crypto/ecdsa"
"crypto/rand" "crypto/rand"
"encoding/json" "encoding/json"
"fmt" "fmt"
@ -44,8 +43,8 @@ type PairingPayloadManager struct {
} }
// NewPairingPayloadManager generates a new and initialised PairingPayloadManager // NewPairingPayloadManager generates a new and initialised PairingPayloadManager
func NewPairingPayloadManager(pk *ecdsa.PrivateKey, config *PairingPayloadManagerConfig) (*PairingPayloadManager, error) { func NewPairingPayloadManager(aesKey []byte, config *PairingPayloadManagerConfig) (*PairingPayloadManager, error) {
pem, err := NewPayloadEncryptionManager(pk) pem, err := NewPayloadEncryptionManager(aesKey)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -120,13 +119,8 @@ type PayloadEncryptionManager struct {
received *EncryptionPayload received *EncryptionPayload
} }
func NewPayloadEncryptionManager(pk *ecdsa.PrivateKey) (*PayloadEncryptionManager, error) { func NewPayloadEncryptionManager(aesKey []byte) (*PayloadEncryptionManager, error) {
ek, err := makeEncryptionKey(pk) return &PayloadEncryptionManager{aesKey, new(EncryptionPayload), new(EncryptionPayload)}, nil
if err != nil {
return nil, err
}
return &PayloadEncryptionManager{ek, new(EncryptionPayload), new(EncryptionPayload)}, nil
} }
func (pem *PayloadEncryptionManager) Encrypt(data []byte) error { func (pem *PayloadEncryptionManager) Encrypt(data []byte) error {

View File

@ -11,13 +11,15 @@ type PairingServer struct {
Server Server
PayloadManager PayloadManager
pk *ecdsa.PrivateKey pk *ecdsa.PublicKey
ek []byte
mode Mode mode Mode
} }
type Config struct { type Config struct {
// Connection fields // Connection fields
PK *ecdsa.PrivateKey PK *ecdsa.PublicKey
EK []byte
Cert *tls.Certificate Cert *tls.Certificate
Hostname string Hostname string
Mode Mode Mode Mode
@ -28,7 +30,7 @@ type Config struct {
// NewPairingServer returns a *PairingServer init from the given *Config // NewPairingServer returns a *PairingServer init from the given *Config
func NewPairingServer(config *Config) (*PairingServer, error) { func NewPairingServer(config *Config) (*PairingServer, error) {
pm, err := NewPairingPayloadManager(config.PK, config.PairingPayloadManagerConfig) pm, err := NewPairingPayloadManager(config.EK, config.PairingPayloadManagerConfig)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -38,21 +40,13 @@ func NewPairingServer(config *Config) (*PairingServer, error) {
config.Hostname, config.Hostname,
), ),
pk: config.PK, pk: config.PK,
ek: config.EK,
mode: config.Mode, mode: config.Mode,
PayloadManager: pm}, nil PayloadManager: pm}, nil
} }
// MakeConnectionParams generates a *ConnectionParams based on the Server's current state // MakeConnectionParams generates a *ConnectionParams based on the Server's current state
func (s *PairingServer) MakeConnectionParams() (*ConnectionParams, error) { func (s *PairingServer) MakeConnectionParams() (*ConnectionParams, error) {
switch {
case s.cert == nil:
return nil, fmt.Errorf("server has no cert set")
case s.cert.Leaf == nil:
return nil, fmt.Errorf("server cert has no Leaf set")
case s.cert.Leaf.NotBefore.IsZero():
return nil, fmt.Errorf("server cert Leaf has a zero value NotBefore")
}
netIP := net.ParseIP(s.hostname) netIP := net.ParseIP(s.hostname)
if netIP == nil { if netIP == nil {
return nil, fmt.Errorf("invalid ip address given '%s'", s.hostname) return nil, fmt.Errorf("invalid ip address given '%s'", s.hostname)
@ -67,7 +61,7 @@ func (s *PairingServer) MakeConnectionParams() (*ConnectionParams, error) {
return nil, fmt.Errorf("port is 0, listener is not yet set") return nil, fmt.Errorf("port is 0, listener is not yet set")
} }
return NewConnectionParams(netIP, s.port, s.pk, s.cert.Leaf.NotBefore, s.mode), nil return NewConnectionParams(netIP, s.port, s.pk, s.ek, s.mode), nil
} }
func (s *PairingServer) StartPairing() error { func (s *PairingServer) StartPairing() error {

View File

@ -1,6 +1,7 @@
package server package server
import ( import (
"crypto/ecdsa"
"testing" "testing"
"time" "time"
@ -22,7 +23,7 @@ func (s *PairingServerSuite) SetupSuite() {
func (s *PairingServerSuite) TestPairingServer_StartPairing() { func (s *PairingServerSuite) TestPairingServer_StartPairing() {
// Replace PairingServer.PayloadManager with a MockEncryptOnlyPayloadManager // Replace PairingServer.PayloadManager with a MockEncryptOnlyPayloadManager
pm, err := NewMockEncryptOnlyPayloadManager(s.EphemeralPK) pm, err := NewMockEncryptOnlyPayloadManager(s.EphemeralAES)
s.Require().NoError(err) s.Require().NoError(err)
s.PS.PayloadManager = pm s.PS.PayloadManager = pm
@ -48,8 +49,7 @@ func (s *PairingServerSuite) TestPairingServer_StartPairing() {
cp, err := s.PS.MakeConnectionParams() cp, err := s.PS.MakeConnectionParams()
s.Require().NoError(err) s.Require().NoError(err)
qr, err := cp.ToString() qr := cp.ToString()
s.Require().NoError(err)
// Client reads QR code and parses the connection string // Client reads QR code and parses the connection string
ccp := new(ConnectionParams) ccp := new(ConnectionParams)
@ -59,11 +59,20 @@ func (s *PairingServerSuite) TestPairingServer_StartPairing() {
c, err := NewPairingClient(ccp, nil) c, err := NewPairingClient(ccp, nil)
s.Require().NoError(err) s.Require().NoError(err)
err = c.getServerCert() // Compare cert values
s.Require().NoError(err) cert := c.serverCert
cl := s.PS.cert.Leaf
s.Require().Equal(cl.Signature, cert.Signature)
s.Require().Zero(cl.PublicKey.(*ecdsa.PublicKey).X.Cmp(cert.PublicKey.(*ecdsa.PublicKey).X))
s.Require().Zero(cl.PublicKey.(*ecdsa.PublicKey).Y.Cmp(cert.PublicKey.(*ecdsa.PublicKey).Y))
s.Require().Equal(cl.Version, cert.Version)
s.Require().Zero(cl.SerialNumber.Cmp(cert.SerialNumber))
s.Require().Exactly(cl.NotBefore, cert.NotBefore)
s.Require().Exactly(cl.NotAfter, cert.NotAfter)
s.Require().Exactly(cl.IPAddresses, cert.IPAddresses)
// Replace PairingClient.PayloadManager with a MockEncryptOnlyPayloadManager // Replace PairingClient.PayloadManager with a MockEncryptOnlyPayloadManager
c.PayloadManager, err = NewMockEncryptOnlyPayloadManager(s.EphemeralPK) c.PayloadManager, err = NewMockEncryptOnlyPayloadManager(s.EphemeralAES)
s.Require().NoError(err) s.Require().NoError(err)
if m == Receiving { if m == Receiving {