510 lines
15 KiB
Go
510 lines
15 KiB
Go
package autonatv2
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"sync"
|
|
"time"
|
|
|
|
pool "github.com/libp2p/go-buffer-pool"
|
|
"github.com/libp2p/go-libp2p/core/host"
|
|
"github.com/libp2p/go-libp2p/core/network"
|
|
"github.com/libp2p/go-libp2p/core/peer"
|
|
"github.com/libp2p/go-libp2p/core/peerstore"
|
|
"github.com/libp2p/go-libp2p/p2p/protocol/autonatv2/pb"
|
|
"github.com/libp2p/go-msgio/pbio"
|
|
|
|
"math/rand"
|
|
|
|
ma "github.com/multiformats/go-multiaddr"
|
|
manet "github.com/multiformats/go-multiaddr/net"
|
|
)
|
|
|
|
var (
|
|
errResourceLimitExceeded = errors.New("resource limit exceeded")
|
|
errBadRequest = errors.New("bad request")
|
|
errDialDataRefused = errors.New("dial data refused")
|
|
)
|
|
|
|
type dataRequestPolicyFunc = func(s network.Stream, dialAddr ma.Multiaddr) bool
|
|
|
|
type EventDialRequestCompleted struct {
|
|
Error error
|
|
ResponseStatus pb.DialResponse_ResponseStatus
|
|
DialStatus pb.DialStatus
|
|
DialDataRequired bool
|
|
DialedAddr ma.Multiaddr
|
|
}
|
|
|
|
// server implements the AutoNATv2 server.
|
|
// It can ask client to provide dial data before attempting the requested dial.
|
|
// It rate limits requests on a global level, per peer level and on whether the request requires dial data.
|
|
type server struct {
|
|
host host.Host
|
|
dialerHost host.Host
|
|
limiter *rateLimiter
|
|
|
|
// dialDataRequestPolicy is used to determine whether dialing the address requires receiving
|
|
// dial data. It is set to amplification attack prevention by default.
|
|
dialDataRequestPolicy dataRequestPolicyFunc
|
|
amplificatonAttackPreventionDialWait time.Duration
|
|
metricsTracer MetricsTracer
|
|
|
|
// for tests
|
|
now func() time.Time
|
|
allowPrivateAddrs bool
|
|
}
|
|
|
|
func newServer(host, dialer host.Host, s *autoNATSettings) *server {
|
|
return &server{
|
|
dialerHost: dialer,
|
|
host: host,
|
|
dialDataRequestPolicy: s.dataRequestPolicy,
|
|
amplificatonAttackPreventionDialWait: s.amplificatonAttackPreventionDialWait,
|
|
allowPrivateAddrs: s.allowPrivateAddrs,
|
|
limiter: &rateLimiter{
|
|
RPM: s.serverRPM,
|
|
PerPeerRPM: s.serverPerPeerRPM,
|
|
DialDataRPM: s.serverDialDataRPM,
|
|
now: s.now,
|
|
},
|
|
now: s.now,
|
|
metricsTracer: s.metricsTracer,
|
|
}
|
|
}
|
|
|
|
// Enable attaches the stream handler to the host.
|
|
func (as *server) Start() {
|
|
as.host.SetStreamHandler(DialProtocol, as.handleDialRequest)
|
|
}
|
|
|
|
func (as *server) Close() {
|
|
as.host.RemoveStreamHandler(DialProtocol)
|
|
as.dialerHost.Close()
|
|
as.limiter.Close()
|
|
}
|
|
|
|
// handleDialRequest is the dial-request protocol stream handler
|
|
func (as *server) handleDialRequest(s network.Stream) {
|
|
evt := as.serveDialRequest(s)
|
|
log.Debugf("completed dial-request from %s, response status: %s, dial status: %s, err: %s",
|
|
s.Conn().RemotePeer(), evt.ResponseStatus, evt.DialStatus, evt.Error)
|
|
if as.metricsTracer != nil {
|
|
as.metricsTracer.CompletedRequest(evt)
|
|
}
|
|
}
|
|
|
|
func (as *server) serveDialRequest(s network.Stream) EventDialRequestCompleted {
|
|
if err := s.Scope().SetService(ServiceName); err != nil {
|
|
s.Reset()
|
|
log.Debugf("failed to attach stream to %s service: %w", ServiceName, err)
|
|
return EventDialRequestCompleted{
|
|
Error: errors.New("failed to attach stream to autonat-v2"),
|
|
}
|
|
}
|
|
|
|
if err := s.Scope().ReserveMemory(maxMsgSize, network.ReservationPriorityAlways); err != nil {
|
|
s.Reset()
|
|
log.Debugf("failed to reserve memory for stream %s: %w", DialProtocol, err)
|
|
return EventDialRequestCompleted{Error: errResourceLimitExceeded}
|
|
}
|
|
defer s.Scope().ReleaseMemory(maxMsgSize)
|
|
|
|
deadline := as.now().Add(streamTimeout)
|
|
ctx, cancel := context.WithDeadline(context.Background(), deadline)
|
|
defer cancel()
|
|
s.SetDeadline(as.now().Add(streamTimeout))
|
|
defer s.Close()
|
|
|
|
p := s.Conn().RemotePeer()
|
|
|
|
var msg pb.Message
|
|
w := pbio.NewDelimitedWriter(s)
|
|
// Check for rate limit before parsing the request
|
|
if !as.limiter.Accept(p) {
|
|
msg = pb.Message{
|
|
Msg: &pb.Message_DialResponse{
|
|
DialResponse: &pb.DialResponse{
|
|
Status: pb.DialResponse_E_REQUEST_REJECTED,
|
|
},
|
|
},
|
|
}
|
|
if err := w.WriteMsg(&msg); err != nil {
|
|
s.Reset()
|
|
log.Debugf("failed to write request rejected response to %s: %s", p, err)
|
|
return EventDialRequestCompleted{
|
|
ResponseStatus: pb.DialResponse_E_REQUEST_REJECTED,
|
|
Error: fmt.Errorf("write failed: %w", err),
|
|
}
|
|
}
|
|
log.Debugf("rejected request from %s: rate limit exceeded", p)
|
|
return EventDialRequestCompleted{ResponseStatus: pb.DialResponse_E_REQUEST_REJECTED}
|
|
}
|
|
defer as.limiter.CompleteRequest(p)
|
|
|
|
r := pbio.NewDelimitedReader(s, maxMsgSize)
|
|
if err := r.ReadMsg(&msg); err != nil {
|
|
s.Reset()
|
|
log.Debugf("failed to read request from %s: %s", p, err)
|
|
return EventDialRequestCompleted{Error: fmt.Errorf("read failed: %w", err)}
|
|
}
|
|
if msg.GetDialRequest() == nil {
|
|
s.Reset()
|
|
log.Debugf("invalid message type from %s: %T expected: DialRequest", p, msg.Msg)
|
|
return EventDialRequestCompleted{Error: errBadRequest}
|
|
}
|
|
|
|
// parse peer's addresses
|
|
var dialAddr ma.Multiaddr
|
|
var addrIdx int
|
|
for i, ab := range msg.GetDialRequest().GetAddrs() {
|
|
if i >= maxPeerAddresses {
|
|
break
|
|
}
|
|
a, err := ma.NewMultiaddrBytes(ab)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
if !as.allowPrivateAddrs && !manet.IsPublicAddr(a) {
|
|
continue
|
|
}
|
|
if !as.dialerHost.Network().CanDial(p, a) {
|
|
continue
|
|
}
|
|
dialAddr = a
|
|
addrIdx = i
|
|
break
|
|
}
|
|
// No dialable address
|
|
if dialAddr == nil {
|
|
msg = pb.Message{
|
|
Msg: &pb.Message_DialResponse{
|
|
DialResponse: &pb.DialResponse{
|
|
Status: pb.DialResponse_E_DIAL_REFUSED,
|
|
},
|
|
},
|
|
}
|
|
if err := w.WriteMsg(&msg); err != nil {
|
|
s.Reset()
|
|
log.Debugf("failed to write dial refused response to %s: %s", p, err)
|
|
return EventDialRequestCompleted{
|
|
ResponseStatus: pb.DialResponse_E_DIAL_REFUSED,
|
|
Error: fmt.Errorf("write failed: %w", err),
|
|
}
|
|
}
|
|
return EventDialRequestCompleted{
|
|
ResponseStatus: pb.DialResponse_E_DIAL_REFUSED,
|
|
}
|
|
}
|
|
|
|
nonce := msg.GetDialRequest().Nonce
|
|
|
|
isDialDataRequired := as.dialDataRequestPolicy(s, dialAddr)
|
|
if isDialDataRequired && !as.limiter.AcceptDialDataRequest(p) {
|
|
msg = pb.Message{
|
|
Msg: &pb.Message_DialResponse{
|
|
DialResponse: &pb.DialResponse{
|
|
Status: pb.DialResponse_E_REQUEST_REJECTED,
|
|
},
|
|
},
|
|
}
|
|
if err := w.WriteMsg(&msg); err != nil {
|
|
s.Reset()
|
|
log.Debugf("failed to write request rejected response to %s: %s", p, err)
|
|
return EventDialRequestCompleted{
|
|
ResponseStatus: pb.DialResponse_E_REQUEST_REJECTED,
|
|
Error: fmt.Errorf("write failed: %w", err),
|
|
DialDataRequired: true,
|
|
}
|
|
}
|
|
log.Debugf("rejected request from %s: rate limit exceeded", p)
|
|
return EventDialRequestCompleted{
|
|
ResponseStatus: pb.DialResponse_E_REQUEST_REJECTED,
|
|
DialDataRequired: true,
|
|
}
|
|
}
|
|
|
|
if isDialDataRequired {
|
|
if err := getDialData(w, s, &msg, addrIdx); err != nil {
|
|
s.Reset()
|
|
log.Debugf("%s refused dial data request: %s", p, err)
|
|
return EventDialRequestCompleted{
|
|
Error: errDialDataRefused,
|
|
DialDataRequired: true,
|
|
DialedAddr: dialAddr,
|
|
}
|
|
}
|
|
// wait for a bit to prevent thundering herd style attacks on a victim
|
|
waitTime := time.Duration(rand.Intn(int(as.amplificatonAttackPreventionDialWait) + 1)) // the range is [0, n)
|
|
t := time.NewTimer(waitTime)
|
|
defer t.Stop()
|
|
select {
|
|
case <-ctx.Done():
|
|
s.Reset()
|
|
log.Debugf("rejecting request without dialing: %s %p ", p, ctx.Err())
|
|
return EventDialRequestCompleted{Error: ctx.Err(), DialDataRequired: true, DialedAddr: dialAddr}
|
|
case <-t.C:
|
|
}
|
|
}
|
|
|
|
dialStatus := as.dialBack(ctx, s.Conn().RemotePeer(), dialAddr, nonce)
|
|
msg = pb.Message{
|
|
Msg: &pb.Message_DialResponse{
|
|
DialResponse: &pb.DialResponse{
|
|
Status: pb.DialResponse_OK,
|
|
DialStatus: dialStatus,
|
|
AddrIdx: uint32(addrIdx),
|
|
},
|
|
},
|
|
}
|
|
if err := w.WriteMsg(&msg); err != nil {
|
|
s.Reset()
|
|
log.Debugf("failed to write response to %s: %s", p, err)
|
|
return EventDialRequestCompleted{
|
|
ResponseStatus: pb.DialResponse_OK,
|
|
DialStatus: dialStatus,
|
|
Error: fmt.Errorf("write failed: %w", err),
|
|
DialDataRequired: isDialDataRequired,
|
|
DialedAddr: dialAddr,
|
|
}
|
|
}
|
|
return EventDialRequestCompleted{
|
|
ResponseStatus: pb.DialResponse_OK,
|
|
DialStatus: dialStatus,
|
|
Error: nil,
|
|
DialDataRequired: isDialDataRequired,
|
|
DialedAddr: dialAddr,
|
|
}
|
|
}
|
|
|
|
// getDialData gets data from the client for dialing the address
|
|
func getDialData(w pbio.Writer, s network.Stream, msg *pb.Message, addrIdx int) error {
|
|
numBytes := minHandshakeSizeBytes + rand.Intn(maxHandshakeSizeBytes-minHandshakeSizeBytes)
|
|
*msg = pb.Message{
|
|
Msg: &pb.Message_DialDataRequest{
|
|
DialDataRequest: &pb.DialDataRequest{
|
|
AddrIdx: uint32(addrIdx),
|
|
NumBytes: uint64(numBytes),
|
|
},
|
|
},
|
|
}
|
|
if err := w.WriteMsg(msg); err != nil {
|
|
return fmt.Errorf("dial data write: %w", err)
|
|
}
|
|
// pbio.Reader that we used so far on this stream is buffered. But at this point
|
|
// there is nothing unread on the stream. So it is safe to use the raw stream to
|
|
// read, reducing allocations.
|
|
return readDialData(numBytes, s)
|
|
}
|
|
|
|
func readDialData(numBytes int, r io.Reader) error {
|
|
mr := &msgReader{R: r, Buf: pool.Get(maxMsgSize)}
|
|
defer pool.Put(mr.Buf)
|
|
for remain := numBytes; remain > 0; {
|
|
msg, err := mr.ReadMsg()
|
|
if err != nil {
|
|
return fmt.Errorf("dial data read: %w", err)
|
|
}
|
|
// protobuf format is:
|
|
// (oneof dialDataResponse:<fieldTag><len varint>)(dial data:<fieldTag><len varint><bytes>)
|
|
bytesLen := len(msg)
|
|
bytesLen -= 2 // fieldTag + varint first byte
|
|
if bytesLen > 127 {
|
|
bytesLen -= 1 // varint second byte
|
|
}
|
|
bytesLen -= 2 // second fieldTag + varint first byte
|
|
if bytesLen > 127 {
|
|
bytesLen -= 1 // varint second byte
|
|
}
|
|
if bytesLen > 0 {
|
|
remain -= bytesLen
|
|
}
|
|
// Check if the peer is not sending too little data forcing us to just do a lot of compute
|
|
if bytesLen < 100 && remain > 0 {
|
|
return fmt.Errorf("dial data msg too small: %d", bytesLen)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (as *server) dialBack(ctx context.Context, p peer.ID, addr ma.Multiaddr, nonce uint64) pb.DialStatus {
|
|
ctx, cancel := context.WithTimeout(ctx, dialBackDialTimeout)
|
|
ctx = network.WithForceDirectDial(ctx, "autonatv2")
|
|
as.dialerHost.Peerstore().AddAddr(p, addr, peerstore.TempAddrTTL)
|
|
defer func() {
|
|
cancel()
|
|
as.dialerHost.Network().ClosePeer(p)
|
|
as.dialerHost.Peerstore().ClearAddrs(p)
|
|
as.dialerHost.Peerstore().RemovePeer(p)
|
|
}()
|
|
|
|
err := as.dialerHost.Connect(ctx, peer.AddrInfo{ID: p})
|
|
if err != nil {
|
|
return pb.DialStatus_E_DIAL_ERROR
|
|
}
|
|
|
|
s, err := as.dialerHost.NewStream(ctx, p, DialBackProtocol)
|
|
if err != nil {
|
|
return pb.DialStatus_E_DIAL_BACK_ERROR
|
|
}
|
|
|
|
defer s.Close()
|
|
s.SetDeadline(as.now().Add(dialBackStreamTimeout))
|
|
|
|
w := pbio.NewDelimitedWriter(s)
|
|
if err := w.WriteMsg(&pb.DialBack{Nonce: nonce}); err != nil {
|
|
s.Reset()
|
|
return pb.DialStatus_E_DIAL_BACK_ERROR
|
|
}
|
|
|
|
// Since the underlying connection is on a separate dialer, it'll be closed after this
|
|
// function returns. Connection close will drop all the queued writes. To ensure message
|
|
// delivery, do a CloseWrite and read a byte from the stream. The peer actually sends a
|
|
// response of type DialBackResponse but we only care about the fact that the DialBack
|
|
// message has reached the peer. So we ignore that message on the read side.
|
|
s.CloseWrite()
|
|
s.SetDeadline(as.now().Add(5 * time.Second)) // 5 is a magic number
|
|
b := make([]byte, 1) // Read 1 byte here because 0 len reads are free to return (0, nil) immediately
|
|
s.Read(b)
|
|
|
|
return pb.DialStatus_OK
|
|
}
|
|
|
|
// rateLimiter implements a sliding window rate limit of requests per minute. It allows 1 concurrent request
|
|
// per peer. It rate limits requests globally, at a peer level and depending on whether it requires dial data.
|
|
type rateLimiter struct {
|
|
// PerPeerRPM is the rate limit per peer
|
|
PerPeerRPM int
|
|
// RPM is the global rate limit
|
|
RPM int
|
|
// DialDataRPM is the rate limit for requests that require dial data
|
|
DialDataRPM int
|
|
|
|
mu sync.Mutex
|
|
closed bool
|
|
reqs []entry
|
|
peerReqs map[peer.ID][]time.Time
|
|
dialDataReqs []time.Time
|
|
// ongoingReqs tracks in progress requests. This is used to disallow multiple concurrent requests by the
|
|
// same peer
|
|
// TODO: Should we allow a few concurrent requests per peer?
|
|
ongoingReqs map[peer.ID]struct{}
|
|
|
|
now func() time.Time // for tests
|
|
}
|
|
|
|
type entry struct {
|
|
PeerID peer.ID
|
|
Time time.Time
|
|
}
|
|
|
|
func (r *rateLimiter) Accept(p peer.ID) bool {
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
if r.closed {
|
|
return false
|
|
}
|
|
if r.peerReqs == nil {
|
|
r.peerReqs = make(map[peer.ID][]time.Time)
|
|
r.ongoingReqs = make(map[peer.ID]struct{})
|
|
}
|
|
|
|
nw := r.now()
|
|
r.cleanup(nw)
|
|
|
|
if _, ok := r.ongoingReqs[p]; ok {
|
|
return false
|
|
}
|
|
if len(r.reqs) >= r.RPM || len(r.peerReqs[p]) >= r.PerPeerRPM {
|
|
return false
|
|
}
|
|
|
|
r.ongoingReqs[p] = struct{}{}
|
|
r.reqs = append(r.reqs, entry{PeerID: p, Time: nw})
|
|
r.peerReqs[p] = append(r.peerReqs[p], nw)
|
|
return true
|
|
}
|
|
|
|
func (r *rateLimiter) AcceptDialDataRequest(p peer.ID) bool {
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
if r.closed {
|
|
return false
|
|
}
|
|
if r.peerReqs == nil {
|
|
r.peerReqs = make(map[peer.ID][]time.Time)
|
|
r.ongoingReqs = make(map[peer.ID]struct{})
|
|
}
|
|
nw := r.now()
|
|
r.cleanup(nw)
|
|
if len(r.dialDataReqs) >= r.DialDataRPM {
|
|
return false
|
|
}
|
|
r.dialDataReqs = append(r.dialDataReqs, nw)
|
|
return true
|
|
}
|
|
|
|
// cleanup removes stale requests.
|
|
//
|
|
// This is fast enough in rate limited cases and the state is small enough to
|
|
// clean up quickly when blocking requests.
|
|
func (r *rateLimiter) cleanup(now time.Time) {
|
|
idx := len(r.reqs)
|
|
for i, e := range r.reqs {
|
|
if now.Sub(e.Time) >= time.Minute {
|
|
pi := len(r.peerReqs[e.PeerID])
|
|
for j, t := range r.peerReqs[e.PeerID] {
|
|
if now.Sub(t) < time.Minute {
|
|
pi = j
|
|
break
|
|
}
|
|
}
|
|
r.peerReqs[e.PeerID] = r.peerReqs[e.PeerID][pi:]
|
|
if len(r.peerReqs[e.PeerID]) == 0 {
|
|
delete(r.peerReqs, e.PeerID)
|
|
}
|
|
} else {
|
|
idx = i
|
|
break
|
|
}
|
|
}
|
|
r.reqs = r.reqs[idx:]
|
|
|
|
idx = len(r.dialDataReqs)
|
|
for i, t := range r.dialDataReqs {
|
|
if now.Sub(t) < time.Minute {
|
|
idx = i
|
|
break
|
|
}
|
|
}
|
|
r.dialDataReqs = r.dialDataReqs[idx:]
|
|
}
|
|
|
|
func (r *rateLimiter) CompleteRequest(p peer.ID) {
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
delete(r.ongoingReqs, p)
|
|
}
|
|
|
|
func (r *rateLimiter) Close() {
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
r.closed = true
|
|
r.peerReqs = nil
|
|
r.ongoingReqs = nil
|
|
r.dialDataReqs = nil
|
|
}
|
|
|
|
// amplificationAttackPrevention is a dialDataRequestPolicy which requests data when the peer's observed
|
|
// IP address is different from the dial back IP address
|
|
func amplificationAttackPrevention(s network.Stream, dialAddr ma.Multiaddr) bool {
|
|
connIP, err := manet.ToIP(s.Conn().RemoteMultiaddr())
|
|
if err != nil {
|
|
return true
|
|
}
|
|
dialIP, _ := manet.ToIP(s.Conn().LocalMultiaddr()) // must be an IP multiaddr
|
|
return !connIP.Equal(dialIP)
|
|
}
|