websocket: set the HTTP host header in WSS(#1834)

* Send host header

Co-authored-by: Thibault Meunier <thibault@cloudflare.com>

* Add comment and use splithostport

* Return error

* Defer the close

Co-authored-by: Thibault Meunier <thibault@cloudflare.com>
This commit is contained in:
Marco Munizaga 2022-10-20 10:51:05 +01:00 committed by GitHub
parent 828486ea04
commit 3e156d0813
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 54 additions and 1 deletions

View File

@ -132,7 +132,7 @@ func parseMultiaddr(maddr ma.Multiaddr) (*url.URL, error) {
type parsedWebsocketMultiaddr struct {
isWSS bool
// sni is the SNI value for the TLS handshake
// sni is the SNI value for the TLS handshake, and for setting HTTP Host header
sni *ma.Component
// the rest of the multiaddr before the /tls/sni/example.com/ws or /ws or /wss
restMultiaddr ma.Multiaddr

View File

@ -4,6 +4,7 @@ package websocket
import (
"context"
"crypto/tls"
"net"
"net/http"
"time"
@ -186,6 +187,17 @@ func (t *WebsocketTransport) maDial(ctx context.Context, raddr ma.Multiaddr) (ma
copytlsClientConf := t.tlsClientConf.Clone()
copytlsClientConf.ServerName = sni
dialer.TLSClientConfig = copytlsClientConf
ipAddr := wsurl.Host
// Setting the NetDial because we already have the resolved IP address, so we don't want to do another resolution.
// We set the `.Host` to the sni field so that the host header gets properly set.
dialer.NetDial = func(network, address string) (net.Conn, error) {
tcpAddr, err := net.ResolveTCPAddr(network, ipAddr)
if err != nil {
return nil, err
}
return net.DialTCP("tcp", nil, tcpAddr)
}
wsurl.Host = sni + ":" + wsurl.Port()
} else {
dialer.TLSClientConfig = t.tlsClientConf
}

View File

@ -9,10 +9,13 @@ import (
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"errors"
"fmt"
"io"
"math/big"
"net"
"net/http"
"strings"
"testing"
"time"
@ -218,6 +221,44 @@ func getTLSConf(t *testing.T, ip net.IP, start, end time.Time) *tls.Config {
}
}
func TestHostHeaderWss(t *testing.T) {
server := &http.Server{}
l, err := net.Listen("tcp", ":0")
require.NoError(t, err)
defer server.Close()
errChan := make(chan error, 1)
go func() {
server.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer close(errChan)
if !strings.Contains(r.Host, "example.com") {
errChan <- errors.New("Didn't see host header")
}
w.WriteHeader(http.StatusNotFound)
})
server.TLSConfig = getTLSConf(t, net.ParseIP("127.0.0.1"), time.Now(), time.Now().Add(time.Hour))
server.ServeTLS(l, "", "")
}()
_, port, err := net.SplitHostPort(l.Addr().String())
require.NoError(t, err)
serverMA := ma.StringCast("/ip4/127.0.0.1/tcp/" + port + "/tls/sni/example.com/ws")
tlsConfig := &tls.Config{InsecureSkipVerify: true} // Our test server doesn't have a cert signed by a CA
_, u := newSecureUpgrader(t)
tpt, err := New(u, network.NullResourceManager, WithTLSClientConfig(tlsConfig))
require.NoError(t, err)
masToDial, err := tpt.Resolve(context.Background(), serverMA)
require.NoError(t, err)
_, err = tpt.Dial(context.Background(), masToDial[0], test.RandPeerIDFatal(t))
require.Error(t, err)
err = <-errChan
require.NoError(t, err)
}
func TestDialWss(t *testing.T) {
serverMA, rid, errChan := testWSSServer(t, ma.StringCast("/ip4/127.0.0.1/tcp/0/tls/sni/example.com/ws"))
require.Contains(t, serverMA.String(), "tls")