status-go/vendor/github.com/libp2p/go-libp2p-swarm/dial_sync.go

110 lines
2.5 KiB
Go
Raw Normal View History

package swarm
import (
"context"
"sync"
2021-10-19 13:43:41 +00:00
"github.com/libp2p/go-libp2p-core/network"
2019-06-09 07:24:20 +00:00
"github.com/libp2p/go-libp2p-core/peer"
)
2022-04-01 16:16:46 +00:00
// dialWorkerFunc is used by dialSync to spawn a new dial worker
type dialWorkerFunc func(peer.ID, <-chan dialRequest)
2022-04-01 16:16:46 +00:00
// newDialSync constructs a new dialSync
func newDialSync(worker dialWorkerFunc) *dialSync {
return &dialSync{
2021-10-19 13:43:41 +00:00
dials: make(map[peer.ID]*activeDial),
dialWorker: worker,
}
}
2022-04-01 16:16:46 +00:00
// dialSync is a dial synchronization helper that ensures that at most one dial
// to any given peer is active at any given time.
2022-04-01 16:16:46 +00:00
type dialSync struct {
mutex sync.Mutex
2021-10-19 13:43:41 +00:00
dials map[peer.ID]*activeDial
dialWorker dialWorkerFunc
}
type activeDial struct {
2021-10-19 13:43:41 +00:00
refCnt int
ctx context.Context
cancel func()
2021-10-19 13:43:41 +00:00
reqch chan dialRequest
}
2022-04-01 16:16:46 +00:00
func (ad *activeDial) close() {
ad.cancel()
close(ad.reqch)
}
2022-04-01 16:16:46 +00:00
func (ad *activeDial) dial(ctx context.Context) (*Conn, error) {
2021-10-19 13:43:41 +00:00
dialCtx := ad.ctx
2021-10-19 13:43:41 +00:00
if forceDirect, reason := network.GetForceDirectDial(ctx); forceDirect {
dialCtx = network.WithForceDirectDial(dialCtx, reason)
}
2022-04-01 16:16:46 +00:00
if simConnect, isClient, reason := network.GetSimultaneousConnect(ctx); simConnect {
dialCtx = network.WithSimultaneousConnect(dialCtx, isClient, reason)
2021-10-19 13:43:41 +00:00
}
resch := make(chan dialResponse, 1)
select {
case ad.reqch <- dialRequest{ctx: dialCtx, resch: resch}:
case <-ctx.Done():
return nil, ctx.Err()
}
2021-10-19 13:43:41 +00:00
select {
case res := <-resch:
return res.conn, res.err
case <-ctx.Done():
return nil, ctx.Err()
}
}
2022-04-01 16:16:46 +00:00
func (ds *dialSync) getActiveDial(p peer.ID) (*activeDial, error) {
ds.mutex.Lock()
defer ds.mutex.Unlock()
actd, ok := ds.dials[p]
if !ok {
2021-10-19 13:43:41 +00:00
// This code intentionally uses the background context. Otherwise, if the first call
// to Dial is canceled, subsequent dial calls will also be canceled.
2022-04-01 16:16:46 +00:00
ctx, cancel := context.WithCancel(context.Background())
actd = &activeDial{
2022-04-01 16:16:46 +00:00
ctx: ctx,
cancel: cancel,
2021-10-19 13:43:41 +00:00
reqch: make(chan dialRequest),
}
2022-04-01 16:16:46 +00:00
go ds.dialWorker(p, actd.reqch)
2021-10-19 13:43:41 +00:00
ds.dials[p] = actd
}
2022-04-01 16:16:46 +00:00
// increase ref count before dropping mutex
2021-10-19 13:43:41 +00:00
actd.refCnt++
return actd, nil
}
2022-04-01 16:16:46 +00:00
// Dial initiates a dial to the given peer if there are none in progress
// then waits for the dial to that peer to complete.
2022-04-01 16:16:46 +00:00
func (ds *dialSync) Dial(ctx context.Context, p peer.ID) (*Conn, error) {
2021-10-19 13:43:41 +00:00
ad, err := ds.getActiveDial(p)
if err != nil {
return nil, err
}
2021-10-19 13:43:41 +00:00
2022-04-01 16:16:46 +00:00
defer func() {
ds.mutex.Lock()
defer ds.mutex.Unlock()
ad.refCnt--
if ad.refCnt == 0 {
ad.close()
delete(ds.dials, p)
}
}()
return ad.dial(ctx)
}