396 lines
8.7 KiB
Go
396 lines
8.7 KiB
Go
package relay
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"sync"
|
|
"time"
|
|
|
|
pb "github.com/libp2p/go-libp2p-circuit/pb"
|
|
|
|
logging "github.com/ipfs/go-log"
|
|
host "github.com/libp2p/go-libp2p-host"
|
|
inet "github.com/libp2p/go-libp2p-net"
|
|
peer "github.com/libp2p/go-libp2p-peer"
|
|
pstore "github.com/libp2p/go-libp2p-peerstore"
|
|
tptu "github.com/libp2p/go-libp2p-transport-upgrader"
|
|
ma "github.com/multiformats/go-multiaddr"
|
|
)
|
|
|
|
var log = logging.Logger("relay")
|
|
|
|
const ProtoID = "/libp2p/circuit/relay/0.1.0"
|
|
|
|
const maxMessageSize = 4096
|
|
|
|
var RelayAcceptTimeout = time.Minute
|
|
var HopConnectTimeout = 10 * time.Second
|
|
|
|
type Relay struct {
|
|
host host.Host
|
|
upgrader *tptu.Upgrader
|
|
ctx context.Context
|
|
self peer.ID
|
|
|
|
active bool
|
|
hop bool
|
|
|
|
incoming chan *Conn
|
|
|
|
relays map[peer.ID]struct{}
|
|
mx sync.Mutex
|
|
}
|
|
|
|
type RelayOpt int
|
|
|
|
var (
|
|
OptActive = RelayOpt(0)
|
|
OptHop = RelayOpt(1)
|
|
)
|
|
|
|
type RelayError struct {
|
|
Code pb.CircuitRelay_Status
|
|
}
|
|
|
|
func (e RelayError) Error() string {
|
|
return fmt.Sprintf("error opening relay circuit: %s (%d)", pb.CircuitRelay_Status_name[int32(e.Code)], e.Code)
|
|
}
|
|
|
|
func NewRelay(ctx context.Context, h host.Host, upgrader *tptu.Upgrader, opts ...RelayOpt) (*Relay, error) {
|
|
r := &Relay{
|
|
upgrader: upgrader,
|
|
host: h,
|
|
ctx: ctx,
|
|
self: h.ID(),
|
|
incoming: make(chan *Conn),
|
|
relays: make(map[peer.ID]struct{}),
|
|
}
|
|
|
|
for _, opt := range opts {
|
|
switch opt {
|
|
case OptActive:
|
|
r.active = true
|
|
case OptHop:
|
|
r.hop = true
|
|
default:
|
|
return nil, fmt.Errorf("unrecognized option: %d", opt)
|
|
}
|
|
}
|
|
|
|
h.SetStreamHandler(ProtoID, r.handleNewStream)
|
|
h.Network().Notify(r.Notifiee())
|
|
|
|
return r, nil
|
|
}
|
|
|
|
func (r *Relay) DialPeer(ctx context.Context, relay pstore.PeerInfo, dest pstore.PeerInfo) (*Conn, error) {
|
|
|
|
log.Debugf("dialing peer %s through relay %s", dest.ID, relay.ID)
|
|
|
|
if len(relay.Addrs) > 0 {
|
|
r.host.Peerstore().AddAddrs(relay.ID, relay.Addrs, pstore.TempAddrTTL)
|
|
}
|
|
|
|
s, err := r.host.NewStream(ctx, relay.ID, ProtoID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
rd := newDelimitedReader(s, maxMessageSize)
|
|
wr := newDelimitedWriter(s)
|
|
|
|
var msg pb.CircuitRelay
|
|
|
|
msg.Type = pb.CircuitRelay_HOP.Enum()
|
|
msg.SrcPeer = peerInfoToPeer(r.host.Peerstore().PeerInfo(r.self))
|
|
msg.DstPeer = peerInfoToPeer(dest)
|
|
|
|
err = wr.WriteMsg(&msg)
|
|
if err != nil {
|
|
s.Reset()
|
|
return nil, err
|
|
}
|
|
|
|
msg.Reset()
|
|
|
|
err = rd.ReadMsg(&msg)
|
|
if err != nil {
|
|
s.Reset()
|
|
return nil, err
|
|
}
|
|
|
|
if msg.GetType() != pb.CircuitRelay_STATUS {
|
|
s.Reset()
|
|
return nil, fmt.Errorf("unexpected relay response; not a status message (%d)", msg.GetType())
|
|
}
|
|
|
|
if msg.GetCode() != pb.CircuitRelay_SUCCESS {
|
|
s.Reset()
|
|
return nil, RelayError{msg.GetCode()}
|
|
}
|
|
|
|
return &Conn{Stream: s, remote: dest}, nil
|
|
}
|
|
|
|
func (r *Relay) Matches(addr ma.Multiaddr) bool {
|
|
// TODO: Look at the prefix transport as well.
|
|
_, err := addr.ValueForProtocol(P_CIRCUIT)
|
|
return err == nil
|
|
}
|
|
|
|
func (r *Relay) CanHop(ctx context.Context, id peer.ID) (bool, error) {
|
|
s, err := r.host.NewStream(ctx, id, ProtoID)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
rd := newDelimitedReader(s, maxMessageSize)
|
|
wr := newDelimitedWriter(s)
|
|
|
|
var msg pb.CircuitRelay
|
|
|
|
msg.Type = pb.CircuitRelay_CAN_HOP.Enum()
|
|
|
|
if err := wr.WriteMsg(&msg); err != nil {
|
|
s.Reset()
|
|
return false, err
|
|
}
|
|
|
|
msg.Reset()
|
|
|
|
if err := rd.ReadMsg(&msg); err != nil {
|
|
s.Reset()
|
|
return false, err
|
|
}
|
|
if err := inet.FullClose(s); err != nil {
|
|
return false, err
|
|
}
|
|
|
|
if msg.GetType() != pb.CircuitRelay_STATUS {
|
|
return false, fmt.Errorf("unexpected relay response; not a status message (%d)", msg.GetType())
|
|
}
|
|
|
|
return msg.GetCode() == pb.CircuitRelay_SUCCESS, nil
|
|
}
|
|
|
|
func (r *Relay) handleNewStream(s inet.Stream) {
|
|
log.Infof("new relay stream from: %s", s.Conn().RemotePeer())
|
|
|
|
rd := newDelimitedReader(s, maxMessageSize)
|
|
|
|
var msg pb.CircuitRelay
|
|
|
|
err := rd.ReadMsg(&msg)
|
|
if err != nil {
|
|
r.handleError(s, pb.CircuitRelay_MALFORMED_MESSAGE)
|
|
return
|
|
}
|
|
|
|
switch msg.GetType() {
|
|
case pb.CircuitRelay_HOP:
|
|
r.handleHopStream(s, &msg)
|
|
case pb.CircuitRelay_STOP:
|
|
r.handleStopStream(s, &msg)
|
|
case pb.CircuitRelay_CAN_HOP:
|
|
r.handleCanHop(s, &msg)
|
|
default:
|
|
log.Warningf("unexpected relay handshake: %d", msg.GetType())
|
|
r.handleError(s, pb.CircuitRelay_MALFORMED_MESSAGE)
|
|
}
|
|
}
|
|
|
|
func (r *Relay) handleHopStream(s inet.Stream, msg *pb.CircuitRelay) {
|
|
if !r.hop {
|
|
r.handleError(s, pb.CircuitRelay_HOP_CANT_SPEAK_RELAY)
|
|
return
|
|
}
|
|
|
|
src, err := peerToPeerInfo(msg.GetSrcPeer())
|
|
if err != nil {
|
|
r.handleError(s, pb.CircuitRelay_HOP_SRC_MULTIADDR_INVALID)
|
|
return
|
|
}
|
|
|
|
if src.ID != s.Conn().RemotePeer() {
|
|
r.handleError(s, pb.CircuitRelay_HOP_SRC_MULTIADDR_INVALID)
|
|
return
|
|
}
|
|
|
|
dst, err := peerToPeerInfo(msg.GetDstPeer())
|
|
if err != nil {
|
|
r.handleError(s, pb.CircuitRelay_HOP_DST_MULTIADDR_INVALID)
|
|
return
|
|
}
|
|
|
|
if dst.ID == r.self {
|
|
r.handleError(s, pb.CircuitRelay_HOP_CANT_RELAY_TO_SELF)
|
|
return
|
|
}
|
|
|
|
// open stream
|
|
ctp := r.host.Network().ConnsToPeer(dst.ID)
|
|
|
|
if len(ctp) == 0 && !r.active {
|
|
r.handleError(s, pb.CircuitRelay_HOP_NO_CONN_TO_DST)
|
|
return
|
|
}
|
|
|
|
if len(dst.Addrs) > 0 {
|
|
r.host.Peerstore().AddAddrs(dst.ID, dst.Addrs, pstore.TempAddrTTL)
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(r.ctx, HopConnectTimeout)
|
|
defer cancel()
|
|
|
|
bs, err := r.host.NewStream(ctx, dst.ID, ProtoID)
|
|
if err != nil {
|
|
log.Debugf("error opening relay stream to %s: %s", dst.ID.Pretty(), err.Error())
|
|
r.handleError(s, pb.CircuitRelay_HOP_CANT_DIAL_DST)
|
|
return
|
|
}
|
|
|
|
// stop handshake
|
|
rd := newDelimitedReader(bs, maxMessageSize)
|
|
wr := newDelimitedWriter(bs)
|
|
|
|
msg.Type = pb.CircuitRelay_STOP.Enum()
|
|
|
|
err = wr.WriteMsg(msg)
|
|
if err != nil {
|
|
log.Debugf("error writing stop handshake: %s", err.Error())
|
|
bs.Reset()
|
|
r.handleError(s, pb.CircuitRelay_HOP_CANT_OPEN_DST_STREAM)
|
|
return
|
|
}
|
|
|
|
msg.Reset()
|
|
|
|
err = rd.ReadMsg(msg)
|
|
if err != nil {
|
|
log.Debugf("error reading stop response: %s", err.Error())
|
|
bs.Reset()
|
|
r.handleError(s, pb.CircuitRelay_HOP_CANT_OPEN_DST_STREAM)
|
|
return
|
|
}
|
|
|
|
if msg.GetType() != pb.CircuitRelay_STATUS {
|
|
log.Debugf("unexpected relay stop response: not a status message (%d)", msg.GetType())
|
|
bs.Reset()
|
|
r.handleError(s, pb.CircuitRelay_HOP_CANT_OPEN_DST_STREAM)
|
|
return
|
|
}
|
|
|
|
if msg.GetCode() != pb.CircuitRelay_SUCCESS {
|
|
log.Debugf("relay stop failure: %d", msg.GetCode())
|
|
bs.Reset()
|
|
r.handleError(s, msg.GetCode())
|
|
return
|
|
}
|
|
|
|
err = r.writeResponse(s, pb.CircuitRelay_SUCCESS)
|
|
if err != nil {
|
|
log.Debugf("error writing relay response: %s", err.Error())
|
|
bs.Reset()
|
|
s.Reset()
|
|
return
|
|
}
|
|
|
|
// relay connection
|
|
log.Infof("relaying connection between %s and %s", src.ID.Pretty(), dst.ID.Pretty())
|
|
|
|
// Don't reset streams after finishing or the other side will get an
|
|
// error, not an EOF.
|
|
go func() {
|
|
count, err := io.Copy(s, bs)
|
|
if err != nil {
|
|
log.Debugf("relay copy error: %s", err)
|
|
// Reset both.
|
|
s.Reset()
|
|
bs.Reset()
|
|
} else {
|
|
// propagate the close
|
|
s.Close()
|
|
}
|
|
log.Debugf("relayed %d bytes from %s to %s", count, dst.ID.Pretty(), src.ID.Pretty())
|
|
}()
|
|
|
|
go func() {
|
|
count, err := io.Copy(bs, s)
|
|
if err != nil {
|
|
log.Debugf("relay copy error: %s", err)
|
|
// Reset both.
|
|
bs.Reset()
|
|
s.Reset()
|
|
} else {
|
|
// propagate the close
|
|
bs.Close()
|
|
}
|
|
log.Debugf("relayed %d bytes from %s to %s", count, src.ID.Pretty(), dst.ID.Pretty())
|
|
}()
|
|
}
|
|
|
|
func (r *Relay) handleStopStream(s inet.Stream, msg *pb.CircuitRelay) {
|
|
src, err := peerToPeerInfo(msg.GetSrcPeer())
|
|
if err != nil {
|
|
r.handleError(s, pb.CircuitRelay_STOP_SRC_MULTIADDR_INVALID)
|
|
return
|
|
}
|
|
|
|
dst, err := peerToPeerInfo(msg.GetDstPeer())
|
|
if err != nil || dst.ID != r.self {
|
|
r.handleError(s, pb.CircuitRelay_STOP_DST_MULTIADDR_INVALID)
|
|
return
|
|
}
|
|
|
|
log.Infof("relay connection from: %s", src.ID)
|
|
|
|
if len(src.Addrs) > 0 {
|
|
r.host.Peerstore().AddAddrs(src.ID, src.Addrs, pstore.TempAddrTTL)
|
|
}
|
|
|
|
select {
|
|
case r.incoming <- &Conn{Stream: s, remote: src}:
|
|
case <-time.After(RelayAcceptTimeout):
|
|
r.handleError(s, pb.CircuitRelay_STOP_RELAY_REFUSED)
|
|
}
|
|
}
|
|
|
|
func (r *Relay) handleCanHop(s inet.Stream, msg *pb.CircuitRelay) {
|
|
var err error
|
|
|
|
if r.hop {
|
|
err = r.writeResponse(s, pb.CircuitRelay_SUCCESS)
|
|
} else {
|
|
err = r.writeResponse(s, pb.CircuitRelay_HOP_CANT_SPEAK_RELAY)
|
|
}
|
|
|
|
if err != nil {
|
|
s.Reset()
|
|
log.Debugf("error writing relay response: %s", err.Error())
|
|
} else {
|
|
inet.FullClose(s)
|
|
}
|
|
}
|
|
|
|
func (r *Relay) handleError(s inet.Stream, code pb.CircuitRelay_Status) {
|
|
log.Warningf("relay error: %s (%d)", pb.CircuitRelay_Status_name[int32(code)], code)
|
|
err := r.writeResponse(s, code)
|
|
if err != nil {
|
|
s.Reset()
|
|
log.Debugf("error writing relay response: %s", err.Error())
|
|
} else {
|
|
inet.FullClose(s)
|
|
}
|
|
}
|
|
|
|
func (r *Relay) writeResponse(s inet.Stream, code pb.CircuitRelay_Status) error {
|
|
wr := newDelimitedWriter(s)
|
|
|
|
var msg pb.CircuitRelay
|
|
msg.Type = pb.CircuitRelay_STATUS.Enum()
|
|
msg.Code = code.Enum()
|
|
|
|
return wr.WriteMsg(&msg)
|
|
}
|