mirror of https://github.com/vacp2p/nim-quic.git
Introduce state pattern for Streams
Streams can now be in one of two states: open or closed. Separates a Stream from its ngtcp2 implementation specifics. Breaks the circular dependency between Connection and Stream, allowing them to be defined in separate modules.
This commit is contained in:
parent
4c005eee37
commit
a50a987542
|
@ -1,4 +1,5 @@
|
|||
import pkg/chronos
|
||||
import ./stream
|
||||
import ./asyncloop
|
||||
import ./ngtcp2
|
||||
import ./connectionid
|
||||
|
|
|
@ -13,9 +13,6 @@ export receive, send
|
|||
export isHandshakeCompleted
|
||||
export handshake
|
||||
export ids
|
||||
export Stream
|
||||
export openStream
|
||||
export close
|
||||
export read, write
|
||||
export incomingStream
|
||||
export destroy
|
||||
|
|
|
@ -2,6 +2,7 @@ import std/sequtils
|
|||
import pkg/chronos
|
||||
import pkg/ngtcp2
|
||||
import ../datagram
|
||||
import ../stream
|
||||
import ../openarray
|
||||
import ../congestion
|
||||
import ../timeout
|
||||
|
@ -23,10 +24,6 @@ type
|
|||
timeout*: Timeout
|
||||
onNewId*: proc(id: ConnectionId)
|
||||
onRemoveId*: proc(id: ConnectionId)
|
||||
Stream* = ref object
|
||||
id*: int64
|
||||
connection*: Ngtcp2Connection
|
||||
incoming*: AsyncQueue[seq[byte]]
|
||||
|
||||
proc destroy(connection: var Ngtcp2ConnectionObj) =
|
||||
if connection.conn != nil:
|
||||
|
@ -60,14 +57,6 @@ proc ids*(connection: Ngtcp2Connection): seq[ConnectionId] =
|
|||
discard ngtcp2_conn_get_scid(connection.conn, scids.toPtr)
|
||||
scids.mapIt(ConnectionId(it.data[0..<it.datalen]))
|
||||
|
||||
proc newStream*(connection: Ngtcp2Connection, id: int64): Stream =
|
||||
let incoming = newAsyncQueue[seq[byte]]()
|
||||
let stream = Stream(connection: connection, id: id, incoming: incoming)
|
||||
let conn = connection.conn
|
||||
let userdata = unsafeAddr stream[]
|
||||
checkResult ngtcp2_conn_set_stream_user_data(conn, stream.id, userdata)
|
||||
stream
|
||||
|
||||
proc updateTimeout*(connection: Ngtcp2Connection) =
|
||||
let expiry = ngtcp2_conn_get_expiry(connection.conn)
|
||||
if expiry != uint64.high:
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
import pkg/chronos
|
||||
import ../../stream
|
||||
import ./state
|
||||
|
||||
type
|
||||
ClosedStream* = ref object of StreamState
|
||||
ClosedStreamError* = object of StreamError
|
||||
|
||||
proc read(stream: Stream, state: ClosedStream): Future[seq[byte]] {.async.} =
|
||||
raise newException(ClosedStreamError, "stream is closed")
|
||||
|
||||
proc write(stream: Stream, state: ClosedStream, bytes: seq[byte]) {.async.} =
|
||||
raise newException(ClosedStreamError, "stream is closed")
|
||||
|
||||
proc close(stream: Stream, state: ClosedStream) {.async.} =
|
||||
discard
|
||||
|
||||
proc destroy(state: ClosedStream) =
|
||||
discard
|
||||
|
||||
proc newClosedState*(): ClosedStream =
|
||||
newState[ClosedStream]()
|
|
@ -0,0 +1,57 @@
|
|||
import pkg/chronos
|
||||
import pkg/ngtcp2
|
||||
import ../../stream
|
||||
import ../../openarray
|
||||
import ../connection
|
||||
import ../errors
|
||||
import ../pointers
|
||||
import ./sending
|
||||
import ./state
|
||||
import ./closedstate
|
||||
|
||||
type
|
||||
OpenStream* = ref object of StreamState
|
||||
id: int64
|
||||
connection: Ngtcp2Connection
|
||||
incoming: AsyncQueue[seq[byte]]
|
||||
|
||||
proc read(stream: Stream, state: OpenStream): Future[seq[byte]] {.async.} =
|
||||
result = await state.incoming.get()
|
||||
|
||||
proc write(stream: Stream, state: OpenStream, bytes: seq[byte]) {.async.} =
|
||||
let connection = state.connection
|
||||
let streamId = state.id
|
||||
var messagePtr = bytes.toUnsafePtr
|
||||
var messageLen = bytes.len.uint
|
||||
var done = false
|
||||
while not done:
|
||||
let written = await send(connection, streamId, messagePtr, messageLen)
|
||||
messagePtr = messagePtr + written
|
||||
messageLen = messageLen - written.uint
|
||||
done = messageLen == 0
|
||||
|
||||
proc close(stream: Stream, state: OpenStream) {.async.} =
|
||||
checkResult ngtcp2_conn_shutdown_stream(state.connection.conn, state.id, 0)
|
||||
stream.switch(newClosedState())
|
||||
|
||||
proc destroy(state: OpenStream) =
|
||||
let conn = state.connection.conn
|
||||
let id = state.id
|
||||
checkResult ngtcp2_conn_set_stream_user_data(conn, id, nil)
|
||||
|
||||
proc setUserData(state: OpenStream) =
|
||||
let conn = state.connection.conn
|
||||
let id = state.id
|
||||
let userdata = unsafeAddr state[]
|
||||
checkResult ngtcp2_conn_set_stream_user_data(conn, id, userdata)
|
||||
|
||||
proc newOpenState*(connection: Ngtcp2Connection, id: int64): OpenStream =
|
||||
let state = newState[OpenStream]()
|
||||
state.connection = connection
|
||||
state.id = id
|
||||
state.incoming = newAsyncQueue[seq[byte]]()
|
||||
state.setUserData()
|
||||
state
|
||||
|
||||
proc receive*(state: OpenStream, bytes: seq[byte]) =
|
||||
state.incoming.putNoWait(bytes)
|
|
@ -0,0 +1,44 @@
|
|||
import pkg/chronos
|
||||
import pkg/ngtcp2
|
||||
import ../../datagram
|
||||
import ../../congestion
|
||||
import ../connection
|
||||
import ../path
|
||||
import ../errors
|
||||
import ../timestamp
|
||||
|
||||
proc trySend(connection: Ngtcp2Connection,
|
||||
streamId: int64,
|
||||
messagePtr: ptr byte,
|
||||
messageLen: uint,
|
||||
written: var int): Datagram =
|
||||
var packetInfo: ngtcp2_pkt_info
|
||||
let length = ngtcp2_conn_write_stream(
|
||||
connection.conn,
|
||||
connection.path.toPathPtr,
|
||||
addr packetInfo,
|
||||
addr connection.buffer[0],
|
||||
connection.buffer.len.uint,
|
||||
addr written,
|
||||
0,
|
||||
streamId,
|
||||
messagePtr,
|
||||
messageLen,
|
||||
now()
|
||||
)
|
||||
checkResult length.cint
|
||||
let data = connection.buffer[0..<length]
|
||||
let ecn = ECN(packetInfo.ecn)
|
||||
Datagram(data: data, ecn: ecn)
|
||||
|
||||
proc send*(connection: Ngtcp2Connection,
|
||||
streamId: int64,
|
||||
messagePtr: ptr byte,
|
||||
messageLen: uint): Future[int] {.async.} =
|
||||
var datagram = trySend(connection, streamId, messagePtr, messageLen, result)
|
||||
while datagram.data.len == 0:
|
||||
connection.flowing.clear()
|
||||
await connection.flowing.wait()
|
||||
datagram = trySend(connection, streamId, messagePtr, messageLen, result)
|
||||
await connection.outgoing.put(datagram)
|
||||
connection.updateTimeout()
|
|
@ -0,0 +1,14 @@
|
|||
import pkg/chronos
|
||||
import ../../stream
|
||||
|
||||
template newState*[T: StreamState](): T =
|
||||
var state = T()
|
||||
state.read = proc(stream: Stream): Future[seq[byte]] =
|
||||
read(stream, state)
|
||||
state.write = proc(stream: Stream, bytes: seq[byte]): Future[void] =
|
||||
write(stream, state, bytes)
|
||||
state.close = proc(stream: Stream): Future[void] =
|
||||
close(stream, state)
|
||||
state.destroy = proc() =
|
||||
destroy(state)
|
||||
state
|
|
@ -1,76 +1,22 @@
|
|||
import pkg/chronos
|
||||
import pkg/ngtcp2
|
||||
import ../openarray
|
||||
import ../datagram
|
||||
import ../congestion
|
||||
import ../stream
|
||||
import ./connection
|
||||
import ./errors
|
||||
import ./path
|
||||
import ./pointers
|
||||
import ./timestamp
|
||||
import ./stream/openstate
|
||||
|
||||
proc newStream*(connection: Ngtcp2Connection, id: int64): Stream =
|
||||
newStream(id, newOpenState(connection, id))
|
||||
|
||||
proc openStream*(connection: Ngtcp2Connection): Stream =
|
||||
var id: int64
|
||||
checkResult ngtcp2_conn_open_uni_stream(connection.conn, addr id, nil)
|
||||
var stream = newStream(connection, id)
|
||||
checkResult ngtcp2_conn_set_stream_user_data(connection.conn, id, addr result)
|
||||
stream
|
||||
|
||||
proc close*(stream: Stream) =
|
||||
checkResult ngtcp2_conn_shutdown_stream(stream.connection.conn, stream.id, 0)
|
||||
|
||||
proc trySend(stream: Stream,
|
||||
messagePtr: ptr byte,
|
||||
messageLen: uint,
|
||||
written: var int): Datagram =
|
||||
let connection = stream.connection
|
||||
var packetInfo: ngtcp2_pkt_info
|
||||
let length = ngtcp2_conn_write_stream(
|
||||
connection.conn,
|
||||
connection.path.toPathPtr,
|
||||
addr packetInfo,
|
||||
addr connection.buffer[0],
|
||||
connection.buffer.len.uint,
|
||||
addr written,
|
||||
0,
|
||||
stream.id,
|
||||
messagePtr,
|
||||
messageLen,
|
||||
now()
|
||||
)
|
||||
checkResult length.cint
|
||||
let data = connection.buffer[0..<length]
|
||||
let ecn = ECN(packetInfo.ecn)
|
||||
Datagram(data: data, ecn: ecn)
|
||||
|
||||
proc send(stream: Stream,
|
||||
messagePtr: ptr byte,
|
||||
messageLen: uint): Future[int] {.async.} =
|
||||
let connection = stream.connection
|
||||
var datagram = stream.trySend(messagePtr, messageLen, result)
|
||||
while datagram.data.len == 0:
|
||||
connection.flowing.clear()
|
||||
await connection.flowing.wait()
|
||||
datagram = stream.trySend(messagePtr, messageLen, result)
|
||||
await connection.outgoing.put(datagram)
|
||||
connection.updateTimeout()
|
||||
|
||||
proc write*(stream: Stream, message: seq[byte]) {.async.} =
|
||||
var messagePtr = message.toUnsafePtr
|
||||
var messageLen = message.len.uint
|
||||
var done = false
|
||||
while not done:
|
||||
let written = await stream.send(messagePtr, messageLen)
|
||||
messagePtr = messagePtr + written
|
||||
messageLen = messageLen - written.uint
|
||||
done = messageLen == 0
|
||||
newStream(connection, id)
|
||||
|
||||
proc incomingStream*(connection: Ngtcp2Connection): Future[Stream] {.async.} =
|
||||
result = await connection.incoming.get()
|
||||
|
||||
proc read*(stream: Stream): Future[seq[byte]] {.async.} =
|
||||
result = await stream.incoming.get()
|
||||
|
||||
proc onStreamOpen(conn: ptr ngtcp2_conn,
|
||||
stream_id: int64,
|
||||
user_data: pointer): cint {.cdecl.} =
|
||||
|
@ -85,10 +31,10 @@ proc onReceiveStreamData(connection: ptr ngtcp2_conn,
|
|||
datalen: uint,
|
||||
user_data: pointer,
|
||||
stream_user_data: pointer): cint{.cdecl.} =
|
||||
let stream = cast[Stream](stream_user_data)
|
||||
let state = cast[OpenStream](stream_user_data)
|
||||
var bytes = newSeqUninitialized[byte](datalen)
|
||||
copyMem(bytes.toUnsafePtr, data, datalen)
|
||||
stream.incoming.putNoWait(bytes)
|
||||
state.receive(bytes)
|
||||
checkResult:
|
||||
connection.ngtcp2_conn_extend_max_stream_offset(stream_id, datalen)
|
||||
connection.ngtcp2_conn_extend_max_offset(datalen)
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
import pkg/chronos
|
||||
|
||||
type
|
||||
Stream* = ref object
|
||||
id: int64
|
||||
state: StreamState
|
||||
StreamState* = ref object of RootObj
|
||||
read*: proc(stream: Stream): Future[seq[byte]] {.gcsafe.}
|
||||
write*: proc(stream: Stream, bytes: seq[byte]): Future[void] {.gcsafe.}
|
||||
close*: proc(stream: Stream): Future[void] {.gcsafe.}
|
||||
destroy*: proc() {.gcsafe.}
|
||||
StreamError* = object of IOError
|
||||
|
||||
proc newStream*(id: int64, state: StreamState): Stream =
|
||||
Stream(state: state)
|
||||
|
||||
proc switch*(stream: Stream, newState: StreamState) =
|
||||
stream.state.destroy()
|
||||
stream.state = newState
|
||||
|
||||
proc id*(stream: Stream): int64 =
|
||||
stream.id
|
||||
|
||||
proc read*(stream: Stream): Future[seq[byte]] {.async.} =
|
||||
result = await stream.state.read(stream)
|
||||
|
||||
proc write*(stream: Stream, bytes: seq[byte]) {.async.} =
|
||||
await stream.state.write(stream, bytes)
|
||||
|
||||
proc close*(stream: Stream) {.async.} =
|
||||
await stream.state.close(stream)
|
|
@ -37,8 +37,8 @@ suite "api":
|
|||
let stream1 = await outgoing.openStream()
|
||||
let stream2 = await incoming.openStream()
|
||||
|
||||
stream1.close()
|
||||
stream2.close()
|
||||
await stream1.close()
|
||||
await stream2.close()
|
||||
|
||||
await outgoing.drop()
|
||||
await incoming.drop()
|
||||
|
@ -75,11 +75,11 @@ suite "api":
|
|||
defer: await incoming.drop()
|
||||
|
||||
let outgoingStream = await outgoing.openStream()
|
||||
defer: outgoingStream.close()
|
||||
defer: await outgoingStream.close()
|
||||
|
||||
await outgoingStream.write(message)
|
||||
|
||||
let incomingStream = await incoming.incomingStream()
|
||||
defer: incomingStream.close()
|
||||
defer: await incomingStream.close()
|
||||
|
||||
check (await incomingStream.read()) == message
|
||||
|
|
|
@ -3,6 +3,7 @@ import std/sequtils
|
|||
import pkg/chronos
|
||||
import pkg/quic/ngtcp2
|
||||
import pkg/quic/datagram
|
||||
import pkg/quic/stream
|
||||
import ../helpers/asynctest
|
||||
import ../helpers/simulation
|
||||
import ../helpers/addresses
|
||||
|
@ -29,7 +30,7 @@ suite "ngtcp2 streams":
|
|||
let (client, server) = await performHandshake()
|
||||
let stream = client.openStream()
|
||||
|
||||
stream.close()
|
||||
await stream.close()
|
||||
|
||||
client.destroy()
|
||||
server.destroy()
|
||||
|
@ -56,10 +57,13 @@ suite "ngtcp2 streams":
|
|||
client.destroy()
|
||||
server.destroy()
|
||||
|
||||
asynctest "raises when stream could not be written to":
|
||||
asynctest "raises when reading from or writing to closed stream":
|
||||
let (client, server) = await performHandshake()
|
||||
let stream = client.openStream()
|
||||
stream.close()
|
||||
await stream.close()
|
||||
|
||||
expect IOError:
|
||||
discard await stream.read()
|
||||
|
||||
expect IOError:
|
||||
await stream.write(@[1'u8, 2'u8, 3'u8])
|
||||
|
|
Loading…
Reference in New Issue