From 53881b2b0bda80cd2b40ed7472293c10300117b6 Mon Sep 17 00:00:00 2001 From: vyzo Date: Mon, 21 Feb 2022 12:56:04 +0200 Subject: [PATCH] refactor dialWorkerLoop into an object --- p2p/net/swarm/dial_worker.go | 351 +++++++++++++++++++++++++++++++++++ p2p/net/swarm/swarm_dial.go | 312 +------------------------------ 2 files changed, 353 insertions(+), 310 deletions(-) create mode 100644 p2p/net/swarm/dial_worker.go diff --git a/p2p/net/swarm/dial_worker.go b/p2p/net/swarm/dial_worker.go new file mode 100644 index 00000000..7bf0d1cc --- /dev/null +++ b/p2p/net/swarm/dial_worker.go @@ -0,0 +1,351 @@ +package swarm + +import ( + "context" + "sync" + + "github.com/libp2p/go-libp2p-core/network" + "github.com/libp2p/go-libp2p-core/peer" + + ma "github.com/multiformats/go-multiaddr" + manet "github.com/multiformats/go-multiaddr/net" +) + +// ///////////////////////////////////////////////////////////////////////////////// +// lo and behold, The Dialer +// TODO explain how all this works +// //////////////////////////////////////////////////////////////////////////////// + +type dialRequest struct { + ctx context.Context + resch chan dialResponse +} + +type dialResponse struct { + conn *Conn + err error +} + +type pendRequest struct { + req dialRequest // the original request + err *DialError // dial error accumulator + addrs map[ma.Multiaddr]struct{} // pending addr dials +} + +type addrDial struct { + addr ma.Multiaddr + ctx context.Context + conn *Conn + err error + requests []int + dialed bool +} + +type dialWorker struct { + s *Swarm + peer peer.ID + reqch <-chan dialRequest + reqno int + requests map[int]*pendRequest + pending map[ma.Multiaddr]*addrDial + resch chan dialResult + + active int + done bool // true when the request channel has been closed + connected bool // true when a connection has been successfully established + + nextDial []ma.Multiaddr + triggerDial <-chan struct{} + + // for testing + wg sync.WaitGroup + eval <-chan func() +} + +func newDialWorker(s *Swarm, p peer.ID, reqch <-chan dialRequest) *dialWorker { + return &dialWorker{ + s: s, + peer: p, + reqch: reqch, + requests: make(map[int]*pendRequest), + pending: make(map[ma.Multiaddr]*addrDial), + resch: make(chan dialResult), + } +} + +func (w *dialWorker) loop() { + w.wg.Add(1) + defer w.wg.Done() + defer w.s.limiter.clearAllPeerDials(w.peer) + + triggerNow := make(chan struct{}) + close(triggerNow) + +loop: + for { + select { + case req, ok := <-w.reqch: + if !ok { + // request channel has been closed, wait for pending dials to complete + if w.active > 0 { + w.done = true + w.reqch = nil + w.triggerDial = nil + continue loop + } + + // no active dials, we are done + return + } + + c := w.s.bestAcceptableConnToPeer(req.ctx, w.peer) + if c != nil { + req.resch <- dialResponse{conn: c} + continue loop + } + + addrs, err := w.s.addrsForDial(req.ctx, w.peer) + if err != nil { + req.resch <- dialResponse{err: err} + continue loop + } + + // at this point, len(addrs) > 0 or else it would be error from addrsForDial + // ranke them to process in order + addrs = w.rankAddrs(addrs) + + // create the pending request object + pr := &pendRequest{ + req: req, + err: &DialError{Peer: w.peer}, + addrs: make(map[ma.Multiaddr]struct{}), + } + for _, a := range addrs { + pr.addrs[a] = struct{}{} + } + + // check if any of the addrs has been successfully dialed and accumulate + // errors from complete dials while collecting new addrs to dial/join + var todial []ma.Multiaddr + var tojoin []*addrDial + + for _, a := range addrs { + ad, ok := w.pending[a] + if !ok { + todial = append(todial, a) + continue + } + + if ad.conn != nil { + // dial to this addr was successful, complete the request + req.resch <- dialResponse{conn: ad.conn} + continue loop + } + + if ad.err != nil { + // dial to this addr errored, accumulate the error + pr.err.recordErr(a, ad.err) + delete(pr.addrs, a) + continue + } + + // dial is still pending, add to the join list + tojoin = append(tojoin, ad) + } + + if len(todial) == 0 && len(tojoin) == 0 { + // all request applicable addrs have been dialed, we must have errored + req.resch <- dialResponse{err: pr.err} + continue loop + } + + // the request has some pending or new dials, track it and schedule new dials + w.reqno++ + w.requests[w.reqno] = pr + + for _, ad := range tojoin { + if !ad.dialed { + if simConnect, isClient, reason := network.GetSimultaneousConnect(req.ctx); simConnect { + if simConnect, _, _ := network.GetSimultaneousConnect(ad.ctx); !simConnect { + ad.ctx = network.WithSimultaneousConnect(ad.ctx, isClient, reason) + } + } + } + ad.requests = append(ad.requests, w.reqno) + } + + if len(todial) > 0 { + for _, a := range todial { + w.pending[a] = &addrDial{addr: a, ctx: req.ctx, requests: []int{w.reqno}} + } + + w.nextDial = append(w.nextDial, todial...) + w.nextDial = w.rankAddrs(w.nextDial) + + // trigger a new dial now to account for the new addrs we added + w.triggerDial = triggerNow + } + + case <-w.triggerDial: + for _, addr := range w.nextDial { + // spawn the dial + ad := w.pending[addr] + err := w.s.dialNextAddr(ad.ctx, w.peer, addr, w.resch) + if err != nil { + w.dispatchError(ad, err) + } else { + w.active++ + } + } + + w.nextDial = nil + w.triggerDial = nil + + case res := <-w.resch: + w.active-- + + if res.Conn != nil { + w.connected = true + } + + if w.done && w.active == 0 { + if res.Conn != nil { + // we got an actual connection, but the dial has been cancelled + // Should we close it? I think not, we should just add it to the swarm + _, err := w.s.addConn(res.Conn, network.DirOutbound) + if err != nil { + // well duh, now we have to close it + res.Conn.Close() + } + } + return + } + + ad := w.pending[res.Addr] + + if res.Conn != nil { + // we got a connection, add it to the swarm + conn, err := w.s.addConn(res.Conn, network.DirOutbound) + if err != nil { + // oops no, we failed to add it to the swarm + res.Conn.Close() + w.dispatchError(ad, err) + if w.active == 0 && len(w.nextDial) > 0 { + w.triggerDial = triggerNow + } + continue loop + } + + // dispatch to still pending requests + for _, reqno := range ad.requests { + pr, ok := w.requests[reqno] + if !ok { + // it has already dispatched a connection + continue + } + + pr.req.resch <- dialResponse{conn: conn} + delete(w.requests, reqno) + } + + ad.conn = conn + ad.requests = nil + + continue loop + } + + // it must be an error -- add backoff if applicable and dispatch + if res.Err != context.Canceled && !w.connected { + // we only add backoff if there has not been a successful connection + // for consistency with the old dialer behavior. + w.s.backf.AddBackoff(w.peer, res.Addr) + } + + w.dispatchError(ad, res.Err) + if w.active == 0 && len(w.nextDial) > 0 { + w.triggerDial = triggerNow + } + + case f := <-w.eval: + f() + } + } +} + +// dispatches an error to a specific addr dial +func (w *dialWorker) dispatchError(ad *addrDial, err error) { + ad.err = err + for _, reqno := range ad.requests { + pr, ok := w.requests[reqno] + if !ok { + // has already been dispatched + continue + } + + // accumulate the error + pr.err.recordErr(ad.addr, err) + + delete(pr.addrs, ad.addr) + if len(pr.addrs) == 0 { + // all addrs have erred, dispatch dial error + // but first do a last one check in case an acceptable connection has landed from + // a simultaneous dial that started later and added new acceptable addrs + c := w.s.bestAcceptableConnToPeer(pr.req.ctx, w.peer) + if c != nil { + pr.req.resch <- dialResponse{conn: c} + } else { + pr.req.resch <- dialResponse{err: pr.err} + } + delete(w.requests, reqno) + } + } + + ad.requests = nil + + // if it was a backoff, clear the address dial so that it doesn't inhibit new dial requests. + // this is necessary to support active listen scenarios, where a new dial comes in while + // another dial is in progress, and needs to do a direct connection without inhibitions from + // dial backoff. + // it is also necessary to preserve consisent behaviour with the old dialer -- TestDialBackoff + // regresses without this. + if err == ErrDialBackoff { + delete(w.pending, ad.addr) + } +} + +// ranks addresses in descending order of preference for dialing, with the following rules: +// NonRelay > Relay +// NonWS > WS +// Private > Public +// UDP > TCP +func (w *dialWorker) rankAddrs(addrs []ma.Multiaddr) []ma.Multiaddr { + addrTier := func(a ma.Multiaddr) (tier int) { + if isRelayAddr(a) { + tier |= 0b1000 + } + if isExpensiveAddr(a) { + tier |= 0b0100 + } + if !manet.IsPrivateAddr(a) { + tier |= 0b0010 + } + if isFdConsumingAddr(a) { + tier |= 0b0001 + } + + return tier + } + + tiers := make([][]ma.Multiaddr, 16) + for _, a := range addrs { + tier := addrTier(a) + tiers[tier] = append(tiers[tier], a) + } + + result := make([]ma.Multiaddr, 0, len(addrs)) + for _, tier := range tiers { + result = append(result, tier...) + } + + return result +} diff --git a/p2p/net/swarm/swarm_dial.go b/p2p/net/swarm/swarm_dial.go index c92970cd..ca11d732 100644 --- a/p2p/net/swarm/swarm_dial.go +++ b/p2p/net/swarm/swarm_dial.go @@ -278,281 +278,10 @@ func (s *Swarm) dialPeer(ctx context.Context, p peer.ID) (*Conn, error) { return nil, err } -// ///////////////////////////////////////////////////////////////////////////////// -// lo and behold, The Dialer -// TODO explain how all this works -// //////////////////////////////////////////////////////////////////////////////// - -type dialRequest struct { - ctx context.Context - resch chan dialResponse -} - -type dialResponse struct { - conn *Conn - err error -} - // dialWorkerLoop synchronizes and executes concurrent dials to a single peer func (s *Swarm) dialWorkerLoop(p peer.ID, reqch <-chan dialRequest) { - defer s.limiter.clearAllPeerDials(p) - - type pendRequest struct { - req dialRequest // the original request - err *DialError // dial error accumulator - addrs map[ma.Multiaddr]struct{} // pending addr dials - } - - type addrDial struct { - addr ma.Multiaddr - ctx context.Context - conn *Conn - err error - requests []int - dialed bool - } - - reqno := 0 - requests := make(map[int]*pendRequest) - pending := make(map[ma.Multiaddr]*addrDial) - - dispatchError := func(ad *addrDial, err error) { - ad.err = err - for _, reqno := range ad.requests { - pr, ok := requests[reqno] - if !ok { - // has already been dispatched - continue - } - - // accumulate the error - pr.err.recordErr(ad.addr, err) - - delete(pr.addrs, ad.addr) - if len(pr.addrs) == 0 { - // all addrs have erred, dispatch dial error - // but first do a last one check in case an acceptable connection has landed from - // a simultaneous dial that started later and added new acceptable addrs - c := s.bestAcceptableConnToPeer(pr.req.ctx, p) - if c != nil { - pr.req.resch <- dialResponse{conn: c} - } else { - pr.req.resch <- dialResponse{err: pr.err} - } - delete(requests, reqno) - } - } - - ad.requests = nil - - // if it was a backoff, clear the address dial so that it doesn't inhibit new dial requests. - // this is necessary to support active listen scenarios, where a new dial comes in while - // another dial is in progress, and needs to do a direct connection without inhibitions from - // dial backoff. - // it is also necessary to preserve consisent behaviour with the old dialer -- TestDialBackoff - // regresses without this. - if err == ErrDialBackoff { - delete(pending, ad.addr) - } - } - - var triggerDial <-chan struct{} - triggerNow := make(chan struct{}) - close(triggerNow) - - var nextDial []ma.Multiaddr - active := 0 - done := false // true when the request channel has been closed - connected := false // true when a connection has been successfully established - - resch := make(chan dialResult) - -loop: - for { - select { - case req, ok := <-reqch: - if !ok { - // request channel has been closed, wait for pending dials to complete - if active > 0 { - done = true - reqch = nil - triggerDial = nil - continue loop - } - - // no active dials, we are done - return - } - - c := s.bestAcceptableConnToPeer(req.ctx, p) - if c != nil { - req.resch <- dialResponse{conn: c} - continue loop - } - - addrs, err := s.addrsForDial(req.ctx, p) - if err != nil { - req.resch <- dialResponse{err: err} - continue loop - } - - // at this point, len(addrs) > 0 or else it would be error from addrsForDial - // ranke them to process in order - addrs = s.rankAddrs(addrs) - - // create the pending request object - pr := &pendRequest{ - req: req, - err: &DialError{Peer: p}, - addrs: make(map[ma.Multiaddr]struct{}), - } - for _, a := range addrs { - pr.addrs[a] = struct{}{} - } - - // check if any of the addrs has been successfully dialed and accumulate - // errors from complete dials while collecting new addrs to dial/join - var todial []ma.Multiaddr - var tojoin []*addrDial - - for _, a := range addrs { - ad, ok := pending[a] - if !ok { - todial = append(todial, a) - continue - } - - if ad.conn != nil { - // dial to this addr was successful, complete the request - req.resch <- dialResponse{conn: ad.conn} - continue loop - } - - if ad.err != nil { - // dial to this addr errored, accumulate the error - pr.err.recordErr(a, ad.err) - delete(pr.addrs, a) - continue - } - - // dial is still pending, add to the join list - tojoin = append(tojoin, ad) - } - - if len(todial) == 0 && len(tojoin) == 0 { - // all request applicable addrs have been dialed, we must have errored - req.resch <- dialResponse{err: pr.err} - continue loop - } - - // the request has some pending or new dials, track it and schedule new dials - reqno++ - requests[reqno] = pr - - for _, ad := range tojoin { - if !ad.dialed { - if simConnect, isClient, reason := network.GetSimultaneousConnect(req.ctx); simConnect { - if simConnect, _, _ := network.GetSimultaneousConnect(ad.ctx); !simConnect { - ad.ctx = network.WithSimultaneousConnect(ad.ctx, isClient, reason) - } - } - } - ad.requests = append(ad.requests, reqno) - } - - if len(todial) > 0 { - for _, a := range todial { - pending[a] = &addrDial{addr: a, ctx: req.ctx, requests: []int{reqno}} - } - - nextDial = append(nextDial, todial...) - nextDial = s.rankAddrs(nextDial) - - // trigger a new dial now to account for the new addrs we added - triggerDial = triggerNow - } - - case <-triggerDial: - for _, addr := range nextDial { - // spawn the dial - ad := pending[addr] - err := s.dialNextAddr(ad.ctx, p, addr, resch) - if err != nil { - dispatchError(ad, err) - } else { - active++ - } - } - - nextDial = nil - triggerDial = nil - - case res := <-resch: - active-- - - if res.Conn != nil { - connected = true - } - - if done && active == 0 { - if res.Conn != nil { - // we got an actual connection, but the dial has been cancelled - // Should we close it? I think not, we should just add it to the swarm - _, err := s.addConn(res.Conn, network.DirOutbound) - if err != nil { - // well duh, now we have to close it - res.Conn.Close() - } - } - return - } - - ad := pending[res.Addr] - - if res.Conn != nil { - // we got a connection, add it to the swarm - conn, err := s.addConn(res.Conn, network.DirOutbound) - if err != nil { - // oops no, we failed to add it to the swarm - res.Conn.Close() - dispatchError(ad, err) - if active == 0 && len(nextDial) > 0 { - triggerDial = triggerNow - } - continue loop - } - - // dispatch to still pending requests - for _, reqno := range ad.requests { - pr, ok := requests[reqno] - if !ok { - // it has already dispatched a connection - continue - } - - pr.req.resch <- dialResponse{conn: conn} - delete(requests, reqno) - } - - ad.conn = conn - ad.requests = nil - - continue loop - } - - // it must be an error -- add backoff if applicable and dispatch - if res.Err != context.Canceled && !connected { - // we only add backoff if there has not been a successful connection - // for consistency with the old dialer behavior. - s.backf.AddBackoff(p, res.Addr) - } - - dispatchError(ad, res.Err) - if active == 0 && len(nextDial) > 0 { - triggerDial = triggerNow - } - } - } + w := newDialWorker(s, p, reqch) + w.loop() } func (s *Swarm) addrsForDial(ctx context.Context, p peer.ID) ([]ma.Multiaddr, error) { @@ -597,43 +326,6 @@ func (s *Swarm) nonProxyAddr(addr ma.Multiaddr) bool { return !t.Proxy() } -// ranks addresses in descending order of preference for dialing, with the following rules: -// NonRelay > Relay -// NonWS > WS -// Private > Public -// UDP > TCP -func (s *Swarm) rankAddrs(addrs []ma.Multiaddr) []ma.Multiaddr { - addrTier := func(a ma.Multiaddr) (tier int) { - if isRelayAddr(a) { - tier |= 0b1000 - } - if isExpensiveAddr(a) { - tier |= 0b0100 - } - if !manet.IsPrivateAddr(a) { - tier |= 0b0010 - } - if isFdConsumingAddr(a) { - tier |= 0b0001 - } - - return tier - } - - tiers := make([][]ma.Multiaddr, 16) - for _, a := range addrs { - tier := addrTier(a) - tiers[tier] = append(tiers[tier], a) - } - - result := make([]ma.Multiaddr, 0, len(addrs)) - for _, tier := range tiers { - result = append(result, tier...) - } - - return result -} - // filterKnownUndialables takes a list of multiaddrs, and removes those // that we definitely don't want to dial: addresses configured to be blocked, // IPv6 link-local addresses, addresses without a dial-capable transport,