2024-06-05 16:10:03 -04:00

197 lines
5.3 KiB
Go

package webtransport
import (
"context"
"sync"
"time"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/http3"
"github.com/quic-go/quic-go/quicvarint"
)
// session is the map value in the conns map
type session struct {
created chan struct{} // is closed once the session map has been initialized
counter int // how many streams are waiting for this session to be established
conn *Session
}
type sessionManager struct {
refCount sync.WaitGroup
ctx context.Context
ctxCancel context.CancelFunc
timeout time.Duration
mx sync.Mutex
conns map[quic.ConnectionTracingID]map[sessionID]*session
}
func newSessionManager(timeout time.Duration) *sessionManager {
m := &sessionManager{
timeout: timeout,
conns: make(map[quic.ConnectionTracingID]map[sessionID]*session),
}
m.ctx, m.ctxCancel = context.WithCancel(context.Background())
return m
}
// AddStream adds a new bidirectional stream to a WebTransport session.
// If the WebTransport session has not yet been established,
// it starts a new go routine and waits for establishment of the session.
// If that takes longer than timeout, the stream is reset.
func (m *sessionManager) AddStream(connTracingID quic.ConnectionTracingID, str quic.Stream, id sessionID) {
sess, isExisting := m.getOrCreateSession(connTracingID, id)
if isExisting {
sess.conn.addIncomingStream(str)
return
}
m.refCount.Add(1)
go func() {
defer m.refCount.Done()
m.handleStream(str, sess)
m.mx.Lock()
defer m.mx.Unlock()
sess.counter--
// Once no more streams are waiting for this session to be established,
// and this session is still outstanding, delete it from the map.
if sess.counter == 0 && sess.conn == nil {
m.maybeDelete(connTracingID, id)
}
}()
}
func (m *sessionManager) maybeDelete(connTracingID quic.ConnectionTracingID, id sessionID) {
sessions, ok := m.conns[connTracingID]
if !ok { // should never happen
return
}
delete(sessions, id)
if len(sessions) == 0 {
delete(m.conns, connTracingID)
}
}
// AddUniStream adds a new unidirectional stream to a WebTransport session.
// If the WebTransport session has not yet been established,
// it starts a new go routine and waits for establishment of the session.
// If that takes longer than timeout, the stream is reset.
func (m *sessionManager) AddUniStream(connTracingID quic.ConnectionTracingID, str quic.ReceiveStream) {
idv, err := quicvarint.Read(quicvarint.NewReader(str))
if err != nil {
str.CancelRead(1337)
}
id := sessionID(idv)
sess, isExisting := m.getOrCreateSession(connTracingID, id)
if isExisting {
sess.conn.addIncomingUniStream(str)
return
}
m.refCount.Add(1)
go func() {
defer m.refCount.Done()
m.handleUniStream(str, sess)
m.mx.Lock()
defer m.mx.Unlock()
sess.counter--
// Once no more streams are waiting for this session to be established,
// and this session is still outstanding, delete it from the map.
if sess.counter == 0 && sess.conn == nil {
m.maybeDelete(connTracingID, id)
}
}()
}
func (m *sessionManager) getOrCreateSession(connTracingID quic.ConnectionTracingID, id sessionID) (sess *session, existed bool) {
m.mx.Lock()
defer m.mx.Unlock()
sessions, ok := m.conns[connTracingID]
if !ok {
sessions = make(map[sessionID]*session)
m.conns[connTracingID] = sessions
}
sess, ok = sessions[id]
if ok && sess.conn != nil {
return sess, true
}
if !ok {
sess = &session{created: make(chan struct{})}
sessions[id] = sess
}
sess.counter++
return sess, false
}
func (m *sessionManager) handleStream(str quic.Stream, sess *session) {
t := time.NewTimer(m.timeout)
defer t.Stop()
// When multiple streams are waiting for the same session to be established,
// the timeout is calculated for every stream separately.
select {
case <-sess.created:
sess.conn.addIncomingStream(str)
case <-t.C:
str.CancelRead(WebTransportBufferedStreamRejectedErrorCode)
str.CancelWrite(WebTransportBufferedStreamRejectedErrorCode)
case <-m.ctx.Done():
}
}
func (m *sessionManager) handleUniStream(str quic.ReceiveStream, sess *session) {
t := time.NewTimer(m.timeout)
defer t.Stop()
// When multiple streams are waiting for the same session to be established,
// the timeout is calculated for every stream separately.
select {
case <-sess.created:
sess.conn.addIncomingUniStream(str)
case <-t.C:
str.CancelRead(WebTransportBufferedStreamRejectedErrorCode)
case <-m.ctx.Done():
}
}
// AddSession adds a new WebTransport session.
func (m *sessionManager) AddSession(qconn http3.Connection, id sessionID, requestStr http3.Stream) *Session {
conn := newSession(id, qconn, requestStr)
connTracingID := qconn.Context().Value(quic.ConnectionTracingKey).(quic.ConnectionTracingID)
m.mx.Lock()
defer m.mx.Unlock()
sessions, ok := m.conns[connTracingID]
if !ok {
sessions = make(map[sessionID]*session)
m.conns[connTracingID] = sessions
}
if sess, ok := sessions[id]; ok {
// We might already have an entry of this session.
// This can happen when we receive a stream for this WebTransport session before we complete the HTTP request
// that establishes the session.
sess.conn = conn
close(sess.created)
return conn
}
c := make(chan struct{})
close(c)
sessions[id] = &session{created: c, conn: conn}
return conn
}
func (m *sessionManager) Close() {
m.ctxCancel()
m.refCount.Wait()
}