2023-02-22 17:58:17 -04:00
|
|
|
package webtransport
|
|
|
|
|
|
|
|
import (
|
|
|
|
"context"
|
|
|
|
"sync"
|
|
|
|
"time"
|
|
|
|
|
|
|
|
"github.com/quic-go/quic-go"
|
|
|
|
"github.com/quic-go/quic-go/http3"
|
|
|
|
"github.com/quic-go/quic-go/quicvarint"
|
|
|
|
)
|
|
|
|
|
|
|
|
// session is the map value in the conns map
|
|
|
|
type session struct {
|
|
|
|
created chan struct{} // is closed once the session map has been initialized
|
|
|
|
counter int // how many streams are waiting for this session to be established
|
|
|
|
conn *Session
|
|
|
|
}
|
|
|
|
|
|
|
|
type sessionManager struct {
|
|
|
|
refCount sync.WaitGroup
|
|
|
|
ctx context.Context
|
|
|
|
ctxCancel context.CancelFunc
|
|
|
|
|
|
|
|
timeout time.Duration
|
|
|
|
|
|
|
|
mx sync.Mutex
|
2024-06-05 16:10:03 -04:00
|
|
|
conns map[quic.ConnectionTracingID]map[sessionID]*session
|
2023-02-22 17:58:17 -04:00
|
|
|
}
|
|
|
|
|
|
|
|
func newSessionManager(timeout time.Duration) *sessionManager {
|
|
|
|
m := &sessionManager{
|
|
|
|
timeout: timeout,
|
2024-06-05 16:10:03 -04:00
|
|
|
conns: make(map[quic.ConnectionTracingID]map[sessionID]*session),
|
2023-02-22 17:58:17 -04:00
|
|
|
}
|
|
|
|
m.ctx, m.ctxCancel = context.WithCancel(context.Background())
|
|
|
|
return m
|
|
|
|
}
|
|
|
|
|
|
|
|
// AddStream adds a new bidirectional stream to a WebTransport session.
|
|
|
|
// If the WebTransport session has not yet been established,
|
|
|
|
// it starts a new go routine and waits for establishment of the session.
|
|
|
|
// If that takes longer than timeout, the stream is reset.
|
2024-06-05 16:10:03 -04:00
|
|
|
func (m *sessionManager) AddStream(connTracingID quic.ConnectionTracingID, str quic.Stream, id sessionID) {
|
|
|
|
sess, isExisting := m.getOrCreateSession(connTracingID, id)
|
2023-02-22 17:58:17 -04:00
|
|
|
if isExisting {
|
|
|
|
sess.conn.addIncomingStream(str)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
m.refCount.Add(1)
|
|
|
|
go func() {
|
|
|
|
defer m.refCount.Done()
|
|
|
|
m.handleStream(str, sess)
|
|
|
|
|
|
|
|
m.mx.Lock()
|
|
|
|
defer m.mx.Unlock()
|
|
|
|
|
|
|
|
sess.counter--
|
|
|
|
// Once no more streams are waiting for this session to be established,
|
|
|
|
// and this session is still outstanding, delete it from the map.
|
|
|
|
if sess.counter == 0 && sess.conn == nil {
|
2024-06-05 16:10:03 -04:00
|
|
|
m.maybeDelete(connTracingID, id)
|
2023-02-22 17:58:17 -04:00
|
|
|
}
|
|
|
|
}()
|
|
|
|
}
|
|
|
|
|
2024-06-05 16:10:03 -04:00
|
|
|
func (m *sessionManager) maybeDelete(connTracingID quic.ConnectionTracingID, id sessionID) {
|
|
|
|
sessions, ok := m.conns[connTracingID]
|
2023-02-22 17:58:17 -04:00
|
|
|
if !ok { // should never happen
|
|
|
|
return
|
|
|
|
}
|
|
|
|
delete(sessions, id)
|
|
|
|
if len(sessions) == 0 {
|
2024-06-05 16:10:03 -04:00
|
|
|
delete(m.conns, connTracingID)
|
2023-02-22 17:58:17 -04:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// AddUniStream adds a new unidirectional stream to a WebTransport session.
|
|
|
|
// If the WebTransport session has not yet been established,
|
|
|
|
// it starts a new go routine and waits for establishment of the session.
|
|
|
|
// If that takes longer than timeout, the stream is reset.
|
2024-06-05 16:10:03 -04:00
|
|
|
func (m *sessionManager) AddUniStream(connTracingID quic.ConnectionTracingID, str quic.ReceiveStream) {
|
2023-02-22 17:58:17 -04:00
|
|
|
idv, err := quicvarint.Read(quicvarint.NewReader(str))
|
|
|
|
if err != nil {
|
|
|
|
str.CancelRead(1337)
|
|
|
|
}
|
|
|
|
id := sessionID(idv)
|
|
|
|
|
2024-06-05 16:10:03 -04:00
|
|
|
sess, isExisting := m.getOrCreateSession(connTracingID, id)
|
2023-02-22 17:58:17 -04:00
|
|
|
if isExisting {
|
|
|
|
sess.conn.addIncomingUniStream(str)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
m.refCount.Add(1)
|
|
|
|
go func() {
|
|
|
|
defer m.refCount.Done()
|
|
|
|
m.handleUniStream(str, sess)
|
|
|
|
|
|
|
|
m.mx.Lock()
|
|
|
|
defer m.mx.Unlock()
|
|
|
|
|
|
|
|
sess.counter--
|
|
|
|
// Once no more streams are waiting for this session to be established,
|
|
|
|
// and this session is still outstanding, delete it from the map.
|
|
|
|
if sess.counter == 0 && sess.conn == nil {
|
2024-06-05 16:10:03 -04:00
|
|
|
m.maybeDelete(connTracingID, id)
|
2023-02-22 17:58:17 -04:00
|
|
|
}
|
|
|
|
}()
|
|
|
|
}
|
|
|
|
|
2024-06-05 16:10:03 -04:00
|
|
|
func (m *sessionManager) getOrCreateSession(connTracingID quic.ConnectionTracingID, id sessionID) (sess *session, existed bool) {
|
2023-02-22 17:58:17 -04:00
|
|
|
m.mx.Lock()
|
|
|
|
defer m.mx.Unlock()
|
|
|
|
|
2024-06-05 16:10:03 -04:00
|
|
|
sessions, ok := m.conns[connTracingID]
|
2023-02-22 17:58:17 -04:00
|
|
|
if !ok {
|
|
|
|
sessions = make(map[sessionID]*session)
|
2024-06-05 16:10:03 -04:00
|
|
|
m.conns[connTracingID] = sessions
|
2023-02-22 17:58:17 -04:00
|
|
|
}
|
|
|
|
|
|
|
|
sess, ok = sessions[id]
|
|
|
|
if ok && sess.conn != nil {
|
|
|
|
return sess, true
|
|
|
|
}
|
|
|
|
if !ok {
|
|
|
|
sess = &session{created: make(chan struct{})}
|
|
|
|
sessions[id] = sess
|
|
|
|
}
|
|
|
|
sess.counter++
|
|
|
|
return sess, false
|
|
|
|
}
|
|
|
|
|
|
|
|
func (m *sessionManager) handleStream(str quic.Stream, sess *session) {
|
|
|
|
t := time.NewTimer(m.timeout)
|
|
|
|
defer t.Stop()
|
|
|
|
|
|
|
|
// When multiple streams are waiting for the same session to be established,
|
|
|
|
// the timeout is calculated for every stream separately.
|
|
|
|
select {
|
|
|
|
case <-sess.created:
|
|
|
|
sess.conn.addIncomingStream(str)
|
|
|
|
case <-t.C:
|
|
|
|
str.CancelRead(WebTransportBufferedStreamRejectedErrorCode)
|
|
|
|
str.CancelWrite(WebTransportBufferedStreamRejectedErrorCode)
|
|
|
|
case <-m.ctx.Done():
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func (m *sessionManager) handleUniStream(str quic.ReceiveStream, sess *session) {
|
|
|
|
t := time.NewTimer(m.timeout)
|
|
|
|
defer t.Stop()
|
|
|
|
|
|
|
|
// When multiple streams are waiting for the same session to be established,
|
|
|
|
// the timeout is calculated for every stream separately.
|
|
|
|
select {
|
|
|
|
case <-sess.created:
|
|
|
|
sess.conn.addIncomingUniStream(str)
|
|
|
|
case <-t.C:
|
|
|
|
str.CancelRead(WebTransportBufferedStreamRejectedErrorCode)
|
|
|
|
case <-m.ctx.Done():
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// AddSession adds a new WebTransport session.
|
2024-06-05 16:10:03 -04:00
|
|
|
func (m *sessionManager) AddSession(qconn http3.Connection, id sessionID, requestStr http3.Stream) *Session {
|
2023-02-22 17:58:17 -04:00
|
|
|
conn := newSession(id, qconn, requestStr)
|
2024-06-05 16:10:03 -04:00
|
|
|
connTracingID := qconn.Context().Value(quic.ConnectionTracingKey).(quic.ConnectionTracingID)
|
2023-02-22 17:58:17 -04:00
|
|
|
|
|
|
|
m.mx.Lock()
|
|
|
|
defer m.mx.Unlock()
|
|
|
|
|
2024-06-05 16:10:03 -04:00
|
|
|
sessions, ok := m.conns[connTracingID]
|
2023-02-22 17:58:17 -04:00
|
|
|
if !ok {
|
|
|
|
sessions = make(map[sessionID]*session)
|
2024-06-05 16:10:03 -04:00
|
|
|
m.conns[connTracingID] = sessions
|
2023-02-22 17:58:17 -04:00
|
|
|
}
|
|
|
|
if sess, ok := sessions[id]; ok {
|
|
|
|
// We might already have an entry of this session.
|
|
|
|
// This can happen when we receive a stream for this WebTransport session before we complete the HTTP request
|
|
|
|
// that establishes the session.
|
|
|
|
sess.conn = conn
|
|
|
|
close(sess.created)
|
|
|
|
return conn
|
|
|
|
}
|
|
|
|
c := make(chan struct{})
|
|
|
|
close(c)
|
|
|
|
sessions[id] = &session{created: c, conn: conn}
|
|
|
|
return conn
|
|
|
|
}
|
|
|
|
|
|
|
|
func (m *sessionManager) Close() {
|
|
|
|
m.ctxCancel()
|
|
|
|
m.refCount.Wait()
|
|
|
|
}
|