From a15a957a366b5eb09841205d418df7b900f6957c Mon Sep 17 00:00:00 2001 From: John Murret Date: Fri, 1 Mar 2024 05:42:42 -0700 Subject: [PATCH] NET-8056 - v2 DNS Testing Improvements (#20710) * NET-8056 - v2 DNS Testing Improvements * adding TestDNSServer_Lifecycle * add license headers to new files. --- agent/discovery/query_fetcher_v1.go | 104 +++++----- agent/discovery/query_fetcher_v1_ce.go | 2 +- agent/discovery/query_fetcher_v2.go | 21 ++- agent/dns/mock_DNSRouter.go | 18 +- agent/dns/parser.go | 2 +- agent/dns/parser_test.go | 141 ++++++++++++++ agent/dns/router.go | 11 +- agent/dns/router_test.go | 72 ++++++- agent/dns/server.go | 2 + agent/dns/server_test.go | 78 ++++++++ agent/dns_test.go | 250 ++++++++++++++++++++++++- 11 files changed, 634 insertions(+), 67 deletions(-) create mode 100644 agent/dns/parser_test.go create mode 100644 agent/dns/server_test.go diff --git a/agent/discovery/query_fetcher_v1.go b/agent/discovery/query_fetcher_v1.go index fc71ae60e9..87c50f93ad 100644 --- a/agent/discovery/query_fetcher_v1.go +++ b/agent/discovery/query_fetcher_v1.go @@ -39,21 +39,21 @@ var DNSCounters = []prometheus.CounterDefinition{ }, } -// v1DataFetcherDynamicConfig is used to store the dynamic configuration of the V1 data fetcher. -type v1DataFetcherDynamicConfig struct { +// V1DataFetcherDynamicConfig is used to store the dynamic configuration of the V1 data fetcher. +type V1DataFetcherDynamicConfig struct { // Default request tenancy - datacenter string + Datacenter string - segmentName string - nodeName string - nodePartition string + SegmentName string + NodeName string + NodePartition string // Catalog configuration - allowStale bool - maxStale time.Duration - useCache bool - cacheMaxAge time.Duration - onlyPassing bool + AllowStale bool + MaxStale time.Duration + UseCache bool + CacheMaxAge time.Duration + OnlyPassing bool } // V1DataFetcher is used to fetch data from the V1 catalog. @@ -93,19 +93,23 @@ func NewV1DataFetcher(config *config.RuntimeConfig, // LoadConfig loads the configuration for the V1 data fetcher. func (f *V1DataFetcher) LoadConfig(config *config.RuntimeConfig) { - dynamicConfig := &v1DataFetcherDynamicConfig{ - allowStale: config.DNSAllowStale, - maxStale: config.DNSMaxStale, - useCache: config.DNSUseCache, - cacheMaxAge: config.DNSCacheMaxAge, - onlyPassing: config.DNSOnlyPassing, - datacenter: config.Datacenter, - segmentName: config.SegmentName, - nodeName: config.NodeName, + dynamicConfig := &V1DataFetcherDynamicConfig{ + AllowStale: config.DNSAllowStale, + MaxStale: config.DNSMaxStale, + UseCache: config.DNSUseCache, + CacheMaxAge: config.DNSCacheMaxAge, + OnlyPassing: config.DNSOnlyPassing, + Datacenter: config.Datacenter, + SegmentName: config.SegmentName, + NodeName: config.NodeName, } f.dynamicConfig.Store(dynamicConfig) } +func (f *V1DataFetcher) GetConfig() *V1DataFetcherDynamicConfig { + return f.dynamicConfig.Load().(*V1DataFetcherDynamicConfig) +} + // FetchNodes fetches A/AAAA/CNAME func (f *V1DataFetcher) FetchNodes(ctx Context, req *QueryPayload) ([]*Result, error) { if req.Tenancy.Namespace != "" && req.Tenancy.Namespace != acl.DefaultNamespaceName { @@ -113,7 +117,7 @@ func (f *V1DataFetcher) FetchNodes(ctx Context, req *QueryPayload) ([]*Result, e return nil, ErrNotFound } - cfg := f.dynamicConfig.Load().(*v1DataFetcherDynamicConfig) + cfg := f.dynamicConfig.Load().(*V1DataFetcherDynamicConfig) // Make an RPC request args := &structs.NodeSpecificRequest{ Datacenter: req.Tenancy.Datacenter, @@ -121,7 +125,7 @@ func (f *V1DataFetcher) FetchNodes(ctx Context, req *QueryPayload) ([]*Result, e Node: req.Name, QueryOptions: structs.QueryOptions{ Token: ctx.Token, - AllowStale: cfg.allowStale, + AllowStale: cfg.AllowStale, }, EnterpriseMeta: queryTenancyToEntMeta(req.Tenancy), } @@ -160,14 +164,14 @@ func (f *V1DataFetcher) FetchNodes(ctx Context, req *QueryPayload) ([]*Result, e // FetchEndpoints fetches records for A/AAAA/CNAME or SRV requests for services func (f *V1DataFetcher) FetchEndpoints(ctx Context, req *QueryPayload, lookupType LookupType) ([]*Result, error) { f.logger.Trace(fmt.Sprintf("FetchEndpoints - req: %+v / lookupType: %+v", req, lookupType)) - cfg := f.dynamicConfig.Load().(*v1DataFetcherDynamicConfig) + cfg := f.dynamicConfig.Load().(*V1DataFetcherDynamicConfig) return f.fetchService(ctx, req, cfg, lookupType) } // FetchVirtualIP fetches A/AAAA records for virtual IPs func (f *V1DataFetcher) FetchVirtualIP(ctx Context, req *QueryPayload) (*Result, error) { args := structs.ServiceSpecificRequest{ - // The datacenter of the request is not specified because cross-datacenter virtual IP + // The Datacenter of the request is not specified because cross-Datacenter virtual IP // queries are not supported. This guard rail is in place because virtual IPs are allocated // within a DC, therefore their uniqueness is not guaranteed globally. PeerName: req.Tenancy.Peer, @@ -200,16 +204,16 @@ func (f *V1DataFetcher) FetchRecordsByIp(reqCtx Context, ip net.IP) ([]*Result, return nil, ErrNotSupported } - configCtx := f.dynamicConfig.Load().(*v1DataFetcherDynamicConfig) + configCtx := f.dynamicConfig.Load().(*V1DataFetcherDynamicConfig) targetIP := ip.String() var results []*Result args := structs.DCSpecificRequest{ - Datacenter: configCtx.datacenter, + Datacenter: configCtx.Datacenter, QueryOptions: structs.QueryOptions{ Token: reqCtx.Token, - AllowStale: configCtx.allowStale, + AllowStale: configCtx.AllowStale, }, } var out structs.IndexedNodes @@ -229,7 +233,7 @@ func (f *V1DataFetcher) FetchRecordsByIp(reqCtx Context, ip net.IP) ([]*Result, Tenancy: ResultTenancy{ Namespace: f.defaultEnterpriseMeta.NamespaceOrDefault(), Partition: f.defaultEnterpriseMeta.PartitionOrDefault(), - Datacenter: configCtx.datacenter, + Datacenter: configCtx.Datacenter, }, }) return results, nil @@ -239,10 +243,10 @@ func (f *V1DataFetcher) FetchRecordsByIp(reqCtx Context, ip net.IP) ([]*Result, // only look into the services if we didn't find a node sargs := structs.ServiceSpecificRequest{ - Datacenter: configCtx.datacenter, + Datacenter: configCtx.Datacenter, QueryOptions: structs.QueryOptions{ Token: reqCtx.Token, - AllowStale: configCtx.allowStale, + AllowStale: configCtx.AllowStale, }, ServiceAddress: targetIP, EnterpriseMeta: *f.defaultEnterpriseMeta.WithWildcardNamespace(), @@ -293,7 +297,7 @@ func (f *V1DataFetcher) FetchWorkload(ctx Context, req *QueryPayload) (*Result, // FetchPreparedQuery evaluates the results of a prepared query. // deprecated in V2 func (f *V1DataFetcher) FetchPreparedQuery(ctx Context, req *QueryPayload) ([]*Result, error) { - cfg := f.dynamicConfig.Load().(*v1DataFetcherDynamicConfig) + cfg := f.dynamicConfig.Load().(*V1DataFetcherDynamicConfig) // Execute the prepared query. args := structs.PreparedQueryExecuteRequest{ @@ -301,8 +305,8 @@ func (f *V1DataFetcher) FetchPreparedQuery(ctx Context, req *QueryPayload) ([]*R QueryIDOrName: req.Name, QueryOptions: structs.QueryOptions{ Token: ctx.Token, - AllowStale: cfg.allowStale, - MaxAge: cfg.cacheMaxAge, + AllowStale: cfg.AllowStale, + MaxAge: cfg.CacheMaxAge, }, // Always pass the local agent through. In the DNS interface, there @@ -310,10 +314,10 @@ func (f *V1DataFetcher) FetchPreparedQuery(ctx Context, req *QueryPayload) ([]*R // send the local agent's data through to allow distance sorting // relative to ourself on the server side. Agent: structs.QuerySource{ - Datacenter: cfg.datacenter, - Segment: cfg.segmentName, - Node: cfg.nodeName, - NodePartition: cfg.nodePartition, + Datacenter: cfg.Datacenter, + Segment: cfg.SegmentName, + Node: cfg.NodeName, + NodePartition: cfg.NodePartition, }, Source: structs.QuerySource{ Ip: req.SourceIP.String(), @@ -367,11 +371,11 @@ func (f *V1DataFetcher) FetchPreparedQuery(ctx Context, req *QueryPayload) ([]*R // executePreparedQuery is used to execute a PreparedQuery against the Consul catalog. // If the config is set to UseCache, it will use agent cache. -func (f *V1DataFetcher) executePreparedQuery(cfg *v1DataFetcherDynamicConfig, args structs.PreparedQueryExecuteRequest) (*structs.PreparedQueryExecuteResponse, error) { +func (f *V1DataFetcher) executePreparedQuery(cfg *V1DataFetcherDynamicConfig, args structs.PreparedQueryExecuteRequest) (*structs.PreparedQueryExecuteResponse, error) { var out structs.PreparedQueryExecuteResponse RPC: - if cfg.useCache { + if cfg.UseCache { raw, m, err := f.getFromCacheFunc(context.TODO(), cachetype.PreparedQueryName, &args) if err != nil { return nil, err @@ -396,7 +400,7 @@ RPC: // Verify that request is not too stale, redo the request. if args.AllowStale { - if out.LastContact > cfg.maxStale { + if out.LastContact > cfg.MaxStale { args.AllowStale = false f.logger.Warn("Query results too stale, re-requesting") goto RPC @@ -489,10 +493,10 @@ func makeTaggedAddressesFromStrings(tagged map[string]string) map[string]*Tagged // fetchNode is used to look up a node in the Consul catalog within NodeServices. // If the config is set to UseCache, it will get the record from the agent cache. -func (f *V1DataFetcher) fetchNode(cfg *v1DataFetcherDynamicConfig, args *structs.NodeSpecificRequest) (*structs.IndexedNodeServices, error) { +func (f *V1DataFetcher) fetchNode(cfg *V1DataFetcherDynamicConfig, args *structs.NodeSpecificRequest) (*structs.IndexedNodeServices, error) { var out structs.IndexedNodeServices - useCache := cfg.useCache + useCache := cfg.UseCache RPC: if useCache { raw, _, err := f.getFromCacheFunc(context.TODO(), cachetype.NodeServicesName, args) @@ -513,7 +517,7 @@ RPC: // Verify that request is not too stale, redo the request if args.AllowStale { - if out.LastContact > cfg.maxStale { + if out.LastContact > cfg.MaxStale { args.AllowStale = false useCache = false f.logger.Warn("Query results too stale, re-requesting") @@ -527,7 +531,7 @@ RPC: } func (f *V1DataFetcher) fetchService(ctx Context, req *QueryPayload, - cfg *v1DataFetcherDynamicConfig, lookupType LookupType) ([]*Result, error) { + cfg *V1DataFetcherDynamicConfig, lookupType LookupType) ([]*Result, error) { f.logger.Trace("fetchService", "req", req) if req.Tenancy.SamenessGroup == "" { return f.fetchServiceBasedOnTenancy(ctx, req, cfg, lookupType) @@ -538,7 +542,7 @@ func (f *V1DataFetcher) fetchService(ctx Context, req *QueryPayload, // fetchServiceBasedOnTenancy is used to look up a service in the Consul catalog based on its tenancy or default tenancy. func (f *V1DataFetcher) fetchServiceBasedOnTenancy(ctx Context, req *QueryPayload, - cfg *v1DataFetcherDynamicConfig, lookupType LookupType) ([]*Result, error) { + cfg *V1DataFetcherDynamicConfig, lookupType LookupType) ([]*Result, error) { f.logger.Trace(fmt.Sprintf("fetchServiceBasedOnTenancy - req: %+v", req)) if req.Tenancy.SamenessGroup != "" { return nil, errors.New("sameness groups are not allowed for service lookups based on tenancy") @@ -563,10 +567,10 @@ func (f *V1DataFetcher) fetchServiceBasedOnTenancy(ctx Context, req *QueryPayloa TagFilter: req.Tag != "", QueryOptions: structs.QueryOptions{ Token: ctx.Token, - AllowStale: cfg.allowStale, - MaxAge: cfg.cacheMaxAge, - UseCache: cfg.useCache, - MaxStaleDuration: cfg.maxStale, + AllowStale: cfg.AllowStale, + MaxAge: cfg.CacheMaxAge, + UseCache: cfg.UseCache, + MaxStaleDuration: cfg.MaxStale, }, EnterpriseMeta: queryTenancyToEntMeta(req.Tenancy), } @@ -588,7 +592,7 @@ func (f *V1DataFetcher) fetchServiceBasedOnTenancy(ctx Context, req *QueryPayloa // We copy the slice to avoid modifying the result if it comes from the cache nodes := make(structs.CheckServiceNodes, len(out.Nodes)) copy(nodes, out.Nodes) - out.Nodes = nodes.Filter(cfg.onlyPassing) + out.Nodes = nodes.Filter(cfg.OnlyPassing) if err != nil { return nil, fmt.Errorf("rpc request failed: %w", err) } diff --git a/agent/discovery/query_fetcher_v1_ce.go b/agent/discovery/query_fetcher_v1_ce.go index 0260b7a24a..59d32e91e2 100644 --- a/agent/discovery/query_fetcher_v1_ce.go +++ b/agent/discovery/query_fetcher_v1_ce.go @@ -29,7 +29,7 @@ func queryTenancyToEntMeta(_ QueryTenancy) acl.EnterpriseMeta { } // fetchServiceFromSamenessGroup fetches a service from a sameness group. -func (f *V1DataFetcher) fetchServiceFromSamenessGroup(ctx Context, req *QueryPayload, cfg *v1DataFetcherDynamicConfig, lookupType LookupType) ([]*Result, error) { +func (f *V1DataFetcher) fetchServiceFromSamenessGroup(ctx Context, req *QueryPayload, cfg *V1DataFetcherDynamicConfig, lookupType LookupType) ([]*Result, error) { f.logger.Trace(fmt.Sprintf("fetchServiceFromSamenessGroup - req: %+v", req)) if req.Tenancy.SamenessGroup == "" { return nil, errors.New("sameness groups must be provided for service lookups") diff --git a/agent/discovery/query_fetcher_v2.go b/agent/discovery/query_fetcher_v2.go index 02e8fcaccc..ac474811fa 100644 --- a/agent/discovery/query_fetcher_v2.go +++ b/agent/discovery/query_fetcher_v2.go @@ -25,9 +25,9 @@ import ( "github.com/hashicorp/consul/proto-public/pbresource" ) -// v2DataFetcherDynamicConfig is used to store the dynamic configuration of the V2 data fetcher. -type v2DataFetcherDynamicConfig struct { - onlyPassing bool +// V2DataFetcherDynamicConfig is used to store the dynamic configuration of the V2 data fetcher. +type V2DataFetcherDynamicConfig struct { + OnlyPassing bool } // V2DataFetcher is used to fetch data from the V2 catalog. @@ -54,12 +54,17 @@ func NewV2DataFetcher(config *config.RuntimeConfig, client pbresource.ResourceSe // LoadConfig loads the configuration for the V2 data fetcher. func (f *V2DataFetcher) LoadConfig(config *config.RuntimeConfig) { - dynamicConfig := &v2DataFetcherDynamicConfig{ - onlyPassing: config.DNSOnlyPassing, + dynamicConfig := &V2DataFetcherDynamicConfig{ + OnlyPassing: config.DNSOnlyPassing, } f.dynamicConfig.Store(dynamicConfig) } +// GetConfig loads the configuration for the V2 data fetcher. +func (f *V2DataFetcher) GetConfig() *V2DataFetcherDynamicConfig { + return f.dynamicConfig.Load().(*V2DataFetcherDynamicConfig) +} + // FetchNodes fetches A/AAAA/CNAME func (f *V2DataFetcher) FetchNodes(ctx Context, req *QueryPayload) ([]*Result, error) { // TODO (v2-dns): NET-6623 - Implement FetchNodes @@ -73,7 +78,7 @@ func (f *V2DataFetcher) FetchEndpoints(reqContext Context, req *QueryPayload, lo return nil, ErrNotSupported } - configCtx := f.dynamicConfig.Load().(*v2DataFetcherDynamicConfig) + configCtx := f.dynamicConfig.Load().(*V2DataFetcherDynamicConfig) serviceEndpoints := pbcatalog.ServiceEndpoints{} serviceEndpointsResource, err := f.fetchResource(reqContext, *req, pbcatalog.ServiceEndpointsType, &serviceEndpoints) @@ -262,7 +267,7 @@ func (f *V2DataFetcher) addressFromWorkloadAddresses(addresses []*pbcatalog.Work // getEndpointWeight returns the weight of the endpoint and a boolean indicating if the endpoint should be included // based on it's health status. -func getEndpointWeight(endpoint *pbcatalog.Endpoint, configCtx *v2DataFetcherDynamicConfig) (uint32, bool) { +func getEndpointWeight(endpoint *pbcatalog.Endpoint, configCtx *V2DataFetcherDynamicConfig) (uint32, bool) { health := endpoint.GetHealthStatus().Enum() if health == nil { return 0, false @@ -277,7 +282,7 @@ func getEndpointWeight(endpoint *pbcatalog.Endpoint, configCtx *v2DataFetcherDyn case pbcatalog.Health_HEALTH_CRITICAL: return 0, false // always filtered out case pbcatalog.Health_HEALTH_WARNING: - if configCtx.onlyPassing { + if configCtx.OnlyPassing { return 0, false // filtered out } weight = endpoint.GetDns().GetWeights().GetWarning() diff --git a/agent/dns/mock_DNSRouter.go b/agent/dns/mock_DNSRouter.go index 788c894f58..9e90de771e 100644 --- a/agent/dns/mock_DNSRouter.go +++ b/agent/dns/mock_DNSRouter.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.37.1. DO NOT EDIT. +// Code generated by mockery v2.32.4. DO NOT EDIT. package dns @@ -16,6 +16,22 @@ type MockDNSRouter struct { mock.Mock } +// GetConfig provides a mock function with given fields: +func (_m *MockDNSRouter) GetConfig() *RouterDynamicConfig { + ret := _m.Called() + + var r0 *RouterDynamicConfig + if rf, ok := ret.Get(0).(func() *RouterDynamicConfig); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*RouterDynamicConfig) + } + } + + return r0 +} + // HandleRequest provides a mock function with given fields: req, reqCtx, remoteAddress func (_m *MockDNSRouter) HandleRequest(req *miekgdns.Msg, reqCtx Context, remoteAddress net.Addr) *miekgdns.Msg { ret := _m.Called(req, reqCtx, remoteAddress) diff --git a/agent/dns/parser.go b/agent/dns/parser.go index 1a0f0a601d..e39f91e0f9 100644 --- a/agent/dns/parser.go +++ b/agent/dns/parser.go @@ -58,7 +58,7 @@ func parseLabels(labels []string) (*parsedLabels, bool) { return nil, false } - // Validation e need to validate that this a valid DNS including sg + // Validate that this a valid DNS including sg if result.SamenessGroup != "" && (result.Datacenter != "" || result.Peer != "") { return nil, false } diff --git a/agent/dns/parser_test.go b/agent/dns/parser_test.go new file mode 100644 index 0000000000..cd5beb117a --- /dev/null +++ b/agent/dns/parser_test.go @@ -0,0 +1,141 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package dns + +import ( + "github.com/stretchr/testify/require" + "testing" +) + +func Test_parseLabels(t *testing.T) { + type testCase struct { + name string + labels []string + expectedOK bool + expectedResult *parsedLabels + } + testCases := []testCase{ + { + name: "6 labels - with datacenter", + labels: []string{"test-ns", "ns", "test-ap", "ap", "test-dc", "dc"}, + expectedResult: &parsedLabels{ + Namespace: "test-ns", + Partition: "test-ap", + Datacenter: "test-dc", + }, + expectedOK: true, + }, + { + name: "6 labels - with cluster", + labels: []string{"test-ns", "ns", "test-ap", "ap", "test-cluster", "cluster"}, + expectedResult: &parsedLabels{ + Namespace: "test-ns", + Partition: "test-ap", + Datacenter: "test-cluster", + }, + expectedOK: true, + }, + { + name: "6 labels - with peer", + labels: []string{"test-ns", "ns", "test-ap", "ap", "test-peer", "peer"}, + expectedResult: &parsedLabels{ + Namespace: "test-ns", + Partition: "test-ap", + Peer: "test-peer", + }, + expectedOK: true, + }, + { + name: "6 labels - with sameness group", + labels: []string{"test-sg", "sg", "test-ap", "ap", "test-ns", "ns"}, + expectedResult: &parsedLabels{ + Namespace: "test-ns", + Partition: "test-ap", + SamenessGroup: "test-sg", + }, + expectedOK: true, + }, + { + name: "6 labels - invalid", + labels: []string{"test-ns", "not-ns", "test-ap", "ap", "test-dc", "dc"}, + expectedResult: nil, + expectedOK: false, + }, + { + name: "4 labels - namespace and datacenter", + labels: []string{"test-ns", "ns", "test-ap", "ap"}, + expectedResult: &parsedLabels{ + Namespace: "test-ns", + Partition: "test-ap", + }, + expectedOK: true, + }, + { + name: "4 labels - invalid", + labels: []string{"test-ns", "not-ns", "test-ap", "ap", "test-dc", "dc"}, + expectedResult: nil, + expectedOK: false, + }, + { + name: "2 labels - namespace and peer or datacenter", + labels: []string{"test-ns", "test-peer-or-dc"}, + expectedResult: &parsedLabels{ + Namespace: "test-ns", + PeerOrDatacenter: "test-peer-or-dc", + }, + expectedOK: true, + }, + { + name: "1 label - peer or datacenter", + labels: []string{"test-peer-or-dc"}, + expectedResult: &parsedLabels{ + PeerOrDatacenter: "test-peer-or-dc", + }, + expectedOK: true, + }, + { + name: "0 labels - returns empty result and true", + labels: []string{}, + expectedResult: &parsedLabels{}, + expectedOK: true, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result, ok := parseLabels(tc.labels) + require.Equal(t, tc.expectedOK, ok) + require.Equal(t, tc.expectedResult, result) + }) + } +} + +func Test_parsePort(t *testing.T) { + type testCase struct { + name string + labels []string + expectedResult string + } + testCases := []testCase{ + { + name: "given 3 labels where the second label is port, the first label is returned", + labels: []string{"port-name", "port", "target-name"}, + expectedResult: "port-name", + }, + { + name: "given 3 labels where the second label is not port, an empty string is returned", + labels: []string{"port-name", "not-port", "target-name"}, + expectedResult: "", + }, + { + name: "given anything but 3 labels, an empty string is returned", + labels: []string{"port-name", "something-else"}, + expectedResult: "", + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + require.Equal(t, tc.expectedResult, parsePort(tc.labels)) + }) + } +} diff --git a/agent/dns/router.go b/agent/dns/router.go index 267c4bd6fe..40d366cff7 100644 --- a/agent/dns/router.go +++ b/agent/dns/router.go @@ -331,7 +331,7 @@ func getTTLForResult(name string, overrideTTL *uint32, query *discovery.Query, c switch query.QueryType { case discovery.QueryTypeService, discovery.QueryTypePreparedQuery: - ttl, ok := cfg.getTTLForService(name) + ttl, ok := cfg.GetTTLForService(name) if ok { return uint32(ttl / time.Second) } @@ -417,9 +417,14 @@ func (r *Router) ReloadConfig(newCfg *config.RuntimeConfig) error { return nil } -// getTTLForService Find the TTL for a given service. +// GetConfig returns the current router config +func (r *Router) GetConfig() *RouterDynamicConfig { + return r.dynamicConfig.Load().(*RouterDynamicConfig) +} + +// GetTTLForService Find the TTL for a given service. // return ttl, true if found, 0, false otherwise -func (cfg *RouterDynamicConfig) getTTLForService(service string) (time.Duration, bool) { +func (cfg *RouterDynamicConfig) GetTTLForService(service string) (time.Duration, bool) { if cfg.TTLStrict != nil { ttl, ok := cfg.TTLStrict[service] if ok { diff --git a/agent/dns/router_test.go b/agent/dns/router_test.go index c96cf752d1..5cc1050b17 100644 --- a/agent/dns/router_test.go +++ b/agent/dns/router_test.go @@ -6,6 +6,7 @@ package dns import ( "errors" "fmt" + "github.com/armon/go-radix" "net" "reflect" "testing" @@ -3155,7 +3156,7 @@ func TestRouterDynamicConfig_GetTTLForService(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - actual, ok := cfg.getTTLForService(tc.inputKey) + actual, ok := cfg.GetTTLForService(tc.inputKey) require.Equal(t, tc.shouldMatch, ok) require.Equal(t, tc.expectedDuration, actual) }) @@ -3462,3 +3463,72 @@ func TestDNS_syncExtra(t *testing.T) { func getUint32Ptr(i uint32) *uint32 { return &i } + +func TestRouter_ReloadConfig(t *testing.T) { + cdf := discovery.NewMockCatalogDataFetcher(t) + cfg := buildDNSConfig(nil, cdf, nil) + router, err := NewRouter(cfg) + require.NoError(t, err) + + router.recursor = newMockDnsRecursor(t) + + // Reload the config + newAgentConfig := &config.RuntimeConfig{ + DNSARecordLimit: 123, + DNSEnableTruncate: true, + DNSNodeTTL: 234, + DNSRecursorStrategy: "strategy-123", + DNSRecursorTimeout: 345, + DNSUDPAnswerLimit: 456, + DNSNodeMetaTXT: true, + DNSDisableCompression: true, + DNSSOA: config.RuntimeSOAConfig{ + Expire: 123, + Minttl: 234, + Refresh: 345, + Retry: 456, + }, + DNSServiceTTL: map[string]time.Duration{ + "wildcard-config-*": 123, + "strict-config": 234, + }, + DNSRecursors: []string{ + "8.8.8.8", + "2001:4860:4860::8888", + }, + } + + expectTTLRadix := radix.New() + expectTTLRadix.Insert("wildcard-config-", time.Duration(123)) + + expectedCfg := &RouterDynamicConfig{ + ARecordLimit: 123, + EnableTruncate: true, + NodeTTL: 234, + RecursorStrategy: "strategy-123", + RecursorTimeout: 345, + UDPAnswerLimit: 456, + NodeMetaTXT: true, + DisableCompression: true, + SOAConfig: SOAConfig{ + Expire: 123, + Minttl: 234, + Refresh: 345, + Retry: 456, + }, + TTLRadix: expectTTLRadix, + TTLStrict: map[string]time.Duration{ + "strict-config": 234, + }, + Recursors: []string{ + "8.8.8.8:53", + "[2001:4860:4860::8888]:53", + }, + } + err = router.ReloadConfig(newAgentConfig) + require.NoError(t, err) + savedCfg := router.dynamicConfig.Load().(*RouterDynamicConfig) + + // Ensure the new config is used + require.Equal(t, expectedCfg, savedCfg) +} diff --git a/agent/dns/server.go b/agent/dns/server.go index 74da3fa663..764fb15980 100644 --- a/agent/dns/server.go +++ b/agent/dns/server.go @@ -24,6 +24,7 @@ import ( type DNSRouter interface { HandleRequest(req *dns.Msg, reqCtx Context, remoteAddress net.Addr) *dns.Msg ServeDNS(w dns.ResponseWriter, req *dns.Msg) + GetConfig() *RouterDynamicConfig ReloadConfig(newCfg *config.RuntimeConfig) error } @@ -93,6 +94,7 @@ func (d *Server) Shutdown() { d.logger.Error("Error stopping DNS server", "error", err) } } + d.Router = nil } // GetAddr is a function to return the server address if is not nil. diff --git a/agent/dns/server_test.go b/agent/dns/server_test.go new file mode 100644 index 0000000000..7ede22efda --- /dev/null +++ b/agent/dns/server_test.go @@ -0,0 +1,78 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package dns + +import ( + "github.com/hashicorp/consul/agent/config" + "github.com/hashicorp/consul/sdk/testutil" + "github.com/stretchr/testify/require" + "testing" +) + +// TestServer_ReloadConfig tests that the ReloadConfig method calls the router's ReloadConfig method. +func TestDNSServer_ReloadConfig(t *testing.T) { + srv, err := NewServer(Config{ + AgentConfig: &config.RuntimeConfig{ + DNSDomain: "test-domain", + DNSAltDomain: "test-alt-domain", + }, + Logger: testutil.Logger(t), + }) + srv.Router = NewMockDNSRouter(t) + require.NoError(t, err) + cfg := &config.RuntimeConfig{ + DNSARecordLimit: 123, + DNSEnableTruncate: true, + DNSNodeTTL: 123, + DNSRecursorStrategy: "test", + DNSRecursorTimeout: 123, + DNSUDPAnswerLimit: 123, + DNSNodeMetaTXT: true, + DNSDisableCompression: true, + DNSSOA: config.RuntimeSOAConfig{ + Expire: 123, + Refresh: 123, + Retry: 123, + Minttl: 123, + }, + } + srv.Router.(*MockDNSRouter).On("ReloadConfig", cfg).Return(nil) + err = srv.ReloadConfig(cfg) + require.NoError(t, err) + require.True(t, srv.Router.(*MockDNSRouter).AssertExpectations(t)) +} + +// TestDNSServer_Lifecycle tests that the server can be started and shutdown. +func TestDNSServer_Lifecycle(t *testing.T) { + // Arrange + srv, err := NewServer(Config{ + AgentConfig: &config.RuntimeConfig{ + DNSDomain: "test-domain", + DNSAltDomain: "test-alt-domain", + }, + Logger: testutil.Logger(t), + }) + defer srv.Shutdown() + require.NotNil(t, srv.Router) + require.NoError(t, err) + require.NotNil(t, srv) + + ch := make(chan bool) + go func() { + err = srv.ListenAndServe("udp", "127.0.0.1:8500", func() { + ch <- true + }) + require.NoError(t, err) + }() + started, ok := <-ch + require.True(t, ok) + require.True(t, started) + require.NotNil(t, srv.Handler) + require.NotNil(t, srv.Handler.(*Router)) + require.NotNil(t, srv.PacketConn) + + //Shutdown + srv.Shutdown() + require.Nil(t, srv.Router) +} diff --git a/agent/dns_test.go b/agent/dns_test.go index 35e80a856a..b39026465f 100644 --- a/agent/dns_test.go +++ b/agent/dns_test.go @@ -16,6 +16,7 @@ import ( "context" "errors" "fmt" + "github.com/hashicorp/consul/agent/discovery" "math" "math/rand" "net" @@ -33,6 +34,7 @@ import ( "github.com/hashicorp/consul/acl" "github.com/hashicorp/consul/agent/config" "github.com/hashicorp/consul/agent/consul" + dnsConsul "github.com/hashicorp/consul/agent/dns" "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/lib" "github.com/hashicorp/consul/sdk/testutil/retry" @@ -3305,6 +3307,8 @@ func TestDNS_Compression_Recurse(t *testing.T) { } } +// TestDNS_V1ConfigReload validates that the dns configuration is saved to the +// DNS server when v1 DNS is configured and reload config internal is called. func TestDNS_V1ConfigReload(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") @@ -3417,10 +3421,252 @@ func TestDNS_V1ConfigReload(t *testing.T) { require.Equal(t, uint32(30), cfg.SOAConfig.Expire) require.Equal(t, uint32(40), cfg.SOAConfig.Minttl) } - } -// TODO (v2-dns) add a test for checking the V2 DNS Server reloads the config (NET-8056) +// TestDNS_V2ConfigReload_WithV1DataFetcher validates that the dns configuration is saved to the +// DNS server when v2 DNS is configured with V1 catalog and reload config internal is called. +func TestDNS_V2ConfigReload_WithV1DataFetcher(t *testing.T) { + if testing.Short() { + t.Skip("too slow for testing.Short") + } + + a := NewTestAgent(t, ` + experiments=["v2dns"] + recursors = ["8.8.8.8:53"] + dns_config = { + allow_stale = false + max_stale = "20s" + node_ttl = "10s" + service_ttl = { + "my_services*" = "5s" + "my_specific_service" = "30s" + } + enable_truncate = false + only_passing = false + recursor_strategy = "sequential" + recursor_timeout = "15s" + disable_compression = false + a_record_limit = 1 + enable_additional_node_meta_txt = false + soa = { + refresh = 1 + retry = 2 + expire = 3 + min_ttl = 4 + } + } + `) + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + for _, s := range a.dnsServers { + server, ok := s.(*dnsConsul.Server) + require.True(t, ok) + + cfg := server.Router.GetConfig() + require.Equal(t, []string{"8.8.8.8:53"}, cfg.Recursors) + require.Equal(t, structs.RecursorStrategy("sequential"), cfg.RecursorStrategy) + df := a.catalogDataFetcher.(*discovery.V1DataFetcher) + dfCfg := df.GetConfig() + + require.False(t, dfCfg.AllowStale) + require.Equal(t, 20*time.Second, dfCfg.MaxStale) + require.Equal(t, 10*time.Second, cfg.NodeTTL) + ttl, _ := cfg.GetTTLForService("my_services_1") + require.Equal(t, 5*time.Second, ttl) + ttl, _ = cfg.GetTTLForService("my_specific_service") + require.Equal(t, 30*time.Second, ttl) + require.False(t, cfg.EnableTruncate) + require.False(t, dfCfg.OnlyPassing) + require.Equal(t, 15*time.Second, cfg.RecursorTimeout) + require.False(t, cfg.DisableCompression) + require.Equal(t, 1, cfg.ARecordLimit) + require.False(t, cfg.NodeMetaTXT) + require.Equal(t, uint32(1), cfg.SOAConfig.Refresh) + require.Equal(t, uint32(2), cfg.SOAConfig.Retry) + require.Equal(t, uint32(3), cfg.SOAConfig.Expire) + require.Equal(t, uint32(4), cfg.SOAConfig.Minttl) + } + + newCfg := *a.Config + newCfg.DNSRecursors = []string{"1.1.1.1:53"} + newCfg.DNSAllowStale = true + newCfg.DNSMaxStale = 21 * time.Second + newCfg.DNSNodeTTL = 11 * time.Second + newCfg.DNSServiceTTL = map[string]time.Duration{ + "2_my_services*": 6 * time.Second, + "2_my_specific_service": 31 * time.Second, + } + newCfg.DNSEnableTruncate = true + newCfg.DNSOnlyPassing = true + newCfg.DNSRecursorStrategy = "random" + newCfg.DNSRecursorTimeout = 16 * time.Second + newCfg.DNSDisableCompression = true + newCfg.DNSARecordLimit = 2 + newCfg.DNSNodeMetaTXT = true + newCfg.DNSSOA.Refresh = 10 + newCfg.DNSSOA.Retry = 20 + newCfg.DNSSOA.Expire = 30 + newCfg.DNSSOA.Minttl = 40 + + err := a.reloadConfigInternal(&newCfg) + require.NoError(t, err) + + for _, s := range a.dnsServers { + server, ok := s.(*dnsConsul.Server) + require.True(t, ok) + + cfg := server.Router.GetConfig() + require.Equal(t, []string{"1.1.1.1:53"}, cfg.Recursors) + require.Equal(t, structs.RecursorStrategy("random"), cfg.RecursorStrategy) + df := a.catalogDataFetcher.(*discovery.V1DataFetcher) + dfCfg := df.GetConfig() + require.True(t, dfCfg.AllowStale) + require.Equal(t, 21*time.Second, dfCfg.MaxStale) + require.Equal(t, 11*time.Second, cfg.NodeTTL) + ttl, _ := cfg.GetTTLForService("my_services_1") + require.Equal(t, time.Duration(0), ttl) + ttl, _ = cfg.GetTTLForService("2_my_services_1") + require.Equal(t, 6*time.Second, ttl) + ttl, _ = cfg.GetTTLForService("my_specific_service") + require.Equal(t, time.Duration(0), ttl) + ttl, _ = cfg.GetTTLForService("2_my_specific_service") + require.Equal(t, 31*time.Second, ttl) + require.True(t, cfg.EnableTruncate) + require.True(t, dfCfg.OnlyPassing) + require.Equal(t, 16*time.Second, cfg.RecursorTimeout) + require.True(t, cfg.DisableCompression) + require.Equal(t, 2, cfg.ARecordLimit) + require.True(t, cfg.NodeMetaTXT) + require.Equal(t, uint32(10), cfg.SOAConfig.Refresh) + require.Equal(t, uint32(20), cfg.SOAConfig.Retry) + require.Equal(t, uint32(30), cfg.SOAConfig.Expire) + require.Equal(t, uint32(40), cfg.SOAConfig.Minttl) + } +} + +// TestDNS_V2ConfigReload_WithV2DataFetcher validates that the dns configuration is saved to the +// DNS server when v2 DNS is configured with V1 catalog and reload config internal is called. +func TestDNS_V2ConfigReload_WithV2DataFetcher(t *testing.T) { + if testing.Short() { + t.Skip("too slow for testing.Short") + } + + a := NewTestAgent(t, ` + experiments=["v2dns", "resource-apis"] + recursors = ["8.8.8.8:53"] + dns_config = { + allow_stale = false + max_stale = "20s" + node_ttl = "10s" + service_ttl = { + "my_services*" = "5s" + "my_specific_service" = "30s" + } + enable_truncate = false + only_passing = false + recursor_strategy = "sequential" + recursor_timeout = "15s" + disable_compression = false + a_record_limit = 1 + enable_additional_node_meta_txt = false + soa = { + refresh = 1 + retry = 2 + expire = 3 + min_ttl = 4 + } + } + `) + defer a.Shutdown() + // use WaitForRaftLeader with v2 resource apis + testrpc.WaitForRaftLeader(t, a.RPC, "dc1") + + for _, s := range a.dnsServers { + server, ok := s.(*dnsConsul.Server) + require.True(t, ok) + + cfg := server.Router.GetConfig() + require.Equal(t, []string{"8.8.8.8:53"}, cfg.Recursors) + require.Equal(t, structs.RecursorStrategy("sequential"), cfg.RecursorStrategy) + df := a.catalogDataFetcher.(*discovery.V2DataFetcher) + dfCfg := df.GetConfig() + + //require.False(t, dfCfg.AllowStale) + //require.Equal(t, 20*time.Second, dfCfg.MaxStale) + require.Equal(t, 10*time.Second, cfg.NodeTTL) + ttl, _ := cfg.GetTTLForService("my_services_1") + require.Equal(t, 5*time.Second, ttl) + ttl, _ = cfg.GetTTLForService("my_specific_service") + require.Equal(t, 30*time.Second, ttl) + require.False(t, cfg.EnableTruncate) + require.False(t, dfCfg.OnlyPassing) + require.Equal(t, 15*time.Second, cfg.RecursorTimeout) + require.False(t, cfg.DisableCompression) + require.Equal(t, 1, cfg.ARecordLimit) + require.False(t, cfg.NodeMetaTXT) + require.Equal(t, uint32(1), cfg.SOAConfig.Refresh) + require.Equal(t, uint32(2), cfg.SOAConfig.Retry) + require.Equal(t, uint32(3), cfg.SOAConfig.Expire) + require.Equal(t, uint32(4), cfg.SOAConfig.Minttl) + } + + newCfg := *a.Config + newCfg.DNSRecursors = []string{"1.1.1.1:53"} + newCfg.DNSAllowStale = true + newCfg.DNSMaxStale = 21 * time.Second + newCfg.DNSNodeTTL = 11 * time.Second + newCfg.DNSServiceTTL = map[string]time.Duration{ + "2_my_services*": 6 * time.Second, + "2_my_specific_service": 31 * time.Second, + } + newCfg.DNSEnableTruncate = true + newCfg.DNSOnlyPassing = true + newCfg.DNSRecursorStrategy = "random" + newCfg.DNSRecursorTimeout = 16 * time.Second + newCfg.DNSDisableCompression = true + newCfg.DNSARecordLimit = 2 + newCfg.DNSNodeMetaTXT = true + newCfg.DNSSOA.Refresh = 10 + newCfg.DNSSOA.Retry = 20 + newCfg.DNSSOA.Expire = 30 + newCfg.DNSSOA.Minttl = 40 + + err := a.reloadConfigInternal(&newCfg) + require.NoError(t, err) + + for _, s := range a.dnsServers { + server, ok := s.(*dnsConsul.Server) + require.True(t, ok) + + cfg := server.Router.GetConfig() + require.Equal(t, []string{"1.1.1.1:53"}, cfg.Recursors) + require.Equal(t, structs.RecursorStrategy("random"), cfg.RecursorStrategy) + df := a.catalogDataFetcher.(*discovery.V2DataFetcher) + dfCfg := df.GetConfig() + //require.True(t, dfCfg.AllowStale) + //require.Equal(t, 21*time.Second, dfCfg.MaxStale) + require.Equal(t, 11*time.Second, cfg.NodeTTL) + ttl, _ := cfg.GetTTLForService("my_services_1") + require.Equal(t, time.Duration(0), ttl) + ttl, _ = cfg.GetTTLForService("2_my_services_1") + require.Equal(t, 6*time.Second, ttl) + ttl, _ = cfg.GetTTLForService("my_specific_service") + require.Equal(t, time.Duration(0), ttl) + ttl, _ = cfg.GetTTLForService("2_my_specific_service") + require.Equal(t, 31*time.Second, ttl) + require.True(t, cfg.EnableTruncate) + require.True(t, dfCfg.OnlyPassing) + require.Equal(t, 16*time.Second, cfg.RecursorTimeout) + require.True(t, cfg.DisableCompression) + require.Equal(t, 2, cfg.ARecordLimit) + require.True(t, cfg.NodeMetaTXT) + require.Equal(t, uint32(10), cfg.SOAConfig.Refresh) + require.Equal(t, uint32(20), cfg.SOAConfig.Retry) + require.Equal(t, uint32(30), cfg.SOAConfig.Expire) + require.Equal(t, uint32(40), cfg.SOAConfig.Minttl) + } +} func TestDNS_ReloadConfig_DuringQuery(t *testing.T) { if testing.Short() {