Merge pull request #1424 from libp2p/merge-quic

move go-libp2p-quic-transport here
This commit is contained in:
Marten Seemann 2022-04-22 17:59:23 +01:00 committed by GitHub
commit 5151d4b4fa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 2792 additions and 9 deletions

View File

@ -6,6 +6,7 @@ import (
"crypto/rand"
"github.com/libp2p/go-libp2p/p2p/net/connmgr"
quic "github.com/libp2p/go-libp2p/p2p/transport/quic"
"github.com/libp2p/go-libp2p/p2p/transport/tcp"
ws "github.com/libp2p/go-libp2p/p2p/transport/websocket"
@ -13,7 +14,6 @@ import (
noise "github.com/libp2p/go-libp2p-noise"
"github.com/libp2p/go-libp2p-peerstore/pstoremem"
quic "github.com/libp2p/go-libp2p-quic-transport"
rcmgr "github.com/libp2p/go-libp2p-resource-manager"
tls "github.com/libp2p/go-libp2p-tls"
yamux "github.com/libp2p/go-libp2p-yamux"

12
go.mod
View File

@ -6,12 +6,14 @@ require (
github.com/benbjohnson/clock v1.3.0
github.com/gogo/protobuf v1.3.2
github.com/golang/mock v1.6.0
github.com/google/gopacket v1.1.19
github.com/gorilla/websocket v1.5.0
github.com/hashicorp/golang-lru v0.5.4
github.com/ipfs/go-cid v0.1.0
github.com/ipfs/go-datastore v0.5.1
github.com/ipfs/go-ipfs-util v0.0.2
github.com/ipfs/go-log/v2 v2.5.1
github.com/klauspost/compress v1.15.1
github.com/libp2p/go-buffer-pool v0.0.2
github.com/libp2p/go-conn-security-multistream v0.3.0
github.com/libp2p/go-eventbus v0.2.1
@ -22,7 +24,6 @@ require (
github.com/libp2p/go-libp2p-nat v0.1.0
github.com/libp2p/go-libp2p-noise v0.4.0
github.com/libp2p/go-libp2p-peerstore v0.6.0
github.com/libp2p/go-libp2p-quic-transport v0.17.0
github.com/libp2p/go-libp2p-resource-manager v0.2.1
github.com/libp2p/go-libp2p-testing v0.9.2
github.com/libp2p/go-libp2p-tls v0.4.1
@ -34,8 +35,10 @@ require (
github.com/libp2p/go-reuseport-transport v0.1.0
github.com/libp2p/go-stream-muxer-multistream v0.4.0
github.com/libp2p/zeroconf/v2 v2.1.1
github.com/lucas-clemente/quic-go v0.27.0
github.com/marten-seemann/tcp v0.0.0-20210406111302-dfbc87cc63fd
github.com/mikioh/tcpinfo v0.0.0-20190314235526-30a79bb1804b
github.com/minio/sha256-simd v1.0.0
github.com/multiformats/go-multiaddr v0.5.0
github.com/multiformats/go-multiaddr-dns v0.3.1
github.com/multiformats/go-multiaddr-fmt v0.1.0
@ -47,6 +50,7 @@ require (
github.com/stretchr/testify v1.7.0
github.com/whyrusleeping/mdns v0.0.0-20190826153040-b9b60ed33aa9
github.com/whyrusleeping/multiaddr-filter v0.0.0-20160516205228-e903e4adabd7
golang.org/x/crypto v0.0.0-20220411220226-7b82a4e95df4
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c
)
@ -68,26 +72,24 @@ require (
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 // indirect
github.com/godbus/dbus/v5 v5.1.0 // indirect
github.com/golang/protobuf v1.5.2 // indirect
github.com/google/gopacket v1.1.19 // indirect
github.com/google/uuid v1.3.0 // indirect
github.com/huin/goupnp v1.0.3 // indirect
github.com/jackpal/go-nat-pmp v1.0.2 // indirect
github.com/jbenet/go-temp-err-catcher v0.1.0 // indirect
github.com/jbenet/goprocess v0.1.4 // indirect
github.com/klauspost/compress v1.15.1 // indirect
github.com/klauspost/cpuid/v2 v2.0.12 // indirect
github.com/koron/go-ssdp v0.0.2 // indirect
github.com/libp2p/go-cidranger v1.1.0 // indirect
github.com/libp2p/go-flow-metrics v0.0.3 // indirect
github.com/libp2p/go-libp2p-blankhost v0.3.0 // indirect
github.com/libp2p/go-libp2p-pnet v0.2.0 // indirect
github.com/libp2p/go-libp2p-quic-transport v0.17.0 // indirect
github.com/libp2p/go-libp2p-swarm v0.10.2 // indirect
github.com/libp2p/go-mplex v0.4.0 // indirect
github.com/libp2p/go-nat v0.1.0 // indirect
github.com/libp2p/go-openssl v0.0.7 // indirect
github.com/libp2p/go-tcp-transport v0.5.1 // indirect
github.com/libp2p/go-yamux/v3 v3.1.1 // indirect
github.com/lucas-clemente/quic-go v0.27.0 // indirect
github.com/marten-seemann/qtls-go1-16 v0.1.5 // indirect
github.com/marten-seemann/qtls-go1-17 v0.1.1 // indirect
github.com/marten-seemann/qtls-go1-18 v0.1.1 // indirect
@ -96,7 +98,6 @@ require (
github.com/miekg/dns v1.1.48 // indirect
github.com/mikioh/tcpopt v0.0.0-20190314235656-172688c1accc // indirect
github.com/minio/blake2b-simd v0.0.0-20160723061019-3f5f724cb5b1 // indirect
github.com/minio/sha256-simd v1.0.0 // indirect
github.com/mr-tron/base58 v1.2.0 // indirect
github.com/multiformats/go-base32 v0.0.4 // indirect
github.com/multiformats/go-base36 v0.1.0 // indirect
@ -117,7 +118,6 @@ require (
go.uber.org/atomic v1.9.0 // indirect
go.uber.org/multierr v1.8.0 // indirect
go.uber.org/zap v1.21.0 // indirect
golang.org/x/crypto v0.0.0-20220411220226-7b82a4e95df4 // indirect
golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3 // indirect
golang.org/x/net v0.0.0-20220418201149-a630d4f3e7a2 // indirect
golang.org/x/sys v0.0.0-20220412211240-33da011f77ad // indirect

View File

@ -8,6 +8,7 @@ import (
"testing"
"time"
quic "github.com/libp2p/go-libp2p/p2p/transport/quic"
"github.com/libp2p/go-libp2p/p2p/transport/tcp"
"github.com/libp2p/go-libp2p-core/peerstore"
@ -16,7 +17,6 @@ import (
csms "github.com/libp2p/go-conn-security-multistream"
"github.com/libp2p/go-libp2p-peerstore/pstoremem"
quic "github.com/libp2p/go-libp2p-quic-transport"
tnet "github.com/libp2p/go-libp2p-testing/net"
tptu "github.com/libp2p/go-libp2p-transport-upgrader"
yamux "github.com/libp2p/go-libp2p-yamux"

View File

@ -5,6 +5,7 @@ import (
"time"
"github.com/libp2p/go-libp2p/p2p/net/swarm"
quic "github.com/libp2p/go-libp2p/p2p/transport/quic"
"github.com/libp2p/go-libp2p/p2p/transport/tcp"
"github.com/libp2p/go-libp2p-core/connmgr"
@ -19,7 +20,6 @@ import (
csms "github.com/libp2p/go-conn-security-multistream"
"github.com/libp2p/go-libp2p-peerstore/pstoremem"
quic "github.com/libp2p/go-libp2p-quic-transport"
tnet "github.com/libp2p/go-libp2p-testing/net"
tptu "github.com/libp2p/go-libp2p-transport-upgrader"
yamux "github.com/libp2p/go-libp2p-yamux"

View File

@ -0,0 +1,71 @@
package main
import (
"context"
"crypto/rand"
"fmt"
"io/ioutil"
"log"
"os"
ic "github.com/libp2p/go-libp2p-core/crypto"
"github.com/libp2p/go-libp2p-core/peer"
libp2pquic "github.com/libp2p/go-libp2p/p2p/transport/quic"
ma "github.com/multiformats/go-multiaddr"
)
func main() {
if len(os.Args) != 3 {
fmt.Printf("Usage: %s <multiaddr> <peer id>", os.Args[0])
return
}
if err := run(os.Args[1], os.Args[2]); err != nil {
log.Fatalf(err.Error())
}
}
func run(raddr string, p string) error {
peerID, err := peer.Decode(p)
if err != nil {
return err
}
addr, err := ma.NewMultiaddr(raddr)
if err != nil {
return err
}
priv, _, err := ic.GenerateECDSAKeyPair(rand.Reader)
if err != nil {
return err
}
t, err := libp2pquic.NewTransport(priv, nil, nil, nil)
if err != nil {
return err
}
log.Printf("Dialing %s\n", addr.String())
conn, err := t.Dial(context.Background(), addr, peerID)
if err != nil {
return err
}
defer conn.Close()
str, err := conn.OpenStream(context.Background())
if err != nil {
return err
}
defer str.Close()
const msg = "Hello world!"
log.Printf("Sending: %s\n", msg)
if _, err := str.Write([]byte(msg)); err != nil {
return err
}
if err := str.CloseWrite(); err != nil {
return err
}
data, err := ioutil.ReadAll(str)
if err != nil {
return err
}
log.Printf("Received: %s\n", data)
return nil
}

View File

@ -0,0 +1,79 @@
package main
import (
"crypto/rand"
"fmt"
"io/ioutil"
"log"
"os"
ic "github.com/libp2p/go-libp2p-core/crypto"
"github.com/libp2p/go-libp2p-core/peer"
tpt "github.com/libp2p/go-libp2p-core/transport"
libp2pquic "github.com/libp2p/go-libp2p/p2p/transport/quic"
ma "github.com/multiformats/go-multiaddr"
)
func main() {
if len(os.Args) != 2 {
fmt.Printf("Usage: %s <port>", os.Args[0])
return
}
if err := run(os.Args[1]); err != nil {
log.Fatalf(err.Error())
}
}
func run(port string) error {
addr, err := ma.NewMultiaddr(fmt.Sprintf("/ip4/0.0.0.0/udp/%s/quic", port))
if err != nil {
return err
}
priv, _, err := ic.GenerateECDSAKeyPair(rand.Reader)
if err != nil {
return err
}
peerID, err := peer.IDFromPrivateKey(priv)
if err != nil {
return err
}
t, err := libp2pquic.NewTransport(priv, nil, nil, nil)
if err != nil {
return err
}
ln, err := t.Listen(addr)
if err != nil {
return err
}
fmt.Printf("Listening. Now run: go run cmd/client/main.go %s %s\n", ln.Multiaddr(), peerID)
for {
conn, err := ln.Accept()
if err != nil {
return err
}
log.Printf("Accepted new connection from %s (%s)\n", conn.RemotePeer(), conn.RemoteMultiaddr())
go func() {
if err := handleConn(conn); err != nil {
log.Printf("handling conn failed: %s", err.Error())
}
}()
}
}
func handleConn(conn tpt.CapableConn) error {
str, err := conn.AcceptStream()
if err != nil {
return err
}
data, err := ioutil.ReadAll(str)
if err != nil {
return err
}
log.Printf("Received: %s\n", data)
if _, err := str.Write(data); err != nil {
return err
}
return str.Close()
}

100
p2p/transport/quic/conn.go Normal file
View File

@ -0,0 +1,100 @@
package libp2pquic
import (
"context"
ic "github.com/libp2p/go-libp2p-core/crypto"
"github.com/libp2p/go-libp2p-core/network"
"github.com/libp2p/go-libp2p-core/peer"
tpt "github.com/libp2p/go-libp2p-core/transport"
"github.com/lucas-clemente/quic-go"
ma "github.com/multiformats/go-multiaddr"
)
type conn struct {
quicConn quic.Connection
pconn *reuseConn
transport *transport
scope network.ConnManagementScope
localPeer peer.ID
privKey ic.PrivKey
localMultiaddr ma.Multiaddr
remotePeerID peer.ID
remotePubKey ic.PubKey
remoteMultiaddr ma.Multiaddr
}
var _ tpt.CapableConn = &conn{}
// Close closes the connection.
// It must be called even if the peer closed the connection in order for
// garbage collection to properly work in this package.
func (c *conn) Close() error {
c.transport.removeConn(c.quicConn)
err := c.quicConn.CloseWithError(0, "")
c.pconn.DecreaseCount()
c.scope.Done()
return err
}
// IsClosed returns whether a connection is fully closed.
func (c *conn) IsClosed() bool {
return c.quicConn.Context().Err() != nil
}
func (c *conn) allowWindowIncrease(size uint64) bool {
return c.scope.ReserveMemory(int(size), network.ReservationPriorityMedium) == nil
}
// OpenStream creates a new stream.
func (c *conn) OpenStream(ctx context.Context) (network.MuxedStream, error) {
qstr, err := c.quicConn.OpenStreamSync(ctx)
return &stream{Stream: qstr}, err
}
// AcceptStream accepts a stream opened by the other side.
func (c *conn) AcceptStream() (network.MuxedStream, error) {
qstr, err := c.quicConn.AcceptStream(context.Background())
return &stream{Stream: qstr}, err
}
// LocalPeer returns our peer ID
func (c *conn) LocalPeer() peer.ID {
return c.localPeer
}
// LocalPrivateKey returns our private key
func (c *conn) LocalPrivateKey() ic.PrivKey {
return c.privKey
}
// RemotePeer returns the peer ID of the remote peer.
func (c *conn) RemotePeer() peer.ID {
return c.remotePeerID
}
// RemotePublicKey returns the public key of the remote peer.
func (c *conn) RemotePublicKey() ic.PubKey {
return c.remotePubKey
}
// LocalMultiaddr returns the local Multiaddr associated
func (c *conn) LocalMultiaddr() ma.Multiaddr {
return c.localMultiaddr
}
// RemoteMultiaddr returns the remote Multiaddr associated
func (c *conn) RemoteMultiaddr() ma.Multiaddr {
return c.remoteMultiaddr
}
func (c *conn) Transport() tpt.Transport {
return c.transport
}
func (c *conn) Scope() network.ConnScope {
return c.scope
}

View File

@ -0,0 +1,567 @@
package libp2pquic
import (
"bytes"
"context"
"crypto/rand"
"errors"
"fmt"
"io"
"io/ioutil"
mrand "math/rand"
"net"
"sync/atomic"
"testing"
"time"
ic "github.com/libp2p/go-libp2p-core/crypto"
"github.com/libp2p/go-libp2p-core/network"
"github.com/libp2p/go-libp2p-core/peer"
tpt "github.com/libp2p/go-libp2p-core/transport"
mocknetwork "github.com/libp2p/go-libp2p-testing/mocks/network"
quicproxy "github.com/lucas-clemente/quic-go/integrationtests/tools/proxy"
ma "github.com/multiformats/go-multiaddr"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
)
//go:generate sh -c "mockgen -package libp2pquic -destination mock_connection_gater_test.go github.com/libp2p/go-libp2p-core/connmgr ConnectionGater && goimports -w mock_connection_gater_test.go"
func createPeer(t *testing.T) (peer.ID, ic.PrivKey) {
var priv ic.PrivKey
var err error
switch mrand.Int() % 4 {
case 0:
priv, _, err = ic.GenerateECDSAKeyPair(rand.Reader)
case 1:
priv, _, err = ic.GenerateRSAKeyPair(2048, rand.Reader)
case 2:
priv, _, err = ic.GenerateEd25519Key(rand.Reader)
case 3:
priv, _, err = ic.GenerateSecp256k1Key(rand.Reader)
}
require.NoError(t, err)
id, err := peer.IDFromPrivateKey(priv)
require.NoError(t, err)
t.Logf("using a %s key: %s", priv.Type(), id.Pretty())
return id, priv
}
func runServer(t *testing.T, tr tpt.Transport, addr string) tpt.Listener {
t.Helper()
ln, err := tr.Listen(ma.StringCast(addr))
require.NoError(t, err)
return ln
}
func TestHandshake(t *testing.T) {
serverID, serverKey := createPeer(t)
clientID, clientKey := createPeer(t)
serverTransport, err := NewTransport(serverKey, nil, nil, nil)
require.NoError(t, err)
defer serverTransport.(io.Closer).Close()
handshake := func(t *testing.T, ln tpt.Listener) {
clientTransport, err := NewTransport(clientKey, nil, nil, nil)
require.NoError(t, err)
defer clientTransport.(io.Closer).Close()
conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID)
require.NoError(t, err)
defer conn.Close()
serverConn, err := ln.Accept()
require.NoError(t, err)
defer serverConn.Close()
require.Equal(t, conn.LocalPeer(), clientID)
require.True(t, conn.LocalPrivateKey().Equals(clientKey), "local private key doesn't match")
require.Equal(t, conn.RemotePeer(), serverID)
require.True(t, conn.RemotePublicKey().Equals(serverKey.GetPublic()), "remote public key doesn't match")
require.Equal(t, serverConn.LocalPeer(), serverID)
require.True(t, serverConn.LocalPrivateKey().Equals(serverKey), "local private key doesn't match")
require.Equal(t, serverConn.RemotePeer(), clientID)
require.True(t, serverConn.RemotePublicKey().Equals(clientKey.GetPublic()), "remote public key doesn't match")
}
t.Run("on IPv4", func(t *testing.T) {
ln := runServer(t, serverTransport, "/ip4/127.0.0.1/udp/0/quic")
defer ln.Close()
handshake(t, ln)
})
t.Run("on IPv6", func(t *testing.T) {
ln := runServer(t, serverTransport, "/ip6/::1/udp/0/quic")
defer ln.Close()
handshake(t, ln)
})
}
func TestResourceManagerSuccess(t *testing.T) {
serverID, serverKey := createPeer(t)
clientID, clientKey := createPeer(t)
ctrl := gomock.NewController(t)
defer ctrl.Finish()
serverRcmgr := mocknetwork.NewMockResourceManager(ctrl)
serverTransport, err := NewTransport(serverKey, nil, nil, serverRcmgr)
require.NoError(t, err)
defer serverTransport.(io.Closer).Close()
ln, err := serverTransport.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic"))
require.NoError(t, err)
defer ln.Close()
clientRcmgr := mocknetwork.NewMockResourceManager(ctrl)
clientTransport, err := NewTransport(clientKey, nil, nil, clientRcmgr)
require.NoError(t, err)
defer clientTransport.(io.Closer).Close()
connChan := make(chan tpt.CapableConn)
serverConnScope := mocknetwork.NewMockConnManagementScope(ctrl)
go func() {
serverRcmgr.EXPECT().OpenConnection(network.DirInbound, false).Return(serverConnScope, nil)
serverConnScope.EXPECT().SetPeer(clientID)
serverConn, err := ln.Accept()
require.NoError(t, err)
connChan <- serverConn
}()
connScope := mocknetwork.NewMockConnManagementScope(ctrl)
clientRcmgr.EXPECT().OpenConnection(network.DirOutbound, false).Return(connScope, nil)
connScope.EXPECT().SetPeer(serverID)
conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID)
require.NoError(t, err)
serverConn := <-connChan
t.Log("received conn")
connScope.EXPECT().Done().MinTimes(1) // for dialed connections, we might call Done multiple times
conn.Close()
serverConnScope.EXPECT().Done()
serverConn.Close()
}
func TestResourceManagerDialDenied(t *testing.T) {
_, clientKey := createPeer(t)
ctrl := gomock.NewController(t)
defer ctrl.Finish()
rcmgr := mocknetwork.NewMockResourceManager(ctrl)
clientTransport, err := NewTransport(clientKey, nil, nil, rcmgr)
require.NoError(t, err)
defer clientTransport.(io.Closer).Close()
connScope := mocknetwork.NewMockConnManagementScope(ctrl)
rcmgr.EXPECT().OpenConnection(network.DirOutbound, false).Return(connScope, nil)
rerr := errors.New("nope")
p := peer.ID("server")
connScope.EXPECT().SetPeer(p).Return(rerr)
connScope.EXPECT().Done()
_, err = clientTransport.Dial(context.Background(), ma.StringCast("/ip4/127.0.0.1/udp/1234/quic"), p)
require.ErrorIs(t, err, rerr)
}
func TestResourceManagerAcceptDenied(t *testing.T) {
serverID, serverKey := createPeer(t)
clientID, clientKey := createPeer(t)
ctrl := gomock.NewController(t)
defer ctrl.Finish()
clientRcmgr := mocknetwork.NewMockResourceManager(ctrl)
clientTransport, err := NewTransport(clientKey, nil, nil, clientRcmgr)
require.NoError(t, err)
defer clientTransport.(io.Closer).Close()
serverRcmgr := mocknetwork.NewMockResourceManager(ctrl)
serverConnScope := mocknetwork.NewMockConnManagementScope(ctrl)
rerr := errors.New("denied")
gomock.InOrder(
serverRcmgr.EXPECT().OpenConnection(network.DirInbound, false).Return(serverConnScope, nil),
serverConnScope.EXPECT().SetPeer(clientID).Return(rerr),
serverConnScope.EXPECT().Done(),
)
serverTransport, err := NewTransport(serverKey, nil, nil, serverRcmgr)
require.NoError(t, err)
defer serverTransport.(io.Closer).Close()
ln, err := serverTransport.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic"))
require.NoError(t, err)
defer ln.Close()
connChan := make(chan tpt.CapableConn)
go func() {
ln.Accept()
close(connChan)
}()
clientConnScope := mocknetwork.NewMockConnManagementScope(ctrl)
clientRcmgr.EXPECT().OpenConnection(network.DirOutbound, false).Return(clientConnScope, nil)
clientConnScope.EXPECT().SetPeer(serverID)
conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID)
require.NoError(t, err)
_, err = conn.AcceptStream()
require.Error(t, err)
select {
case <-connChan:
t.Fatal("didn't expect to accept a connection")
default:
}
}
func TestStreams(t *testing.T) {
serverID, serverKey := createPeer(t)
_, clientKey := createPeer(t)
serverTransport, err := NewTransport(serverKey, nil, nil, nil)
require.NoError(t, err)
defer serverTransport.(io.Closer).Close()
ln := runServer(t, serverTransport, "/ip4/127.0.0.1/udp/0/quic")
defer ln.Close()
clientTransport, err := NewTransport(clientKey, nil, nil, nil)
require.NoError(t, err)
defer clientTransport.(io.Closer).Close()
conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID)
require.NoError(t, err)
defer conn.Close()
serverConn, err := ln.Accept()
require.NoError(t, err)
defer serverConn.Close()
str, err := conn.OpenStream(context.Background())
require.NoError(t, err)
_, err = str.Write([]byte("foobar"))
require.NoError(t, err)
str.Close()
sstr, err := serverConn.AcceptStream()
require.NoError(t, err)
data, err := ioutil.ReadAll(sstr)
require.NoError(t, err)
require.Equal(t, data, []byte("foobar"))
}
func TestHandshakeFailPeerIDMismatch(t *testing.T) {
_, serverKey := createPeer(t)
_, clientKey := createPeer(t)
thirdPartyID, _ := createPeer(t)
serverTransport, err := NewTransport(serverKey, nil, nil, nil)
require.NoError(t, err)
defer serverTransport.(io.Closer).Close()
ln := runServer(t, serverTransport, "/ip4/127.0.0.1/udp/0/quic")
clientTransport, err := NewTransport(clientKey, nil, nil, nil)
require.NoError(t, err)
// dial, but expect the wrong peer ID
_, err = clientTransport.Dial(context.Background(), ln.Multiaddr(), thirdPartyID)
require.Error(t, err)
require.Contains(t, err.Error(), "CRYPTO_ERROR")
defer clientTransport.(io.Closer).Close()
acceptErr := make(chan error)
go func() {
_, err := ln.Accept()
acceptErr <- err
}()
select {
case <-acceptErr:
t.Fatal("didn't expect Accept to return before being closed")
case <-time.After(100 * time.Millisecond):
}
require.NoError(t, ln.Close())
require.Error(t, <-acceptErr)
}
func TestConnectionGating(t *testing.T) {
serverID, serverKey := createPeer(t)
_, clientKey := createPeer(t)
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
cg := NewMockConnectionGater(mockCtrl)
t.Run("accepted connections", func(t *testing.T) {
serverTransport, err := NewTransport(serverKey, nil, cg, nil)
defer serverTransport.(io.Closer).Close()
require.NoError(t, err)
ln := runServer(t, serverTransport, "/ip4/127.0.0.1/udp/0/quic")
defer ln.Close()
cg.EXPECT().InterceptAccept(gomock.Any())
accepted := make(chan struct{})
go func() {
defer close(accepted)
_, err := ln.Accept()
require.NoError(t, err)
}()
clientTransport, err := NewTransport(clientKey, nil, nil, nil)
require.NoError(t, err)
defer clientTransport.(io.Closer).Close()
// make sure that connection attempts fails
conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID)
require.NoError(t, err)
_, err = conn.AcceptStream()
require.Error(t, err)
require.Contains(t, err.Error(), "connection gated")
// now allow the address and make sure the connection goes through
cg.EXPECT().InterceptAccept(gomock.Any()).Return(true)
cg.EXPECT().InterceptSecured(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
clientTransport.(*transport).clientConfig.HandshakeIdleTimeout = 2 * time.Second
conn, err = clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID)
require.NoError(t, err)
defer conn.Close()
require.Eventually(t, func() bool {
select {
case <-accepted:
return true
default:
return false
}
}, time.Second, 10*time.Millisecond)
})
t.Run("secured connections", func(t *testing.T) {
serverTransport, err := NewTransport(serverKey, nil, nil, nil)
require.NoError(t, err)
defer serverTransport.(io.Closer).Close()
ln := runServer(t, serverTransport, "/ip4/127.0.0.1/udp/0/quic")
defer ln.Close()
cg := NewMockConnectionGater(mockCtrl)
cg.EXPECT().InterceptSecured(gomock.Any(), gomock.Any(), gomock.Any())
clientTransport, err := NewTransport(clientKey, nil, cg, nil)
require.NoError(t, err)
defer clientTransport.(io.Closer).Close()
// make sure that connection attempts fails
_, err = clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID)
require.Error(t, err)
require.Contains(t, err.Error(), "connection gated")
// now allow the peerId and make sure the connection goes through
cg.EXPECT().InterceptSecured(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
clientTransport.(*transport).clientConfig.HandshakeIdleTimeout = 2 * time.Second
conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID)
require.NoError(t, err)
conn.Close()
})
}
func TestDialTwo(t *testing.T) {
serverID, serverKey := createPeer(t)
_, clientKey := createPeer(t)
serverID2, serverKey2 := createPeer(t)
serverTransport, err := NewTransport(serverKey, nil, nil, nil)
require.NoError(t, err)
defer serverTransport.(io.Closer).Close()
ln1 := runServer(t, serverTransport, "/ip4/127.0.0.1/udp/0/quic")
defer ln1.Close()
serverTransport2, err := NewTransport(serverKey2, nil, nil, nil)
require.NoError(t, err)
defer serverTransport2.(io.Closer).Close()
ln2 := runServer(t, serverTransport2, "/ip4/127.0.0.1/udp/0/quic")
defer ln2.Close()
data := bytes.Repeat([]byte{'a'}, 5*1<<20) // 5 MB
// wait for both servers to accept a connection
// then send some data
go func() {
serverConn1, err := ln1.Accept()
require.NoError(t, err)
serverConn2, err := ln2.Accept()
require.NoError(t, err)
for _, c := range []tpt.CapableConn{serverConn1, serverConn2} {
go func(conn tpt.CapableConn) {
str, err := conn.OpenStream(context.Background())
require.NoError(t, err)
defer str.Close()
_, err = str.Write(data)
require.NoError(t, err)
}(c)
}
}()
clientTransport, err := NewTransport(clientKey, nil, nil, nil)
require.NoError(t, err)
defer clientTransport.(io.Closer).Close()
c1, err := clientTransport.Dial(context.Background(), ln1.Multiaddr(), serverID)
require.NoError(t, err)
defer c1.Close()
c2, err := clientTransport.Dial(context.Background(), ln2.Multiaddr(), serverID2)
require.NoError(t, err)
defer c2.Close()
done := make(chan struct{}, 2)
// receive the data on both connections at the same time
for _, c := range []tpt.CapableConn{c1, c2} {
go func(conn tpt.CapableConn) {
str, err := conn.AcceptStream()
require.NoError(t, err)
str.CloseWrite()
d, err := ioutil.ReadAll(str)
require.NoError(t, err)
require.Equal(t, d, data)
done <- struct{}{}
}(c)
}
for i := 0; i < 2; i++ {
require.Eventually(t, func() bool {
select {
case <-done:
return true
default:
return false
}
}, 15*time.Second, 50*time.Millisecond)
}
}
func TestStatelessReset(t *testing.T) {
origGarbageCollectInterval := garbageCollectInterval
origMaxUnusedDuration := maxUnusedDuration
garbageCollectInterval = 50 * time.Millisecond
maxUnusedDuration = 0
t.Cleanup(func() {
garbageCollectInterval = origGarbageCollectInterval
maxUnusedDuration = origMaxUnusedDuration
})
serverID, serverKey := createPeer(t)
_, clientKey := createPeer(t)
serverTransport, err := NewTransport(serverKey, nil, nil, nil)
require.NoError(t, err)
defer serverTransport.(io.Closer).Close()
ln := runServer(t, serverTransport, "/ip4/127.0.0.1/udp/0/quic")
var drop uint32
serverPort := ln.Addr().(*net.UDPAddr).Port
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
DropPacket: func(quicproxy.Direction, []byte) bool {
return atomic.LoadUint32(&drop) > 0
},
})
require.NoError(t, err)
defer proxy.Close()
// establish a connection
clientTransport, err := NewTransport(clientKey, nil, nil, nil)
require.NoError(t, err)
defer clientTransport.(io.Closer).Close()
proxyAddr, err := toQuicMultiaddr(proxy.LocalAddr())
require.NoError(t, err)
conn, err := clientTransport.Dial(context.Background(), proxyAddr, serverID)
require.NoError(t, err)
connChan := make(chan tpt.CapableConn)
go func() {
conn, err := ln.Accept()
require.NoError(t, err)
str, err := conn.OpenStream(context.Background())
require.NoError(t, err)
str.Write([]byte("foobar"))
connChan <- conn
}()
str, err := conn.AcceptStream()
require.NoError(t, err)
_, err = str.Read(make([]byte, 6))
require.NoError(t, err)
// Stop forwarding packets and close the server.
// This prevents the CONNECTION_CLOSE from reaching the client.
atomic.StoreUint32(&drop, 1)
ln.Close()
(<-connChan).Close()
// require.NoError(t, ln.Close())
time.Sleep(2000 * time.Millisecond) // give the kernel some time to free the UDP port
ln = runServer(t, serverTransport, fmt.Sprintf("/ip4/127.0.0.1/udp/%d/quic", serverPort))
defer ln.Close()
// Now that the new server is up, re-enable packet forwarding.
atomic.StoreUint32(&drop, 0)
// Trigger something (not too small) to be sent, so that we receive the stateless reset.
// The new server doesn't have any state for the previously established connection.
// We expect it to send a stateless reset.
_, rerr := str.Write([]byte("Lorem ipsum dolor sit amet."))
if rerr == nil {
_, rerr = str.Read([]byte{0, 0})
}
require.Error(t, rerr)
require.Contains(t, rerr.Error(), "received a stateless reset")
}
func TestHolePunching(t *testing.T) {
serverID, serverKey := createPeer(t)
clientID, clientKey := createPeer(t)
t1, err := NewTransport(serverKey, nil, nil, nil)
require.NoError(t, err)
defer t1.(io.Closer).Close()
laddr, err := ma.NewMultiaddr("/ip4/127.0.0.1/udp/0/quic")
require.NoError(t, err)
ln1, err := t1.Listen(laddr)
require.NoError(t, err)
done1 := make(chan struct{})
go func() {
defer close(done1)
_, err := ln1.Accept()
require.Error(t, err, "didn't expect to accept any connections")
}()
t2, err := NewTransport(clientKey, nil, nil, nil)
require.NoError(t, err)
defer t2.(io.Closer).Close()
ln2, err := t2.Listen(laddr)
require.NoError(t, err)
done2 := make(chan struct{})
go func() {
defer close(done2)
_, err := ln2.Accept()
require.Error(t, err, "didn't expect to accept any connections")
}()
connChan := make(chan tpt.CapableConn)
go func() {
conn, err := t2.Dial(
network.WithSimultaneousConnect(context.Background(), false, ""),
ln1.Multiaddr(),
serverID,
)
require.NoError(t, err)
connChan <- conn
}()
conn1, err := t1.Dial(
network.WithSimultaneousConnect(context.Background(), true, ""),
ln2.Multiaddr(),
clientID,
)
require.NoError(t, err)
defer conn1.Close()
require.Equal(t, conn1.RemotePeer(), clientID)
var conn2 tpt.CapableConn
require.Eventually(t, func() bool {
select {
case conn2 = <-connChan:
return true
default:
return false
}
}, 100*time.Millisecond, 10*time.Millisecond)
defer conn2.Close()
require.Equal(t, conn2.RemotePeer(), serverID)
ln1.Close()
ln2.Close()
<-done1
<-done2
}

View File

@ -0,0 +1,160 @@
package libp2pquic
import (
"context"
"crypto/tls"
"net"
ic "github.com/libp2p/go-libp2p-core/crypto"
"github.com/libp2p/go-libp2p-core/network"
"github.com/libp2p/go-libp2p-core/peer"
tpt "github.com/libp2p/go-libp2p-core/transport"
p2ptls "github.com/libp2p/go-libp2p-tls"
"github.com/lucas-clemente/quic-go"
ma "github.com/multiformats/go-multiaddr"
)
var quicListen = quic.Listen // so we can mock it in tests
// A listener listens for QUIC connections.
type listener struct {
quicListener quic.Listener
conn *reuseConn
transport *transport
rcmgr network.ResourceManager
privKey ic.PrivKey
localPeer peer.ID
localMultiaddr ma.Multiaddr
}
var _ tpt.Listener = &listener{}
func newListener(rconn *reuseConn, t *transport, localPeer peer.ID, key ic.PrivKey, identity *p2ptls.Identity, rcmgr network.ResourceManager) (tpt.Listener, error) {
var tlsConf tls.Config
tlsConf.GetConfigForClient = func(_ *tls.ClientHelloInfo) (*tls.Config, error) {
// return a tls.Config that verifies the peer's certificate chain.
// Note that since we have no way of associating an incoming QUIC connection with
// the peer ID calculated here, we don't actually receive the peer's public key
// from the key chan.
conf, _ := identity.ConfigForPeer("")
return conf, nil
}
ln, err := quicListen(rconn, &tlsConf, t.serverConfig)
if err != nil {
return nil, err
}
localMultiaddr, err := toQuicMultiaddr(ln.Addr())
if err != nil {
return nil, err
}
return &listener{
conn: rconn,
quicListener: ln,
transport: t,
rcmgr: rcmgr,
privKey: key,
localPeer: localPeer,
localMultiaddr: localMultiaddr,
}, nil
}
// Accept accepts new connections.
func (l *listener) Accept() (tpt.CapableConn, error) {
for {
qconn, err := l.quicListener.Accept(context.Background())
if err != nil {
return nil, err
}
c, err := l.setupConn(qconn)
if err != nil {
qconn.CloseWithError(0, err.Error())
continue
}
if l.transport.gater != nil && !(l.transport.gater.InterceptAccept(c) && l.transport.gater.InterceptSecured(network.DirInbound, c.remotePeerID, c)) {
c.scope.Done()
qconn.CloseWithError(errorCodeConnectionGating, "connection gated")
continue
}
l.transport.addConn(qconn, c)
// return through active hole punching if any
key := holePunchKey{addr: qconn.RemoteAddr().String(), peer: c.remotePeerID}
var wasHolePunch bool
l.transport.holePunchingMx.Lock()
holePunch, ok := l.transport.holePunching[key]
if ok && !holePunch.fulfilled {
holePunch.connCh <- c
wasHolePunch = true
holePunch.fulfilled = true
}
l.transport.holePunchingMx.Unlock()
if wasHolePunch {
continue
}
return c, nil
}
}
func (l *listener) setupConn(qconn quic.Connection) (*conn, error) {
connScope, err := l.rcmgr.OpenConnection(network.DirInbound, false)
if err != nil {
log.Debugw("resource manager blocked incoming connection", "addr", qconn.RemoteAddr(), "error", err)
return nil, err
}
// The tls.Config used to establish this connection already verified the certificate chain.
// Since we don't have any way of knowing which tls.Config was used though,
// we have to re-determine the peer's identity here.
// Therefore, this is expected to never fail.
remotePubKey, err := p2ptls.PubKeyFromCertChain(qconn.ConnectionState().TLS.PeerCertificates)
if err != nil {
connScope.Done()
return nil, err
}
remotePeerID, err := peer.IDFromPublicKey(remotePubKey)
if err != nil {
connScope.Done()
return nil, err
}
if err := connScope.SetPeer(remotePeerID); err != nil {
log.Debugw("resource manager blocked incoming connection for peer", "peer", remotePeerID, "addr", qconn.RemoteAddr(), "error", err)
connScope.Done()
return nil, err
}
remoteMultiaddr, err := toQuicMultiaddr(qconn.RemoteAddr())
if err != nil {
connScope.Done()
return nil, err
}
l.conn.IncreaseCount()
return &conn{
quicConn: qconn,
pconn: l.conn,
transport: l.transport,
scope: connScope,
localPeer: l.localPeer,
localMultiaddr: l.localMultiaddr,
privKey: l.privKey,
remoteMultiaddr: remoteMultiaddr,
remotePeerID: remotePeerID,
remotePubKey: remotePubKey,
}, nil
}
// Close closes the listener.
func (l *listener) Close() error {
defer l.conn.DecreaseCount()
return l.quicListener.Close()
}
// Addr returns the address of this listener.
func (l *listener) Addr() net.Addr {
return l.quicListener.Addr()
}
// Multiaddr returns the multiaddress of this listener.
func (l *listener) Multiaddr() ma.Multiaddr {
return l.localMultiaddr
}

View File

@ -0,0 +1,115 @@
package libp2pquic
import (
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io"
"net"
"testing"
"time"
ic "github.com/libp2p/go-libp2p-core/crypto"
"github.com/libp2p/go-libp2p-core/network"
tpt "github.com/libp2p/go-libp2p-core/transport"
"github.com/lucas-clemente/quic-go"
ma "github.com/multiformats/go-multiaddr"
"github.com/stretchr/testify/require"
)
func newTransport(t *testing.T, rcmgr network.ResourceManager) tpt.Transport {
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
key, err := ic.UnmarshalRsaPrivateKey(x509.MarshalPKCS1PrivateKey(rsaKey))
require.NoError(t, err)
tr, err := NewTransport(key, nil, nil, rcmgr)
require.NoError(t, err)
return tr
}
// The conn passed to quic-go should be a conn that quic-go can be
// type-asserted to a UDPConn. That way, it can use all kinds of optimizations.
func TestConnUsedForListening(t *testing.T) {
origQuicListen := quicListen
t.Cleanup(func() { quicListen = origQuicListen })
var conn net.PacketConn
quicListen = func(c net.PacketConn, _ *tls.Config, _ *quic.Config) (quic.Listener, error) {
conn = c
return nil, errors.New("listen error")
}
localAddr, err := ma.NewMultiaddr("/ip4/127.0.0.1/udp/0/quic")
require.NoError(t, err)
tr := newTransport(t, nil)
defer tr.(io.Closer).Close()
_, err = tr.Listen(localAddr)
require.EqualError(t, err, "listen error")
require.NotNil(t, conn)
defer conn.Close()
_, ok := conn.(quic.OOBCapablePacketConn)
require.True(t, ok)
}
func TestListenAddr(t *testing.T) {
tr := newTransport(t, nil)
defer tr.(io.Closer).Close()
t.Run("for IPv4", func(t *testing.T) {
localAddr := ma.StringCast("/ip4/127.0.0.1/udp/0/quic")
ln, err := tr.Listen(localAddr)
require.NoError(t, err)
defer ln.Close()
port := ln.Addr().(*net.UDPAddr).Port
require.NotZero(t, port)
require.Equal(t, ln.Multiaddr().String(), fmt.Sprintf("/ip4/127.0.0.1/udp/%d/quic", port))
})
t.Run("for IPv6", func(t *testing.T) {
localAddr := ma.StringCast("/ip6/::/udp/0/quic")
ln, err := tr.Listen(localAddr)
require.NoError(t, err)
defer ln.Close()
port := ln.Addr().(*net.UDPAddr).Port
require.NotZero(t, port)
require.Equal(t, ln.Multiaddr().String(), fmt.Sprintf("/ip6/::/udp/%d/quic", port))
})
}
func TestAccepting(t *testing.T) {
tr := newTransport(t, nil)
defer tr.(io.Closer).Close()
ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic"))
require.NoError(t, err)
done := make(chan struct{})
go func() {
ln.Accept()
close(done)
}()
time.Sleep(100 * time.Millisecond)
select {
case <-done:
t.Fatal("Accept didn't block")
default:
}
require.NoError(t, ln.Close())
select {
case <-done:
case <-time.After(100 * time.Millisecond):
t.Fatal("Accept didn't return after the listener was closed")
}
}
func TestAcceptAfterClose(t *testing.T) {
tr := newTransport(t, nil)
defer tr.(io.Closer).Close()
ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic"))
require.NoError(t, err)
require.NoError(t, ln.Close())
_, err = ln.Accept()
require.Error(t, err)
}

View File

@ -0,0 +1,109 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/libp2p/go-libp2p-core/connmgr (interfaces: ConnectionGater)
// Package libp2pquic is a generated GoMock package.
package libp2pquic
import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
control "github.com/libp2p/go-libp2p-core/control"
network "github.com/libp2p/go-libp2p-core/network"
peer "github.com/libp2p/go-libp2p-core/peer"
multiaddr "github.com/multiformats/go-multiaddr"
)
// MockConnectionGater is a mock of ConnectionGater interface.
type MockConnectionGater struct {
ctrl *gomock.Controller
recorder *MockConnectionGaterMockRecorder
}
// MockConnectionGaterMockRecorder is the mock recorder for MockConnectionGater.
type MockConnectionGaterMockRecorder struct {
mock *MockConnectionGater
}
// NewMockConnectionGater creates a new mock instance.
func NewMockConnectionGater(ctrl *gomock.Controller) *MockConnectionGater {
mock := &MockConnectionGater{ctrl: ctrl}
mock.recorder = &MockConnectionGaterMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockConnectionGater) EXPECT() *MockConnectionGaterMockRecorder {
return m.recorder
}
// InterceptAccept mocks base method.
func (m *MockConnectionGater) InterceptAccept(arg0 network.ConnMultiaddrs) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "InterceptAccept", arg0)
ret0, _ := ret[0].(bool)
return ret0
}
// InterceptAccept indicates an expected call of InterceptAccept.
func (mr *MockConnectionGaterMockRecorder) InterceptAccept(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InterceptAccept", reflect.TypeOf((*MockConnectionGater)(nil).InterceptAccept), arg0)
}
// InterceptAddrDial mocks base method.
func (m *MockConnectionGater) InterceptAddrDial(arg0 peer.ID, arg1 multiaddr.Multiaddr) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "InterceptAddrDial", arg0, arg1)
ret0, _ := ret[0].(bool)
return ret0
}
// InterceptAddrDial indicates an expected call of InterceptAddrDial.
func (mr *MockConnectionGaterMockRecorder) InterceptAddrDial(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InterceptAddrDial", reflect.TypeOf((*MockConnectionGater)(nil).InterceptAddrDial), arg0, arg1)
}
// InterceptPeerDial mocks base method.
func (m *MockConnectionGater) InterceptPeerDial(arg0 peer.ID) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "InterceptPeerDial", arg0)
ret0, _ := ret[0].(bool)
return ret0
}
// InterceptPeerDial indicates an expected call of InterceptPeerDial.
func (mr *MockConnectionGaterMockRecorder) InterceptPeerDial(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InterceptPeerDial", reflect.TypeOf((*MockConnectionGater)(nil).InterceptPeerDial), arg0)
}
// InterceptSecured mocks base method.
func (m *MockConnectionGater) InterceptSecured(arg0 network.Direction, arg1 peer.ID, arg2 network.ConnMultiaddrs) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "InterceptSecured", arg0, arg1, arg2)
ret0, _ := ret[0].(bool)
return ret0
}
// InterceptSecured indicates an expected call of InterceptSecured.
func (mr *MockConnectionGaterMockRecorder) InterceptSecured(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InterceptSecured", reflect.TypeOf((*MockConnectionGater)(nil).InterceptSecured), arg0, arg1, arg2)
}
// InterceptUpgraded mocks base method.
func (m *MockConnectionGater) InterceptUpgraded(arg0 network.Conn) (bool, control.DisconnectReason) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "InterceptUpgraded", arg0)
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(control.DisconnectReason)
return ret0, ret1
}
// InterceptUpgraded indicates an expected call of InterceptUpgraded.
func (mr *MockConnectionGaterMockRecorder) InterceptUpgraded(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InterceptUpgraded", reflect.TypeOf((*MockConnectionGater)(nil).InterceptUpgraded), arg0)
}

View File

@ -0,0 +1,30 @@
package libp2pquic
import (
"net"
ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net"
)
var quicMA ma.Multiaddr
func init() {
var err error
quicMA, err = ma.NewMultiaddr("/quic")
if err != nil {
panic(err)
}
}
func toQuicMultiaddr(na net.Addr) (ma.Multiaddr, error) {
udpMA, err := manet.FromNetAddr(na)
if err != nil {
return nil, err
}
return udpMA.Encapsulate(quicMA), nil
}
func fromQuicMultiaddr(addr ma.Multiaddr) (net.Addr, error) {
return manet.ToNetAddr(addr.Decapsulate(quicMA))
}

View File

@ -0,0 +1,27 @@
package libp2pquic
import (
"net"
"testing"
ma "github.com/multiformats/go-multiaddr"
"github.com/stretchr/testify/require"
)
func TestConvertToQuicMultiaddr(t *testing.T) {
addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 42), Port: 1337}
maddr, err := toQuicMultiaddr(addr)
require.NoError(t, err)
require.Equal(t, maddr.String(), "/ip4/192.168.0.42/udp/1337/quic")
}
func TestConvertFromQuicMultiaddr(t *testing.T) {
maddr, err := ma.NewMultiaddr("/ip4/192.168.0.42/udp/1337/quic")
require.NoError(t, err)
addr, err := fromQuicMultiaddr(maddr)
require.NoError(t, err)
udpAddr, ok := addr.(*net.UDPAddr)
require.True(t, ok)
require.Equal(t, udpAddr.IP, net.IPv4(192, 168, 0, 42))
require.Equal(t, udpAddr.Port, 1337)
}

231
p2p/transport/quic/reuse.go Normal file
View File

@ -0,0 +1,231 @@
package libp2pquic
import (
"net"
"sync"
"time"
"github.com/google/gopacket/routing"
"github.com/libp2p/go-netroute"
)
// Constant. Defined as variables to simplify testing.
var (
garbageCollectInterval = 30 * time.Second
maxUnusedDuration = 10 * time.Second
)
type reuseConn struct {
*net.UDPConn
mutex sync.Mutex
refCount int
unusedSince time.Time
}
func newReuseConn(conn *net.UDPConn) *reuseConn {
return &reuseConn{UDPConn: conn}
}
func (c *reuseConn) IncreaseCount() {
c.mutex.Lock()
c.refCount++
c.unusedSince = time.Time{}
c.mutex.Unlock()
}
func (c *reuseConn) DecreaseCount() {
c.mutex.Lock()
c.refCount--
if c.refCount == 0 {
c.unusedSince = time.Now()
}
c.mutex.Unlock()
}
func (c *reuseConn) ShouldGarbageCollect(now time.Time) bool {
c.mutex.Lock()
defer c.mutex.Unlock()
return !c.unusedSince.IsZero() && c.unusedSince.Add(maxUnusedDuration).Before(now)
}
type reuse struct {
mutex sync.Mutex
closeChan chan struct{}
gcStopChan chan struct{}
routes routing.Router
unicast map[string] /* IP.String() */ map[int] /* port */ *reuseConn
// global contains connections that are listening on 0.0.0.0 / ::
global map[int]*reuseConn
}
func newReuse() *reuse {
r := &reuse{
unicast: make(map[string]map[int]*reuseConn),
global: make(map[int]*reuseConn),
closeChan: make(chan struct{}),
gcStopChan: make(chan struct{}),
}
go r.gc()
return r
}
func (r *reuse) gc() {
defer func() {
r.mutex.Lock()
for _, conn := range r.global {
conn.Close()
}
for _, conns := range r.unicast {
for _, conn := range conns {
conn.Close()
}
}
r.mutex.Unlock()
close(r.gcStopChan)
}()
ticker := time.NewTicker(garbageCollectInterval)
defer ticker.Stop()
for {
select {
case <-r.closeChan:
return
case now := <-ticker.C:
r.mutex.Lock()
for key, conn := range r.global {
if conn.ShouldGarbageCollect(now) {
conn.Close()
delete(r.global, key)
}
}
for ukey, conns := range r.unicast {
for key, conn := range conns {
if conn.ShouldGarbageCollect(now) {
conn.Close()
delete(conns, key)
}
}
if len(conns) == 0 {
delete(r.unicast, ukey)
// If we've dropped all connections with a unicast binding,
// assume our routes may have changed.
if len(r.unicast) == 0 {
r.routes = nil
} else {
// Ignore the error, there's nothing we can do about
// it.
r.routes, _ = netroute.New()
}
}
}
r.mutex.Unlock()
}
}
}
func (r *reuse) Dial(network string, raddr *net.UDPAddr) (*reuseConn, error) {
var ip *net.IP
// Only bother looking up the source address if we actually _have_ non 0.0.0.0 listeners.
// Otherwise, save some time.
r.mutex.Lock()
router := r.routes
r.mutex.Unlock()
if router != nil {
_, _, src, err := router.Route(raddr.IP)
if err == nil && !src.IsUnspecified() {
ip = &src
}
}
r.mutex.Lock()
defer r.mutex.Unlock()
conn, err := r.dialLocked(network, ip)
if err != nil {
return nil, err
}
conn.IncreaseCount()
return conn, nil
}
func (r *reuse) dialLocked(network string, source *net.IP) (*reuseConn, error) {
if source != nil {
// We already have at least one suitable connection...
if conns, ok := r.unicast[source.String()]; ok {
// ... we don't care which port we're dialing from. Just use the first.
for _, c := range conns {
return c, nil
}
}
}
// Use a connection listening on 0.0.0.0 (or ::).
// Again, we don't care about the port number.
for _, conn := range r.global {
return conn, nil
}
// We don't have a connection that we can use for dialing.
// Dial a new connection from a random port.
var addr *net.UDPAddr
switch network {
case "udp4":
addr = &net.UDPAddr{IP: net.IPv4zero, Port: 0}
case "udp6":
addr = &net.UDPAddr{IP: net.IPv6zero, Port: 0}
}
conn, err := net.ListenUDP(network, addr)
if err != nil {
return nil, err
}
rconn := newReuseConn(conn)
r.global[conn.LocalAddr().(*net.UDPAddr).Port] = rconn
return rconn, nil
}
func (r *reuse) Listen(network string, laddr *net.UDPAddr) (*reuseConn, error) {
conn, err := net.ListenUDP(network, laddr)
if err != nil {
return nil, err
}
localAddr := conn.LocalAddr().(*net.UDPAddr)
rconn := newReuseConn(conn)
rconn.IncreaseCount()
r.mutex.Lock()
defer r.mutex.Unlock()
// Deal with listen on a global address
if localAddr.IP.IsUnspecified() {
// The kernel already checked that the laddr is not already listen
// so we need not check here (when we create ListenUDP).
r.global[localAddr.Port] = rconn
return rconn, err
}
// Deal with listen on a unicast address
if _, ok := r.unicast[localAddr.IP.String()]; !ok {
r.unicast[localAddr.IP.String()] = make(map[int]*reuseConn)
// Assume the system's routes may have changed if we're adding a new listener.
// Ignore the error, there's nothing we can do.
r.routes, _ = netroute.New()
}
// The kernel already checked that the laddr is not already listen
// so we need not check here (when we create ListenUDP).
r.unicast[localAddr.IP.String()][localAddr.Port] = rconn
return rconn, err
}
func (r *reuse) Close() error {
close(r.closeChan)
<-r.gcStopChan
return nil
}

View File

@ -0,0 +1,175 @@
package libp2pquic
import (
"bytes"
"net"
"runtime/pprof"
"strings"
"testing"
"time"
"github.com/libp2p/go-netroute"
"github.com/stretchr/testify/require"
)
func (c *reuseConn) GetCount() int {
c.mutex.Lock()
defer c.mutex.Unlock()
return c.refCount
}
func closeAllConns(reuse *reuse) {
reuse.mutex.Lock()
for _, conn := range reuse.global {
for conn.GetCount() > 0 {
conn.DecreaseCount()
}
}
for _, conns := range reuse.unicast {
for _, conn := range conns {
for conn.GetCount() > 0 {
conn.DecreaseCount()
}
}
}
reuse.mutex.Unlock()
}
func platformHasRoutingTables() bool {
_, err := netroute.New()
return err == nil
}
func isGarbageCollectorRunning() bool {
var b bytes.Buffer
pprof.Lookup("goroutine").WriteTo(&b, 1)
return strings.Contains(b.String(), "quic.(*reuse).gc")
}
func cleanup(t *testing.T, reuse *reuse) {
t.Cleanup(func() {
closeAllConns(reuse)
reuse.Close()
require.False(t, isGarbageCollectorRunning(), "reuse gc still running")
})
}
func TestReuseListenOnAllIPv4(t *testing.T) {
reuse := newReuse()
require.Eventually(t, isGarbageCollectorRunning, 100*time.Millisecond, time.Millisecond, "expected garbage collector to be running")
cleanup(t, reuse)
addr, err := net.ResolveUDPAddr("udp4", "0.0.0.0:0")
require.NoError(t, err)
conn, err := reuse.Listen("udp4", addr)
require.NoError(t, err)
require.Equal(t, conn.GetCount(), 1)
}
func TestReuseListenOnAllIPv6(t *testing.T) {
reuse := newReuse()
require.Eventually(t, isGarbageCollectorRunning, 100*time.Millisecond, time.Millisecond, "expected garbage collector to be running")
cleanup(t, reuse)
addr, err := net.ResolveUDPAddr("udp6", "[::]:1234")
require.NoError(t, err)
conn, err := reuse.Listen("udp6", addr)
require.NoError(t, err)
require.Equal(t, conn.GetCount(), 1)
}
func TestReuseCreateNewGlobalConnOnDial(t *testing.T) {
reuse := newReuse()
cleanup(t, reuse)
addr, err := net.ResolveUDPAddr("udp4", "1.1.1.1:1234")
require.NoError(t, err)
conn, err := reuse.Dial("udp4", addr)
require.NoError(t, err)
require.Equal(t, conn.GetCount(), 1)
laddr := conn.LocalAddr().(*net.UDPAddr)
require.Equal(t, laddr.IP.String(), "0.0.0.0")
require.NotEqual(t, laddr.Port, 0)
}
func TestReuseConnectionWhenDialing(t *testing.T) {
reuse := newReuse()
cleanup(t, reuse)
addr, err := net.ResolveUDPAddr("udp4", "0.0.0.0:0")
require.NoError(t, err)
lconn, err := reuse.Listen("udp4", addr)
require.NoError(t, err)
require.Equal(t, lconn.GetCount(), 1)
// dial
raddr, err := net.ResolveUDPAddr("udp4", "1.1.1.1:1234")
require.NoError(t, err)
conn, err := reuse.Dial("udp4", raddr)
require.NoError(t, err)
require.Equal(t, conn.GetCount(), 2)
}
func TestReuseListenOnSpecificInterface(t *testing.T) {
if platformHasRoutingTables() {
t.Skip("this test only works on platforms that support routing tables")
}
reuse := newReuse()
cleanup(t, reuse)
router, err := netroute.New()
require.NoError(t, err)
raddr, err := net.ResolveUDPAddr("udp4", "1.1.1.1:1234")
require.NoError(t, err)
_, _, ip, err := router.Route(raddr.IP)
require.NoError(t, err)
// listen
addr, err := net.ResolveUDPAddr("udp4", ip.String()+":0")
require.NoError(t, err)
lconn, err := reuse.Listen("udp4", addr)
require.NoError(t, err)
require.Equal(t, lconn.GetCount(), 1)
// dial
conn, err := reuse.Dial("udp4", raddr)
require.NoError(t, err)
require.Equal(t, conn.GetCount(), 1)
}
func TestReuseGarbageCollect(t *testing.T) {
maxUnusedDurationOrig := maxUnusedDuration
garbageCollectIntervalOrig := garbageCollectInterval
t.Cleanup(func() {
maxUnusedDuration = maxUnusedDurationOrig
garbageCollectInterval = garbageCollectIntervalOrig
})
garbageCollectInterval = 50 * time.Millisecond
maxUnusedDuration = 100 * time.Millisecond
reuse := newReuse()
cleanup(t, reuse)
numGlobals := func() int {
reuse.mutex.Lock()
defer reuse.mutex.Unlock()
return len(reuse.global)
}
addr, err := net.ResolveUDPAddr("udp4", "0.0.0.0:0")
require.NoError(t, err)
lconn, err := reuse.Listen("udp4", addr)
require.NoError(t, err)
require.Equal(t, lconn.GetCount(), 1)
closeTime := time.Now()
lconn.DecreaseCount()
for {
num := numGlobals()
if closeTime.Add(maxUnusedDuration).Before(time.Now()) {
break
}
require.Equal(t, num, 1)
time.Sleep(2 * time.Millisecond)
}
require.Eventually(t, func() bool { return numGlobals() == 0 }, 100*time.Millisecond, 5*time.Millisecond)
}

View File

@ -0,0 +1,55 @@
package libp2pquic
import (
"errors"
"github.com/libp2p/go-libp2p-core/network"
"github.com/lucas-clemente/quic-go"
)
const (
reset quic.StreamErrorCode = 0
)
type stream struct {
quic.Stream
}
var _ network.MuxedStream = &stream{}
func (s *stream) Read(b []byte) (n int, err error) {
n, err = s.Stream.Read(b)
if err != nil && errors.Is(err, &quic.StreamError{}) {
err = network.ErrReset
}
return n, err
}
func (s *stream) Write(b []byte) (n int, err error) {
n, err = s.Stream.Write(b)
if err != nil && errors.Is(err, &quic.StreamError{}) {
err = network.ErrReset
}
return n, err
}
func (s *stream) Reset() error {
s.Stream.CancelRead(reset)
s.Stream.CancelWrite(reset)
return nil
}
func (s *stream) Close() error {
s.Stream.CancelRead(reset)
return s.Stream.Close()
}
func (s *stream) CloseRead() error {
s.Stream.CancelRead(reset)
return nil
}
func (s *stream) CloseWrite() error {
return s.Stream.Close()
}

View File

@ -0,0 +1,95 @@
package libp2pquic
import (
"bufio"
"fmt"
"io"
"os"
"time"
"github.com/klauspost/compress/zstd"
"github.com/lucas-clemente/quic-go/logging"
"github.com/lucas-clemente/quic-go/qlog"
)
var tracer logging.Tracer
func init() {
tracers := []logging.Tracer{&metricsTracer{}}
if qlogDir := os.Getenv("QLOGDIR"); len(qlogDir) > 0 {
if qlogger := initQlogger(qlogDir); qlogger != nil {
tracers = append(tracers, qlogger)
}
}
tracer = logging.NewMultiplexedTracer(tracers...)
}
func initQlogger(qlogDir string) logging.Tracer {
return qlog.NewTracer(func(role logging.Perspective, connID []byte) io.WriteCloser {
// create the QLOGDIR, if it doesn't exist
if err := os.MkdirAll(qlogDir, 0777); err != nil {
log.Errorf("creating the QLOGDIR failed: %s", err)
return nil
}
return newQlogger(qlogDir, role, connID)
})
}
// The qlogger logs qlog events to a temporary file: .<name>.qlog.swp.
// When it is closed, it compresses the temporary file and saves it as <name>.qlog.zst.
// It is not possible to compress on the fly, as compression algorithms keep a lot of internal state,
// which can easily exhaust the host system's memory when running a few hundred QUIC connections in parallel.
type qlogger struct {
f *os.File // QLOGDIR/.log_xxx.qlog.swp
filename string // QLOGDIR/log_xxx.qlog.zst
*bufio.Writer // buffering the f
}
func newQlogger(qlogDir string, role logging.Perspective, connID []byte) io.WriteCloser {
t := time.Now().UTC().Format("2006-01-02T15-04-05.999999999UTC")
r := "server"
if role == logging.PerspectiveClient {
r = "client"
}
finalFilename := fmt.Sprintf("%s%clog_%s_%s_%x.qlog.zst", qlogDir, os.PathSeparator, t, r, connID)
filename := fmt.Sprintf("%s%c.log_%s_%s_%x.qlog.swp", qlogDir, os.PathSeparator, t, r, connID)
f, err := os.Create(filename)
if err != nil {
log.Errorf("unable to create qlog file %s: %s", filename, err)
return nil
}
return &qlogger{
f: f,
filename: finalFilename,
Writer: bufio.NewWriter(f),
}
}
func (l *qlogger) Close() error {
defer os.Remove(l.f.Name())
defer l.f.Close()
if err := l.Writer.Flush(); err != nil {
return err
}
if _, err := l.f.Seek(0, io.SeekStart); err != nil { // set the read position to the beginning of the file
return err
}
f, err := os.Create(l.filename)
if err != nil {
return err
}
defer f.Close()
buf := bufio.NewWriter(f)
c, err := zstd.NewWriter(buf, zstd.WithEncoderLevel(zstd.SpeedFastest), zstd.WithWindowSize(32*1024))
if err != nil {
return err
}
if _, err := io.Copy(c, l.f); err != nil {
return err
}
if err := c.Close(); err != nil {
return err
}
return buf.Flush()
}

View File

@ -0,0 +1,387 @@
package libp2pquic
import (
"context"
"errors"
"fmt"
"net"
"sync"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/logging"
)
var (
bytesTransferred *prometheus.CounterVec
newConns *prometheus.CounterVec
closedConns *prometheus.CounterVec
sentPackets *prometheus.CounterVec
rcvdPackets *prometheus.CounterVec
bufferedPackets *prometheus.CounterVec
droppedPackets *prometheus.CounterVec
lostPackets *prometheus.CounterVec
connErrors *prometheus.CounterVec
)
type aggregatingCollector struct {
mutex sync.Mutex
conns map[string] /* conn ID */ *metricsConnTracer
rtts prometheus.Histogram
connDurations prometheus.Histogram
}
func newAggregatingCollector() *aggregatingCollector {
return &aggregatingCollector{
conns: make(map[string]*metricsConnTracer),
rtts: prometheus.NewHistogram(prometheus.HistogramOpts{
Name: "quic_smoothed_rtt",
Help: "Smoothed RTT",
Buckets: prometheus.ExponentialBuckets(0.001, 1.25, 40), // 1ms to ~6000ms
}),
connDurations: prometheus.NewHistogram(prometheus.HistogramOpts{
Name: "quic_connection_duration",
Help: "Connection Duration",
Buckets: prometheus.ExponentialBuckets(1, 1.5, 40), // 1s to ~12 weeks
}),
}
}
var _ prometheus.Collector = &aggregatingCollector{}
func (c *aggregatingCollector) Describe(descs chan<- *prometheus.Desc) {
descs <- c.rtts.Desc()
descs <- c.connDurations.Desc()
}
func (c *aggregatingCollector) Collect(metrics chan<- prometheus.Metric) {
now := time.Now()
c.mutex.Lock()
for _, conn := range c.conns {
if rtt, valid := conn.getSmoothedRTT(); valid {
c.rtts.Observe(rtt.Seconds())
}
c.connDurations.Observe(now.Sub(conn.startTime).Seconds())
}
c.mutex.Unlock()
metrics <- c.rtts
metrics <- c.connDurations
}
func (c *aggregatingCollector) AddConn(id string, t *metricsConnTracer) {
c.mutex.Lock()
c.conns[id] = t
c.mutex.Unlock()
}
func (c *aggregatingCollector) RemoveConn(id string) {
c.mutex.Lock()
delete(c.conns, id)
c.mutex.Unlock()
}
var collector *aggregatingCollector
func init() {
const (
direction = "direction"
encLevel = "encryption_level"
)
closedConns = prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "quic_connections_closed_total",
Help: "closed QUIC connection",
},
[]string{direction},
)
prometheus.MustRegister(closedConns)
newConns = prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "quic_connections_new_total",
Help: "new QUIC connection",
},
[]string{direction, "handshake_successful"},
)
prometheus.MustRegister(newConns)
bytesTransferred = prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "quic_transferred_bytes",
Help: "QUIC bytes transferred",
},
[]string{direction}, // TODO: this is confusing. Other times, we use direction for the perspective
)
prometheus.MustRegister(bytesTransferred)
sentPackets = prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "quic_packets_sent_total",
Help: "QUIC packets sent",
},
[]string{encLevel},
)
prometheus.MustRegister(sentPackets)
rcvdPackets = prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "quic_packets_rcvd_total",
Help: "QUIC packets received",
},
[]string{encLevel},
)
prometheus.MustRegister(rcvdPackets)
bufferedPackets = prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "quic_packets_buffered_total",
Help: "Buffered packets",
},
[]string{"packet_type"},
)
prometheus.MustRegister(bufferedPackets)
droppedPackets = prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "quic_packets_dropped_total",
Help: "Dropped packets",
},
[]string{"packet_type", "reason"},
)
prometheus.MustRegister(droppedPackets)
connErrors = prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "quic_connection_errors_total",
Help: "QUIC connection errors",
},
[]string{"side", "error_code"},
)
prometheus.MustRegister(connErrors)
lostPackets = prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "quic_packets_lost_total",
Help: "QUIC lost received",
},
[]string{encLevel, "reason"},
)
prometheus.MustRegister(lostPackets)
collector = newAggregatingCollector()
prometheus.MustRegister(collector)
}
type metricsTracer struct{}
var _ logging.Tracer = &metricsTracer{}
func (m *metricsTracer) TracerForConnection(_ context.Context, p logging.Perspective, connID logging.ConnectionID) logging.ConnectionTracer {
return &metricsConnTracer{perspective: p, connID: connID}
}
func (m *metricsTracer) SentPacket(_ net.Addr, _ *logging.Header, size logging.ByteCount, _ []logging.Frame) {
bytesTransferred.WithLabelValues("sent").Add(float64(size))
}
func (m *metricsTracer) DroppedPacket(addr net.Addr, packetType logging.PacketType, count logging.ByteCount, reason logging.PacketDropReason) {
}
type metricsConnTracer struct {
perspective logging.Perspective
startTime time.Time
connID logging.ConnectionID
handshakeComplete bool
mutex sync.Mutex
numRTTMeasurements int
rtt time.Duration
}
var _ logging.ConnectionTracer = &metricsConnTracer{}
func (m *metricsConnTracer) getDirection() string {
if m.perspective == logging.PerspectiveClient {
return "outgoing"
}
return "incoming"
}
func (m *metricsConnTracer) getEncLevel(packetType logging.PacketType) string {
switch packetType {
case logging.PacketType0RTT:
return "0-RTT"
case logging.PacketTypeInitial:
return "Initial"
case logging.PacketTypeHandshake:
return "Handshake"
case logging.PacketTypeRetry:
return "Retry"
case logging.PacketType1RTT:
return "1-RTT"
default:
return "unknown"
}
}
func (m *metricsConnTracer) StartedConnection(net.Addr, net.Addr, logging.ConnectionID, logging.ConnectionID) {
m.startTime = time.Now()
collector.AddConn(m.connID.String(), m)
}
func (m *metricsConnTracer) NegotiatedVersion(chosen quic.VersionNumber, clientVersions []quic.VersionNumber, serverVersions []quic.VersionNumber) {
}
func (m *metricsConnTracer) ClosedConnection(e error) {
var (
applicationErr *quic.ApplicationError
transportErr *quic.TransportError
statelessResetErr *quic.StatelessResetError
vnErr *quic.VersionNegotiationError
idleTimeoutErr *quic.IdleTimeoutError
handshakeTimeoutErr *quic.HandshakeTimeoutError
remote bool
desc string
)
switch {
case errors.As(e, &applicationErr):
return
case errors.As(e, &transportErr):
remote = transportErr.Remote
desc = transportErr.ErrorCode.String()
case errors.As(e, &statelessResetErr):
remote = true
desc = "stateless_reset"
case errors.As(e, &vnErr):
desc = "version_negotiation"
case errors.As(e, &idleTimeoutErr):
desc = "idle_timeout"
case errors.As(e, &handshakeTimeoutErr):
desc = "handshake_timeout"
default:
desc = fmt.Sprintf("unknown error: %v", e)
}
side := "local"
if remote {
side = "remote"
}
connErrors.WithLabelValues(side, desc).Inc()
}
func (m *metricsConnTracer) SentTransportParameters(parameters *logging.TransportParameters) {}
func (m *metricsConnTracer) ReceivedTransportParameters(parameters *logging.TransportParameters) {}
func (m *metricsConnTracer) RestoredTransportParameters(parameters *logging.TransportParameters) {}
func (m *metricsConnTracer) SentPacket(hdr *logging.ExtendedHeader, size logging.ByteCount, _ *logging.AckFrame, _ []logging.Frame) {
bytesTransferred.WithLabelValues("sent").Add(float64(size))
sentPackets.WithLabelValues(m.getEncLevel(logging.PacketTypeFromHeader(&hdr.Header))).Inc()
}
func (m *metricsConnTracer) ReceivedVersionNegotiationPacket(hdr *logging.Header, v []logging.VersionNumber) {
bytesTransferred.WithLabelValues("rcvd").Add(float64(hdr.ParsedLen() + logging.ByteCount(4*len(v))))
rcvdPackets.WithLabelValues("Version Negotiation").Inc()
}
func (m *metricsConnTracer) ReceivedRetry(*logging.Header) {
rcvdPackets.WithLabelValues("Retry").Inc()
}
func (m *metricsConnTracer) ReceivedPacket(hdr *logging.ExtendedHeader, size logging.ByteCount, _ []logging.Frame) {
bytesTransferred.WithLabelValues("rcvd").Add(float64(size))
rcvdPackets.WithLabelValues(m.getEncLevel(logging.PacketTypeFromHeader(&hdr.Header))).Inc()
}
func (m *metricsConnTracer) BufferedPacket(packetType logging.PacketType) {
bufferedPackets.WithLabelValues(m.getEncLevel(packetType)).Inc()
}
func (m *metricsConnTracer) DroppedPacket(packetType logging.PacketType, size logging.ByteCount, r logging.PacketDropReason) {
bytesTransferred.WithLabelValues("rcvd").Add(float64(size))
var reason string
switch r {
case logging.PacketDropKeyUnavailable:
reason = "key_unavailable"
case logging.PacketDropUnknownConnectionID:
reason = "unknown_connection_id"
case logging.PacketDropHeaderParseError:
reason = "header_parse_error"
case logging.PacketDropPayloadDecryptError:
reason = "payload_decrypt_error"
case logging.PacketDropProtocolViolation:
reason = "protocol_violation"
case logging.PacketDropDOSPrevention:
reason = "dos_prevention"
case logging.PacketDropUnsupportedVersion:
reason = "unsupported_version"
case logging.PacketDropUnexpectedPacket:
reason = "unexpected_packet"
case logging.PacketDropUnexpectedSourceConnectionID:
reason = "unexpected_source_connection_id"
case logging.PacketDropUnexpectedVersion:
reason = "unexpected_version"
case logging.PacketDropDuplicate:
reason = "duplicate"
default:
reason = "unknown"
}
droppedPackets.WithLabelValues(m.getEncLevel(packetType), reason).Inc()
}
func (m *metricsConnTracer) UpdatedMetrics(rttStats *logging.RTTStats, cwnd, bytesInFlight logging.ByteCount, packetsInFlight int) {
m.mutex.Lock()
m.rtt = rttStats.SmoothedRTT()
m.numRTTMeasurements++
m.mutex.Unlock()
}
func (m *metricsConnTracer) AcknowledgedPacket(logging.EncryptionLevel, logging.PacketNumber) {}
func (m *metricsConnTracer) LostPacket(level logging.EncryptionLevel, _ logging.PacketNumber, r logging.PacketLossReason) {
var reason string
switch r {
case logging.PacketLossReorderingThreshold:
reason = "reordering_threshold"
case logging.PacketLossTimeThreshold:
reason = "time_threshold"
default:
reason = "unknown"
}
lostPackets.WithLabelValues(level.String(), reason).Inc()
}
func (m *metricsConnTracer) UpdatedCongestionState(state logging.CongestionState) {}
func (m *metricsConnTracer) UpdatedPTOCount(value uint32) {}
func (m *metricsConnTracer) UpdatedKeyFromTLS(level logging.EncryptionLevel, perspective logging.Perspective) {
}
func (m *metricsConnTracer) UpdatedKey(generation logging.KeyPhase, remote bool) {}
func (m *metricsConnTracer) DroppedEncryptionLevel(level logging.EncryptionLevel) {
if level == logging.EncryptionHandshake {
m.handleHandshakeComplete()
}
}
func (m *metricsConnTracer) DroppedKey(generation logging.KeyPhase) {}
func (m *metricsConnTracer) SetLossTimer(timerType logging.TimerType, level logging.EncryptionLevel, time time.Time) {
}
func (m *metricsConnTracer) LossTimerExpired(timerType logging.TimerType, level logging.EncryptionLevel) {
}
func (m *metricsConnTracer) LossTimerCanceled() {}
func (m *metricsConnTracer) Close() {
if m.handshakeComplete {
closedConns.WithLabelValues(m.getDirection()).Inc()
} else {
newConns.WithLabelValues(m.getDirection(), "false").Inc()
}
collector.RemoveConn(m.connID.String())
}
func (m *metricsConnTracer) Debug(name, msg string) {}
func (m *metricsConnTracer) handleHandshakeComplete() {
m.handshakeComplete = true
newConns.WithLabelValues(m.getDirection(), "true").Inc()
}
func (m *metricsConnTracer) getSmoothedRTT() (rtt time.Duration, valid bool) {
m.mutex.Lock()
rtt = m.rtt
valid = m.numRTTMeasurements > 10
m.mutex.Unlock()
return
}

View File

@ -0,0 +1,74 @@
package libp2pquic
import (
"bytes"
"io/ioutil"
"os"
"strings"
"testing"
"github.com/stretchr/testify/require"
"github.com/klauspost/compress/zstd"
"github.com/lucas-clemente/quic-go/logging"
)
func createLogDir(t *testing.T) string {
dir, err := ioutil.TempDir("", "libp2p-quic-transport-test")
require.NoError(t, err)
t.Cleanup(func() { os.RemoveAll(dir) })
return dir
}
func getFile(t *testing.T, dir string) os.FileInfo {
files, err := ioutil.ReadDir(dir)
require.NoError(t, err)
require.Len(t, files, 1)
return files[0]
}
func TestSaveQlog(t *testing.T) {
qlogDir := createLogDir(t)
logger := newQlogger(qlogDir, logging.PerspectiveServer, []byte{0xde, 0xad, 0xbe, 0xef})
file := getFile(t, qlogDir)
require.Equal(t, string(file.Name()[0]), ".")
require.Truef(t, strings.HasSuffix(file.Name(), ".qlog.swp"), "expected %s to have the .qlog.swp file ending", file.Name())
// close the logger. This should move the file.
require.NoError(t, logger.Close())
file = getFile(t, qlogDir)
require.NotEqual(t, string(file.Name()[0]), ".")
require.Truef(t, strings.HasSuffix(file.Name(), ".qlog.zst"), "expected %s to have the .qlog.zst file ending", file.Name())
require.Contains(t, file.Name(), "server")
require.Contains(t, file.Name(), "deadbeef")
}
func TestQlogBuffering(t *testing.T) {
qlogDir := createLogDir(t)
logger := newQlogger(qlogDir, logging.PerspectiveServer, []byte("connid"))
initialSize := getFile(t, qlogDir).Size()
// Do a small write.
// Since the writter is buffered, this should not be written to disk yet.
logger.Write([]byte("foobar"))
require.Equal(t, getFile(t, qlogDir).Size(), initialSize)
// Close the logger. This should flush the buffer to disk.
require.NoError(t, logger.Close())
finalSize := getFile(t, qlogDir).Size()
t.Logf("initial log file size: %d, final log file size: %d\n", initialSize, finalSize)
require.Greater(t, finalSize, initialSize)
}
func TestQlogCompression(t *testing.T) {
qlogDir := createLogDir(t)
logger := newQlogger(qlogDir, logging.PerspectiveServer, []byte("connid"))
logger.Write([]byte("foobar"))
require.NoError(t, logger.Close())
compressed, err := ioutil.ReadFile(qlogDir + "/" + getFile(t, qlogDir).Name())
require.NoError(t, err)
require.NotEqual(t, compressed, "foobar")
c, err := zstd.NewReader(bytes.NewReader(compressed))
require.NoError(t, err)
data, err := ioutil.ReadAll(c)
require.NoError(t, err)
require.Equal(t, data, []byte("foobar"))
}

View File

@ -0,0 +1,409 @@
package libp2pquic
import (
"context"
"errors"
"fmt"
"io"
"math/rand"
"net"
"sync"
"time"
"golang.org/x/crypto/hkdf"
"github.com/libp2p/go-libp2p-core/connmgr"
ic "github.com/libp2p/go-libp2p-core/crypto"
"github.com/libp2p/go-libp2p-core/network"
"github.com/libp2p/go-libp2p-core/peer"
"github.com/libp2p/go-libp2p-core/pnet"
tpt "github.com/libp2p/go-libp2p-core/transport"
p2ptls "github.com/libp2p/go-libp2p-tls"
ma "github.com/multiformats/go-multiaddr"
mafmt "github.com/multiformats/go-multiaddr-fmt"
manet "github.com/multiformats/go-multiaddr/net"
logging "github.com/ipfs/go-log/v2"
"github.com/lucas-clemente/quic-go"
"github.com/minio/sha256-simd"
)
var log = logging.Logger("quic-transport")
var ErrHolePunching = errors.New("hole punching attempted; no active dial")
var quicDialContext = quic.DialContext // so we can mock it in tests
var HolePunchTimeout = 5 * time.Second
var quicConfig = &quic.Config{
MaxIncomingStreams: 256,
MaxIncomingUniStreams: -1, // disable unidirectional streams
MaxStreamReceiveWindow: 10 * (1 << 20), // 10 MB
MaxConnectionReceiveWindow: 15 * (1 << 20), // 15 MB
AcceptToken: func(clientAddr net.Addr, _ *quic.Token) bool {
// TODO(#6): require source address validation when under load
return true
},
KeepAlive: true,
Versions: []quic.VersionNumber{quic.VersionDraft29, quic.Version1},
}
const statelessResetKeyInfo = "libp2p quic stateless reset key"
const errorCodeConnectionGating = 0x47415445 // GATE in ASCII
type connManager struct {
reuseUDP4 *reuse
reuseUDP6 *reuse
}
func newConnManager() (*connManager, error) {
reuseUDP4 := newReuse()
reuseUDP6 := newReuse()
return &connManager{
reuseUDP4: reuseUDP4,
reuseUDP6: reuseUDP6,
}, nil
}
func (c *connManager) getReuse(network string) (*reuse, error) {
switch network {
case "udp4":
return c.reuseUDP4, nil
case "udp6":
return c.reuseUDP6, nil
default:
return nil, errors.New("invalid network: must be either udp4 or udp6")
}
}
func (c *connManager) Listen(network string, laddr *net.UDPAddr) (*reuseConn, error) {
reuse, err := c.getReuse(network)
if err != nil {
return nil, err
}
return reuse.Listen(network, laddr)
}
func (c *connManager) Dial(network string, raddr *net.UDPAddr) (*reuseConn, error) {
reuse, err := c.getReuse(network)
if err != nil {
return nil, err
}
return reuse.Dial(network, raddr)
}
func (c *connManager) Close() error {
if err := c.reuseUDP6.Close(); err != nil {
return err
}
return c.reuseUDP4.Close()
}
// The Transport implements the tpt.Transport interface for QUIC connections.
type transport struct {
privKey ic.PrivKey
localPeer peer.ID
identity *p2ptls.Identity
connManager *connManager
serverConfig *quic.Config
clientConfig *quic.Config
gater connmgr.ConnectionGater
rcmgr network.ResourceManager
holePunchingMx sync.Mutex
holePunching map[holePunchKey]*activeHolePunch
connMx sync.Mutex
conns map[quic.Connection]*conn
}
var _ tpt.Transport = &transport{}
type holePunchKey struct {
addr string
peer peer.ID
}
type activeHolePunch struct {
connCh chan tpt.CapableConn
fulfilled bool
}
// NewTransport creates a new QUIC transport
func NewTransport(key ic.PrivKey, psk pnet.PSK, gater connmgr.ConnectionGater, rcmgr network.ResourceManager) (tpt.Transport, error) {
if len(psk) > 0 {
log.Error("QUIC doesn't support private networks yet.")
return nil, errors.New("QUIC doesn't support private networks yet")
}
localPeer, err := peer.IDFromPrivateKey(key)
if err != nil {
return nil, err
}
identity, err := p2ptls.NewIdentity(key)
if err != nil {
return nil, err
}
connManager, err := newConnManager()
if err != nil {
return nil, err
}
if rcmgr == nil {
rcmgr = network.NullResourceManager
}
config := quicConfig.Clone()
keyBytes, err := key.Raw()
if err != nil {
return nil, err
}
keyReader := hkdf.New(sha256.New, keyBytes, nil, []byte(statelessResetKeyInfo))
config.StatelessResetKey = make([]byte, 32)
if _, err := io.ReadFull(keyReader, config.StatelessResetKey); err != nil {
return nil, err
}
config.Tracer = tracer
tr := &transport{
privKey: key,
localPeer: localPeer,
identity: identity,
connManager: connManager,
gater: gater,
rcmgr: rcmgr,
conns: make(map[quic.Connection]*conn),
holePunching: make(map[holePunchKey]*activeHolePunch),
}
config.AllowConnectionWindowIncrease = tr.allowWindowIncrease
tr.serverConfig = config
tr.clientConfig = config.Clone()
return tr, nil
}
// Dial dials a new QUIC connection
func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tpt.CapableConn, error) {
netw, host, err := manet.DialArgs(raddr)
if err != nil {
return nil, err
}
addr, err := net.ResolveUDPAddr(netw, host)
if err != nil {
return nil, err
}
remoteMultiaddr, err := toQuicMultiaddr(addr)
if err != nil {
return nil, err
}
tlsConf, keyCh := t.identity.ConfigForPeer(p)
if ok, isClient, _ := network.GetSimultaneousConnect(ctx); ok && !isClient {
return t.holePunch(ctx, netw, addr, p)
}
scope, err := t.rcmgr.OpenConnection(network.DirOutbound, false)
if err != nil {
log.Debugw("resource manager blocked outgoing connection", "peer", p, "addr", raddr, "error", err)
return nil, err
}
if err := scope.SetPeer(p); err != nil {
log.Debugw("resource manager blocked outgoing connection for peer", "peer", p, "addr", raddr, "error", err)
scope.Done()
return nil, err
}
pconn, err := t.connManager.Dial(netw, addr)
if err != nil {
return nil, err
}
qconn, err := quicDialContext(ctx, pconn, addr, host, tlsConf, t.clientConfig)
if err != nil {
scope.Done()
pconn.DecreaseCount()
return nil, err
}
// Should be ready by this point, don't block.
var remotePubKey ic.PubKey
select {
case remotePubKey = <-keyCh:
default:
}
if remotePubKey == nil {
pconn.DecreaseCount()
scope.Done()
return nil, errors.New("p2p/transport/quic BUG: expected remote pub key to be set")
}
localMultiaddr, err := toQuicMultiaddr(pconn.LocalAddr())
if err != nil {
qconn.CloseWithError(0, "")
return nil, err
}
c := &conn{
quicConn: qconn,
pconn: pconn,
transport: t,
scope: scope,
privKey: t.privKey,
localPeer: t.localPeer,
localMultiaddr: localMultiaddr,
remotePubKey: remotePubKey,
remotePeerID: p,
remoteMultiaddr: remoteMultiaddr,
}
if t.gater != nil && !t.gater.InterceptSecured(network.DirOutbound, p, c) {
qconn.CloseWithError(errorCodeConnectionGating, "connection gated")
return nil, fmt.Errorf("secured connection gated")
}
t.addConn(qconn, c)
return c, nil
}
func (t *transport) addConn(conn quic.Connection, c *conn) {
t.connMx.Lock()
t.conns[conn] = c
t.connMx.Unlock()
}
func (t *transport) removeConn(conn quic.Connection) {
t.connMx.Lock()
delete(t.conns, conn)
t.connMx.Unlock()
}
func (t *transport) holePunch(ctx context.Context, network string, addr *net.UDPAddr, p peer.ID) (tpt.CapableConn, error) {
pconn, err := t.connManager.Dial(network, addr)
if err != nil {
return nil, err
}
defer pconn.DecreaseCount()
ctx, cancel := context.WithTimeout(ctx, HolePunchTimeout)
defer cancel()
key := holePunchKey{addr: addr.String(), peer: p}
t.holePunchingMx.Lock()
if _, ok := t.holePunching[key]; ok {
t.holePunchingMx.Unlock()
return nil, fmt.Errorf("already punching hole for %s", addr)
}
connCh := make(chan tpt.CapableConn, 1)
t.holePunching[key] = &activeHolePunch{connCh: connCh}
t.holePunchingMx.Unlock()
var timer *time.Timer
defer func() {
if timer != nil {
timer.Stop()
}
}()
payload := make([]byte, 64)
var punchErr error
loop:
for i := 0; ; i++ {
if _, err := rand.Read(payload); err != nil {
punchErr = err
break
}
if _, err := pconn.UDPConn.WriteToUDP(payload, addr); err != nil {
punchErr = err
break
}
maxSleep := 10 * (i + 1) * (i + 1) // in ms
if maxSleep > 200 {
maxSleep = 200
}
d := 10*time.Millisecond + time.Duration(rand.Intn(maxSleep))*time.Millisecond
if timer == nil {
timer = time.NewTimer(d)
} else {
timer.Reset(d)
}
select {
case c := <-connCh:
t.holePunchingMx.Lock()
delete(t.holePunching, key)
t.holePunchingMx.Unlock()
return c, nil
case <-timer.C:
case <-ctx.Done():
punchErr = ErrHolePunching
break loop
}
}
// we only arrive here if punchErr != nil
t.holePunchingMx.Lock()
defer func() {
delete(t.holePunching, key)
t.holePunchingMx.Unlock()
}()
select {
case c := <-t.holePunching[key].connCh:
return c, nil
default:
return nil, punchErr
}
}
// Don't use mafmt.QUIC as we don't want to dial DNS addresses. Just /ip{4,6}/udp/quic
var dialMatcher = mafmt.And(mafmt.IP, mafmt.Base(ma.P_UDP), mafmt.Base(ma.P_QUIC))
// CanDial determines if we can dial to an address
func (t *transport) CanDial(addr ma.Multiaddr) bool {
return dialMatcher.Matches(addr)
}
// Listen listens for new QUIC connections on the passed multiaddr.
func (t *transport) Listen(addr ma.Multiaddr) (tpt.Listener, error) {
lnet, host, err := manet.DialArgs(addr)
if err != nil {
return nil, err
}
laddr, err := net.ResolveUDPAddr(lnet, host)
if err != nil {
return nil, err
}
conn, err := t.connManager.Listen(lnet, laddr)
if err != nil {
return nil, err
}
ln, err := newListener(conn, t, t.localPeer, t.privKey, t.identity, t.rcmgr)
if err != nil {
conn.DecreaseCount()
return nil, err
}
return ln, nil
}
func (t *transport) allowWindowIncrease(conn quic.Connection, size uint64) bool {
// If the QUIC connection tries to increase the window before we've inserted it
// into our connections map (which we do right after dialing / accepting it),
// we have no way to account for that memory. This should be very rare.
// Block this attempt. The connection can request more memory later.
t.connMx.Lock()
c, ok := t.conns[conn]
t.connMx.Unlock()
if !ok {
return false
}
return c.allowWindowIncrease(size)
}
// Proxy returns true if this transport proxies.
func (t *transport) Proxy() bool {
return false
}
// Protocols returns the set of protocols handled by this transport.
func (t *transport) Protocols() []int {
return []int{ma.P_QUIC}
}
func (t *transport) String() string {
return "QUIC"
}
func (t *transport) Close() error {
return t.connManager.Close()
}

View File

@ -0,0 +1,99 @@
package libp2pquic
import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"errors"
"io"
"net"
"testing"
"github.com/stretchr/testify/require"
ic "github.com/libp2p/go-libp2p-core/crypto"
tpt "github.com/libp2p/go-libp2p-core/transport"
ma "github.com/multiformats/go-multiaddr"
"github.com/lucas-clemente/quic-go"
)
func getTransport(t *testing.T) tpt.Transport {
t.Helper()
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
key, err := ic.UnmarshalRsaPrivateKey(x509.MarshalPKCS1PrivateKey(rsaKey))
require.NoError(t, err)
tr, err := NewTransport(key, nil, nil, nil)
require.NoError(t, err)
return tr
}
func TestQUICProtocol(t *testing.T) {
tr := getTransport(t)
defer tr.(io.Closer).Close()
protocols := tr.Protocols()
if len(protocols) != 1 {
t.Fatalf("expected to only support a single protocol, got %v", protocols)
}
if protocols[0] != ma.P_QUIC {
t.Fatalf("expected the supported protocol to be QUIC, got %d", protocols[0])
}
}
func TestCanDial(t *testing.T) {
tr := getTransport(t)
defer tr.(io.Closer).Close()
invalid := []string{
"/ip4/127.0.0.1/udp/1234",
"/ip4/5.5.5.5/tcp/1234",
"/dns/google.com/udp/443/quic",
}
valid := []string{
"/ip4/127.0.0.1/udp/1234/quic",
"/ip4/5.5.5.5/udp/0/quic",
}
for _, s := range invalid {
invalidAddr, err := ma.NewMultiaddr(s)
require.NoError(t, err)
if tr.CanDial(invalidAddr) {
t.Errorf("didn't expect to be able to dial a non-quic address (%s)", invalidAddr)
}
}
for _, s := range valid {
validAddr, err := ma.NewMultiaddr(s)
require.NoError(t, err)
if !tr.CanDial(validAddr) {
t.Errorf("expected to be able to dial QUIC address (%s)", validAddr)
}
}
}
// The connection passed to quic-go needs to be type-assertable to a net.UDPConn,
// in order to enable features like batch processing and ECN.
func TestConnectionPassedToQUIC(t *testing.T) {
tr := getTransport(t)
defer tr.(io.Closer).Close()
origQuicDialContext := quicDialContext
defer func() { quicDialContext = origQuicDialContext }()
var conn net.PacketConn
quicDialContext = func(_ context.Context, c net.PacketConn, _ net.Addr, _ string, _ *tls.Config, _ *quic.Config) (quic.Connection, error) {
conn = c
return nil, errors.New("listen error")
}
remoteAddr, err := ma.NewMultiaddr("/ip4/127.0.0.1/udp/0/quic")
require.NoError(t, err)
_, err = tr.Dial(context.Background(), remoteAddr, "remote peer id")
require.EqualError(t, err, "listen error")
require.NotNil(t, conn)
defer conn.Close()
if _, ok := conn.(quic.OOBCapablePacketConn); !ok {
t.Fatal("connection passed to quic-go cannot be type asserted to a *net.UDPConn")
}
}