Introduce state pattern for Streams

Streams can now be in one of two states: open or closed.
Separates a Stream from its ngtcp2 implementation specifics.
Breaks the circular dependency between Connection and Stream,
allowing them to be defined in separate modules.
This commit is contained in:
Mark Spanbroek 2020-12-14 09:21:45 +01:00 committed by markspanbroek
parent 4c005eee37
commit a50a987542
11 changed files with 189 additions and 84 deletions

View File

@ -1,4 +1,5 @@
import pkg/chronos
import ./stream
import ./asyncloop
import ./ngtcp2
import ./connectionid

View File

@ -13,9 +13,6 @@ export receive, send
export isHandshakeCompleted
export handshake
export ids
export Stream
export openStream
export close
export read, write
export incomingStream
export destroy

View File

@ -2,6 +2,7 @@ import std/sequtils
import pkg/chronos
import pkg/ngtcp2
import ../datagram
import ../stream
import ../openarray
import ../congestion
import ../timeout
@ -23,10 +24,6 @@ type
timeout*: Timeout
onNewId*: proc(id: ConnectionId)
onRemoveId*: proc(id: ConnectionId)
Stream* = ref object
id*: int64
connection*: Ngtcp2Connection
incoming*: AsyncQueue[seq[byte]]
proc destroy(connection: var Ngtcp2ConnectionObj) =
if connection.conn != nil:
@ -60,14 +57,6 @@ proc ids*(connection: Ngtcp2Connection): seq[ConnectionId] =
discard ngtcp2_conn_get_scid(connection.conn, scids.toPtr)
scids.mapIt(ConnectionId(it.data[0..<it.datalen]))
proc newStream*(connection: Ngtcp2Connection, id: int64): Stream =
let incoming = newAsyncQueue[seq[byte]]()
let stream = Stream(connection: connection, id: id, incoming: incoming)
let conn = connection.conn
let userdata = unsafeAddr stream[]
checkResult ngtcp2_conn_set_stream_user_data(conn, stream.id, userdata)
stream
proc updateTimeout*(connection: Ngtcp2Connection) =
let expiry = ngtcp2_conn_get_expiry(connection.conn)
if expiry != uint64.high:

View File

@ -0,0 +1,22 @@
import pkg/chronos
import ../../stream
import ./state
type
ClosedStream* = ref object of StreamState
ClosedStreamError* = object of StreamError
proc read(stream: Stream, state: ClosedStream): Future[seq[byte]] {.async.} =
raise newException(ClosedStreamError, "stream is closed")
proc write(stream: Stream, state: ClosedStream, bytes: seq[byte]) {.async.} =
raise newException(ClosedStreamError, "stream is closed")
proc close(stream: Stream, state: ClosedStream) {.async.} =
discard
proc destroy(state: ClosedStream) =
discard
proc newClosedState*(): ClosedStream =
newState[ClosedStream]()

View File

@ -0,0 +1,57 @@
import pkg/chronos
import pkg/ngtcp2
import ../../stream
import ../../openarray
import ../connection
import ../errors
import ../pointers
import ./sending
import ./state
import ./closedstate
type
OpenStream* = ref object of StreamState
id: int64
connection: Ngtcp2Connection
incoming: AsyncQueue[seq[byte]]
proc read(stream: Stream, state: OpenStream): Future[seq[byte]] {.async.} =
result = await state.incoming.get()
proc write(stream: Stream, state: OpenStream, bytes: seq[byte]) {.async.} =
let connection = state.connection
let streamId = state.id
var messagePtr = bytes.toUnsafePtr
var messageLen = bytes.len.uint
var done = false
while not done:
let written = await send(connection, streamId, messagePtr, messageLen)
messagePtr = messagePtr + written
messageLen = messageLen - written.uint
done = messageLen == 0
proc close(stream: Stream, state: OpenStream) {.async.} =
checkResult ngtcp2_conn_shutdown_stream(state.connection.conn, state.id, 0)
stream.switch(newClosedState())
proc destroy(state: OpenStream) =
let conn = state.connection.conn
let id = state.id
checkResult ngtcp2_conn_set_stream_user_data(conn, id, nil)
proc setUserData(state: OpenStream) =
let conn = state.connection.conn
let id = state.id
let userdata = unsafeAddr state[]
checkResult ngtcp2_conn_set_stream_user_data(conn, id, userdata)
proc newOpenState*(connection: Ngtcp2Connection, id: int64): OpenStream =
let state = newState[OpenStream]()
state.connection = connection
state.id = id
state.incoming = newAsyncQueue[seq[byte]]()
state.setUserData()
state
proc receive*(state: OpenStream, bytes: seq[byte]) =
state.incoming.putNoWait(bytes)

View File

@ -0,0 +1,44 @@
import pkg/chronos
import pkg/ngtcp2
import ../../datagram
import ../../congestion
import ../connection
import ../path
import ../errors
import ../timestamp
proc trySend(connection: Ngtcp2Connection,
streamId: int64,
messagePtr: ptr byte,
messageLen: uint,
written: var int): Datagram =
var packetInfo: ngtcp2_pkt_info
let length = ngtcp2_conn_write_stream(
connection.conn,
connection.path.toPathPtr,
addr packetInfo,
addr connection.buffer[0],
connection.buffer.len.uint,
addr written,
0,
streamId,
messagePtr,
messageLen,
now()
)
checkResult length.cint
let data = connection.buffer[0..<length]
let ecn = ECN(packetInfo.ecn)
Datagram(data: data, ecn: ecn)
proc send*(connection: Ngtcp2Connection,
streamId: int64,
messagePtr: ptr byte,
messageLen: uint): Future[int] {.async.} =
var datagram = trySend(connection, streamId, messagePtr, messageLen, result)
while datagram.data.len == 0:
connection.flowing.clear()
await connection.flowing.wait()
datagram = trySend(connection, streamId, messagePtr, messageLen, result)
await connection.outgoing.put(datagram)
connection.updateTimeout()

View File

@ -0,0 +1,14 @@
import pkg/chronos
import ../../stream
template newState*[T: StreamState](): T =
var state = T()
state.read = proc(stream: Stream): Future[seq[byte]] =
read(stream, state)
state.write = proc(stream: Stream, bytes: seq[byte]): Future[void] =
write(stream, state, bytes)
state.close = proc(stream: Stream): Future[void] =
close(stream, state)
state.destroy = proc() =
destroy(state)
state

View File

@ -1,76 +1,22 @@
import pkg/chronos
import pkg/ngtcp2
import ../openarray
import ../datagram
import ../congestion
import ../stream
import ./connection
import ./errors
import ./path
import ./pointers
import ./timestamp
import ./stream/openstate
proc newStream*(connection: Ngtcp2Connection, id: int64): Stream =
newStream(id, newOpenState(connection, id))
proc openStream*(connection: Ngtcp2Connection): Stream =
var id: int64
checkResult ngtcp2_conn_open_uni_stream(connection.conn, addr id, nil)
var stream = newStream(connection, id)
checkResult ngtcp2_conn_set_stream_user_data(connection.conn, id, addr result)
stream
proc close*(stream: Stream) =
checkResult ngtcp2_conn_shutdown_stream(stream.connection.conn, stream.id, 0)
proc trySend(stream: Stream,
messagePtr: ptr byte,
messageLen: uint,
written: var int): Datagram =
let connection = stream.connection
var packetInfo: ngtcp2_pkt_info
let length = ngtcp2_conn_write_stream(
connection.conn,
connection.path.toPathPtr,
addr packetInfo,
addr connection.buffer[0],
connection.buffer.len.uint,
addr written,
0,
stream.id,
messagePtr,
messageLen,
now()
)
checkResult length.cint
let data = connection.buffer[0..<length]
let ecn = ECN(packetInfo.ecn)
Datagram(data: data, ecn: ecn)
proc send(stream: Stream,
messagePtr: ptr byte,
messageLen: uint): Future[int] {.async.} =
let connection = stream.connection
var datagram = stream.trySend(messagePtr, messageLen, result)
while datagram.data.len == 0:
connection.flowing.clear()
await connection.flowing.wait()
datagram = stream.trySend(messagePtr, messageLen, result)
await connection.outgoing.put(datagram)
connection.updateTimeout()
proc write*(stream: Stream, message: seq[byte]) {.async.} =
var messagePtr = message.toUnsafePtr
var messageLen = message.len.uint
var done = false
while not done:
let written = await stream.send(messagePtr, messageLen)
messagePtr = messagePtr + written
messageLen = messageLen - written.uint
done = messageLen == 0
newStream(connection, id)
proc incomingStream*(connection: Ngtcp2Connection): Future[Stream] {.async.} =
result = await connection.incoming.get()
proc read*(stream: Stream): Future[seq[byte]] {.async.} =
result = await stream.incoming.get()
proc onStreamOpen(conn: ptr ngtcp2_conn,
stream_id: int64,
user_data: pointer): cint {.cdecl.} =
@ -85,10 +31,10 @@ proc onReceiveStreamData(connection: ptr ngtcp2_conn,
datalen: uint,
user_data: pointer,
stream_user_data: pointer): cint{.cdecl.} =
let stream = cast[Stream](stream_user_data)
let state = cast[OpenStream](stream_user_data)
var bytes = newSeqUninitialized[byte](datalen)
copyMem(bytes.toUnsafePtr, data, datalen)
stream.incoming.putNoWait(bytes)
state.receive(bytes)
checkResult:
connection.ngtcp2_conn_extend_max_stream_offset(stream_id, datalen)
connection.ngtcp2_conn_extend_max_offset(datalen)

31
quic/stream.nim Normal file
View File

@ -0,0 +1,31 @@
import pkg/chronos
type
Stream* = ref object
id: int64
state: StreamState
StreamState* = ref object of RootObj
read*: proc(stream: Stream): Future[seq[byte]] {.gcsafe.}
write*: proc(stream: Stream, bytes: seq[byte]): Future[void] {.gcsafe.}
close*: proc(stream: Stream): Future[void] {.gcsafe.}
destroy*: proc() {.gcsafe.}
StreamError* = object of IOError
proc newStream*(id: int64, state: StreamState): Stream =
Stream(state: state)
proc switch*(stream: Stream, newState: StreamState) =
stream.state.destroy()
stream.state = newState
proc id*(stream: Stream): int64 =
stream.id
proc read*(stream: Stream): Future[seq[byte]] {.async.} =
result = await stream.state.read(stream)
proc write*(stream: Stream, bytes: seq[byte]) {.async.} =
await stream.state.write(stream, bytes)
proc close*(stream: Stream) {.async.} =
await stream.state.close(stream)

View File

@ -37,8 +37,8 @@ suite "api":
let stream1 = await outgoing.openStream()
let stream2 = await incoming.openStream()
stream1.close()
stream2.close()
await stream1.close()
await stream2.close()
await outgoing.drop()
await incoming.drop()
@ -75,11 +75,11 @@ suite "api":
defer: await incoming.drop()
let outgoingStream = await outgoing.openStream()
defer: outgoingStream.close()
defer: await outgoingStream.close()
await outgoingStream.write(message)
let incomingStream = await incoming.incomingStream()
defer: incomingStream.close()
defer: await incomingStream.close()
check (await incomingStream.read()) == message

View File

@ -3,6 +3,7 @@ import std/sequtils
import pkg/chronos
import pkg/quic/ngtcp2
import pkg/quic/datagram
import pkg/quic/stream
import ../helpers/asynctest
import ../helpers/simulation
import ../helpers/addresses
@ -29,7 +30,7 @@ suite "ngtcp2 streams":
let (client, server) = await performHandshake()
let stream = client.openStream()
stream.close()
await stream.close()
client.destroy()
server.destroy()
@ -56,10 +57,13 @@ suite "ngtcp2 streams":
client.destroy()
server.destroy()
asynctest "raises when stream could not be written to":
asynctest "raises when reading from or writing to closed stream":
let (client, server) = await performHandshake()
let stream = client.openStream()
stream.close()
await stream.close()
expect IOError:
discard await stream.read()
expect IOError:
await stream.write(@[1'u8, 2'u8, 3'u8])