From 938d2315e0e37a47e738e4bc4a5eb5f3199bce04 Mon Sep 17 00:00:00 2001 From: John Murret Date: Wed, 17 Jan 2024 16:46:18 -0700 Subject: [PATCH] DNS v2 - add virtual ip questions (#20245) --- agent/agent.go | 4 +- agent/discovery/discovery.go | 17 +- agent/discovery/mock_CatalogDataFetcher.go | 190 ++++++++++++++++++++ agent/discovery/query_fetcher_v1.go | 51 +++++- agent/discovery/query_fetcher_v1_ce_test.go | 11 ++ agent/discovery/query_fetcher_v1_test.go | 105 +++++++++++ agent/discovery/query_fetcher_v2.go | 12 ++ agent/dns/query_locality.go | 42 +++++ agent/dns/query_locality_ce.go | 57 ++++++ agent/dns/query_locality_ce_test.go | 60 +++++++ agent/dns/query_locality_test.go | 74 ++++++++ agent/dns/router.go | 85 +++++---- agent/dns/router_query.go | 125 +++++++++++++ agent/dns/router_query_ce_test.go | 81 +++++++++ agent/dns/router_query_test.go | 37 ++++ agent/dns/router_test.go | 134 ++++++++++++-- agent/dns/server.go | 5 +- agent/dns_ce.go | 4 + agent/dns_ce_test.go | 50 ++++++ agent/dns_test.go | 74 +++++++- 20 files changed, 1157 insertions(+), 61 deletions(-) create mode 100644 agent/discovery/mock_CatalogDataFetcher.go create mode 100644 agent/discovery/query_fetcher_v1_ce_test.go create mode 100644 agent/discovery/query_fetcher_v1_test.go create mode 100644 agent/dns/query_locality.go create mode 100644 agent/dns/query_locality_ce.go create mode 100644 agent/dns/query_locality_ce_test.go create mode 100644 agent/dns/query_locality_test.go create mode 100644 agent/dns/router_query.go create mode 100644 agent/dns/router_query_ce_test.go create mode 100644 agent/dns/router_query_test.go diff --git a/agent/agent.go b/agent/agent.go index c70315b57f..5e7ef7e022 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -1103,7 +1103,7 @@ func (a *Agent) listenAndServeV2DNS() error { if a.baseDeps.UseV2Resources() { a.catalogDataFetcher = discovery.NewV2DataFetcher(a.config) } else { - a.catalogDataFetcher = discovery.NewV1DataFetcher(a.config) + a.catalogDataFetcher = discovery.NewV1DataFetcher(a.config, a.RPC, a.logger.Named("catalog-data-fetcher")) } // Generate a Query Processor with the appropriate data fetcher @@ -1115,7 +1115,7 @@ func (a *Agent) listenAndServeV2DNS() error { // create server cfg := dns.Config{ AgentConfig: a.config, - EntMeta: a.AgentEnterpriseMeta(), // TODO (v2-dns): does this even work for v2 tenancy? + EntMeta: *a.AgentEnterpriseMeta(), Logger: a.logger, Processor: processor, TokenFunc: a.getTokenFunc(), diff --git a/agent/discovery/discovery.go b/agent/discovery/discovery.go index 921a9ef4c2..89950ac690 100644 --- a/agent/discovery/discovery.go +++ b/agent/discovery/discovery.go @@ -7,6 +7,7 @@ import ( "fmt" "net" + "github.com/hashicorp/consul/acl" "github.com/hashicorp/consul/agent/config" "github.com/hashicorp/consul/agent/structs" ) @@ -25,6 +26,7 @@ type QueryType string const ( QueryTypeConnect QueryType = "CONNECT" // deprecated: use for V1 only QueryTypeIngress QueryType = "INGRESS" // deprecated: use for V1 only + QueryTypeInvalid QueryType = "INVALID" QueryTypeNode QueryType = "NODE" QueryTypePreparedQuery QueryType = "PREPARED_QUERY" // deprecated: use for V1 only QueryTypeService QueryType = "SERVICE" @@ -32,6 +34,7 @@ const ( QueryTypeWorkload QueryType = "WORKLOAD" // V2-only ) +// Context is used to pass information about the request. type Context struct { Token string DefaultPartition string @@ -39,12 +42,12 @@ type Context struct { DefaultLocality *structs.Locality } +// QueryTenancy is used to filter catalog data based on tenancy. type QueryTenancy struct { - Partition string - Namespace string - SamenessGroup string - Peer string - Datacenter string + EnterpriseMeta acl.EnterpriseMeta + SamenessGroup string + Peer string + Datacenter string } // QueryPayload represents all information needed by the data backend @@ -87,6 +90,7 @@ type Result struct { Target string } +// LookupType is used by the CatalogDataFetcher to properly filter endpoints. type LookupType string const ( @@ -124,10 +128,12 @@ type CatalogDataFetcher interface { FetchPreparedQuery(ctx Context, req *QueryPayload) ([]*Result, error) } +// QueryProcessor is used to process a Discovery Query and return the results. type QueryProcessor struct { dataFetcher CatalogDataFetcher } +// NewQueryProcessor creates a new QueryProcessor. func NewQueryProcessor(dataFetcher CatalogDataFetcher) *QueryProcessor { return &QueryProcessor{ dataFetcher: dataFetcher, @@ -163,6 +169,7 @@ func (p *QueryProcessor) QueryByName(query *Query, ctx Context) ([]*Result, erro } } +// QueryByIP is used to look up a service or node from an IP address. func (p *QueryProcessor) QueryByIP(ip net.IP, ctx Context) ([]*Result, error) { return p.dataFetcher.FetchRecordsByIp(ctx, ip) } diff --git a/agent/discovery/mock_CatalogDataFetcher.go b/agent/discovery/mock_CatalogDataFetcher.go new file mode 100644 index 0000000000..5a035ecfe3 --- /dev/null +++ b/agent/discovery/mock_CatalogDataFetcher.go @@ -0,0 +1,190 @@ +// Code generated by mockery v2.32.4. DO NOT EDIT. + +package discovery + +import ( + config "github.com/hashicorp/consul/agent/config" + mock "github.com/stretchr/testify/mock" + + net "net" +) + +// MockCatalogDataFetcher is an autogenerated mock type for the CatalogDataFetcher type +type MockCatalogDataFetcher struct { + mock.Mock +} + +// FetchEndpoints provides a mock function with given fields: ctx, req, lookupType +func (_m *MockCatalogDataFetcher) FetchEndpoints(ctx Context, req *QueryPayload, lookupType LookupType) ([]*Result, error) { + ret := _m.Called(ctx, req, lookupType) + + var r0 []*Result + var r1 error + if rf, ok := ret.Get(0).(func(Context, *QueryPayload, LookupType) ([]*Result, error)); ok { + return rf(ctx, req, lookupType) + } + if rf, ok := ret.Get(0).(func(Context, *QueryPayload, LookupType) []*Result); ok { + r0 = rf(ctx, req, lookupType) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*Result) + } + } + + if rf, ok := ret.Get(1).(func(Context, *QueryPayload, LookupType) error); ok { + r1 = rf(ctx, req, lookupType) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// FetchNodes provides a mock function with given fields: ctx, req +func (_m *MockCatalogDataFetcher) FetchNodes(ctx Context, req *QueryPayload) ([]*Result, error) { + ret := _m.Called(ctx, req) + + var r0 []*Result + var r1 error + if rf, ok := ret.Get(0).(func(Context, *QueryPayload) ([]*Result, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(Context, *QueryPayload) []*Result); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*Result) + } + } + + if rf, ok := ret.Get(1).(func(Context, *QueryPayload) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// FetchPreparedQuery provides a mock function with given fields: ctx, req +func (_m *MockCatalogDataFetcher) FetchPreparedQuery(ctx Context, req *QueryPayload) ([]*Result, error) { + ret := _m.Called(ctx, req) + + var r0 []*Result + var r1 error + if rf, ok := ret.Get(0).(func(Context, *QueryPayload) ([]*Result, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(Context, *QueryPayload) []*Result); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*Result) + } + } + + if rf, ok := ret.Get(1).(func(Context, *QueryPayload) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// FetchRecordsByIp provides a mock function with given fields: ctx, ip +func (_m *MockCatalogDataFetcher) FetchRecordsByIp(ctx Context, ip net.IP) ([]*Result, error) { + ret := _m.Called(ctx, ip) + + var r0 []*Result + var r1 error + if rf, ok := ret.Get(0).(func(Context, net.IP) ([]*Result, error)); ok { + return rf(ctx, ip) + } + if rf, ok := ret.Get(0).(func(Context, net.IP) []*Result); ok { + r0 = rf(ctx, ip) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*Result) + } + } + + if rf, ok := ret.Get(1).(func(Context, net.IP) error); ok { + r1 = rf(ctx, ip) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// FetchVirtualIP provides a mock function with given fields: ctx, req +func (_m *MockCatalogDataFetcher) FetchVirtualIP(ctx Context, req *QueryPayload) (*Result, error) { + ret := _m.Called(ctx, req) + + var r0 *Result + var r1 error + if rf, ok := ret.Get(0).(func(Context, *QueryPayload) (*Result, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(Context, *QueryPayload) *Result); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*Result) + } + } + + if rf, ok := ret.Get(1).(func(Context, *QueryPayload) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// FetchWorkload provides a mock function with given fields: ctx, req +func (_m *MockCatalogDataFetcher) FetchWorkload(ctx Context, req *QueryPayload) (*Result, error) { + ret := _m.Called(ctx, req) + + var r0 *Result + var r1 error + if rf, ok := ret.Get(0).(func(Context, *QueryPayload) (*Result, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(Context, *QueryPayload) *Result); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*Result) + } + } + + if rf, ok := ret.Get(1).(func(Context, *QueryPayload) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// LoadConfig provides a mock function with given fields: _a0 +func (_m *MockCatalogDataFetcher) LoadConfig(_a0 *config.RuntimeConfig) { + _m.Called(_a0) +} + +// NewMockCatalogDataFetcher creates a new instance of MockCatalogDataFetcher. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockCatalogDataFetcher(t interface { + mock.TestingT + Cleanup(func()) +}) *MockCatalogDataFetcher { + mock := &MockCatalogDataFetcher{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/agent/discovery/query_fetcher_v1.go b/agent/discovery/query_fetcher_v1.go index b6f885778e..b1a4553c33 100644 --- a/agent/discovery/query_fetcher_v1.go +++ b/agent/discovery/query_fetcher_v1.go @@ -4,11 +4,15 @@ package discovery import ( + "context" "net" "sync/atomic" "time" + "github.com/hashicorp/go-hclog" + "github.com/hashicorp/consul/agent/config" + "github.com/hashicorp/consul/agent/structs" ) const ( @@ -17,6 +21,7 @@ const ( maxRecurseRecords = 5 ) +// v1DataFetcherDynamicConfig is used to store the dynamic configuration of the V1 data fetcher. type v1DataFetcherDynamicConfig struct { allowStale bool maxStale time.Duration @@ -25,16 +30,27 @@ type v1DataFetcherDynamicConfig struct { onlyPassing bool } +// V1DataFetcher is used to fetch data from the V1 catalog. type V1DataFetcher struct { dynamicConfig atomic.Value + logger hclog.Logger + + rpcFunc func(ctx context.Context, method string, args interface{}, reply interface{}) error } -func NewV1DataFetcher(config *config.RuntimeConfig) *V1DataFetcher { - f := &V1DataFetcher{} +// NewV1DataFetcher creates a new V1 data fetcher. +func NewV1DataFetcher(config *config.RuntimeConfig, + rpcFunc func(ctx context.Context, method string, args interface{}, reply interface{}) error, + logger hclog.Logger) *V1DataFetcher { + f := &V1DataFetcher{ + rpcFunc: rpcFunc, + logger: logger, + } f.LoadConfig(config) return f } +// LoadConfig loads the configuration for the V1 data fetcher. func (f *V1DataFetcher) LoadConfig(config *config.RuntimeConfig) { dynamicConfig := &v1DataFetcherDynamicConfig{ allowStale: config.DNSAllowStale, @@ -48,26 +64,55 @@ func (f *V1DataFetcher) LoadConfig(config *config.RuntimeConfig) { // TODO (v2-dns): Implementation of the V1 data fetcher +// FetchNodes fetches A/AAAA/CNAME func (f *V1DataFetcher) FetchNodes(ctx Context, req *QueryPayload) ([]*Result, error) { return nil, nil } +// FetchEndpoints fetches records for A/AAAA/CNAME or SRV requests for services func (f *V1DataFetcher) FetchEndpoints(ctx Context, req *QueryPayload, lookupType LookupType) ([]*Result, error) { return nil, nil } +// FetchVirtualIP fetches A/AAAA records for virtual IPs func (f *V1DataFetcher) FetchVirtualIP(ctx Context, req *QueryPayload) (*Result, error) { - return nil, nil + args := structs.ServiceSpecificRequest{ + // The datacenter of the request is not specified because cross-datacenter virtual IP + // queries are not supported. This guard rail is in place because virtual IPs are allocated + // within a DC, therefore their uniqueness is not guaranteed globally. + PeerName: req.Tenancy.Peer, + ServiceName: req.Name, + EnterpriseMeta: req.Tenancy.EnterpriseMeta, + QueryOptions: structs.QueryOptions{ + Token: ctx.Token, + }, + } + + var out string + if err := f.rpcFunc(context.Background(), "Catalog.VirtualIPForService", &args, &out); err != nil { + return nil, err + } + + result := &Result{ + Address: out, + Type: ResultTypeVirtual, + } + return result, nil } +// FetchRecordsByIp is used for PTR requests to look up a service/node from an IP. func (f *V1DataFetcher) FetchRecordsByIp(ctx Context, ip net.IP) ([]*Result, error) { return nil, nil } +// FetchWorkload fetches a single Result associated with +// V2 Workload. V2-only. func (f *V1DataFetcher) FetchWorkload(ctx Context, req *QueryPayload) (*Result, error) { return nil, nil } +// FetchPreparedQuery evaluates the results of a prepared query. +// deprecated in V2 func (f *V1DataFetcher) FetchPreparedQuery(ctx Context, req *QueryPayload) ([]*Result, error) { return nil, nil } diff --git a/agent/discovery/query_fetcher_v1_ce_test.go b/agent/discovery/query_fetcher_v1_ce_test.go new file mode 100644 index 0000000000..7376e50560 --- /dev/null +++ b/agent/discovery/query_fetcher_v1_ce_test.go @@ -0,0 +1,11 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +//go:build !consulent + +package discovery + +import "github.com/hashicorp/consul/acl" + +// defaultEntMeta is the default enterprise meta used for testing. +var defaultEntMeta = acl.EnterpriseMeta{} diff --git a/agent/discovery/query_fetcher_v1_test.go b/agent/discovery/query_fetcher_v1_test.go new file mode 100644 index 0000000000..23864b4d7e --- /dev/null +++ b/agent/discovery/query_fetcher_v1_test.go @@ -0,0 +1,105 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package discovery + +import ( + "errors" + "testing" + "time" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + cachetype "github.com/hashicorp/consul/agent/cache-types" + "github.com/hashicorp/consul/agent/config" + "github.com/hashicorp/consul/agent/structs" + "github.com/hashicorp/consul/sdk/testutil" +) + +// Test_FetchService tests the FetchService method in scenarios where the RPC +// call succeeds and fails. +func Test_FetchVirtualIP(t *testing.T) { + // set these to confirm that RPC call does not use them for this particular RPC + rc := &config.RuntimeConfig{ + DNSAllowStale: true, + DNSMaxStale: 100, + DNSUseCache: true, + DNSCacheMaxAge: 100, + } + tests := []struct { + name string + queryPayload *QueryPayload + context Context + expectedResult *Result + expectedErr error + }{ + { + name: "FetchVirtualIP returns result", + queryPayload: &QueryPayload{ + Name: "db", + Tenancy: QueryTenancy{ + Peer: "test-peer", + EnterpriseMeta: defaultEntMeta, + }, + }, + context: Context{ + Token: "test-token", + }, + expectedResult: &Result{ + Address: "192.168.10.10", + Type: ResultTypeVirtual, + }, + expectedErr: nil, + }, + { + name: "FetchVirtualIP returns error", + queryPayload: &QueryPayload{ + Name: "db", + Tenancy: QueryTenancy{ + Peer: "test-peer", + EnterpriseMeta: defaultEntMeta, + }, + }, + context: Context{ + Token: "test-token", + }, + expectedResult: nil, + expectedErr: errors.New("test-error"), + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + logger := testutil.Logger(t) + mockRPC := cachetype.NewMockRPC(t) + mockRPC.On("RPC", mock.Anything, "Catalog.VirtualIPForService", mock.Anything, mock.Anything). + Return(tc.expectedErr). + Run(func(args mock.Arguments) { + req := args.Get(2).(*structs.ServiceSpecificRequest) + + // validate RPC options are not set from config for the VirtuaLIPForService RPC + require.False(t, req.AllowStale) + require.Equal(t, time.Duration(0), req.MaxStaleDuration) + require.False(t, req.UseCache) + require.Equal(t, time.Duration(0), req.MaxAge) + + // validate RPC options are set correctly from the queryPayload and context + require.Equal(t, tc.queryPayload.Tenancy.Peer, req.PeerName) + require.Equal(t, tc.queryPayload.Tenancy.EnterpriseMeta, req.EnterpriseMeta) + require.Equal(t, tc.context.Token, req.QueryOptions.Token) + + if tc.expectedErr == nil { + // set the out parameter to ensure that it is used to formulate the result.Address + reply := args.Get(3).(*string) + *reply = tc.expectedResult.Address + } + }) + df := NewV1DataFetcher(rc, mockRPC.RPC, logger) + + result, err := df.FetchVirtualIP(tc.context, tc.queryPayload) + require.Equal(t, tc.expectedErr, err) + require.Equal(t, tc.expectedResult, result) + }) + } +} diff --git a/agent/discovery/query_fetcher_v2.go b/agent/discovery/query_fetcher_v2.go index fd2ce87ed9..2cefc656af 100644 --- a/agent/discovery/query_fetcher_v2.go +++ b/agent/discovery/query_fetcher_v2.go @@ -10,20 +10,24 @@ import ( "github.com/hashicorp/consul/agent/config" ) +// v2DataFetcherDynamicConfig is used to store the dynamic configuration of the V2 data fetcher. type v2DataFetcherDynamicConfig struct { onlyPassing bool } +// V2DataFetcher is used to fetch data from the V2 catalog. type V2DataFetcher struct { dynamicConfig atomic.Value } +// NewV2DataFetcher creates a new V2 data fetcher. func NewV2DataFetcher(config *config.RuntimeConfig) *V2DataFetcher { f := &V2DataFetcher{} f.LoadConfig(config) return f } +// LoadConfig loads the configuration for the V2 data fetcher. func (f *V2DataFetcher) LoadConfig(config *config.RuntimeConfig) { dynamicConfig := &v2DataFetcherDynamicConfig{ onlyPassing: config.DNSOnlyPassing, @@ -33,26 +37,34 @@ func (f *V2DataFetcher) LoadConfig(config *config.RuntimeConfig) { // TODO (v2-dns): Implementation of the V2 data fetcher +// FetchNodes fetches A/AAAA/CNAME func (f *V2DataFetcher) FetchNodes(ctx Context, req *QueryPayload) ([]*Result, error) { return nil, nil } +// FetchEndpoints fetches records for A/AAAA/CNAME or SRV requests for services func (f *V2DataFetcher) FetchEndpoints(ctx Context, req *QueryPayload, lookupType LookupType) ([]*Result, error) { return nil, nil } +// FetchVirtualIP fetches A/AAAA records for virtual IPs func (f *V2DataFetcher) FetchVirtualIP(ctx Context, req *QueryPayload) (*Result, error) { return nil, nil } +// FetchRecordsByIp is used for PTR requests to look up a service/node from an IP. func (f *V2DataFetcher) FetchRecordsByIp(ctx Context, ip net.IP) ([]*Result, error) { return nil, nil } +// FetchWorkload is used to fetch a single workload from the V2 catalog. +// V2-only. func (f *V2DataFetcher) FetchWorkload(ctx Context, req *QueryPayload) (*Result, error) { return nil, nil } +// FetchPreparedQuery is used to fetch a prepared query from the V2 catalog. +// Deprecated in V2. func (f *V2DataFetcher) FetchPreparedQuery(ctx Context, req *QueryPayload) ([]*Result, error) { return nil, nil } diff --git a/agent/dns/query_locality.go b/agent/dns/query_locality.go new file mode 100644 index 0000000000..6ad2ea27bd --- /dev/null +++ b/agent/dns/query_locality.go @@ -0,0 +1,42 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package dns + +import "github.com/hashicorp/consul/acl" + +// queryLocality is the locality parsed from a DNS query. +type queryLocality struct { + // datacenter is the datacenter parsed from a label that has an explicit datacenter part. + // Example query: .virtual..ns..ap..dc.consul + datacenter string + + // peer is the peer name parsed from a label that has explicit parts. + // Example query: .virtual..ns..peer..ap.consul + peer string + + // peerOrDatacenter is parsed from DNS queries where the datacenter and peer name are + // specified in the same query part. + // Example query: .virtual..consul + // + // Note that this field should only be a "peer" for virtual queries, since virtual IPs should + // not be shared between datacenters. In all other cases, it should be considered a DC. + peerOrDatacenter string + + acl.EnterpriseMeta +} + +// EffectiveDatacenter returns the datacenter parsed from a query, or a default +// value if none is specified. +func (l queryLocality) EffectiveDatacenter(defaultDC string) string { + // Prefer the value parsed from a query with explicit parts: .ns..ap..dc + if l.datacenter != "" { + return l.datacenter + } + // Fall back to the ambiguously parsed DC or Peer. + if l.peerOrDatacenter != "" { + return l.peerOrDatacenter + } + // If all are empty, use a default value. + return defaultDC +} diff --git a/agent/dns/query_locality_ce.go b/agent/dns/query_locality_ce.go new file mode 100644 index 0000000000..23a080c0df --- /dev/null +++ b/agent/dns/query_locality_ce.go @@ -0,0 +1,57 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +//go:build !consulent + +package dns + +import ( + "github.com/hashicorp/consul/acl" + "github.com/hashicorp/consul/agent/config" +) + +// ParseLocality can parse peer name or datacenter from a DNS query's labels. +// Peer name is parsed from the same query part that datacenter is, so given this ambiguity +// we parse a "peerOrDatacenter". The caller or RPC handler are responsible for disambiguating. +func ParseLocality(labels []string, defaultEnterpriseMeta acl.EnterpriseMeta, _ enterpriseDNSConfig) (queryLocality, bool) { + locality := queryLocality{ + EnterpriseMeta: defaultEnterpriseMeta, + } + + switch len(labels) { + case 2, 4: + // Support the following formats: + // - [..dc] + // - [..peer] + for i := 0; i < len(labels); i += 2 { + switch labels[i+1] { + case "dc": + locality.datacenter = labels[i] + case "peer": + locality.peer = labels[i] + default: + return queryLocality{}, false + } + } + // Return error when both datacenter and peer are specified. + if locality.datacenter != "" && locality.peer != "" { + return queryLocality{}, false + } + return locality, true + case 1: + return queryLocality{peerOrDatacenter: labels[0]}, true + + case 0: + return queryLocality{}, true + } + + return queryLocality{}, false +} + +// enterpriseDNSConfig is the configuration for enterprise DNS. +type enterpriseDNSConfig struct{} + +// getEnterpriseDNSConfig returns the enterprise DNS configuration. +func getEnterpriseDNSConfig(conf *config.RuntimeConfig) enterpriseDNSConfig { + return enterpriseDNSConfig{} +} diff --git a/agent/dns/query_locality_ce_test.go b/agent/dns/query_locality_ce_test.go new file mode 100644 index 0000000000..59d49ed336 --- /dev/null +++ b/agent/dns/query_locality_ce_test.go @@ -0,0 +1,60 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +//go:build !consulent + +package dns + +import ( + "github.com/hashicorp/consul/acl" +) + +func getTestCases() []testCaseParseLocality { + testCases := []testCaseParseLocality{ + { + name: "test [..dc]", + labels: []string{"test-dc", "dc"}, + enterpriseDNSConfig: enterpriseDNSConfig{}, + expectedResult: queryLocality{ + EnterpriseMeta: acl.EnterpriseMeta{}, + datacenter: "test-dc", + }, + expectedOK: true, + }, + { + name: "test [..peer]", + labels: []string{"test-peer", "peer"}, + enterpriseDNSConfig: enterpriseDNSConfig{}, + expectedResult: queryLocality{ + EnterpriseMeta: acl.EnterpriseMeta{}, + peer: "test-peer", + }, + expectedOK: true, + }, + { + name: "test 1 label", + labels: []string{"test-peer"}, + enterpriseDNSConfig: enterpriseDNSConfig{}, + expectedResult: queryLocality{ + EnterpriseMeta: acl.EnterpriseMeta{}, + peerOrDatacenter: "test-peer", + }, + expectedOK: true, + }, + { + name: "test 0 labels", + labels: []string{}, + enterpriseDNSConfig: enterpriseDNSConfig{}, + expectedResult: queryLocality{}, + expectedOK: true, + }, + { + name: "test 3 labels returns not found", + labels: []string{"test-dc", "dc", "test-blah"}, + enterpriseDNSConfig: enterpriseDNSConfig{}, + expectedResult: queryLocality{}, + expectedOK: false, + }, + } + return testCases +} diff --git a/agent/dns/query_locality_test.go b/agent/dns/query_locality_test.go new file mode 100644 index 0000000000..84008a7a6d --- /dev/null +++ b/agent/dns/query_locality_test.go @@ -0,0 +1,74 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package dns + +import ( + "testing" + + "github.com/hashicorp/consul/acl" + "github.com/stretchr/testify/require" +) + +type testCaseParseLocality struct { + name string + labels []string + defaultMeta acl.EnterpriseMeta + enterpriseDNSConfig enterpriseDNSConfig + expectedResult queryLocality + expectedOK bool +} + +func Test_ParseLocality(t *testing.T) { + testCases := getTestCases() + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + actualResult, actualOK := ParseLocality(tc.labels, tc.defaultMeta, tc.enterpriseDNSConfig) + require.Equal(t, tc.expectedOK, actualOK) + require.Equal(t, tc.expectedResult, actualResult) + + }) + } + +} + +func Test_EffectiveDatacenter(t *testing.T) { + type testCase struct { + name string + queryLocality queryLocality + defaultDC string + expected string + } + testCases := []testCase{ + { + name: "return datacenter first", + queryLocality: queryLocality{ + datacenter: "test-dc", + peerOrDatacenter: "test-peer", + }, + defaultDC: "default-dc", + expected: "test-dc", + }, + { + name: "return peerOrDatacenter second", + queryLocality: queryLocality{ + peerOrDatacenter: "test-peer", + }, + defaultDC: "default-dc", + expected: "test-peer", + }, + { + name: "return defaultDC as fallback", + queryLocality: queryLocality{}, + defaultDC: "default-dc", + expected: "default-dc", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got := tc.queryLocality.EffectiveDatacenter(tc.defaultDC) + require.Equal(t, tc.expected, got) + }) + } +} diff --git a/agent/dns/router.go b/agent/dns/router.go index 08f5f8c814..bf6b29f077 100644 --- a/agent/dns/router.go +++ b/agent/dns/router.go @@ -15,6 +15,7 @@ import ( "github.com/hashicorp/go-hclog" "github.com/miekg/dns" + "github.com/hashicorp/consul/acl" "github.com/hashicorp/consul/agent/config" "github.com/hashicorp/consul/agent/discovery" "github.com/hashicorp/consul/agent/structs" @@ -54,6 +55,8 @@ type RouterDynamicConfig struct { // TTLStrict sets TTLs to service by full name match. It Has higher priority than TTLRadix TTLStrict map[string]time.Duration UDPAnswerLimit int + + enterpriseDNSConfig } type SOAConfig struct { @@ -74,15 +77,15 @@ type DiscoveryQueryProcessor interface { // Router replaces miekg/dns.ServeMux with a simpler router that only checks for the 2-3 valid domains // that Consul supports and forwards to a single DiscoveryQueryProcessor handler. If there is no match, it will recurse. type Router struct { - processor DiscoveryQueryProcessor - domain string - altDomain string - logger hclog.Logger + processor DiscoveryQueryProcessor + domain string + altDomain string + datacenter string + logger hclog.Logger tokenFunc func() string - defaultNamespace string - defaultPartition string + defaultEntMeta acl.EnterpriseMeta // TODO (v2-dns): default locality for request context? @@ -99,16 +102,13 @@ func NewRouter(cfg Config) (*Router, error) { altDomain := dns.CanonicalName(cfg.AgentConfig.DNSAltDomain) // TODO (v2-dns): need to figure out tenancy information here in a way that work for V2 and V1 - router := &Router{ - processor: cfg.Processor, - domain: domain, - altDomain: altDomain, - logger: cfg.Logger.Named(logging.DNS), - tokenFunc: cfg.TokenFunc, - // TODO (v2-dns): see tenancy question above - //defaultPartition: ?, - //defaultNamespace: ?, + processor: cfg.Processor, + domain: domain, + altDomain: altDomain, + logger: cfg.Logger.Named(logging.DNS), + tokenFunc: cfg.TokenFunc, + defaultEntMeta: cfg.EntMeta, } if err := router.ReloadConfig(cfg.AgentConfig); err != nil { @@ -117,7 +117,7 @@ func NewRouter(cfg Config) (*Router, error) { return router, nil } -// HandleRequest is used to process and individual DNS request. It returns a message in success or fail cases. +// HandleRequest is used to process an individual DNS request. It returns a message in success or fail cases. func (r *Router) HandleRequest(req *dns.Msg, reqCtx discovery.Context, remoteAddress net.Addr) *dns.Msg { cfg := r.dynamicConfig.Load().(*RouterDynamicConfig) @@ -138,21 +138,8 @@ func (r *Router) HandleRequest(req *dns.Msg, reqCtx discovery.Context, remoteAdd return createServerFailureResponse(req, cfg, false) } - var results []*discovery.Result - switch reqType { - case requestTypeName: - //query, err := r.buildQuery(req, reqCtx) - //results, err = r.processor.QueryByName(query, reqCtx) - // TODO (v2-dns): implement requestTypeName - // This will call discovery.QueryByName - r.logger.Error("requestTypeName not implemented") - case requestTypeIP: - // TODO (v2-dns): implement requestTypeIP - // This will call discovery.QueryByIP - r.logger.Error("requestTypeIP not implemented") - case requestTypeAddress: - results, err = buildAddressResults(req) - } + results, err := r.getQueryResults(req, reqCtx, reqType, cfg) + if err != nil && errors.Is(err, errNameNotFound) { r.logger.Error("name not found", "name", req.Question[0].Name) return createNameErrorResponse(req, cfg, responseDomain) @@ -172,6 +159,26 @@ func (r *Router) HandleRequest(req *dns.Msg, reqCtx discovery.Context, remoteAdd return resp } +// getQueryResults returns a discovery.Result from a DNS message. +func (r *Router) getQueryResults(req *dns.Msg, reqCtx discovery.Context, reqType requestType, cfg *RouterDynamicConfig) ([]*discovery.Result, error) { + switch reqType { + case requestTypeName: + query, err := buildQueryFromDNSMessage(req, r.domain, r.altDomain, cfg, r.defaultEntMeta) + if err != nil { + r.logger.Error("error building discovery query from DNS request", "error", err) + return nil, err + } + return r.processor.QueryByName(query, reqCtx) + case requestTypeIP: + // TODO (v2-dns): implement requestTypeIP + // This will call discovery.QueryByIP + return nil, errors.New("requestTypeIP not implemented") + case requestTypeAddress: + return buildAddressResults(req) + } + return nil, errors.New("invalid request type") +} + // ServeDNS implements the miekg/dns.Handler interface. // This is a standard DNS listener, so we inject a default request context based on the agent's config. func (r *Router) ServeDNS(w dns.ResponseWriter, req *dns.Msg) { @@ -190,6 +197,7 @@ func (r *Router) ReloadConfig(newCfg *config.RuntimeConfig) error { return nil } +// defaultAgentDNSRequestContext returns a default request context based on the agent's config. func (r *Router) defaultAgentDNSRequestContext() discovery.Context { return discovery.Context{ Token: r.tokenFunc(), @@ -198,6 +206,7 @@ func (r *Router) defaultAgentDNSRequestContext() discovery.Context { } } +// validateAndNormalizeRequest validates the DNS request and normalizes the request name. func validateAndNormalizeRequest(req *dns.Msg) error { // like upstream miekg/dns, we require at least one question, // but we will only answer the first. @@ -255,6 +264,7 @@ func (r *Router) parseDomain(req *dns.Msg) (requestType, string, bool) { return "", "", true } +// serializeQueryResults converts a discovery.Result into a DNS message. func (r *Router) serializeQueryResults(req *dns.Msg, results []*discovery.Result, cfg *RouterDynamicConfig, responseDomain string) (*dns.Msg, error) { resp := new(dns.Msg) resp.SetReply(req) @@ -271,6 +281,7 @@ func (r *Router) serializeQueryResults(req *dns.Msg, results []*discovery.Result return resp, nil } +// stripSuffix strips off the suffixes that may have been added to the request name. func stripSuffix(target string) (string, bool) { enableFailover := false @@ -289,6 +300,7 @@ func stripSuffix(target string) (string, bool) { return target, enableFailover } +// isAddrSubdomain returns true if the domain is a valid addr subdomain. func isAddrSubdomain(domain string) bool { labels := dns.SplitDomainName(domain) @@ -316,6 +328,7 @@ func getDynamicRouterConfig(conf *config.RuntimeConfig) (*RouterDynamicConfig, e Refresh: conf.DNSSOA.Refresh, Retry: conf.DNSSOA.Retry, }, + enterpriseDNSConfig: getEnterpriseDNSConfig(conf), } // TODO (v2-dns): add service TTL recalculation @@ -324,10 +337,12 @@ func getDynamicRouterConfig(conf *config.RuntimeConfig) (*RouterDynamicConfig, e return cfg, nil } +// canRecurse returns true if the router can recurse on the request. func canRecurse(cfg *RouterDynamicConfig) bool { return len(cfg.Recursors) > 0 } +// createServerFailureResponse returns a SERVFAIL message. func createServerFailureResponse(req *dns.Msg, cfg *RouterDynamicConfig, recursionAvailable bool) *dns.Msg { // Return a SERVFAIL message m := &dns.Msg{} @@ -339,6 +354,7 @@ func createServerFailureResponse(req *dns.Msg, cfg *RouterDynamicConfig, recursi return m } +// createRefusedResponse returns a REFUSED message. func createRefusedResponse(req *dns.Msg) *dns.Msg { // Return a REFUSED message m := &dns.Msg{} @@ -346,6 +362,7 @@ func createRefusedResponse(req *dns.Msg) *dns.Msg { return m } +// createNameErrorResponse returns a NXDOMAIN message. func createNameErrorResponse(req *dns.Msg, cfg *RouterDynamicConfig, domain string) *dns.Msg { // Return a NXDOMAIN message m := &dns.Msg{} @@ -376,6 +393,7 @@ func createNameErrorResponse(req *dns.Msg, cfg *RouterDynamicConfig, domain stri return m } +// buildAddressResults returns a discovery.Result from a DNS request for addr. records. func buildAddressResults(req *dns.Msg) ([]*discovery.Result, error) { domain := dns.CanonicalName(req.Question[0].Name) labels := dns.SplitDomainName(domain) @@ -399,6 +417,7 @@ func buildAddressResults(req *dns.Msg) ([]*discovery.Result, error) { }, nil } +// buildQueryFromDNSMessage appends the discovery result to the dns message. func appendResultToDNSResponse(result *discovery.Result, req *dns.Msg, resp *dns.Msg, _ string, cfg *RouterDynamicConfig) { ip, ok := convertToIp(result) @@ -411,7 +430,7 @@ func appendResultToDNSResponse(result *discovery.Result, req *dns.Msg, resp *dns var ttl uint32 switch result.Type { - case discovery.ResultTypeNode: + case discovery.ResultTypeNode, discovery.ResultTypeVirtual: ttl = uint32(cfg.NodeTTL / time.Second) case discovery.ResultTypeService: // TODO (v2-dns): implement service TTL using the radix tree @@ -444,6 +463,7 @@ func appendResultToDNSResponse(result *discovery.Result, req *dns.Msg, resp *dns resp.Answer = append(resp.Answer, record) } +// convertToIp converts a discovery.Result to a net.IP. func convertToIp(result *discovery.Result) (net.IP, bool) { ip := net.ParseIP(result.Address) if ip == nil { @@ -452,6 +472,7 @@ func convertToIp(result *discovery.Result) (net.IP, bool) { return ip, true } +// n A or AAAA record for the given name and IP. // Note: we might want to pass in the Query Name here, which is used in addr. and virtual. queries // since there is only ever one result. Right now choosing to leave it off for simplification. func makeRecord(name string, ip net.IP, ttl uint32) (dns.RR, bool) { diff --git a/agent/dns/router_query.go b/agent/dns/router_query.go new file mode 100644 index 0000000000..9815e65aaa --- /dev/null +++ b/agent/dns/router_query.go @@ -0,0 +1,125 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package dns + +import ( + "errors" + "strings" + + "github.com/miekg/dns" + + "github.com/hashicorp/consul/acl" + "github.com/hashicorp/consul/agent/discovery" +) + +// buildQueryFromDNSMessage returns a discovery.Query from a DNS message. +func buildQueryFromDNSMessage(req *dns.Msg, domain, altDomain string, cfg *RouterDynamicConfig, defaultEntMeta acl.EnterpriseMeta) (*discovery.Query, error) { + queryType, queryParts, querySuffixes := getQueryTypePartsAndSuffixesFromDNSMessage(req, domain, altDomain) + + locality, ok := ParseLocality(querySuffixes, defaultEntMeta, cfg.enterpriseDNSConfig) + if !ok { + return nil, errors.New("invalid locality") + } + + // TODO(v2-dns): This needs to be deprecated. + peerName := locality.peer + if peerName == "" { + // If the peer name was not explicitly defined, fall back to the ambiguously-parsed version. + peerName = locality.peerOrDatacenter + } + + return &discovery.Query{ + QueryType: queryType, + QueryPayload: discovery.QueryPayload{ + Name: queryParts[len(queryParts)-1], + Tenancy: discovery.QueryTenancy{ + EnterpriseMeta: locality.EnterpriseMeta, + // v2-dns: revisit if we need this after the rest of this works. + // SamenessGroup: "", + // The datacenter of the request is not specified because cross-datacenter virtual IP + // queries are not supported. This guard rail is in place because virtual IPs are allocated + // within a DC, therefore their uniqueness is not guaranteed globally. + Peer: peerName, + Datacenter: locality.datacenter, + }, + // TODO(v2-dns): what should these be? + //PortName: "", + //Tag: "", + //RemoteAddr: nil, + //DisableFailover: false, + }, + }, nil +} + +// getQueryTypePartsAndSuffixesFromDNSMessage returns the query type, the parts, and suffixes of the query name. +func getQueryTypePartsAndSuffixesFromDNSMessage(req *dns.Msg, domain, altDomain string) (queryType discovery.QueryType, parts []string, suffixes []string) { + // Get the QName without the domain suffix + qName := trimDomainFromQuestionName(req.Question[0].Name, domain, altDomain) + + // Split into the label parts + labels := dns.SplitDomainName(qName) + + done := false + for i := len(labels) - 1; i >= 0 && !done; i-- { + queryType = getQueryTypeFromLabels(labels[i]) + switch queryType { + case discovery.QueryTypeInvalid: + // If we don't recognize the query type, we keep going until we find one we do. + case discovery.QueryTypeService, + discovery.QueryTypeConnect, discovery.QueryTypeVirtual, discovery.QueryTypeIngress, + discovery.QueryTypeNode, discovery.QueryTypePreparedQuery: + parts = labels[:i] + suffixes = labels[i+1:] + done = true + default: + // If this is a SRV query the "service" label is optional, we add it back to use the + // existing code-path. + if req.Question[0].Qtype == dns.TypeSRV && strings.HasPrefix(labels[i], "_") { + parts = labels[:i+1] + suffixes = labels[i+1:] + done = true + } + } + } + + return queryType, parts, suffixes +} + +// trimDomainFromQuestionName returns the question name without the domain suffix. +func trimDomainFromQuestionName(questionName, domain, altDomain string) string { + qName := strings.ToLower(dns.Fqdn(questionName)) + longer := domain + shorter := altDomain + + if len(shorter) > len(longer) { + longer, shorter = shorter, longer + } + + if strings.HasSuffix(qName, "."+strings.TrimLeft(longer, ".")) { + return strings.TrimSuffix(qName, longer) + } + return strings.TrimSuffix(qName, shorter) +} + +// getQueryTypeFromLabels returns the query type from the labels. +func getQueryTypeFromLabels(label string) discovery.QueryType { + switch label { + case "service": + return discovery.QueryTypeService + case "connect": + return discovery.QueryTypeConnect + case "virtual": + return discovery.QueryTypeVirtual + case "ingress": + return discovery.QueryTypeIngress + case "node": + return discovery.QueryTypeNode + case "query": + return discovery.QueryTypePreparedQuery + case "workload": + return discovery.QueryTypeWorkload + default: + return discovery.QueryTypeInvalid + } +} diff --git a/agent/dns/router_query_ce_test.go b/agent/dns/router_query_ce_test.go new file mode 100644 index 0000000000..13337dfbe0 --- /dev/null +++ b/agent/dns/router_query_ce_test.go @@ -0,0 +1,81 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +//go:build !consulent + +package dns + +import ( + "github.com/miekg/dns" + + "github.com/hashicorp/consul/acl" + "github.com/hashicorp/consul/agent/discovery" +) + +func getBuildQueryFromDNSMessageTestCases() []testCaseBuildQueryFromDNSMessage { + testCases := []testCaseBuildQueryFromDNSMessage{ + // virtual ip queries + { + name: "test A 'virtual.' query, ipv4 response", + request: &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Opcode: dns.OpcodeQuery, + }, + Question: []dns.Question{ + { + Name: "db.virtual.consul", // "intentionally missing the trailing dot" + Qtype: dns.TypeA, + Qclass: dns.ClassINET, + }, + }, + }, + expectedQuery: &discovery.Query{ + QueryType: discovery.QueryTypeVirtual, + QueryPayload: discovery.QueryPayload{ + Name: "db", + PortName: "", + Tag: "", + Tenancy: discovery.QueryTenancy{ + EnterpriseMeta: acl.EnterpriseMeta{}, + SamenessGroup: "", + Peer: "consul", + Datacenter: "", + }, + DisableFailover: false, + }, + }, + }, + { + name: "test A 'virtual.' with peer query, ipv4 response", + request: &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Opcode: dns.OpcodeQuery, + }, + Question: []dns.Question{ + { + Name: "db.virtual.consul", // "intentionally missing the trailing dot" + Qtype: dns.TypeA, + Qclass: dns.ClassINET, + }, + }, + }, + expectedQuery: &discovery.Query{ + QueryType: discovery.QueryTypeVirtual, + QueryPayload: discovery.QueryPayload{ + Name: "db", + PortName: "", + Tag: "", + Tenancy: discovery.QueryTenancy{ + EnterpriseMeta: acl.EnterpriseMeta{}, + SamenessGroup: "", + Peer: "consul", // this gets set in the query building after ParseLocality processes. + Datacenter: "", + }, + DisableFailover: false, + }, + }, + }, + } + + return testCases +} diff --git a/agent/dns/router_query_test.go b/agent/dns/router_query_test.go new file mode 100644 index 0000000000..14348bfb96 --- /dev/null +++ b/agent/dns/router_query_test.go @@ -0,0 +1,37 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package dns + +import ( + "testing" + + "github.com/miekg/dns" + + "github.com/hashicorp/consul/acl" + "github.com/hashicorp/consul/agent/discovery" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// testCaseBuildQueryFromDNSMessage is a test case for the buildQueryFromDNSMessage function. +type testCaseBuildQueryFromDNSMessage struct { + name string + request *dns.Msg + requestContext *discovery.Context + expectedQuery *discovery.Query +} + +// Test_buildQueryFromDNSMessage tests the buildQueryFromDNSMessage function. +func Test_buildQueryFromDNSMessage(t *testing.T) { + + testCases := getBuildQueryFromDNSMessageTestCases() + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + query, err := buildQueryFromDNSMessage(tc.request, "domain", "altDomain", &RouterDynamicConfig{}, acl.EnterpriseMeta{}) + require.NoError(t, err) + assert.Equal(t, tc.expectedQuery, query) + }) + } +} diff --git a/agent/dns/router_test.go b/agent/dns/router_test.go index 1735053d67..2f1dc133e6 100644 --- a/agent/dns/router_test.go +++ b/agent/dns/router_test.go @@ -10,8 +10,10 @@ import ( "github.com/hashicorp/go-hclog" "github.com/miekg/dns" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + "github.com/hashicorp/consul/acl" "github.com/hashicorp/consul/agent/config" "github.com/hashicorp/consul/agent/discovery" ) @@ -25,20 +27,19 @@ import ( // 4. Something case insensitive func Test_HandleRequest(t *testing.T) { - type testCase struct { - name string - agentConfig *config.RuntimeConfig // This will override the default test Router Config - mockProcessorResponseByName []*discovery.Result // These will be fed to the mock processor to be returned in order - mockProcessorResponseByIP []*discovery.Result - mockProcessorError error - request *dns.Msg - requestContext *discovery.Context - remoteAddress net.Addr - response *dns.Msg + name string + agentConfig *config.RuntimeConfig // This will override the default test Router Config + configureDataFetcher func(fetcher discovery.CatalogDataFetcher) + mockProcessorError error + request *dns.Msg + requestContext *discovery.Context + remoteAddress net.Addr + response *dns.Msg } testCases := []testCase{ + // addr queries { name: "test A 'addr.' query, ipv4 response", request: &dns.Msg{ @@ -421,10 +422,114 @@ func Test_HandleRequest(t *testing.T) { }, }, }, + // virtual ip queries - we will test just the A record, since the + // AAAA and SRV records are handled the same way and the complete + // set of addr tests above cover the rest of the cases. + { + name: "test A 'virtual.' query, ipv4 response", + request: &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Opcode: dns.OpcodeQuery, + }, + Question: []dns.Question{ + { + Name: "c000020a.virtual.consul", // "intentionally missing the trailing dot" + Qtype: dns.TypeA, + Qclass: dns.ClassINET, + }, + }, + }, + configureDataFetcher: func(fetcher discovery.CatalogDataFetcher) { + fetcher.(*discovery.MockCatalogDataFetcher).On("FetchVirtualIP", + mock.Anything, mock.Anything).Return(&discovery.Result{ + Address: "240.0.0.2", + Type: discovery.ResultTypeVirtual, + }, nil) + }, + response: &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Opcode: dns.OpcodeQuery, + Response: true, + Authoritative: true, + }, + Compress: true, + Question: []dns.Question{ + { + Name: "c000020a.virtual.consul.", + Qtype: dns.TypeA, + Qclass: dns.ClassINET, + }, + }, + Answer: []dns.RR{ + &dns.A{ + Hdr: dns.RR_Header{ + Name: "c000020a.virtual.consul.", + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: 123, + }, + A: net.ParseIP("240.0.0.2"), + }, + }, + }, + }, + { + name: "test A 'virtual.' query, ipv6 response", + // Since we asked for an A record, the AAAA record that resolves from the address is attached as an extra + request: &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Opcode: dns.OpcodeQuery, + }, + Question: []dns.Question{ + { + Name: "20010db800010002cafe000000001337.virtual.dc1.consul", + Qtype: dns.TypeA, + Qclass: dns.ClassINET, + }, + }, + }, + configureDataFetcher: func(fetcher discovery.CatalogDataFetcher) { + fetcher.(*discovery.MockCatalogDataFetcher).On("FetchVirtualIP", + mock.Anything, mock.Anything).Return(&discovery.Result{ + Address: "2001:db8:1:2:cafe::1337", + Type: discovery.ResultTypeVirtual, + }, nil) + }, + response: &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Opcode: dns.OpcodeQuery, + Response: true, + Authoritative: true, + }, + Compress: true, + Question: []dns.Question{ + { + Name: "20010db800010002cafe000000001337.virtual.dc1.consul.", + Qtype: dns.TypeA, + Qclass: dns.ClassINET, + }, + }, + Extra: []dns.RR{ + &dns.AAAA{ + Hdr: dns.RR_Header{ + Name: "20010db800010002cafe000000001337.virtual.dc1.consul.", + Rrtype: dns.TypeAAAA, + Class: dns.ClassINET, + Ttl: 123, + }, + AAAA: net.ParseIP("2001:db8:1:2:cafe::1337"), + }, + }, + }, + }, } run := func(t *testing.T, tc testCase) { - cfg := buildDNSConfig(tc.agentConfig, tc.mockProcessorResponseByName, tc.mockProcessorResponseByIP, tc.mockProcessorError) + cdf := &discovery.MockCatalogDataFetcher{} + if tc.configureDataFetcher != nil { + tc.configureDataFetcher(cdf) + } + cfg := buildDNSConfig(tc.agentConfig, cdf, tc.mockProcessorError) router, err := NewRouter(cfg) require.NoError(t, err) @@ -433,7 +538,6 @@ func Test_HandleRequest(t *testing.T) { if ctx == nil { ctx = &discovery.Context{} } - actual := router.HandleRequest(tc.request, *ctx, tc.remoteAddress) require.Equal(t, tc.response, actual) } @@ -446,7 +550,7 @@ func Test_HandleRequest(t *testing.T) { } -func buildDNSConfig(agentConfig *config.RuntimeConfig, _ []*discovery.Result, _ []*discovery.Result, _ error) Config { +func buildDNSConfig(agentConfig *config.RuntimeConfig, cdf discovery.CatalogDataFetcher, _ error) Config { cfg := Config{ AgentConfig: &config.RuntimeConfig{ DNSDomain: "consul", @@ -458,9 +562,9 @@ func buildDNSConfig(agentConfig *config.RuntimeConfig, _ []*discovery.Result, _ Minttl: 4, }, }, - EntMeta: nil, + EntMeta: acl.EnterpriseMeta{}, Logger: hclog.NewNullLogger(), - Processor: nil, // TODO (v2-dns): build this from a mock with the reponses loaded + Processor: discovery.NewQueryProcessor(cdf), TokenFunc: func() string { return "" }, } diff --git a/agent/dns/server.go b/agent/dns/server.go index 2559ff8c1c..55cc06326f 100644 --- a/agent/dns/server.go +++ b/agent/dns/server.go @@ -26,12 +26,13 @@ type Server struct { // Config represent all the DNS configuration required to construct a DNS server. type Config struct { AgentConfig *config.RuntimeConfig - EntMeta *acl.EnterpriseMeta + EntMeta acl.EnterpriseMeta Logger hclog.Logger Processor DiscoveryQueryProcessor TokenFunc func() string } +// NewServer creates a new DNS server. func NewServer(config Config) (*Server, error) { router, err := NewRouter(config) if err != nil { @@ -45,6 +46,7 @@ func NewServer(config Config) (*Server, error) { return srv, nil } +// ListenAndServe starts the DNS server. func (d *Server) ListenAndServe(network, addr string, notif func()) error { d.Server = &dns.Server{ Addr: addr, @@ -63,6 +65,7 @@ func (d *Server) ReloadConfig(newCfg *config.RuntimeConfig) error { return d.Router.ReloadConfig(newCfg) } +// Shutdown stops the DNS server. func (d *Server) Shutdown() { if d.Server != nil { d.logger.Info("Stopping server", diff --git a/agent/dns_ce.go b/agent/dns_ce.go index bc6216e30a..4eb74442fa 100644 --- a/agent/dns_ce.go +++ b/agent/dns_ce.go @@ -12,6 +12,10 @@ import ( "github.com/hashicorp/consul/agent/config" ) +// NOTE: these functions have also been copied to agent/dns package for dns v2. +// If you change these functions, please also change the ones in agent/dns as well. +// These v1 versions will soon be deprecated. + type enterpriseDNSConfig struct{} func getEnterpriseDNSConfig(conf *config.RuntimeConfig) enterpriseDNSConfig { diff --git a/agent/dns_ce_test.go b/agent/dns_ce_test.go index 976b7e7ebe..c01f62d938 100644 --- a/agent/dns_ce_test.go +++ b/agent/dns_ce_test.go @@ -129,3 +129,53 @@ func TestDNS_CE_PeeredServices(t *testing.T) { assertARec(t, q.Answer[0], "web-proxy.service.peer1.peer.consul.", "199.0.0.1") }) } + +func getTestCasesParseLocality() []testCaseParseLocality { + testCases := []testCaseParseLocality{ + { + name: "test [..dc]", + labels: []string{"test-dc", "dc"}, + enterpriseDNSConfig: enterpriseDNSConfig{}, + expectedResult: queryLocality{ + EnterpriseMeta: acl.EnterpriseMeta{}, + datacenter: "test-dc", + }, + expectedOK: true, + }, + { + name: "test [..peer]", + labels: []string{"test-peer", "peer"}, + enterpriseDNSConfig: enterpriseDNSConfig{}, + expectedResult: queryLocality{ + EnterpriseMeta: acl.EnterpriseMeta{}, + peer: "test-peer", + }, + expectedOK: true, + }, + { + name: "test 1 label", + labels: []string{"test-peer"}, + enterpriseDNSConfig: enterpriseDNSConfig{}, + expectedResult: queryLocality{ + EnterpriseMeta: acl.EnterpriseMeta{}, + peerOrDatacenter: "test-peer", + }, + expectedOK: true, + }, + { + name: "test 0 labels", + labels: []string{}, + enterpriseDNSConfig: enterpriseDNSConfig{}, + expectedResult: queryLocality{}, + expectedOK: true, + }, + { + name: "test 3 labels returns not found", + labels: []string{"test-dc", "dc", "test-blah"}, + enterpriseDNSConfig: enterpriseDNSConfig{}, + expectedResult: queryLocality{}, + expectedOK: false, + }, + } + return testCases +} diff --git a/agent/dns_test.go b/agent/dns_test.go index 083e4b9536..8f584172cd 100644 --- a/agent/dns_test.go +++ b/agent/dns_test.go @@ -29,6 +29,7 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/sync/errgroup" + "github.com/hashicorp/consul/acl" "github.com/hashicorp/consul/agent/config" "github.com/hashicorp/consul/agent/consul" "github.com/hashicorp/consul/agent/structs" @@ -119,7 +120,6 @@ func dnsTXT(src string, txt []string) *dns.TXT { func getVersionHCL(enableV2 bool) map[string]string { versions := map[string]string{ "DNS: v1 / Catalog: v1": "", - //"DNS: v2 / Catalog: v1": `experiments=["v2dns"]`, } if enableV2 { @@ -670,9 +670,9 @@ func TestDNS_VirtualIPLookup(t *testing.T) { t.Parallel() - for name, experimentsHCL := range getVersionHCL(false) { + for name, experimentsHCL := range getVersionHCL(true) { t.Run(name, func(t *testing.T) { - a := StartTestAgent(t, TestAgent{HCL: experimentsHCL, Overrides: `peering = { test_allow_peer_registrations = true }`}) + a := StartTestAgent(t, TestAgent{HCL: experimentsHCL, Overrides: `peering = { test_allow_peer_registrations = true } log_level = "debug"`}) defer a.Shutdown() testrpc.WaitForLeader(t, a.RPC, "dc1") @@ -3841,3 +3841,71 @@ func TestPerfectlyRandomChoices(t *testing.T) { }) } } + +type testCaseParseLocality struct { + name string + labels []string + defaultEntMeta acl.EnterpriseMeta + enterpriseDNSConfig enterpriseDNSConfig + expectedResult queryLocality + expectedOK bool +} + +func Test_ParseLocality(t *testing.T) { + testCases := getTestCasesParseLocality() + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + d := &DNSServer{ + defaultEnterpriseMeta: tc.defaultEntMeta, + } + actualResult, actualOK := d.parseLocality(tc.labels, &dnsConfig{ + enterpriseDNSConfig: tc.enterpriseDNSConfig, + }) + require.Equal(t, tc.expectedOK, actualOK) + require.Equal(t, tc.expectedResult, actualResult) + + }) + } + +} + +func Test_EffectiveDatacenter(t *testing.T) { + type testCase struct { + name string + queryLocality queryLocality + defaultDC string + expected string + } + testCases := []testCase{ + { + name: "return datacenter first", + queryLocality: queryLocality{ + datacenter: "test-dc", + peerOrDatacenter: "test-peer", + }, + defaultDC: "default-dc", + expected: "test-dc", + }, + { + name: "return PeerOrDatacenter second", + queryLocality: queryLocality{ + peerOrDatacenter: "test-peer", + }, + defaultDC: "default-dc", + expected: "test-peer", + }, + { + name: "return defaultDC as fallback", + queryLocality: queryLocality{}, + defaultDC: "default-dc", + expected: "default-dc", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got := tc.queryLocality.effectiveDatacenter(tc.defaultDC) + require.Equal(t, tc.expected, got) + }) + } +}