233 lines
5.4 KiB
Go
Raw Normal View History

2022-03-10 10:44:48 +01:00
package dnscache
import (
"context"
"net"
"sync"
"time"
"golang.org/x/sync/singleflight"
)
type DNSResolver interface {
LookupHost(ctx context.Context, host string) (addrs []string, err error)
LookupAddr(ctx context.Context, addr string) (names []string, err error)
}
type Resolver struct {
// Timeout defines the maximum allowed time allowed for a lookup.
Timeout time.Duration
// Resolver is used to perform actual DNS lookup. If nil,
// net.DefaultResolver is used instead.
Resolver DNSResolver
once sync.Once
mu sync.RWMutex
cache map[string]*cacheEntry
// OnCacheMiss is executed if the host or address is not included in
// the cache and the default lookup is executed.
OnCacheMiss func()
}
type ResolverRefreshOptions struct {
ClearUnused bool
PersistOnFailure bool
}
type cacheEntry struct {
rrs []string
err error
used bool
}
// LookupAddr performs a reverse lookup for the given address, returning a list
// of names mapping to that address.
func (r *Resolver) LookupAddr(ctx context.Context, addr string) (names []string, err error) {
r.once.Do(r.init)
return r.lookup(ctx, "r"+addr)
}
// LookupHost looks up the given host using the local resolver. It returns a
// slice of that host's addresses.
func (r *Resolver) LookupHost(ctx context.Context, host string) (addrs []string, err error) {
r.once.Do(r.init)
return r.lookup(ctx, "h"+host)
}
// refreshRecords refreshes cached entries which have been used at least once since
// the last Refresh. If clearUnused is true, entries which haven't be used since the
// last Refresh are removed from the cache. If persistOnFailure is true, stale
// entries will not be removed on failed lookups
func (r *Resolver) refreshRecords(clearUnused bool, persistOnFailure bool) {
r.once.Do(r.init)
r.mu.RLock()
update := make([]string, 0, len(r.cache))
del := make([]string, 0, len(r.cache))
for key, entry := range r.cache {
if entry.used {
update = append(update, key)
} else if clearUnused {
del = append(del, key)
}
}
r.mu.RUnlock()
if len(del) > 0 {
r.mu.Lock()
for _, key := range del {
delete(r.cache, key)
}
r.mu.Unlock()
}
for _, key := range update {
r.update(context.Background(), key, false, persistOnFailure)
}
}
func (r *Resolver) Refresh(clearUnused bool) {
r.refreshRecords(clearUnused, false)
}
func (r *Resolver) RefreshWithOptions(options ResolverRefreshOptions) {
r.refreshRecords(options.ClearUnused, options.PersistOnFailure)
}
func (r *Resolver) init() {
r.cache = make(map[string]*cacheEntry)
}
// lookupGroup merges lookup calls together for lookups for the same host. The
// lookupGroup key is is the LookupIPAddr.host argument.
var lookupGroup singleflight.Group
func (r *Resolver) lookup(ctx context.Context, key string) (rrs []string, err error) {
var found bool
rrs, err, found = r.load(key)
if !found {
if r.OnCacheMiss != nil {
r.OnCacheMiss()
}
rrs, err = r.update(ctx, key, true, false)
}
return
}
func (r *Resolver) update(ctx context.Context, key string, used bool, persistOnFailure bool) (rrs []string, err error) {
c := lookupGroup.DoChan(key, r.lookupFunc(key))
select {
case <-ctx.Done():
err = ctx.Err()
if err == context.DeadlineExceeded {
// If DNS request timed out for some reason, force future
// request to start the DNS lookup again rather than waiting
// for the current lookup to complete.
lookupGroup.Forget(key)
}
case res := <-c:
if res.Shared {
// We had concurrent lookups, check if the cache is already updated
// by a friend.
var found bool
rrs, err, found = r.load(key)
if found {
return
}
}
err = res.Err
if err == nil {
rrs, _ = res.Val.([]string)
}
if err != nil && persistOnFailure {
var found bool
rrs, err, found = r.load(key)
if found {
return
}
}
r.mu.Lock()
r.storeLocked(key, rrs, used, err)
r.mu.Unlock()
}
return
}
// lookupFunc returns lookup function for key. The type of the key is stored as
// the first char and the lookup subject is the rest of the key.
func (r *Resolver) lookupFunc(key string) func() (interface{}, error) {
if len(key) == 0 {
panic("lookupFunc with empty key")
}
var resolver DNSResolver = net.DefaultResolver
if r.Resolver != nil {
resolver = r.Resolver
}
switch key[0] {
case 'h':
return func() (interface{}, error) {
ctx, cancel := r.getCtx()
defer cancel()
return resolver.LookupHost(ctx, key[1:])
}
case 'r':
return func() (interface{}, error) {
ctx, cancel := r.getCtx()
defer cancel()
return resolver.LookupAddr(ctx, key[1:])
}
default:
panic("lookupFunc invalid key type: " + key)
}
}
func (r *Resolver) getCtx() (ctx context.Context, cancel context.CancelFunc) {
ctx = context.Background()
if r.Timeout > 0 {
ctx, cancel = context.WithTimeout(ctx, r.Timeout)
} else {
cancel = func() {}
}
return
}
func (r *Resolver) load(key string) (rrs []string, err error, found bool) {
r.mu.RLock()
var entry *cacheEntry
entry, found = r.cache[key]
if !found {
r.mu.RUnlock()
return
}
rrs = entry.rrs
err = entry.err
used := entry.used
r.mu.RUnlock()
if !used {
r.mu.Lock()
entry.used = true
r.mu.Unlock()
}
return rrs, err, true
}
func (r *Resolver) storeLocked(key string, rrs []string, used bool, err error) {
if entry, found := r.cache[key]; found {
// Update existing entry in place
entry.rrs = rrs
entry.err = err
entry.used = used
return
}
r.cache[key] = &cacheEntry{
rrs: rrs,
err: err,
used: used,
}
}