Get server cert, run validation on cert
This commit is contained in:
parent
81f58dc869
commit
f7cbe0b1e8
|
@ -3,10 +3,13 @@ package server
|
|||
import (
|
||||
"bytes"
|
||||
"crypto/ecdsa"
|
||||
"crypto/sha256"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/asn1"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"net/url"
|
||||
)
|
||||
|
@ -19,6 +22,7 @@ type PairingClient struct {
|
|||
certPEM []byte
|
||||
privateKey *ecdsa.PrivateKey
|
||||
serverMode Mode
|
||||
serverCert *x509.Certificate
|
||||
}
|
||||
|
||||
func NewPairingClient(c *ConnectionParams, config *PairingPayloadManagerConfig) (*PairingClient, error) {
|
||||
|
@ -94,3 +98,55 @@ func (c *PairingClient) receiveAccountData() error {
|
|||
|
||||
return c.PayloadManager.Receive(payload)
|
||||
}
|
||||
|
||||
func verifyCertSig(cert *x509.Certificate) (bool, error) {
|
||||
var esig struct {
|
||||
R, S *big.Int
|
||||
}
|
||||
if _, err := asn1.Unmarshal(cert.Signature, &esig); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
hash := sha256.New()
|
||||
hash.Write(cert.RawTBSCertificate)
|
||||
|
||||
verified := ecdsa.Verify(cert.PublicKey.(*ecdsa.PublicKey), hash.Sum(nil), esig.R, esig.S)
|
||||
return verified, nil
|
||||
}
|
||||
|
||||
func (c *PairingClient) getServerCert() error {
|
||||
conf := &tls.Config{
|
||||
InsecureSkipVerify: true, // Only skip verify to get the server's TLS cert. DO NOT skip for any other reason.
|
||||
}
|
||||
|
||||
conn, err := tls.Dial("tcp", c.baseAddress.Host, conf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
certs := conn.ConnectionState().PeerCertificates
|
||||
if len(certs) != 1 {
|
||||
return fmt.Errorf("expected 1 TLS certificate, received '%d'", len(certs))
|
||||
}
|
||||
|
||||
certKey, ok := certs[0].PublicKey.(*ecdsa.PublicKey)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected public key type, expected ecdsa.PublicKey")
|
||||
}
|
||||
|
||||
if certKey.Equal(c.privateKey) {
|
||||
return fmt.Errorf("server certificate MUST match the given public key")
|
||||
}
|
||||
|
||||
verified, err := verifyCertSig(certs[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !verified {
|
||||
return fmt.Errorf("server certificate signature MUST verify")
|
||||
}
|
||||
|
||||
c.serverCert = certs[0]
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -117,6 +117,16 @@ func (s *Server) SetHandlers(handlers HandlerPatternMap) {
|
|||
s.handlers = handlers
|
||||
}
|
||||
|
||||
func (s *Server) AddHandlers(handlers HandlerPatternMap) {
|
||||
if s.handlers == nil {
|
||||
s.handlers = make(HandlerPatternMap)
|
||||
}
|
||||
|
||||
for name := range handlers {
|
||||
s.handlers[name] = handlers[name]
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) MakeBaseURL() *url.URL {
|
||||
return &url.URL{
|
||||
Scheme: "https",
|
||||
|
|
|
@ -59,6 +59,9 @@ func (s *PairingServerSuite) TestPairingServer_StartPairing() {
|
|||
c, err := NewPairingClient(ccp, nil)
|
||||
s.Require().NoError(err)
|
||||
|
||||
err = c.getServerCert()
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Replace PairingClient.PayloadManager with a MockEncryptOnlyPayloadManager
|
||||
c.PayloadManager, err = NewMockEncryptOnlyPayloadManager(s.EphemeralPK)
|
||||
s.Require().NoError(err)
|
||||
|
|
Loading…
Reference in New Issue