fix(yamux): doesn't work in a Relayv2 connection (#979)

Co-authored-by: Ludovic Chenut <ludovic@status.im>
This commit is contained in:
diegomrsantos 2023-11-21 16:03:29 +01:00 committed by GitHub
parent fb05f5ae22
commit 1f4b090227
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 316 additions and 302 deletions

View File

@ -186,6 +186,7 @@ proc remoteClosed(channel: YamuxChannel) {.async.} =
method closeImpl*(channel: YamuxChannel) {.async, gcsafe.} = method closeImpl*(channel: YamuxChannel) {.async, gcsafe.} =
if not channel.closedLocally: if not channel.closedLocally:
channel.closedLocally = true channel.closedLocally = true
channel.isEof = true
if channel.isReset == false and channel.sendQueue.len == 0: if channel.isReset == false and channel.sendQueue.len == 0:
await channel.conn.write(YamuxHeader.data(channel.id, 0, {Fin})) await channel.conn.write(YamuxHeader.data(channel.id, 0, {Fin}))
@ -249,6 +250,7 @@ method readOnce*(
await channel.closedRemotely or channel.receivedData.wait() await channel.closedRemotely or channel.receivedData.wait()
if channel.closedRemotely.done() and channel.recvQueue.len == 0: if channel.closedRemotely.done() and channel.recvQueue.len == 0:
channel.returnedEof = true channel.returnedEof = true
channel.isEof = true
return 0 return 0
let toRead = min(channel.recvQueue.len, nbytes) let toRead = min(channel.recvQueue.len, nbytes)
@ -454,6 +456,7 @@ method handle*(m: Yamux) {.async, gcsafe.} =
if header.streamId in m.flushed: if header.streamId in m.flushed:
m.flushed.del(header.streamId) m.flushed.del(header.streamId)
if header.streamId mod 2 == m.currentId mod 2: if header.streamId mod 2 == m.currentId mod 2:
debug "Peer used our reserved stream id, skipping", id=header.streamId, currentId=m.currentId, peerId=m.connection.peerId
raise newException(YamuxError, "Peer used our reserved stream id") raise newException(YamuxError, "Peer used our reserved stream id")
let newStream = m.createStream(header.streamId, false) let newStream = m.createStream(header.streamId, false)
if m.channels.len >= m.maxChannCount: if m.channels.len >= m.maxChannCount:

View File

@ -47,6 +47,7 @@ proc new*(
limitDuration: uint32, limitDuration: uint32,
limitData: uint64): T = limitData: uint64): T =
let rc = T(conn: conn, limitDuration: limitDuration, limitData: limitData) let rc = T(conn: conn, limitDuration: limitDuration, limitData: limitData)
rc.dir = conn.dir
rc.initStream() rc.initStream()
if limitDuration > 0: if limitDuration > 0:
proc checkDurationConnection() {.async.} = proc checkDurationConnection() {.async.} =

View File

@ -19,14 +19,22 @@ import ./helpers
import std/times import std/times
import stew/byteutils import stew/byteutils
proc createSwitch(r: Relay): Switch = proc createSwitch(r: Relay = nil, useYamux: bool = false): Switch =
result = SwitchBuilder.new() var builder = SwitchBuilder.new()
.withRng(newRng()) .withRng(newRng())
.withAddresses(@[ MultiAddress.init("/ip4/0.0.0.0/tcp/0").tryGet() ]) .withAddresses(@[ MultiAddress.init("/ip4/0.0.0.0/tcp/0").tryGet() ])
.withTcpTransport() .withTcpTransport()
.withMplex()
if useYamux:
builder = builder.withYamux()
else:
builder = builder.withMplex()
if r != nil:
builder = builder.withCircuitRelay(r)
return builder
.withNoise() .withNoise()
.withCircuitRelay(r)
.build() .build()
suite "Circuit Relay V2": suite "Circuit Relay V2":
@ -122,7 +130,8 @@ suite "Circuit Relay V2":
expect(ReservationError): expect(ReservationError):
discard await cl1.reserve(src2.peerInfo.peerId, addrs) discard await cl1.reserve(src2.peerInfo.peerId, addrs)
suite "Connection": for (useYamux, muxName) in [(false, "Mplex"), (true, "Yamux")]:
suite "Circuit Relay V2 Connection using " & muxName:
asyncTeardown: asyncTeardown:
checkTrackers() checkTrackers()
var var
@ -149,9 +158,9 @@ suite "Circuit Relay V2":
ldata = 16384 ldata = 16384
srcCl = RelayClient.new() srcCl = RelayClient.new()
dstCl = RelayClient.new() dstCl = RelayClient.new()
src = createSwitch(srcCl) src = createSwitch(srcCl, useYamux)
dst = createSwitch(dstCl) dst = createSwitch(dstCl, useYamux)
rel = newStandardSwitch() rel = createSwitch(nil, useYamux)
asyncTest "Connection succeed": asyncTest "Connection succeed":
proto.handler = proc(conn: Connection, proto: string) {.async.} = proto.handler = proc(conn: Connection, proto: string) {.async.} =
@ -322,7 +331,7 @@ take to the ship.""")
raise newException(CatchableError, "Should not be here") raise newException(CatchableError, "Should not be here")
let let
rel2Cl = RelayClient.new(canHop = true) rel2Cl = RelayClient.new(canHop = true)
rel2 = createSwitch(rel2Cl) rel2 = createSwitch(rel2Cl, useYamux)
rv2 = Relay.new() rv2 = Relay.new()
rv2.setup(rel) rv2.setup(rel)
rel.mount(rv2) rel.mount(rv2)
@ -347,6 +356,7 @@ take to the ship.""")
expect(DialFailedError): expect(DialFailedError):
conn = await src.dial(dst.peerInfo.peerId, addrs, customProtoCodec) conn = await src.dial(dst.peerInfo.peerId, addrs, customProtoCodec)
if not conn.isNil():
await allFutures(conn.close()) await allFutures(conn.close())
await allFutures(src.stop(), dst.stop(), rel.stop(), rel2.stop()) await allFutures(src.stop(), dst.stop(), rel.stop(), rel2.stop())
@ -381,9 +391,9 @@ take to the ship.""")
clientA = RelayClient.new(canHop = true) clientA = RelayClient.new(canHop = true)
clientB = RelayClient.new(canHop = true) clientB = RelayClient.new(canHop = true)
clientC = RelayClient.new(canHop = true) clientC = RelayClient.new(canHop = true)
switchA = createSwitch(clientA) switchA = createSwitch(clientA, useYamux)
switchB = createSwitch(clientB) switchB = createSwitch(clientB, useYamux)
switchC = createSwitch(clientC) switchC = createSwitch(clientC, useYamux)
switchA.mount(protoBCA) switchA.mount(protoBCA)
switchB.mount(protoCAB) switchB.mount(protoCAB)