package relay import ( "context" "fmt" "io" "sync" "sync/atomic" "time" pb "github.com/libp2p/go-libp2p-circuit/pb" "github.com/libp2p/go-libp2p-core/helpers" "github.com/libp2p/go-libp2p-core/host" "github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-libp2p-core/peerstore" pool "github.com/libp2p/go-buffer-pool" tptu "github.com/libp2p/go-libp2p-transport-upgrader" logging "github.com/ipfs/go-log" ma "github.com/multiformats/go-multiaddr" ) var log = logging.Logger("relay") const ProtoID = "/libp2p/circuit/relay/0.1.0" const maxMessageSize = 4096 var ( RelayAcceptTimeout = 10 * time.Second HopConnectTimeout = 30 * time.Second StopHandshakeTimeout = 1 * time.Minute HopStreamBufferSize = 4096 HopStreamLimit = 1 << 19 // 512K hops for 1M goroutines ) // Relay is the relay transport and service. type Relay struct { host host.Host upgrader *tptu.Upgrader ctx context.Context self peer.ID active bool hop bool discovery bool incoming chan *Conn relays map[peer.ID]struct{} mx sync.Mutex // atomic counters streamCount int32 liveHopCount int32 } // RelayOpts are options for configuring the relay transport. type RelayOpt int var ( // OptActive configures the relay transport to actively establish // outbound connections on behalf of clients. You probably don't want to // enable this unless you know what you're doing. OptActive = RelayOpt(0) // OptHop configures the relay transport to accept requests to relay // traffic on behalf of third-parties. Unless OptActive is specified, // this will only relay traffic between peers already connected to this // node. OptHop = RelayOpt(1) // OptDiscovery configures this relay transport to discover new relays // by probing every new peer. You almost _certainly_ don't want to // enable this. OptDiscovery = RelayOpt(2) ) type RelayError struct { Code pb.CircuitRelay_Status } func (e RelayError) Error() string { return fmt.Sprintf("error opening relay circuit: %s (%d)", pb.CircuitRelay_Status_name[int32(e.Code)], e.Code) } // NewRelay constructs a new relay. func NewRelay(ctx context.Context, h host.Host, upgrader *tptu.Upgrader, opts ...RelayOpt) (*Relay, error) { r := &Relay{ upgrader: upgrader, host: h, ctx: ctx, self: h.ID(), incoming: make(chan *Conn), relays: make(map[peer.ID]struct{}), } for _, opt := range opts { switch opt { case OptActive: r.active = true case OptHop: r.hop = true case OptDiscovery: r.discovery = true default: return nil, fmt.Errorf("unrecognized option: %d", opt) } } h.SetStreamHandler(ProtoID, r.handleNewStream) if r.discovery { h.Network().Notify(r.notifiee()) } return r, nil } // Increment the live hop count and increment the connection manager tags by 1 for the two // sides of the hop stream. This ensures that connections with many hop streams will be protected // from pruning, thus minimizing disruption from connection trimming in a relay node. func (r *Relay) addLiveHop(from, to peer.ID) { atomic.AddInt32(&r.liveHopCount, 1) r.host.ConnManager().UpsertTag(from, "relay-hop-stream", incrementTag) r.host.ConnManager().UpsertTag(to, "relay-hop-stream", incrementTag) } // Decrement the live hpo count and decrement the connection manager tags for the two sides // of the hop stream. func (r *Relay) rmLiveHop(from, to peer.ID) { atomic.AddInt32(&r.liveHopCount, -1) r.host.ConnManager().UpsertTag(from, "relay-hop-stream", decrementTag) r.host.ConnManager().UpsertTag(to, "relay-hop-stream", decrementTag) } func (r *Relay) GetActiveHops() int32 { return atomic.LoadInt32(&r.liveHopCount) } func (r *Relay) DialPeer(ctx context.Context, relay peer.AddrInfo, dest peer.AddrInfo) (*Conn, error) { log.Debugf("dialing peer %s through relay %s", dest.ID, relay.ID) if len(relay.Addrs) > 0 { r.host.Peerstore().AddAddrs(relay.ID, relay.Addrs, peerstore.TempAddrTTL) } s, err := r.host.NewStream(ctx, relay.ID, ProtoID) if err != nil { return nil, err } rd := newDelimitedReader(s, maxMessageSize) wr := newDelimitedWriter(s) defer rd.Close() var msg pb.CircuitRelay msg.Type = pb.CircuitRelay_HOP.Enum() msg.SrcPeer = peerInfoToPeer(r.host.Peerstore().PeerInfo(r.self)) msg.DstPeer = peerInfoToPeer(dest) err = wr.WriteMsg(&msg) if err != nil { s.Reset() return nil, err } msg.Reset() err = rd.ReadMsg(&msg) if err != nil { s.Reset() return nil, err } if msg.GetType() != pb.CircuitRelay_STATUS { s.Reset() return nil, fmt.Errorf("unexpected relay response; not a status message (%d)", msg.GetType()) } if msg.GetCode() != pb.CircuitRelay_SUCCESS { s.Reset() return nil, RelayError{msg.GetCode()} } return &Conn{stream: s, remote: dest, host: r.host}, nil } func (r *Relay) Matches(addr ma.Multiaddr) bool { // TODO: Look at the prefix transport as well. _, err := addr.ValueForProtocol(P_CIRCUIT) return err == nil } func (r *Relay) CanHop(ctx context.Context, id peer.ID) (bool, error) { s, err := r.host.NewStream(ctx, id, ProtoID) if err != nil { return false, err } rd := newDelimitedReader(s, maxMessageSize) wr := newDelimitedWriter(s) defer rd.Close() var msg pb.CircuitRelay msg.Type = pb.CircuitRelay_CAN_HOP.Enum() if err := wr.WriteMsg(&msg); err != nil { s.Reset() return false, err } msg.Reset() if err := rd.ReadMsg(&msg); err != nil { s.Reset() return false, err } if err := helpers.FullClose(s); err != nil { return false, err } if msg.GetType() != pb.CircuitRelay_STATUS { return false, fmt.Errorf("unexpected relay response; not a status message (%d)", msg.GetType()) } return msg.GetCode() == pb.CircuitRelay_SUCCESS, nil } func (r *Relay) handleNewStream(s network.Stream) { log.Infof("new relay stream from: %s", s.Conn().RemotePeer()) rd := newDelimitedReader(s, maxMessageSize) defer rd.Close() var msg pb.CircuitRelay err := rd.ReadMsg(&msg) if err != nil { r.handleError(s, pb.CircuitRelay_MALFORMED_MESSAGE) return } switch msg.GetType() { case pb.CircuitRelay_HOP: r.handleHopStream(s, &msg) case pb.CircuitRelay_STOP: r.handleStopStream(s, &msg) case pb.CircuitRelay_CAN_HOP: r.handleCanHop(s, &msg) default: log.Warningf("unexpected relay handshake: %d", msg.GetType()) r.handleError(s, pb.CircuitRelay_MALFORMED_MESSAGE) } } func (r *Relay) handleHopStream(s network.Stream, msg *pb.CircuitRelay) { if !r.hop { r.handleError(s, pb.CircuitRelay_HOP_CANT_SPEAK_RELAY) return } streamCount := atomic.AddInt32(&r.streamCount, 1) liveHopCount := atomic.LoadInt32(&r.liveHopCount) defer atomic.AddInt32(&r.streamCount, -1) if (streamCount + liveHopCount) > int32(HopStreamLimit) { log.Warning("hop stream limit exceeded; resetting stream") s.Reset() return } src, err := peerToPeerInfo(msg.GetSrcPeer()) if err != nil { r.handleError(s, pb.CircuitRelay_HOP_SRC_MULTIADDR_INVALID) return } if src.ID != s.Conn().RemotePeer() { r.handleError(s, pb.CircuitRelay_HOP_SRC_MULTIADDR_INVALID) return } dst, err := peerToPeerInfo(msg.GetDstPeer()) if err != nil { r.handleError(s, pb.CircuitRelay_HOP_DST_MULTIADDR_INVALID) return } if dst.ID == r.self { r.handleError(s, pb.CircuitRelay_HOP_CANT_RELAY_TO_SELF) return } // open stream ctx, cancel := context.WithTimeout(r.ctx, HopConnectTimeout) defer cancel() if !r.active { ctx = network.WithNoDial(ctx, "relay hop") } else if len(dst.Addrs) > 0 { r.host.Peerstore().AddAddrs(dst.ID, dst.Addrs, peerstore.TempAddrTTL) } bs, err := r.host.NewStream(ctx, dst.ID, ProtoID) if err != nil { log.Debugf("error opening relay stream to %s: %s", dst.ID.Pretty(), err.Error()) if err == network.ErrNoConn { r.handleError(s, pb.CircuitRelay_HOP_NO_CONN_TO_DST) } else { r.handleError(s, pb.CircuitRelay_HOP_CANT_DIAL_DST) } return } // stop handshake rd := newDelimitedReader(bs, maxMessageSize) wr := newDelimitedWriter(bs) defer rd.Close() // set handshake deadline bs.SetDeadline(time.Now().Add(StopHandshakeTimeout)) msg.Type = pb.CircuitRelay_STOP.Enum() err = wr.WriteMsg(msg) if err != nil { log.Debugf("error writing stop handshake: %s", err.Error()) bs.Reset() r.handleError(s, pb.CircuitRelay_HOP_CANT_OPEN_DST_STREAM) return } msg.Reset() err = rd.ReadMsg(msg) if err != nil { log.Debugf("error reading stop response: %s", err.Error()) bs.Reset() r.handleError(s, pb.CircuitRelay_HOP_CANT_OPEN_DST_STREAM) return } if msg.GetType() != pb.CircuitRelay_STATUS { log.Debugf("unexpected relay stop response: not a status message (%d)", msg.GetType()) bs.Reset() r.handleError(s, pb.CircuitRelay_HOP_CANT_OPEN_DST_STREAM) return } if msg.GetCode() != pb.CircuitRelay_SUCCESS { log.Debugf("relay stop failure: %d", msg.GetCode()) bs.Reset() r.handleError(s, msg.GetCode()) return } err = r.writeResponse(s, pb.CircuitRelay_SUCCESS) if err != nil { log.Debugf("error writing relay response: %s", err.Error()) bs.Reset() s.Reset() return } // relay connection log.Infof("relaying connection between %s and %s", src.ID.Pretty(), dst.ID.Pretty()) // reset deadline bs.SetDeadline(time.Time{}) r.addLiveHop(src.ID, dst.ID) goroutines := new(int32) *goroutines = 2 done := func() { if atomic.AddInt32(goroutines, -1) == 0 { r.rmLiveHop(src.ID, dst.ID) } } // Don't reset streams after finishing or the other side will get an // error, not an EOF. go func() { defer done() buf := pool.Get(HopStreamBufferSize) defer pool.Put(buf) count, err := io.CopyBuffer(s, bs, buf) if err != nil { log.Debugf("relay copy error: %s", err) // Reset both. s.Reset() bs.Reset() } else { // propagate the close s.Close() } log.Debugf("relayed %d bytes from %s to %s", count, dst.ID.Pretty(), src.ID.Pretty()) }() go func() { defer done() buf := pool.Get(HopStreamBufferSize) defer pool.Put(buf) count, err := io.CopyBuffer(bs, s, buf) if err != nil { log.Debugf("relay copy error: %s", err) // Reset both. bs.Reset() s.Reset() } else { // propagate the close bs.Close() } log.Debugf("relayed %d bytes from %s to %s", count, src.ID.Pretty(), dst.ID.Pretty()) }() } func (r *Relay) handleStopStream(s network.Stream, msg *pb.CircuitRelay) { src, err := peerToPeerInfo(msg.GetSrcPeer()) if err != nil { r.handleError(s, pb.CircuitRelay_STOP_SRC_MULTIADDR_INVALID) return } dst, err := peerToPeerInfo(msg.GetDstPeer()) if err != nil || dst.ID != r.self { r.handleError(s, pb.CircuitRelay_STOP_DST_MULTIADDR_INVALID) return } log.Infof("relay connection from: %s", src.ID) if len(src.Addrs) > 0 { r.host.Peerstore().AddAddrs(src.ID, src.Addrs, peerstore.TempAddrTTL) } select { case r.incoming <- &Conn{stream: s, remote: src, host: r.host}: case <-time.After(RelayAcceptTimeout): r.handleError(s, pb.CircuitRelay_STOP_RELAY_REFUSED) } } func (r *Relay) handleCanHop(s network.Stream, msg *pb.CircuitRelay) { var err error if r.hop { err = r.writeResponse(s, pb.CircuitRelay_SUCCESS) } else { err = r.writeResponse(s, pb.CircuitRelay_HOP_CANT_SPEAK_RELAY) } if err != nil { s.Reset() log.Debugf("error writing relay response: %s", err.Error()) } else { helpers.FullClose(s) } } func (r *Relay) handleError(s network.Stream, code pb.CircuitRelay_Status) { log.Warningf("relay error: %s (%d)", pb.CircuitRelay_Status_name[int32(code)], code) err := r.writeResponse(s, code) if err != nil { s.Reset() log.Debugf("error writing relay response: %s", err.Error()) } else { helpers.FullClose(s) } } func (r *Relay) writeResponse(s network.Stream, code pb.CircuitRelay_Status) error { wr := newDelimitedWriter(s) var msg pb.CircuitRelay msg.Type = pb.CircuitRelay_STATUS.Enum() msg.Code = code.Enum() return wr.WriteMsg(&msg) }