Merge pull request #1424 from libp2p/merge-quic
move go-libp2p-quic-transport here
This commit is contained in:
commit
5151d4b4fa
|
@ -6,6 +6,7 @@ import (
|
|||
"crypto/rand"
|
||||
|
||||
"github.com/libp2p/go-libp2p/p2p/net/connmgr"
|
||||
quic "github.com/libp2p/go-libp2p/p2p/transport/quic"
|
||||
"github.com/libp2p/go-libp2p/p2p/transport/tcp"
|
||||
ws "github.com/libp2p/go-libp2p/p2p/transport/websocket"
|
||||
|
||||
|
@ -13,7 +14,6 @@ import (
|
|||
|
||||
noise "github.com/libp2p/go-libp2p-noise"
|
||||
"github.com/libp2p/go-libp2p-peerstore/pstoremem"
|
||||
quic "github.com/libp2p/go-libp2p-quic-transport"
|
||||
rcmgr "github.com/libp2p/go-libp2p-resource-manager"
|
||||
tls "github.com/libp2p/go-libp2p-tls"
|
||||
yamux "github.com/libp2p/go-libp2p-yamux"
|
||||
|
|
12
go.mod
12
go.mod
|
@ -6,12 +6,14 @@ require (
|
|||
github.com/benbjohnson/clock v1.3.0
|
||||
github.com/gogo/protobuf v1.3.2
|
||||
github.com/golang/mock v1.6.0
|
||||
github.com/google/gopacket v1.1.19
|
||||
github.com/gorilla/websocket v1.5.0
|
||||
github.com/hashicorp/golang-lru v0.5.4
|
||||
github.com/ipfs/go-cid v0.1.0
|
||||
github.com/ipfs/go-datastore v0.5.1
|
||||
github.com/ipfs/go-ipfs-util v0.0.2
|
||||
github.com/ipfs/go-log/v2 v2.5.1
|
||||
github.com/klauspost/compress v1.15.1
|
||||
github.com/libp2p/go-buffer-pool v0.0.2
|
||||
github.com/libp2p/go-conn-security-multistream v0.3.0
|
||||
github.com/libp2p/go-eventbus v0.2.1
|
||||
|
@ -22,7 +24,6 @@ require (
|
|||
github.com/libp2p/go-libp2p-nat v0.1.0
|
||||
github.com/libp2p/go-libp2p-noise v0.4.0
|
||||
github.com/libp2p/go-libp2p-peerstore v0.6.0
|
||||
github.com/libp2p/go-libp2p-quic-transport v0.17.0
|
||||
github.com/libp2p/go-libp2p-resource-manager v0.2.1
|
||||
github.com/libp2p/go-libp2p-testing v0.9.2
|
||||
github.com/libp2p/go-libp2p-tls v0.4.1
|
||||
|
@ -34,8 +35,10 @@ require (
|
|||
github.com/libp2p/go-reuseport-transport v0.1.0
|
||||
github.com/libp2p/go-stream-muxer-multistream v0.4.0
|
||||
github.com/libp2p/zeroconf/v2 v2.1.1
|
||||
github.com/lucas-clemente/quic-go v0.27.0
|
||||
github.com/marten-seemann/tcp v0.0.0-20210406111302-dfbc87cc63fd
|
||||
github.com/mikioh/tcpinfo v0.0.0-20190314235526-30a79bb1804b
|
||||
github.com/minio/sha256-simd v1.0.0
|
||||
github.com/multiformats/go-multiaddr v0.5.0
|
||||
github.com/multiformats/go-multiaddr-dns v0.3.1
|
||||
github.com/multiformats/go-multiaddr-fmt v0.1.0
|
||||
|
@ -47,6 +50,7 @@ require (
|
|||
github.com/stretchr/testify v1.7.0
|
||||
github.com/whyrusleeping/mdns v0.0.0-20190826153040-b9b60ed33aa9
|
||||
github.com/whyrusleeping/multiaddr-filter v0.0.0-20160516205228-e903e4adabd7
|
||||
golang.org/x/crypto v0.0.0-20220411220226-7b82a4e95df4
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c
|
||||
)
|
||||
|
||||
|
@ -68,26 +72,24 @@ require (
|
|||
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 // indirect
|
||||
github.com/godbus/dbus/v5 v5.1.0 // indirect
|
||||
github.com/golang/protobuf v1.5.2 // indirect
|
||||
github.com/google/gopacket v1.1.19 // indirect
|
||||
github.com/google/uuid v1.3.0 // indirect
|
||||
github.com/huin/goupnp v1.0.3 // indirect
|
||||
github.com/jackpal/go-nat-pmp v1.0.2 // indirect
|
||||
github.com/jbenet/go-temp-err-catcher v0.1.0 // indirect
|
||||
github.com/jbenet/goprocess v0.1.4 // indirect
|
||||
github.com/klauspost/compress v1.15.1 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.0.12 // indirect
|
||||
github.com/koron/go-ssdp v0.0.2 // indirect
|
||||
github.com/libp2p/go-cidranger v1.1.0 // indirect
|
||||
github.com/libp2p/go-flow-metrics v0.0.3 // indirect
|
||||
github.com/libp2p/go-libp2p-blankhost v0.3.0 // indirect
|
||||
github.com/libp2p/go-libp2p-pnet v0.2.0 // indirect
|
||||
github.com/libp2p/go-libp2p-quic-transport v0.17.0 // indirect
|
||||
github.com/libp2p/go-libp2p-swarm v0.10.2 // indirect
|
||||
github.com/libp2p/go-mplex v0.4.0 // indirect
|
||||
github.com/libp2p/go-nat v0.1.0 // indirect
|
||||
github.com/libp2p/go-openssl v0.0.7 // indirect
|
||||
github.com/libp2p/go-tcp-transport v0.5.1 // indirect
|
||||
github.com/libp2p/go-yamux/v3 v3.1.1 // indirect
|
||||
github.com/lucas-clemente/quic-go v0.27.0 // indirect
|
||||
github.com/marten-seemann/qtls-go1-16 v0.1.5 // indirect
|
||||
github.com/marten-seemann/qtls-go1-17 v0.1.1 // indirect
|
||||
github.com/marten-seemann/qtls-go1-18 v0.1.1 // indirect
|
||||
|
@ -96,7 +98,6 @@ require (
|
|||
github.com/miekg/dns v1.1.48 // indirect
|
||||
github.com/mikioh/tcpopt v0.0.0-20190314235656-172688c1accc // indirect
|
||||
github.com/minio/blake2b-simd v0.0.0-20160723061019-3f5f724cb5b1 // indirect
|
||||
github.com/minio/sha256-simd v1.0.0 // indirect
|
||||
github.com/mr-tron/base58 v1.2.0 // indirect
|
||||
github.com/multiformats/go-base32 v0.0.4 // indirect
|
||||
github.com/multiformats/go-base36 v0.1.0 // indirect
|
||||
|
@ -117,7 +118,6 @@ require (
|
|||
go.uber.org/atomic v1.9.0 // indirect
|
||||
go.uber.org/multierr v1.8.0 // indirect
|
||||
go.uber.org/zap v1.21.0 // indirect
|
||||
golang.org/x/crypto v0.0.0-20220411220226-7b82a4e95df4 // indirect
|
||||
golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3 // indirect
|
||||
golang.org/x/net v0.0.0-20220418201149-a630d4f3e7a2 // indirect
|
||||
golang.org/x/sys v0.0.0-20220412211240-33da011f77ad // indirect
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
quic "github.com/libp2p/go-libp2p/p2p/transport/quic"
|
||||
"github.com/libp2p/go-libp2p/p2p/transport/tcp"
|
||||
|
||||
"github.com/libp2p/go-libp2p-core/peerstore"
|
||||
|
@ -16,7 +17,6 @@ import (
|
|||
|
||||
csms "github.com/libp2p/go-conn-security-multistream"
|
||||
"github.com/libp2p/go-libp2p-peerstore/pstoremem"
|
||||
quic "github.com/libp2p/go-libp2p-quic-transport"
|
||||
tnet "github.com/libp2p/go-libp2p-testing/net"
|
||||
tptu "github.com/libp2p/go-libp2p-transport-upgrader"
|
||||
yamux "github.com/libp2p/go-libp2p-yamux"
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/libp2p/go-libp2p/p2p/net/swarm"
|
||||
quic "github.com/libp2p/go-libp2p/p2p/transport/quic"
|
||||
"github.com/libp2p/go-libp2p/p2p/transport/tcp"
|
||||
|
||||
"github.com/libp2p/go-libp2p-core/connmgr"
|
||||
|
@ -19,7 +20,6 @@ import (
|
|||
|
||||
csms "github.com/libp2p/go-conn-security-multistream"
|
||||
"github.com/libp2p/go-libp2p-peerstore/pstoremem"
|
||||
quic "github.com/libp2p/go-libp2p-quic-transport"
|
||||
tnet "github.com/libp2p/go-libp2p-testing/net"
|
||||
tptu "github.com/libp2p/go-libp2p-transport-upgrader"
|
||||
yamux "github.com/libp2p/go-libp2p-yamux"
|
||||
|
|
|
@ -0,0 +1,71 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"os"
|
||||
|
||||
ic "github.com/libp2p/go-libp2p-core/crypto"
|
||||
"github.com/libp2p/go-libp2p-core/peer"
|
||||
libp2pquic "github.com/libp2p/go-libp2p/p2p/transport/quic"
|
||||
ma "github.com/multiformats/go-multiaddr"
|
||||
)
|
||||
|
||||
func main() {
|
||||
if len(os.Args) != 3 {
|
||||
fmt.Printf("Usage: %s <multiaddr> <peer id>", os.Args[0])
|
||||
return
|
||||
}
|
||||
if err := run(os.Args[1], os.Args[2]); err != nil {
|
||||
log.Fatalf(err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func run(raddr string, p string) error {
|
||||
peerID, err := peer.Decode(p)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
addr, err := ma.NewMultiaddr(raddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
priv, _, err := ic.GenerateECDSAKeyPair(rand.Reader)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t, err := libp2pquic.NewTransport(priv, nil, nil, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Printf("Dialing %s\n", addr.String())
|
||||
conn, err := t.Dial(context.Background(), addr, peerID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
str, err := conn.OpenStream(context.Background())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer str.Close()
|
||||
const msg = "Hello world!"
|
||||
log.Printf("Sending: %s\n", msg)
|
||||
if _, err := str.Write([]byte(msg)); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := str.CloseWrite(); err != nil {
|
||||
return err
|
||||
}
|
||||
data, err := ioutil.ReadAll(str)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Printf("Received: %s\n", data)
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,79 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"os"
|
||||
|
||||
ic "github.com/libp2p/go-libp2p-core/crypto"
|
||||
"github.com/libp2p/go-libp2p-core/peer"
|
||||
tpt "github.com/libp2p/go-libp2p-core/transport"
|
||||
libp2pquic "github.com/libp2p/go-libp2p/p2p/transport/quic"
|
||||
ma "github.com/multiformats/go-multiaddr"
|
||||
)
|
||||
|
||||
func main() {
|
||||
if len(os.Args) != 2 {
|
||||
fmt.Printf("Usage: %s <port>", os.Args[0])
|
||||
return
|
||||
}
|
||||
if err := run(os.Args[1]); err != nil {
|
||||
log.Fatalf(err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func run(port string) error {
|
||||
addr, err := ma.NewMultiaddr(fmt.Sprintf("/ip4/0.0.0.0/udp/%s/quic", port))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
priv, _, err := ic.GenerateECDSAKeyPair(rand.Reader)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
peerID, err := peer.IDFromPrivateKey(priv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t, err := libp2pquic.NewTransport(priv, nil, nil, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ln, err := t.Listen(addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Printf("Listening. Now run: go run cmd/client/main.go %s %s\n", ln.Multiaddr(), peerID)
|
||||
for {
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Printf("Accepted new connection from %s (%s)\n", conn.RemotePeer(), conn.RemoteMultiaddr())
|
||||
go func() {
|
||||
if err := handleConn(conn); err != nil {
|
||||
log.Printf("handling conn failed: %s", err.Error())
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func handleConn(conn tpt.CapableConn) error {
|
||||
str, err := conn.AcceptStream()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data, err := ioutil.ReadAll(str)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Printf("Received: %s\n", data)
|
||||
if _, err := str.Write(data); err != nil {
|
||||
return err
|
||||
}
|
||||
return str.Close()
|
||||
}
|
|
@ -0,0 +1,100 @@
|
|||
package libp2pquic
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
ic "github.com/libp2p/go-libp2p-core/crypto"
|
||||
"github.com/libp2p/go-libp2p-core/network"
|
||||
"github.com/libp2p/go-libp2p-core/peer"
|
||||
tpt "github.com/libp2p/go-libp2p-core/transport"
|
||||
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
ma "github.com/multiformats/go-multiaddr"
|
||||
)
|
||||
|
||||
type conn struct {
|
||||
quicConn quic.Connection
|
||||
pconn *reuseConn
|
||||
transport *transport
|
||||
scope network.ConnManagementScope
|
||||
|
||||
localPeer peer.ID
|
||||
privKey ic.PrivKey
|
||||
localMultiaddr ma.Multiaddr
|
||||
|
||||
remotePeerID peer.ID
|
||||
remotePubKey ic.PubKey
|
||||
remoteMultiaddr ma.Multiaddr
|
||||
}
|
||||
|
||||
var _ tpt.CapableConn = &conn{}
|
||||
|
||||
// Close closes the connection.
|
||||
// It must be called even if the peer closed the connection in order for
|
||||
// garbage collection to properly work in this package.
|
||||
func (c *conn) Close() error {
|
||||
c.transport.removeConn(c.quicConn)
|
||||
err := c.quicConn.CloseWithError(0, "")
|
||||
c.pconn.DecreaseCount()
|
||||
c.scope.Done()
|
||||
return err
|
||||
}
|
||||
|
||||
// IsClosed returns whether a connection is fully closed.
|
||||
func (c *conn) IsClosed() bool {
|
||||
return c.quicConn.Context().Err() != nil
|
||||
}
|
||||
|
||||
func (c *conn) allowWindowIncrease(size uint64) bool {
|
||||
return c.scope.ReserveMemory(int(size), network.ReservationPriorityMedium) == nil
|
||||
}
|
||||
|
||||
// OpenStream creates a new stream.
|
||||
func (c *conn) OpenStream(ctx context.Context) (network.MuxedStream, error) {
|
||||
qstr, err := c.quicConn.OpenStreamSync(ctx)
|
||||
return &stream{Stream: qstr}, err
|
||||
}
|
||||
|
||||
// AcceptStream accepts a stream opened by the other side.
|
||||
func (c *conn) AcceptStream() (network.MuxedStream, error) {
|
||||
qstr, err := c.quicConn.AcceptStream(context.Background())
|
||||
return &stream{Stream: qstr}, err
|
||||
}
|
||||
|
||||
// LocalPeer returns our peer ID
|
||||
func (c *conn) LocalPeer() peer.ID {
|
||||
return c.localPeer
|
||||
}
|
||||
|
||||
// LocalPrivateKey returns our private key
|
||||
func (c *conn) LocalPrivateKey() ic.PrivKey {
|
||||
return c.privKey
|
||||
}
|
||||
|
||||
// RemotePeer returns the peer ID of the remote peer.
|
||||
func (c *conn) RemotePeer() peer.ID {
|
||||
return c.remotePeerID
|
||||
}
|
||||
|
||||
// RemotePublicKey returns the public key of the remote peer.
|
||||
func (c *conn) RemotePublicKey() ic.PubKey {
|
||||
return c.remotePubKey
|
||||
}
|
||||
|
||||
// LocalMultiaddr returns the local Multiaddr associated
|
||||
func (c *conn) LocalMultiaddr() ma.Multiaddr {
|
||||
return c.localMultiaddr
|
||||
}
|
||||
|
||||
// RemoteMultiaddr returns the remote Multiaddr associated
|
||||
func (c *conn) RemoteMultiaddr() ma.Multiaddr {
|
||||
return c.remoteMultiaddr
|
||||
}
|
||||
|
||||
func (c *conn) Transport() tpt.Transport {
|
||||
return c.transport
|
||||
}
|
||||
|
||||
func (c *conn) Scope() network.ConnScope {
|
||||
return c.scope
|
||||
}
|
|
@ -0,0 +1,567 @@
|
|||
package libp2pquic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
mrand "math/rand"
|
||||
"net"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
ic "github.com/libp2p/go-libp2p-core/crypto"
|
||||
"github.com/libp2p/go-libp2p-core/network"
|
||||
"github.com/libp2p/go-libp2p-core/peer"
|
||||
tpt "github.com/libp2p/go-libp2p-core/transport"
|
||||
|
||||
mocknetwork "github.com/libp2p/go-libp2p-testing/mocks/network"
|
||||
|
||||
quicproxy "github.com/lucas-clemente/quic-go/integrationtests/tools/proxy"
|
||||
ma "github.com/multiformats/go-multiaddr"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
//go:generate sh -c "mockgen -package libp2pquic -destination mock_connection_gater_test.go github.com/libp2p/go-libp2p-core/connmgr ConnectionGater && goimports -w mock_connection_gater_test.go"
|
||||
|
||||
func createPeer(t *testing.T) (peer.ID, ic.PrivKey) {
|
||||
var priv ic.PrivKey
|
||||
var err error
|
||||
switch mrand.Int() % 4 {
|
||||
case 0:
|
||||
priv, _, err = ic.GenerateECDSAKeyPair(rand.Reader)
|
||||
case 1:
|
||||
priv, _, err = ic.GenerateRSAKeyPair(2048, rand.Reader)
|
||||
case 2:
|
||||
priv, _, err = ic.GenerateEd25519Key(rand.Reader)
|
||||
case 3:
|
||||
priv, _, err = ic.GenerateSecp256k1Key(rand.Reader)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
id, err := peer.IDFromPrivateKey(priv)
|
||||
require.NoError(t, err)
|
||||
t.Logf("using a %s key: %s", priv.Type(), id.Pretty())
|
||||
return id, priv
|
||||
}
|
||||
|
||||
func runServer(t *testing.T, tr tpt.Transport, addr string) tpt.Listener {
|
||||
t.Helper()
|
||||
ln, err := tr.Listen(ma.StringCast(addr))
|
||||
require.NoError(t, err)
|
||||
return ln
|
||||
}
|
||||
|
||||
func TestHandshake(t *testing.T) {
|
||||
serverID, serverKey := createPeer(t)
|
||||
clientID, clientKey := createPeer(t)
|
||||
serverTransport, err := NewTransport(serverKey, nil, nil, nil)
|
||||
require.NoError(t, err)
|
||||
defer serverTransport.(io.Closer).Close()
|
||||
|
||||
handshake := func(t *testing.T, ln tpt.Listener) {
|
||||
clientTransport, err := NewTransport(clientKey, nil, nil, nil)
|
||||
require.NoError(t, err)
|
||||
defer clientTransport.(io.Closer).Close()
|
||||
conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID)
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
serverConn, err := ln.Accept()
|
||||
require.NoError(t, err)
|
||||
defer serverConn.Close()
|
||||
|
||||
require.Equal(t, conn.LocalPeer(), clientID)
|
||||
require.True(t, conn.LocalPrivateKey().Equals(clientKey), "local private key doesn't match")
|
||||
require.Equal(t, conn.RemotePeer(), serverID)
|
||||
require.True(t, conn.RemotePublicKey().Equals(serverKey.GetPublic()), "remote public key doesn't match")
|
||||
|
||||
require.Equal(t, serverConn.LocalPeer(), serverID)
|
||||
require.True(t, serverConn.LocalPrivateKey().Equals(serverKey), "local private key doesn't match")
|
||||
require.Equal(t, serverConn.RemotePeer(), clientID)
|
||||
require.True(t, serverConn.RemotePublicKey().Equals(clientKey.GetPublic()), "remote public key doesn't match")
|
||||
}
|
||||
|
||||
t.Run("on IPv4", func(t *testing.T) {
|
||||
ln := runServer(t, serverTransport, "/ip4/127.0.0.1/udp/0/quic")
|
||||
defer ln.Close()
|
||||
handshake(t, ln)
|
||||
})
|
||||
|
||||
t.Run("on IPv6", func(t *testing.T) {
|
||||
ln := runServer(t, serverTransport, "/ip6/::1/udp/0/quic")
|
||||
defer ln.Close()
|
||||
handshake(t, ln)
|
||||
})
|
||||
}
|
||||
|
||||
func TestResourceManagerSuccess(t *testing.T) {
|
||||
serverID, serverKey := createPeer(t)
|
||||
clientID, clientKey := createPeer(t)
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
serverRcmgr := mocknetwork.NewMockResourceManager(ctrl)
|
||||
serverTransport, err := NewTransport(serverKey, nil, nil, serverRcmgr)
|
||||
require.NoError(t, err)
|
||||
defer serverTransport.(io.Closer).Close()
|
||||
ln, err := serverTransport.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic"))
|
||||
require.NoError(t, err)
|
||||
defer ln.Close()
|
||||
|
||||
clientRcmgr := mocknetwork.NewMockResourceManager(ctrl)
|
||||
clientTransport, err := NewTransport(clientKey, nil, nil, clientRcmgr)
|
||||
require.NoError(t, err)
|
||||
defer clientTransport.(io.Closer).Close()
|
||||
|
||||
connChan := make(chan tpt.CapableConn)
|
||||
serverConnScope := mocknetwork.NewMockConnManagementScope(ctrl)
|
||||
go func() {
|
||||
serverRcmgr.EXPECT().OpenConnection(network.DirInbound, false).Return(serverConnScope, nil)
|
||||
serverConnScope.EXPECT().SetPeer(clientID)
|
||||
serverConn, err := ln.Accept()
|
||||
require.NoError(t, err)
|
||||
connChan <- serverConn
|
||||
}()
|
||||
|
||||
connScope := mocknetwork.NewMockConnManagementScope(ctrl)
|
||||
clientRcmgr.EXPECT().OpenConnection(network.DirOutbound, false).Return(connScope, nil)
|
||||
connScope.EXPECT().SetPeer(serverID)
|
||||
conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID)
|
||||
require.NoError(t, err)
|
||||
serverConn := <-connChan
|
||||
t.Log("received conn")
|
||||
connScope.EXPECT().Done().MinTimes(1) // for dialed connections, we might call Done multiple times
|
||||
conn.Close()
|
||||
serverConnScope.EXPECT().Done()
|
||||
serverConn.Close()
|
||||
}
|
||||
|
||||
func TestResourceManagerDialDenied(t *testing.T) {
|
||||
_, clientKey := createPeer(t)
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
rcmgr := mocknetwork.NewMockResourceManager(ctrl)
|
||||
clientTransport, err := NewTransport(clientKey, nil, nil, rcmgr)
|
||||
require.NoError(t, err)
|
||||
defer clientTransport.(io.Closer).Close()
|
||||
|
||||
connScope := mocknetwork.NewMockConnManagementScope(ctrl)
|
||||
rcmgr.EXPECT().OpenConnection(network.DirOutbound, false).Return(connScope, nil)
|
||||
rerr := errors.New("nope")
|
||||
p := peer.ID("server")
|
||||
connScope.EXPECT().SetPeer(p).Return(rerr)
|
||||
connScope.EXPECT().Done()
|
||||
|
||||
_, err = clientTransport.Dial(context.Background(), ma.StringCast("/ip4/127.0.0.1/udp/1234/quic"), p)
|
||||
require.ErrorIs(t, err, rerr)
|
||||
}
|
||||
|
||||
func TestResourceManagerAcceptDenied(t *testing.T) {
|
||||
serverID, serverKey := createPeer(t)
|
||||
clientID, clientKey := createPeer(t)
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
clientRcmgr := mocknetwork.NewMockResourceManager(ctrl)
|
||||
clientTransport, err := NewTransport(clientKey, nil, nil, clientRcmgr)
|
||||
require.NoError(t, err)
|
||||
defer clientTransport.(io.Closer).Close()
|
||||
|
||||
serverRcmgr := mocknetwork.NewMockResourceManager(ctrl)
|
||||
serverConnScope := mocknetwork.NewMockConnManagementScope(ctrl)
|
||||
rerr := errors.New("denied")
|
||||
gomock.InOrder(
|
||||
serverRcmgr.EXPECT().OpenConnection(network.DirInbound, false).Return(serverConnScope, nil),
|
||||
serverConnScope.EXPECT().SetPeer(clientID).Return(rerr),
|
||||
serverConnScope.EXPECT().Done(),
|
||||
)
|
||||
serverTransport, err := NewTransport(serverKey, nil, nil, serverRcmgr)
|
||||
require.NoError(t, err)
|
||||
defer serverTransport.(io.Closer).Close()
|
||||
ln, err := serverTransport.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic"))
|
||||
require.NoError(t, err)
|
||||
defer ln.Close()
|
||||
connChan := make(chan tpt.CapableConn)
|
||||
go func() {
|
||||
ln.Accept()
|
||||
close(connChan)
|
||||
}()
|
||||
|
||||
clientConnScope := mocknetwork.NewMockConnManagementScope(ctrl)
|
||||
clientRcmgr.EXPECT().OpenConnection(network.DirOutbound, false).Return(clientConnScope, nil)
|
||||
clientConnScope.EXPECT().SetPeer(serverID)
|
||||
conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID)
|
||||
require.NoError(t, err)
|
||||
_, err = conn.AcceptStream()
|
||||
require.Error(t, err)
|
||||
select {
|
||||
case <-connChan:
|
||||
t.Fatal("didn't expect to accept a connection")
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreams(t *testing.T) {
|
||||
serverID, serverKey := createPeer(t)
|
||||
_, clientKey := createPeer(t)
|
||||
|
||||
serverTransport, err := NewTransport(serverKey, nil, nil, nil)
|
||||
require.NoError(t, err)
|
||||
defer serverTransport.(io.Closer).Close()
|
||||
ln := runServer(t, serverTransport, "/ip4/127.0.0.1/udp/0/quic")
|
||||
defer ln.Close()
|
||||
|
||||
clientTransport, err := NewTransport(clientKey, nil, nil, nil)
|
||||
require.NoError(t, err)
|
||||
defer clientTransport.(io.Closer).Close()
|
||||
conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID)
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
serverConn, err := ln.Accept()
|
||||
require.NoError(t, err)
|
||||
defer serverConn.Close()
|
||||
|
||||
str, err := conn.OpenStream(context.Background())
|
||||
require.NoError(t, err)
|
||||
_, err = str.Write([]byte("foobar"))
|
||||
require.NoError(t, err)
|
||||
str.Close()
|
||||
sstr, err := serverConn.AcceptStream()
|
||||
require.NoError(t, err)
|
||||
data, err := ioutil.ReadAll(sstr)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, data, []byte("foobar"))
|
||||
}
|
||||
|
||||
func TestHandshakeFailPeerIDMismatch(t *testing.T) {
|
||||
_, serverKey := createPeer(t)
|
||||
_, clientKey := createPeer(t)
|
||||
thirdPartyID, _ := createPeer(t)
|
||||
|
||||
serverTransport, err := NewTransport(serverKey, nil, nil, nil)
|
||||
require.NoError(t, err)
|
||||
defer serverTransport.(io.Closer).Close()
|
||||
ln := runServer(t, serverTransport, "/ip4/127.0.0.1/udp/0/quic")
|
||||
|
||||
clientTransport, err := NewTransport(clientKey, nil, nil, nil)
|
||||
require.NoError(t, err)
|
||||
// dial, but expect the wrong peer ID
|
||||
_, err = clientTransport.Dial(context.Background(), ln.Multiaddr(), thirdPartyID)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "CRYPTO_ERROR")
|
||||
defer clientTransport.(io.Closer).Close()
|
||||
|
||||
acceptErr := make(chan error)
|
||||
go func() {
|
||||
_, err := ln.Accept()
|
||||
acceptErr <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-acceptErr:
|
||||
t.Fatal("didn't expect Accept to return before being closed")
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
}
|
||||
|
||||
require.NoError(t, ln.Close())
|
||||
require.Error(t, <-acceptErr)
|
||||
}
|
||||
|
||||
func TestConnectionGating(t *testing.T) {
|
||||
serverID, serverKey := createPeer(t)
|
||||
_, clientKey := createPeer(t)
|
||||
|
||||
mockCtrl := gomock.NewController(t)
|
||||
defer mockCtrl.Finish()
|
||||
cg := NewMockConnectionGater(mockCtrl)
|
||||
|
||||
t.Run("accepted connections", func(t *testing.T) {
|
||||
serverTransport, err := NewTransport(serverKey, nil, cg, nil)
|
||||
defer serverTransport.(io.Closer).Close()
|
||||
require.NoError(t, err)
|
||||
ln := runServer(t, serverTransport, "/ip4/127.0.0.1/udp/0/quic")
|
||||
defer ln.Close()
|
||||
|
||||
cg.EXPECT().InterceptAccept(gomock.Any())
|
||||
|
||||
accepted := make(chan struct{})
|
||||
go func() {
|
||||
defer close(accepted)
|
||||
_, err := ln.Accept()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
clientTransport, err := NewTransport(clientKey, nil, nil, nil)
|
||||
require.NoError(t, err)
|
||||
defer clientTransport.(io.Closer).Close()
|
||||
// make sure that connection attempts fails
|
||||
conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID)
|
||||
require.NoError(t, err)
|
||||
_, err = conn.AcceptStream()
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "connection gated")
|
||||
|
||||
// now allow the address and make sure the connection goes through
|
||||
cg.EXPECT().InterceptAccept(gomock.Any()).Return(true)
|
||||
cg.EXPECT().InterceptSecured(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
|
||||
clientTransport.(*transport).clientConfig.HandshakeIdleTimeout = 2 * time.Second
|
||||
conn, err = clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID)
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
require.Eventually(t, func() bool {
|
||||
select {
|
||||
case <-accepted:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, time.Second, 10*time.Millisecond)
|
||||
})
|
||||
|
||||
t.Run("secured connections", func(t *testing.T) {
|
||||
serverTransport, err := NewTransport(serverKey, nil, nil, nil)
|
||||
require.NoError(t, err)
|
||||
defer serverTransport.(io.Closer).Close()
|
||||
ln := runServer(t, serverTransport, "/ip4/127.0.0.1/udp/0/quic")
|
||||
defer ln.Close()
|
||||
|
||||
cg := NewMockConnectionGater(mockCtrl)
|
||||
cg.EXPECT().InterceptSecured(gomock.Any(), gomock.Any(), gomock.Any())
|
||||
|
||||
clientTransport, err := NewTransport(clientKey, nil, cg, nil)
|
||||
require.NoError(t, err)
|
||||
defer clientTransport.(io.Closer).Close()
|
||||
|
||||
// make sure that connection attempts fails
|
||||
_, err = clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "connection gated")
|
||||
|
||||
// now allow the peerId and make sure the connection goes through
|
||||
cg.EXPECT().InterceptSecured(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
|
||||
clientTransport.(*transport).clientConfig.HandshakeIdleTimeout = 2 * time.Second
|
||||
conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID)
|
||||
require.NoError(t, err)
|
||||
conn.Close()
|
||||
})
|
||||
}
|
||||
|
||||
func TestDialTwo(t *testing.T) {
|
||||
serverID, serverKey := createPeer(t)
|
||||
_, clientKey := createPeer(t)
|
||||
serverID2, serverKey2 := createPeer(t)
|
||||
|
||||
serverTransport, err := NewTransport(serverKey, nil, nil, nil)
|
||||
require.NoError(t, err)
|
||||
defer serverTransport.(io.Closer).Close()
|
||||
ln1 := runServer(t, serverTransport, "/ip4/127.0.0.1/udp/0/quic")
|
||||
defer ln1.Close()
|
||||
serverTransport2, err := NewTransport(serverKey2, nil, nil, nil)
|
||||
require.NoError(t, err)
|
||||
defer serverTransport2.(io.Closer).Close()
|
||||
ln2 := runServer(t, serverTransport2, "/ip4/127.0.0.1/udp/0/quic")
|
||||
defer ln2.Close()
|
||||
|
||||
data := bytes.Repeat([]byte{'a'}, 5*1<<20) // 5 MB
|
||||
// wait for both servers to accept a connection
|
||||
// then send some data
|
||||
go func() {
|
||||
serverConn1, err := ln1.Accept()
|
||||
require.NoError(t, err)
|
||||
serverConn2, err := ln2.Accept()
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, c := range []tpt.CapableConn{serverConn1, serverConn2} {
|
||||
go func(conn tpt.CapableConn) {
|
||||
str, err := conn.OpenStream(context.Background())
|
||||
require.NoError(t, err)
|
||||
defer str.Close()
|
||||
_, err = str.Write(data)
|
||||
require.NoError(t, err)
|
||||
}(c)
|
||||
}
|
||||
}()
|
||||
|
||||
clientTransport, err := NewTransport(clientKey, nil, nil, nil)
|
||||
require.NoError(t, err)
|
||||
defer clientTransport.(io.Closer).Close()
|
||||
c1, err := clientTransport.Dial(context.Background(), ln1.Multiaddr(), serverID)
|
||||
require.NoError(t, err)
|
||||
defer c1.Close()
|
||||
c2, err := clientTransport.Dial(context.Background(), ln2.Multiaddr(), serverID2)
|
||||
require.NoError(t, err)
|
||||
defer c2.Close()
|
||||
|
||||
done := make(chan struct{}, 2)
|
||||
// receive the data on both connections at the same time
|
||||
for _, c := range []tpt.CapableConn{c1, c2} {
|
||||
go func(conn tpt.CapableConn) {
|
||||
str, err := conn.AcceptStream()
|
||||
require.NoError(t, err)
|
||||
str.CloseWrite()
|
||||
d, err := ioutil.ReadAll(str)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, d, data)
|
||||
done <- struct{}{}
|
||||
}(c)
|
||||
}
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
require.Eventually(t, func() bool {
|
||||
select {
|
||||
case <-done:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, 15*time.Second, 50*time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStatelessReset(t *testing.T) {
|
||||
origGarbageCollectInterval := garbageCollectInterval
|
||||
origMaxUnusedDuration := maxUnusedDuration
|
||||
|
||||
garbageCollectInterval = 50 * time.Millisecond
|
||||
maxUnusedDuration = 0
|
||||
|
||||
t.Cleanup(func() {
|
||||
garbageCollectInterval = origGarbageCollectInterval
|
||||
maxUnusedDuration = origMaxUnusedDuration
|
||||
})
|
||||
|
||||
serverID, serverKey := createPeer(t)
|
||||
_, clientKey := createPeer(t)
|
||||
|
||||
serverTransport, err := NewTransport(serverKey, nil, nil, nil)
|
||||
require.NoError(t, err)
|
||||
defer serverTransport.(io.Closer).Close()
|
||||
ln := runServer(t, serverTransport, "/ip4/127.0.0.1/udp/0/quic")
|
||||
|
||||
var drop uint32
|
||||
serverPort := ln.Addr().(*net.UDPAddr).Port
|
||||
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
|
||||
RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
|
||||
DropPacket: func(quicproxy.Direction, []byte) bool {
|
||||
return atomic.LoadUint32(&drop) > 0
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer proxy.Close()
|
||||
|
||||
// establish a connection
|
||||
clientTransport, err := NewTransport(clientKey, nil, nil, nil)
|
||||
require.NoError(t, err)
|
||||
defer clientTransport.(io.Closer).Close()
|
||||
proxyAddr, err := toQuicMultiaddr(proxy.LocalAddr())
|
||||
require.NoError(t, err)
|
||||
conn, err := clientTransport.Dial(context.Background(), proxyAddr, serverID)
|
||||
require.NoError(t, err)
|
||||
connChan := make(chan tpt.CapableConn)
|
||||
go func() {
|
||||
conn, err := ln.Accept()
|
||||
require.NoError(t, err)
|
||||
str, err := conn.OpenStream(context.Background())
|
||||
require.NoError(t, err)
|
||||
str.Write([]byte("foobar"))
|
||||
connChan <- conn
|
||||
}()
|
||||
|
||||
str, err := conn.AcceptStream()
|
||||
require.NoError(t, err)
|
||||
_, err = str.Read(make([]byte, 6))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Stop forwarding packets and close the server.
|
||||
// This prevents the CONNECTION_CLOSE from reaching the client.
|
||||
atomic.StoreUint32(&drop, 1)
|
||||
ln.Close()
|
||||
(<-connChan).Close()
|
||||
// require.NoError(t, ln.Close())
|
||||
time.Sleep(2000 * time.Millisecond) // give the kernel some time to free the UDP port
|
||||
ln = runServer(t, serverTransport, fmt.Sprintf("/ip4/127.0.0.1/udp/%d/quic", serverPort))
|
||||
defer ln.Close()
|
||||
// Now that the new server is up, re-enable packet forwarding.
|
||||
atomic.StoreUint32(&drop, 0)
|
||||
|
||||
// Trigger something (not too small) to be sent, so that we receive the stateless reset.
|
||||
// The new server doesn't have any state for the previously established connection.
|
||||
// We expect it to send a stateless reset.
|
||||
_, rerr := str.Write([]byte("Lorem ipsum dolor sit amet."))
|
||||
if rerr == nil {
|
||||
_, rerr = str.Read([]byte{0, 0})
|
||||
}
|
||||
require.Error(t, rerr)
|
||||
require.Contains(t, rerr.Error(), "received a stateless reset")
|
||||
}
|
||||
|
||||
func TestHolePunching(t *testing.T) {
|
||||
serverID, serverKey := createPeer(t)
|
||||
clientID, clientKey := createPeer(t)
|
||||
|
||||
t1, err := NewTransport(serverKey, nil, nil, nil)
|
||||
require.NoError(t, err)
|
||||
defer t1.(io.Closer).Close()
|
||||
laddr, err := ma.NewMultiaddr("/ip4/127.0.0.1/udp/0/quic")
|
||||
require.NoError(t, err)
|
||||
ln1, err := t1.Listen(laddr)
|
||||
require.NoError(t, err)
|
||||
done1 := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done1)
|
||||
_, err := ln1.Accept()
|
||||
require.Error(t, err, "didn't expect to accept any connections")
|
||||
}()
|
||||
|
||||
t2, err := NewTransport(clientKey, nil, nil, nil)
|
||||
require.NoError(t, err)
|
||||
defer t2.(io.Closer).Close()
|
||||
ln2, err := t2.Listen(laddr)
|
||||
require.NoError(t, err)
|
||||
done2 := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done2)
|
||||
_, err := ln2.Accept()
|
||||
require.Error(t, err, "didn't expect to accept any connections")
|
||||
}()
|
||||
connChan := make(chan tpt.CapableConn)
|
||||
go func() {
|
||||
conn, err := t2.Dial(
|
||||
network.WithSimultaneousConnect(context.Background(), false, ""),
|
||||
ln1.Multiaddr(),
|
||||
serverID,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
connChan <- conn
|
||||
}()
|
||||
conn1, err := t1.Dial(
|
||||
network.WithSimultaneousConnect(context.Background(), true, ""),
|
||||
ln2.Multiaddr(),
|
||||
clientID,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer conn1.Close()
|
||||
require.Equal(t, conn1.RemotePeer(), clientID)
|
||||
var conn2 tpt.CapableConn
|
||||
require.Eventually(t, func() bool {
|
||||
select {
|
||||
case conn2 = <-connChan:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, 100*time.Millisecond, 10*time.Millisecond)
|
||||
defer conn2.Close()
|
||||
require.Equal(t, conn2.RemotePeer(), serverID)
|
||||
ln1.Close()
|
||||
ln2.Close()
|
||||
<-done1
|
||||
<-done2
|
||||
}
|
|
@ -0,0 +1,160 @@
|
|||
package libp2pquic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
|
||||
ic "github.com/libp2p/go-libp2p-core/crypto"
|
||||
"github.com/libp2p/go-libp2p-core/network"
|
||||
"github.com/libp2p/go-libp2p-core/peer"
|
||||
tpt "github.com/libp2p/go-libp2p-core/transport"
|
||||
|
||||
p2ptls "github.com/libp2p/go-libp2p-tls"
|
||||
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
ma "github.com/multiformats/go-multiaddr"
|
||||
)
|
||||
|
||||
var quicListen = quic.Listen // so we can mock it in tests
|
||||
|
||||
// A listener listens for QUIC connections.
|
||||
type listener struct {
|
||||
quicListener quic.Listener
|
||||
conn *reuseConn
|
||||
transport *transport
|
||||
rcmgr network.ResourceManager
|
||||
privKey ic.PrivKey
|
||||
localPeer peer.ID
|
||||
localMultiaddr ma.Multiaddr
|
||||
}
|
||||
|
||||
var _ tpt.Listener = &listener{}
|
||||
|
||||
func newListener(rconn *reuseConn, t *transport, localPeer peer.ID, key ic.PrivKey, identity *p2ptls.Identity, rcmgr network.ResourceManager) (tpt.Listener, error) {
|
||||
var tlsConf tls.Config
|
||||
tlsConf.GetConfigForClient = func(_ *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||
// return a tls.Config that verifies the peer's certificate chain.
|
||||
// Note that since we have no way of associating an incoming QUIC connection with
|
||||
// the peer ID calculated here, we don't actually receive the peer's public key
|
||||
// from the key chan.
|
||||
conf, _ := identity.ConfigForPeer("")
|
||||
return conf, nil
|
||||
}
|
||||
ln, err := quicListen(rconn, &tlsConf, t.serverConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
localMultiaddr, err := toQuicMultiaddr(ln.Addr())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &listener{
|
||||
conn: rconn,
|
||||
quicListener: ln,
|
||||
transport: t,
|
||||
rcmgr: rcmgr,
|
||||
privKey: key,
|
||||
localPeer: localPeer,
|
||||
localMultiaddr: localMultiaddr,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Accept accepts new connections.
|
||||
func (l *listener) Accept() (tpt.CapableConn, error) {
|
||||
for {
|
||||
qconn, err := l.quicListener.Accept(context.Background())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c, err := l.setupConn(qconn)
|
||||
if err != nil {
|
||||
qconn.CloseWithError(0, err.Error())
|
||||
continue
|
||||
}
|
||||
if l.transport.gater != nil && !(l.transport.gater.InterceptAccept(c) && l.transport.gater.InterceptSecured(network.DirInbound, c.remotePeerID, c)) {
|
||||
c.scope.Done()
|
||||
qconn.CloseWithError(errorCodeConnectionGating, "connection gated")
|
||||
continue
|
||||
}
|
||||
l.transport.addConn(qconn, c)
|
||||
|
||||
// return through active hole punching if any
|
||||
key := holePunchKey{addr: qconn.RemoteAddr().String(), peer: c.remotePeerID}
|
||||
var wasHolePunch bool
|
||||
l.transport.holePunchingMx.Lock()
|
||||
holePunch, ok := l.transport.holePunching[key]
|
||||
if ok && !holePunch.fulfilled {
|
||||
holePunch.connCh <- c
|
||||
wasHolePunch = true
|
||||
holePunch.fulfilled = true
|
||||
}
|
||||
l.transport.holePunchingMx.Unlock()
|
||||
if wasHolePunch {
|
||||
continue
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (l *listener) setupConn(qconn quic.Connection) (*conn, error) {
|
||||
connScope, err := l.rcmgr.OpenConnection(network.DirInbound, false)
|
||||
if err != nil {
|
||||
log.Debugw("resource manager blocked incoming connection", "addr", qconn.RemoteAddr(), "error", err)
|
||||
return nil, err
|
||||
}
|
||||
// The tls.Config used to establish this connection already verified the certificate chain.
|
||||
// Since we don't have any way of knowing which tls.Config was used though,
|
||||
// we have to re-determine the peer's identity here.
|
||||
// Therefore, this is expected to never fail.
|
||||
remotePubKey, err := p2ptls.PubKeyFromCertChain(qconn.ConnectionState().TLS.PeerCertificates)
|
||||
if err != nil {
|
||||
connScope.Done()
|
||||
return nil, err
|
||||
}
|
||||
remotePeerID, err := peer.IDFromPublicKey(remotePubKey)
|
||||
if err != nil {
|
||||
connScope.Done()
|
||||
return nil, err
|
||||
}
|
||||
if err := connScope.SetPeer(remotePeerID); err != nil {
|
||||
log.Debugw("resource manager blocked incoming connection for peer", "peer", remotePeerID, "addr", qconn.RemoteAddr(), "error", err)
|
||||
connScope.Done()
|
||||
return nil, err
|
||||
}
|
||||
remoteMultiaddr, err := toQuicMultiaddr(qconn.RemoteAddr())
|
||||
if err != nil {
|
||||
connScope.Done()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
l.conn.IncreaseCount()
|
||||
return &conn{
|
||||
quicConn: qconn,
|
||||
pconn: l.conn,
|
||||
transport: l.transport,
|
||||
scope: connScope,
|
||||
localPeer: l.localPeer,
|
||||
localMultiaddr: l.localMultiaddr,
|
||||
privKey: l.privKey,
|
||||
remoteMultiaddr: remoteMultiaddr,
|
||||
remotePeerID: remotePeerID,
|
||||
remotePubKey: remotePubKey,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Close closes the listener.
|
||||
func (l *listener) Close() error {
|
||||
defer l.conn.DecreaseCount()
|
||||
return l.quicListener.Close()
|
||||
}
|
||||
|
||||
// Addr returns the address of this listener.
|
||||
func (l *listener) Addr() net.Addr {
|
||||
return l.quicListener.Addr()
|
||||
}
|
||||
|
||||
// Multiaddr returns the multiaddress of this listener.
|
||||
func (l *listener) Multiaddr() ma.Multiaddr {
|
||||
return l.localMultiaddr
|
||||
}
|
|
@ -0,0 +1,115 @@
|
|||
package libp2pquic
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
ic "github.com/libp2p/go-libp2p-core/crypto"
|
||||
"github.com/libp2p/go-libp2p-core/network"
|
||||
tpt "github.com/libp2p/go-libp2p-core/transport"
|
||||
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
ma "github.com/multiformats/go-multiaddr"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newTransport(t *testing.T, rcmgr network.ResourceManager) tpt.Transport {
|
||||
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
key, err := ic.UnmarshalRsaPrivateKey(x509.MarshalPKCS1PrivateKey(rsaKey))
|
||||
require.NoError(t, err)
|
||||
tr, err := NewTransport(key, nil, nil, rcmgr)
|
||||
require.NoError(t, err)
|
||||
return tr
|
||||
}
|
||||
|
||||
// The conn passed to quic-go should be a conn that quic-go can be
|
||||
// type-asserted to a UDPConn. That way, it can use all kinds of optimizations.
|
||||
func TestConnUsedForListening(t *testing.T) {
|
||||
origQuicListen := quicListen
|
||||
t.Cleanup(func() { quicListen = origQuicListen })
|
||||
|
||||
var conn net.PacketConn
|
||||
quicListen = func(c net.PacketConn, _ *tls.Config, _ *quic.Config) (quic.Listener, error) {
|
||||
conn = c
|
||||
return nil, errors.New("listen error")
|
||||
}
|
||||
localAddr, err := ma.NewMultiaddr("/ip4/127.0.0.1/udp/0/quic")
|
||||
require.NoError(t, err)
|
||||
|
||||
tr := newTransport(t, nil)
|
||||
defer tr.(io.Closer).Close()
|
||||
_, err = tr.Listen(localAddr)
|
||||
require.EqualError(t, err, "listen error")
|
||||
require.NotNil(t, conn)
|
||||
defer conn.Close()
|
||||
_, ok := conn.(quic.OOBCapablePacketConn)
|
||||
require.True(t, ok)
|
||||
}
|
||||
|
||||
func TestListenAddr(t *testing.T) {
|
||||
tr := newTransport(t, nil)
|
||||
defer tr.(io.Closer).Close()
|
||||
|
||||
t.Run("for IPv4", func(t *testing.T) {
|
||||
localAddr := ma.StringCast("/ip4/127.0.0.1/udp/0/quic")
|
||||
ln, err := tr.Listen(localAddr)
|
||||
require.NoError(t, err)
|
||||
defer ln.Close()
|
||||
port := ln.Addr().(*net.UDPAddr).Port
|
||||
require.NotZero(t, port)
|
||||
require.Equal(t, ln.Multiaddr().String(), fmt.Sprintf("/ip4/127.0.0.1/udp/%d/quic", port))
|
||||
})
|
||||
|
||||
t.Run("for IPv6", func(t *testing.T) {
|
||||
localAddr := ma.StringCast("/ip6/::/udp/0/quic")
|
||||
ln, err := tr.Listen(localAddr)
|
||||
require.NoError(t, err)
|
||||
defer ln.Close()
|
||||
port := ln.Addr().(*net.UDPAddr).Port
|
||||
require.NotZero(t, port)
|
||||
require.Equal(t, ln.Multiaddr().String(), fmt.Sprintf("/ip6/::/udp/%d/quic", port))
|
||||
})
|
||||
}
|
||||
|
||||
func TestAccepting(t *testing.T) {
|
||||
tr := newTransport(t, nil)
|
||||
defer tr.(io.Closer).Close()
|
||||
ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic"))
|
||||
require.NoError(t, err)
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
ln.Accept()
|
||||
close(done)
|
||||
}()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
select {
|
||||
case <-done:
|
||||
t.Fatal("Accept didn't block")
|
||||
default:
|
||||
}
|
||||
require.NoError(t, ln.Close())
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("Accept didn't return after the listener was closed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAcceptAfterClose(t *testing.T) {
|
||||
tr := newTransport(t, nil)
|
||||
defer tr.(io.Closer).Close()
|
||||
ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic"))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, ln.Close())
|
||||
_, err = ln.Accept()
|
||||
require.Error(t, err)
|
||||
}
|
|
@ -0,0 +1,109 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/libp2p/go-libp2p-core/connmgr (interfaces: ConnectionGater)
|
||||
|
||||
// Package libp2pquic is a generated GoMock package.
|
||||
package libp2pquic
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
control "github.com/libp2p/go-libp2p-core/control"
|
||||
network "github.com/libp2p/go-libp2p-core/network"
|
||||
peer "github.com/libp2p/go-libp2p-core/peer"
|
||||
multiaddr "github.com/multiformats/go-multiaddr"
|
||||
)
|
||||
|
||||
// MockConnectionGater is a mock of ConnectionGater interface.
|
||||
type MockConnectionGater struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockConnectionGaterMockRecorder
|
||||
}
|
||||
|
||||
// MockConnectionGaterMockRecorder is the mock recorder for MockConnectionGater.
|
||||
type MockConnectionGaterMockRecorder struct {
|
||||
mock *MockConnectionGater
|
||||
}
|
||||
|
||||
// NewMockConnectionGater creates a new mock instance.
|
||||
func NewMockConnectionGater(ctrl *gomock.Controller) *MockConnectionGater {
|
||||
mock := &MockConnectionGater{ctrl: ctrl}
|
||||
mock.recorder = &MockConnectionGaterMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockConnectionGater) EXPECT() *MockConnectionGaterMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// InterceptAccept mocks base method.
|
||||
func (m *MockConnectionGater) InterceptAccept(arg0 network.ConnMultiaddrs) bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "InterceptAccept", arg0)
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// InterceptAccept indicates an expected call of InterceptAccept.
|
||||
func (mr *MockConnectionGaterMockRecorder) InterceptAccept(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InterceptAccept", reflect.TypeOf((*MockConnectionGater)(nil).InterceptAccept), arg0)
|
||||
}
|
||||
|
||||
// InterceptAddrDial mocks base method.
|
||||
func (m *MockConnectionGater) InterceptAddrDial(arg0 peer.ID, arg1 multiaddr.Multiaddr) bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "InterceptAddrDial", arg0, arg1)
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// InterceptAddrDial indicates an expected call of InterceptAddrDial.
|
||||
func (mr *MockConnectionGaterMockRecorder) InterceptAddrDial(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InterceptAddrDial", reflect.TypeOf((*MockConnectionGater)(nil).InterceptAddrDial), arg0, arg1)
|
||||
}
|
||||
|
||||
// InterceptPeerDial mocks base method.
|
||||
func (m *MockConnectionGater) InterceptPeerDial(arg0 peer.ID) bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "InterceptPeerDial", arg0)
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// InterceptPeerDial indicates an expected call of InterceptPeerDial.
|
||||
func (mr *MockConnectionGaterMockRecorder) InterceptPeerDial(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InterceptPeerDial", reflect.TypeOf((*MockConnectionGater)(nil).InterceptPeerDial), arg0)
|
||||
}
|
||||
|
||||
// InterceptSecured mocks base method.
|
||||
func (m *MockConnectionGater) InterceptSecured(arg0 network.Direction, arg1 peer.ID, arg2 network.ConnMultiaddrs) bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "InterceptSecured", arg0, arg1, arg2)
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// InterceptSecured indicates an expected call of InterceptSecured.
|
||||
func (mr *MockConnectionGaterMockRecorder) InterceptSecured(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InterceptSecured", reflect.TypeOf((*MockConnectionGater)(nil).InterceptSecured), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// InterceptUpgraded mocks base method.
|
||||
func (m *MockConnectionGater) InterceptUpgraded(arg0 network.Conn) (bool, control.DisconnectReason) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "InterceptUpgraded", arg0)
|
||||
ret0, _ := ret[0].(bool)
|
||||
ret1, _ := ret[1].(control.DisconnectReason)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// InterceptUpgraded indicates an expected call of InterceptUpgraded.
|
||||
func (mr *MockConnectionGaterMockRecorder) InterceptUpgraded(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InterceptUpgraded", reflect.TypeOf((*MockConnectionGater)(nil).InterceptUpgraded), arg0)
|
||||
}
|
|
@ -0,0 +1,30 @@
|
|||
package libp2pquic
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
ma "github.com/multiformats/go-multiaddr"
|
||||
manet "github.com/multiformats/go-multiaddr/net"
|
||||
)
|
||||
|
||||
var quicMA ma.Multiaddr
|
||||
|
||||
func init() {
|
||||
var err error
|
||||
quicMA, err = ma.NewMultiaddr("/quic")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func toQuicMultiaddr(na net.Addr) (ma.Multiaddr, error) {
|
||||
udpMA, err := manet.FromNetAddr(na)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return udpMA.Encapsulate(quicMA), nil
|
||||
}
|
||||
|
||||
func fromQuicMultiaddr(addr ma.Multiaddr) (net.Addr, error) {
|
||||
return manet.ToNetAddr(addr.Decapsulate(quicMA))
|
||||
}
|
|
@ -0,0 +1,27 @@
|
|||
package libp2pquic
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
ma "github.com/multiformats/go-multiaddr"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestConvertToQuicMultiaddr(t *testing.T) {
|
||||
addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 42), Port: 1337}
|
||||
maddr, err := toQuicMultiaddr(addr)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, maddr.String(), "/ip4/192.168.0.42/udp/1337/quic")
|
||||
}
|
||||
|
||||
func TestConvertFromQuicMultiaddr(t *testing.T) {
|
||||
maddr, err := ma.NewMultiaddr("/ip4/192.168.0.42/udp/1337/quic")
|
||||
require.NoError(t, err)
|
||||
addr, err := fromQuicMultiaddr(maddr)
|
||||
require.NoError(t, err)
|
||||
udpAddr, ok := addr.(*net.UDPAddr)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, udpAddr.IP, net.IPv4(192, 168, 0, 42))
|
||||
require.Equal(t, udpAddr.Port, 1337)
|
||||
}
|
|
@ -0,0 +1,231 @@
|
|||
package libp2pquic
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/gopacket/routing"
|
||||
"github.com/libp2p/go-netroute"
|
||||
)
|
||||
|
||||
// Constant. Defined as variables to simplify testing.
|
||||
var (
|
||||
garbageCollectInterval = 30 * time.Second
|
||||
maxUnusedDuration = 10 * time.Second
|
||||
)
|
||||
|
||||
type reuseConn struct {
|
||||
*net.UDPConn
|
||||
|
||||
mutex sync.Mutex
|
||||
refCount int
|
||||
unusedSince time.Time
|
||||
}
|
||||
|
||||
func newReuseConn(conn *net.UDPConn) *reuseConn {
|
||||
return &reuseConn{UDPConn: conn}
|
||||
}
|
||||
|
||||
func (c *reuseConn) IncreaseCount() {
|
||||
c.mutex.Lock()
|
||||
c.refCount++
|
||||
c.unusedSince = time.Time{}
|
||||
c.mutex.Unlock()
|
||||
}
|
||||
|
||||
func (c *reuseConn) DecreaseCount() {
|
||||
c.mutex.Lock()
|
||||
c.refCount--
|
||||
if c.refCount == 0 {
|
||||
c.unusedSince = time.Now()
|
||||
}
|
||||
c.mutex.Unlock()
|
||||
}
|
||||
|
||||
func (c *reuseConn) ShouldGarbageCollect(now time.Time) bool {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
return !c.unusedSince.IsZero() && c.unusedSince.Add(maxUnusedDuration).Before(now)
|
||||
}
|
||||
|
||||
type reuse struct {
|
||||
mutex sync.Mutex
|
||||
|
||||
closeChan chan struct{}
|
||||
gcStopChan chan struct{}
|
||||
|
||||
routes routing.Router
|
||||
unicast map[string] /* IP.String() */ map[int] /* port */ *reuseConn
|
||||
// global contains connections that are listening on 0.0.0.0 / ::
|
||||
global map[int]*reuseConn
|
||||
}
|
||||
|
||||
func newReuse() *reuse {
|
||||
r := &reuse{
|
||||
unicast: make(map[string]map[int]*reuseConn),
|
||||
global: make(map[int]*reuseConn),
|
||||
closeChan: make(chan struct{}),
|
||||
gcStopChan: make(chan struct{}),
|
||||
}
|
||||
go r.gc()
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *reuse) gc() {
|
||||
defer func() {
|
||||
r.mutex.Lock()
|
||||
for _, conn := range r.global {
|
||||
conn.Close()
|
||||
}
|
||||
for _, conns := range r.unicast {
|
||||
for _, conn := range conns {
|
||||
conn.Close()
|
||||
}
|
||||
}
|
||||
r.mutex.Unlock()
|
||||
close(r.gcStopChan)
|
||||
}()
|
||||
ticker := time.NewTicker(garbageCollectInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-r.closeChan:
|
||||
return
|
||||
case now := <-ticker.C:
|
||||
r.mutex.Lock()
|
||||
for key, conn := range r.global {
|
||||
if conn.ShouldGarbageCollect(now) {
|
||||
conn.Close()
|
||||
delete(r.global, key)
|
||||
}
|
||||
}
|
||||
for ukey, conns := range r.unicast {
|
||||
for key, conn := range conns {
|
||||
if conn.ShouldGarbageCollect(now) {
|
||||
conn.Close()
|
||||
delete(conns, key)
|
||||
}
|
||||
}
|
||||
if len(conns) == 0 {
|
||||
delete(r.unicast, ukey)
|
||||
// If we've dropped all connections with a unicast binding,
|
||||
// assume our routes may have changed.
|
||||
if len(r.unicast) == 0 {
|
||||
r.routes = nil
|
||||
} else {
|
||||
// Ignore the error, there's nothing we can do about
|
||||
// it.
|
||||
r.routes, _ = netroute.New()
|
||||
}
|
||||
}
|
||||
}
|
||||
r.mutex.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *reuse) Dial(network string, raddr *net.UDPAddr) (*reuseConn, error) {
|
||||
var ip *net.IP
|
||||
|
||||
// Only bother looking up the source address if we actually _have_ non 0.0.0.0 listeners.
|
||||
// Otherwise, save some time.
|
||||
|
||||
r.mutex.Lock()
|
||||
router := r.routes
|
||||
r.mutex.Unlock()
|
||||
|
||||
if router != nil {
|
||||
_, _, src, err := router.Route(raddr.IP)
|
||||
if err == nil && !src.IsUnspecified() {
|
||||
ip = &src
|
||||
}
|
||||
}
|
||||
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
|
||||
conn, err := r.dialLocked(network, ip)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn.IncreaseCount()
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (r *reuse) dialLocked(network string, source *net.IP) (*reuseConn, error) {
|
||||
if source != nil {
|
||||
// We already have at least one suitable connection...
|
||||
if conns, ok := r.unicast[source.String()]; ok {
|
||||
// ... we don't care which port we're dialing from. Just use the first.
|
||||
for _, c := range conns {
|
||||
return c, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Use a connection listening on 0.0.0.0 (or ::).
|
||||
// Again, we don't care about the port number.
|
||||
for _, conn := range r.global {
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// We don't have a connection that we can use for dialing.
|
||||
// Dial a new connection from a random port.
|
||||
var addr *net.UDPAddr
|
||||
switch network {
|
||||
case "udp4":
|
||||
addr = &net.UDPAddr{IP: net.IPv4zero, Port: 0}
|
||||
case "udp6":
|
||||
addr = &net.UDPAddr{IP: net.IPv6zero, Port: 0}
|
||||
}
|
||||
conn, err := net.ListenUDP(network, addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rconn := newReuseConn(conn)
|
||||
r.global[conn.LocalAddr().(*net.UDPAddr).Port] = rconn
|
||||
return rconn, nil
|
||||
}
|
||||
|
||||
func (r *reuse) Listen(network string, laddr *net.UDPAddr) (*reuseConn, error) {
|
||||
conn, err := net.ListenUDP(network, laddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
localAddr := conn.LocalAddr().(*net.UDPAddr)
|
||||
|
||||
rconn := newReuseConn(conn)
|
||||
rconn.IncreaseCount()
|
||||
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
|
||||
// Deal with listen on a global address
|
||||
if localAddr.IP.IsUnspecified() {
|
||||
// The kernel already checked that the laddr is not already listen
|
||||
// so we need not check here (when we create ListenUDP).
|
||||
r.global[localAddr.Port] = rconn
|
||||
return rconn, err
|
||||
}
|
||||
|
||||
// Deal with listen on a unicast address
|
||||
if _, ok := r.unicast[localAddr.IP.String()]; !ok {
|
||||
r.unicast[localAddr.IP.String()] = make(map[int]*reuseConn)
|
||||
// Assume the system's routes may have changed if we're adding a new listener.
|
||||
// Ignore the error, there's nothing we can do.
|
||||
r.routes, _ = netroute.New()
|
||||
}
|
||||
|
||||
// The kernel already checked that the laddr is not already listen
|
||||
// so we need not check here (when we create ListenUDP).
|
||||
r.unicast[localAddr.IP.String()][localAddr.Port] = rconn
|
||||
return rconn, err
|
||||
}
|
||||
|
||||
func (r *reuse) Close() error {
|
||||
close(r.closeChan)
|
||||
<-r.gcStopChan
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,175 @@
|
|||
package libp2pquic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net"
|
||||
"runtime/pprof"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/libp2p/go-netroute"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func (c *reuseConn) GetCount() int {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
return c.refCount
|
||||
}
|
||||
|
||||
func closeAllConns(reuse *reuse) {
|
||||
reuse.mutex.Lock()
|
||||
for _, conn := range reuse.global {
|
||||
for conn.GetCount() > 0 {
|
||||
conn.DecreaseCount()
|
||||
}
|
||||
}
|
||||
for _, conns := range reuse.unicast {
|
||||
for _, conn := range conns {
|
||||
for conn.GetCount() > 0 {
|
||||
conn.DecreaseCount()
|
||||
}
|
||||
}
|
||||
}
|
||||
reuse.mutex.Unlock()
|
||||
}
|
||||
|
||||
func platformHasRoutingTables() bool {
|
||||
_, err := netroute.New()
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func isGarbageCollectorRunning() bool {
|
||||
var b bytes.Buffer
|
||||
pprof.Lookup("goroutine").WriteTo(&b, 1)
|
||||
return strings.Contains(b.String(), "quic.(*reuse).gc")
|
||||
}
|
||||
|
||||
func cleanup(t *testing.T, reuse *reuse) {
|
||||
t.Cleanup(func() {
|
||||
closeAllConns(reuse)
|
||||
reuse.Close()
|
||||
require.False(t, isGarbageCollectorRunning(), "reuse gc still running")
|
||||
})
|
||||
}
|
||||
|
||||
func TestReuseListenOnAllIPv4(t *testing.T) {
|
||||
reuse := newReuse()
|
||||
require.Eventually(t, isGarbageCollectorRunning, 100*time.Millisecond, time.Millisecond, "expected garbage collector to be running")
|
||||
cleanup(t, reuse)
|
||||
|
||||
addr, err := net.ResolveUDPAddr("udp4", "0.0.0.0:0")
|
||||
require.NoError(t, err)
|
||||
conn, err := reuse.Listen("udp4", addr)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, conn.GetCount(), 1)
|
||||
}
|
||||
|
||||
func TestReuseListenOnAllIPv6(t *testing.T) {
|
||||
reuse := newReuse()
|
||||
require.Eventually(t, isGarbageCollectorRunning, 100*time.Millisecond, time.Millisecond, "expected garbage collector to be running")
|
||||
cleanup(t, reuse)
|
||||
|
||||
addr, err := net.ResolveUDPAddr("udp6", "[::]:1234")
|
||||
require.NoError(t, err)
|
||||
conn, err := reuse.Listen("udp6", addr)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, conn.GetCount(), 1)
|
||||
}
|
||||
|
||||
func TestReuseCreateNewGlobalConnOnDial(t *testing.T) {
|
||||
reuse := newReuse()
|
||||
cleanup(t, reuse)
|
||||
|
||||
addr, err := net.ResolveUDPAddr("udp4", "1.1.1.1:1234")
|
||||
require.NoError(t, err)
|
||||
conn, err := reuse.Dial("udp4", addr)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, conn.GetCount(), 1)
|
||||
laddr := conn.LocalAddr().(*net.UDPAddr)
|
||||
require.Equal(t, laddr.IP.String(), "0.0.0.0")
|
||||
require.NotEqual(t, laddr.Port, 0)
|
||||
}
|
||||
|
||||
func TestReuseConnectionWhenDialing(t *testing.T) {
|
||||
reuse := newReuse()
|
||||
cleanup(t, reuse)
|
||||
|
||||
addr, err := net.ResolveUDPAddr("udp4", "0.0.0.0:0")
|
||||
require.NoError(t, err)
|
||||
lconn, err := reuse.Listen("udp4", addr)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, lconn.GetCount(), 1)
|
||||
// dial
|
||||
raddr, err := net.ResolveUDPAddr("udp4", "1.1.1.1:1234")
|
||||
require.NoError(t, err)
|
||||
conn, err := reuse.Dial("udp4", raddr)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, conn.GetCount(), 2)
|
||||
}
|
||||
|
||||
func TestReuseListenOnSpecificInterface(t *testing.T) {
|
||||
if platformHasRoutingTables() {
|
||||
t.Skip("this test only works on platforms that support routing tables")
|
||||
}
|
||||
reuse := newReuse()
|
||||
cleanup(t, reuse)
|
||||
|
||||
router, err := netroute.New()
|
||||
require.NoError(t, err)
|
||||
|
||||
raddr, err := net.ResolveUDPAddr("udp4", "1.1.1.1:1234")
|
||||
require.NoError(t, err)
|
||||
_, _, ip, err := router.Route(raddr.IP)
|
||||
require.NoError(t, err)
|
||||
// listen
|
||||
addr, err := net.ResolveUDPAddr("udp4", ip.String()+":0")
|
||||
require.NoError(t, err)
|
||||
lconn, err := reuse.Listen("udp4", addr)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, lconn.GetCount(), 1)
|
||||
// dial
|
||||
conn, err := reuse.Dial("udp4", raddr)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, conn.GetCount(), 1)
|
||||
}
|
||||
|
||||
func TestReuseGarbageCollect(t *testing.T) {
|
||||
maxUnusedDurationOrig := maxUnusedDuration
|
||||
garbageCollectIntervalOrig := garbageCollectInterval
|
||||
t.Cleanup(func() {
|
||||
maxUnusedDuration = maxUnusedDurationOrig
|
||||
garbageCollectInterval = garbageCollectIntervalOrig
|
||||
})
|
||||
garbageCollectInterval = 50 * time.Millisecond
|
||||
maxUnusedDuration = 100 * time.Millisecond
|
||||
|
||||
reuse := newReuse()
|
||||
cleanup(t, reuse)
|
||||
|
||||
numGlobals := func() int {
|
||||
reuse.mutex.Lock()
|
||||
defer reuse.mutex.Unlock()
|
||||
return len(reuse.global)
|
||||
}
|
||||
|
||||
addr, err := net.ResolveUDPAddr("udp4", "0.0.0.0:0")
|
||||
require.NoError(t, err)
|
||||
lconn, err := reuse.Listen("udp4", addr)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, lconn.GetCount(), 1)
|
||||
|
||||
closeTime := time.Now()
|
||||
lconn.DecreaseCount()
|
||||
|
||||
for {
|
||||
num := numGlobals()
|
||||
if closeTime.Add(maxUnusedDuration).Before(time.Now()) {
|
||||
break
|
||||
}
|
||||
require.Equal(t, num, 1)
|
||||
time.Sleep(2 * time.Millisecond)
|
||||
}
|
||||
require.Eventually(t, func() bool { return numGlobals() == 0 }, 100*time.Millisecond, 5*time.Millisecond)
|
||||
}
|
|
@ -0,0 +1,55 @@
|
|||
package libp2pquic
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/libp2p/go-libp2p-core/network"
|
||||
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
)
|
||||
|
||||
const (
|
||||
reset quic.StreamErrorCode = 0
|
||||
)
|
||||
|
||||
type stream struct {
|
||||
quic.Stream
|
||||
}
|
||||
|
||||
var _ network.MuxedStream = &stream{}
|
||||
|
||||
func (s *stream) Read(b []byte) (n int, err error) {
|
||||
n, err = s.Stream.Read(b)
|
||||
if err != nil && errors.Is(err, &quic.StreamError{}) {
|
||||
err = network.ErrReset
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (s *stream) Write(b []byte) (n int, err error) {
|
||||
n, err = s.Stream.Write(b)
|
||||
if err != nil && errors.Is(err, &quic.StreamError{}) {
|
||||
err = network.ErrReset
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (s *stream) Reset() error {
|
||||
s.Stream.CancelRead(reset)
|
||||
s.Stream.CancelWrite(reset)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stream) Close() error {
|
||||
s.Stream.CancelRead(reset)
|
||||
return s.Stream.Close()
|
||||
}
|
||||
|
||||
func (s *stream) CloseRead() error {
|
||||
s.Stream.CancelRead(reset)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stream) CloseWrite() error {
|
||||
return s.Stream.Close()
|
||||
}
|
|
@ -0,0 +1,95 @@
|
|||
package libp2pquic
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/klauspost/compress/zstd"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/logging"
|
||||
"github.com/lucas-clemente/quic-go/qlog"
|
||||
)
|
||||
|
||||
var tracer logging.Tracer
|
||||
|
||||
func init() {
|
||||
tracers := []logging.Tracer{&metricsTracer{}}
|
||||
if qlogDir := os.Getenv("QLOGDIR"); len(qlogDir) > 0 {
|
||||
if qlogger := initQlogger(qlogDir); qlogger != nil {
|
||||
tracers = append(tracers, qlogger)
|
||||
}
|
||||
}
|
||||
tracer = logging.NewMultiplexedTracer(tracers...)
|
||||
}
|
||||
|
||||
func initQlogger(qlogDir string) logging.Tracer {
|
||||
return qlog.NewTracer(func(role logging.Perspective, connID []byte) io.WriteCloser {
|
||||
// create the QLOGDIR, if it doesn't exist
|
||||
if err := os.MkdirAll(qlogDir, 0777); err != nil {
|
||||
log.Errorf("creating the QLOGDIR failed: %s", err)
|
||||
return nil
|
||||
}
|
||||
return newQlogger(qlogDir, role, connID)
|
||||
})
|
||||
}
|
||||
|
||||
// The qlogger logs qlog events to a temporary file: .<name>.qlog.swp.
|
||||
// When it is closed, it compresses the temporary file and saves it as <name>.qlog.zst.
|
||||
// It is not possible to compress on the fly, as compression algorithms keep a lot of internal state,
|
||||
// which can easily exhaust the host system's memory when running a few hundred QUIC connections in parallel.
|
||||
type qlogger struct {
|
||||
f *os.File // QLOGDIR/.log_xxx.qlog.swp
|
||||
filename string // QLOGDIR/log_xxx.qlog.zst
|
||||
*bufio.Writer // buffering the f
|
||||
}
|
||||
|
||||
func newQlogger(qlogDir string, role logging.Perspective, connID []byte) io.WriteCloser {
|
||||
t := time.Now().UTC().Format("2006-01-02T15-04-05.999999999UTC")
|
||||
r := "server"
|
||||
if role == logging.PerspectiveClient {
|
||||
r = "client"
|
||||
}
|
||||
finalFilename := fmt.Sprintf("%s%clog_%s_%s_%x.qlog.zst", qlogDir, os.PathSeparator, t, r, connID)
|
||||
filename := fmt.Sprintf("%s%c.log_%s_%s_%x.qlog.swp", qlogDir, os.PathSeparator, t, r, connID)
|
||||
f, err := os.Create(filename)
|
||||
if err != nil {
|
||||
log.Errorf("unable to create qlog file %s: %s", filename, err)
|
||||
return nil
|
||||
}
|
||||
return &qlogger{
|
||||
f: f,
|
||||
filename: finalFilename,
|
||||
Writer: bufio.NewWriter(f),
|
||||
}
|
||||
}
|
||||
|
||||
func (l *qlogger) Close() error {
|
||||
defer os.Remove(l.f.Name())
|
||||
defer l.f.Close()
|
||||
if err := l.Writer.Flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := l.f.Seek(0, io.SeekStart); err != nil { // set the read position to the beginning of the file
|
||||
return err
|
||||
}
|
||||
f, err := os.Create(l.filename)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
buf := bufio.NewWriter(f)
|
||||
c, err := zstd.NewWriter(buf, zstd.WithEncoderLevel(zstd.SpeedFastest), zstd.WithWindowSize(32*1024))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := io.Copy(c, l.f); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
return buf.Flush()
|
||||
}
|
|
@ -0,0 +1,387 @@
|
|||
package libp2pquic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
"github.com/lucas-clemente/quic-go/logging"
|
||||
)
|
||||
|
||||
var (
|
||||
bytesTransferred *prometheus.CounterVec
|
||||
newConns *prometheus.CounterVec
|
||||
closedConns *prometheus.CounterVec
|
||||
sentPackets *prometheus.CounterVec
|
||||
rcvdPackets *prometheus.CounterVec
|
||||
bufferedPackets *prometheus.CounterVec
|
||||
droppedPackets *prometheus.CounterVec
|
||||
lostPackets *prometheus.CounterVec
|
||||
connErrors *prometheus.CounterVec
|
||||
)
|
||||
|
||||
type aggregatingCollector struct {
|
||||
mutex sync.Mutex
|
||||
|
||||
conns map[string] /* conn ID */ *metricsConnTracer
|
||||
rtts prometheus.Histogram
|
||||
connDurations prometheus.Histogram
|
||||
}
|
||||
|
||||
func newAggregatingCollector() *aggregatingCollector {
|
||||
return &aggregatingCollector{
|
||||
conns: make(map[string]*metricsConnTracer),
|
||||
rtts: prometheus.NewHistogram(prometheus.HistogramOpts{
|
||||
Name: "quic_smoothed_rtt",
|
||||
Help: "Smoothed RTT",
|
||||
Buckets: prometheus.ExponentialBuckets(0.001, 1.25, 40), // 1ms to ~6000ms
|
||||
}),
|
||||
connDurations: prometheus.NewHistogram(prometheus.HistogramOpts{
|
||||
Name: "quic_connection_duration",
|
||||
Help: "Connection Duration",
|
||||
Buckets: prometheus.ExponentialBuckets(1, 1.5, 40), // 1s to ~12 weeks
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
var _ prometheus.Collector = &aggregatingCollector{}
|
||||
|
||||
func (c *aggregatingCollector) Describe(descs chan<- *prometheus.Desc) {
|
||||
descs <- c.rtts.Desc()
|
||||
descs <- c.connDurations.Desc()
|
||||
}
|
||||
|
||||
func (c *aggregatingCollector) Collect(metrics chan<- prometheus.Metric) {
|
||||
now := time.Now()
|
||||
c.mutex.Lock()
|
||||
for _, conn := range c.conns {
|
||||
if rtt, valid := conn.getSmoothedRTT(); valid {
|
||||
c.rtts.Observe(rtt.Seconds())
|
||||
}
|
||||
c.connDurations.Observe(now.Sub(conn.startTime).Seconds())
|
||||
}
|
||||
c.mutex.Unlock()
|
||||
metrics <- c.rtts
|
||||
metrics <- c.connDurations
|
||||
}
|
||||
|
||||
func (c *aggregatingCollector) AddConn(id string, t *metricsConnTracer) {
|
||||
c.mutex.Lock()
|
||||
c.conns[id] = t
|
||||
c.mutex.Unlock()
|
||||
}
|
||||
|
||||
func (c *aggregatingCollector) RemoveConn(id string) {
|
||||
c.mutex.Lock()
|
||||
delete(c.conns, id)
|
||||
c.mutex.Unlock()
|
||||
}
|
||||
|
||||
var collector *aggregatingCollector
|
||||
|
||||
func init() {
|
||||
const (
|
||||
direction = "direction"
|
||||
encLevel = "encryption_level"
|
||||
)
|
||||
|
||||
closedConns = prometheus.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: "quic_connections_closed_total",
|
||||
Help: "closed QUIC connection",
|
||||
},
|
||||
[]string{direction},
|
||||
)
|
||||
prometheus.MustRegister(closedConns)
|
||||
newConns = prometheus.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: "quic_connections_new_total",
|
||||
Help: "new QUIC connection",
|
||||
},
|
||||
[]string{direction, "handshake_successful"},
|
||||
)
|
||||
prometheus.MustRegister(newConns)
|
||||
bytesTransferred = prometheus.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: "quic_transferred_bytes",
|
||||
Help: "QUIC bytes transferred",
|
||||
},
|
||||
[]string{direction}, // TODO: this is confusing. Other times, we use direction for the perspective
|
||||
)
|
||||
prometheus.MustRegister(bytesTransferred)
|
||||
sentPackets = prometheus.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: "quic_packets_sent_total",
|
||||
Help: "QUIC packets sent",
|
||||
},
|
||||
[]string{encLevel},
|
||||
)
|
||||
prometheus.MustRegister(sentPackets)
|
||||
rcvdPackets = prometheus.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: "quic_packets_rcvd_total",
|
||||
Help: "QUIC packets received",
|
||||
},
|
||||
[]string{encLevel},
|
||||
)
|
||||
prometheus.MustRegister(rcvdPackets)
|
||||
bufferedPackets = prometheus.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: "quic_packets_buffered_total",
|
||||
Help: "Buffered packets",
|
||||
},
|
||||
[]string{"packet_type"},
|
||||
)
|
||||
prometheus.MustRegister(bufferedPackets)
|
||||
droppedPackets = prometheus.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: "quic_packets_dropped_total",
|
||||
Help: "Dropped packets",
|
||||
},
|
||||
[]string{"packet_type", "reason"},
|
||||
)
|
||||
prometheus.MustRegister(droppedPackets)
|
||||
connErrors = prometheus.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: "quic_connection_errors_total",
|
||||
Help: "QUIC connection errors",
|
||||
},
|
||||
[]string{"side", "error_code"},
|
||||
)
|
||||
prometheus.MustRegister(connErrors)
|
||||
lostPackets = prometheus.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: "quic_packets_lost_total",
|
||||
Help: "QUIC lost received",
|
||||
},
|
||||
[]string{encLevel, "reason"},
|
||||
)
|
||||
prometheus.MustRegister(lostPackets)
|
||||
collector = newAggregatingCollector()
|
||||
prometheus.MustRegister(collector)
|
||||
}
|
||||
|
||||
type metricsTracer struct{}
|
||||
|
||||
var _ logging.Tracer = &metricsTracer{}
|
||||
|
||||
func (m *metricsTracer) TracerForConnection(_ context.Context, p logging.Perspective, connID logging.ConnectionID) logging.ConnectionTracer {
|
||||
return &metricsConnTracer{perspective: p, connID: connID}
|
||||
}
|
||||
|
||||
func (m *metricsTracer) SentPacket(_ net.Addr, _ *logging.Header, size logging.ByteCount, _ []logging.Frame) {
|
||||
bytesTransferred.WithLabelValues("sent").Add(float64(size))
|
||||
}
|
||||
|
||||
func (m *metricsTracer) DroppedPacket(addr net.Addr, packetType logging.PacketType, count logging.ByteCount, reason logging.PacketDropReason) {
|
||||
}
|
||||
|
||||
type metricsConnTracer struct {
|
||||
perspective logging.Perspective
|
||||
startTime time.Time
|
||||
connID logging.ConnectionID
|
||||
handshakeComplete bool
|
||||
|
||||
mutex sync.Mutex
|
||||
numRTTMeasurements int
|
||||
rtt time.Duration
|
||||
}
|
||||
|
||||
var _ logging.ConnectionTracer = &metricsConnTracer{}
|
||||
|
||||
func (m *metricsConnTracer) getDirection() string {
|
||||
if m.perspective == logging.PerspectiveClient {
|
||||
return "outgoing"
|
||||
}
|
||||
return "incoming"
|
||||
}
|
||||
|
||||
func (m *metricsConnTracer) getEncLevel(packetType logging.PacketType) string {
|
||||
switch packetType {
|
||||
case logging.PacketType0RTT:
|
||||
return "0-RTT"
|
||||
case logging.PacketTypeInitial:
|
||||
return "Initial"
|
||||
case logging.PacketTypeHandshake:
|
||||
return "Handshake"
|
||||
case logging.PacketTypeRetry:
|
||||
return "Retry"
|
||||
case logging.PacketType1RTT:
|
||||
return "1-RTT"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
func (m *metricsConnTracer) StartedConnection(net.Addr, net.Addr, logging.ConnectionID, logging.ConnectionID) {
|
||||
m.startTime = time.Now()
|
||||
collector.AddConn(m.connID.String(), m)
|
||||
}
|
||||
|
||||
func (m *metricsConnTracer) NegotiatedVersion(chosen quic.VersionNumber, clientVersions []quic.VersionNumber, serverVersions []quic.VersionNumber) {
|
||||
}
|
||||
|
||||
func (m *metricsConnTracer) ClosedConnection(e error) {
|
||||
var (
|
||||
applicationErr *quic.ApplicationError
|
||||
transportErr *quic.TransportError
|
||||
statelessResetErr *quic.StatelessResetError
|
||||
vnErr *quic.VersionNegotiationError
|
||||
idleTimeoutErr *quic.IdleTimeoutError
|
||||
handshakeTimeoutErr *quic.HandshakeTimeoutError
|
||||
remote bool
|
||||
desc string
|
||||
)
|
||||
|
||||
switch {
|
||||
case errors.As(e, &applicationErr):
|
||||
return
|
||||
case errors.As(e, &transportErr):
|
||||
remote = transportErr.Remote
|
||||
desc = transportErr.ErrorCode.String()
|
||||
case errors.As(e, &statelessResetErr):
|
||||
remote = true
|
||||
desc = "stateless_reset"
|
||||
case errors.As(e, &vnErr):
|
||||
desc = "version_negotiation"
|
||||
case errors.As(e, &idleTimeoutErr):
|
||||
desc = "idle_timeout"
|
||||
case errors.As(e, &handshakeTimeoutErr):
|
||||
desc = "handshake_timeout"
|
||||
default:
|
||||
desc = fmt.Sprintf("unknown error: %v", e)
|
||||
}
|
||||
|
||||
side := "local"
|
||||
if remote {
|
||||
side = "remote"
|
||||
}
|
||||
connErrors.WithLabelValues(side, desc).Inc()
|
||||
}
|
||||
func (m *metricsConnTracer) SentTransportParameters(parameters *logging.TransportParameters) {}
|
||||
func (m *metricsConnTracer) ReceivedTransportParameters(parameters *logging.TransportParameters) {}
|
||||
func (m *metricsConnTracer) RestoredTransportParameters(parameters *logging.TransportParameters) {}
|
||||
func (m *metricsConnTracer) SentPacket(hdr *logging.ExtendedHeader, size logging.ByteCount, _ *logging.AckFrame, _ []logging.Frame) {
|
||||
bytesTransferred.WithLabelValues("sent").Add(float64(size))
|
||||
sentPackets.WithLabelValues(m.getEncLevel(logging.PacketTypeFromHeader(&hdr.Header))).Inc()
|
||||
}
|
||||
|
||||
func (m *metricsConnTracer) ReceivedVersionNegotiationPacket(hdr *logging.Header, v []logging.VersionNumber) {
|
||||
bytesTransferred.WithLabelValues("rcvd").Add(float64(hdr.ParsedLen() + logging.ByteCount(4*len(v))))
|
||||
rcvdPackets.WithLabelValues("Version Negotiation").Inc()
|
||||
}
|
||||
|
||||
func (m *metricsConnTracer) ReceivedRetry(*logging.Header) {
|
||||
rcvdPackets.WithLabelValues("Retry").Inc()
|
||||
}
|
||||
|
||||
func (m *metricsConnTracer) ReceivedPacket(hdr *logging.ExtendedHeader, size logging.ByteCount, _ []logging.Frame) {
|
||||
bytesTransferred.WithLabelValues("rcvd").Add(float64(size))
|
||||
rcvdPackets.WithLabelValues(m.getEncLevel(logging.PacketTypeFromHeader(&hdr.Header))).Inc()
|
||||
}
|
||||
|
||||
func (m *metricsConnTracer) BufferedPacket(packetType logging.PacketType) {
|
||||
bufferedPackets.WithLabelValues(m.getEncLevel(packetType)).Inc()
|
||||
}
|
||||
|
||||
func (m *metricsConnTracer) DroppedPacket(packetType logging.PacketType, size logging.ByteCount, r logging.PacketDropReason) {
|
||||
bytesTransferred.WithLabelValues("rcvd").Add(float64(size))
|
||||
var reason string
|
||||
switch r {
|
||||
case logging.PacketDropKeyUnavailable:
|
||||
reason = "key_unavailable"
|
||||
case logging.PacketDropUnknownConnectionID:
|
||||
reason = "unknown_connection_id"
|
||||
case logging.PacketDropHeaderParseError:
|
||||
reason = "header_parse_error"
|
||||
case logging.PacketDropPayloadDecryptError:
|
||||
reason = "payload_decrypt_error"
|
||||
case logging.PacketDropProtocolViolation:
|
||||
reason = "protocol_violation"
|
||||
case logging.PacketDropDOSPrevention:
|
||||
reason = "dos_prevention"
|
||||
case logging.PacketDropUnsupportedVersion:
|
||||
reason = "unsupported_version"
|
||||
case logging.PacketDropUnexpectedPacket:
|
||||
reason = "unexpected_packet"
|
||||
case logging.PacketDropUnexpectedSourceConnectionID:
|
||||
reason = "unexpected_source_connection_id"
|
||||
case logging.PacketDropUnexpectedVersion:
|
||||
reason = "unexpected_version"
|
||||
case logging.PacketDropDuplicate:
|
||||
reason = "duplicate"
|
||||
default:
|
||||
reason = "unknown"
|
||||
}
|
||||
droppedPackets.WithLabelValues(m.getEncLevel(packetType), reason).Inc()
|
||||
}
|
||||
|
||||
func (m *metricsConnTracer) UpdatedMetrics(rttStats *logging.RTTStats, cwnd, bytesInFlight logging.ByteCount, packetsInFlight int) {
|
||||
m.mutex.Lock()
|
||||
m.rtt = rttStats.SmoothedRTT()
|
||||
m.numRTTMeasurements++
|
||||
m.mutex.Unlock()
|
||||
}
|
||||
|
||||
func (m *metricsConnTracer) AcknowledgedPacket(logging.EncryptionLevel, logging.PacketNumber) {}
|
||||
|
||||
func (m *metricsConnTracer) LostPacket(level logging.EncryptionLevel, _ logging.PacketNumber, r logging.PacketLossReason) {
|
||||
var reason string
|
||||
switch r {
|
||||
case logging.PacketLossReorderingThreshold:
|
||||
reason = "reordering_threshold"
|
||||
case logging.PacketLossTimeThreshold:
|
||||
reason = "time_threshold"
|
||||
default:
|
||||
reason = "unknown"
|
||||
}
|
||||
lostPackets.WithLabelValues(level.String(), reason).Inc()
|
||||
}
|
||||
|
||||
func (m *metricsConnTracer) UpdatedCongestionState(state logging.CongestionState) {}
|
||||
func (m *metricsConnTracer) UpdatedPTOCount(value uint32) {}
|
||||
func (m *metricsConnTracer) UpdatedKeyFromTLS(level logging.EncryptionLevel, perspective logging.Perspective) {
|
||||
}
|
||||
func (m *metricsConnTracer) UpdatedKey(generation logging.KeyPhase, remote bool) {}
|
||||
func (m *metricsConnTracer) DroppedEncryptionLevel(level logging.EncryptionLevel) {
|
||||
if level == logging.EncryptionHandshake {
|
||||
m.handleHandshakeComplete()
|
||||
}
|
||||
}
|
||||
func (m *metricsConnTracer) DroppedKey(generation logging.KeyPhase) {}
|
||||
func (m *metricsConnTracer) SetLossTimer(timerType logging.TimerType, level logging.EncryptionLevel, time time.Time) {
|
||||
}
|
||||
|
||||
func (m *metricsConnTracer) LossTimerExpired(timerType logging.TimerType, level logging.EncryptionLevel) {
|
||||
}
|
||||
func (m *metricsConnTracer) LossTimerCanceled() {}
|
||||
|
||||
func (m *metricsConnTracer) Close() {
|
||||
if m.handshakeComplete {
|
||||
closedConns.WithLabelValues(m.getDirection()).Inc()
|
||||
} else {
|
||||
newConns.WithLabelValues(m.getDirection(), "false").Inc()
|
||||
}
|
||||
collector.RemoveConn(m.connID.String())
|
||||
}
|
||||
|
||||
func (m *metricsConnTracer) Debug(name, msg string) {}
|
||||
|
||||
func (m *metricsConnTracer) handleHandshakeComplete() {
|
||||
m.handshakeComplete = true
|
||||
newConns.WithLabelValues(m.getDirection(), "true").Inc()
|
||||
}
|
||||
|
||||
func (m *metricsConnTracer) getSmoothedRTT() (rtt time.Duration, valid bool) {
|
||||
m.mutex.Lock()
|
||||
rtt = m.rtt
|
||||
valid = m.numRTTMeasurements > 10
|
||||
m.mutex.Unlock()
|
||||
return
|
||||
}
|
|
@ -0,0 +1,74 @@
|
|||
package libp2pquic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/klauspost/compress/zstd"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/logging"
|
||||
)
|
||||
|
||||
func createLogDir(t *testing.T) string {
|
||||
dir, err := ioutil.TempDir("", "libp2p-quic-transport-test")
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { os.RemoveAll(dir) })
|
||||
return dir
|
||||
}
|
||||
|
||||
func getFile(t *testing.T, dir string) os.FileInfo {
|
||||
files, err := ioutil.ReadDir(dir)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, files, 1)
|
||||
return files[0]
|
||||
}
|
||||
|
||||
func TestSaveQlog(t *testing.T) {
|
||||
qlogDir := createLogDir(t)
|
||||
logger := newQlogger(qlogDir, logging.PerspectiveServer, []byte{0xde, 0xad, 0xbe, 0xef})
|
||||
file := getFile(t, qlogDir)
|
||||
require.Equal(t, string(file.Name()[0]), ".")
|
||||
require.Truef(t, strings.HasSuffix(file.Name(), ".qlog.swp"), "expected %s to have the .qlog.swp file ending", file.Name())
|
||||
// close the logger. This should move the file.
|
||||
require.NoError(t, logger.Close())
|
||||
file = getFile(t, qlogDir)
|
||||
require.NotEqual(t, string(file.Name()[0]), ".")
|
||||
require.Truef(t, strings.HasSuffix(file.Name(), ".qlog.zst"), "expected %s to have the .qlog.zst file ending", file.Name())
|
||||
require.Contains(t, file.Name(), "server")
|
||||
require.Contains(t, file.Name(), "deadbeef")
|
||||
}
|
||||
|
||||
func TestQlogBuffering(t *testing.T) {
|
||||
qlogDir := createLogDir(t)
|
||||
logger := newQlogger(qlogDir, logging.PerspectiveServer, []byte("connid"))
|
||||
initialSize := getFile(t, qlogDir).Size()
|
||||
// Do a small write.
|
||||
// Since the writter is buffered, this should not be written to disk yet.
|
||||
logger.Write([]byte("foobar"))
|
||||
require.Equal(t, getFile(t, qlogDir).Size(), initialSize)
|
||||
// Close the logger. This should flush the buffer to disk.
|
||||
require.NoError(t, logger.Close())
|
||||
finalSize := getFile(t, qlogDir).Size()
|
||||
t.Logf("initial log file size: %d, final log file size: %d\n", initialSize, finalSize)
|
||||
require.Greater(t, finalSize, initialSize)
|
||||
}
|
||||
|
||||
func TestQlogCompression(t *testing.T) {
|
||||
qlogDir := createLogDir(t)
|
||||
logger := newQlogger(qlogDir, logging.PerspectiveServer, []byte("connid"))
|
||||
logger.Write([]byte("foobar"))
|
||||
require.NoError(t, logger.Close())
|
||||
compressed, err := ioutil.ReadFile(qlogDir + "/" + getFile(t, qlogDir).Name())
|
||||
require.NoError(t, err)
|
||||
require.NotEqual(t, compressed, "foobar")
|
||||
c, err := zstd.NewReader(bytes.NewReader(compressed))
|
||||
require.NoError(t, err)
|
||||
data, err := ioutil.ReadAll(c)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, data, []byte("foobar"))
|
||||
}
|
|
@ -0,0 +1,409 @@
|
|||
package libp2pquic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/hkdf"
|
||||
|
||||
"github.com/libp2p/go-libp2p-core/connmgr"
|
||||
ic "github.com/libp2p/go-libp2p-core/crypto"
|
||||
"github.com/libp2p/go-libp2p-core/network"
|
||||
"github.com/libp2p/go-libp2p-core/peer"
|
||||
"github.com/libp2p/go-libp2p-core/pnet"
|
||||
tpt "github.com/libp2p/go-libp2p-core/transport"
|
||||
|
||||
p2ptls "github.com/libp2p/go-libp2p-tls"
|
||||
|
||||
ma "github.com/multiformats/go-multiaddr"
|
||||
mafmt "github.com/multiformats/go-multiaddr-fmt"
|
||||
manet "github.com/multiformats/go-multiaddr/net"
|
||||
|
||||
logging "github.com/ipfs/go-log/v2"
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
"github.com/minio/sha256-simd"
|
||||
)
|
||||
|
||||
var log = logging.Logger("quic-transport")
|
||||
|
||||
var ErrHolePunching = errors.New("hole punching attempted; no active dial")
|
||||
|
||||
var quicDialContext = quic.DialContext // so we can mock it in tests
|
||||
|
||||
var HolePunchTimeout = 5 * time.Second
|
||||
|
||||
var quicConfig = &quic.Config{
|
||||
MaxIncomingStreams: 256,
|
||||
MaxIncomingUniStreams: -1, // disable unidirectional streams
|
||||
MaxStreamReceiveWindow: 10 * (1 << 20), // 10 MB
|
||||
MaxConnectionReceiveWindow: 15 * (1 << 20), // 15 MB
|
||||
AcceptToken: func(clientAddr net.Addr, _ *quic.Token) bool {
|
||||
// TODO(#6): require source address validation when under load
|
||||
return true
|
||||
},
|
||||
KeepAlive: true,
|
||||
Versions: []quic.VersionNumber{quic.VersionDraft29, quic.Version1},
|
||||
}
|
||||
|
||||
const statelessResetKeyInfo = "libp2p quic stateless reset key"
|
||||
const errorCodeConnectionGating = 0x47415445 // GATE in ASCII
|
||||
|
||||
type connManager struct {
|
||||
reuseUDP4 *reuse
|
||||
reuseUDP6 *reuse
|
||||
}
|
||||
|
||||
func newConnManager() (*connManager, error) {
|
||||
reuseUDP4 := newReuse()
|
||||
reuseUDP6 := newReuse()
|
||||
|
||||
return &connManager{
|
||||
reuseUDP4: reuseUDP4,
|
||||
reuseUDP6: reuseUDP6,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *connManager) getReuse(network string) (*reuse, error) {
|
||||
switch network {
|
||||
case "udp4":
|
||||
return c.reuseUDP4, nil
|
||||
case "udp6":
|
||||
return c.reuseUDP6, nil
|
||||
default:
|
||||
return nil, errors.New("invalid network: must be either udp4 or udp6")
|
||||
}
|
||||
}
|
||||
|
||||
func (c *connManager) Listen(network string, laddr *net.UDPAddr) (*reuseConn, error) {
|
||||
reuse, err := c.getReuse(network)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return reuse.Listen(network, laddr)
|
||||
}
|
||||
|
||||
func (c *connManager) Dial(network string, raddr *net.UDPAddr) (*reuseConn, error) {
|
||||
reuse, err := c.getReuse(network)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return reuse.Dial(network, raddr)
|
||||
}
|
||||
|
||||
func (c *connManager) Close() error {
|
||||
if err := c.reuseUDP6.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
return c.reuseUDP4.Close()
|
||||
}
|
||||
|
||||
// The Transport implements the tpt.Transport interface for QUIC connections.
|
||||
type transport struct {
|
||||
privKey ic.PrivKey
|
||||
localPeer peer.ID
|
||||
identity *p2ptls.Identity
|
||||
connManager *connManager
|
||||
serverConfig *quic.Config
|
||||
clientConfig *quic.Config
|
||||
gater connmgr.ConnectionGater
|
||||
rcmgr network.ResourceManager
|
||||
|
||||
holePunchingMx sync.Mutex
|
||||
holePunching map[holePunchKey]*activeHolePunch
|
||||
|
||||
connMx sync.Mutex
|
||||
conns map[quic.Connection]*conn
|
||||
}
|
||||
|
||||
var _ tpt.Transport = &transport{}
|
||||
|
||||
type holePunchKey struct {
|
||||
addr string
|
||||
peer peer.ID
|
||||
}
|
||||
|
||||
type activeHolePunch struct {
|
||||
connCh chan tpt.CapableConn
|
||||
fulfilled bool
|
||||
}
|
||||
|
||||
// NewTransport creates a new QUIC transport
|
||||
func NewTransport(key ic.PrivKey, psk pnet.PSK, gater connmgr.ConnectionGater, rcmgr network.ResourceManager) (tpt.Transport, error) {
|
||||
if len(psk) > 0 {
|
||||
log.Error("QUIC doesn't support private networks yet.")
|
||||
return nil, errors.New("QUIC doesn't support private networks yet")
|
||||
}
|
||||
localPeer, err := peer.IDFromPrivateKey(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
identity, err := p2ptls.NewIdentity(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
connManager, err := newConnManager()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if rcmgr == nil {
|
||||
rcmgr = network.NullResourceManager
|
||||
}
|
||||
config := quicConfig.Clone()
|
||||
keyBytes, err := key.Raw()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
keyReader := hkdf.New(sha256.New, keyBytes, nil, []byte(statelessResetKeyInfo))
|
||||
config.StatelessResetKey = make([]byte, 32)
|
||||
if _, err := io.ReadFull(keyReader, config.StatelessResetKey); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
config.Tracer = tracer
|
||||
|
||||
tr := &transport{
|
||||
privKey: key,
|
||||
localPeer: localPeer,
|
||||
identity: identity,
|
||||
connManager: connManager,
|
||||
gater: gater,
|
||||
rcmgr: rcmgr,
|
||||
conns: make(map[quic.Connection]*conn),
|
||||
holePunching: make(map[holePunchKey]*activeHolePunch),
|
||||
}
|
||||
config.AllowConnectionWindowIncrease = tr.allowWindowIncrease
|
||||
tr.serverConfig = config
|
||||
tr.clientConfig = config.Clone()
|
||||
return tr, nil
|
||||
}
|
||||
|
||||
// Dial dials a new QUIC connection
|
||||
func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tpt.CapableConn, error) {
|
||||
netw, host, err := manet.DialArgs(raddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
addr, err := net.ResolveUDPAddr(netw, host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
remoteMultiaddr, err := toQuicMultiaddr(addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tlsConf, keyCh := t.identity.ConfigForPeer(p)
|
||||
if ok, isClient, _ := network.GetSimultaneousConnect(ctx); ok && !isClient {
|
||||
return t.holePunch(ctx, netw, addr, p)
|
||||
}
|
||||
|
||||
scope, err := t.rcmgr.OpenConnection(network.DirOutbound, false)
|
||||
if err != nil {
|
||||
log.Debugw("resource manager blocked outgoing connection", "peer", p, "addr", raddr, "error", err)
|
||||
return nil, err
|
||||
}
|
||||
if err := scope.SetPeer(p); err != nil {
|
||||
log.Debugw("resource manager blocked outgoing connection for peer", "peer", p, "addr", raddr, "error", err)
|
||||
scope.Done()
|
||||
return nil, err
|
||||
}
|
||||
pconn, err := t.connManager.Dial(netw, addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
qconn, err := quicDialContext(ctx, pconn, addr, host, tlsConf, t.clientConfig)
|
||||
if err != nil {
|
||||
scope.Done()
|
||||
pconn.DecreaseCount()
|
||||
return nil, err
|
||||
}
|
||||
// Should be ready by this point, don't block.
|
||||
var remotePubKey ic.PubKey
|
||||
select {
|
||||
case remotePubKey = <-keyCh:
|
||||
default:
|
||||
}
|
||||
if remotePubKey == nil {
|
||||
pconn.DecreaseCount()
|
||||
scope.Done()
|
||||
return nil, errors.New("p2p/transport/quic BUG: expected remote pub key to be set")
|
||||
}
|
||||
|
||||
localMultiaddr, err := toQuicMultiaddr(pconn.LocalAddr())
|
||||
if err != nil {
|
||||
qconn.CloseWithError(0, "")
|
||||
return nil, err
|
||||
}
|
||||
c := &conn{
|
||||
quicConn: qconn,
|
||||
pconn: pconn,
|
||||
transport: t,
|
||||
scope: scope,
|
||||
privKey: t.privKey,
|
||||
localPeer: t.localPeer,
|
||||
localMultiaddr: localMultiaddr,
|
||||
remotePubKey: remotePubKey,
|
||||
remotePeerID: p,
|
||||
remoteMultiaddr: remoteMultiaddr,
|
||||
}
|
||||
if t.gater != nil && !t.gater.InterceptSecured(network.DirOutbound, p, c) {
|
||||
qconn.CloseWithError(errorCodeConnectionGating, "connection gated")
|
||||
return nil, fmt.Errorf("secured connection gated")
|
||||
}
|
||||
t.addConn(qconn, c)
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (t *transport) addConn(conn quic.Connection, c *conn) {
|
||||
t.connMx.Lock()
|
||||
t.conns[conn] = c
|
||||
t.connMx.Unlock()
|
||||
}
|
||||
|
||||
func (t *transport) removeConn(conn quic.Connection) {
|
||||
t.connMx.Lock()
|
||||
delete(t.conns, conn)
|
||||
t.connMx.Unlock()
|
||||
}
|
||||
|
||||
func (t *transport) holePunch(ctx context.Context, network string, addr *net.UDPAddr, p peer.ID) (tpt.CapableConn, error) {
|
||||
pconn, err := t.connManager.Dial(network, addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer pconn.DecreaseCount()
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, HolePunchTimeout)
|
||||
defer cancel()
|
||||
|
||||
key := holePunchKey{addr: addr.String(), peer: p}
|
||||
t.holePunchingMx.Lock()
|
||||
if _, ok := t.holePunching[key]; ok {
|
||||
t.holePunchingMx.Unlock()
|
||||
return nil, fmt.Errorf("already punching hole for %s", addr)
|
||||
}
|
||||
connCh := make(chan tpt.CapableConn, 1)
|
||||
t.holePunching[key] = &activeHolePunch{connCh: connCh}
|
||||
t.holePunchingMx.Unlock()
|
||||
|
||||
var timer *time.Timer
|
||||
defer func() {
|
||||
if timer != nil {
|
||||
timer.Stop()
|
||||
}
|
||||
}()
|
||||
|
||||
payload := make([]byte, 64)
|
||||
var punchErr error
|
||||
loop:
|
||||
for i := 0; ; i++ {
|
||||
if _, err := rand.Read(payload); err != nil {
|
||||
punchErr = err
|
||||
break
|
||||
}
|
||||
if _, err := pconn.UDPConn.WriteToUDP(payload, addr); err != nil {
|
||||
punchErr = err
|
||||
break
|
||||
}
|
||||
|
||||
maxSleep := 10 * (i + 1) * (i + 1) // in ms
|
||||
if maxSleep > 200 {
|
||||
maxSleep = 200
|
||||
}
|
||||
d := 10*time.Millisecond + time.Duration(rand.Intn(maxSleep))*time.Millisecond
|
||||
if timer == nil {
|
||||
timer = time.NewTimer(d)
|
||||
} else {
|
||||
timer.Reset(d)
|
||||
}
|
||||
select {
|
||||
case c := <-connCh:
|
||||
t.holePunchingMx.Lock()
|
||||
delete(t.holePunching, key)
|
||||
t.holePunchingMx.Unlock()
|
||||
return c, nil
|
||||
case <-timer.C:
|
||||
case <-ctx.Done():
|
||||
punchErr = ErrHolePunching
|
||||
break loop
|
||||
}
|
||||
}
|
||||
// we only arrive here if punchErr != nil
|
||||
t.holePunchingMx.Lock()
|
||||
defer func() {
|
||||
delete(t.holePunching, key)
|
||||
t.holePunchingMx.Unlock()
|
||||
}()
|
||||
select {
|
||||
case c := <-t.holePunching[key].connCh:
|
||||
return c, nil
|
||||
default:
|
||||
return nil, punchErr
|
||||
}
|
||||
}
|
||||
|
||||
// Don't use mafmt.QUIC as we don't want to dial DNS addresses. Just /ip{4,6}/udp/quic
|
||||
var dialMatcher = mafmt.And(mafmt.IP, mafmt.Base(ma.P_UDP), mafmt.Base(ma.P_QUIC))
|
||||
|
||||
// CanDial determines if we can dial to an address
|
||||
func (t *transport) CanDial(addr ma.Multiaddr) bool {
|
||||
return dialMatcher.Matches(addr)
|
||||
}
|
||||
|
||||
// Listen listens for new QUIC connections on the passed multiaddr.
|
||||
func (t *transport) Listen(addr ma.Multiaddr) (tpt.Listener, error) {
|
||||
lnet, host, err := manet.DialArgs(addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
laddr, err := net.ResolveUDPAddr(lnet, host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn, err := t.connManager.Listen(lnet, laddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ln, err := newListener(conn, t, t.localPeer, t.privKey, t.identity, t.rcmgr)
|
||||
if err != nil {
|
||||
conn.DecreaseCount()
|
||||
return nil, err
|
||||
}
|
||||
return ln, nil
|
||||
}
|
||||
|
||||
func (t *transport) allowWindowIncrease(conn quic.Connection, size uint64) bool {
|
||||
// If the QUIC connection tries to increase the window before we've inserted it
|
||||
// into our connections map (which we do right after dialing / accepting it),
|
||||
// we have no way to account for that memory. This should be very rare.
|
||||
// Block this attempt. The connection can request more memory later.
|
||||
t.connMx.Lock()
|
||||
c, ok := t.conns[conn]
|
||||
t.connMx.Unlock()
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return c.allowWindowIncrease(size)
|
||||
}
|
||||
|
||||
// Proxy returns true if this transport proxies.
|
||||
func (t *transport) Proxy() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// Protocols returns the set of protocols handled by this transport.
|
||||
func (t *transport) Protocols() []int {
|
||||
return []int{ma.P_QUIC}
|
||||
}
|
||||
|
||||
func (t *transport) String() string {
|
||||
return "QUIC"
|
||||
}
|
||||
|
||||
func (t *transport) Close() error {
|
||||
return t.connManager.Close()
|
||||
}
|
|
@ -0,0 +1,99 @@
|
|||
package libp2pquic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
ic "github.com/libp2p/go-libp2p-core/crypto"
|
||||
tpt "github.com/libp2p/go-libp2p-core/transport"
|
||||
ma "github.com/multiformats/go-multiaddr"
|
||||
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
)
|
||||
|
||||
func getTransport(t *testing.T) tpt.Transport {
|
||||
t.Helper()
|
||||
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
key, err := ic.UnmarshalRsaPrivateKey(x509.MarshalPKCS1PrivateKey(rsaKey))
|
||||
require.NoError(t, err)
|
||||
tr, err := NewTransport(key, nil, nil, nil)
|
||||
require.NoError(t, err)
|
||||
return tr
|
||||
}
|
||||
|
||||
func TestQUICProtocol(t *testing.T) {
|
||||
tr := getTransport(t)
|
||||
defer tr.(io.Closer).Close()
|
||||
|
||||
protocols := tr.Protocols()
|
||||
if len(protocols) != 1 {
|
||||
t.Fatalf("expected to only support a single protocol, got %v", protocols)
|
||||
}
|
||||
if protocols[0] != ma.P_QUIC {
|
||||
t.Fatalf("expected the supported protocol to be QUIC, got %d", protocols[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestCanDial(t *testing.T) {
|
||||
tr := getTransport(t)
|
||||
defer tr.(io.Closer).Close()
|
||||
|
||||
invalid := []string{
|
||||
"/ip4/127.0.0.1/udp/1234",
|
||||
"/ip4/5.5.5.5/tcp/1234",
|
||||
"/dns/google.com/udp/443/quic",
|
||||
}
|
||||
valid := []string{
|
||||
"/ip4/127.0.0.1/udp/1234/quic",
|
||||
"/ip4/5.5.5.5/udp/0/quic",
|
||||
}
|
||||
for _, s := range invalid {
|
||||
invalidAddr, err := ma.NewMultiaddr(s)
|
||||
require.NoError(t, err)
|
||||
if tr.CanDial(invalidAddr) {
|
||||
t.Errorf("didn't expect to be able to dial a non-quic address (%s)", invalidAddr)
|
||||
}
|
||||
}
|
||||
for _, s := range valid {
|
||||
validAddr, err := ma.NewMultiaddr(s)
|
||||
require.NoError(t, err)
|
||||
if !tr.CanDial(validAddr) {
|
||||
t.Errorf("expected to be able to dial QUIC address (%s)", validAddr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// The connection passed to quic-go needs to be type-assertable to a net.UDPConn,
|
||||
// in order to enable features like batch processing and ECN.
|
||||
func TestConnectionPassedToQUIC(t *testing.T) {
|
||||
tr := getTransport(t)
|
||||
defer tr.(io.Closer).Close()
|
||||
|
||||
origQuicDialContext := quicDialContext
|
||||
defer func() { quicDialContext = origQuicDialContext }()
|
||||
|
||||
var conn net.PacketConn
|
||||
quicDialContext = func(_ context.Context, c net.PacketConn, _ net.Addr, _ string, _ *tls.Config, _ *quic.Config) (quic.Connection, error) {
|
||||
conn = c
|
||||
return nil, errors.New("listen error")
|
||||
}
|
||||
remoteAddr, err := ma.NewMultiaddr("/ip4/127.0.0.1/udp/0/quic")
|
||||
require.NoError(t, err)
|
||||
_, err = tr.Dial(context.Background(), remoteAddr, "remote peer id")
|
||||
require.EqualError(t, err, "listen error")
|
||||
require.NotNil(t, conn)
|
||||
defer conn.Close()
|
||||
if _, ok := conn.(quic.OOBCapablePacketConn); !ok {
|
||||
t.Fatal("connection passed to quic-go cannot be type asserted to a *net.UDPConn")
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue