feat(v2dns): catalog v2 workload query support (#20466)

This commit is contained in:
Dan Stough 2024-02-02 18:29:38 -05:00 committed by GitHub
parent deca6a49bd
commit 9602b43183
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
26 changed files with 1434 additions and 570 deletions

View File

@ -1111,7 +1111,7 @@ func (a *Agent) listenAndServeV2DNS() error {
// Check the catalog version and decide which implementation of the data fetcher to implement
if a.baseDeps.UseV2Resources() {
a.catalogDataFetcher = discovery.NewV2DataFetcher(a.config)
a.catalogDataFetcher = discovery.NewV2DataFetcher(a.config, a.delegate.ResourceServiceClient(), a.logger.Named("catalog-data-fetcher"))
} else {
a.catalogDataFetcher = discovery.NewV1DataFetcher(a.config,
a.AgentEnterpriseMeta(),

View File

@ -7,14 +7,13 @@ import (
"fmt"
"net"
"github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/config"
"github.com/hashicorp/consul/agent/structs"
)
var (
ErrNoData = fmt.Errorf("no data")
ErrECSNotGlobal = fmt.Errorf("ECS response is not global")
ErrNoData = fmt.Errorf("no data")
ErrNotFound = fmt.Errorf("not found")
ErrNotSupported = fmt.Errorf("not supported")
)
@ -65,18 +64,16 @@ const (
// Context is used to pass information about the request.
type Context struct {
Token string
DefaultPartition string
DefaultNamespace string
DefaultLocality *structs.Locality
Token string
}
// QueryTenancy is used to filter catalog data based on tenancy.
type QueryTenancy struct {
EnterpriseMeta acl.EnterpriseMeta
SamenessGroup string
Peer string
Datacenter string
Namespace string
Partition string
SamenessGroup string
Peer string
Datacenter string
}
// QueryPayload represents all information needed by the data backend
@ -89,7 +86,7 @@ type QueryPayload struct {
Tenancy QueryTenancy // tenancy includes any additional labels specified before the domain
// v2 fields only
DisableFailover bool
EnableFailover bool
}
// ResultType indicates the Consul resource that a discovery record represents.
@ -107,11 +104,12 @@ const (
// It is the responsibility of the DNS encoder to know what to do with
// each Result, based on the query type.
type Result struct {
Address string // A/AAAA/CNAME records - could be used in the Extra section. CNAME is required to handle hostname addresses in workloads & nodes.
Weight uint32 // SRV queries
Port uint32 // SRV queries
Metadata map[string]string // Used to collect metadata into TXT Records
Type ResultType // Used to reconstruct the fqdn name of the resource
Address string // A/AAAA/CNAME records - could be used in the Extra section. CNAME is required to handle hostname addresses in workloads & nodes.
Weight uint32 // SRV queries
PortName string // Used to generate a fgdn when a specifc port was queried
PortNumber uint32 // SRV queries
Metadata map[string]string // Used to collect metadata into TXT Records
Type ResultType // Used to reconstruct the fqdn name of the resource
// Used in SRV & PTR queries to point at an A/AAAA Record.
Target string
@ -121,9 +119,10 @@ type Result struct {
// ResultTenancy is used to reconstruct the fqdn name of the resource.
type ResultTenancy struct {
PeerName string
Datacenter string
EnterpriseMeta acl.EnterpriseMeta // TODO (v2-dns): need something that is compatible with the V2 catalog
Namespace string
Partition string
PeerName string
Datacenter string
}
// LookupType is used by the CatalogDataFetcher to properly filter endpoints.
@ -138,6 +137,8 @@ const (
// CatalogDataFetcher is an interface that abstracts data collection
// for Discovery queries. It is assumed that the instantiation also
// includes any agent configuration that influences catalog queries.
//
//go:generate mockery --name CatalogDataFetcher --inpackage
type CatalogDataFetcher interface {
// LoadConfig is used to hot-reload the data fetcher with new agent config.
LoadConfig(config *config.RuntimeConfig)
@ -162,6 +163,13 @@ type CatalogDataFetcher interface {
// FetchPreparedQuery evaluates the results of a prepared query.
// deprecated in V2
FetchPreparedQuery(ctx Context, req *QueryPayload) ([]*Result, error)
// NormalizeRequest mutates the original request based on data fetcher configuration, like
// defaulting tenancy to the agent's partition.
NormalizeRequest(req *QueryPayload)
// ValidateRequest throws an error is any of the input fields are invalid for this data fetcher.
ValidateRequest(ctx Context, req *QueryPayload) error
}
// QueryProcessor is used to process a Discovery Query and return the results.
@ -178,6 +186,12 @@ func NewQueryProcessor(dataFetcher CatalogDataFetcher) *QueryProcessor {
// QueryByName is used to look up a service, node, workload, or prepared query.
func (p *QueryProcessor) QueryByName(query *Query, ctx Context) ([]*Result, error) {
if err := p.dataFetcher.ValidateRequest(ctx, &query.QueryPayload); err != nil {
return nil, err
}
p.dataFetcher.NormalizeRequest(&query.QueryPayload)
switch query.QueryType {
case QueryTypeNode:
return p.dataFetcher.FetchNodes(ctx, &query.QueryPayload)

View File

@ -0,0 +1,221 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package discovery
import (
"errors"
"net"
"testing"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
var (
testContext = Context{
Token: "bar",
}
testErr = errors.New("test error")
testIP = net.ParseIP("1.2.3.4")
testPayload = QueryPayload{
Name: "foo",
}
testResult = &Result{
Address: "1.2.3.4",
Type: ResultTypeNode, // This isn't correct for some test cases, but we are only asserting the right data fetcher functions are called
Target: "foo",
}
)
func TestQueryByName(t *testing.T) {
type testCase struct {
name string
reqType QueryType
configureDataFetcher func(*testing.T, *MockCatalogDataFetcher)
expectedResults []*Result
expectedError error
}
run := func(t *testing.T, tc testCase) {
fetcher := NewMockCatalogDataFetcher(t)
tc.configureDataFetcher(t, fetcher)
qp := NewQueryProcessor(fetcher)
q := Query{
QueryType: tc.reqType,
QueryPayload: testPayload,
}
results, err := qp.QueryByName(&q, testContext)
if tc.expectedError != nil {
require.Error(t, err)
require.True(t, errors.Is(err, tc.expectedError))
require.Nil(t, results)
return
}
require.NoError(t, err)
require.Equal(t, tc.expectedResults, results)
}
testCases := []testCase{
{
name: "query node",
reqType: QueryTypeNode,
configureDataFetcher: func(t *testing.T, fetcher *MockCatalogDataFetcher) {
fetcher.On("ValidateRequest", mock.Anything, mock.Anything).Return(nil)
fetcher.On("NormalizeRequest", mock.Anything)
fetcher.On("FetchNodes", mock.Anything, mock.Anything).Return([]*Result{testResult}, nil)
},
expectedResults: []*Result{testResult},
},
{
name: "query service",
reqType: QueryTypeService,
configureDataFetcher: func(t *testing.T, fetcher *MockCatalogDataFetcher) {
fetcher.On("ValidateRequest", mock.Anything, mock.Anything).Return(nil)
fetcher.On("NormalizeRequest", mock.Anything)
fetcher.On("FetchEndpoints", mock.Anything, mock.Anything, mock.Anything).Return([]*Result{testResult}, nil)
},
expectedResults: []*Result{testResult},
},
{
name: "query connect",
reqType: QueryTypeConnect,
configureDataFetcher: func(t *testing.T, fetcher *MockCatalogDataFetcher) {
fetcher.On("ValidateRequest", mock.Anything, mock.Anything).Return(nil)
fetcher.On("NormalizeRequest", mock.Anything)
fetcher.On("FetchEndpoints", mock.Anything, mock.Anything, mock.Anything).Return([]*Result{testResult}, nil)
},
expectedResults: []*Result{testResult},
},
{
name: "query ingress",
reqType: QueryTypeIngress,
configureDataFetcher: func(t *testing.T, fetcher *MockCatalogDataFetcher) {
fetcher.On("ValidateRequest", mock.Anything, mock.Anything).Return(nil)
fetcher.On("NormalizeRequest", mock.Anything)
fetcher.On("FetchEndpoints", mock.Anything, mock.Anything, mock.Anything).Return([]*Result{testResult}, nil)
},
expectedResults: []*Result{testResult},
},
{
name: "query virtual ip",
reqType: QueryTypeVirtual,
configureDataFetcher: func(t *testing.T, fetcher *MockCatalogDataFetcher) {
fetcher.On("ValidateRequest", mock.Anything, mock.Anything).Return(nil)
fetcher.On("NormalizeRequest", mock.Anything)
fetcher.On("FetchVirtualIP", mock.Anything, mock.Anything).Return(testResult, nil)
},
expectedResults: []*Result{testResult},
},
{
name: "query workload",
reqType: QueryTypeWorkload,
configureDataFetcher: func(t *testing.T, fetcher *MockCatalogDataFetcher) {
fetcher.On("ValidateRequest", mock.Anything, mock.Anything).Return(nil)
fetcher.On("NormalizeRequest", mock.Anything)
fetcher.On("FetchWorkload", mock.Anything, mock.Anything).Return(testResult, nil)
},
expectedResults: []*Result{testResult},
},
{
name: "query prepared query",
reqType: QueryTypePreparedQuery,
configureDataFetcher: func(t *testing.T, fetcher *MockCatalogDataFetcher) {
fetcher.On("ValidateRequest", mock.Anything, mock.Anything).Return(nil)
fetcher.On("NormalizeRequest", mock.Anything)
fetcher.On("FetchPreparedQuery", mock.Anything, mock.Anything).Return([]*Result{testResult}, nil)
},
expectedResults: []*Result{testResult},
},
{
name: "returns error from validation",
reqType: QueryTypePreparedQuery,
configureDataFetcher: func(t *testing.T, fetcher *MockCatalogDataFetcher) {
fetcher.On("ValidateRequest", mock.Anything, mock.Anything).Return(testErr)
},
expectedError: testErr,
},
{
name: "returns error from fetcher",
reqType: QueryTypePreparedQuery,
configureDataFetcher: func(t *testing.T, fetcher *MockCatalogDataFetcher) {
fetcher.On("ValidateRequest", mock.Anything, mock.Anything).Return(nil)
fetcher.On("NormalizeRequest", mock.Anything)
fetcher.On("FetchPreparedQuery", mock.Anything, mock.Anything).Return(nil, testErr)
},
expectedError: testErr,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
run(t, tc)
})
}
}
func TestQueryByIP(t *testing.T) {
type testCase struct {
name string
configureDataFetcher func(*testing.T, *MockCatalogDataFetcher)
expectedResults []*Result
expectedError error
}
run := func(t *testing.T, tc testCase) {
fetcher := NewMockCatalogDataFetcher(t)
tc.configureDataFetcher(t, fetcher)
qp := NewQueryProcessor(fetcher)
results, err := qp.QueryByIP(testIP, testContext)
if tc.expectedError != nil {
require.Error(t, err)
require.True(t, errors.Is(err, tc.expectedError))
require.Nil(t, results)
return
}
require.NoError(t, err)
require.Equal(t, tc.expectedResults, results)
}
testCases := []testCase{
{
name: "query by IP",
configureDataFetcher: func(t *testing.T, fetcher *MockCatalogDataFetcher) {
fetcher.On("FetchRecordsByIp", mock.Anything, mock.Anything).Return([]*Result{testResult}, nil)
},
expectedResults: []*Result{testResult},
},
{
name: "returns error from fetcher",
configureDataFetcher: func(t *testing.T, fetcher *MockCatalogDataFetcher) {
fetcher.On("FetchRecordsByIp", mock.Anything, mock.Anything).Return(nil, testErr)
},
expectedError: testErr,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
run(t, tc)
})
}
}

View File

@ -1,4 +1,4 @@
// Code generated by mockery v2.32.4. DO NOT EDIT.
// Code generated by mockery v2.37.1. DO NOT EDIT.
package discovery
@ -175,6 +175,25 @@ func (_m *MockCatalogDataFetcher) LoadConfig(_a0 *config.RuntimeConfig) {
_m.Called(_a0)
}
// NormalizeRequest provides a mock function with given fields: req
func (_m *MockCatalogDataFetcher) NormalizeRequest(req *QueryPayload) {
_m.Called(req)
}
// ValidateRequest provides a mock function with given fields: ctx, req
func (_m *MockCatalogDataFetcher) ValidateRequest(ctx Context, req *QueryPayload) error {
ret := _m.Called(ctx, req)
var r0 error
if rf, ok := ret.Get(0).(func(Context, *QueryPayload) error); ok {
r0 = rf(ctx, req)
} else {
r0 = ret.Error(0)
}
return r0
}
// NewMockCatalogDataFetcher creates a new instance of MockCatalogDataFetcher. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewMockCatalogDataFetcher(t interface {

View File

@ -12,15 +12,15 @@ import (
"time"
"github.com/armon/go-metrics"
cachetype "github.com/hashicorp/consul/agent/cache-types"
"github.com/hashicorp/consul/api"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/cache"
cachetype "github.com/hashicorp/consul/agent/cache-types"
"github.com/hashicorp/consul/agent/config"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/api"
)
const (
@ -31,16 +31,14 @@ const (
// v1DataFetcherDynamicConfig is used to store the dynamic configuration of the V1 data fetcher.
type v1DataFetcherDynamicConfig struct {
// Default request tenancy
defaultEntMeta acl.EnterpriseMeta
datacenter string
datacenter string
// Catalog configuration
allowStale bool
maxStale time.Duration
useCache bool
cacheMaxAge time.Duration
onlyPassing bool
enterpriseDNSConfig EnterpriseDNSConfig
allowStale bool
maxStale time.Duration
useCache bool
cacheMaxAge time.Duration
onlyPassing bool
}
// V1DataFetcher is used to fetch data from the V1 catalog.
@ -82,15 +80,12 @@ func NewV1DataFetcher(config *config.RuntimeConfig,
// LoadConfig loads the configuration for the V1 data fetcher.
func (f *V1DataFetcher) LoadConfig(config *config.RuntimeConfig) {
dynamicConfig := &v1DataFetcherDynamicConfig{
allowStale: config.DNSAllowStale,
maxStale: config.DNSMaxStale,
useCache: config.DNSUseCache,
cacheMaxAge: config.DNSCacheMaxAge,
onlyPassing: config.DNSOnlyPassing,
enterpriseDNSConfig: GetEnterpriseDNSConfig(config),
datacenter: config.Datacenter,
// TODO (v2-dns): make this work
//defaultEntMeta: config.EnterpriseRuntimeConfig.DefaultEntMeta,
allowStale: config.DNSAllowStale,
maxStale: config.DNSMaxStale,
useCache: config.DNSUseCache,
cacheMaxAge: config.DNSCacheMaxAge,
onlyPassing: config.DNSOnlyPassing,
datacenter: config.Datacenter,
}
f.dynamicConfig.Store(dynamicConfig)
}
@ -107,7 +102,7 @@ func (f *V1DataFetcher) FetchNodes(ctx Context, req *QueryPayload) ([]*Result, e
Token: ctx.Token,
AllowStale: cfg.allowStale,
},
EnterpriseMeta: req.Tenancy.EnterpriseMeta,
EnterpriseMeta: queryTenancyToEntMeta(req.Tenancy),
}
out, err := f.fetchNode(cfg, args)
if err != nil {
@ -128,8 +123,9 @@ func (f *V1DataFetcher) FetchNodes(ctx Context, req *QueryPayload) ([]*Result, e
Metadata: node.Meta,
Target: node.Node,
Tenancy: ResultTenancy{
EnterpriseMeta: cfg.defaultEntMeta,
Datacenter: cfg.datacenter,
// Namespace is not required because nodes are not namespaced
Partition: node.GetEnterpriseMeta().PartitionOrDefault(),
Datacenter: node.Datacenter,
},
})
@ -155,7 +151,7 @@ func (f *V1DataFetcher) FetchVirtualIP(ctx Context, req *QueryPayload) (*Result,
// within a DC, therefore their uniqueness is not guaranteed globally.
PeerName: req.Tenancy.Peer,
ServiceName: req.Name,
EnterpriseMeta: req.Tenancy.EnterpriseMeta,
EnterpriseMeta: queryTenancyToEntMeta(req.Tenancy),
QueryOptions: structs.QueryOptions{
Token: ctx.Token,
},
@ -176,6 +172,10 @@ func (f *V1DataFetcher) FetchVirtualIP(ctx Context, req *QueryPayload) (*Result,
// FetchRecordsByIp is used for PTR requests to look up a service/node from an IP.
// The search is performed in the agent's partition and over all namespaces (or those allowed by the ACL token).
func (f *V1DataFetcher) FetchRecordsByIp(reqCtx Context, ip net.IP) ([]*Result, error) {
if ip == nil {
return nil, ErrNotSupported
}
configCtx := f.dynamicConfig.Load().(*v1DataFetcherDynamicConfig)
targetIP := ip.String()
@ -200,8 +200,9 @@ func (f *V1DataFetcher) FetchRecordsByIp(reqCtx Context, ip net.IP) ([]*Result,
Type: ResultTypeNode,
Target: n.Node,
Tenancy: ResultTenancy{
EnterpriseMeta: f.defaultEnterpriseMeta,
Datacenter: configCtx.datacenter,
Namespace: f.defaultEnterpriseMeta.NamespaceOrDefault(),
Partition: f.defaultEnterpriseMeta.PartitionOrDefault(),
Datacenter: configCtx.datacenter,
},
})
return results, nil
@ -229,8 +230,9 @@ func (f *V1DataFetcher) FetchRecordsByIp(reqCtx Context, ip net.IP) ([]*Result,
Type: ResultTypeService,
Target: n.ServiceName,
Tenancy: ResultTenancy{
EnterpriseMeta: f.defaultEnterpriseMeta,
Datacenter: configCtx.datacenter,
Namespace: f.defaultEnterpriseMeta.NamespaceOrDefault(),
Partition: f.defaultEnterpriseMeta.PartitionOrDefault(),
Datacenter: configCtx.datacenter,
},
})
return results, nil
@ -257,6 +259,16 @@ func (f *V1DataFetcher) FetchPreparedQuery(ctx Context, req *QueryPayload) ([]*R
return nil, nil
}
func (f *V1DataFetcher) ValidateRequest(_ Context, req *QueryPayload) error {
if req.EnableFailover {
return ErrNotSupported
}
if req.PortName != "" {
return ErrNotSupported
}
return validateEnterpriseTenancy(req.Tenancy)
}
// fetchNode is used to look up a node in the Consul catalog within NodeServices.
// If the config is set to UseCache, it will get the record from the agent cache.
func (f *V1DataFetcher) fetchNode(cfg *v1DataFetcherDynamicConfig, args *structs.NodeSpecificRequest) (*structs.IndexedNodeServices, error) {
@ -336,7 +348,7 @@ func (f *V1DataFetcher) fetchServiceBasedOnTenancy(ctx Context, req *QueryPayloa
UseCache: cfg.useCache,
MaxStaleDuration: cfg.maxStale,
},
EnterpriseMeta: req.Tenancy.EnterpriseMeta,
EnterpriseMeta: queryTenancyToEntMeta(req.Tenancy),
}
out, _, err := f.rpcFuncForServiceNodes(context.TODO(), args)
@ -365,15 +377,16 @@ func (f *V1DataFetcher) fetchServiceBasedOnTenancy(ctx Context, req *QueryPayloa
address, target, resultType := getAddressTargetAndResultType(node)
results = append(results, &Result{
Address: address,
Type: resultType,
Target: target,
Weight: uint32(findWeight(node)),
Port: uint32(f.translateServicePortFunc(node.Node.Datacenter, node.Service.Port, node.Service.TaggedAddresses)),
Metadata: node.Node.Meta,
Address: address,
Type: resultType,
Target: target,
Weight: uint32(findWeight(node)),
PortNumber: uint32(f.translateServicePortFunc(node.Node.Datacenter, node.Service.Port, node.Service.TaggedAddresses)),
Metadata: node.Node.Meta,
Tenancy: ResultTenancy{
EnterpriseMeta: cfg.defaultEntMeta,
Datacenter: cfg.datacenter,
Namespace: node.Service.NamespaceOrEmpty(),
Partition: node.Service.PartitionOrEmpty(),
Datacenter: node.Node.Datacenter,
},
})
}

View File

@ -8,8 +8,26 @@ package discovery
import (
"errors"
"fmt"
"github.com/hashicorp/consul/acl"
)
func (f *V1DataFetcher) NormalizeRequest(req *QueryPayload) {
// Nothing to do for CE
return
}
func validateEnterpriseTenancy(req QueryTenancy) error {
if req.Namespace != "" || req.Partition != "" {
return ErrNotSupported
}
return nil
}
func queryTenancyToEntMeta(_ QueryTenancy) acl.EnterpriseMeta {
return acl.EnterpriseMeta{}
}
// fetchServiceFromSamenessGroup fetches a service from a sameness group.
func (f *V1DataFetcher) fetchServiceFromSamenessGroup(ctx Context, req *QueryPayload, cfg *v1DataFetcherDynamicConfig) ([]*Result, error) {
f.logger.Debug(fmt.Sprintf("fetchServiceFromSamenessGroup - req: %+v", req))

View File

@ -5,7 +5,7 @@
package discovery
import "github.com/hashicorp/consul/acl"
// defaultEntMeta is the default enterprise meta used for testing.
var defaultEntMeta = acl.EnterpriseMeta{}
const (
defaultTestNamespace = ""
defaultTestPartition = ""
)

View File

@ -9,12 +9,11 @@ import (
"testing"
"time"
"github.com/hashicorp/consul/agent/cache"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/cache"
cachetype "github.com/hashicorp/consul/agent/cache-types"
"github.com/hashicorp/consul/agent/config"
"github.com/hashicorp/consul/agent/structs"
@ -43,8 +42,9 @@ func Test_FetchVirtualIP(t *testing.T) {
queryPayload: &QueryPayload{
Name: "db",
Tenancy: QueryTenancy{
Peer: "test-peer",
EnterpriseMeta: defaultEntMeta,
Peer: "test-peer",
Namespace: defaultTestNamespace,
Partition: defaultTestPartition,
},
},
context: Context{
@ -61,9 +61,9 @@ func Test_FetchVirtualIP(t *testing.T) {
queryPayload: &QueryPayload{
Name: "db",
Tenancy: QueryTenancy{
Peer: "test-peer",
EnterpriseMeta: defaultEntMeta,
},
Peer: "test-peer",
Namespace: defaultTestNamespace,
Partition: defaultTestPartition},
},
context: Context{
Token: "test-token",
@ -90,7 +90,8 @@ func Test_FetchVirtualIP(t *testing.T) {
// validate RPC options are set correctly from the queryPayload and context
require.Equal(t, tc.queryPayload.Tenancy.Peer, req.PeerName)
require.Equal(t, tc.queryPayload.Tenancy.EnterpriseMeta, req.EnterpriseMeta)
require.Equal(t, tc.queryPayload.Tenancy.Namespace, req.EnterpriseMeta.NamespaceOrEmpty())
require.Equal(t, tc.queryPayload.Tenancy.Partition, req.EnterpriseMeta.PartitionOrEmpty())
require.Equal(t, tc.context.Token, req.QueryOptions.Token)
if tc.expectedErr == nil {
@ -143,7 +144,8 @@ func Test_FetchEndpoints(t *testing.T) {
queryPayload: &QueryPayload{
Name: "service-name",
Tenancy: QueryTenancy{
EnterpriseMeta: defaultEntMeta,
Namespace: defaultTestNamespace,
Partition: defaultTestPartition,
},
},
rpcFuncForServiceNodes: func(ctx context.Context, req structs.ServiceSpecificRequest) (structs.IndexedCheckServiceNodes, cache.ResultMeta, error) {
@ -151,12 +153,14 @@ func Test_FetchEndpoints(t *testing.T) {
Nodes: []structs.CheckServiceNode{
{
Node: &structs.Node{
Address: "node-address",
Node: "node-name",
Address: "node-address",
Node: "node-name",
Partition: defaultTestPartition,
},
Service: &structs.NodeService{
Address: "127.0.0.1",
Service: "service-name",
Address: "127.0.0.1",
Service: "service-name",
EnterpriseMeta: acl.NewEnterpriseMetaWithPartition(defaultTestPartition, defaultTestNamespace),
},
},
},
@ -171,6 +175,10 @@ func Test_FetchEndpoints(t *testing.T) {
Target: "service-name",
Type: ResultTypeService,
Weight: 1,
Tenancy: ResultTenancy{
Namespace: defaultTestNamespace,
Partition: defaultTestPartition,
},
},
},
expectedErr: nil,
@ -180,7 +188,8 @@ func Test_FetchEndpoints(t *testing.T) {
queryPayload: &QueryPayload{
Name: "service-name",
Tenancy: QueryTenancy{
EnterpriseMeta: defaultEntMeta,
Namespace: defaultTestNamespace,
Partition: defaultTestPartition,
},
},
rpcFuncForServiceNodes: func(ctx context.Context, req structs.ServiceSpecificRequest) (structs.IndexedCheckServiceNodes, cache.ResultMeta, error) {
@ -188,12 +197,14 @@ func Test_FetchEndpoints(t *testing.T) {
Nodes: []structs.CheckServiceNode{
{
Node: &structs.Node{
Address: "node-address",
Node: "node-name",
Address: "node-address",
Node: "node-name",
Partition: defaultTestPartition,
},
Service: &structs.NodeService{
Address: "2001:db8:1:2:cafe::1337",
Service: "service-name",
Address: "2001:db8:1:2:cafe::1337",
Service: "service-name",
EnterpriseMeta: acl.NewEnterpriseMetaWithPartition(defaultTestPartition, defaultTestNamespace),
},
},
},
@ -208,6 +219,10 @@ func Test_FetchEndpoints(t *testing.T) {
Target: "service-name",
Type: ResultTypeService,
Weight: 1,
Tenancy: ResultTenancy{
Namespace: defaultTestNamespace,
Partition: defaultTestPartition,
},
},
},
expectedErr: nil,
@ -217,7 +232,8 @@ func Test_FetchEndpoints(t *testing.T) {
queryPayload: &QueryPayload{
Name: "service-name",
Tenancy: QueryTenancy{
EnterpriseMeta: defaultEntMeta,
Namespace: defaultTestNamespace,
Partition: defaultTestPartition,
},
},
rpcFuncForServiceNodes: func(ctx context.Context, req structs.ServiceSpecificRequest) (structs.IndexedCheckServiceNodes, cache.ResultMeta, error) {
@ -225,12 +241,14 @@ func Test_FetchEndpoints(t *testing.T) {
Nodes: []structs.CheckServiceNode{
{
Node: &structs.Node{
Address: "node-address",
Node: "node-name",
Address: "node-address",
Node: "node-name",
Partition: defaultTestPartition,
},
Service: &structs.NodeService{
Address: "foo",
Service: "service-name",
Address: "foo",
Service: "service-name",
EnterpriseMeta: acl.NewEnterpriseMetaWithPartition(defaultTestPartition, defaultTestNamespace),
},
},
},
@ -245,6 +263,10 @@ func Test_FetchEndpoints(t *testing.T) {
Target: "foo",
Type: ResultTypeNode,
Weight: 1,
Tenancy: ResultTenancy{
Namespace: defaultTestNamespace,
Partition: defaultTestPartition,
},
},
},
expectedErr: nil,
@ -254,7 +276,8 @@ func Test_FetchEndpoints(t *testing.T) {
queryPayload: &QueryPayload{
Name: "service-name",
Tenancy: QueryTenancy{
EnterpriseMeta: defaultEntMeta,
Namespace: defaultTestNamespace,
Partition: defaultTestPartition,
},
},
rpcFuncForServiceNodes: func(ctx context.Context, req structs.ServiceSpecificRequest) (structs.IndexedCheckServiceNodes, cache.ResultMeta, error) {
@ -262,12 +285,14 @@ func Test_FetchEndpoints(t *testing.T) {
Nodes: []structs.CheckServiceNode{
{
Node: &structs.Node{
Address: "node-address",
Node: "node-name",
Address: "node-address",
Node: "node-name",
Partition: defaultTestPartition,
},
Service: &structs.NodeService{
Address: "",
Service: "service-name",
Address: "",
Service: "service-name",
EnterpriseMeta: acl.NewEnterpriseMetaWithPartition(defaultTestPartition, defaultTestNamespace),
},
},
},
@ -282,6 +307,10 @@ func Test_FetchEndpoints(t *testing.T) {
Target: "node-name",
Type: ResultTypeNode,
Weight: 1,
Tenancy: ResultTenancy{
Namespace: defaultTestNamespace,
Partition: defaultTestPartition,
},
},
},
expectedErr: nil,

View File

@ -4,10 +4,22 @@
package discovery
import (
"context"
"fmt"
"net"
"strings"
"sync/atomic"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/consul/agent/config"
"github.com/hashicorp/consul/internal/resource"
pbcatalog "github.com/hashicorp/consul/proto-public/pbcatalog/v2beta1"
"github.com/hashicorp/consul/proto-public/pbresource"
)
// v2DataFetcherDynamicConfig is used to store the dynamic configuration of the V2 data fetcher.
@ -17,12 +29,22 @@ type v2DataFetcherDynamicConfig struct {
// V2DataFetcher is used to fetch data from the V2 catalog.
type V2DataFetcher struct {
client pbresource.ResourceServiceClient
logger hclog.Logger
// Requests inherit the partition of the agent unless otherwise specified.
defaultPartition string
dynamicConfig atomic.Value
}
// NewV2DataFetcher creates a new V2 data fetcher.
func NewV2DataFetcher(config *config.RuntimeConfig) *V2DataFetcher {
f := &V2DataFetcher{}
func NewV2DataFetcher(config *config.RuntimeConfig, client pbresource.ResourceServiceClient, logger hclog.Logger) *V2DataFetcher {
f := &V2DataFetcher{
client: client,
logger: logger,
defaultPartition: config.PartitionOrDefault(),
}
f.LoadConfig(config)
return f
}
@ -35,14 +57,13 @@ func (f *V2DataFetcher) LoadConfig(config *config.RuntimeConfig) {
f.dynamicConfig.Store(dynamicConfig)
}
// TODO (v2-dns): Implementation of the V2 data fetcher
// FetchNodes fetches A/AAAA/CNAME
func (f *V2DataFetcher) FetchNodes(ctx Context, req *QueryPayload) ([]*Result, error) {
return nil, nil
}
// FetchEndpoints fetches records for A/AAAA/CNAME or SRV requests for services
// TODO (v2-dns): Validate lookupType
func (f *V2DataFetcher) FetchEndpoints(ctx Context, req *QueryPayload, lookupType LookupType) ([]*Result, error) {
return nil, nil
}
@ -53,14 +74,81 @@ func (f *V2DataFetcher) FetchVirtualIP(ctx Context, req *QueryPayload) (*Result,
}
// FetchRecordsByIp is used for PTR requests to look up a service/node from an IP.
// TODO (v2-dns): Validate non-nil IP
func (f *V2DataFetcher) FetchRecordsByIp(ctx Context, ip net.IP) ([]*Result, error) {
return nil, nil
}
// FetchWorkload is used to fetch a single workload from the V2 catalog.
// V2-only.
func (f *V2DataFetcher) FetchWorkload(ctx Context, req *QueryPayload) (*Result, error) {
return nil, nil
func (f *V2DataFetcher) FetchWorkload(reqContext Context, req *QueryPayload) (*Result, error) {
// Query the resource service for the workload by name and tenancy
resourceReq := pbresource.ReadRequest{
Id: &pbresource.ID{
Name: req.Name,
Type: pbcatalog.WorkloadType,
Tenancy: queryTenancyToResourceTenancy(req.Tenancy),
},
}
f.logger.Debug("fetching workload", "name", req.Name)
resourceCtx := metadata.AppendToOutgoingContext(context.Background(), "x-consul-token", reqContext.Token)
// If the workload is not found, return nil and an error equivalent to NXDOMAIN
response, err := f.client.Read(resourceCtx, &resourceReq)
switch {
case grpcNotFoundErr(err):
f.logger.Debug("workload not found", "name", req.Name)
return nil, ErrNotFound
case err != nil:
f.logger.Error("error fetching workload", "name", req.Name)
return nil, fmt.Errorf("error fetching workload: %w", err)
// default: fallthrough
}
workload := &pbcatalog.Workload{}
data := response.GetResource().GetData()
if err := data.UnmarshalTo(workload); err != nil {
f.logger.Error("error unmarshalling workload", "name", req.Name)
return nil, fmt.Errorf("error unmarshalling workload: %w", err)
}
// TODO: (v2-dns): we will need to intelligently return the right workload address based on either the translate
// address setting or the locality of the requester. Workloads must have at least one.
// We also need to make sure that we filter out unix sockets here.
address := workload.Addresses[0].GetHost()
if strings.HasPrefix(address, "unix://") {
f.logger.Error("unix sockets are currently unsupported in workload results", "name", req.Name)
return nil, ErrNotFound
}
tenancy := response.GetResource().GetId().GetTenancy()
result := &Result{
Address: address,
Type: ResultTypeWorkload,
Tenancy: ResultTenancy{
Namespace: tenancy.GetNamespace(),
Partition: tenancy.GetPartition(),
},
Target: response.GetResource().GetId().GetName(),
}
if req.PortName == "" {
return result, nil
}
// If a port is specified, make sure the workload implements that port name.
for name, port := range workload.Ports {
if name == req.PortName {
result.PortName = req.PortName
result.PortNumber = port.Port
return result, nil
}
}
f.logger.Debug("could not find matching port for workload", "name", req.Name, "port", req.PortName)
// Return an ErrNotFound, which is equivalent to NXDOMAIN
return nil, ErrNotFound
}
// FetchPreparedQuery is used to fetch a prepared query from the V2 catalog.
@ -68,3 +156,45 @@ func (f *V2DataFetcher) FetchWorkload(ctx Context, req *QueryPayload) (*Result,
func (f *V2DataFetcher) FetchPreparedQuery(ctx Context, req *QueryPayload) ([]*Result, error) {
return nil, ErrNotSupported
}
func (f *V2DataFetcher) NormalizeRequest(req *QueryPayload) {
// If we do not have an explicit partition in the request, we use the agent's
if req.Tenancy.Partition == "" {
req.Tenancy.Partition = f.defaultPartition
}
}
// ValidateRequest throws an error is any of the deprecated V1 input fields are used in a QueryByName for this data fetcher.
func (f *V2DataFetcher) ValidateRequest(_ Context, req *QueryPayload) error {
if req.Tag != "" {
return ErrNotSupported
}
if req.RemoteAddr != nil {
return ErrNotSupported
}
return nil
}
func queryTenancyToResourceTenancy(qTenancy QueryTenancy) *pbresource.Tenancy {
rTenancy := resource.DefaultNamespacedTenancy()
// If the request has any tenancy specified, it overrides the defaults.
if qTenancy.Namespace != "" {
rTenancy.Namespace = qTenancy.Namespace
}
// In the case of partition, we have the agent's partition as the fallback.
if qTenancy.Partition != "" {
rTenancy.Partition = qTenancy.Partition
}
return rTenancy
}
// grpcNotFoundErr returns true if the error is a gRPC status error with a code of NotFound.
func grpcNotFoundErr(err error) bool {
if err == nil {
return false
}
s, ok := status.FromError(err)
return ok && s.Code() == codes.NotFound
}

View File

@ -0,0 +1,259 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package discovery
import (
"errors"
"testing"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/anypb"
"github.com/hashicorp/consul/agent/config"
mockpbresource "github.com/hashicorp/consul/grpcmocks/proto-public/pbresource"
"github.com/hashicorp/consul/internal/resource"
pbcatalog "github.com/hashicorp/consul/proto-public/pbcatalog/v2beta1"
"github.com/hashicorp/consul/proto-public/pbresource"
"github.com/hashicorp/consul/sdk/testutil"
)
// Test_FetchService tests the FetchService method in scenarios where the RPC
// call succeeds and fails.
func Test_FetchWorkload(t *testing.T) {
rc := &config.RuntimeConfig{
DNSOnlyPassing: false,
}
unknownErr := errors.New("I don't feel so good")
tests := []struct {
name string
queryPayload *QueryPayload
context Context
configureMockClient func(mockClient *mockpbresource.ResourceServiceClient_Expecter)
expectedResult *Result
expectedErr error
}{
{
name: "FetchWorkload returns result",
queryPayload: &QueryPayload{
Name: "foo-1234",
},
context: Context{
Token: "test-token",
},
configureMockClient: func(mockClient *mockpbresource.ResourceServiceClient_Expecter) {
result := getTestWorkloadResponse(t, "", "")
mockClient.Read(mock.Anything, mock.Anything).
Return(result, nil).
Once().
Run(func(args mock.Arguments) {
req := args.Get(1).(*pbresource.ReadRequest)
require.Equal(t, result.GetResource().GetId().GetName(), req.Id.Name)
})
},
expectedResult: &Result{
Address: "1.2.3.4",
Type: ResultTypeWorkload,
Tenancy: ResultTenancy{
Namespace: resource.DefaultNamespaceName,
Partition: resource.DefaultPartitionName,
},
Target: "foo-1234",
},
expectedErr: nil,
},
{
name: "FetchWorkload for non-existent workload",
queryPayload: &QueryPayload{
Name: "foo-1234",
},
context: Context{
Token: "test-token",
},
configureMockClient: func(mockClient *mockpbresource.ResourceServiceClient_Expecter) {
input := getTestWorkloadResponse(t, "", "")
mockClient.Read(mock.Anything, mock.Anything).
Return(nil, status.Error(codes.NotFound, "not found")).
Once().
Run(func(args mock.Arguments) {
req := args.Get(1).(*pbresource.ReadRequest)
require.Equal(t, input.GetResource().GetId().GetName(), req.Id.Name)
})
},
expectedResult: nil,
expectedErr: ErrNotFound,
},
{
name: "FetchWorkload encounters a resource client error",
queryPayload: &QueryPayload{
Name: "foo-1234",
},
context: Context{
Token: "test-token",
},
configureMockClient: func(mockClient *mockpbresource.ResourceServiceClient_Expecter) {
input := getTestWorkloadResponse(t, "", "")
mockClient.Read(mock.Anything, mock.Anything).
Return(nil, unknownErr).
Once().
Run(func(args mock.Arguments) {
req := args.Get(1).(*pbresource.ReadRequest)
require.Equal(t, input.GetResource().GetId().GetName(), req.Id.Name)
})
},
expectedResult: nil,
expectedErr: unknownErr,
},
{
name: "FetchWorkload with a matching port",
queryPayload: &QueryPayload{
Name: "foo-1234",
PortName: "api",
},
context: Context{
Token: "test-token",
},
configureMockClient: func(mockClient *mockpbresource.ResourceServiceClient_Expecter) {
result := getTestWorkloadResponse(t, "", "")
mockClient.Read(mock.Anything, mock.Anything).
Return(result, nil).
Once().
Run(func(args mock.Arguments) {
req := args.Get(1).(*pbresource.ReadRequest)
require.Equal(t, result.GetResource().GetId().GetName(), req.Id.Name)
})
},
expectedResult: &Result{
Address: "1.2.3.4",
Type: ResultTypeWorkload,
PortName: "api",
PortNumber: 5678,
Tenancy: ResultTenancy{
Namespace: resource.DefaultNamespaceName,
Partition: resource.DefaultPartitionName,
},
Target: "foo-1234",
},
expectedErr: nil,
},
{
name: "FetchWorkload with a matching port",
queryPayload: &QueryPayload{
Name: "foo-1234",
PortName: "not-api",
},
context: Context{
Token: "test-token",
},
configureMockClient: func(mockClient *mockpbresource.ResourceServiceClient_Expecter) {
result := getTestWorkloadResponse(t, "", "")
mockClient.Read(mock.Anything, mock.Anything).
Return(result, nil).
Once().
Run(func(args mock.Arguments) {
req := args.Get(1).(*pbresource.ReadRequest)
require.Equal(t, result.GetResource().GetId().GetName(), req.Id.Name)
})
},
expectedResult: nil,
expectedErr: ErrNotFound,
},
{
name: "FetchWorkload returns result for non-default tenancy",
queryPayload: &QueryPayload{
Name: "foo-1234",
Tenancy: QueryTenancy{
Namespace: "test-namespace",
Partition: "test-partition",
},
},
context: Context{
Token: "test-token",
},
configureMockClient: func(mockClient *mockpbresource.ResourceServiceClient_Expecter) {
result := getTestWorkloadResponse(t, "test-namespace", "test-partition")
mockClient.Read(mock.Anything, mock.Anything).
Return(result, nil).
Once().
Run(func(args mock.Arguments) {
req := args.Get(1).(*pbresource.ReadRequest)
require.Equal(t, result.GetResource().GetId().GetName(), req.Id.Name)
require.Equal(t, result.GetResource().GetId().GetTenancy().GetNamespace(), req.Id.Tenancy.Namespace)
require.Equal(t, result.GetResource().GetId().GetTenancy().GetPartition(), req.Id.Tenancy.Partition)
})
},
expectedResult: &Result{
Address: "1.2.3.4",
Type: ResultTypeWorkload,
Tenancy: ResultTenancy{
Namespace: "test-namespace",
Partition: "test-partition",
},
Target: "foo-1234",
},
expectedErr: nil,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
logger := testutil.Logger(t)
client := mockpbresource.NewResourceServiceClient(t)
mockClient := client.EXPECT()
tc.configureMockClient(mockClient)
df := NewV2DataFetcher(rc, client, logger)
result, err := df.FetchWorkload(tc.context, tc.queryPayload)
require.True(t, errors.Is(err, tc.expectedErr))
require.Equal(t, tc.expectedResult, result)
})
}
}
func getTestWorkloadResponse(t *testing.T, nsOverride string, partitionOverride string) *pbresource.ReadResponse {
workload := &pbcatalog.Workload{
Addresses: []*pbcatalog.WorkloadAddress{
{
Host: "1.2.3.4",
Ports: []string{"api"},
},
},
Ports: map[string]*pbcatalog.WorkloadPort{
"api": {
Port: 5678,
},
},
Identity: "test-identity",
}
data, err := anypb.New(workload)
require.NoError(t, err)
resp := &pbresource.ReadResponse{
Resource: &pbresource.Resource{
Id: &pbresource.ID{
Name: "foo-1234",
Type: pbcatalog.WorkloadType,
Tenancy: resource.DefaultNamespacedTenancy(), // TODO (v2-dns): tenancy
},
Data: data,
},
}
if nsOverride != "" {
resp.Resource.Id.Tenancy.Namespace = nsOverride
}
if partitionOverride != "" {
resp.Resource.Id.Tenancy.Partition = partitionOverride
}
return resp
}

View File

@ -1,61 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package discovery
import "github.com/hashicorp/consul/acl"
// QueryLocality is the locality parsed from a DNS query.
type QueryLocality struct {
// Datacenter is the datacenter parsed from a label that has an explicit datacenter part.
// Example query: <service>.virtual.<namespace>.ns.<partition>.ap.<datacenter>.dc.consul
Datacenter string
// Peer is the peer name parsed from a label that has explicit parts.
// Example query: <service>.virtual.<namespace>.ns.<peer>.peer.<partition>.ap.consul
Peer string
// PeerOrDatacenter is parsed from DNS queries where the datacenter and peer name are
// specified in the same query part.
// Example query: <service>.virtual.<peerOrDatacenter>.consul
//
// Note that this field should only be a "peer" for virtual queries, since virtual IPs should
// not be shared between datacenters. In all other cases, it should be considered a DC.
PeerOrDatacenter string
acl.EnterpriseMeta
}
// EffectiveDatacenter returns the datacenter parsed from a query, or a default
// value if none is specified.
func (l QueryLocality) EffectiveDatacenter(defaultDC string) string {
// Prefer the value parsed from a query with explicit parts: <namespace>.ns.<partition>.ap.<datacenter>.dc
if l.Datacenter != "" {
return l.Datacenter
}
// Fall back to the ambiguously parsed DC or Peer.
if l.PeerOrDatacenter != "" {
return l.PeerOrDatacenter
}
// If all are empty, use a default value.
return defaultDC
}
// GetQueryTenancyBasedOnLocality returns a discovery.QueryTenancy from a DNS message.
func GetQueryTenancyBasedOnLocality(locality QueryLocality, defaultDatacenter string) (QueryTenancy, error) {
datacenter := locality.EffectiveDatacenter(defaultDatacenter)
// Only one of dc or peer can be used.
if locality.Peer != "" {
datacenter = ""
}
return QueryTenancy{
EnterpriseMeta: locality.EnterpriseMeta,
// The datacenter of the request is not specified because cross-datacenter virtual IP
// queries are not supported. This guard rail is in place because virtual IPs are allocated
// within a DC, therefore their uniqueness is not guaranteed globally.
Peer: locality.Peer,
Datacenter: datacenter,
SamenessGroup: "", // this should be nil since the single locality was directly used to configure tenancy.
}, nil
}

View File

@ -1,57 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
//go:build !consulent
package discovery
import (
"github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/config"
)
// ParseLocality can parse peer name or datacenter from a DNS query's labels.
// Peer name is parsed from the same query part that datacenter is, so given this ambiguity
// we parse a "peerOrDatacenter". The caller or RPC handler are responsible for disambiguating.
func ParseLocality(labels []string, defaultEnterpriseMeta acl.EnterpriseMeta, _ EnterpriseDNSConfig) (QueryLocality, bool) {
locality := QueryLocality{
EnterpriseMeta: defaultEnterpriseMeta,
}
switch len(labels) {
case 2, 4:
// Support the following formats:
// - [.<datacenter>.dc]
// - [.<peer>.peer]
for i := 0; i < len(labels); i += 2 {
switch labels[i+1] {
case "dc":
locality.Datacenter = labels[i]
case "peer":
locality.Peer = labels[i]
default:
return QueryLocality{}, false
}
}
// Return error when both datacenter and peer are specified.
if locality.Datacenter != "" && locality.Peer != "" {
return QueryLocality{}, false
}
return locality, true
case 1:
return QueryLocality{PeerOrDatacenter: labels[0]}, true
case 0:
return QueryLocality{}, true
}
return QueryLocality{}, false
}
// EnterpriseDNSConfig is the configuration for enterprise DNS.
type EnterpriseDNSConfig struct{}
// GetEnterpriseDNSConfig returns the enterprise DNS configuration.
func GetEnterpriseDNSConfig(conf *config.RuntimeConfig) EnterpriseDNSConfig {
return EnterpriseDNSConfig{}
}

View File

@ -1,60 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
//go:build !consulent
package discovery
import (
"github.com/hashicorp/consul/acl"
)
func getTestCases() []testCaseParseLocality {
testCases := []testCaseParseLocality{
{
name: "test [.<datacenter>.dc]",
labels: []string{"test-dc", "dc"},
enterpriseDNSConfig: EnterpriseDNSConfig{},
expectedResult: QueryLocality{
EnterpriseMeta: acl.EnterpriseMeta{},
Datacenter: "test-dc",
},
expectedOK: true,
},
{
name: "test [.<peer>.peer]",
labels: []string{"test-peer", "peer"},
enterpriseDNSConfig: EnterpriseDNSConfig{},
expectedResult: QueryLocality{
EnterpriseMeta: acl.EnterpriseMeta{},
Peer: "test-peer",
},
expectedOK: true,
},
{
name: "test 1 label",
labels: []string{"test-peer"},
enterpriseDNSConfig: EnterpriseDNSConfig{},
expectedResult: QueryLocality{
EnterpriseMeta: acl.EnterpriseMeta{},
PeerOrDatacenter: "test-peer",
},
expectedOK: true,
},
{
name: "test 0 labels",
labels: []string{},
enterpriseDNSConfig: EnterpriseDNSConfig{},
expectedResult: QueryLocality{},
expectedOK: true,
},
{
name: "test 3 labels returns not found",
labels: []string{"test-dc", "dc", "test-blah"},
enterpriseDNSConfig: EnterpriseDNSConfig{},
expectedResult: QueryLocality{},
expectedOK: false,
},
}
return testCases
}

View File

@ -1,73 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package discovery
import (
"testing"
"github.com/hashicorp/consul/acl"
"github.com/stretchr/testify/require"
)
type testCaseParseLocality struct {
name string
labels []string
defaultMeta acl.EnterpriseMeta
enterpriseDNSConfig EnterpriseDNSConfig
expectedResult QueryLocality
expectedOK bool
}
func Test_parseLocality(t *testing.T) {
testCases := getTestCases()
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
actualResult, actualOK := ParseLocality(tc.labels, tc.defaultMeta, tc.enterpriseDNSConfig)
require.Equal(t, tc.expectedOK, actualOK)
require.Equal(t, tc.expectedResult, actualResult)
})
}
}
func Test_effectiveDatacenter(t *testing.T) {
type testCase struct {
name string
QueryLocality QueryLocality
defaultDC string
expected string
}
testCases := []testCase{
{
name: "return Datacenter first",
QueryLocality: QueryLocality{
Datacenter: "test-dc",
PeerOrDatacenter: "test-peer",
},
defaultDC: "default-dc",
expected: "test-dc",
},
{
name: "return PeerOrDatacenter second",
QueryLocality: QueryLocality{
PeerOrDatacenter: "test-peer",
},
defaultDC: "default-dc",
expected: "test-peer",
},
{
name: "return defaultDC as fallback",
QueryLocality: QueryLocality{},
defaultDC: "default-dc",
expected: "default-dc",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
got := tc.QueryLocality.EffectiveDatacenter(tc.defaultDC)
require.Equal(t, tc.expected, got)
})
}
}

View File

@ -4,8 +4,6 @@ package dns
import (
config "github.com/hashicorp/consul/agent/config"
discovery "github.com/hashicorp/consul/agent/discovery"
miekgdns "github.com/miekg/dns"
mock "github.com/stretchr/testify/mock"
@ -19,11 +17,11 @@ type MockDNSRouter struct {
}
// HandleRequest provides a mock function with given fields: req, reqCtx, remoteAddress
func (_m *MockDNSRouter) HandleRequest(req *miekgdns.Msg, reqCtx discovery.Context, remoteAddress net.Addr) *miekgdns.Msg {
func (_m *MockDNSRouter) HandleRequest(req *miekgdns.Msg, reqCtx Context, remoteAddress net.Addr) *miekgdns.Msg {
ret := _m.Called(req, reqCtx, remoteAddress)
var r0 *miekgdns.Msg
if rf, ok := ret.Get(0).(func(*miekgdns.Msg, discovery.Context, net.Addr) *miekgdns.Msg); ok {
if rf, ok := ret.Get(0).(func(*miekgdns.Msg, Context, net.Addr) *miekgdns.Msg); ok {
r0 = rf(req, reqCtx, remoteAddress)
} else {
if ret.Get(0) != nil {

89
agent/dns/parser.go Normal file
View File

@ -0,0 +1,89 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package dns
// parsedLabels defines valid DNS labels that are possible for ALL DNS query in Consul. (v1 and v2, CE and ENT)
// It is the job of the parser to populate the struct, the routers to call the query processor,
// and the query processor to validate is the labels.
type parsedLabels struct {
Datacenter string
Namespace string
Partition string
Peer string
PeerOrDatacenter string // deprecated: use Datacenter or Peer
SamenessGroup string
}
// ParseLabels can parse a DNS query's labels and returns a parsedLabels.
// It also does light validation according to invariants across all possible DNS queries for all Consul versions
func parseLabels(labels []string) (*parsedLabels, bool) {
var result parsedLabels
switch len(labels) {
case 2, 4, 6:
// Supports the following formats:
// - [.<namespace>.ns][.<partition>.ap][.<datacenter>.dc]
// - <namespace>.<datacenter>
// - [.<namespace>.ns][.<partition>.ap][.<peer>.peer]
// - [.<samenessGroup>.sg][.<partition>.ap][.<namespace>.ns]
for i := 0; i < len(labels); i += 2 {
switch labels[i+1] {
case "ns":
result.Namespace = labels[i]
case "ap":
result.Partition = labels[i]
case "dc": // TODO (v2-dns): This should also include "cluster" for the new notation.
result.Datacenter = labels[i]
case "sg":
result.SamenessGroup = labels[i]
case "peer":
result.Peer = labels[i]
default:
// The only case in which labels[i+1] is allowed to be a value
// other than ns, ap, or dc is if n == 2 to support the format:
// <namespace>.<datacenter>.
if len(labels) == 2 {
result.PeerOrDatacenter = labels[1]
result.Namespace = labels[0]
return &result, true
}
return nil, false
}
}
// VALIDATIONS
// Return nil result and false boolean when both datacenter and peer are specified.
if result.Datacenter != "" && result.Peer != "" {
return nil, false
}
// Validation e need to validate that this a valid DNS including sg
if result.SamenessGroup != "" && (result.Datacenter != "" || result.Peer != "") {
return nil, false
}
return &result, true
case 1:
result.PeerOrDatacenter = labels[0]
return &result, true
case 0:
return &result, true
}
return &result, false
}
// parsePort looks through the query parts for a named port label.
// It assumes the only valid input format is["<portName>", "port", "<targetName>"].
// The other expected formats are ["<targetName>"] and ["<tag>", "<targetName>"].
// It is expected that the queryProcessor validates if the label is allowed for the query type.
func parsePort(parts []string) string {
// The minimum number of parts would be
if len(parts) != 3 || parts[1] != "port" {
return ""
}
return parts[0]
}

6
agent/dns/parser_test.go Normal file
View File

@ -0,0 +1,6 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package dns
// TODO (v2-dns): parser tests

View File

@ -14,14 +14,15 @@ import (
"time"
"github.com/armon/go-radix"
"github.com/hashicorp/go-hclog"
"github.com/miekg/dns"
"github.com/hashicorp/consul/acl"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/consul/agent/config"
"github.com/hashicorp/consul/agent/discovery"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/internal/dnsutil"
"github.com/hashicorp/consul/internal/resource"
"github.com/hashicorp/consul/logging"
)
@ -40,6 +41,7 @@ const (
var (
errInvalidQuestion = fmt.Errorf("invalid question")
errNameNotFound = fmt.Errorf("name not found")
errNotImplemented = fmt.Errorf("not implemented")
errRecursionFailed = fmt.Errorf("recursion failed")
trailingSpacesRE = regexp.MustCompile(" +$")
@ -47,6 +49,12 @@ var (
// TODO (v2-dns): metrics
// Context is used augment a DNS message with Consul-specific metadata.
type Context struct {
Token string
DefaultPartition string
}
// RouterDynamicConfig is the dynamic configuration that can be hot-reloaded
type RouterDynamicConfig struct {
ARecordLimit int
@ -64,26 +72,6 @@ type RouterDynamicConfig struct {
// TTLStrict sets TTLs to service by full name match. It Has higher priority than TTLRadix
TTLStrict map[string]time.Duration
UDPAnswerLimit int
discovery.EnterpriseDNSConfig
}
// GetTTLForService Find the TTL for a given service.
// return ttl, true if found, 0, false otherwise
func (cfg *RouterDynamicConfig) GetTTLForService(service string) (time.Duration, bool) {
if cfg.TTLStrict != nil {
ttl, ok := cfg.TTLStrict[service]
if ok {
return ttl, true
}
}
if cfg.TTLRadix != nil {
_, ttlRaw, ok := cfg.TTLRadix.LongestPrefix(service)
if ok {
return ttlRaw.(time.Duration), true
}
}
return 0, false
}
type SOAConfig struct {
@ -120,10 +108,6 @@ type Router struct {
tokenFunc func() string
defaultEntMeta acl.EnterpriseMeta
// TODO (v2-dns): default locality for request context?
// dynamicConfig stores the config as an atomic value (for hot-reloading).
// It is always of type *RouterDynamicConfig
dynamicConfig atomic.Value
@ -142,13 +126,12 @@ func NewRouter(cfg Config) (*Router, error) {
logger := cfg.Logger.Named(logging.DNS)
router := &Router{
processor: cfg.Processor,
recursor: newRecursor(logger),
domain: domain,
altDomain: altDomain,
logger: logger,
tokenFunc: cfg.TokenFunc,
defaultEntMeta: cfg.EntMeta,
processor: cfg.Processor,
recursor: newRecursor(logger),
domain: domain,
altDomain: altDomain,
logger: logger,
tokenFunc: cfg.TokenFunc,
}
if err := router.ReloadConfig(cfg.AgentConfig); err != nil {
@ -158,13 +141,13 @@ func NewRouter(cfg Config) (*Router, error) {
}
// HandleRequest is used to process an individual DNS request. It returns a message in success or fail cases.
func (r *Router) HandleRequest(req *dns.Msg, reqCtx discovery.Context, remoteAddress net.Addr) *dns.Msg {
func (r *Router) HandleRequest(req *dns.Msg, reqCtx Context, remoteAddress net.Addr) *dns.Msg {
return r.handleRequestRecursively(req, reqCtx, remoteAddress, maxRecursionLevelDefault)
}
// handleRequestRecursively is used to process an individual DNS request. It will recurse as needed
// a maximum number of times and returns a message in success or fail cases.
func (r *Router) handleRequestRecursively(req *dns.Msg, reqCtx discovery.Context,
func (r *Router) handleRequestRecursively(req *dns.Msg, reqCtx Context,
remoteAddress net.Addr, maxRecursionLevel int) *dns.Msg {
configCtx := r.dynamicConfig.Load().(*RouterDynamicConfig)
@ -204,14 +187,28 @@ func (r *Router) handleRequestRecursively(req *dns.Msg, reqCtx discovery.Context
}
reqType := parseRequestType(req)
results, query, err := r.getQueryResults(req, reqCtx, reqType, configCtx, qName)
results, query, err := r.getQueryResults(req, reqCtx, reqType, qName)
switch {
case errors.Is(err, errNameNotFound):
r.logger.Error("name not found", "name", qName)
ecsGlobal := !errors.Is(err, discovery.ErrECSNotGlobal)
return createAuthoritativeResponse(req, configCtx, responseDomain, dns.RcodeNameError, ecsGlobal)
// TODO (v2-dns): there is another case here where the discovery service returns "name not found"
case errors.Is(err, errNotImplemented):
r.logger.Error("query not implemented", "name", qName, "type", dns.Type(req.Question[0].Qtype).String())
ecsGlobal := !errors.Is(err, discovery.ErrECSNotGlobal)
return createAuthoritativeResponse(req, configCtx, responseDomain, dns.RcodeNotImplemented, ecsGlobal)
case errors.Is(err, discovery.ErrNotSupported):
r.logger.Debug("query name syntax not supported", "name", req.Question[0].Name)
ecsGlobal := !errors.Is(err, discovery.ErrECSNotGlobal)
return createAuthoritativeResponse(req, configCtx, responseDomain, dns.RcodeNameError, ecsGlobal)
case errors.Is(err, discovery.ErrNotFound):
r.logger.Debug("query name not found", "name", req.Question[0].Name)
ecsGlobal := !errors.Is(err, discovery.ErrECSNotGlobal)
return createAuthoritativeResponse(req, configCtx, responseDomain, dns.RcodeNameError, ecsGlobal)
case errors.Is(err, discovery.ErrNoData):
r.logger.Debug("no data available", "name", qName)
@ -249,13 +246,22 @@ func (r *Router) trimDomain(questionName string) string {
// getTTLForResult returns the TTL for a given result.
func getTTLForResult(name string, query *discovery.Query, cfg *RouterDynamicConfig) uint32 {
switch {
// In the case we are not making a discovery query, such as addr. or arpa. lookups,
// use the node TTL by convention
if query == nil {
return uint32(cfg.NodeTTL / time.Second)
}
switch query.QueryType {
// TODO (v2-dns): currently have to do this related to the results type being changed to node whe
// the v1 data fetcher encounters a blank service address and uses the node address instead.
// we will revisiting this when look at modifying the discovery result struct to
// possibly include additional metadata like the node address.
case query != nil && query.QueryType == discovery.QueryTypeService:
ttl, ok := cfg.GetTTLForService(name)
case discovery.QueryTypeWorkload:
// TODO (v2-dns): we need to discuss what we want to do for workload TTLs
return 0
case discovery.QueryTypeService:
ttl, ok := cfg.getTTLForService(name)
if ok {
return uint32(ttl / time.Second)
}
@ -266,9 +272,7 @@ func getTTLForResult(name string, query *discovery.Query, cfg *RouterDynamicConf
}
// getQueryResults returns a discovery.Result from a DNS message.
func (r *Router) getQueryResults(req *dns.Msg, reqCtx discovery.Context,
reqType requestType, cfg *RouterDynamicConfig, qName string) ([]*discovery.Result, *discovery.Query, error) {
var query *discovery.Query
func (r *Router) getQueryResults(req *dns.Msg, reqCtx Context, reqType requestType, qName string) ([]*discovery.Result, *discovery.Query, error) {
switch reqType {
case requestTypeConsul:
// This is a special case of discovery.QueryByName where we know that we need to query the consul service
@ -277,19 +281,26 @@ func (r *Router) getQueryResults(req *dns.Msg, reqCtx discovery.Context,
QueryType: discovery.QueryTypeService,
QueryPayload: discovery.QueryPayload{
Name: structs.ConsulServiceName,
Tenancy: discovery.QueryTenancy{
// We specify the partition here so that in the case we are a client agent in a non-default partition.
// We don't want the query processors default partition to be used.
// This is a small hack because for V1 CE, this is not the correct default partition name, but we
// need to add something to disambiguate the empty field.
Partition: resource.DefaultPartitionName,
},
},
Limit: 3, // TODO (v2-dns): need to thread this through to the backend and make sure we shuffle the results
}
results, err := r.processor.QueryByName(query, reqCtx)
results, err := r.processor.QueryByName(query, discovery.Context{Token: reqCtx.Token})
return results, query, err
case requestTypeName:
query, err := buildQueryFromDNSMessage(req, r.domain, r.altDomain, cfg, r.defaultEntMeta, r.datacenter)
query, err := buildQueryFromDNSMessage(req, reqCtx, r.domain, r.altDomain)
if err != nil {
r.logger.Error("error building discovery query from DNS request", "error", err)
return nil, query, err
}
results, err := r.processor.QueryByName(query, reqCtx)
results, err := r.processor.QueryByName(query, discovery.Context{Token: reqCtx.Token})
if err != nil {
r.logger.Error("error processing discovery query", "error", err)
return nil, query, err
@ -301,17 +312,17 @@ func (r *Router) getQueryResults(req *dns.Msg, reqCtx discovery.Context,
r.logger.Error("error building IP from DNS request", "name", qName)
return nil, nil, errNameNotFound
}
results, err := r.processor.QueryByIP(ip, reqCtx)
return results, query, err
results, err := r.processor.QueryByIP(ip, discovery.Context{Token: reqCtx.Token})
return results, nil, err
case requestTypeAddress:
results, err := buildAddressResults(req)
if err != nil {
r.logger.Error("error processing discovery query", "error", err)
return nil, query, err
return nil, nil, err
}
return results, query, nil
return results, nil, nil
}
return nil, query, errors.New("invalid request type")
return nil, nil, errors.New("invalid request type")
}
// ServeDNS implements the miekg/dns.Handler interface.
@ -332,6 +343,24 @@ func (r *Router) ReloadConfig(newCfg *config.RuntimeConfig) error {
return nil
}
// getTTLForService Find the TTL for a given service.
// return ttl, true if found, 0, false otherwise
func (cfg *RouterDynamicConfig) getTTLForService(service string) (time.Duration, bool) {
if cfg.TTLStrict != nil {
ttl, ok := cfg.TTLStrict[service]
if ok {
return ttl, true
}
}
if cfg.TTLRadix != nil {
_, ttlRaw, ok := cfg.TTLRadix.LongestPrefix(service)
if ok {
return ttlRaw.(time.Duration), true
}
}
return 0, false
}
// Request type is similar to miekg/dns.Type, but correlates to the different query processors we might need to invoke.
type requestType string
@ -391,7 +420,7 @@ func parseRequestType(req *dns.Msg) requestType {
}
// serializeQueryResults converts a discovery.Result into a DNS message.
func (r *Router) serializeQueryResults(req *dns.Msg, reqCtx discovery.Context,
func (r *Router) serializeQueryResults(req *dns.Msg, reqCtx Context,
query *discovery.Query, results []*discovery.Result, cfg *RouterDynamicConfig,
responseDomain string, remoteAddress net.Addr, maxRecursionLevel int) (*dns.Msg, error) {
resp := new(dns.Msg)
@ -430,7 +459,7 @@ func (r *Router) serializeQueryResults(req *dns.Msg, reqCtx discovery.Context,
// appendResultsToDNSResponse builds dns message from the discovery results and
// appends them to the dns response.
func (r *Router) appendResultsToDNSResponse(req *dns.Msg, reqCtx discovery.Context,
func (r *Router) appendResultsToDNSResponse(req *dns.Msg, reqCtx Context,
query *discovery.Query, resp *dns.Msg, results []*discovery.Result, cfg *RouterDynamicConfig,
responseDomain string, remoteAddress net.Addr, maxRecursionLevel int) {
@ -487,19 +516,19 @@ func (r *Router) appendResultsToDNSResponse(req *dns.Msg, reqCtx discovery.Conte
}
// defaultAgentDNSRequestContext returns a default request context based on the agent's config.
func (r *Router) defaultAgentDNSRequestContext() discovery.Context {
return discovery.Context{
func (r *Router) defaultAgentDNSRequestContext() Context {
return Context{
Token: r.tokenFunc(),
// TODO (v2-dns): tenancy information; maybe we choose not to specify and use the default
// attached to the Router (from the agent's config)
// We don't need to specify the agent's partition here because that will be handled further down the stack
// in the query processor.
}
}
// resolveCNAME is used to recursively resolve CNAME records
func (r *Router) resolveCNAME(cfg *RouterDynamicConfig, name string, reqCtx discovery.Context,
func (r *Router) resolveCNAME(cfg *RouterDynamicConfig, name string, reqCtx Context,
remoteAddress net.Addr, maxRecursionLevel int) []dns.RR {
// If the CNAME record points to a Consul address, resolve it internally
// Convert query to lowercase because DNS is case insensitive; d.domain and
// Convert query to lowercase because DNS is case-insensitive; d.domain and
// d.altDomain are already converted
if ln := strings.ToLower(name); strings.HasSuffix(ln, "."+r.domain) || strings.HasSuffix(ln, "."+r.altDomain) {
@ -609,7 +638,6 @@ func getDynamicRouterConfig(conf *config.RuntimeConfig) (*RouterDynamicConfig, e
Refresh: conf.DNSSOA.Refresh,
Retry: conf.DNSSOA.Retry,
},
EnterpriseDNSConfig: discovery.GetEnterpriseDNSConfig(conf),
}
if conf.DNSServiceTTL != nil {
@ -765,7 +793,7 @@ func buildAddressResults(req *dns.Msg) ([]*discovery.Result, error) {
}
// getAnswerAndExtra creates the dns answer and extra from discovery results.
func (r *Router) getAnswerExtraAndNs(result *discovery.Result, req *dns.Msg, reqCtx discovery.Context,
func (r *Router) getAnswerExtraAndNs(result *discovery.Result, req *dns.Msg, reqCtx Context,
query *discovery.Query, cfg *RouterDynamicConfig, domain string, remoteAddress net.Addr,
maxRecursionLevel int) (answer []dns.RR, extra []dns.RR, ns []dns.RR) {
target := newDNSAddress(result.Target)
@ -829,7 +857,7 @@ func (r *Router) getAnswerExtraAndNs(result *discovery.Result, req *dns.Msg, req
// getAnswerExtrasForAddressAndTarget creates the dns answer and extra from address and target dnsAddress pairs.
func (r *Router) getAnswerExtrasForAddressAndTarget(address *dnsAddress, target *dnsAddress, req *dns.Msg,
reqCtx discovery.Context, result *discovery.Result, ttl uint32, remoteAddress net.Addr,
reqCtx Context, result *discovery.Result, ttl uint32, remoteAddress net.Addr,
cfg *RouterDynamicConfig, maxRecursionLevel int) (answer []dns.RR, extra []dns.RR) {
qName := req.Question[0].Name
reqType := parseRequestType(req)
@ -854,7 +882,7 @@ func (r *Router) getAnswerExtrasForAddressAndTarget(address *dnsAddress, target
// Target is FQDN that point to IP
case target.IsFQDN() && address.IsIP():
var a, e []dns.RR
if result.Type == discovery.ResultTypeNode {
if result.Type == discovery.ResultTypeNode || result.Type == discovery.ResultTypeWorkload {
// if it is a node record it means the service address pointed to a node
// and the node address was used. So we create an A record for the node address,
// as well as a CNAME for the service to node mapping.
@ -977,7 +1005,7 @@ func makeIPBasedRecord(name string, addr *dnsAddress, ttl uint32) dns.RR {
}
func (r *Router) makeRecordFromFQDN(fqdn string, result *discovery.Result,
req *dns.Msg, reqCtx discovery.Context, cfg *RouterDynamicConfig, ttl uint32,
req *dns.Msg, reqCtx Context, cfg *RouterDynamicConfig, ttl uint32,
remoteAddress net.Addr, maxRecursionLevel int) ([]dns.RR, []dns.RR) {
edns := req.IsEdns0() != nil
q := req.Question[0]
@ -1039,7 +1067,7 @@ func makeSRVRecord(name, target string, result *discovery.Result, ttl uint32) *d
},
Priority: 1,
Weight: uint16(result.Weight),
Port: uint16(result.Port),
Port: uint16(result.PortNumber),
Target: target,
}
}

View File

@ -28,6 +28,9 @@ func canonicalNameForResult(result *discovery.Result, domain string) string {
// Return a simpler format for non-peering nodes.
return fmt.Sprintf("%s.node.%s.%s", result.Target, result.Tenancy.Datacenter, domain)
case discovery.ResultTypeWorkload:
if result.PortName != "" {
return fmt.Sprintf("%s.port.%s.workload.%s", result.PortName, result.Target, domain)
}
return fmt.Sprintf("%s.workload.%s", result.Target, domain)
}
return ""

View File

@ -4,37 +4,39 @@
package dns
import (
"errors"
"strings"
"github.com/miekg/dns"
"github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/discovery"
)
// buildQueryFromDNSMessage returns a discovery.Query from a DNS message.
func buildQueryFromDNSMessage(req *dns.Msg, domain, altDomain string,
cfg *RouterDynamicConfig, defaultEntMeta acl.EnterpriseMeta, defaultDatacenter string) (*discovery.Query, error) {
func buildQueryFromDNSMessage(req *dns.Msg, reqCtx Context, domain, altDomain string) (*discovery.Query, error) {
queryType, queryParts, querySuffixes := getQueryTypePartsAndSuffixesFromDNSMessage(req, domain, altDomain)
queryTenancy, err := getQueryTenancy(queryType, querySuffixes, defaultEntMeta, cfg, defaultDatacenter)
queryTenancy, err := getQueryTenancy(reqCtx, queryType, querySuffixes)
if err != nil {
return nil, err
}
name, tag := getQueryNameAndTagFromParts(queryType, queryParts)
portName := parsePort(queryParts)
if queryType == discovery.QueryTypeWorkload && req.Question[0].Qtype == dns.TypeSRV {
// Currently we do not support SRV records for workloads
return nil, errNotImplemented
}
return &discovery.Query{
QueryType: queryType,
QueryPayload: discovery.QueryPayload{
Name: name,
Tenancy: queryTenancy,
Tag: tag,
// TODO (v2-dns): what should these be?
//PortName: "",
//RemoteAddr: nil,
//DisableFailover: false,
Name: name,
Tenancy: queryTenancy,
Tag: tag,
PortName: portName,
//RemoteAddr: nil, // TODO (v2-dns): Prepared Queries for V1 Catalog
},
}, nil
}
@ -64,30 +66,48 @@ func getQueryNameAndTagFromParts(queryType discovery.QueryType, queryParts []str
}
// getQueryTenancy returns a discovery.QueryTenancy from a DNS message.
func getQueryTenancy(queryType discovery.QueryType, querySuffixes []string,
defaultEntMeta acl.EnterpriseMeta, cfg *RouterDynamicConfig, defaultDatacenter string) (discovery.QueryTenancy, error) {
if queryType == discovery.QueryTypeService {
return getQueryTenancyForService(querySuffixes, defaultEntMeta, cfg, defaultDatacenter)
func getQueryTenancy(reqCtx Context, queryType discovery.QueryType, querySuffixes []string) (discovery.QueryTenancy, error) {
labels, ok := parseLabels(querySuffixes)
if !ok {
return discovery.QueryTenancy{}, errNameNotFound
}
locality, ok := discovery.ParseLocality(querySuffixes, defaultEntMeta, cfg.EnterpriseDNSConfig)
if !ok {
return discovery.QueryTenancy{}, errors.New("invalid locality")
// If we don't have an explicit partition in the request, try the first fallback
// which was supplied in the request context. The agent's partition will be used as the last fallback
// later in the query processor.
if labels.Partition == "" {
labels.Partition = reqCtx.DefaultPartition
}
// If we have a sameness group, we can return early without further data massage.
if labels.SamenessGroup != "" {
return discovery.QueryTenancy{
Namespace: labels.Namespace,
Partition: labels.Partition,
SamenessGroup: labels.SamenessGroup,
}, nil
}
if queryType == discovery.QueryTypeVirtual {
if locality.Peer == "" {
if labels.Peer == "" {
// If the peer name was not explicitly defined, fall back to the ambiguously-parsed version.
locality.Peer = locality.PeerOrDatacenter
labels.Peer = labels.PeerOrDatacenter
}
}
return discovery.GetQueryTenancyBasedOnLocality(locality, defaultDatacenter)
return discovery.QueryTenancy{
Namespace: labels.Namespace,
Partition: labels.Partition,
Peer: labels.Peer,
Datacenter: labels.Datacenter,
}, nil
}
// getQueryTypePartsAndSuffixesFromDNSMessage returns the query type, the parts, and suffixes of the query name.
func getQueryTypePartsAndSuffixesFromDNSMessage(req *dns.Msg, domain, altDomain string) (queryType discovery.QueryType, parts []string, suffixes []string) {
// Get the QName without the domain suffix
// TODO (v2-dns): we will also need to handle the "failover" and "no-failover" suffixes here.
// They come AFTER the domain. See `stripSuffix` in router.go
qName := trimDomainFromQuestionName(req.Question[0].Name, domain, altDomain)
// Split into the label parts
@ -97,7 +117,7 @@ func getQueryTypePartsAndSuffixesFromDNSMessage(req *dns.Msg, domain, altDomain
for i := len(labels) - 1; i >= 0 && !done; i-- {
queryType = getQueryTypeFromLabels(labels[i])
switch queryType {
case discovery.QueryTypeService,
case discovery.QueryTypeService, discovery.QueryTypeWorkload,
discovery.QueryTypeConnect, discovery.QueryTypeVirtual, discovery.QueryTypeIngress,
discovery.QueryTypeNode, discovery.QueryTypePreparedQuery:
parts = labels[:i]
@ -122,7 +142,7 @@ func getQueryTypePartsAndSuffixesFromDNSMessage(req *dns.Msg, domain, altDomain
// trimDomainFromQuestionName returns the question name without the domain suffix.
func trimDomainFromQuestionName(questionName, domain, altDomain string) string {
qName := strings.ToLower(dns.Fqdn(questionName))
qName := dns.CanonicalName(questionName)
longer := domain
shorter := altDomain

View File

@ -1,24 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
//go:build !consulent
package dns
import (
"errors"
"github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/discovery"
)
// getQueryTenancy returns a discovery.QueryTenancy from a DNS message.
func getQueryTenancyForService(querySuffixes []string,
defaultEntMeta acl.EnterpriseMeta, cfg *RouterDynamicConfig, defaultDatacenter string) (discovery.QueryTenancy, error) {
locality, ok := discovery.ParseLocality(querySuffixes, defaultEntMeta, cfg.EnterpriseDNSConfig)
if !ok {
return discovery.QueryTenancy{}, errors.New("invalid locality")
}
return discovery.GetQueryTenancyBasedOnLocality(locality, defaultDatacenter)
}

View File

@ -1,81 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
//go:build !consulent
package dns
import (
"github.com/miekg/dns"
"github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/discovery"
)
func getBuildQueryFromDNSMessageTestCases() []testCaseBuildQueryFromDNSMessage {
testCases := []testCaseBuildQueryFromDNSMessage{
// virtual ip queries
{
name: "test A 'virtual.' query, ipv4 response",
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "db.virtual.consul", // "intentionally missing the trailing dot"
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
},
expectedQuery: &discovery.Query{
QueryType: discovery.QueryTypeVirtual,
QueryPayload: discovery.QueryPayload{
Name: "db",
PortName: "",
Tag: "",
Tenancy: discovery.QueryTenancy{
EnterpriseMeta: acl.EnterpriseMeta{},
SamenessGroup: "",
Peer: "consul",
Datacenter: "",
},
DisableFailover: false,
},
},
},
{
name: "test A 'virtual.' with peer query, ipv4 response",
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "db.virtual.consul", // "intentionally missing the trailing dot"
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
},
expectedQuery: &discovery.Query{
QueryType: discovery.QueryTypeVirtual,
QueryPayload: discovery.QueryPayload{
Name: "db",
PortName: "",
Tag: "",
Tenancy: discovery.QueryTenancy{
EnterpriseMeta: acl.EnterpriseMeta{},
SamenessGroup: "",
Peer: "consul", // this gets set in the query building after ParseLocality processes.
Datacenter: "",
},
DisableFailover: false,
},
},
},
}
return testCases
}

View File

@ -7,29 +7,206 @@ import (
"testing"
"github.com/miekg/dns"
"github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/discovery"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/hashicorp/consul/agent/discovery"
)
// testCaseBuildQueryFromDNSMessage is a test case for the buildQueryFromDNSMessage function.
type testCaseBuildQueryFromDNSMessage struct {
name string
request *dns.Msg
requestContext *discovery.Context
requestContext *Context
expectedQuery *discovery.Query
}
// Test_buildQueryFromDNSMessage tests the buildQueryFromDNSMessage function.
func Test_buildQueryFromDNSMessage(t *testing.T) {
testCases := getBuildQueryFromDNSMessageTestCases()
testCases := []testCaseBuildQueryFromDNSMessage{
// virtual ip queries
{
name: "test A 'virtual.' query",
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "db.virtual.consul", // "intentionally missing the trailing dot"
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
},
expectedQuery: &discovery.Query{
QueryType: discovery.QueryTypeVirtual,
QueryPayload: discovery.QueryPayload{
Name: "db",
Tenancy: discovery.QueryTenancy{},
},
},
},
{
name: "test A 'virtual.' with kitchen sink labels",
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "db.virtual.banana.ns.orange.ap.foo.peer.consul", // "intentionally missing the trailing dot"
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
},
expectedQuery: &discovery.Query{
QueryType: discovery.QueryTypeVirtual,
QueryPayload: discovery.QueryPayload{
Name: "db",
Tenancy: discovery.QueryTenancy{
Peer: "foo",
Namespace: "banana",
Partition: "orange",
},
},
},
},
{
name: "test A 'virtual.' with implicit peer",
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "db.virtual.foo.consul", // "intentionally missing the trailing dot"
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
},
expectedQuery: &discovery.Query{
QueryType: discovery.QueryTypeVirtual,
QueryPayload: discovery.QueryPayload{
Name: "db",
Tenancy: discovery.QueryTenancy{
Peer: "foo",
},
},
},
},
{
name: "test A 'virtual.' with implicit peer and namespace query",
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "db.virtual.frontend.foo.consul", // "intentionally missing the trailing dot"
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
},
expectedQuery: &discovery.Query{
QueryType: discovery.QueryTypeVirtual,
QueryPayload: discovery.QueryPayload{
Name: "db",
Tenancy: discovery.QueryTenancy{
Namespace: "frontend",
Peer: "foo",
},
},
},
},
{
name: "test A 'workload.'",
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "foo.workload.consul", // "intentionally missing the trailing dot"
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
},
expectedQuery: &discovery.Query{
QueryType: discovery.QueryTypeWorkload,
QueryPayload: discovery.QueryPayload{
Name: "foo",
Tenancy: discovery.QueryTenancy{},
},
},
},
{
name: "test A 'workload.' with all possible labels",
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "api.port.foo.workload.banana.ns.orange.ap.apple.peer.consul", // "intentionally missing the trailing dot"
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
},
expectedQuery: &discovery.Query{
QueryType: discovery.QueryTypeWorkload,
QueryPayload: discovery.QueryPayload{
Name: "foo",
PortName: "api",
Tenancy: discovery.QueryTenancy{
Namespace: "banana",
Partition: "orange",
Peer: "apple",
},
},
},
},
{
name: "test sameness group with all possible labels",
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "foo.service.apple.sg.banana.ns.orange.ap.consul", // "intentionally missing the trailing dot"
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
},
expectedQuery: &discovery.Query{
QueryType: discovery.QueryTypeService,
QueryPayload: discovery.QueryPayload{
Name: "foo",
Tenancy: discovery.QueryTenancy{
Namespace: "banana",
Partition: "orange",
SamenessGroup: "apple",
},
},
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
query, err := buildQueryFromDNSMessage(tc.request, "domain", "altDomain", &RouterDynamicConfig{}, acl.EnterpriseMeta{}, "defaultDatacenter")
context := tc.requestContext
if context == nil {
context = &Context{}
}
query, err := buildQueryFromDNSMessage(tc.request, *context, "consul.", ".")
require.NoError(t, err)
assert.Equal(t, tc.expectedQuery, query)
})

View File

@ -9,11 +9,12 @@ import (
"testing"
"time"
"github.com/hashicorp/go-hclog"
"github.com/miekg/dns"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/config"
"github.com/hashicorp/consul/agent/discovery"
@ -29,15 +30,16 @@ import (
// 4. Test the edns settings.
type HandleTestCase struct {
name string
agentConfig *config.RuntimeConfig // This will override the default test Router Config
configureDataFetcher func(fetcher discovery.CatalogDataFetcher)
configureRecursor func(recursor dnsRecursor)
mockProcessorError error
request *dns.Msg
requestContext *discovery.Context
remoteAddress net.Addr
response *dns.Msg
name string
agentConfig *config.RuntimeConfig // This will override the default test Router Config
configureDataFetcher func(fetcher discovery.CatalogDataFetcher)
validateAndNormalizeExpected bool
configureRecursor func(recursor dnsRecursor)
mockProcessorError error
request *dns.Msg
requestContext *Context
remoteAddress net.Addr
response *dns.Msg
}
func Test_HandleRequest(t *testing.T) {
@ -719,6 +721,7 @@ func Test_HandleRequest(t *testing.T) {
Type: discovery.ResultTypeVirtual,
}, nil)
},
validateAndNormalizeExpected: true,
response: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
@ -768,6 +771,7 @@ func Test_HandleRequest(t *testing.T) {
Type: discovery.ResultTypeVirtual,
}, nil)
},
validateAndNormalizeExpected: true,
response: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
@ -833,6 +837,7 @@ func Test_HandleRequest(t *testing.T) {
require.Equal(t, structs.ConsulServiceName, req.Name)
})
},
validateAndNormalizeExpected: true,
response: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
@ -954,6 +959,7 @@ func Test_HandleRequest(t *testing.T) {
require.Equal(t, structs.ConsulServiceName, req.Name)
})
},
validateAndNormalizeExpected: true,
response: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
@ -1275,6 +1281,7 @@ func Test_HandleRequest(t *testing.T) {
require.Equal(t, "foo", req.Name)
})
},
validateAndNormalizeExpected: true,
response: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
@ -1309,12 +1316,201 @@ func Test_HandleRequest(t *testing.T) {
},
},
// TODO (v2-dns): add a test to make sure only 3 records are returned
// V2 Workload Lookup
{
name: "workload A query w/ port, returns A record",
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "api.port.foo.workload.consul.",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
},
configureDataFetcher: func(fetcher discovery.CatalogDataFetcher) {
result := &discovery.Result{
Address: "1.2.3.4",
Type: discovery.ResultTypeWorkload,
Tenancy: discovery.ResultTenancy{},
PortName: "api",
PortNumber: 5678,
Target: "foo",
}
fetcher.(*discovery.MockCatalogDataFetcher).
On("FetchWorkload", mock.Anything, mock.Anything).
Return(result, nil). //TODO
Run(func(args mock.Arguments) {
req := args.Get(1).(*discovery.QueryPayload)
require.Equal(t, "foo", req.Name)
require.Equal(t, "api", req.PortName)
})
},
validateAndNormalizeExpected: true,
response: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
Response: true,
Authoritative: true,
},
Compress: true,
Question: []dns.Question{
{
Name: "api.port.foo.workload.consul.",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
Answer: []dns.RR{
&dns.A{
Hdr: dns.RR_Header{
Name: "api.port.foo.workload.consul.",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
},
A: net.ParseIP("1.2.3.4"),
},
},
},
},
{
name: "workload ANY query w/o port, returns A record",
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "foo.workload.consul.",
Qtype: dns.TypeANY,
Qclass: dns.ClassINET,
},
},
},
configureDataFetcher: func(fetcher discovery.CatalogDataFetcher) {
result := &discovery.Result{
Address: "1.2.3.4",
Type: discovery.ResultTypeWorkload,
Tenancy: discovery.ResultTenancy{},
Target: "foo",
}
fetcher.(*discovery.MockCatalogDataFetcher).
On("FetchWorkload", mock.Anything, mock.Anything).
Return(result, nil). //TODO
Run(func(args mock.Arguments) {
req := args.Get(1).(*discovery.QueryPayload)
require.Equal(t, "foo", req.Name)
require.Empty(t, req.PortName)
})
},
validateAndNormalizeExpected: true,
response: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
Response: true,
Authoritative: true,
},
Compress: true,
Question: []dns.Question{
{
Name: "foo.workload.consul.",
Qtype: dns.TypeANY,
Qclass: dns.ClassINET,
},
},
Answer: []dns.RR{
&dns.A{
Hdr: dns.RR_Header{
Name: "foo.workload.consul.",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
},
A: net.ParseIP("1.2.3.4"),
},
},
},
},
{
name: "workload AAAA query with namespace, partition, and cluster id; returns A record",
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "foo.workload.bar.ns.baz.ap.dc3.dc.consul.",
Qtype: dns.TypeAAAA,
Qclass: dns.ClassINET,
},
},
},
configureDataFetcher: func(fetcher discovery.CatalogDataFetcher) {
result := &discovery.Result{
Address: "1.2.3.4",
Type: discovery.ResultTypeWorkload,
Tenancy: discovery.ResultTenancy{
Namespace: "bar",
Partition: "baz",
Datacenter: "dc3",
},
Target: "foo",
}
fetcher.(*discovery.MockCatalogDataFetcher).
On("FetchWorkload", mock.Anything, mock.Anything).
Return(result, nil).
Run(func(args mock.Arguments) {
req := args.Get(1).(*discovery.QueryPayload)
require.Equal(t, "foo", req.Name)
require.Empty(t, req.PortName)
})
},
validateAndNormalizeExpected: true,
response: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
Response: true,
Authoritative: true,
},
Compress: true,
Question: []dns.Question{
{
Name: "foo.workload.bar.ns.baz.ap.dc3.dc.consul.",
Qtype: dns.TypeAAAA,
Qclass: dns.ClassINET,
},
},
Extra: []dns.RR{
&dns.A{
Hdr: dns.RR_Header{
Name: "foo.workload.bar.ns.baz.ap.dc3.dc.consul.",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
},
A: net.ParseIP("1.2.3.4"),
},
},
},
},
}
testCases = append(testCases, getAdditionalTestCases(t)...)
//testCases = append(testCases, getAdditionalTestCases(t)...)
run := func(t *testing.T, tc HandleTestCase) {
cdf := discovery.NewMockCatalogDataFetcher(t)
if tc.validateAndNormalizeExpected {
cdf.On("ValidateRequest", mock.Anything, mock.Anything).Return(nil)
cdf.On("NormalizeRequest", mock.Anything).Return()
}
if tc.configureDataFetcher != nil {
tc.configureDataFetcher(cdf)
}
@ -1331,7 +1527,7 @@ func Test_HandleRequest(t *testing.T) {
ctx := tc.requestContext
if ctx == nil {
ctx = &discovery.Context{}
ctx = &Context{}
}
actual := router.HandleRequest(tc.request, *ctx, tc.remoteAddress)
require.Equal(t, tc.response, actual)
@ -1391,7 +1587,7 @@ func TestRouterDynamicConfig_GetTTLForService(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
actual, ok := cfg.GetTTLForService(tc.inputKey)
actual, ok := cfg.getTTLForService(tc.inputKey)
require.Equal(t, tc.shouldMatch, ok)
require.Equal(t, tc.expectedDuration, actual)
})

View File

@ -7,12 +7,12 @@ import (
"fmt"
"net"
"github.com/hashicorp/go-hclog"
"github.com/miekg/dns"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/config"
"github.com/hashicorp/consul/agent/discovery"
"github.com/hashicorp/consul/logging"
)
@ -20,7 +20,7 @@ import (
//
//go:generate mockery --name DNSRouter --inpackage
type DNSRouter interface {
HandleRequest(req *dns.Msg, reqCtx discovery.Context, remoteAddress net.Addr) *dns.Msg
HandleRequest(req *dns.Msg, reqCtx Context, remoteAddress net.Addr) *dns.Msg
ServeDNS(w dns.ResponseWriter, req *dns.Msg)
ReloadConfig(newCfg *config.RuntimeConfig) error
}

View File

@ -8,14 +8,14 @@ import (
"fmt"
"net"
"github.com/hashicorp/go-hclog"
"github.com/miekg/dns"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/status"
"github.com/hashicorp/consul/agent/discovery"
"github.com/hashicorp/go-hclog"
agentdns "github.com/hashicorp/consul/agent/dns"
"github.com/hashicorp/consul/proto-public/pbdns"
)
@ -73,7 +73,7 @@ func (s *ServerV2) Query(ctx context.Context, req *pbdns.QueryRequest) (*pbdns.Q
}
// TODO (v2-dns): parse token and other context metadata from the grpc request/metadata
reqCtx := discovery.Context{
reqCtx := agentdns.Context{
Token: s.TokenFunc(),
}