status-go/vendor/github.com/libp2p/go-libp2p/p2p/net/swarm/swarm.go

830 lines
21 KiB
Go

package swarm
import (
"context"
"errors"
"fmt"
"io"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/libp2p/go-libp2p/core/connmgr"
"github.com/libp2p/go-libp2p/core/event"
"github.com/libp2p/go-libp2p/core/metrics"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/peerstore"
"github.com/libp2p/go-libp2p/core/transport"
"golang.org/x/exp/slices"
logging "github.com/ipfs/go-log/v2"
ma "github.com/multiformats/go-multiaddr"
madns "github.com/multiformats/go-multiaddr-dns"
)
const (
defaultDialTimeout = 15 * time.Second
// defaultDialTimeoutLocal is the maximum duration a Dial to local network address
// is allowed to take.
// This includes the time between dialing the raw network connection,
// protocol selection as well the handshake, if applicable.
defaultDialTimeoutLocal = 5 * time.Second
)
var log = logging.Logger("swarm2")
// ErrSwarmClosed is returned when one attempts to operate on a closed swarm.
var ErrSwarmClosed = errors.New("swarm closed")
// ErrAddrFiltered is returned when trying to register a connection to a
// filtered address. You shouldn't see this error unless some underlying
// transport is misbehaving.
var ErrAddrFiltered = errors.New("address filtered")
// ErrDialTimeout is returned when one a dial times out due to the global timeout
var ErrDialTimeout = errors.New("dial timed out")
type Option func(*Swarm) error
// WithConnectionGater sets a connection gater
func WithConnectionGater(gater connmgr.ConnectionGater) Option {
return func(s *Swarm) error {
s.gater = gater
return nil
}
}
// WithMultiaddrResolver sets a custom multiaddress resolver
func WithMultiaddrResolver(maResolver *madns.Resolver) Option {
return func(s *Swarm) error {
s.maResolver = maResolver
return nil
}
}
// WithMetrics sets a metrics reporter
func WithMetrics(reporter metrics.Reporter) Option {
return func(s *Swarm) error {
s.bwc = reporter
return nil
}
}
func WithMetricsTracer(t MetricsTracer) Option {
return func(s *Swarm) error {
s.metricsTracer = t
return nil
}
}
func WithDialTimeout(t time.Duration) Option {
return func(s *Swarm) error {
s.dialTimeout = t
return nil
}
}
func WithDialTimeoutLocal(t time.Duration) Option {
return func(s *Swarm) error {
s.dialTimeoutLocal = t
return nil
}
}
func WithResourceManager(m network.ResourceManager) Option {
return func(s *Swarm) error {
s.rcmgr = m
return nil
}
}
// WithDialRanker configures swarm to use d as the DialRanker
func WithDialRanker(d network.DialRanker) Option {
return func(s *Swarm) error {
if d == nil {
return errors.New("swarm: dial ranker cannot be nil")
}
s.dialRanker = d
return nil
}
}
// WithUDPBlackHoleConfig configures swarm to use c as the config for UDP black hole detection
// n is the size of the sliding window used to evaluate black hole state
// min is the minimum number of successes out of n required to not block requests
func WithUDPBlackHoleConfig(enabled bool, n, min int) Option {
return func(s *Swarm) error {
s.udpBlackHoleConfig = blackHoleConfig{Enabled: enabled, N: n, MinSuccesses: min}
return nil
}
}
// WithIPv6BlackHoleConfig configures swarm to use c as the config for IPv6 black hole detection
// n is the size of the sliding window used to evaluate black hole state
// min is the minimum number of successes out of n required to not block requests
func WithIPv6BlackHoleConfig(enabled bool, n, min int) Option {
return func(s *Swarm) error {
s.ipv6BlackHoleConfig = blackHoleConfig{Enabled: enabled, N: n, MinSuccesses: min}
return nil
}
}
// Swarm is a connection muxer, allowing connections to other peers to
// be opened and closed, while still using the same Chan for all
// communication. The Chan sends/receives Messages, which note the
// destination or source Peer.
type Swarm struct {
nextConnID atomic.Uint64
nextStreamID atomic.Uint64
// Close refcount. This allows us to fully wait for the swarm to be torn
// down before continuing.
refs sync.WaitGroup
emitter event.Emitter
rcmgr network.ResourceManager
local peer.ID
peers peerstore.Peerstore
dialTimeout time.Duration
dialTimeoutLocal time.Duration
conns struct {
sync.RWMutex
m map[peer.ID][]*Conn
}
listeners struct {
sync.RWMutex
ifaceListenAddres []ma.Multiaddr
cacheEOL time.Time
m map[transport.Listener]struct{}
}
notifs struct {
sync.RWMutex
m map[network.Notifiee]struct{}
}
directConnNotifs struct {
sync.Mutex
m map[peer.ID][]chan struct{}
}
transports struct {
sync.RWMutex
m map[int]transport.Transport
}
maResolver *madns.Resolver
// stream handlers
streamh atomic.Pointer[network.StreamHandler]
// dialing helpers
dsync *dialSync
backf DialBackoff
limiter *dialLimiter
gater connmgr.ConnectionGater
closeOnce sync.Once
ctx context.Context // is canceled when Close is called
ctxCancel context.CancelFunc
bwc metrics.Reporter
metricsTracer MetricsTracer
dialRanker network.DialRanker
udpBlackHoleConfig blackHoleConfig
ipv6BlackHoleConfig blackHoleConfig
bhd *blackHoleDetector
connectednessEventEmitter *connectednessEventEmitter
}
// NewSwarm constructs a Swarm.
func NewSwarm(local peer.ID, peers peerstore.Peerstore, eventBus event.Bus, opts ...Option) (*Swarm, error) {
emitter, err := eventBus.Emitter(new(event.EvtPeerConnectednessChanged))
if err != nil {
return nil, err
}
ctx, cancel := context.WithCancel(context.Background())
s := &Swarm{
local: local,
peers: peers,
emitter: emitter,
ctx: ctx,
ctxCancel: cancel,
dialTimeout: defaultDialTimeout,
dialTimeoutLocal: defaultDialTimeoutLocal,
maResolver: madns.DefaultResolver,
dialRanker: DefaultDialRanker,
// A black hole is a binary property. On a network if UDP dials are blocked or there is
// no IPv6 connectivity, all dials will fail. So a low success rate of 5 out 100 dials
// is good enough.
udpBlackHoleConfig: blackHoleConfig{Enabled: true, N: 100, MinSuccesses: 5},
ipv6BlackHoleConfig: blackHoleConfig{Enabled: true, N: 100, MinSuccesses: 5},
}
s.conns.m = make(map[peer.ID][]*Conn)
s.listeners.m = make(map[transport.Listener]struct{})
s.transports.m = make(map[int]transport.Transport)
s.notifs.m = make(map[network.Notifiee]struct{})
s.directConnNotifs.m = make(map[peer.ID][]chan struct{})
s.connectednessEventEmitter = newConnectednessEventEmitter(s.Connectedness, emitter)
for _, opt := range opts {
if err := opt(s); err != nil {
return nil, err
}
}
if s.rcmgr == nil {
s.rcmgr = &network.NullResourceManager{}
}
s.dsync = newDialSync(s.dialWorkerLoop)
s.limiter = newDialLimiter(s.dialAddr)
s.backf.init(s.ctx)
s.bhd = newBlackHoleDetector(s.udpBlackHoleConfig, s.ipv6BlackHoleConfig, s.metricsTracer)
return s, nil
}
func (s *Swarm) Close() error {
s.closeOnce.Do(s.close)
return nil
}
// Done returns a channel that is closed when the swarm is closed.
func (s *Swarm) Done() <-chan struct{} {
return s.ctx.Done()
}
func (s *Swarm) close() {
s.ctxCancel()
// Prevents new connections and/or listeners from being added to the swarm.
s.listeners.Lock()
listeners := s.listeners.m
s.listeners.m = nil
s.listeners.Unlock()
s.conns.Lock()
conns := s.conns.m
s.conns.m = nil
s.conns.Unlock()
// Lots of goroutines but we might as well do this in parallel. We want to shut down as fast as
// possible.
s.refs.Add(len(listeners))
for l := range listeners {
go func(l transport.Listener) {
defer s.refs.Done()
if err := l.Close(); err != nil && err != transport.ErrListenerClosed {
log.Errorf("error when shutting down listener: %s", err)
}
}(l)
}
for _, cs := range conns {
for _, c := range cs {
go func(c *Conn) {
if err := c.Close(); err != nil {
log.Errorf("error when shutting down connection: %s", err)
}
}(c)
}
}
// Wait for everything to finish.
s.refs.Wait()
s.connectednessEventEmitter.Close()
s.emitter.Close()
// Now close out any transports (if necessary). Do this after closing
// all connections/listeners.
s.transports.Lock()
transports := s.transports.m
s.transports.m = nil
s.transports.Unlock()
// Dedup transports that may be listening on multiple protocols
transportsToClose := make(map[transport.Transport]struct{}, len(transports))
for _, t := range transports {
transportsToClose[t] = struct{}{}
}
var wg sync.WaitGroup
for t := range transportsToClose {
if closer, ok := t.(io.Closer); ok {
wg.Add(1)
go func(c io.Closer) {
defer wg.Done()
if err := closer.Close(); err != nil {
log.Errorf("error when closing down transport %T: %s", c, err)
}
}(closer)
}
}
wg.Wait()
}
func (s *Swarm) addConn(tc transport.CapableConn, dir network.Direction) (*Conn, error) {
var (
p = tc.RemotePeer()
addr = tc.RemoteMultiaddr()
)
// create the Stat object, initializing with the underlying connection Stat if available
var stat network.ConnStats
if cs, ok := tc.(network.ConnStat); ok {
stat = cs.Stat()
}
stat.Direction = dir
stat.Opened = time.Now()
isLimited := stat.Limited
// Wrap and register the connection.
c := &Conn{
conn: tc,
swarm: s,
stat: stat,
id: s.nextConnID.Add(1),
}
// we ONLY check upgraded connections here so we can send them a Disconnect message.
// If we do this in the Upgrader, we will not be able to do this.
if s.gater != nil {
if allow, _ := s.gater.InterceptUpgraded(c); !allow {
// TODO Send disconnect with reason here
err := tc.Close()
if err != nil {
log.Warnf("failed to close connection with peer %s and addr %s; err: %s", p, addr, err)
}
return nil, ErrGaterDisallowedConnection
}
}
// Add the public key.
if pk := tc.RemotePublicKey(); pk != nil {
s.peers.AddPubKey(p, pk)
}
// Clear any backoffs
s.backf.Clear(p)
// Finally, add the peer.
s.conns.Lock()
// Check if we're still online
if s.conns.m == nil {
s.conns.Unlock()
tc.Close()
return nil, ErrSwarmClosed
}
c.streams.m = make(map[*Stream]struct{})
s.conns.m[p] = append(s.conns.m[p], c)
// Add two swarm refs:
// * One will be decremented after the close notifications fire in Conn.doClose
// * The other will be decremented when Conn.start exits.
s.refs.Add(2)
// Take the notification lock before releasing the conns lock to block
// Disconnect notifications until after the Connect notifications done.
// This lock also ensures that swarm.refs.Wait() exits after we have
// enqueued the peer connectedness changed notification.
// TODO: Fix this fragility by taking a swarm ref for dial worker loop
c.notifyLk.Lock()
s.conns.Unlock()
s.connectednessEventEmitter.AddConn(p)
if !isLimited {
// Notify goroutines waiting for a direct connection
//
// Go routines interested in waiting for direct connection first acquire this lock
// and then acquire s.conns.RLock. Do not acquire this lock before conns.Unlock to
// prevent deadlock.
s.directConnNotifs.Lock()
for _, ch := range s.directConnNotifs.m[p] {
close(ch)
}
delete(s.directConnNotifs.m, p)
s.directConnNotifs.Unlock()
}
s.notifyAll(func(f network.Notifiee) {
f.Connected(s, c)
})
c.notifyLk.Unlock()
c.start()
return c, nil
}
// Peerstore returns this swarms internal Peerstore.
func (s *Swarm) Peerstore() peerstore.Peerstore {
return s.peers
}
// SetStreamHandler assigns the handler for new streams.
func (s *Swarm) SetStreamHandler(handler network.StreamHandler) {
s.streamh.Store(&handler)
}
// StreamHandler gets the handler for new streams.
func (s *Swarm) StreamHandler() network.StreamHandler {
handler := s.streamh.Load()
if handler == nil {
return nil
}
return *handler
}
// NewStream creates a new stream on any available connection to peer, dialing
// if necessary.
// Use network.WithAllowLimitedConn to open a stream over a limited(relayed)
// connection.
func (s *Swarm) NewStream(ctx context.Context, p peer.ID) (network.Stream, error) {
log.Debugf("[%s] opening stream to peer [%s]", s.local, p)
// Algorithm:
// 1. Find the best connection, otherwise, dial.
// 2. If the best connection is limited, wait for a direct conn via conn
// reversal or hole punching.
// 3. Try opening a stream.
// 4. If the underlying connection is, in fact, closed, close the outer
// connection and try again. We do this in case we have a closed
// connection but don't notice it until we actually try to open a
// stream.
//
// TODO: Try all connections even if we get an error opening a stream on
// a non-closed connection.
numDials := 0
for {
c := s.bestConnToPeer(p)
if c == nil {
if nodial, _ := network.GetNoDial(ctx); !nodial {
numDials++
if numDials > DialAttempts {
return nil, errors.New("max dial attempts exceeded")
}
var err error
c, err = s.dialPeer(ctx, p)
if err != nil {
return nil, err
}
} else {
return nil, network.ErrNoConn
}
}
limitedAllowed, _ := network.GetAllowLimitedConn(ctx)
if !limitedAllowed && c.Stat().Limited {
var err error
c, err = s.waitForDirectConn(ctx, p)
if err != nil {
return nil, err
}
}
str, err := c.NewStream(ctx)
if err != nil {
if c.conn.IsClosed() {
continue
}
return nil, err
}
return str, nil
}
}
// waitForDirectConn waits for a direct connection established through hole punching or connection reversal.
func (s *Swarm) waitForDirectConn(ctx context.Context, p peer.ID) (*Conn, error) {
s.directConnNotifs.Lock()
c := s.bestConnToPeer(p)
if c == nil {
s.directConnNotifs.Unlock()
return nil, network.ErrNoConn
} else if !c.Stat().Limited {
s.directConnNotifs.Unlock()
return c, nil
}
// Wait for limited connection to upgrade to a direct connection either by
// connection reversal or hole punching.
ch := make(chan struct{})
s.directConnNotifs.m[p] = append(s.directConnNotifs.m[p], ch)
s.directConnNotifs.Unlock()
// apply the DialPeer timeout
ctx, cancel := context.WithTimeout(ctx, network.GetDialPeerTimeout(ctx))
defer cancel()
// Wait for notification.
select {
case <-ctx.Done():
// Remove ourselves from the notification list
s.directConnNotifs.Lock()
defer s.directConnNotifs.Unlock()
s.directConnNotifs.m[p] = slices.DeleteFunc(
s.directConnNotifs.m[p],
func(c chan struct{}) bool { return c == ch },
)
if len(s.directConnNotifs.m[p]) == 0 {
delete(s.directConnNotifs.m, p)
}
return nil, ctx.Err()
case <-ch:
// We do not need to remove ourselves from the list here as the notifier
// clears the map entry
c := s.bestConnToPeer(p)
if c == nil {
return nil, network.ErrNoConn
}
if c.Stat().Limited {
return nil, network.ErrLimitedConn
}
return c, nil
}
}
// ConnsToPeer returns all the live connections to peer.
func (s *Swarm) ConnsToPeer(p peer.ID) []network.Conn {
// TODO: Consider sorting the connection list best to worst. Currently,
// it's sorted oldest to newest.
s.conns.RLock()
defer s.conns.RUnlock()
conns := s.conns.m[p]
output := make([]network.Conn, len(conns))
for i, c := range conns {
output[i] = c
}
return output
}
func isBetterConn(a, b *Conn) bool {
// If one is limited and not the other, prefer the unlimited connection.
aLimited := a.Stat().Limited
bLimited := b.Stat().Limited
if aLimited != bLimited {
return !aLimited
}
// If one is direct and not the other, prefer the direct connection.
aDirect := isDirectConn(a)
bDirect := isDirectConn(b)
if aDirect != bDirect {
return aDirect
}
// Otherwise, prefer the connection with more open streams.
a.streams.Lock()
aLen := len(a.streams.m)
a.streams.Unlock()
b.streams.Lock()
bLen := len(b.streams.m)
b.streams.Unlock()
if aLen != bLen {
return aLen > bLen
}
// finally, pick the last connection.
return true
}
// bestConnToPeer returns the best connection to peer.
func (s *Swarm) bestConnToPeer(p peer.ID) *Conn {
// TODO: Prefer some transports over others.
// For now, prefers direct connections over Relayed connections.
// For tie-breaking, select the newest non-closed connection with the most streams.
s.conns.RLock()
defer s.conns.RUnlock()
var best *Conn
for _, c := range s.conns.m[p] {
if c.conn.IsClosed() {
// We *will* garbage collect this soon anyways.
continue
}
if best == nil || isBetterConn(c, best) {
best = c
}
}
return best
}
// bestAcceptableConnToPeer returns the best acceptable connection, considering the passed in ctx.
// If network.WithForceDirectDial is used, it only returns a direct connections, ignoring
// any limited (relayed) connections to the peer.
func (s *Swarm) bestAcceptableConnToPeer(ctx context.Context, p peer.ID) *Conn {
conn := s.bestConnToPeer(p)
forceDirect, _ := network.GetForceDirectDial(ctx)
if forceDirect && !isDirectConn(conn) {
return nil
}
return conn
}
func isDirectConn(c *Conn) bool {
return c != nil && !c.conn.Transport().Proxy()
}
// Connectedness returns our "connectedness" state with the given peer.
//
// To check if we have an open connection, use `s.Connectedness(p) ==
// network.Connected`.
func (s *Swarm) Connectedness(p peer.ID) network.Connectedness {
s.conns.RLock()
defer s.conns.RUnlock()
return s.connectednessUnlocked(p)
}
// connectednessUnlocked returns the connectedness of a peer.
func (s *Swarm) connectednessUnlocked(p peer.ID) network.Connectedness {
var haveLimited bool
for _, c := range s.conns.m[p] {
if c.IsClosed() {
// These will be garbage collected soon
continue
}
if c.Stat().Limited {
haveLimited = true
} else {
return network.Connected
}
}
if haveLimited {
return network.Limited
}
return network.NotConnected
}
// Conns returns a slice of all connections.
func (s *Swarm) Conns() []network.Conn {
s.conns.RLock()
defer s.conns.RUnlock()
conns := make([]network.Conn, 0, len(s.conns.m))
for _, cs := range s.conns.m {
for _, c := range cs {
conns = append(conns, c)
}
}
return conns
}
// ClosePeer closes all connections to the given peer.
func (s *Swarm) ClosePeer(p peer.ID) error {
conns := s.ConnsToPeer(p)
switch len(conns) {
case 0:
return nil
case 1:
return conns[0].Close()
default:
errCh := make(chan error)
for _, c := range conns {
go func(c network.Conn) {
errCh <- c.Close()
}(c)
}
var errs []string
for range conns {
err := <-errCh
if err != nil {
errs = append(errs, err.Error())
}
}
if len(errs) > 0 {
return fmt.Errorf("when disconnecting from peer %s: %s", p, strings.Join(errs, ", "))
}
return nil
}
}
// Peers returns a copy of the set of peers swarm is connected to.
func (s *Swarm) Peers() []peer.ID {
s.conns.RLock()
defer s.conns.RUnlock()
peers := make([]peer.ID, 0, len(s.conns.m))
for p := range s.conns.m {
peers = append(peers, p)
}
return peers
}
// LocalPeer returns the local peer swarm is associated to.
func (s *Swarm) LocalPeer() peer.ID {
return s.local
}
// Backoff returns the DialBackoff object for this swarm.
func (s *Swarm) Backoff() *DialBackoff {
return &s.backf
}
// notifyAll sends a signal to all Notifiees
func (s *Swarm) notifyAll(notify func(network.Notifiee)) {
s.notifs.RLock()
for f := range s.notifs.m {
notify(f)
}
s.notifs.RUnlock()
}
// Notify signs up Notifiee to receive signals when events happen
func (s *Swarm) Notify(f network.Notifiee) {
s.notifs.Lock()
s.notifs.m[f] = struct{}{}
s.notifs.Unlock()
}
// StopNotify unregisters Notifiee fromr receiving signals
func (s *Swarm) StopNotify(f network.Notifiee) {
s.notifs.Lock()
delete(s.notifs.m, f)
s.notifs.Unlock()
}
func (s *Swarm) removeConn(c *Conn) {
p := c.RemotePeer()
s.conns.Lock()
cs := s.conns.m[p]
for i, ci := range cs {
if ci == c {
// NOTE: We're intentionally preserving order.
// This way, connections to a peer are always
// sorted oldest to newest.
copy(cs[i:], cs[i+1:])
cs[len(cs)-1] = nil
s.conns.m[p] = cs[:len(cs)-1]
break
}
}
if len(s.conns.m[p]) == 0 {
delete(s.conns.m, p)
}
s.conns.Unlock()
}
// String returns a string representation of Network.
func (s *Swarm) String() string {
return fmt.Sprintf("<Swarm %s>", s.LocalPeer())
}
func (s *Swarm) ResourceManager() network.ResourceManager {
return s.rcmgr
}
// Swarm is a Network.
var _ network.Network = (*Swarm)(nil)
var _ transport.TransportNetwork = (*Swarm)(nil)
type connWithMetrics struct {
transport.CapableConn
opened time.Time
dir network.Direction
metricsTracer MetricsTracer
}
func wrapWithMetrics(capableConn transport.CapableConn, metricsTracer MetricsTracer, opened time.Time, dir network.Direction) connWithMetrics {
c := connWithMetrics{CapableConn: capableConn, opened: opened, dir: dir, metricsTracer: metricsTracer}
c.metricsTracer.OpenedConnection(c.dir, capableConn.RemotePublicKey(), capableConn.ConnState(), capableConn.LocalMultiaddr())
return c
}
func (c connWithMetrics) completedHandshake() {
c.metricsTracer.CompletedHandshake(time.Since(c.opened), c.ConnState(), c.LocalMultiaddr())
}
func (c connWithMetrics) Close() error {
c.metricsTracer.ClosedConnection(c.dir, time.Since(c.opened), c.ConnState(), c.LocalMultiaddr())
return c.CapableConn.Close()
}
func (c connWithMetrics) Stat() network.ConnStats {
if cs, ok := c.CapableConn.(network.ConnStat); ok {
return cs.Stat()
}
return network.ConnStats{}
}
var _ network.ConnStat = connWithMetrics{}