Get server cert, run validation on cert

This commit is contained in:
Samuel Hawksby-Robinson 2022-08-06 14:26:16 +01:00
parent 81f58dc869
commit f7cbe0b1e8
3 changed files with 69 additions and 0 deletions

View File

@ -3,10 +3,13 @@ package server
import (
"bytes"
"crypto/ecdsa"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"encoding/asn1"
"fmt"
"io/ioutil"
"math/big"
"net/http"
"net/url"
)
@ -19,6 +22,7 @@ type PairingClient struct {
certPEM []byte
privateKey *ecdsa.PrivateKey
serverMode Mode
serverCert *x509.Certificate
}
func NewPairingClient(c *ConnectionParams, config *PairingPayloadManagerConfig) (*PairingClient, error) {
@ -94,3 +98,55 @@ func (c *PairingClient) receiveAccountData() error {
return c.PayloadManager.Receive(payload)
}
func verifyCertSig(cert *x509.Certificate) (bool, error) {
var esig struct {
R, S *big.Int
}
if _, err := asn1.Unmarshal(cert.Signature, &esig); err != nil {
return false, err
}
hash := sha256.New()
hash.Write(cert.RawTBSCertificate)
verified := ecdsa.Verify(cert.PublicKey.(*ecdsa.PublicKey), hash.Sum(nil), esig.R, esig.S)
return verified, nil
}
func (c *PairingClient) getServerCert() error {
conf := &tls.Config{
InsecureSkipVerify: true, // Only skip verify to get the server's TLS cert. DO NOT skip for any other reason.
}
conn, err := tls.Dial("tcp", c.baseAddress.Host, conf)
if err != nil {
return err
}
defer conn.Close()
certs := conn.ConnectionState().PeerCertificates
if len(certs) != 1 {
return fmt.Errorf("expected 1 TLS certificate, received '%d'", len(certs))
}
certKey, ok := certs[0].PublicKey.(*ecdsa.PublicKey)
if !ok {
return fmt.Errorf("unexpected public key type, expected ecdsa.PublicKey")
}
if certKey.Equal(c.privateKey) {
return fmt.Errorf("server certificate MUST match the given public key")
}
verified, err := verifyCertSig(certs[0])
if err != nil {
return err
}
if !verified {
return fmt.Errorf("server certificate signature MUST verify")
}
c.serverCert = certs[0]
return nil
}

View File

@ -117,6 +117,16 @@ func (s *Server) SetHandlers(handlers HandlerPatternMap) {
s.handlers = handlers
}
func (s *Server) AddHandlers(handlers HandlerPatternMap) {
if s.handlers == nil {
s.handlers = make(HandlerPatternMap)
}
for name := range handlers {
s.handlers[name] = handlers[name]
}
}
func (s *Server) MakeBaseURL() *url.URL {
return &url.URL{
Scheme: "https",

View File

@ -59,6 +59,9 @@ func (s *PairingServerSuite) TestPairingServer_StartPairing() {
c, err := NewPairingClient(ccp, nil)
s.Require().NoError(err)
err = c.getServerCert()
s.Require().NoError(err)
// Replace PairingClient.PayloadManager with a MockEncryptOnlyPayloadManager
c.PayloadManager, err = NewMockEncryptOnlyPayloadManager(s.EphemeralPK)
s.Require().NoError(err)