go-libp2p/p2p/transport/webtransport/crypto_test.go

201 lines
6.3 KiB
Go
Raw Normal View History

package libp2pwebtransport
import (
"bytes"
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/x509"
"crypto/x509/pkix"
"fmt"
"io"
"math/big"
mrand "math/rand"
"testing"
"time"
ic "github.com/libp2p/go-libp2p/core/crypto"
"github.com/multiformats/go-multihash"
"github.com/stretchr/testify/require"
)
func sha256Multihash(t *testing.T, b []byte) multihash.DecodedMultihash {
t.Helper()
hash := sha256.Sum256(b)
h, err := multihash.Encode(hash[:], multihash.SHA2_256)
require.NoError(t, err)
dh, err := multihash.Decode(h)
require.NoError(t, err)
return *dh
}
func generateCertWithKey(t *testing.T, key crypto.PrivateKey, start, end time.Time) *x509.Certificate {
t.Helper()
serial := int64(mrand.Uint64())
if serial < 0 {
serial = -serial
}
certTempl := &x509.Certificate{
SerialNumber: big.NewInt(serial),
Subject: pkix.Name{},
NotBefore: start,
NotAfter: end,
IsCA: true,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
BasicConstraintsValid: true,
}
caBytes, err := x509.CreateCertificate(rand.Reader, certTempl, certTempl, key.(interface{ Public() crypto.PublicKey }).Public(), key)
require.NoError(t, err)
ca, err := x509.ParseCertificate(caBytes)
require.NoError(t, err)
return ca
}
func TestCertificateVerification(t *testing.T) {
now := time.Now()
ecdsaKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
rsaKey, err := rsa.GenerateKey(rand.Reader, 1024)
require.NoError(t, err)
t.Run("accepting a valid cert", func(t *testing.T) {
validCert := generateCertWithKey(t, ecdsaKey, now, now.Add(14*24*time.Hour))
require.NoError(t, verifyRawCerts([][]byte{validCert.Raw}, []multihash.DecodedMultihash{sha256Multihash(t, validCert.Raw)}))
})
for _, tc := range [...]struct {
name string
cert *x509.Certificate
errStr string
}{
{
name: "validitity period too long",
cert: generateCertWithKey(t, ecdsaKey, now, now.Add(15*24*time.Hour)),
errStr: "cert must not be valid for longer than 14 days",
},
{
name: "uses RSA key",
cert: generateCertWithKey(t, rsaKey, now, now.Add(14*24*time.Hour)),
errStr: "RSA",
},
{
name: "expired certificate",
cert: generateCertWithKey(t, ecdsaKey, now.Add(-14*24*time.Hour), now),
errStr: "cert not valid",
},
{
name: "not yet valid",
cert: generateCertWithKey(t, ecdsaKey, now.Add(time.Hour), now.Add(time.Hour+14*24*time.Hour)),
errStr: "cert not valid",
},
} {
tc := tc
t.Run(fmt.Sprintf("rejecting invalid certificates: %s", tc.name), func(t *testing.T) {
err := verifyRawCerts([][]byte{tc.cert.Raw}, []multihash.DecodedMultihash{sha256Multihash(t, tc.cert.Raw)})
require.Error(t, err)
require.Contains(t, err.Error(), tc.errStr)
})
}
for _, tc := range [...]struct {
name string
certs [][]byte
hashes []multihash.DecodedMultihash
errStr string
}{
{
name: "no certificates",
hashes: []multihash.DecodedMultihash{sha256Multihash(t, []byte("foobar"))},
errStr: "no cert",
},
{
name: "certificate not parseable",
certs: [][]byte{[]byte("foobar")},
hashes: []multihash.DecodedMultihash{sha256Multihash(t, []byte("foobar"))},
errStr: "x509: malformed certificate",
},
{
name: "hash mismatch",
certs: [][]byte{generateCertWithKey(t, ecdsaKey, now, now.Add(15*24*time.Hour)).Raw},
hashes: []multihash.DecodedMultihash{sha256Multihash(t, []byte("foobar"))},
errStr: "cert hash not found",
},
} {
tc := tc
t.Run(fmt.Sprintf("rejecting invalid certificates: %s", tc.name), func(t *testing.T) {
err := verifyRawCerts(tc.certs, tc.hashes)
require.Error(t, err)
require.Contains(t, err.Error(), tc.errStr)
})
}
}
func TestDeterministicCertHashes(t *testing.T) {
// Run this test 1000 times since we want to make sure the signatures are deterministic
runs := 1000
for i := 0; i < runs; i++ {
t.Run(fmt.Sprintf("Run=%d", i), func(t *testing.T) {
zeroSeed := [32]byte{}
priv, _, err := ic.GenerateEd25519Key(bytes.NewReader(zeroSeed[:]))
require.NoError(t, err)
cert, certPriv, err := generateCert(priv, time.Time{}, time.Time{}.Add(time.Hour*24*14))
require.NoError(t, err)
keyBytes, err := x509.MarshalECPrivateKey(certPriv)
require.NoError(t, err)
cert2, certPriv2, err := generateCert(priv, time.Time{}, time.Time{}.Add(time.Hour*24*14))
require.NoError(t, err)
require.Equal(t, cert2.Signature, cert.Signature)
require.Equal(t, cert2.Raw, cert.Raw)
keyBytes2, err := x509.MarshalECPrivateKey(certPriv2)
require.NoError(t, err)
require.Equal(t, keyBytes, keyBytes2)
})
}
}
// TestDeterministicSig tests that our hack around making ECDSA signatures
// deterministic works. If this fails, this means we need to try another
// strategy to make deterministic signatures or try something else entirely.
// See deterministicReader for more context.
func TestDeterministicSig(t *testing.T) {
// Run this test 1000 times since we want to make sure the signatures are deterministic
runs := 1000
for i := 0; i < runs; i++ {
t.Run(fmt.Sprintf("Run=%d", i), func(t *testing.T) {
zeroSeed := [32]byte{}
deterministicHKDFReader := newDeterministicReader(zeroSeed[:], nil, deterministicCertInfo)
b := [1024]byte{}
io.ReadFull(deterministicHKDFReader, b[:])
caPrivateKey, err := ecdsa.GenerateKey(elliptic.P256(), deterministicHKDFReader)
require.NoError(t, err)
sig, err := caPrivateKey.Sign(deterministicHKDFReader, b[:], crypto.SHA256)
require.NoError(t, err)
deterministicHKDFReader = newDeterministicReader(zeroSeed[:], nil, deterministicCertInfo)
b2 := [1024]byte{}
io.ReadFull(deterministicHKDFReader, b2[:])
caPrivateKey2, err := ecdsa.GenerateKey(elliptic.P256(), deterministicHKDFReader)
require.NoError(t, err)
sig2, err := caPrivateKey2.Sign(deterministicHKDFReader, b2[:], crypto.SHA256)
require.NoError(t, err)
keyBytes, err := x509.MarshalECPrivateKey(caPrivateKey)
require.NoError(t, err)
keyBytes2, err := x509.MarshalECPrivateKey(caPrivateKey2)
require.NoError(t, err)
require.Equal(t, sig, sig2)
require.Equal(t, keyBytes, keyBytes2)
})
}
}