diff --git a/p2p/protocol/holepunch/coordination.go b/p2p/protocol/holepunch/coordination.go deleted file mode 100644 index c3964485..00000000 --- a/p2p/protocol/holepunch/coordination.go +++ /dev/null @@ -1,511 +0,0 @@ -package holepunch - -import ( - "context" - "errors" - "fmt" - "sync" - "time" - - "github.com/libp2p/go-libp2p-core/host" - "github.com/libp2p/go-libp2p-core/network" - "github.com/libp2p/go-libp2p-core/peer" - "github.com/libp2p/go-libp2p-core/protocol" - pb "github.com/libp2p/go-libp2p/p2p/protocol/holepunch/pb" - "github.com/libp2p/go-libp2p/p2p/protocol/identify" - - logging "github.com/ipfs/go-log/v2" - "github.com/libp2p/go-msgio/protoio" - ma "github.com/multiformats/go-multiaddr" - manet "github.com/multiformats/go-multiaddr/net" -) - -// Protocol is the libp2p protocol for Hole Punching. -const Protocol protocol.ID = "/libp2p/dcutr" - -// StreamTimeout is the timeout for the hole punch protocol stream. -var StreamTimeout = 1 * time.Minute - -// TODO Should we have options for these ? -const ( - ServiceName = "libp2p.holepunch" - - maxMsgSize = 4 * 1024 // 4K - dialTimeout = 5 * time.Second - maxRetries = 3 - retryWait = 2 * time.Second -) - -var ( - log = logging.Logger("p2p-holepunch") - // ErrHolePunchActive is returned from DirectConnect when another hole punching attempt is currently running - ErrHolePunchActive = errors.New("another hole punching attempt to this peer is active") - // ErrClosed is returned when the hole punching is closed - ErrClosed = errors.New("hole punching service closing") -) - -// The Service is used to make direct connections with a peer via hole-punching. -type Service struct { - ctx context.Context - ctxCancel context.CancelFunc - - ids identify.IDService - host host.Host - - tracer *tracer - - closeMx sync.RWMutex - closed bool - refCount sync.WaitGroup - - hasPublicAddrsChan chan struct{} // this chan is closed as soon as we have a public address - - // active hole punches for deduplicating - activeMx sync.Mutex - active map[peer.ID]struct{} -} - -type Option func(*Service) error - -// NewService creates a new service that can be used for hole punching -func NewService(h host.Host, ids identify.IDService, opts ...Option) (*Service, error) { - if ids == nil { - return nil, errors.New("identify service can't be nil") - } - - ctx, cancel := context.WithCancel(context.Background()) - hs := &Service{ - ctx: ctx, - ctxCancel: cancel, - host: h, - ids: ids, - active: make(map[peer.ID]struct{}), - hasPublicAddrsChan: make(chan struct{}), - } - - for _, opt := range opts { - if err := opt(hs); err != nil { - cancel() - return nil, err - } - } - - hs.refCount.Add(1) - go hs.watchForPublicAddr() - - h.Network().Notify((*netNotifiee)(hs)) - return hs, nil -} - -func (hs *Service) watchForPublicAddr() { - defer hs.refCount.Done() - - log.Debug("waiting until we have at least one public address", "peer", hs.host.ID()) - - // TODO: We should have an event here that fires when identify discovers a new - // address (and when autonat confirms that address). - // As we currently don't have an event like this, just check our observed addresses - // regularly (exponential backoff starting at 250 ms, capped at 5s). - duration := 250 * time.Millisecond - const maxDuration = 5 * time.Second - t := time.NewTimer(duration) - defer t.Stop() - for { - if containsPublicAddr(hs.ids.OwnObservedAddrs()) { - log.Debug("Host now has a public address. Starting holepunch protocol.") - hs.host.SetStreamHandler(Protocol, hs.handleNewStream) - close(hs.hasPublicAddrsChan) - return - } - - select { - case <-hs.ctx.Done(): - return - case <-t.C: - duration *= 2 - if duration > maxDuration { - duration = maxDuration - } - t.Reset(duration) - } - } -} - -// Close closes the Hole Punch Service. -func (hs *Service) Close() error { - hs.closeMx.Lock() - hs.closed = true - hs.closeMx.Unlock() - hs.tracer.Close() - hs.host.RemoveStreamHandler(Protocol) - hs.ctxCancel() - hs.refCount.Wait() - return nil -} - -// initiateHolePunch opens a new hole punching coordination stream, -// exchanges the addresses and measures the RTT. -func (hs *Service) initiateHolePunch(rp peer.ID) ([]ma.Multiaddr, time.Duration, error) { - hpCtx := network.WithUseTransient(hs.ctx, "hole-punch") - sCtx := network.WithNoDial(hpCtx, "hole-punch") - - str, err := hs.host.NewStream(sCtx, rp, Protocol) - if err != nil { - return nil, 0, fmt.Errorf("failed to open hole-punching stream: %w", err) - } - defer str.Close() - - if err := str.Scope().SetService(ServiceName); err != nil { - log.Debugf("error attaching stream to holepunch service: %s", err) - str.Reset() - return nil, 0, err - } - - if err := str.Scope().ReserveMemory(maxMsgSize, network.ReservationPriorityAlways); err != nil { - log.Debugf("error reserving memory for stream: %s, err") - str.Reset() - return nil, 0, err - } - defer str.Scope().ReleaseMemory(maxMsgSize) - - w := protoio.NewDelimitedWriter(str) - rd := protoio.NewDelimitedReader(str, maxMsgSize) - - str.SetDeadline(time.Now().Add(StreamTimeout)) - // send a CONNECT and start RTT measurement. - msg := &pb.HolePunch{ - Type: pb.HolePunch_CONNECT.Enum(), - ObsAddrs: addrsToBytes(removeRelayAddrs(hs.ids.OwnObservedAddrs())), - } - - start := time.Now() - if err := w.WriteMsg(msg); err != nil { - str.Reset() - return nil, 0, err - } - - // wait for a CONNECT message from the remote peer - msg.Reset() - if err := rd.ReadMsg(msg); err != nil { - str.Reset() - return nil, 0, fmt.Errorf("failed to read CONNECT message from remote peer: %w", err) - } - rtt := time.Since(start) - - if t := msg.GetType(); t != pb.HolePunch_CONNECT { - str.Reset() - return nil, 0, fmt.Errorf("expect CONNECT message, got %s", t) - } - addrs := removeRelayAddrs(addrsFromBytes(msg.ObsAddrs)) - if len(addrs) == 0 { - str.Reset() - } - - msg.Reset() - msg.Type = pb.HolePunch_SYNC.Enum() - if err := w.WriteMsg(msg); err != nil { - str.Reset() - return nil, 0, fmt.Errorf("failed to send SYNC message for hole punching: %w", err) - } - return addrs, rtt, nil -} - -func (hs *Service) beginDirectConnect(p peer.ID) error { - hs.closeMx.RLock() - defer hs.closeMx.RUnlock() - if hs.closed { - return ErrClosed - } - - hs.activeMx.Lock() - defer hs.activeMx.Unlock() - if _, ok := hs.active[p]; ok { - return ErrHolePunchActive - } - - hs.active[p] = struct{}{} - return nil -} - -// DirectConnect attempts to make a direct connection with a remote peer. -// It first attempts a direct dial (if we have a public address of that peer), and then -// coordinates a hole punch over the given relay connection. -func (hs *Service) DirectConnect(p peer.ID) error { - if err := hs.beginDirectConnect(p); err != nil { - return err - } - - defer func() { - hs.activeMx.Lock() - delete(hs.active, p) - hs.activeMx.Unlock() - }() - - return hs.directConnect(p) -} - -func (hs *Service) directConnect(rp peer.ID) error { - // short-circuit check to see if we already have a direct connection - for _, c := range hs.host.Network().ConnsToPeer(rp) { - if !isRelayAddress(c.RemoteMultiaddr()) { - return nil - } - } - - // short-circuit hole punching if a direct dial works. - // attempt a direct connection ONLY if we have a public address for the remote peer - for _, a := range hs.host.Peerstore().Addrs(rp) { - if manet.IsPublicAddr(a) && !isRelayAddress(a) { - forceDirectConnCtx := network.WithForceDirectDial(hs.ctx, "hole-punching") - dialCtx, cancel := context.WithTimeout(forceDirectConnCtx, dialTimeout) - - tstart := time.Now() - // This dials *all* public addresses from the peerstore. - err := hs.host.Connect(dialCtx, peer.AddrInfo{ID: rp}) - dt := time.Since(tstart) - cancel() - - if err != nil { - hs.tracer.DirectDialFailed(rp, dt, err) - break - } - hs.tracer.DirectDialSuccessful(rp, dt) - log.Debugw("direct connection to peer successful, no need for a hole punch", "peer", rp) - return nil - } - } - - log.Debugw("got inbound proxy conn", "peer", rp) - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - select { - case <-hs.ctx.Done(): - return hs.ctx.Err() - case <-ctx.Done(): - log.Debug("didn't find any public host address") - return errors.New("can't initiate hole punch, as we don't have any public addresses") - case <-hs.hasPublicAddrsChan: - } - - // hole punch - for i := 0; i < maxRetries; i++ { - addrs, rtt, err := hs.initiateHolePunch(rp) - if err != nil { - log.Debugw("hole punching failed", "peer", rp, "error", err) - hs.tracer.ProtocolError(rp, err) - return err - } - synTime := rtt / 2 - log.Debugf("peer RTT is %s; starting hole punch in %s", rtt, synTime) - - // wait for sync to reach the other peer and then punch a hole for it in our NAT - // by attempting a connect to it. - timer := time.NewTimer(synTime) - select { - case start := <-timer.C: - pi := peer.AddrInfo{ - ID: rp, - Addrs: addrs, - } - hs.tracer.StartHolePunch(rp, addrs, rtt) - err := hs.holePunchConnect(pi, true) - dt := time.Since(start) - hs.tracer.EndHolePunch(rp, dt, err) - if err == nil { - log.Debugw("hole punching with successful", "peer", rp, "time", dt) - return nil - } - case <-hs.ctx.Done(): - timer.Stop() - return hs.ctx.Err() - } - } - return fmt.Errorf("all retries for hole punch with peer %s failed", rp) -} - -func (hs *Service) incomingHolePunch(s network.Stream) (rtt time.Duration, addrs []ma.Multiaddr, err error) { - // sanity check: a hole punch request should only come from peers behind a relay - if !isRelayAddress(s.Conn().RemoteMultiaddr()) { - return 0, nil, fmt.Errorf("received hole punch stream: %s", s.Conn().RemoteMultiaddr()) - } - ownAddrs := removeRelayAddrs(hs.ids.OwnObservedAddrs()) - // If we can't tell the peer where to dial us, there's no point in starting the hole punching. - if len(ownAddrs) == 0 { - return 0, nil, errors.New("rejecting hole punch request, as we don't have any public addresses") - } - - if err := s.Scope().ReserveMemory(maxMsgSize, network.ReservationPriorityAlways); err != nil { - log.Debugf("error reserving memory for stream: %s, err") - return 0, nil, err - } - defer s.Scope().ReleaseMemory(maxMsgSize) - - wr := protoio.NewDelimitedWriter(s) - rd := protoio.NewDelimitedReader(s, maxMsgSize) - - // Read Connect message - msg := new(pb.HolePunch) - - s.SetDeadline(time.Now().Add(StreamTimeout)) - - if err := rd.ReadMsg(msg); err != nil { - return 0, nil, fmt.Errorf("failed to read message from initator: %w", err) - } - if t := msg.GetType(); t != pb.HolePunch_CONNECT { - return 0, nil, fmt.Errorf("expected CONNECT message from initiator but got %d", t) - } - obsDial := removeRelayAddrs(addrsFromBytes(msg.ObsAddrs)) - log.Debugw("received hole punch request", "peer", s.Conn().RemotePeer(), "addrs", obsDial) - if len(obsDial) == 0 { - return 0, nil, errors.New("expected CONNECT message to contain at least one address") - } - - // Write CONNECT message - msg.Reset() - msg.Type = pb.HolePunch_CONNECT.Enum() - msg.ObsAddrs = addrsToBytes(ownAddrs) - tstart := time.Now() - if err := wr.WriteMsg(msg); err != nil { - return 0, nil, fmt.Errorf("failed to write CONNECT message to initator: %w", err) - } - - // Read SYNC message - msg.Reset() - if err := rd.ReadMsg(msg); err != nil { - return 0, nil, fmt.Errorf("failed to read message from initator: %w", err) - } - if t := msg.GetType(); t != pb.HolePunch_SYNC { - return 0, nil, fmt.Errorf("expected SYNC message from initiator but got %d", t) - } - return time.Since(tstart), obsDial, nil -} - -func (hs *Service) handleNewStream(s network.Stream) { - // Check directionality of the underlying connection. - // Peer A receives an inbound connection from peer B. - // Peer A opens a new hole punch stream to peer B. - // Peer B receives this stream, calling this function. - // Peer B sees the underlying connection as an outbound connection. - if s.Conn().Stat().Direction == network.DirInbound { - s.Reset() - return - } - - if err := s.Scope().SetService(ServiceName); err != nil { - log.Debugf("error attaching stream to holepunch service: %s", err) - s.Reset() - return - } - - rp := s.Conn().RemotePeer() - rtt, addrs, err := hs.incomingHolePunch(s) - if err != nil { - hs.tracer.ProtocolError(rp, err) - log.Debugw("error handling holepunching stream from", rp, "error", err) - s.Reset() - return - } - s.Close() - - // Hole punch now by forcing a connect - pi := peer.AddrInfo{ - ID: rp, - Addrs: addrs, - } - hs.tracer.StartHolePunch(rp, addrs, rtt) - log.Debugw("starting hole punch", "peer", rp) - start := time.Now() - err = hs.holePunchConnect(pi, false) - dt := time.Since(start) - hs.tracer.EndHolePunch(rp, dt, err) -} - -func (hs *Service) holePunchConnect(pi peer.AddrInfo, isClient bool) error { - holePunchCtx := network.WithSimultaneousConnect(hs.ctx, isClient, "hole-punching") - forceDirectConnCtx := network.WithForceDirectDial(holePunchCtx, "hole-punching") - dialCtx, cancel := context.WithTimeout(forceDirectConnCtx, dialTimeout) - defer cancel() - - hs.tracer.HolePunchAttempt(pi.ID) - if err := hs.host.Connect(dialCtx, pi); err != nil { - log.Debugw("hole punch attempt with peer failed", "peer ID", pi.ID, "error", err) - return err - } - log.Debugw("hole punch successful", "peer", pi.ID) - return nil -} - -func containsPublicAddr(addrs []ma.Multiaddr) bool { - for _, addr := range addrs { - if isRelayAddress(addr) || !manet.IsPublicAddr(addr) { - continue - } - return true - } - return false -} - -func removeRelayAddrs(addrs []ma.Multiaddr) []ma.Multiaddr { - result := make([]ma.Multiaddr, 0, len(addrs)) - for _, addr := range addrs { - if !isRelayAddress(addr) { - result = append(result, addr) - } - } - return result -} - -func isRelayAddress(a ma.Multiaddr) bool { - _, err := a.ValueForProtocol(ma.P_CIRCUIT) - return err == nil -} - -func addrsToBytes(as []ma.Multiaddr) [][]byte { - bzs := make([][]byte, 0, len(as)) - for _, a := range as { - bzs = append(bzs, a.Bytes()) - } - return bzs -} - -func addrsFromBytes(bzs [][]byte) []ma.Multiaddr { - addrs := make([]ma.Multiaddr, 0, len(bzs)) - for _, bz := range bzs { - a, err := ma.NewMultiaddrBytes(bz) - if err == nil { - addrs = append(addrs, a) - } - } - return addrs -} - -type netNotifiee Service - -func (nn *netNotifiee) Connected(_ network.Network, conn network.Conn) { - hs := (*Service)(nn) - - // Hole punch if it's an inbound proxy connection. - // If we already have a direct connection with the remote peer, this will be a no-op. - if conn.Stat().Direction == network.DirInbound && isRelayAddress(conn.RemoteMultiaddr()) { - hs.refCount.Add(1) - go func() { - defer hs.refCount.Done() - - select { - // waiting for Identify here will allow us to access the peer's public and observed addresses - // that we can dial to for a hole punch. - case <-hs.ids.IdentifyWait(conn): - case <-hs.ctx.Done(): - return - } - - _ = hs.DirectConnect(conn.RemotePeer()) - }() - } -} - -func (nn *netNotifiee) Disconnected(_ network.Network, v network.Conn) {} -func (nn *netNotifiee) OpenedStream(n network.Network, v network.Stream) {} -func (nn *netNotifiee) ClosedStream(n network.Network, v network.Stream) {} -func (nn *netNotifiee) Listen(n network.Network, a ma.Multiaddr) {} -func (nn *netNotifiee) ListenClose(n network.Network, a ma.Multiaddr) {} diff --git a/p2p/protocol/holepunch/coordination_test.go b/p2p/protocol/holepunch/holepunch_test.go similarity index 100% rename from p2p/protocol/holepunch/coordination_test.go rename to p2p/protocol/holepunch/holepunch_test.go diff --git a/p2p/protocol/holepunch/holepuncher.go b/p2p/protocol/holepunch/holepuncher.go new file mode 100644 index 00000000..6309d6eb --- /dev/null +++ b/p2p/protocol/holepunch/holepuncher.go @@ -0,0 +1,289 @@ +package holepunch + +import ( + "context" + "errors" + "fmt" + "sync" + "time" + + pb "github.com/libp2p/go-libp2p/p2p/protocol/holepunch/pb" + "github.com/libp2p/go-libp2p/p2p/protocol/identify" + + "github.com/libp2p/go-libp2p-core/host" + "github.com/libp2p/go-libp2p-core/network" + "github.com/libp2p/go-libp2p-core/peer" + + "github.com/libp2p/go-msgio/protoio" + + ma "github.com/multiformats/go-multiaddr" + manet "github.com/multiformats/go-multiaddr/net" +) + +// ErrHolePunchActive is returned from DirectConnect when another hole punching attempt is currently running +var ErrHolePunchActive = errors.New("another hole punching attempt to this peer is active") + +const ( + dialTimeout = 5 * time.Second + maxRetries = 3 +) + +// The holePuncher is run on the peer that's behind a NAT / Firewall. +// It observes new incoming connections via a relay that it has a reservation with, +// and initiates the DCUtR protocol with them. +// It then first tries to establish a direct connection, and if that fails, it +// initiates a hole punch. +type holePuncher struct { + ctx context.Context + ctxCancel context.CancelFunc + + host host.Host + refCount sync.WaitGroup + + ids identify.IDService + + hasPublicAddrsChan <-chan struct{} + + // active hole punches for deduplicating + activeMx sync.Mutex + active map[peer.ID]struct{} + + closeMx sync.RWMutex + closed bool + + tracer *tracer +} + +func newHolePuncher(h host.Host, ids identify.IDService, hasPublicAddrsChan <-chan struct{}, tracer *tracer) *holePuncher { + hp := &holePuncher{ + host: h, + ids: ids, + active: make(map[peer.ID]struct{}), + hasPublicAddrsChan: hasPublicAddrsChan, + tracer: tracer, + } + hp.ctx, hp.ctxCancel = context.WithCancel(context.Background()) + h.Network().Notify((*netNotifiee)(hp)) + return hp +} + +func (hp *holePuncher) beginDirectConnect(p peer.ID) error { + hp.closeMx.RLock() + defer hp.closeMx.RUnlock() + if hp.closed { + return ErrClosed + } + + hp.activeMx.Lock() + defer hp.activeMx.Unlock() + if _, ok := hp.active[p]; ok { + return ErrHolePunchActive + } + + hp.active[p] = struct{}{} + return nil +} + +// DirectConnect attempts to make a direct connection with a remote peer. +// It first attempts a direct dial (if we have a public address of that peer), and then +// coordinates a hole punch over the given relay connection. +func (hp *holePuncher) DirectConnect(p peer.ID) error { + if err := hp.beginDirectConnect(p); err != nil { + return err + } + + defer func() { + hp.activeMx.Lock() + delete(hp.active, p) + hp.activeMx.Unlock() + }() + + return hp.directConnect(p) +} + +func (hp *holePuncher) directConnect(rp peer.ID) error { + // short-circuit check to see if we already have a direct connection + for _, c := range hp.host.Network().ConnsToPeer(rp) { + if !isRelayAddress(c.RemoteMultiaddr()) { + return nil + } + } + + // short-circuit hole punching if a direct dial works. + // attempt a direct connection ONLY if we have a public address for the remote peer + for _, a := range hp.host.Peerstore().Addrs(rp) { + if manet.IsPublicAddr(a) && !isRelayAddress(a) { + forceDirectConnCtx := network.WithForceDirectDial(hp.ctx, "hole-punching") + dialCtx, cancel := context.WithTimeout(forceDirectConnCtx, dialTimeout) + + tstart := time.Now() + // This dials *all* public addresses from the peerstore. + err := hp.host.Connect(dialCtx, peer.AddrInfo{ID: rp}) + dt := time.Since(tstart) + cancel() + + if err != nil { + hp.tracer.DirectDialFailed(rp, dt, err) + break + } + hp.tracer.DirectDialSuccessful(rp, dt) + log.Debugw("direct connection to peer successful, no need for a hole punch", "peer", rp) + return nil + } + } + + log.Debugw("got inbound proxy conn", "peer", rp) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + select { + case <-hp.ctx.Done(): + return hp.ctx.Err() + case <-ctx.Done(): + log.Debug("didn't find any public host address") + return errors.New("can't initiate hole punch, as we don't have any public addresses") + case <-hp.hasPublicAddrsChan: + } + + // hole punch + for i := 0; i < maxRetries; i++ { + addrs, rtt, err := hp.initiateHolePunch(rp) + if err != nil { + log.Debugw("hole punching failed", "peer", rp, "error", err) + hp.tracer.ProtocolError(rp, err) + return err + } + synTime := rtt / 2 + log.Debugf("peer RTT is %s; starting hole punch in %s", rtt, synTime) + + // wait for sync to reach the other peer and then punch a hole for it in our NAT + // by attempting a connect to it. + timer := time.NewTimer(synTime) + select { + case start := <-timer.C: + pi := peer.AddrInfo{ + ID: rp, + Addrs: addrs, + } + hp.tracer.StartHolePunch(rp, addrs, rtt) + hp.tracer.HolePunchAttempt(pi.ID) + err := holePunchConnect(hp.ctx, hp.host, pi, true) + dt := time.Since(start) + hp.tracer.EndHolePunch(rp, dt, err) + if err == nil { + log.Debugw("hole punching with successful", "peer", rp, "time", dt) + return nil + } + case <-hp.ctx.Done(): + timer.Stop() + return hp.ctx.Err() + } + } + return fmt.Errorf("all retries for hole punch with peer %s failed", rp) +} + +// initiateHolePunch opens a new hole punching coordination stream, +// exchanges the addresses and measures the RTT. +func (hp *holePuncher) initiateHolePunch(rp peer.ID) ([]ma.Multiaddr, time.Duration, error) { + hpCtx := network.WithUseTransient(hp.ctx, "hole-punch") + sCtx := network.WithNoDial(hpCtx, "hole-punch") + + str, err := hp.host.NewStream(sCtx, rp, Protocol) + if err != nil { + return nil, 0, fmt.Errorf("failed to open hole-punching stream: %w", err) + } + defer str.Close() + + if err := str.Scope().SetService(ServiceName); err != nil { + log.Debugf("error attaching stream to holepunch service: %s", err) + str.Reset() + return nil, 0, err + } + + if err := str.Scope().ReserveMemory(maxMsgSize, network.ReservationPriorityAlways); err != nil { + log.Debugf("error reserving memory for stream: %s, err") + str.Reset() + return nil, 0, err + } + defer str.Scope().ReleaseMemory(maxMsgSize) + + w := protoio.NewDelimitedWriter(str) + rd := protoio.NewDelimitedReader(str, maxMsgSize) + + str.SetDeadline(time.Now().Add(StreamTimeout)) + // send a CONNECT and start RTT measurement. + msg := &pb.HolePunch{ + Type: pb.HolePunch_CONNECT.Enum(), + ObsAddrs: addrsToBytes(removeRelayAddrs(hp.ids.OwnObservedAddrs())), + } + + start := time.Now() + if err := w.WriteMsg(msg); err != nil { + str.Reset() + return nil, 0, err + } + + // wait for a CONNECT message from the remote peer + msg.Reset() + if err := rd.ReadMsg(msg); err != nil { + str.Reset() + return nil, 0, fmt.Errorf("failed to read CONNECT message from remote peer: %w", err) + } + rtt := time.Since(start) + + if t := msg.GetType(); t != pb.HolePunch_CONNECT { + str.Reset() + return nil, 0, fmt.Errorf("expect CONNECT message, got %s", t) + } + addrs := removeRelayAddrs(addrsFromBytes(msg.ObsAddrs)) + if len(addrs) == 0 { + str.Reset() + } + + msg.Reset() + msg.Type = pb.HolePunch_SYNC.Enum() + if err := w.WriteMsg(msg); err != nil { + str.Reset() + return nil, 0, fmt.Errorf("failed to send SYNC message for hole punching: %w", err) + } + return addrs, rtt, nil +} + +func (hp *holePuncher) Close() error { + hp.closeMx.Lock() + hp.closed = true + hp.closeMx.Unlock() + hp.ctxCancel() + hp.refCount.Wait() + return nil +} + +type netNotifiee holePuncher + +func (nn *netNotifiee) Connected(_ network.Network, conn network.Conn) { + hs := (*holePuncher)(nn) + + // Hole punch if it's an inbound proxy connection. + // If we already have a direct connection with the remote peer, this will be a no-op. + if conn.Stat().Direction == network.DirInbound && isRelayAddress(conn.RemoteMultiaddr()) { + hs.refCount.Add(1) + go func() { + defer hs.refCount.Done() + + select { + // waiting for Identify here will allow us to access the peer's public and observed addresses + // that we can dial to for a hole punch. + case <-hs.ids.IdentifyWait(conn): + case <-hs.ctx.Done(): + return + } + + _ = hs.DirectConnect(conn.RemotePeer()) + }() + } +} + +func (nn *netNotifiee) Disconnected(_ network.Network, v network.Conn) {} +func (nn *netNotifiee) OpenedStream(n network.Network, v network.Stream) {} +func (nn *netNotifiee) ClosedStream(n network.Network, v network.Stream) {} +func (nn *netNotifiee) Listen(n network.Network, a ma.Multiaddr) {} +func (nn *netNotifiee) ListenClose(n network.Network, a ma.Multiaddr) {} diff --git a/p2p/protocol/holepunch/svc.go b/p2p/protocol/holepunch/svc.go new file mode 100644 index 00000000..ff7178c2 --- /dev/null +++ b/p2p/protocol/holepunch/svc.go @@ -0,0 +1,236 @@ +package holepunch + +import ( + "context" + "errors" + "fmt" + "sync" + "time" + + "github.com/libp2p/go-libp2p-core/host" + "github.com/libp2p/go-libp2p-core/network" + "github.com/libp2p/go-libp2p-core/peer" + "github.com/libp2p/go-libp2p-core/protocol" + pb "github.com/libp2p/go-libp2p/p2p/protocol/holepunch/pb" + "github.com/libp2p/go-libp2p/p2p/protocol/identify" + + logging "github.com/ipfs/go-log/v2" + "github.com/libp2p/go-msgio/protoio" + ma "github.com/multiformats/go-multiaddr" +) + +// Protocol is the libp2p protocol for Hole Punching. +const Protocol protocol.ID = "/libp2p/dcutr" + +var log = logging.Logger("p2p-holepunch") + +// StreamTimeout is the timeout for the hole punch protocol stream. +var StreamTimeout = 1 * time.Minute + +const ( + ServiceName = "libp2p.holepunch" + + maxMsgSize = 4 * 1024 // 4K +) + +// ErrClosed is returned when the hole punching is closed +var ErrClosed = errors.New("hole punching service closing") + +type Option func(*Service) error + +// The Service runs on every node that supports the DCUtR protocol. +type Service struct { + ctx context.Context + ctxCancel context.CancelFunc + + host host.Host + ids identify.IDService + holePuncher *holePuncher + + tracer *tracer + + refCount sync.WaitGroup + + hasPublicAddrsChan chan struct{} // this chan is closed as soon as we have a public address +} + +// NewService creates a new service that can be used for hole punching +// The Service runs on all hosts that support the DCUtR protocol, +// no matter if they are behind a NAT / firewall or not. +// The Service handles DCUtR streams (which are initiated from the node behind +// a NAT / Firewall once we establish a connection to them through a relay. +func NewService(h host.Host, ids identify.IDService, opts ...Option) (*Service, error) { + if ids == nil { + return nil, errors.New("identify service can't be nil") + } + + ctx, cancel := context.WithCancel(context.Background()) + hs := &Service{ + ctx: ctx, + ctxCancel: cancel, + host: h, + ids: ids, + hasPublicAddrsChan: make(chan struct{}), + } + + for _, opt := range opts { + if err := opt(hs); err != nil { + cancel() + return nil, err + } + } + hs.holePuncher = newHolePuncher(h, ids, hs.hasPublicAddrsChan, hs.tracer) + + hs.refCount.Add(1) + go hs.watchForPublicAddr() + + return hs, nil +} + +func (s *Service) watchForPublicAddr() { + defer s.refCount.Done() + + log.Debug("waiting until we have at least one public address", "peer", s.host.ID()) + + // TODO: We should have an event here that fires when identify discovers a new + // address (and when autonat confirms that address). + // As we currently don't have an event like this, just check our observed addresses + // regularly (exponential backoff starting at 250 ms, capped at 5s). + duration := 250 * time.Millisecond + const maxDuration = 5 * time.Second + t := time.NewTimer(duration) + defer t.Stop() + for { + if containsPublicAddr(s.ids.OwnObservedAddrs()) { + log.Debug("Host now has a public address. Starting holepunch protocol.") + s.host.SetStreamHandler(Protocol, s.handleNewStream) + close(s.hasPublicAddrsChan) + return + } + + select { + case <-s.ctx.Done(): + return + case <-t.C: + duration *= 2 + if duration > maxDuration { + duration = maxDuration + } + t.Reset(duration) + } + } +} + +// Close closes the Hole Punch Service. +func (s *Service) Close() error { + err := s.holePuncher.Close() + s.tracer.Close() + s.host.RemoveStreamHandler(Protocol) + s.ctxCancel() + s.refCount.Wait() + return err +} + +func (s *Service) incomingHolePunch(str network.Stream) (rtt time.Duration, addrs []ma.Multiaddr, err error) { + // sanity check: a hole punch request should only come from peers behind a relay + if !isRelayAddress(str.Conn().RemoteMultiaddr()) { + return 0, nil, fmt.Errorf("received hole punch stream: %s", str.Conn().RemoteMultiaddr()) + } + ownAddrs := removeRelayAddrs(s.ids.OwnObservedAddrs()) + // If we can't tell the peer where to dial us, there's no point in starting the hole punching. + if len(ownAddrs) == 0 { + return 0, nil, errors.New("rejecting hole punch request, as we don't have any public addresses") + } + + if err := str.Scope().ReserveMemory(maxMsgSize, network.ReservationPriorityAlways); err != nil { + log.Debugf("error reserving memory for stream: %s, err") + return 0, nil, err + } + defer str.Scope().ReleaseMemory(maxMsgSize) + + wr := protoio.NewDelimitedWriter(str) + rd := protoio.NewDelimitedReader(str, maxMsgSize) + + // Read Connect message + msg := new(pb.HolePunch) + + str.SetDeadline(time.Now().Add(StreamTimeout)) + + if err := rd.ReadMsg(msg); err != nil { + return 0, nil, fmt.Errorf("failed to read message from initator: %w", err) + } + if t := msg.GetType(); t != pb.HolePunch_CONNECT { + return 0, nil, fmt.Errorf("expected CONNECT message from initiator but got %d", t) + } + obsDial := removeRelayAddrs(addrsFromBytes(msg.ObsAddrs)) + log.Debugw("received hole punch request", "peer", str.Conn().RemotePeer(), "addrs", obsDial) + if len(obsDial) == 0 { + return 0, nil, errors.New("expected CONNECT message to contain at least one address") + } + + // Write CONNECT message + msg.Reset() + msg.Type = pb.HolePunch_CONNECT.Enum() + msg.ObsAddrs = addrsToBytes(ownAddrs) + tstart := time.Now() + if err := wr.WriteMsg(msg); err != nil { + return 0, nil, fmt.Errorf("failed to write CONNECT message to initator: %w", err) + } + + // Read SYNC message + msg.Reset() + if err := rd.ReadMsg(msg); err != nil { + return 0, nil, fmt.Errorf("failed to read message from initator: %w", err) + } + if t := msg.GetType(); t != pb.HolePunch_SYNC { + return 0, nil, fmt.Errorf("expected SYNC message from initiator but got %d", t) + } + return time.Since(tstart), obsDial, nil +} + +func (s *Service) handleNewStream(str network.Stream) { + // Check directionality of the underlying connection. + // Peer A receives an inbound connection from peer B. + // Peer A opens a new hole punch stream to peer B. + // Peer B receives this stream, calling this function. + // Peer B sees the underlying connection as an outbound connection. + if str.Conn().Stat().Direction == network.DirInbound { + str.Reset() + return + } + + if err := str.Scope().SetService(ServiceName); err != nil { + log.Debugf("error attaching stream to holepunch service: %s", err) + str.Reset() + return + } + + rp := str.Conn().RemotePeer() + rtt, addrs, err := s.incomingHolePunch(str) + if err != nil { + s.tracer.ProtocolError(rp, err) + log.Debugw("error handling holepunching stream from", rp, "error", err) + str.Reset() + return + } + str.Close() + + // Hole punch now by forcing a connect + pi := peer.AddrInfo{ + ID: rp, + Addrs: addrs, + } + s.tracer.StartHolePunch(rp, addrs, rtt) + log.Debugw("starting hole punch", "peer", rp) + start := time.Now() + s.tracer.HolePunchAttempt(pi.ID) + err = holePunchConnect(s.ctx, s.host, pi, false) + dt := time.Since(start) + s.tracer.EndHolePunch(rp, dt, err) +} + +// DirectConnect is only exposed for testing purposes. +// TODO: find a solution for this. +func (s *Service) DirectConnect(p peer.ID) error { + return s.holePuncher.DirectConnect(p) +} diff --git a/p2p/protocol/holepunch/util.go b/p2p/protocol/holepunch/util.go new file mode 100644 index 00000000..af707757 --- /dev/null +++ b/p2p/protocol/holepunch/util.go @@ -0,0 +1,69 @@ +package holepunch + +import ( + "context" + + "github.com/libp2p/go-libp2p-core/host" + "github.com/libp2p/go-libp2p-core/network" + "github.com/libp2p/go-libp2p-core/peer" + ma "github.com/multiformats/go-multiaddr" + manet "github.com/multiformats/go-multiaddr/net" +) + +func containsPublicAddr(addrs []ma.Multiaddr) bool { + for _, addr := range addrs { + if isRelayAddress(addr) || !manet.IsPublicAddr(addr) { + continue + } + return true + } + return false +} + +func removeRelayAddrs(addrs []ma.Multiaddr) []ma.Multiaddr { + result := make([]ma.Multiaddr, 0, len(addrs)) + for _, addr := range addrs { + if !isRelayAddress(addr) { + result = append(result, addr) + } + } + return result +} + +func isRelayAddress(a ma.Multiaddr) bool { + _, err := a.ValueForProtocol(ma.P_CIRCUIT) + return err == nil +} + +func addrsToBytes(as []ma.Multiaddr) [][]byte { + bzs := make([][]byte, 0, len(as)) + for _, a := range as { + bzs = append(bzs, a.Bytes()) + } + return bzs +} + +func addrsFromBytes(bzs [][]byte) []ma.Multiaddr { + addrs := make([]ma.Multiaddr, 0, len(bzs)) + for _, bz := range bzs { + a, err := ma.NewMultiaddrBytes(bz) + if err == nil { + addrs = append(addrs, a) + } + } + return addrs +} + +func holePunchConnect(ctx context.Context, host host.Host, pi peer.AddrInfo, isClient bool) error { + holePunchCtx := network.WithSimultaneousConnect(ctx, isClient, "hole-punching") + forceDirectConnCtx := network.WithForceDirectDial(holePunchCtx, "hole-punching") + dialCtx, cancel := context.WithTimeout(forceDirectConnCtx, dialTimeout) + defer cancel() + + if err := host.Connect(dialCtx, pi); err != nil { + log.Debugw("hole punch attempt with peer failed", "peer ID", pi.ID, "error", err) + return err + } + log.Debugw("hole punch successful", "peer", pi.ID) + return nil +}