From f7cbe0b1e84a47ad850303dc24adcb0f3e954777 Mon Sep 17 00:00:00 2001 From: Samuel Hawksby-Robinson Date: Sat, 6 Aug 2022 14:26:16 +0100 Subject: [PATCH] Get server cert, run validation on cert --- server/client.go | 56 +++++++++++++++++++++++++++++++++++ server/server.go | 10 +++++++ server/server_pairing_test.go | 3 ++ 3 files changed, 69 insertions(+) diff --git a/server/client.go b/server/client.go index 7952f4fb4..3e8e6100c 100644 --- a/server/client.go +++ b/server/client.go @@ -3,10 +3,13 @@ package server import ( "bytes" "crypto/ecdsa" + "crypto/sha256" "crypto/tls" "crypto/x509" + "encoding/asn1" "fmt" "io/ioutil" + "math/big" "net/http" "net/url" ) @@ -19,6 +22,7 @@ type PairingClient struct { certPEM []byte privateKey *ecdsa.PrivateKey serverMode Mode + serverCert *x509.Certificate } func NewPairingClient(c *ConnectionParams, config *PairingPayloadManagerConfig) (*PairingClient, error) { @@ -94,3 +98,55 @@ func (c *PairingClient) receiveAccountData() error { return c.PayloadManager.Receive(payload) } + +func verifyCertSig(cert *x509.Certificate) (bool, error) { + var esig struct { + R, S *big.Int + } + if _, err := asn1.Unmarshal(cert.Signature, &esig); err != nil { + return false, err + } + + hash := sha256.New() + hash.Write(cert.RawTBSCertificate) + + verified := ecdsa.Verify(cert.PublicKey.(*ecdsa.PublicKey), hash.Sum(nil), esig.R, esig.S) + return verified, nil +} + +func (c *PairingClient) getServerCert() error { + conf := &tls.Config{ + InsecureSkipVerify: true, // Only skip verify to get the server's TLS cert. DO NOT skip for any other reason. + } + + conn, err := tls.Dial("tcp", c.baseAddress.Host, conf) + if err != nil { + return err + } + defer conn.Close() + + certs := conn.ConnectionState().PeerCertificates + if len(certs) != 1 { + return fmt.Errorf("expected 1 TLS certificate, received '%d'", len(certs)) + } + + certKey, ok := certs[0].PublicKey.(*ecdsa.PublicKey) + if !ok { + return fmt.Errorf("unexpected public key type, expected ecdsa.PublicKey") + } + + if certKey.Equal(c.privateKey) { + return fmt.Errorf("server certificate MUST match the given public key") + } + + verified, err := verifyCertSig(certs[0]) + if err != nil { + return err + } + if !verified { + return fmt.Errorf("server certificate signature MUST verify") + } + + c.serverCert = certs[0] + return nil +} diff --git a/server/server.go b/server/server.go index 1295b0a19..dafd013c0 100644 --- a/server/server.go +++ b/server/server.go @@ -117,6 +117,16 @@ func (s *Server) SetHandlers(handlers HandlerPatternMap) { s.handlers = handlers } +func (s *Server) AddHandlers(handlers HandlerPatternMap) { + if s.handlers == nil { + s.handlers = make(HandlerPatternMap) + } + + for name := range handlers { + s.handlers[name] = handlers[name] + } +} + func (s *Server) MakeBaseURL() *url.URL { return &url.URL{ Scheme: "https", diff --git a/server/server_pairing_test.go b/server/server_pairing_test.go index 5895dc46d..180381174 100644 --- a/server/server_pairing_test.go +++ b/server/server_pairing_test.go @@ -59,6 +59,9 @@ func (s *PairingServerSuite) TestPairingServer_StartPairing() { c, err := NewPairingClient(ccp, nil) s.Require().NoError(err) + err = c.getServerCert() + s.Require().NoError(err) + // Replace PairingClient.PayloadManager with a MockEncryptOnlyPayloadManager c.PayloadManager, err = NewMockEncryptOnlyPayloadManager(s.EphemeralPK) s.Require().NoError(err)