492 lines
12 KiB
Go
492 lines
12 KiB
Go
package relay
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
pb "github.com/libp2p/go-libp2p-circuit/pb"
|
|
|
|
"github.com/libp2p/go-libp2p-core/helpers"
|
|
"github.com/libp2p/go-libp2p-core/host"
|
|
"github.com/libp2p/go-libp2p-core/network"
|
|
"github.com/libp2p/go-libp2p-core/peer"
|
|
"github.com/libp2p/go-libp2p-core/peerstore"
|
|
|
|
pool "github.com/libp2p/go-buffer-pool"
|
|
tptu "github.com/libp2p/go-libp2p-transport-upgrader"
|
|
|
|
logging "github.com/ipfs/go-log"
|
|
ma "github.com/multiformats/go-multiaddr"
|
|
)
|
|
|
|
var log = logging.Logger("relay")
|
|
|
|
const ProtoID = "/libp2p/circuit/relay/0.1.0"
|
|
|
|
const maxMessageSize = 4096
|
|
|
|
var (
|
|
RelayAcceptTimeout = 10 * time.Second
|
|
HopConnectTimeout = 30 * time.Second
|
|
StopHandshakeTimeout = 1 * time.Minute
|
|
|
|
HopStreamBufferSize = 4096
|
|
HopStreamLimit = 1 << 19 // 512K hops for 1M goroutines
|
|
)
|
|
|
|
// Relay is the relay transport and service.
|
|
type Relay struct {
|
|
host host.Host
|
|
upgrader *tptu.Upgrader
|
|
ctx context.Context
|
|
self peer.ID
|
|
|
|
active bool
|
|
hop bool
|
|
discovery bool
|
|
|
|
incoming chan *Conn
|
|
|
|
relays map[peer.ID]struct{}
|
|
mx sync.Mutex
|
|
|
|
// atomic counters
|
|
streamCount int32
|
|
liveHopCount int32
|
|
}
|
|
|
|
// RelayOpts are options for configuring the relay transport.
|
|
type RelayOpt int
|
|
|
|
var (
|
|
// OptActive configures the relay transport to actively establish
|
|
// outbound connections on behalf of clients. You probably don't want to
|
|
// enable this unless you know what you're doing.
|
|
OptActive = RelayOpt(0)
|
|
// OptHop configures the relay transport to accept requests to relay
|
|
// traffic on behalf of third-parties. Unless OptActive is specified,
|
|
// this will only relay traffic between peers already connected to this
|
|
// node.
|
|
OptHop = RelayOpt(1)
|
|
// OptDiscovery configures this relay transport to discover new relays
|
|
// by probing every new peer. You almost _certainly_ don't want to
|
|
// enable this.
|
|
OptDiscovery = RelayOpt(2)
|
|
)
|
|
|
|
type RelayError struct {
|
|
Code pb.CircuitRelay_Status
|
|
}
|
|
|
|
func (e RelayError) Error() string {
|
|
return fmt.Sprintf("error opening relay circuit: %s (%d)", pb.CircuitRelay_Status_name[int32(e.Code)], e.Code)
|
|
}
|
|
|
|
// NewRelay constructs a new relay.
|
|
func NewRelay(ctx context.Context, h host.Host, upgrader *tptu.Upgrader, opts ...RelayOpt) (*Relay, error) {
|
|
r := &Relay{
|
|
upgrader: upgrader,
|
|
host: h,
|
|
ctx: ctx,
|
|
self: h.ID(),
|
|
incoming: make(chan *Conn),
|
|
relays: make(map[peer.ID]struct{}),
|
|
}
|
|
|
|
for _, opt := range opts {
|
|
switch opt {
|
|
case OptActive:
|
|
r.active = true
|
|
case OptHop:
|
|
r.hop = true
|
|
case OptDiscovery:
|
|
r.discovery = true
|
|
default:
|
|
return nil, fmt.Errorf("unrecognized option: %d", opt)
|
|
}
|
|
}
|
|
|
|
h.SetStreamHandler(ProtoID, r.handleNewStream)
|
|
|
|
if r.discovery {
|
|
h.Network().Notify(r.notifiee())
|
|
}
|
|
|
|
return r, nil
|
|
}
|
|
|
|
// Increment the live hop count and increment the connection manager tags by 1 for the two
|
|
// sides of the hop stream. This ensures that connections with many hop streams will be protected
|
|
// from pruning, thus minimizing disruption from connection trimming in a relay node.
|
|
func (r *Relay) addLiveHop(from, to peer.ID) {
|
|
atomic.AddInt32(&r.liveHopCount, 1)
|
|
r.host.ConnManager().UpsertTag(from, "relay-hop-stream", incrementTag)
|
|
r.host.ConnManager().UpsertTag(to, "relay-hop-stream", incrementTag)
|
|
}
|
|
|
|
// Decrement the live hpo count and decrement the connection manager tags for the two sides
|
|
// of the hop stream.
|
|
func (r *Relay) rmLiveHop(from, to peer.ID) {
|
|
atomic.AddInt32(&r.liveHopCount, -1)
|
|
r.host.ConnManager().UpsertTag(from, "relay-hop-stream", decrementTag)
|
|
r.host.ConnManager().UpsertTag(to, "relay-hop-stream", decrementTag)
|
|
|
|
}
|
|
|
|
func (r *Relay) GetActiveHops() int32 {
|
|
return atomic.LoadInt32(&r.liveHopCount)
|
|
}
|
|
|
|
func (r *Relay) DialPeer(ctx context.Context, relay peer.AddrInfo, dest peer.AddrInfo) (*Conn, error) {
|
|
|
|
log.Debugf("dialing peer %s through relay %s", dest.ID, relay.ID)
|
|
|
|
if len(relay.Addrs) > 0 {
|
|
r.host.Peerstore().AddAddrs(relay.ID, relay.Addrs, peerstore.TempAddrTTL)
|
|
}
|
|
|
|
s, err := r.host.NewStream(ctx, relay.ID, ProtoID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
rd := newDelimitedReader(s, maxMessageSize)
|
|
wr := newDelimitedWriter(s)
|
|
defer rd.Close()
|
|
|
|
var msg pb.CircuitRelay
|
|
|
|
msg.Type = pb.CircuitRelay_HOP.Enum()
|
|
msg.SrcPeer = peerInfoToPeer(r.host.Peerstore().PeerInfo(r.self))
|
|
msg.DstPeer = peerInfoToPeer(dest)
|
|
|
|
err = wr.WriteMsg(&msg)
|
|
if err != nil {
|
|
s.Reset()
|
|
return nil, err
|
|
}
|
|
|
|
msg.Reset()
|
|
|
|
err = rd.ReadMsg(&msg)
|
|
if err != nil {
|
|
s.Reset()
|
|
return nil, err
|
|
}
|
|
|
|
if msg.GetType() != pb.CircuitRelay_STATUS {
|
|
s.Reset()
|
|
return nil, fmt.Errorf("unexpected relay response; not a status message (%d)", msg.GetType())
|
|
}
|
|
|
|
if msg.GetCode() != pb.CircuitRelay_SUCCESS {
|
|
s.Reset()
|
|
return nil, RelayError{msg.GetCode()}
|
|
}
|
|
|
|
return &Conn{stream: s, remote: dest, host: r.host}, nil
|
|
}
|
|
|
|
func (r *Relay) Matches(addr ma.Multiaddr) bool {
|
|
// TODO: Look at the prefix transport as well.
|
|
_, err := addr.ValueForProtocol(P_CIRCUIT)
|
|
return err == nil
|
|
}
|
|
|
|
func (r *Relay) CanHop(ctx context.Context, id peer.ID) (bool, error) {
|
|
s, err := r.host.NewStream(ctx, id, ProtoID)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
rd := newDelimitedReader(s, maxMessageSize)
|
|
wr := newDelimitedWriter(s)
|
|
defer rd.Close()
|
|
|
|
var msg pb.CircuitRelay
|
|
|
|
msg.Type = pb.CircuitRelay_CAN_HOP.Enum()
|
|
|
|
if err := wr.WriteMsg(&msg); err != nil {
|
|
s.Reset()
|
|
return false, err
|
|
}
|
|
|
|
msg.Reset()
|
|
|
|
if err := rd.ReadMsg(&msg); err != nil {
|
|
s.Reset()
|
|
return false, err
|
|
}
|
|
if err := helpers.FullClose(s); err != nil {
|
|
return false, err
|
|
}
|
|
|
|
if msg.GetType() != pb.CircuitRelay_STATUS {
|
|
return false, fmt.Errorf("unexpected relay response; not a status message (%d)", msg.GetType())
|
|
}
|
|
|
|
return msg.GetCode() == pb.CircuitRelay_SUCCESS, nil
|
|
}
|
|
|
|
func (r *Relay) handleNewStream(s network.Stream) {
|
|
log.Infof("new relay stream from: %s", s.Conn().RemotePeer())
|
|
|
|
rd := newDelimitedReader(s, maxMessageSize)
|
|
defer rd.Close()
|
|
|
|
var msg pb.CircuitRelay
|
|
|
|
err := rd.ReadMsg(&msg)
|
|
if err != nil {
|
|
r.handleError(s, pb.CircuitRelay_MALFORMED_MESSAGE)
|
|
return
|
|
}
|
|
|
|
switch msg.GetType() {
|
|
case pb.CircuitRelay_HOP:
|
|
r.handleHopStream(s, &msg)
|
|
case pb.CircuitRelay_STOP:
|
|
r.handleStopStream(s, &msg)
|
|
case pb.CircuitRelay_CAN_HOP:
|
|
r.handleCanHop(s, &msg)
|
|
default:
|
|
log.Warningf("unexpected relay handshake: %d", msg.GetType())
|
|
r.handleError(s, pb.CircuitRelay_MALFORMED_MESSAGE)
|
|
}
|
|
}
|
|
|
|
func (r *Relay) handleHopStream(s network.Stream, msg *pb.CircuitRelay) {
|
|
if !r.hop {
|
|
r.handleError(s, pb.CircuitRelay_HOP_CANT_SPEAK_RELAY)
|
|
return
|
|
}
|
|
|
|
streamCount := atomic.AddInt32(&r.streamCount, 1)
|
|
liveHopCount := atomic.LoadInt32(&r.liveHopCount)
|
|
defer atomic.AddInt32(&r.streamCount, -1)
|
|
|
|
if (streamCount + liveHopCount) > int32(HopStreamLimit) {
|
|
log.Warning("hop stream limit exceeded; resetting stream")
|
|
s.Reset()
|
|
return
|
|
}
|
|
|
|
src, err := peerToPeerInfo(msg.GetSrcPeer())
|
|
if err != nil {
|
|
r.handleError(s, pb.CircuitRelay_HOP_SRC_MULTIADDR_INVALID)
|
|
return
|
|
}
|
|
|
|
if src.ID != s.Conn().RemotePeer() {
|
|
r.handleError(s, pb.CircuitRelay_HOP_SRC_MULTIADDR_INVALID)
|
|
return
|
|
}
|
|
|
|
dst, err := peerToPeerInfo(msg.GetDstPeer())
|
|
if err != nil {
|
|
r.handleError(s, pb.CircuitRelay_HOP_DST_MULTIADDR_INVALID)
|
|
return
|
|
}
|
|
|
|
if dst.ID == r.self {
|
|
r.handleError(s, pb.CircuitRelay_HOP_CANT_RELAY_TO_SELF)
|
|
return
|
|
}
|
|
|
|
// open stream
|
|
ctx, cancel := context.WithTimeout(r.ctx, HopConnectTimeout)
|
|
defer cancel()
|
|
|
|
if !r.active {
|
|
ctx = network.WithNoDial(ctx, "relay hop")
|
|
} else if len(dst.Addrs) > 0 {
|
|
r.host.Peerstore().AddAddrs(dst.ID, dst.Addrs, peerstore.TempAddrTTL)
|
|
}
|
|
|
|
bs, err := r.host.NewStream(ctx, dst.ID, ProtoID)
|
|
if err != nil {
|
|
log.Debugf("error opening relay stream to %s: %s", dst.ID.Pretty(), err.Error())
|
|
if err == network.ErrNoConn {
|
|
r.handleError(s, pb.CircuitRelay_HOP_NO_CONN_TO_DST)
|
|
} else {
|
|
r.handleError(s, pb.CircuitRelay_HOP_CANT_DIAL_DST)
|
|
}
|
|
return
|
|
}
|
|
|
|
// stop handshake
|
|
rd := newDelimitedReader(bs, maxMessageSize)
|
|
wr := newDelimitedWriter(bs)
|
|
defer rd.Close()
|
|
|
|
// set handshake deadline
|
|
bs.SetDeadline(time.Now().Add(StopHandshakeTimeout))
|
|
|
|
msg.Type = pb.CircuitRelay_STOP.Enum()
|
|
|
|
err = wr.WriteMsg(msg)
|
|
if err != nil {
|
|
log.Debugf("error writing stop handshake: %s", err.Error())
|
|
bs.Reset()
|
|
r.handleError(s, pb.CircuitRelay_HOP_CANT_OPEN_DST_STREAM)
|
|
return
|
|
}
|
|
|
|
msg.Reset()
|
|
|
|
err = rd.ReadMsg(msg)
|
|
if err != nil {
|
|
log.Debugf("error reading stop response: %s", err.Error())
|
|
bs.Reset()
|
|
r.handleError(s, pb.CircuitRelay_HOP_CANT_OPEN_DST_STREAM)
|
|
return
|
|
}
|
|
|
|
if msg.GetType() != pb.CircuitRelay_STATUS {
|
|
log.Debugf("unexpected relay stop response: not a status message (%d)", msg.GetType())
|
|
bs.Reset()
|
|
r.handleError(s, pb.CircuitRelay_HOP_CANT_OPEN_DST_STREAM)
|
|
return
|
|
}
|
|
|
|
if msg.GetCode() != pb.CircuitRelay_SUCCESS {
|
|
log.Debugf("relay stop failure: %d", msg.GetCode())
|
|
bs.Reset()
|
|
r.handleError(s, msg.GetCode())
|
|
return
|
|
}
|
|
|
|
err = r.writeResponse(s, pb.CircuitRelay_SUCCESS)
|
|
if err != nil {
|
|
log.Debugf("error writing relay response: %s", err.Error())
|
|
bs.Reset()
|
|
s.Reset()
|
|
return
|
|
}
|
|
|
|
// relay connection
|
|
log.Infof("relaying connection between %s and %s", src.ID.Pretty(), dst.ID.Pretty())
|
|
|
|
// reset deadline
|
|
bs.SetDeadline(time.Time{})
|
|
|
|
r.addLiveHop(src.ID, dst.ID)
|
|
|
|
goroutines := new(int32)
|
|
*goroutines = 2
|
|
done := func() {
|
|
if atomic.AddInt32(goroutines, -1) == 0 {
|
|
r.rmLiveHop(src.ID, dst.ID)
|
|
}
|
|
}
|
|
|
|
// Don't reset streams after finishing or the other side will get an
|
|
// error, not an EOF.
|
|
go func() {
|
|
defer done()
|
|
|
|
buf := pool.Get(HopStreamBufferSize)
|
|
defer pool.Put(buf)
|
|
|
|
count, err := io.CopyBuffer(s, bs, buf)
|
|
if err != nil {
|
|
log.Debugf("relay copy error: %s", err)
|
|
// Reset both.
|
|
s.Reset()
|
|
bs.Reset()
|
|
} else {
|
|
// propagate the close
|
|
s.Close()
|
|
}
|
|
log.Debugf("relayed %d bytes from %s to %s", count, dst.ID.Pretty(), src.ID.Pretty())
|
|
}()
|
|
|
|
go func() {
|
|
defer done()
|
|
|
|
buf := pool.Get(HopStreamBufferSize)
|
|
defer pool.Put(buf)
|
|
|
|
count, err := io.CopyBuffer(bs, s, buf)
|
|
if err != nil {
|
|
log.Debugf("relay copy error: %s", err)
|
|
// Reset both.
|
|
bs.Reset()
|
|
s.Reset()
|
|
} else {
|
|
// propagate the close
|
|
bs.Close()
|
|
}
|
|
log.Debugf("relayed %d bytes from %s to %s", count, src.ID.Pretty(), dst.ID.Pretty())
|
|
}()
|
|
}
|
|
|
|
func (r *Relay) handleStopStream(s network.Stream, msg *pb.CircuitRelay) {
|
|
src, err := peerToPeerInfo(msg.GetSrcPeer())
|
|
if err != nil {
|
|
r.handleError(s, pb.CircuitRelay_STOP_SRC_MULTIADDR_INVALID)
|
|
return
|
|
}
|
|
|
|
dst, err := peerToPeerInfo(msg.GetDstPeer())
|
|
if err != nil || dst.ID != r.self {
|
|
r.handleError(s, pb.CircuitRelay_STOP_DST_MULTIADDR_INVALID)
|
|
return
|
|
}
|
|
|
|
log.Infof("relay connection from: %s", src.ID)
|
|
|
|
if len(src.Addrs) > 0 {
|
|
r.host.Peerstore().AddAddrs(src.ID, src.Addrs, peerstore.TempAddrTTL)
|
|
}
|
|
|
|
select {
|
|
case r.incoming <- &Conn{stream: s, remote: src, host: r.host}:
|
|
case <-time.After(RelayAcceptTimeout):
|
|
r.handleError(s, pb.CircuitRelay_STOP_RELAY_REFUSED)
|
|
}
|
|
}
|
|
|
|
func (r *Relay) handleCanHop(s network.Stream, msg *pb.CircuitRelay) {
|
|
var err error
|
|
|
|
if r.hop {
|
|
err = r.writeResponse(s, pb.CircuitRelay_SUCCESS)
|
|
} else {
|
|
err = r.writeResponse(s, pb.CircuitRelay_HOP_CANT_SPEAK_RELAY)
|
|
}
|
|
|
|
if err != nil {
|
|
s.Reset()
|
|
log.Debugf("error writing relay response: %s", err.Error())
|
|
} else {
|
|
helpers.FullClose(s)
|
|
}
|
|
}
|
|
|
|
func (r *Relay) handleError(s network.Stream, code pb.CircuitRelay_Status) {
|
|
log.Warningf("relay error: %s (%d)", pb.CircuitRelay_Status_name[int32(code)], code)
|
|
err := r.writeResponse(s, code)
|
|
if err != nil {
|
|
s.Reset()
|
|
log.Debugf("error writing relay response: %s", err.Error())
|
|
} else {
|
|
helpers.FullClose(s)
|
|
}
|
|
}
|
|
|
|
func (r *Relay) writeResponse(s network.Stream, code pb.CircuitRelay_Status) error {
|
|
wr := newDelimitedWriter(s)
|
|
|
|
var msg pb.CircuitRelay
|
|
msg.Type = pb.CircuitRelay_STATUS.Enum()
|
|
msg.Code = code.Enum()
|
|
|
|
return wr.WriteMsg(&msg)
|
|
}
|