websocket: set the HTTP host header in WSS(#1834)
* Send host header Co-authored-by: Thibault Meunier <thibault@cloudflare.com> * Add comment and use splithostport * Return error * Defer the close Co-authored-by: Thibault Meunier <thibault@cloudflare.com>
This commit is contained in:
parent
828486ea04
commit
3e156d0813
|
@ -132,7 +132,7 @@ func parseMultiaddr(maddr ma.Multiaddr) (*url.URL, error) {
|
|||
|
||||
type parsedWebsocketMultiaddr struct {
|
||||
isWSS bool
|
||||
// sni is the SNI value for the TLS handshake
|
||||
// sni is the SNI value for the TLS handshake, and for setting HTTP Host header
|
||||
sni *ma.Component
|
||||
// the rest of the multiaddr before the /tls/sni/example.com/ws or /ws or /wss
|
||||
restMultiaddr ma.Multiaddr
|
||||
|
|
|
@ -4,6 +4,7 @@ package websocket
|
|||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
|
@ -186,6 +187,17 @@ func (t *WebsocketTransport) maDial(ctx context.Context, raddr ma.Multiaddr) (ma
|
|||
copytlsClientConf := t.tlsClientConf.Clone()
|
||||
copytlsClientConf.ServerName = sni
|
||||
dialer.TLSClientConfig = copytlsClientConf
|
||||
ipAddr := wsurl.Host
|
||||
// Setting the NetDial because we already have the resolved IP address, so we don't want to do another resolution.
|
||||
// We set the `.Host` to the sni field so that the host header gets properly set.
|
||||
dialer.NetDial = func(network, address string) (net.Conn, error) {
|
||||
tcpAddr, err := net.ResolveTCPAddr(network, ipAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return net.DialTCP("tcp", nil, tcpAddr)
|
||||
}
|
||||
wsurl.Host = sni + ":" + wsurl.Port()
|
||||
} else {
|
||||
dialer.TLSClientConfig = t.tlsClientConf
|
||||
}
|
||||
|
|
|
@ -9,10 +9,13 @@ import (
|
|||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/big"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -218,6 +221,44 @@ func getTLSConf(t *testing.T, ip net.IP, start, end time.Time) *tls.Config {
|
|||
}
|
||||
}
|
||||
|
||||
func TestHostHeaderWss(t *testing.T) {
|
||||
server := &http.Server{}
|
||||
l, err := net.Listen("tcp", ":0")
|
||||
require.NoError(t, err)
|
||||
defer server.Close()
|
||||
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
server.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
defer close(errChan)
|
||||
if !strings.Contains(r.Host, "example.com") {
|
||||
errChan <- errors.New("Didn't see host header")
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
})
|
||||
server.TLSConfig = getTLSConf(t, net.ParseIP("127.0.0.1"), time.Now(), time.Now().Add(time.Hour))
|
||||
server.ServeTLS(l, "", "")
|
||||
}()
|
||||
|
||||
_, port, err := net.SplitHostPort(l.Addr().String())
|
||||
require.NoError(t, err)
|
||||
serverMA := ma.StringCast("/ip4/127.0.0.1/tcp/" + port + "/tls/sni/example.com/ws")
|
||||
|
||||
tlsConfig := &tls.Config{InsecureSkipVerify: true} // Our test server doesn't have a cert signed by a CA
|
||||
_, u := newSecureUpgrader(t)
|
||||
tpt, err := New(u, network.NullResourceManager, WithTLSClientConfig(tlsConfig))
|
||||
require.NoError(t, err)
|
||||
|
||||
masToDial, err := tpt.Resolve(context.Background(), serverMA)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = tpt.Dial(context.Background(), masToDial[0], test.RandPeerIDFatal(t))
|
||||
require.Error(t, err)
|
||||
|
||||
err = <-errChan
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestDialWss(t *testing.T) {
|
||||
serverMA, rid, errChan := testWSSServer(t, ma.StringCast("/ip4/127.0.0.1/tcp/0/tls/sni/example.com/ws"))
|
||||
require.Contains(t, serverMA.String(), "tls")
|
||||
|
|
Loading…
Reference in New Issue