diff --git a/p2p/host/basic/basic_host.go b/p2p/host/basic/basic_host.go index 6127d340..0725fa70 100644 --- a/p2p/host/basic/basic_host.go +++ b/p2p/host/basic/basic_host.go @@ -132,7 +132,7 @@ func NewHost(ctx context.Context, net inet.Network, opts *HostOpts) (*BasicHost, h.ids = opts.IdentifyService } else { // we can't set this as a default above because it depends on the *BasicHost. - h.ids = identify.NewIDService(h) + h.ids = identify.NewIDService(ctx, h) } if uint64(opts.NegotiationTimeout) != 0 { diff --git a/p2p/protocol/identify/id.go b/p2p/protocol/identify/id.go index 01fc7a7b..f2e118b4 100644 --- a/p2p/protocol/identify/id.go +++ b/p2p/protocol/identify/id.go @@ -53,15 +53,16 @@ type IDService struct { // our own observed addresses. // TODO: instead of expiring, remove these when we disconnect - observedAddrs ObservedAddrSet + observedAddrs *ObservedAddrSet } // NewIDService constructs a new *IDService and activates it by // attaching its stream handler to the given host.Host. -func NewIDService(h host.Host) *IDService { +func NewIDService(ctx context.Context, h host.Host) *IDService { s := &IDService{ - Host: h, - currid: make(map[inet.Conn]chan struct{}), + Host: h, + currid: make(map[inet.Conn]chan struct{}), + observedAddrs: NewObservedAddrSet(ctx), } h.SetStreamHandler(ID, s.requestHandler) h.SetStreamHandler(IDPush, s.pushHandler) diff --git a/p2p/protocol/identify/id_test.go b/p2p/protocol/identify/id_test.go index 621ce1aa..fd56f28c 100644 --- a/p2p/protocol/identify/id_test.go +++ b/p2p/protocol/identify/id_test.go @@ -26,8 +26,8 @@ func subtestIDService(t *testing.T) { h1p := h1.ID() h2p := h2.ID() - ids1 := identify.NewIDService(h1) - ids2 := identify.NewIDService(h2) + ids1 := identify.NewIDService(ctx, h1) + ids2 := identify.NewIDService(ctx, h2) testKnowsAddrs(t, h1, h2p, []ma.Multiaddr{}) // nothing testKnowsAddrs(t, h2, h1p, []ma.Multiaddr{}) // nothing diff --git a/p2p/protocol/identify/obsaddr.go b/p2p/protocol/identify/obsaddr.go index af153edc..b568ba50 100644 --- a/p2p/protocol/identify/obsaddr.go +++ b/p2p/protocol/identify/obsaddr.go @@ -1,6 +1,7 @@ package identify import ( + "context" "sync" "time" @@ -11,6 +12,8 @@ import ( const ActivationThresh = 4 +var GCInterval = 10 * time.Minute + type observation struct { seenTime time.Time connDirection net.Direction @@ -42,23 +45,40 @@ func (oa *ObservedAddr) activated(ttl time.Duration) bool { return len(oa.SeenBy) >= ActivationThresh } +type newObservation struct { + observed, local, observer ma.Multiaddr + direction net.Direction +} + // ObservedAddrSet keeps track of a set of ObservedAddrs // the zero-value is ready to be used. type ObservedAddrSet struct { - sync.Mutex // guards whole datastruct. + sync.RWMutex // guards whole datastruct. // local(internal) address -> list of observed(external) addresses addrs map[string][]*ObservedAddr ttl time.Duration + + // this is the worker channel + wch chan newObservation +} + +func NewObservedAddrSet(ctx context.Context) *ObservedAddrSet { + oas := &ObservedAddrSet{ + addrs: make(map[string][]*ObservedAddr), + ttl: pstore.OwnObservedAddrTTL, + wch: make(chan newObservation, 16), + } + go oas.worker(ctx) + return oas } // AddrsFor return all activated observed addresses associated with the given // (resolved) listen address. func (oas *ObservedAddrSet) AddrsFor(addr ma.Multiaddr) (addrs []ma.Multiaddr) { - oas.Lock() - defer oas.Unlock() + oas.RLock() + defer oas.RUnlock() - // for zero-value. if len(oas.addrs) == 0 { return nil } @@ -70,54 +90,86 @@ func (oas *ObservedAddrSet) AddrsFor(addr ma.Multiaddr) (addrs []ma.Multiaddr) { } now := time.Now() - filteredAddrs := make([]*ObservedAddr, 0, len(observedAddrs)) for _, a := range observedAddrs { - // leave only alive observed addresses - if now.Sub(a.LastSeen) <= oas.ttl { - filteredAddrs = append(filteredAddrs, a) - if a.activated(oas.ttl) { - addrs = append(addrs, a.Addr) - } + if now.Sub(a.LastSeen) <= oas.ttl && a.activated(oas.ttl) { + addrs = append(addrs, a.Addr) } } - if len(filteredAddrs) > 0 { - oas.addrs[key] = filteredAddrs - } else { - delete(oas.addrs, key) - } return addrs } // Addrs return all activated observed addresses func (oas *ObservedAddrSet) Addrs() (addrs []ma.Multiaddr) { - oas.Lock() - defer oas.Unlock() + oas.RLock() + defer oas.RUnlock() - // for zero-value. if len(oas.addrs) == 0 { return nil } now := time.Now() - for local, observedAddrs := range oas.addrs { - filteredAddrs := make([]*ObservedAddr, 0, len(observedAddrs)) + for _, observedAddrs := range oas.addrs { for _, a := range observedAddrs { - // leave only alive observed addresses - if now.Sub(a.LastSeen) <= oas.ttl { - filteredAddrs = append(filteredAddrs, a) - if a.activated(oas.ttl) { - addrs = append(addrs, a.Addr) - } + if now.Sub(a.LastSeen) <= oas.ttl && a.activated(oas.ttl) { + addrs = append(addrs, a.Addr) } } - oas.addrs[local] = filteredAddrs } return addrs } func (oas *ObservedAddrSet) Add(observed, local, observer ma.Multiaddr, direction net.Direction) { + select { + case oas.wch <- newObservation{observed: observed, local: local, observer: observer, direction: direction}: + default: + log.Debugf("dropping address observation of %s; buffer full", observed) + } +} + +func (oas *ObservedAddrSet) worker(ctx context.Context) { + ticker := time.NewTicker(GCInterval) + defer ticker.Stop() + + for { + select { + case obs := <-oas.wch: + oas.doAdd(obs.observed, obs.local, obs.observer, obs.direction) + + case <-ticker.C: + oas.gc() + + case <-ctx.Done(): + return + } + } +} + +func (oas *ObservedAddrSet) gc() { + oas.Lock() + defer oas.Unlock() + + now := time.Now() + for local, observedAddrs := range oas.addrs { + // TODO we can do this without allocating by compacting the array in place + filteredAddrs := make([]*ObservedAddr, 0, len(observedAddrs)) + for _, a := range observedAddrs { + // leave only alive observed addresses + if now.Sub(a.LastSeen) <= oas.ttl { + filteredAddrs = append(filteredAddrs, a) + } + } + if len(filteredAddrs) > 0 { + oas.addrs[local] = filteredAddrs + } else { + delete(oas.addrs, local) + } + } +} + +func (oas *ObservedAddrSet) doAdd(observed, local, observer ma.Multiaddr, + direction net.Direction) { now := time.Now() observerString := observerGroup(observer) @@ -130,12 +182,6 @@ func (oas *ObservedAddrSet) Add(observed, local, observer ma.Multiaddr, oas.Lock() defer oas.Unlock() - // for zero-value. - if oas.addrs == nil { - oas.addrs = make(map[string][]*ObservedAddr) - oas.ttl = pstore.OwnObservedAddrTTL - } - observedAddrs := oas.addrs[localString] // check if observed address seen yet, if so, update it for i, previousObserved := range observedAddrs { @@ -178,11 +224,7 @@ func (oas *ObservedAddrSet) SetTTL(ttl time.Duration) { } func (oas *ObservedAddrSet) TTL() time.Duration { - oas.Lock() - defer oas.Unlock() - // for zero-value. - if oas.addrs == nil { - oas.ttl = pstore.OwnObservedAddrTTL - } + oas.RLock() + defer oas.RUnlock() return oas.ttl } diff --git a/p2p/protocol/identify/obsaddr_test.go b/p2p/protocol/identify/obsaddr_test.go index 12bdb54a..440b009b 100644 --- a/p2p/protocol/identify/obsaddr_test.go +++ b/p2p/protocol/identify/obsaddr_test.go @@ -1,6 +1,7 @@ package identify import ( + "context" "sync" "testing" "time" @@ -52,7 +53,9 @@ func TestObsAddrSet(t *testing.T) { b4 := m("/ip4/1.2.3.9/tcp/1237") b5 := m("/ip4/1.2.3.10/tcp/1237") - oas := &ObservedAddrSet{} + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + oas := NewObservedAddrSet(ctx) if !addrsMarch(oas.Addrs(), nil) { t.Error("addrs should be empty") @@ -63,6 +66,7 @@ func TestObsAddrSet(t *testing.T) { dummyDirection := net.DirOutbound oas.Add(observed, dummyLocal, observer, dummyDirection) + time.Sleep(1 * time.Millisecond) // let the worker run } add(oas, a1, a4) @@ -131,13 +135,17 @@ func TestAddAddrsProfile(b *testing.T) { } return m } - oas := &ObservedAddrSet{} + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + oas := NewObservedAddrSet(ctx) add := func(oas *ObservedAddrSet, observed, observer ma.Multiaddr) { dummyLocal := m("/ip4/127.0.0.1/tcp/10086") dummyDirection := net.DirOutbound oas.Add(observed, dummyLocal, observer, dummyDirection) + time.Sleep(1 * time.Millisecond) // let the worker run } a1 := m("/ip4/1.2.3.4/tcp/1231")