228 lines
5.9 KiB
Go
228 lines
5.9 KiB
Go
package webtransport
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"net/http"
|
|
"net/url"
|
|
"sync"
|
|
"time"
|
|
"unicode/utf8"
|
|
|
|
"github.com/quic-go/quic-go"
|
|
"github.com/quic-go/quic-go/http3"
|
|
"github.com/quic-go/quic-go/quicvarint"
|
|
)
|
|
|
|
const (
|
|
webTransportDraftOfferHeaderKey = "Sec-Webtransport-Http3-Draft02"
|
|
webTransportDraftHeaderKey = "Sec-Webtransport-Http3-Draft"
|
|
webTransportDraftHeaderValue = "draft02"
|
|
)
|
|
|
|
const (
|
|
webTransportFrameType = 0x41
|
|
webTransportUniStreamType = 0x54
|
|
)
|
|
|
|
type Server struct {
|
|
H3 http3.Server
|
|
|
|
// StreamReorderingTime is the time an incoming WebTransport stream that cannot be associated
|
|
// with a session is buffered.
|
|
// This can happen if the CONNECT request (that creates a new session) is reordered, and arrives
|
|
// after the first WebTransport stream(s) for that session.
|
|
// Defaults to 5 seconds.
|
|
StreamReorderingTimeout time.Duration
|
|
|
|
// CheckOrigin is used to validate the request origin, thereby preventing cross-site request forgery.
|
|
// CheckOrigin returns true if the request Origin header is acceptable.
|
|
// If unset, a safe default is used: If the Origin header is set, it is checked that it
|
|
// matches the request's Host header.
|
|
CheckOrigin func(r *http.Request) bool
|
|
|
|
ctx context.Context // is closed when Close is called
|
|
ctxCancel context.CancelFunc
|
|
refCount sync.WaitGroup
|
|
|
|
initOnce sync.Once
|
|
initErr error
|
|
|
|
conns *sessionManager
|
|
}
|
|
|
|
func (s *Server) initialize() error {
|
|
s.initOnce.Do(func() {
|
|
s.initErr = s.init()
|
|
})
|
|
return s.initErr
|
|
}
|
|
|
|
func (s *Server) init() error {
|
|
s.ctx, s.ctxCancel = context.WithCancel(context.Background())
|
|
timeout := s.StreamReorderingTimeout
|
|
if timeout == 0 {
|
|
timeout = 5 * time.Second
|
|
}
|
|
s.conns = newSessionManager(timeout)
|
|
if s.CheckOrigin == nil {
|
|
s.CheckOrigin = checkSameOrigin
|
|
}
|
|
|
|
// configure the http3.Server
|
|
if s.H3.AdditionalSettings == nil {
|
|
s.H3.AdditionalSettings = make(map[uint64]uint64)
|
|
}
|
|
s.H3.AdditionalSettings[settingsEnableWebtransport] = 1
|
|
s.H3.EnableDatagrams = true
|
|
if s.H3.StreamHijacker != nil {
|
|
return errors.New("StreamHijacker already set")
|
|
}
|
|
s.H3.StreamHijacker = func(ft http3.FrameType, qconn quic.Connection, str quic.Stream, err error) (bool /* hijacked */, error) {
|
|
if isWebTransportError(err) {
|
|
return true, nil
|
|
}
|
|
if ft != webTransportFrameType {
|
|
return false, nil
|
|
}
|
|
// Reading the varint might block if the peer sends really small frames, but this is fine.
|
|
// This function is called from the HTTP/3 request handler, which runs in its own Go routine.
|
|
id, err := quicvarint.Read(quicvarint.NewReader(str))
|
|
if err != nil {
|
|
if isWebTransportError(err) {
|
|
return true, nil
|
|
}
|
|
return false, err
|
|
}
|
|
s.conns.AddStream(qconn, str, sessionID(id))
|
|
return true, nil
|
|
}
|
|
s.H3.UniStreamHijacker = func(st http3.StreamType, qconn quic.Connection, str quic.ReceiveStream, err error) (hijacked bool) {
|
|
if st != webTransportUniStreamType && !isWebTransportError(err) {
|
|
return false
|
|
}
|
|
s.conns.AddUniStream(qconn, str)
|
|
return true
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *Server) Serve(conn net.PacketConn) error {
|
|
if err := s.initialize(); err != nil {
|
|
return err
|
|
}
|
|
return s.H3.Serve(conn)
|
|
}
|
|
|
|
// ServeQUICConn serves a single QUIC connection.
|
|
func (s *Server) ServeQUICConn(conn quic.Connection) error {
|
|
if err := s.initialize(); err != nil {
|
|
return err
|
|
}
|
|
return s.H3.ServeQUICConn(conn)
|
|
}
|
|
|
|
func (s *Server) ListenAndServe() error {
|
|
if err := s.initialize(); err != nil {
|
|
return err
|
|
}
|
|
return s.H3.ListenAndServe()
|
|
}
|
|
|
|
func (s *Server) ListenAndServeTLS(certFile, keyFile string) error {
|
|
if err := s.initialize(); err != nil {
|
|
return err
|
|
}
|
|
return s.H3.ListenAndServeTLS(certFile, keyFile)
|
|
}
|
|
|
|
func (s *Server) Close() error {
|
|
// Make sure that ctxCancel is defined.
|
|
// This is expected to be uncommon.
|
|
// It only happens if the server is closed without Serve / ListenAndServe having been called.
|
|
s.initOnce.Do(func() {})
|
|
|
|
if s.ctxCancel != nil {
|
|
s.ctxCancel()
|
|
}
|
|
if s.conns != nil {
|
|
s.conns.Close()
|
|
}
|
|
err := s.H3.Close()
|
|
s.refCount.Wait()
|
|
return err
|
|
}
|
|
|
|
func (s *Server) Upgrade(w http.ResponseWriter, r *http.Request) (*Session, error) {
|
|
if r.Method != http.MethodConnect {
|
|
return nil, fmt.Errorf("expected CONNECT request, got %s", r.Method)
|
|
}
|
|
if r.Proto != protocolHeader {
|
|
return nil, fmt.Errorf("unexpected protocol: %s", r.Proto)
|
|
}
|
|
if v, ok := r.Header[webTransportDraftOfferHeaderKey]; !ok || len(v) != 1 || v[0] != "1" {
|
|
return nil, fmt.Errorf("missing or invalid %s header", webTransportDraftOfferHeaderKey)
|
|
}
|
|
if !s.CheckOrigin(r) {
|
|
return nil, errors.New("webtransport: request origin not allowed")
|
|
}
|
|
w.Header().Add(webTransportDraftHeaderKey, webTransportDraftHeaderValue)
|
|
w.WriteHeader(http.StatusOK)
|
|
w.(http.Flusher).Flush()
|
|
|
|
httpStreamer, ok := r.Body.(http3.HTTPStreamer)
|
|
if !ok { // should never happen, unless quic-go changed the API
|
|
return nil, errors.New("failed to take over HTTP stream")
|
|
}
|
|
str := httpStreamer.HTTPStream()
|
|
sID := sessionID(str.StreamID())
|
|
|
|
hijacker, ok := w.(http3.Hijacker)
|
|
if !ok { // should never happen, unless quic-go changed the API
|
|
return nil, errors.New("failed to hijack")
|
|
}
|
|
return s.conns.AddSession(
|
|
hijacker.StreamCreator(),
|
|
sID,
|
|
r.Body.(http3.HTTPStreamer).HTTPStream(),
|
|
), nil
|
|
}
|
|
|
|
// copied from https://github.com/gorilla/websocket
|
|
func checkSameOrigin(r *http.Request) bool {
|
|
origin := r.Header.Get("Origin")
|
|
if origin == "" {
|
|
return true
|
|
}
|
|
u, err := url.Parse(origin)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
return equalASCIIFold(u.Host, r.Host)
|
|
}
|
|
|
|
// copied from https://github.com/gorilla/websocket
|
|
func equalASCIIFold(s, t string) bool {
|
|
for s != "" && t != "" {
|
|
sr, size := utf8.DecodeRuneInString(s)
|
|
s = s[size:]
|
|
tr, size := utf8.DecodeRuneInString(t)
|
|
t = t[size:]
|
|
if sr == tr {
|
|
continue
|
|
}
|
|
if 'A' <= sr && sr <= 'Z' {
|
|
sr = sr + 'a' - 'A'
|
|
}
|
|
if 'A' <= tr && tr <= 'Z' {
|
|
tr = tr + 'a' - 'A'
|
|
}
|
|
if sr != tr {
|
|
return false
|
|
}
|
|
}
|
|
return s == t
|
|
}
|