Merge pull request #315 from libp2p/fix/dial-worker
Refactor dial worker loop into an object and fix bug
This commit is contained in:
commit
0487a88370
|
@ -74,8 +74,6 @@ func (ds *dialSync) getActiveDial(p peer.ID) (*activeDial, error) {
|
|||
if !ok {
|
||||
// This code intentionally uses the background context. Otherwise, if the first call
|
||||
// to Dial is canceled, subsequent dial calls will also be canceled.
|
||||
// XXX: this also breaks direct connection logic. We will need to pipe the
|
||||
// information through some other way.
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
actd = &activeDial{
|
||||
ctx: ctx,
|
||||
|
|
|
@ -0,0 +1,316 @@
|
|||
package swarm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/libp2p/go-libp2p-core/network"
|
||||
"github.com/libp2p/go-libp2p-core/peer"
|
||||
|
||||
ma "github.com/multiformats/go-multiaddr"
|
||||
manet "github.com/multiformats/go-multiaddr/net"
|
||||
)
|
||||
|
||||
// /////////////////////////////////////////////////////////////////////////////////
|
||||
// lo and behold, The Dialer
|
||||
// TODO explain how all this works
|
||||
// ////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
type dialRequest struct {
|
||||
ctx context.Context
|
||||
resch chan dialResponse
|
||||
}
|
||||
|
||||
type dialResponse struct {
|
||||
conn *Conn
|
||||
err error
|
||||
}
|
||||
|
||||
type pendRequest struct {
|
||||
req dialRequest // the original request
|
||||
err *DialError // dial error accumulator
|
||||
addrs map[ma.Multiaddr]struct{} // pending addr dials
|
||||
}
|
||||
|
||||
type addrDial struct {
|
||||
addr ma.Multiaddr
|
||||
ctx context.Context
|
||||
conn *Conn
|
||||
err error
|
||||
requests []int
|
||||
dialed bool
|
||||
}
|
||||
|
||||
type dialWorker struct {
|
||||
s *Swarm
|
||||
peer peer.ID
|
||||
reqch <-chan dialRequest
|
||||
reqno int
|
||||
requests map[int]*pendRequest
|
||||
pending map[ma.Multiaddr]*addrDial
|
||||
resch chan dialResult
|
||||
|
||||
connected bool // true when a connection has been successfully established
|
||||
|
||||
nextDial []ma.Multiaddr
|
||||
|
||||
// ready when we have more addresses to dial (nextDial is not empty)
|
||||
triggerDial <-chan struct{}
|
||||
|
||||
// for testing
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
func newDialWorker(s *Swarm, p peer.ID, reqch <-chan dialRequest) *dialWorker {
|
||||
return &dialWorker{
|
||||
s: s,
|
||||
peer: p,
|
||||
reqch: reqch,
|
||||
requests: make(map[int]*pendRequest),
|
||||
pending: make(map[ma.Multiaddr]*addrDial),
|
||||
resch: make(chan dialResult),
|
||||
}
|
||||
}
|
||||
|
||||
func (w *dialWorker) loop() {
|
||||
w.wg.Add(1)
|
||||
defer w.wg.Done()
|
||||
defer w.s.limiter.clearAllPeerDials(w.peer)
|
||||
|
||||
// used to signal readiness to dial and completion of the dial
|
||||
ready := make(chan struct{})
|
||||
close(ready)
|
||||
|
||||
loop:
|
||||
for {
|
||||
select {
|
||||
case req, ok := <-w.reqch:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
c := w.s.bestAcceptableConnToPeer(req.ctx, w.peer)
|
||||
if c != nil {
|
||||
req.resch <- dialResponse{conn: c}
|
||||
continue loop
|
||||
}
|
||||
|
||||
addrs, err := w.s.addrsForDial(req.ctx, w.peer)
|
||||
if err != nil {
|
||||
req.resch <- dialResponse{err: err}
|
||||
continue loop
|
||||
}
|
||||
|
||||
// at this point, len(addrs) > 0 or else it would be error from addrsForDial
|
||||
// ranke them to process in order
|
||||
addrs = w.rankAddrs(addrs)
|
||||
|
||||
// create the pending request object
|
||||
pr := &pendRequest{
|
||||
req: req,
|
||||
err: &DialError{Peer: w.peer},
|
||||
addrs: make(map[ma.Multiaddr]struct{}),
|
||||
}
|
||||
for _, a := range addrs {
|
||||
pr.addrs[a] = struct{}{}
|
||||
}
|
||||
|
||||
// check if any of the addrs has been successfully dialed and accumulate
|
||||
// errors from complete dials while collecting new addrs to dial/join
|
||||
var todial []ma.Multiaddr
|
||||
var tojoin []*addrDial
|
||||
|
||||
for _, a := range addrs {
|
||||
ad, ok := w.pending[a]
|
||||
if !ok {
|
||||
todial = append(todial, a)
|
||||
continue
|
||||
}
|
||||
|
||||
if ad.conn != nil {
|
||||
// dial to this addr was successful, complete the request
|
||||
req.resch <- dialResponse{conn: ad.conn}
|
||||
continue loop
|
||||
}
|
||||
|
||||
if ad.err != nil {
|
||||
// dial to this addr errored, accumulate the error
|
||||
pr.err.recordErr(a, ad.err)
|
||||
delete(pr.addrs, a)
|
||||
continue
|
||||
}
|
||||
|
||||
// dial is still pending, add to the join list
|
||||
tojoin = append(tojoin, ad)
|
||||
}
|
||||
|
||||
if len(todial) == 0 && len(tojoin) == 0 {
|
||||
// all request applicable addrs have been dialed, we must have errored
|
||||
req.resch <- dialResponse{err: pr.err}
|
||||
continue loop
|
||||
}
|
||||
|
||||
// the request has some pending or new dials, track it and schedule new dials
|
||||
w.reqno++
|
||||
w.requests[w.reqno] = pr
|
||||
|
||||
for _, ad := range tojoin {
|
||||
if !ad.dialed {
|
||||
if simConnect, isClient, reason := network.GetSimultaneousConnect(req.ctx); simConnect {
|
||||
if simConnect, _, _ := network.GetSimultaneousConnect(ad.ctx); !simConnect {
|
||||
ad.ctx = network.WithSimultaneousConnect(ad.ctx, isClient, reason)
|
||||
}
|
||||
}
|
||||
}
|
||||
ad.requests = append(ad.requests, w.reqno)
|
||||
}
|
||||
|
||||
if len(todial) > 0 {
|
||||
for _, a := range todial {
|
||||
w.pending[a] = &addrDial{addr: a, ctx: req.ctx, requests: []int{w.reqno}}
|
||||
}
|
||||
|
||||
w.nextDial = append(w.nextDial, todial...)
|
||||
w.nextDial = w.rankAddrs(w.nextDial)
|
||||
|
||||
// trigger a new dial now to account for the new addrs we added
|
||||
w.triggerDial = ready
|
||||
}
|
||||
|
||||
case <-w.triggerDial:
|
||||
for _, addr := range w.nextDial {
|
||||
// spawn the dial
|
||||
ad := w.pending[addr]
|
||||
err := w.s.dialNextAddr(ad.ctx, w.peer, addr, w.resch)
|
||||
if err != nil {
|
||||
w.dispatchError(ad, err)
|
||||
}
|
||||
}
|
||||
|
||||
w.nextDial = nil
|
||||
w.triggerDial = nil
|
||||
|
||||
case res := <-w.resch:
|
||||
if res.Conn != nil {
|
||||
w.connected = true
|
||||
}
|
||||
|
||||
ad := w.pending[res.Addr]
|
||||
|
||||
if res.Conn != nil {
|
||||
// we got a connection, add it to the swarm
|
||||
conn, err := w.s.addConn(res.Conn, network.DirOutbound)
|
||||
if err != nil {
|
||||
// oops no, we failed to add it to the swarm
|
||||
res.Conn.Close()
|
||||
w.dispatchError(ad, err)
|
||||
continue loop
|
||||
}
|
||||
|
||||
// dispatch to still pending requests
|
||||
for _, reqno := range ad.requests {
|
||||
pr, ok := w.requests[reqno]
|
||||
if !ok {
|
||||
// it has already dispatched a connection
|
||||
continue
|
||||
}
|
||||
|
||||
pr.req.resch <- dialResponse{conn: conn}
|
||||
delete(w.requests, reqno)
|
||||
}
|
||||
|
||||
ad.conn = conn
|
||||
ad.requests = nil
|
||||
|
||||
continue loop
|
||||
}
|
||||
|
||||
// it must be an error -- add backoff if applicable and dispatch
|
||||
if res.Err != context.Canceled && !w.connected {
|
||||
// we only add backoff if there has not been a successful connection
|
||||
// for consistency with the old dialer behavior.
|
||||
w.s.backf.AddBackoff(w.peer, res.Addr)
|
||||
}
|
||||
|
||||
w.dispatchError(ad, res.Err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// dispatches an error to a specific addr dial
|
||||
func (w *dialWorker) dispatchError(ad *addrDial, err error) {
|
||||
ad.err = err
|
||||
for _, reqno := range ad.requests {
|
||||
pr, ok := w.requests[reqno]
|
||||
if !ok {
|
||||
// has already been dispatched
|
||||
continue
|
||||
}
|
||||
|
||||
// accumulate the error
|
||||
pr.err.recordErr(ad.addr, err)
|
||||
|
||||
delete(pr.addrs, ad.addr)
|
||||
if len(pr.addrs) == 0 {
|
||||
// all addrs have erred, dispatch dial error
|
||||
// but first do a last one check in case an acceptable connection has landed from
|
||||
// a simultaneous dial that started later and added new acceptable addrs
|
||||
c := w.s.bestAcceptableConnToPeer(pr.req.ctx, w.peer)
|
||||
if c != nil {
|
||||
pr.req.resch <- dialResponse{conn: c}
|
||||
} else {
|
||||
pr.req.resch <- dialResponse{err: pr.err}
|
||||
}
|
||||
delete(w.requests, reqno)
|
||||
}
|
||||
}
|
||||
|
||||
ad.requests = nil
|
||||
|
||||
// if it was a backoff, clear the address dial so that it doesn't inhibit new dial requests.
|
||||
// this is necessary to support active listen scenarios, where a new dial comes in while
|
||||
// another dial is in progress, and needs to do a direct connection without inhibitions from
|
||||
// dial backoff.
|
||||
// it is also necessary to preserve consisent behaviour with the old dialer -- TestDialBackoff
|
||||
// regresses without this.
|
||||
if err == ErrDialBackoff {
|
||||
delete(w.pending, ad.addr)
|
||||
}
|
||||
}
|
||||
|
||||
// ranks addresses in descending order of preference for dialing, with the following rules:
|
||||
// NonRelay > Relay
|
||||
// NonWS > WS
|
||||
// Private > Public
|
||||
// UDP > TCP
|
||||
func (w *dialWorker) rankAddrs(addrs []ma.Multiaddr) []ma.Multiaddr {
|
||||
addrTier := func(a ma.Multiaddr) (tier int) {
|
||||
if isRelayAddr(a) {
|
||||
tier |= 0b1000
|
||||
}
|
||||
if isExpensiveAddr(a) {
|
||||
tier |= 0b0100
|
||||
}
|
||||
if !manet.IsPrivateAddr(a) {
|
||||
tier |= 0b0010
|
||||
}
|
||||
if isFdConsumingAddr(a) {
|
||||
tier |= 0b0001
|
||||
}
|
||||
|
||||
return tier
|
||||
}
|
||||
|
||||
tiers := make([][]ma.Multiaddr, 16)
|
||||
for _, a := range addrs {
|
||||
tier := addrTier(a)
|
||||
tiers[tier] = append(tiers[tier], a)
|
||||
}
|
||||
|
||||
result := make([]ma.Multiaddr, 0, len(addrs))
|
||||
for _, tier := range tiers {
|
||||
result = append(result, tier...)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
|
@ -0,0 +1,327 @@
|
|||
package swarm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
csms "github.com/libp2p/go-conn-security-multistream"
|
||||
"github.com/libp2p/go-libp2p-core/peerstore"
|
||||
"github.com/libp2p/go-libp2p-core/sec/insecure"
|
||||
"github.com/libp2p/go-libp2p-core/transport"
|
||||
"github.com/libp2p/go-libp2p-peerstore/pstoremem"
|
||||
quic "github.com/libp2p/go-libp2p-quic-transport"
|
||||
tnet "github.com/libp2p/go-libp2p-testing/net"
|
||||
tptu "github.com/libp2p/go-libp2p-transport-upgrader"
|
||||
yamux "github.com/libp2p/go-libp2p-yamux"
|
||||
msmux "github.com/libp2p/go-stream-muxer-multistream"
|
||||
tcp "github.com/libp2p/go-tcp-transport"
|
||||
ma "github.com/multiformats/go-multiaddr"
|
||||
)
|
||||
|
||||
func makeSwarm(t *testing.T) *Swarm {
|
||||
p := tnet.RandPeerNetParamsOrFatal(t)
|
||||
|
||||
ps, err := pstoremem.NewPeerstore()
|
||||
require.NoError(t, err)
|
||||
ps.AddPubKey(p.ID, p.PubKey)
|
||||
ps.AddPrivKey(p.ID, p.PrivKey)
|
||||
t.Cleanup(func() { ps.Close() })
|
||||
|
||||
s, err := NewSwarm(p.ID, ps, WithDialTimeout(time.Second))
|
||||
require.NoError(t, err)
|
||||
|
||||
upgrader := makeUpgrader(t, s)
|
||||
|
||||
var tcpOpts []tcp.Option
|
||||
tcpOpts = append(tcpOpts, tcp.DisableReuseport())
|
||||
tcpTransport, err := tcp.NewTCPTransport(upgrader, nil, tcpOpts...)
|
||||
require.NoError(t, err)
|
||||
if err := s.AddTransport(tcpTransport); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := s.Listen(p.Addr); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
quicTransport, err := quic.NewTransport(p.PrivKey, nil, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := s.AddTransport(quicTransport); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := s.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func makeUpgrader(t *testing.T, n *Swarm) transport.Upgrader {
|
||||
id := n.LocalPeer()
|
||||
pk := n.Peerstore().PrivKey(id)
|
||||
secMuxer := new(csms.SSMuxer)
|
||||
secMuxer.AddTransport(insecure.ID, insecure.NewWithIdentity(id, pk))
|
||||
|
||||
stMuxer := msmux.NewBlankTransport()
|
||||
stMuxer.AddTransport("/yamux/1.0.0", yamux.DefaultTransport)
|
||||
u, err := tptu.New(secMuxer, stMuxer)
|
||||
require.NoError(t, err)
|
||||
return u
|
||||
}
|
||||
|
||||
func TestDialWorkerLoopBasic(t *testing.T) {
|
||||
s1 := makeSwarm(t)
|
||||
s2 := makeSwarm(t)
|
||||
defer s1.Close()
|
||||
defer s2.Close()
|
||||
|
||||
s1.Peerstore().AddAddrs(s2.LocalPeer(), s2.ListenAddresses(), peerstore.PermanentAddrTTL)
|
||||
|
||||
reqch := make(chan dialRequest)
|
||||
resch := make(chan dialResponse)
|
||||
worker := newDialWorker(s1, s2.LocalPeer(), reqch)
|
||||
go worker.loop()
|
||||
|
||||
var conn *Conn
|
||||
reqch <- dialRequest{ctx: context.Background(), resch: resch}
|
||||
select {
|
||||
case res := <-resch:
|
||||
require.NoError(t, res.err)
|
||||
conn = res.conn
|
||||
case <-time.After(time.Minute):
|
||||
t.Fatal("dial didn't complete")
|
||||
}
|
||||
|
||||
s, err := conn.NewStream(context.Background())
|
||||
require.NoError(t, err)
|
||||
s.Close()
|
||||
|
||||
var conn2 *Conn
|
||||
reqch <- dialRequest{ctx: context.Background(), resch: resch}
|
||||
select {
|
||||
case res := <-resch:
|
||||
require.NoError(t, res.err)
|
||||
conn2 = res.conn
|
||||
case <-time.After(time.Minute):
|
||||
t.Fatal("dial didn't complete")
|
||||
}
|
||||
|
||||
require.Equal(t, conn, conn2)
|
||||
|
||||
close(reqch)
|
||||
worker.wg.Wait()
|
||||
}
|
||||
|
||||
func TestDialWorkerLoopConcurrent(t *testing.T) {
|
||||
s1 := makeSwarm(t)
|
||||
s2 := makeSwarm(t)
|
||||
defer s1.Close()
|
||||
defer s2.Close()
|
||||
|
||||
s1.Peerstore().AddAddrs(s2.LocalPeer(), s2.ListenAddresses(), peerstore.PermanentAddrTTL)
|
||||
|
||||
reqch := make(chan dialRequest)
|
||||
worker := newDialWorker(s1, s2.LocalPeer(), reqch)
|
||||
go worker.loop()
|
||||
|
||||
const dials = 100
|
||||
var wg sync.WaitGroup
|
||||
resch := make(chan dialResponse, dials)
|
||||
for i := 0; i < dials; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
reschgo := make(chan dialResponse, 1)
|
||||
reqch <- dialRequest{ctx: context.Background(), resch: reschgo}
|
||||
select {
|
||||
case res := <-reschgo:
|
||||
resch <- res
|
||||
case <-time.After(time.Minute):
|
||||
resch <- dialResponse{err: errors.New("timed out!")}
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
for i := 0; i < dials; i++ {
|
||||
res := <-resch
|
||||
require.NoError(t, res.err)
|
||||
}
|
||||
|
||||
t.Log("all concurrent dials done")
|
||||
|
||||
close(reqch)
|
||||
worker.wg.Wait()
|
||||
}
|
||||
|
||||
func TestDialWorkerLoopFailure(t *testing.T) {
|
||||
s1 := makeSwarm(t)
|
||||
defer s1.Close()
|
||||
|
||||
p2 := tnet.RandPeerNetParamsOrFatal(t)
|
||||
|
||||
s1.Peerstore().AddAddrs(p2.ID, []ma.Multiaddr{ma.StringCast("/ip4/11.0.0.1/tcp/1234"), ma.StringCast("/ip4/11.0.0.1/udp/1234/quic")}, peerstore.PermanentAddrTTL)
|
||||
|
||||
reqch := make(chan dialRequest)
|
||||
resch := make(chan dialResponse)
|
||||
worker := newDialWorker(s1, p2.ID, reqch)
|
||||
go worker.loop()
|
||||
|
||||
reqch <- dialRequest{ctx: context.Background(), resch: resch}
|
||||
select {
|
||||
case res := <-resch:
|
||||
require.Error(t, res.err)
|
||||
case <-time.After(time.Minute):
|
||||
t.Fatal("dial didn't complete")
|
||||
}
|
||||
|
||||
close(reqch)
|
||||
worker.wg.Wait()
|
||||
}
|
||||
|
||||
func TestDialWorkerLoopConcurrentFailure(t *testing.T) {
|
||||
s1 := makeSwarm(t)
|
||||
defer s1.Close()
|
||||
|
||||
p2 := tnet.RandPeerNetParamsOrFatal(t)
|
||||
|
||||
s1.Peerstore().AddAddrs(p2.ID, []ma.Multiaddr{ma.StringCast("/ip4/11.0.0.1/tcp/1234"), ma.StringCast("/ip4/11.0.0.1/udp/1234/quic")}, peerstore.PermanentAddrTTL)
|
||||
|
||||
reqch := make(chan dialRequest)
|
||||
worker := newDialWorker(s1, p2.ID, reqch)
|
||||
go worker.loop()
|
||||
|
||||
const dials = 100
|
||||
var errTimeout = errors.New("timed out!")
|
||||
var wg sync.WaitGroup
|
||||
resch := make(chan dialResponse, dials)
|
||||
for i := 0; i < dials; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
reschgo := make(chan dialResponse, 1)
|
||||
reqch <- dialRequest{ctx: context.Background(), resch: reschgo}
|
||||
|
||||
select {
|
||||
case res := <-reschgo:
|
||||
resch <- res
|
||||
case <-time.After(time.Minute):
|
||||
resch <- dialResponse{err: errTimeout}
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
for i := 0; i < dials; i++ {
|
||||
res := <-resch
|
||||
require.Error(t, res.err)
|
||||
if res.err == errTimeout {
|
||||
t.Fatal("dial response timed out")
|
||||
}
|
||||
}
|
||||
|
||||
t.Log("all concurrent dials done")
|
||||
|
||||
close(reqch)
|
||||
worker.wg.Wait()
|
||||
}
|
||||
|
||||
func TestDialWorkerLoopConcurrentMix(t *testing.T) {
|
||||
s1 := makeSwarm(t)
|
||||
s2 := makeSwarm(t)
|
||||
defer s1.Close()
|
||||
defer s2.Close()
|
||||
|
||||
s1.Peerstore().AddAddrs(s2.LocalPeer(), s2.ListenAddresses(), peerstore.PermanentAddrTTL)
|
||||
s1.Peerstore().AddAddrs(s2.LocalPeer(), []ma.Multiaddr{ma.StringCast("/ip4/11.0.0.1/tcp/1234"), ma.StringCast("/ip4/11.0.0.1/udp/1234/quic")}, peerstore.PermanentAddrTTL)
|
||||
|
||||
reqch := make(chan dialRequest)
|
||||
worker := newDialWorker(s1, s2.LocalPeer(), reqch)
|
||||
go worker.loop()
|
||||
|
||||
const dials = 100
|
||||
var wg sync.WaitGroup
|
||||
resch := make(chan dialResponse, dials)
|
||||
for i := 0; i < dials; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
reschgo := make(chan dialResponse, 1)
|
||||
reqch <- dialRequest{ctx: context.Background(), resch: reschgo}
|
||||
select {
|
||||
case res := <-reschgo:
|
||||
resch <- res
|
||||
case <-time.After(time.Minute):
|
||||
resch <- dialResponse{err: errors.New("timed out!")}
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
for i := 0; i < dials; i++ {
|
||||
res := <-resch
|
||||
require.NoError(t, res.err)
|
||||
}
|
||||
|
||||
t.Log("all concurrent dials done")
|
||||
|
||||
close(reqch)
|
||||
worker.wg.Wait()
|
||||
}
|
||||
|
||||
func TestDialWorkerLoopConcurrentFailureStress(t *testing.T) {
|
||||
s1 := makeSwarm(t)
|
||||
defer s1.Close()
|
||||
|
||||
p2 := tnet.RandPeerNetParamsOrFatal(t)
|
||||
|
||||
var addrs []ma.Multiaddr
|
||||
for i := 0; i < 200; i++ {
|
||||
addrs = append(addrs, ma.StringCast(fmt.Sprintf("/ip4/11.0.0.%d/tcp/%d", i%256, 1234+i)))
|
||||
}
|
||||
s1.Peerstore().AddAddrs(p2.ID, addrs, peerstore.PermanentAddrTTL)
|
||||
|
||||
reqch := make(chan dialRequest)
|
||||
worker := newDialWorker(s1, p2.ID, reqch)
|
||||
go worker.loop()
|
||||
|
||||
const dials = 100
|
||||
var errTimeout = errors.New("timed out!")
|
||||
var wg sync.WaitGroup
|
||||
resch := make(chan dialResponse, dials)
|
||||
for i := 0; i < dials; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
reschgo := make(chan dialResponse, 1)
|
||||
reqch <- dialRequest{ctx: context.Background(), resch: reschgo}
|
||||
select {
|
||||
case res := <-reschgo:
|
||||
resch <- res
|
||||
case <-time.After(5 * time.Minute):
|
||||
resch <- dialResponse{err: errTimeout}
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
for i := 0; i < dials; i++ {
|
||||
res := <-resch
|
||||
require.Error(t, res.err)
|
||||
if res.err == errTimeout {
|
||||
t.Fatal("dial response timed out")
|
||||
}
|
||||
}
|
||||
|
||||
t.Log("all concurrent dials done")
|
||||
|
||||
close(reqch)
|
||||
worker.wg.Wait()
|
||||
}
|
|
@ -220,7 +220,7 @@ func (dl *dialLimiter) executeDial(j *dialJob) {
|
|||
select {
|
||||
case j.resp <- dialResult{Conn: con, Addr: j.addr, Err: err}:
|
||||
case <-j.ctx.Done():
|
||||
if err == nil {
|
||||
if con != nil {
|
||||
con.Close()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -328,7 +328,7 @@ func TestStressLimiter(t *testing.T) {
|
|||
for i := 0; i < 20; i++ {
|
||||
select {
|
||||
case <-success:
|
||||
case <-time.After(time.Second * 5):
|
||||
case <-time.After(time.Minute):
|
||||
t.Fatal("expected a success within five seconds")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -278,281 +278,10 @@ func (s *Swarm) dialPeer(ctx context.Context, p peer.ID) (*Conn, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
// /////////////////////////////////////////////////////////////////////////////////
|
||||
// lo and behold, The Dialer
|
||||
// TODO explain how all this works
|
||||
// ////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
type dialRequest struct {
|
||||
ctx context.Context
|
||||
resch chan dialResponse
|
||||
}
|
||||
|
||||
type dialResponse struct {
|
||||
conn *Conn
|
||||
err error
|
||||
}
|
||||
|
||||
// dialWorkerLoop synchronizes and executes concurrent dials to a single peer
|
||||
func (s *Swarm) dialWorkerLoop(p peer.ID, reqch <-chan dialRequest) {
|
||||
defer s.limiter.clearAllPeerDials(p)
|
||||
|
||||
type pendRequest struct {
|
||||
req dialRequest // the original request
|
||||
err *DialError // dial error accumulator
|
||||
addrs map[ma.Multiaddr]struct{} // pending addr dials
|
||||
}
|
||||
|
||||
type addrDial struct {
|
||||
addr ma.Multiaddr
|
||||
ctx context.Context
|
||||
conn *Conn
|
||||
err error
|
||||
requests []int
|
||||
dialed bool
|
||||
}
|
||||
|
||||
reqno := 0
|
||||
requests := make(map[int]*pendRequest)
|
||||
pending := make(map[ma.Multiaddr]*addrDial)
|
||||
|
||||
dispatchError := func(ad *addrDial, err error) {
|
||||
ad.err = err
|
||||
for _, reqno := range ad.requests {
|
||||
pr, ok := requests[reqno]
|
||||
if !ok {
|
||||
// has already been dispatched
|
||||
continue
|
||||
}
|
||||
|
||||
// accumulate the error
|
||||
pr.err.recordErr(ad.addr, err)
|
||||
|
||||
delete(pr.addrs, ad.addr)
|
||||
if len(pr.addrs) == 0 {
|
||||
// all addrs have erred, dispatch dial error
|
||||
// but first do a last one check in case an acceptable connection has landed from
|
||||
// a simultaneous dial that started later and added new acceptable addrs
|
||||
c := s.bestAcceptableConnToPeer(pr.req.ctx, p)
|
||||
if c != nil {
|
||||
pr.req.resch <- dialResponse{conn: c}
|
||||
} else {
|
||||
pr.req.resch <- dialResponse{err: pr.err}
|
||||
}
|
||||
delete(requests, reqno)
|
||||
}
|
||||
}
|
||||
|
||||
ad.requests = nil
|
||||
|
||||
// if it was a backoff, clear the address dial so that it doesn't inhibit new dial requests.
|
||||
// this is necessary to support active listen scenarios, where a new dial comes in while
|
||||
// another dial is in progress, and needs to do a direct connection without inhibitions from
|
||||
// dial backoff.
|
||||
// it is also necessary to preserve consisent behaviour with the old dialer -- TestDialBackoff
|
||||
// regresses without this.
|
||||
if err == ErrDialBackoff {
|
||||
delete(pending, ad.addr)
|
||||
}
|
||||
}
|
||||
|
||||
var triggerDial <-chan struct{}
|
||||
triggerNow := make(chan struct{})
|
||||
close(triggerNow)
|
||||
|
||||
var nextDial []ma.Multiaddr
|
||||
active := 0
|
||||
done := false // true when the request channel has been closed
|
||||
connected := false // true when a connection has been successfully established
|
||||
|
||||
resch := make(chan dialResult)
|
||||
|
||||
loop:
|
||||
for {
|
||||
select {
|
||||
case req, ok := <-reqch:
|
||||
if !ok {
|
||||
// request channel has been closed, wait for pending dials to complete
|
||||
if active > 0 {
|
||||
done = true
|
||||
reqch = nil
|
||||
triggerDial = nil
|
||||
continue loop
|
||||
}
|
||||
|
||||
// no active dials, we are done
|
||||
return
|
||||
}
|
||||
|
||||
c := s.bestAcceptableConnToPeer(req.ctx, p)
|
||||
if c != nil {
|
||||
req.resch <- dialResponse{conn: c}
|
||||
continue loop
|
||||
}
|
||||
|
||||
addrs, err := s.addrsForDial(req.ctx, p)
|
||||
if err != nil {
|
||||
req.resch <- dialResponse{err: err}
|
||||
continue loop
|
||||
}
|
||||
|
||||
// at this point, len(addrs) > 0 or else it would be error from addrsForDial
|
||||
// ranke them to process in order
|
||||
addrs = s.rankAddrs(addrs)
|
||||
|
||||
// create the pending request object
|
||||
pr := &pendRequest{
|
||||
req: req,
|
||||
err: &DialError{Peer: p},
|
||||
addrs: make(map[ma.Multiaddr]struct{}),
|
||||
}
|
||||
for _, a := range addrs {
|
||||
pr.addrs[a] = struct{}{}
|
||||
}
|
||||
|
||||
// check if any of the addrs has been successfully dialed and accumulate
|
||||
// errors from complete dials while collecting new addrs to dial/join
|
||||
var todial []ma.Multiaddr
|
||||
var tojoin []*addrDial
|
||||
|
||||
for _, a := range addrs {
|
||||
ad, ok := pending[a]
|
||||
if !ok {
|
||||
todial = append(todial, a)
|
||||
continue
|
||||
}
|
||||
|
||||
if ad.conn != nil {
|
||||
// dial to this addr was successful, complete the request
|
||||
req.resch <- dialResponse{conn: ad.conn}
|
||||
continue loop
|
||||
}
|
||||
|
||||
if ad.err != nil {
|
||||
// dial to this addr errored, accumulate the error
|
||||
pr.err.recordErr(a, ad.err)
|
||||
delete(pr.addrs, a)
|
||||
continue
|
||||
}
|
||||
|
||||
// dial is still pending, add to the join list
|
||||
tojoin = append(tojoin, ad)
|
||||
}
|
||||
|
||||
if len(todial) == 0 && len(tojoin) == 0 {
|
||||
// all request applicable addrs have been dialed, we must have errored
|
||||
req.resch <- dialResponse{err: pr.err}
|
||||
continue loop
|
||||
}
|
||||
|
||||
// the request has some pending or new dials, track it and schedule new dials
|
||||
reqno++
|
||||
requests[reqno] = pr
|
||||
|
||||
for _, ad := range tojoin {
|
||||
if !ad.dialed {
|
||||
if simConnect, isClient, reason := network.GetSimultaneousConnect(req.ctx); simConnect {
|
||||
if simConnect, _, _ := network.GetSimultaneousConnect(ad.ctx); !simConnect {
|
||||
ad.ctx = network.WithSimultaneousConnect(ad.ctx, isClient, reason)
|
||||
}
|
||||
}
|
||||
}
|
||||
ad.requests = append(ad.requests, reqno)
|
||||
}
|
||||
|
||||
if len(todial) > 0 {
|
||||
for _, a := range todial {
|
||||
pending[a] = &addrDial{addr: a, ctx: req.ctx, requests: []int{reqno}}
|
||||
}
|
||||
|
||||
nextDial = append(nextDial, todial...)
|
||||
nextDial = s.rankAddrs(nextDial)
|
||||
|
||||
// trigger a new dial now to account for the new addrs we added
|
||||
triggerDial = triggerNow
|
||||
}
|
||||
|
||||
case <-triggerDial:
|
||||
for _, addr := range nextDial {
|
||||
// spawn the dial
|
||||
ad := pending[addr]
|
||||
err := s.dialNextAddr(ad.ctx, p, addr, resch)
|
||||
if err != nil {
|
||||
dispatchError(ad, err)
|
||||
} else {
|
||||
active++
|
||||
}
|
||||
}
|
||||
|
||||
nextDial = nil
|
||||
triggerDial = nil
|
||||
|
||||
case res := <-resch:
|
||||
active--
|
||||
|
||||
if res.Conn != nil {
|
||||
connected = true
|
||||
}
|
||||
|
||||
if done && active == 0 {
|
||||
if res.Conn != nil {
|
||||
// we got an actual connection, but the dial has been cancelled
|
||||
// Should we close it? I think not, we should just add it to the swarm
|
||||
_, err := s.addConn(res.Conn, network.DirOutbound)
|
||||
if err != nil {
|
||||
// well duh, now we have to close it
|
||||
res.Conn.Close()
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
ad := pending[res.Addr]
|
||||
|
||||
if res.Conn != nil {
|
||||
// we got a connection, add it to the swarm
|
||||
conn, err := s.addConn(res.Conn, network.DirOutbound)
|
||||
if err != nil {
|
||||
// oops no, we failed to add it to the swarm
|
||||
res.Conn.Close()
|
||||
dispatchError(ad, err)
|
||||
if active == 0 && len(nextDial) > 0 {
|
||||
triggerDial = triggerNow
|
||||
}
|
||||
continue loop
|
||||
}
|
||||
|
||||
// dispatch to still pending requests
|
||||
for _, reqno := range ad.requests {
|
||||
pr, ok := requests[reqno]
|
||||
if !ok {
|
||||
// it has already dispatched a connection
|
||||
continue
|
||||
}
|
||||
|
||||
pr.req.resch <- dialResponse{conn: conn}
|
||||
delete(requests, reqno)
|
||||
}
|
||||
|
||||
ad.conn = conn
|
||||
ad.requests = nil
|
||||
|
||||
continue loop
|
||||
}
|
||||
|
||||
// it must be an error -- add backoff if applicable and dispatch
|
||||
if res.Err != context.Canceled && !connected {
|
||||
// we only add backoff if there has not been a successful connection
|
||||
// for consistency with the old dialer behavior.
|
||||
s.backf.AddBackoff(p, res.Addr)
|
||||
}
|
||||
|
||||
dispatchError(ad, res.Err)
|
||||
if active == 0 && len(nextDial) > 0 {
|
||||
triggerDial = triggerNow
|
||||
}
|
||||
}
|
||||
}
|
||||
w := newDialWorker(s, p, reqch)
|
||||
w.loop()
|
||||
}
|
||||
|
||||
func (s *Swarm) addrsForDial(ctx context.Context, p peer.ID) ([]ma.Multiaddr, error) {
|
||||
|
@ -597,43 +326,6 @@ func (s *Swarm) nonProxyAddr(addr ma.Multiaddr) bool {
|
|||
return !t.Proxy()
|
||||
}
|
||||
|
||||
// ranks addresses in descending order of preference for dialing, with the following rules:
|
||||
// NonRelay > Relay
|
||||
// NonWS > WS
|
||||
// Private > Public
|
||||
// UDP > TCP
|
||||
func (s *Swarm) rankAddrs(addrs []ma.Multiaddr) []ma.Multiaddr {
|
||||
addrTier := func(a ma.Multiaddr) (tier int) {
|
||||
if isRelayAddr(a) {
|
||||
tier |= 0b1000
|
||||
}
|
||||
if isExpensiveAddr(a) {
|
||||
tier |= 0b0100
|
||||
}
|
||||
if !manet.IsPrivateAddr(a) {
|
||||
tier |= 0b0010
|
||||
}
|
||||
if isFdConsumingAddr(a) {
|
||||
tier |= 0b0001
|
||||
}
|
||||
|
||||
return tier
|
||||
}
|
||||
|
||||
tiers := make([][]ma.Multiaddr, 16)
|
||||
for _, a := range addrs {
|
||||
tier := addrTier(a)
|
||||
tiers[tier] = append(tiers[tier], a)
|
||||
}
|
||||
|
||||
result := make([]ma.Multiaddr, 0, len(addrs))
|
||||
for _, tier := range tiers {
|
||||
result = append(result, tier...)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// filterKnownUndialables takes a list of multiaddrs, and removes those
|
||||
// that we definitely don't want to dial: addresses configured to be blocked,
|
||||
// IPv6 link-local addresses, addresses without a dial-capable transport,
|
||||
|
|
Loading…
Reference in New Issue