diff --git a/server/certs.go b/server/certs.go index ff70ef67d..9e29ce976 100644 --- a/server/certs.go +++ b/server/certs.go @@ -8,10 +8,12 @@ import ( "crypto/tls" "crypto/x509" "crypto/x509/pkix" + "encoding/asn1" "encoding/pem" "fmt" "math/big" "net" + "net/url" "time" ) @@ -146,3 +148,78 @@ func ToECDSA(d []byte) *ecdsa.PrivateKey { k.PublicKey.X, k.PublicKey.Y = k.PublicKey.Curve.ScalarBaseMult(d) return k } + +// verifyCertPublicKey checks that the ecdsa.PublicKey using in a x509.Certificate matches a known ecdsa.PublicKey +func verifyCertPublicKey(cert *x509.Certificate, publicKey *ecdsa.PublicKey) error { + certKey, ok := cert.PublicKey.(*ecdsa.PublicKey) + if !ok { + return fmt.Errorf("unexpected public key type, expected ecdsa.PublicKey") + } + + if !certKey.Equal(publicKey) { + return fmt.Errorf("server certificate MUST match the given public key") + } + return nil +} + +// verifyCertSig checks that a x509.Certificate's Signature verifies against x509.Certificate's PublicKey +// If the x509.Certificate's PublicKey is not an ecdsa.PublicKey an error will be thrown +func verifyCertSig(cert *x509.Certificate) (bool, error) { + var esig struct { + R, S *big.Int + } + if _, err := asn1.Unmarshal(cert.Signature, &esig); err != nil { + return false, err + } + + hash := sha256.New() + hash.Write(cert.RawTBSCertificate) + + ecKey, ok := cert.PublicKey.(*ecdsa.PublicKey) + if !ok { + return false, fmt.Errorf("certificate public is not an ecdsa.PublicKey") + } + + verified := ecdsa.Verify(ecKey, hash.Sum(nil), esig.R, esig.S) + return verified, nil +} + +// verifyCert verifies an x509.Certificate against a known ecdsa.PublicKey +// combining the checks of verifyCertPublicKey and verifyCertSig. +// If an x509.Certificate fails to verify an error is also thrown +func verifyCert(cert *x509.Certificate, publicKey *ecdsa.PublicKey) error { + err := verifyCertPublicKey(cert, publicKey) + if err != nil { + return err + } + + verified, err := verifyCertSig(cert) + if err != nil { + return err + } + if !verified { + return fmt.Errorf("server certificate signature MUST verify") + } + return nil +} + +// getServerCert pings a given tls host, extracts and returns its x509.Certificate +// the function expects there to be only 1 certificate +func getServerCert(URL *url.URL) (*x509.Certificate, error) { + conf := &tls.Config{ + InsecureSkipVerify: true, // Only skip verify to get the server's TLS cert. DO NOT skip for any other reason. + } + + conn, err := tls.Dial("tcp", URL.Host, conf) + if err != nil { + return nil, err + } + defer conn.Close() + + certs := conn.ConnectionState().PeerCertificates + if len(certs) != 1 { + return nil, fmt.Errorf("expected 1 TLS certificate, received '%d'", len(certs)) + } + + return certs[0], nil +} diff --git a/server/client.go b/server/client.go index a838ed5fc..dedfc43a7 100644 --- a/server/client.go +++ b/server/client.go @@ -3,14 +3,11 @@ package server import ( "bytes" "crypto/ecdsa" - "crypto/sha256" "crypto/tls" "crypto/x509" - "encoding/asn1" "encoding/pem" "fmt" "io/ioutil" - "math/big" "net/http" "net/url" ) @@ -109,65 +106,3 @@ func (c *PairingClient) receiveAccountData() error { return c.PayloadManager.Receive(payload) } - -func verifyPublicKey(cert *x509.Certificate, publicKey *ecdsa.PublicKey) error { - certKey, ok := cert.PublicKey.(*ecdsa.PublicKey) - if !ok { - return fmt.Errorf("unexpected public key type, expected ecdsa.PublicKey") - } - - if !certKey.Equal(publicKey) { - return fmt.Errorf("server certificate MUST match the given public key") - } - return nil -} - -func verifyCertSig(cert *x509.Certificate) (bool, error) { - var esig struct { - R, S *big.Int - } - if _, err := asn1.Unmarshal(cert.Signature, &esig); err != nil { - return false, err - } - - hash := sha256.New() - hash.Write(cert.RawTBSCertificate) - - verified := ecdsa.Verify(cert.PublicKey.(*ecdsa.PublicKey), hash.Sum(nil), esig.R, esig.S) - return verified, nil -} - -func verifyCert(cert *x509.Certificate, publicKey *ecdsa.PublicKey) error { - err := verifyPublicKey(cert, publicKey) - if err != nil { - return err - } - - verified, err := verifyCertSig(cert) - if err != nil { - return err - } - if !verified { - return fmt.Errorf("server certificate signature MUST verify") - } - return nil -} - -func getServerCert(URL *url.URL) (*x509.Certificate, error) { - conf := &tls.Config{ - InsecureSkipVerify: true, // Only skip verify to get the server's TLS cert. DO NOT skip for any other reason. - } - - conn, err := tls.Dial("tcp", URL.Host, conf) - if err != nil { - return nil, err - } - defer conn.Close() - - certs := conn.ConnectionState().PeerCertificates - if len(certs) != 1 { - return nil, fmt.Errorf("expected 1 TLS certificate, received '%d'", len(certs)) - } - - return certs[0], nil -}