Merge pull request #1175 from libp2p/id-service-shutdown

don't use a context for closing the ObservedAddrManager
This commit is contained in:
Marten Seemann 2021-09-07 14:30:40 +01:00 committed by GitHub
commit 0797df7cbc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 46 additions and 42 deletions

View File

@ -87,8 +87,6 @@ type IDService struct {
ctx context.Context
ctxCancel context.CancelFunc
// ensure we shutdown ONLY once
closeSync sync.Once
// track resources that need to be shut down before we shut down
refCount sync.WaitGroup
@ -126,25 +124,23 @@ func NewIDService(h host.Host, opts ...Option) (*IDService, error) {
userAgent = cfg.userAgent
}
hostCtx, cancel := context.WithCancel(context.Background())
s := &IDService{
Host: h,
UserAgent: userAgent,
ctx: hostCtx,
ctxCancel: cancel,
conns: make(map[network.Conn]chan struct{}),
conns: make(map[network.Conn]chan struct{}),
disableSignedPeerRecord: cfg.disableSignedPeerRecord,
addPeerHandlerCh: make(chan addPeerHandlerReq),
rmPeerHandlerCh: make(chan rmPeerHandlerReq),
}
s.ctx, s.ctxCancel = context.WithCancel(context.Background())
// handle local protocol handler updates, and push deltas to peers.
var err error
observedAddrs, err := NewObservedAddrManager(hostCtx, h)
observedAddrs, err := NewObservedAddrManager(h)
if err != nil {
return nil, fmt.Errorf("failed to create observed address manager: %s", err)
}
@ -276,10 +272,8 @@ func (ids *IDService) loop() {
// Close shuts down the IDService
func (ids *IDService) Close() error {
ids.closeSync.Do(func() {
ids.ctxCancel()
ids.refCount.Wait()
})
ids.ctxCancel()
ids.refCount.Wait()
return nil
}

View File

@ -98,13 +98,19 @@ type newObservation struct {
type ObservedAddrManager struct {
host host.Host
closeOnce sync.Once
refCount sync.WaitGroup
ctx context.Context // the context is canceled when Close is called
ctxCancel context.CancelFunc
// latest observation from active connections
// we'll "re-observe" these when we gc
activeConnsMu sync.Mutex
// active connection -> most recent observation
activeConns map[network.Conn]ma.Multiaddr
mu sync.RWMutex
mu sync.RWMutex
closed bool
// local(internal) address -> list of observed(external) addresses
addrs map[string][]*observedAddr
ttl time.Duration
@ -123,7 +129,7 @@ type ObservedAddrManager struct {
// NewObservedAddrManager returns a new address manager using
// peerstore.OwnObservedAddressTTL as the TTL.
func NewObservedAddrManager(ctx context.Context, host host.Host) (*ObservedAddrManager, error) {
func NewObservedAddrManager(host host.Host) (*ObservedAddrManager, error) {
oas := &ObservedAddrManager{
addrs: make(map[string][]*observedAddr),
ttl: peerstore.OwnObservedAddrTTL,
@ -133,6 +139,7 @@ func NewObservedAddrManager(ctx context.Context, host host.Host) (*ObservedAddrM
// refresh every ttl/2 so we don't forget observations from connected peers
refreshTimer: time.NewTimer(peerstore.OwnObservedAddrTTL / 2),
}
oas.ctx, oas.ctxCancel = context.WithCancel(context.Background())
reachabilitySub, err := host.EventBus().Subscribe(new(event.EvtLocalReachabilityChanged))
if err != nil {
@ -147,7 +154,8 @@ func NewObservedAddrManager(ctx context.Context, host host.Host) (*ObservedAddrM
oas.emitNATDeviceTypeChanged = emitter
oas.host.Network().Notify((*obsAddrNotifiee)(oas))
go oas.worker(ctx)
oas.refCount.Add(1)
go oas.worker()
return oas, nil
}
@ -239,22 +247,12 @@ func (oas *ObservedAddrManager) Record(conn network.Conn, observed ma.Multiaddr)
}
}
func (oas *ObservedAddrManager) teardown() {
oas.host.Network().StopNotify((*obsAddrNotifiee)(oas))
oas.reachabilitySub.Close()
oas.mu.Lock()
oas.refreshTimer.Stop()
oas.mu.Unlock()
}
func (oas *ObservedAddrManager) worker(ctx context.Context) {
defer oas.teardown()
func (oas *ObservedAddrManager) worker() {
defer oas.refCount.Done()
ticker := time.NewTicker(GCInterval)
defer ticker.Stop()
hostClosing := oas.host.Network().Process().Closing()
subChan := oas.reachabilitySub.Out()
for {
select {
@ -265,17 +263,13 @@ func (oas *ObservedAddrManager) worker(ctx context.Context) {
}
ev := evt.(event.EvtLocalReachabilityChanged)
oas.reachability = ev.Reachability
case obs := <-oas.wch:
oas.maybeRecordObservation(obs.conn, obs.observed)
case <-ticker.C:
oas.gc()
case <-oas.refreshTimer.C:
oas.refresh()
case <-hostClosing:
return
case <-ctx.Done():
case <-oas.ctx.Done():
return
}
}
@ -534,6 +528,22 @@ func (oas *ObservedAddrManager) emitSpecificNATType(addrs []*observedAddr, proto
return false, 0
}
func (oas *ObservedAddrManager) Close() error {
oas.closeOnce.Do(func() {
oas.ctxCancel()
oas.mu.Lock()
oas.closed = true
oas.refreshTimer.Stop()
oas.mu.Unlock()
oas.refCount.Wait()
oas.reachabilitySub.Close()
oas.host.Network().StopNotify((*obsAddrNotifiee)(oas))
})
return nil
}
// observerGroup is a function that determines what part of
// a multiaddr counts as a different observer. for example,
// two ipfs nodes at the same IP/TCP transport would get
@ -554,6 +564,9 @@ func observerGroup(m ma.Multiaddr) string {
func (oas *ObservedAddrManager) SetTTL(ttl time.Duration) {
oas.mu.Lock()
defer oas.mu.Unlock()
if oas.closed {
return
}
oas.ttl = ttl
// refresh every ttl/2 so we don't forget observations from connected peers
oas.refreshTimer.Reset(ttl / 2)

View File

@ -85,18 +85,11 @@ func (h *harness) observeInbound(observed ma.Multiaddr, observer peer.ID) networ
func newHarness(ctx context.Context, t *testing.T) harness {
mn := mocknet.New(ctx)
sk, err := p2putil.RandTestBogusPrivateKey()
if err != nil {
t.Fatal(err)
}
h, err := mn.AddPeer(sk, ma.StringCast("/ip4/127.0.0.1/tcp/10086"))
if err != nil {
t.Fatal(err)
}
oas, err := identify.NewObservedAddrManager(ctx, h)
require.NoError(t, err)
h, err := mn.AddPeer(sk, ma.StringCast("/ip4/127.0.0.1/tcp/10086"))
require.NoError(t, err)
oas, err := identify.NewObservedAddrManager(h)
require.NoError(t, err)
return harness{
oas: oas,
mocknet: mn,
@ -142,6 +135,7 @@ func TestObsAddrSet(t *testing.T) {
defer cancel()
harness := newHarness(ctx, t)
defer harness.oas.Close()
if !addrsMatch(harness.oas.Addrs(), nil) {
t.Error("addrs should be empty")
@ -243,6 +237,7 @@ func TestObservedAddrFiltering(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
harness := newHarness(ctx, t)
defer harness.oas.Close()
require.Empty(t, harness.oas.Addrs())
// IP4/TCP
@ -344,6 +339,7 @@ func TestEmitNATDeviceTypeSymmetric(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
harness := newHarness(ctx, t)
defer harness.oas.Close()
require.Empty(t, harness.oas.Addrs())
emitter, err := harness.host.EventBus().Emitter(new(event.EvtLocalReachabilityChanged), eventbus.Stateful)
require.NoError(t, err)
@ -390,6 +386,7 @@ func TestEmitNATDeviceTypeCone(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
harness := newHarness(ctx, t)
defer harness.oas.Close()
require.Empty(t, harness.oas.Addrs())
emitter, err := harness.host.EventBus().Emitter(new(event.EvtLocalReachabilityChanged), eventbus.Stateful)
require.NoError(t, err)