simplify the DialSync code

It's easier to reason about this code if activeDial doesn't contain a pointer
back to DialSync (which already has a map of activeDials). It also allows us to
remove the memory footprint of the activeDial struct, so this should be
(slightly) more efficient.
This commit is contained in:
Marten Seemann 2021-08-23 17:38:57 +01:00
parent bf044ffcb7
commit 0e0111c6f5
1 changed files with 20 additions and 25 deletions

View File

@ -22,35 +22,26 @@ func newDialSync(worker dialWorkerFunc) *DialSync {
// DialSync is a dial synchronization helper that ensures that at most one dial // DialSync is a dial synchronization helper that ensures that at most one dial
// to any given peer is active at any given time. // to any given peer is active at any given time.
type DialSync struct { type DialSync struct {
mutex sync.Mutex
dials map[peer.ID]*activeDial dials map[peer.ID]*activeDial
dialsLk sync.Mutex
dialWorker dialWorkerFunc dialWorker dialWorkerFunc
} }
type activeDial struct { type activeDial struct {
id peer.ID
refCnt int refCnt int
ctx context.Context ctx context.Context
cancel func() cancel func()
reqch chan dialRequest reqch chan dialRequest
ds *DialSync
} }
func (ad *activeDial) decref() { func (ad *activeDial) close() {
ad.ds.dialsLk.Lock()
ad.refCnt--
if ad.refCnt == 0 {
ad.cancel() ad.cancel()
close(ad.reqch) close(ad.reqch)
delete(ad.ds.dials, ad.id)
}
ad.ds.dialsLk.Unlock()
} }
func (ad *activeDial) dial(ctx context.Context, p peer.ID) (*Conn, error) { func (ad *activeDial) dial(ctx context.Context) (*Conn, error) {
dialCtx := ad.ctx dialCtx := ad.ctx
if forceDirect, reason := network.GetForceDirectDial(ctx); forceDirect { if forceDirect, reason := network.GetForceDirectDial(ctx); forceDirect {
@ -76,8 +67,8 @@ func (ad *activeDial) dial(ctx context.Context, p peer.ID) (*Conn, error) {
} }
func (ds *DialSync) getActiveDial(p peer.ID) (*activeDial, error) { func (ds *DialSync) getActiveDial(p peer.ID) (*activeDial, error) {
ds.dialsLk.Lock() ds.mutex.Lock()
defer ds.dialsLk.Unlock() defer ds.mutex.Unlock()
actd, ok := ds.dials[p] actd, ok := ds.dials[p]
if !ok { if !ok {
@ -85,21 +76,17 @@ func (ds *DialSync) getActiveDial(p peer.ID) (*activeDial, error) {
// to Dial is canceled, subsequent dial calls will also be canceled. // to Dial is canceled, subsequent dial calls will also be canceled.
// XXX: this also breaks direct connection logic. We will need to pipe the // XXX: this also breaks direct connection logic. We will need to pipe the
// information through some other way. // information through some other way.
adctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
actd = &activeDial{ actd = &activeDial{
id: p, ctx: ctx,
ctx: adctx,
cancel: cancel, cancel: cancel,
reqch: make(chan dialRequest), reqch: make(chan dialRequest),
ds: ds,
} }
go ds.dialWorker(p, actd.reqch) go ds.dialWorker(p, actd.reqch)
ds.dials[p] = actd ds.dials[p] = actd
} }
// increase ref count before dropping mutex
// increase ref count before dropping dialsLk
actd.refCnt++ actd.refCnt++
return actd, nil return actd, nil
} }
@ -111,6 +98,14 @@ func (ds *DialSync) Dial(ctx context.Context, p peer.ID) (*Conn, error) {
return nil, err return nil, err
} }
defer ad.decref() defer func() {
return ad.dial(ctx, p) ds.mutex.Lock()
defer ds.mutex.Unlock()
ad.refCnt--
if ad.refCnt == 0 {
ad.close()
delete(ds.dials, p)
}
}()
return ad.dial(ctx)
} }