package webtransport import ( "context" "sync" "time" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/http3" "github.com/quic-go/quic-go/quicvarint" ) // session is the map value in the conns map type session struct { created chan struct{} // is closed once the session map has been initialized counter int // how many streams are waiting for this session to be established conn *Session } type sessionManager struct { refCount sync.WaitGroup ctx context.Context ctxCancel context.CancelFunc timeout time.Duration mx sync.Mutex conns map[quic.ConnectionTracingID]map[sessionID]*session } func newSessionManager(timeout time.Duration) *sessionManager { m := &sessionManager{ timeout: timeout, conns: make(map[quic.ConnectionTracingID]map[sessionID]*session), } m.ctx, m.ctxCancel = context.WithCancel(context.Background()) return m } // AddStream adds a new bidirectional stream to a WebTransport session. // If the WebTransport session has not yet been established, // it starts a new go routine and waits for establishment of the session. // If that takes longer than timeout, the stream is reset. func (m *sessionManager) AddStream(connTracingID quic.ConnectionTracingID, str quic.Stream, id sessionID) { sess, isExisting := m.getOrCreateSession(connTracingID, id) if isExisting { sess.conn.addIncomingStream(str) return } m.refCount.Add(1) go func() { defer m.refCount.Done() m.handleStream(str, sess) m.mx.Lock() defer m.mx.Unlock() sess.counter-- // Once no more streams are waiting for this session to be established, // and this session is still outstanding, delete it from the map. if sess.counter == 0 && sess.conn == nil { m.maybeDelete(connTracingID, id) } }() } func (m *sessionManager) maybeDelete(connTracingID quic.ConnectionTracingID, id sessionID) { sessions, ok := m.conns[connTracingID] if !ok { // should never happen return } delete(sessions, id) if len(sessions) == 0 { delete(m.conns, connTracingID) } } // AddUniStream adds a new unidirectional stream to a WebTransport session. // If the WebTransport session has not yet been established, // it starts a new go routine and waits for establishment of the session. // If that takes longer than timeout, the stream is reset. func (m *sessionManager) AddUniStream(connTracingID quic.ConnectionTracingID, str quic.ReceiveStream) { idv, err := quicvarint.Read(quicvarint.NewReader(str)) if err != nil { str.CancelRead(1337) } id := sessionID(idv) sess, isExisting := m.getOrCreateSession(connTracingID, id) if isExisting { sess.conn.addIncomingUniStream(str) return } m.refCount.Add(1) go func() { defer m.refCount.Done() m.handleUniStream(str, sess) m.mx.Lock() defer m.mx.Unlock() sess.counter-- // Once no more streams are waiting for this session to be established, // and this session is still outstanding, delete it from the map. if sess.counter == 0 && sess.conn == nil { m.maybeDelete(connTracingID, id) } }() } func (m *sessionManager) getOrCreateSession(connTracingID quic.ConnectionTracingID, id sessionID) (sess *session, existed bool) { m.mx.Lock() defer m.mx.Unlock() sessions, ok := m.conns[connTracingID] if !ok { sessions = make(map[sessionID]*session) m.conns[connTracingID] = sessions } sess, ok = sessions[id] if ok && sess.conn != nil { return sess, true } if !ok { sess = &session{created: make(chan struct{})} sessions[id] = sess } sess.counter++ return sess, false } func (m *sessionManager) handleStream(str quic.Stream, sess *session) { t := time.NewTimer(m.timeout) defer t.Stop() // When multiple streams are waiting for the same session to be established, // the timeout is calculated for every stream separately. select { case <-sess.created: sess.conn.addIncomingStream(str) case <-t.C: str.CancelRead(WebTransportBufferedStreamRejectedErrorCode) str.CancelWrite(WebTransportBufferedStreamRejectedErrorCode) case <-m.ctx.Done(): } } func (m *sessionManager) handleUniStream(str quic.ReceiveStream, sess *session) { t := time.NewTimer(m.timeout) defer t.Stop() // When multiple streams are waiting for the same session to be established, // the timeout is calculated for every stream separately. select { case <-sess.created: sess.conn.addIncomingUniStream(str) case <-t.C: str.CancelRead(WebTransportBufferedStreamRejectedErrorCode) case <-m.ctx.Done(): } } // AddSession adds a new WebTransport session. func (m *sessionManager) AddSession(qconn http3.Connection, id sessionID, requestStr http3.Stream) *Session { conn := newSession(id, qconn, requestStr) connTracingID := qconn.Context().Value(quic.ConnectionTracingKey).(quic.ConnectionTracingID) m.mx.Lock() defer m.mx.Unlock() sessions, ok := m.conns[connTracingID] if !ok { sessions = make(map[sessionID]*session) m.conns[connTracingID] = sessions } if sess, ok := sessions[id]; ok { // We might already have an entry of this session. // This can happen when we receive a stream for this WebTransport session before we complete the HTTP request // that establishes the session. sess.conn = conn close(sess.created) return conn } c := make(chan struct{}) close(c) sessions[id] = &session{created: c, conn: conn} return conn } func (m *sessionManager) Close() { m.ctxCancel() m.refCount.Wait() }