NET-10685 - Remove dns v2 code (#21598)

* NET-10685 - Remove dns v2 code

* adding missing erro

* add missing license info.
This commit is contained in:
John Murret 2024-08-13 16:53:48 -06:00 committed by GitHub
parent 89618f9e37
commit dcad90639f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
58 changed files with 6087 additions and 19191 deletions

View File

@ -48,8 +48,6 @@ import (
"github.com/hashicorp/consul/agent/consul"
rpcRate "github.com/hashicorp/consul/agent/consul/rate"
"github.com/hashicorp/consul/agent/consul/servercert"
"github.com/hashicorp/consul/agent/discovery"
"github.com/hashicorp/consul/agent/dns"
external "github.com/hashicorp/consul/agent/grpc-external"
grpcDNS "github.com/hashicorp/consul/agent/grpc-external/services/dns"
middleware "github.com/hashicorp/consul/agent/grpc-middleware"
@ -222,7 +220,7 @@ type notifier interface {
Notify(string) error
}
// dnsServer abstracts the V1 and V2 implementations of the DNS server.
// dnsServer abstracts the implementations of the DNS server.
type dnsServer interface {
GetAddr() string
ListenAndServe(string, string, func()) error
@ -354,10 +352,6 @@ type Agent struct {
// dnsServer provides the DNS API
dnsServers []dnsServer
// catalogDataFetcher is used as an interface to the catalog for service discovery
// (aka DNS). Only applicable to the V2 DNS server (agent/dns).
catalogDataFetcher discovery.CatalogDataFetcher
// apiServers listening for connections. If any of these server goroutines
// fail, the agent will be shutdown.
apiServers *apiServers
@ -879,14 +873,8 @@ func (a *Agent) Start(ctx context.Context) error {
}
// start DNS servers
if a.baseDeps.UseV1DNS() {
if err := a.listenAndServeV1DNS(); err != nil {
return err
}
} else {
if err := a.listenAndServeV2DNS(); err != nil {
return err
}
if err := a.listenAndServeDNS(); err != nil {
return err
}
// Configure the http connection limiter.
@ -1065,7 +1053,7 @@ func (a *Agent) listenAndServeGRPC(proxyTracker *proxytracker.ProxyTracker, serv
return nil
}
func (a *Agent) listenAndServeV1DNS() error {
func (a *Agent) listenAndServeDNS() error {
notif := make(chan net.Addr, len(a.config.DNSAddrs))
errCh := make(chan error, len(a.config.DNSAddrs))
for _, addr := range a.config.DNSAddrs {
@ -1117,92 +1105,6 @@ func (a *Agent) listenAndServeV1DNS() error {
return merr.ErrorOrNil()
}
func (a *Agent) listenAndServeV2DNS() error {
// Check the catalog version and decide which implementation of the data fetcher to implement
if a.baseDeps.UseV2Resources() {
a.catalogDataFetcher = discovery.NewV2DataFetcher(a.config, a.delegate.ResourceServiceClient(), a.logger.Named("catalog-data-fetcher"))
} else {
a.catalogDataFetcher = discovery.NewV1DataFetcher(a.config,
a.AgentEnterpriseMeta(),
a.cache.Get,
a.RPC,
a.rpcClientHealth.ServiceNodes,
a.rpcClientConfigEntry.GetSamenessGroup,
a.TranslateServicePort,
a.logger.Named("catalog-data-fetcher"))
}
// Generate a Query Processor with the appropriate data fetcher
processor := discovery.NewQueryProcessor(a.catalogDataFetcher)
notif := make(chan net.Addr, len(a.config.DNSAddrs))
errCh := make(chan error, len(a.config.DNSAddrs))
// create server
cfg := dns.Config{
AgentConfig: a.config,
EntMeta: *a.AgentEnterpriseMeta(),
Logger: a.logger,
Processor: processor,
TokenFunc: a.getTokenFunc(),
TranslateAddressFunc: a.TranslateAddress,
TranslateServiceAddressFunc: a.TranslateServiceAddress,
}
for _, addr := range a.config.DNSAddrs {
s, err := dns.NewServer(cfg)
if err != nil {
return err
}
a.dnsServers = append(a.dnsServers, s)
// start server
a.wgServers.Add(1)
go func(addr net.Addr) {
defer a.wgServers.Done()
err := s.ListenAndServe(addr.Network(), addr.String(), func() { notif <- addr })
if err != nil && !strings.Contains(err.Error(), "accept") {
errCh <- err
}
}(addr)
}
s, err := dns.NewServer(cfg)
if err != nil {
return fmt.Errorf("failed to create grpc dns server: %w", err)
}
// Create a v2 compatible grpc dns server
grpcDNS.NewServerV2(grpcDNS.ConfigV2{
Logger: a.logger.Named("grpc-api.dns"),
DNSRouter: s.Router,
TokenFunc: a.getTokenFunc(),
}).Register(a.externalGRPCServer)
a.dnsServers = append(a.dnsServers, s)
// wait for servers to be up
timeout := time.After(time.Second)
var merr *multierror.Error
for range a.config.DNSAddrs {
select {
case addr := <-notif:
a.logger.Info("Started DNS server",
"address", addr.String(),
"network", addr.Network(),
)
case err := <-errCh:
merr = multierror.Append(merr, err)
case <-timeout:
merr = multierror.Append(merr, fmt.Errorf("agent: timeout starting DNS servers"))
return merr.ErrorOrNil()
}
}
return merr.ErrorOrNil()
}
// startListeners will return a net.Listener for every address unless an
// error is encountered, in which case it will close all previously opened
// listeners and return the error.
@ -4414,10 +4316,6 @@ func (a *Agent) reloadConfigInternal(newCfg *config.RuntimeConfig) error {
return fmt.Errorf("Failed reloading dns config : %v", err)
}
}
// This field is only populated for the V2 DNS server
if a.catalogDataFetcher != nil {
a.catalogDataFetcher.LoadConfig(newCfg)
}
err := a.reloadEnterprise(newCfg)
if err != nil {

View File

@ -618,9 +618,6 @@ func TestBuilder_CheckExperimentsInSecondaryDatacenters(t *testing.T) {
"primary server v2catalog": {
hcl: primary + `experiments = ["resource-apis"]`,
},
"primary server v1dns": {
hcl: primary + `experiments = ["v1dns"]`,
},
"primary server v2tenancy": {
hcl: primary + `experiments = ["v2tenancy"]`,
},
@ -631,9 +628,6 @@ func TestBuilder_CheckExperimentsInSecondaryDatacenters(t *testing.T) {
hcl: secondary + `experiments = ["resource-apis"]`,
expectErr: true,
},
"secondary server v1dns": {
hcl: secondary + `experiments = ["v1dns"]`,
},
"secondary server v2tenancy": {
hcl: secondary + `experiments = ["v2tenancy"]`,
expectErr: true,

View File

@ -50,15 +50,6 @@ type Deps struct {
EnterpriseDeps
}
// UseV1DNS returns true if "v1dns" is present in the Experiments
// array of the agent config. It is ignored if the v2 resource APIs are enabled.
func (d Deps) UseV1DNS() bool {
if stringslice.Contains(d.Experiments, V1DNSExperimentName) && !d.UseV2Resources() {
return true
}
return false
}
// UseV2Resources returns true if "resource-apis" is present in the Experiments
// array of the agent config.
func (d Deps) UseV2Resources() bool {

View File

@ -133,7 +133,6 @@ const (
LeaderTransferMinVersion = "1.6.0"
CatalogResourceExperimentName = "resource-apis"
V1DNSExperimentName = "v1dns"
V2TenancyExperimentName = "v2tenancy"
HCPAllowV2ResourceAPIs = "hcp-v2-resource-apis"
)

View File

@ -1,250 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package discovery
import (
"fmt"
"net"
"github.com/hashicorp/consul/agent/config"
)
var (
ErrECSNotGlobal = fmt.Errorf("ECS response is not global")
ErrNoData = fmt.Errorf("no data")
ErrNotFound = fmt.Errorf("not found")
ErrNotSupported = fmt.Errorf("not supported")
ErrNoPathToDatacenter = fmt.Errorf("no path to datacenter")
)
// ECSNotGlobalError may be used to wrap an error or nil, to indicate that the
// EDNS client subnet source scope is not global.
type ECSNotGlobalError struct {
error
}
func (e ECSNotGlobalError) Error() string {
if e.error == nil {
return ""
}
return e.error.Error()
}
func (e ECSNotGlobalError) Is(other error) bool {
return other == ErrECSNotGlobal
}
func (e ECSNotGlobalError) Unwrap() error {
return e.error
}
// Query is used to request a name-based Service Discovery lookup.
type Query struct {
QueryType QueryType
QueryPayload QueryPayload
}
// QueryType is used to filter service endpoints.
// This is needed by the V1 catalog because of the
// overlapping lookups through the service endpoint.
type QueryType string
const (
QueryTypeConnect QueryType = "CONNECT" // deprecated: use for V1 only
QueryTypeIngress QueryType = "INGRESS" // deprecated: use for V1 only
QueryTypeInvalid QueryType = "INVALID"
QueryTypeNode QueryType = "NODE"
QueryTypePreparedQuery QueryType = "PREPARED_QUERY" // deprecated: use for V1 only
QueryTypeService QueryType = "SERVICE"
QueryTypeVirtual QueryType = "VIRTUAL"
QueryTypeWorkload QueryType = "WORKLOAD" // V2-only
)
// Context is used to pass information about the request.
type Context struct {
Token string
}
// QueryTenancy is used to filter catalog data based on tenancy.
type QueryTenancy struct {
Namespace string
Partition string
SamenessGroup string
Peer string
Datacenter string
}
// QueryPayload represents all information needed by the data backend
// to decide which records to include.
type QueryPayload struct {
Name string
PortName string // v1 - this could optionally be "connect" or "ingress"; v2 - this is the service port name
Tag string // deprecated: use for V1 only
SourceIP net.IP // deprecated: used for prepared queries
Tenancy QueryTenancy // tenancy includes any additional labels specified before the domain
Limit int // The maximum number of records to return
// v2 fields only
EnableFailover bool
}
// ResultType indicates the Consul resource that a discovery record represents.
// This is useful for things like adding TTLs for different objects in the DNS.
type ResultType string
const (
ResultTypeService ResultType = "SERVICE"
ResultTypeNode ResultType = "NODE"
ResultTypeVirtual ResultType = "VIRTUAL"
ResultTypeWorkload ResultType = "WORKLOAD"
)
// Result is a generic format of targets that could be returned in a query.
// It is the responsibility of the DNS encoder to know what to do with
// each Result, based on the query type.
type Result struct {
Service *Location // The name and address of the service.
Node *Location // The name and address of the node.
Metadata map[string]string // Used to collect metadata into TXT Records
Type ResultType // Used to reconstruct the fqdn name of the resource
DNS DNSConfig // Used for DNS-specific configuration for this result
// Ports include anything the node/service/workload implements. These are filtered if requested by the client.
// They are used in to generate the FQDN and SRV port numbers in V2 Catalog responses.
Ports []Port
Tenancy ResultTenancy
}
// TaggedAddress is used to represent a tagged address.
type TaggedAddress struct {
Name string
Address string
Port Port
}
// Location is used to represent a service, node, or workload.
type Location struct {
Name string
Address string
TaggedAddresses map[string]*TaggedAddress // Used to collect tagged addresses into A/AAAA Records
}
type DNSConfig struct {
TTL *uint32 // deprecated: use for V1 prepared queries only
Weight uint32 // SRV queries
}
type Port struct {
Name string
Number uint32
}
// ResultTenancy is used to reconstruct the fqdn name of the resource.
type ResultTenancy struct {
Namespace string
Partition string
PeerName string
Datacenter string
}
// LookupType is used by the CatalogDataFetcher to properly filter endpoints.
type LookupType string
const (
LookupTypeService LookupType = "SERVICE"
LookupTypeConnect LookupType = "CONNECT"
LookupTypeIngress LookupType = "INGRESS"
)
// CatalogDataFetcher is an interface that abstracts data collection
// for Discovery queries. It is assumed that the instantiation also
// includes any agent configuration that influences catalog queries.
//
//go:generate mockery --name CatalogDataFetcher --inpackage
type CatalogDataFetcher interface {
// LoadConfig is used to hot-reload the data fetcher with new agent config.
LoadConfig(config *config.RuntimeConfig)
// FetchNodes fetches A/AAAA/CNAME
FetchNodes(ctx Context, req *QueryPayload) ([]*Result, error)
// FetchEndpoints fetches records for A/AAAA/CNAME or SRV requests for services
FetchEndpoints(ctx Context, req *QueryPayload, lookupType LookupType) ([]*Result, error)
// FetchVirtualIP fetches A/AAAA records for virtual IPs
FetchVirtualIP(ctx Context, req *QueryPayload) (*Result, error)
// FetchRecordsByIp is used for PTR requests
// to look up a service/node from an IP.
FetchRecordsByIp(ctx Context, ip net.IP) ([]*Result, error)
// FetchWorkload fetches a single Result associated with
// V2 Workload. V2-only.
FetchWorkload(ctx Context, req *QueryPayload) (*Result, error)
// FetchPreparedQuery evaluates the results of a prepared query.
// deprecated in V2
FetchPreparedQuery(ctx Context, req *QueryPayload) ([]*Result, error)
// NormalizeRequest mutates the original request based on data fetcher configuration, like
// defaulting tenancy to the agent's partition.
NormalizeRequest(req *QueryPayload)
// ValidateRequest throws an error is any of the input fields are invalid for this data fetcher.
ValidateRequest(ctx Context, req *QueryPayload) error
}
// QueryProcessor is used to process a Discovery Query and return the results.
type QueryProcessor struct {
dataFetcher CatalogDataFetcher
}
// NewQueryProcessor creates a new QueryProcessor.
func NewQueryProcessor(dataFetcher CatalogDataFetcher) *QueryProcessor {
return &QueryProcessor{
dataFetcher: dataFetcher,
}
}
// QueryByName is used to look up a service, node, workload, or prepared query.
func (p *QueryProcessor) QueryByName(query *Query, ctx Context) ([]*Result, error) {
if err := p.dataFetcher.ValidateRequest(ctx, &query.QueryPayload); err != nil {
return nil, err
}
p.dataFetcher.NormalizeRequest(&query.QueryPayload)
switch query.QueryType {
case QueryTypeNode:
return p.dataFetcher.FetchNodes(ctx, &query.QueryPayload)
case QueryTypeService:
return p.dataFetcher.FetchEndpoints(ctx, &query.QueryPayload, LookupTypeService)
case QueryTypeConnect:
return p.dataFetcher.FetchEndpoints(ctx, &query.QueryPayload, LookupTypeConnect)
case QueryTypeIngress:
return p.dataFetcher.FetchEndpoints(ctx, &query.QueryPayload, LookupTypeIngress)
case QueryTypeVirtual:
result, err := p.dataFetcher.FetchVirtualIP(ctx, &query.QueryPayload)
if err != nil {
return nil, err
}
return []*Result{result}, nil
case QueryTypeWorkload:
result, err := p.dataFetcher.FetchWorkload(ctx, &query.QueryPayload)
if err != nil {
return nil, err
}
return []*Result{result}, nil
case QueryTypePreparedQuery:
return p.dataFetcher.FetchPreparedQuery(ctx, &query.QueryPayload)
default:
return nil, fmt.Errorf("unknown query type: %s", query.QueryType)
}
}
// QueryByIP is used to look up a service or node from an IP address.
func (p *QueryProcessor) QueryByIP(ip net.IP, reqCtx Context) ([]*Result, error) {
return p.dataFetcher.FetchRecordsByIp(reqCtx, ip)
}

View File

@ -1,221 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package discovery
import (
"errors"
"net"
"testing"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
var (
testContext = Context{
Token: "bar",
}
testErr = errors.New("test error")
testIP = net.ParseIP("1.2.3.4")
testPayload = QueryPayload{
Name: "foo",
}
testResult = &Result{
Node: &Location{Address: "1.2.3.4"},
Type: ResultTypeNode, // This isn't correct for some test cases, but we are only asserting the right data fetcher functions are called
Service: &Location{Name: "foo"},
}
)
func TestQueryByName(t *testing.T) {
type testCase struct {
name string
reqType QueryType
configureDataFetcher func(*testing.T, *MockCatalogDataFetcher)
expectedResults []*Result
expectedError error
}
run := func(t *testing.T, tc testCase) {
fetcher := NewMockCatalogDataFetcher(t)
tc.configureDataFetcher(t, fetcher)
qp := NewQueryProcessor(fetcher)
q := Query{
QueryType: tc.reqType,
QueryPayload: testPayload,
}
results, err := qp.QueryByName(&q, testContext)
if tc.expectedError != nil {
require.Error(t, err)
require.True(t, errors.Is(err, tc.expectedError))
require.Nil(t, results)
return
}
require.NoError(t, err)
require.Equal(t, tc.expectedResults, results)
}
testCases := []testCase{
{
name: "query node",
reqType: QueryTypeNode,
configureDataFetcher: func(t *testing.T, fetcher *MockCatalogDataFetcher) {
fetcher.On("ValidateRequest", mock.Anything, mock.Anything).Return(nil)
fetcher.On("NormalizeRequest", mock.Anything)
fetcher.On("FetchNodes", mock.Anything, mock.Anything).Return([]*Result{testResult}, nil)
},
expectedResults: []*Result{testResult},
},
{
name: "query service",
reqType: QueryTypeService,
configureDataFetcher: func(t *testing.T, fetcher *MockCatalogDataFetcher) {
fetcher.On("ValidateRequest", mock.Anything, mock.Anything).Return(nil)
fetcher.On("NormalizeRequest", mock.Anything)
fetcher.On("FetchEndpoints", mock.Anything, mock.Anything, mock.Anything).Return([]*Result{testResult}, nil)
},
expectedResults: []*Result{testResult},
},
{
name: "query connect",
reqType: QueryTypeConnect,
configureDataFetcher: func(t *testing.T, fetcher *MockCatalogDataFetcher) {
fetcher.On("ValidateRequest", mock.Anything, mock.Anything).Return(nil)
fetcher.On("NormalizeRequest", mock.Anything)
fetcher.On("FetchEndpoints", mock.Anything, mock.Anything, mock.Anything).Return([]*Result{testResult}, nil)
},
expectedResults: []*Result{testResult},
},
{
name: "query ingress",
reqType: QueryTypeIngress,
configureDataFetcher: func(t *testing.T, fetcher *MockCatalogDataFetcher) {
fetcher.On("ValidateRequest", mock.Anything, mock.Anything).Return(nil)
fetcher.On("NormalizeRequest", mock.Anything)
fetcher.On("FetchEndpoints", mock.Anything, mock.Anything, mock.Anything).Return([]*Result{testResult}, nil)
},
expectedResults: []*Result{testResult},
},
{
name: "query virtual ip",
reqType: QueryTypeVirtual,
configureDataFetcher: func(t *testing.T, fetcher *MockCatalogDataFetcher) {
fetcher.On("ValidateRequest", mock.Anything, mock.Anything).Return(nil)
fetcher.On("NormalizeRequest", mock.Anything)
fetcher.On("FetchVirtualIP", mock.Anything, mock.Anything).Return(testResult, nil)
},
expectedResults: []*Result{testResult},
},
{
name: "query workload",
reqType: QueryTypeWorkload,
configureDataFetcher: func(t *testing.T, fetcher *MockCatalogDataFetcher) {
fetcher.On("ValidateRequest", mock.Anything, mock.Anything).Return(nil)
fetcher.On("NormalizeRequest", mock.Anything)
fetcher.On("FetchWorkload", mock.Anything, mock.Anything).Return(testResult, nil)
},
expectedResults: []*Result{testResult},
},
{
name: "query prepared query",
reqType: QueryTypePreparedQuery,
configureDataFetcher: func(t *testing.T, fetcher *MockCatalogDataFetcher) {
fetcher.On("ValidateRequest", mock.Anything, mock.Anything).Return(nil)
fetcher.On("NormalizeRequest", mock.Anything)
fetcher.On("FetchPreparedQuery", mock.Anything, mock.Anything).Return([]*Result{testResult}, nil)
},
expectedResults: []*Result{testResult},
},
{
name: "returns error from validation",
reqType: QueryTypePreparedQuery,
configureDataFetcher: func(t *testing.T, fetcher *MockCatalogDataFetcher) {
fetcher.On("ValidateRequest", mock.Anything, mock.Anything).Return(testErr)
},
expectedError: testErr,
},
{
name: "returns error from fetcher",
reqType: QueryTypePreparedQuery,
configureDataFetcher: func(t *testing.T, fetcher *MockCatalogDataFetcher) {
fetcher.On("ValidateRequest", mock.Anything, mock.Anything).Return(nil)
fetcher.On("NormalizeRequest", mock.Anything)
fetcher.On("FetchPreparedQuery", mock.Anything, mock.Anything).Return(nil, testErr)
},
expectedError: testErr,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
run(t, tc)
})
}
}
func TestQueryByIP(t *testing.T) {
type testCase struct {
name string
configureDataFetcher func(*testing.T, *MockCatalogDataFetcher)
expectedResults []*Result
expectedError error
}
run := func(t *testing.T, tc testCase) {
fetcher := NewMockCatalogDataFetcher(t)
tc.configureDataFetcher(t, fetcher)
qp := NewQueryProcessor(fetcher)
results, err := qp.QueryByIP(testIP, testContext)
if tc.expectedError != nil {
require.Error(t, err)
require.True(t, errors.Is(err, tc.expectedError))
require.Nil(t, results)
return
}
require.NoError(t, err)
require.Equal(t, tc.expectedResults, results)
}
testCases := []testCase{
{
name: "query by IP",
configureDataFetcher: func(t *testing.T, fetcher *MockCatalogDataFetcher) {
fetcher.On("FetchRecordsByIp", mock.Anything, mock.Anything).Return([]*Result{testResult}, nil)
},
expectedResults: []*Result{testResult},
},
{
name: "returns error from fetcher",
configureDataFetcher: func(t *testing.T, fetcher *MockCatalogDataFetcher) {
fetcher.On("FetchRecordsByIp", mock.Anything, mock.Anything).Return(nil, testErr)
},
expectedError: testErr,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
run(t, tc)
})
}
}

View File

@ -1,209 +0,0 @@
// Code generated by mockery v2.37.1. DO NOT EDIT.
package discovery
import (
config "github.com/hashicorp/consul/agent/config"
mock "github.com/stretchr/testify/mock"
net "net"
)
// MockCatalogDataFetcher is an autogenerated mock type for the CatalogDataFetcher type
type MockCatalogDataFetcher struct {
mock.Mock
}
// FetchEndpoints provides a mock function with given fields: ctx, req, lookupType
func (_m *MockCatalogDataFetcher) FetchEndpoints(ctx Context, req *QueryPayload, lookupType LookupType) ([]*Result, error) {
ret := _m.Called(ctx, req, lookupType)
var r0 []*Result
var r1 error
if rf, ok := ret.Get(0).(func(Context, *QueryPayload, LookupType) ([]*Result, error)); ok {
return rf(ctx, req, lookupType)
}
if rf, ok := ret.Get(0).(func(Context, *QueryPayload, LookupType) []*Result); ok {
r0 = rf(ctx, req, lookupType)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*Result)
}
}
if rf, ok := ret.Get(1).(func(Context, *QueryPayload, LookupType) error); ok {
r1 = rf(ctx, req, lookupType)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// FetchNodes provides a mock function with given fields: ctx, req
func (_m *MockCatalogDataFetcher) FetchNodes(ctx Context, req *QueryPayload) ([]*Result, error) {
ret := _m.Called(ctx, req)
var r0 []*Result
var r1 error
if rf, ok := ret.Get(0).(func(Context, *QueryPayload) ([]*Result, error)); ok {
return rf(ctx, req)
}
if rf, ok := ret.Get(0).(func(Context, *QueryPayload) []*Result); ok {
r0 = rf(ctx, req)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*Result)
}
}
if rf, ok := ret.Get(1).(func(Context, *QueryPayload) error); ok {
r1 = rf(ctx, req)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// FetchPreparedQuery provides a mock function with given fields: ctx, req
func (_m *MockCatalogDataFetcher) FetchPreparedQuery(ctx Context, req *QueryPayload) ([]*Result, error) {
ret := _m.Called(ctx, req)
var r0 []*Result
var r1 error
if rf, ok := ret.Get(0).(func(Context, *QueryPayload) ([]*Result, error)); ok {
return rf(ctx, req)
}
if rf, ok := ret.Get(0).(func(Context, *QueryPayload) []*Result); ok {
r0 = rf(ctx, req)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*Result)
}
}
if rf, ok := ret.Get(1).(func(Context, *QueryPayload) error); ok {
r1 = rf(ctx, req)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// FetchRecordsByIp provides a mock function with given fields: ctx, ip
func (_m *MockCatalogDataFetcher) FetchRecordsByIp(ctx Context, ip net.IP) ([]*Result, error) {
ret := _m.Called(ctx, ip)
var r0 []*Result
var r1 error
if rf, ok := ret.Get(0).(func(Context, net.IP) ([]*Result, error)); ok {
return rf(ctx, ip)
}
if rf, ok := ret.Get(0).(func(Context, net.IP) []*Result); ok {
r0 = rf(ctx, ip)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*Result)
}
}
if rf, ok := ret.Get(1).(func(Context, net.IP) error); ok {
r1 = rf(ctx, ip)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// FetchVirtualIP provides a mock function with given fields: ctx, req
func (_m *MockCatalogDataFetcher) FetchVirtualIP(ctx Context, req *QueryPayload) (*Result, error) {
ret := _m.Called(ctx, req)
var r0 *Result
var r1 error
if rf, ok := ret.Get(0).(func(Context, *QueryPayload) (*Result, error)); ok {
return rf(ctx, req)
}
if rf, ok := ret.Get(0).(func(Context, *QueryPayload) *Result); ok {
r0 = rf(ctx, req)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*Result)
}
}
if rf, ok := ret.Get(1).(func(Context, *QueryPayload) error); ok {
r1 = rf(ctx, req)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// FetchWorkload provides a mock function with given fields: ctx, req
func (_m *MockCatalogDataFetcher) FetchWorkload(ctx Context, req *QueryPayload) (*Result, error) {
ret := _m.Called(ctx, req)
var r0 *Result
var r1 error
if rf, ok := ret.Get(0).(func(Context, *QueryPayload) (*Result, error)); ok {
return rf(ctx, req)
}
if rf, ok := ret.Get(0).(func(Context, *QueryPayload) *Result); ok {
r0 = rf(ctx, req)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*Result)
}
}
if rf, ok := ret.Get(1).(func(Context, *QueryPayload) error); ok {
r1 = rf(ctx, req)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// LoadConfig provides a mock function with given fields: _a0
func (_m *MockCatalogDataFetcher) LoadConfig(_a0 *config.RuntimeConfig) {
_m.Called(_a0)
}
// NormalizeRequest provides a mock function with given fields: req
func (_m *MockCatalogDataFetcher) NormalizeRequest(req *QueryPayload) {
_m.Called(req)
}
// ValidateRequest provides a mock function with given fields: ctx, req
func (_m *MockCatalogDataFetcher) ValidateRequest(ctx Context, req *QueryPayload) error {
ret := _m.Called(ctx, req)
var r0 error
if rf, ok := ret.Get(0).(func(Context, *QueryPayload) error); ok {
r0 = rf(ctx, req)
} else {
r0 = ret.Error(0)
}
return r0
}
// NewMockCatalogDataFetcher creates a new instance of MockCatalogDataFetcher. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewMockCatalogDataFetcher(t interface {
mock.TestingT
Cleanup(func())
}) *MockCatalogDataFetcher {
mock := &MockCatalogDataFetcher{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}

View File

@ -1,649 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package discovery
import (
"context"
"fmt"
"net"
"strings"
"sync/atomic"
"time"
"github.com/armon/go-metrics"
"github.com/armon/go-metrics/prometheus"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/cache"
cachetype "github.com/hashicorp/consul/agent/cache-types"
"github.com/hashicorp/consul/agent/config"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/api"
)
const (
// Increment a counter when requests staler than this are served
staleCounterThreshold = 5 * time.Second
)
// DNSCounters pre-registers the staleness metric.
// This value is used by both the V1 and V2 DNS (V1 Catalog-only) servers.
var DNSCounters = []prometheus.CounterDefinition{
{
Name: []string{"dns", "stale_queries"},
Help: "Increments when an agent serves a query within the allowed stale threshold.",
},
}
// V1DataFetcherDynamicConfig is used to store the dynamic configuration of the V1 data fetcher.
type V1DataFetcherDynamicConfig struct {
// Default request tenancy
Datacenter string
SegmentName string
NodeName string
NodePartition string
// Catalog configuration
AllowStale bool
MaxStale time.Duration
UseCache bool
CacheMaxAge time.Duration
OnlyPassing bool
}
// V1DataFetcher is used to fetch data from the V1 catalog.
type V1DataFetcher struct {
defaultEnterpriseMeta acl.EnterpriseMeta
dynamicConfig atomic.Value
logger hclog.Logger
getFromCacheFunc func(ctx context.Context, t string, r cache.Request) (interface{}, cache.ResultMeta, error)
rpcFunc func(ctx context.Context, method string, args interface{}, reply interface{}) error
rpcFuncForServiceNodes func(ctx context.Context, req structs.ServiceSpecificRequest) (structs.IndexedCheckServiceNodes, cache.ResultMeta, error)
rpcFuncForSamenessGroup func(ctx context.Context, req *structs.ConfigEntryQuery) (structs.SamenessGroupConfigEntry, cache.ResultMeta, error)
translateServicePortFunc func(dc string, port int, taggedAddresses map[string]structs.ServiceAddress) int
}
// NewV1DataFetcher creates a new V1 data fetcher.
func NewV1DataFetcher(config *config.RuntimeConfig,
entMeta *acl.EnterpriseMeta,
getFromCacheFunc func(ctx context.Context, t string, r cache.Request) (interface{}, cache.ResultMeta, error),
rpcFunc func(ctx context.Context, method string, args interface{}, reply interface{}) error,
rpcFuncForServiceNodes func(ctx context.Context, req structs.ServiceSpecificRequest) (structs.IndexedCheckServiceNodes, cache.ResultMeta, error),
rpcFuncForSamenessGroup func(ctx context.Context, req *structs.ConfigEntryQuery) (structs.SamenessGroupConfigEntry, cache.ResultMeta, error),
translateServicePortFunc func(dc string, port int, taggedAddresses map[string]structs.ServiceAddress) int,
logger hclog.Logger) *V1DataFetcher {
f := &V1DataFetcher{
defaultEnterpriseMeta: *entMeta,
getFromCacheFunc: getFromCacheFunc,
rpcFunc: rpcFunc,
rpcFuncForServiceNodes: rpcFuncForServiceNodes,
rpcFuncForSamenessGroup: rpcFuncForSamenessGroup,
translateServicePortFunc: translateServicePortFunc,
logger: logger,
}
f.LoadConfig(config)
return f
}
// LoadConfig loads the configuration for the V1 data fetcher.
func (f *V1DataFetcher) LoadConfig(config *config.RuntimeConfig) {
dynamicConfig := &V1DataFetcherDynamicConfig{
AllowStale: config.DNSAllowStale,
MaxStale: config.DNSMaxStale,
UseCache: config.DNSUseCache,
CacheMaxAge: config.DNSCacheMaxAge,
OnlyPassing: config.DNSOnlyPassing,
Datacenter: config.Datacenter,
SegmentName: config.SegmentName,
NodeName: config.NodeName,
}
f.dynamicConfig.Store(dynamicConfig)
}
func (f *V1DataFetcher) GetConfig() *V1DataFetcherDynamicConfig {
return f.dynamicConfig.Load().(*V1DataFetcherDynamicConfig)
}
// FetchNodes fetches A/AAAA/CNAME
func (f *V1DataFetcher) FetchNodes(ctx Context, req *QueryPayload) ([]*Result, error) {
if req.Tenancy.Namespace != "" && req.Tenancy.Namespace != acl.DefaultNamespaceName {
// Nodes are not namespaced, so this is a name error
return nil, ErrNotFound
}
cfg := f.dynamicConfig.Load().(*V1DataFetcherDynamicConfig)
// If no datacenter is passed, default to our own
datacenter := cfg.Datacenter
if req.Tenancy.Datacenter != "" {
datacenter = req.Tenancy.Datacenter
}
// Make an RPC request
args := &structs.NodeSpecificRequest{
Datacenter: datacenter,
PeerName: req.Tenancy.Peer,
Node: req.Name,
QueryOptions: structs.QueryOptions{
Token: ctx.Token,
AllowStale: cfg.AllowStale,
},
EnterpriseMeta: queryTenancyToEntMeta(req.Tenancy),
}
out, err := f.fetchNode(cfg, args)
if err != nil {
return nil, fmt.Errorf("failed rpc request: %w", err)
}
// If we have no out.NodeServices.Nodeaddress, return not found!
if out.NodeServices == nil {
return nil, ErrNotFound
}
results := make([]*Result, 0, 1)
n := out.NodeServices.Node
results = append(results, &Result{
Node: &Location{
Name: n.Node,
Address: n.Address,
TaggedAddresses: makeTaggedAddressesFromStrings(n.TaggedAddresses),
},
Type: ResultTypeNode,
Metadata: n.Meta,
Tenancy: ResultTenancy{
// Namespace is not required because nodes are not namespaced
Partition: n.GetEnterpriseMeta().PartitionOrDefault(),
Datacenter: n.Datacenter,
},
})
return results, nil
}
// FetchEndpoints fetches records for A/AAAA/CNAME or SRV requests for services
func (f *V1DataFetcher) FetchEndpoints(ctx Context, req *QueryPayload, lookupType LookupType) ([]*Result, error) {
f.logger.Trace(fmt.Sprintf("FetchEndpoints - req: %+v / lookupType: %+v", req, lookupType))
cfg := f.dynamicConfig.Load().(*V1DataFetcherDynamicConfig)
return f.fetchService(ctx, req, cfg, lookupType)
}
// FetchVirtualIP fetches A/AAAA records for virtual IPs
func (f *V1DataFetcher) FetchVirtualIP(ctx Context, req *QueryPayload) (*Result, error) {
args := structs.ServiceSpecificRequest{
// The Datacenter of the request is not specified because cross-Datacenter virtual IP
// queries are not supported. This guard rail is in place because virtual IPs are allocated
// within a DC, therefore their uniqueness is not guaranteed globally.
PeerName: req.Tenancy.Peer,
ServiceName: req.Name,
EnterpriseMeta: queryTenancyToEntMeta(req.Tenancy),
QueryOptions: structs.QueryOptions{
Token: ctx.Token,
},
}
var out string
if err := f.rpcFunc(context.Background(), "Catalog.VirtualIPForService", &args, &out); err != nil {
return nil, err
}
result := &Result{
Service: &Location{
Name: req.Name,
Address: out,
},
Type: ResultTypeVirtual,
}
return result, nil
}
// FetchRecordsByIp is used for PTR requests to look up a service/node from an IP.
// The search is performed in the agent's partition and over all namespaces (or those allowed by the ACL token).
func (f *V1DataFetcher) FetchRecordsByIp(reqCtx Context, ip net.IP) ([]*Result, error) {
if ip == nil {
return nil, ErrNotSupported
}
configCtx := f.dynamicConfig.Load().(*V1DataFetcherDynamicConfig)
targetIP := ip.String()
var results []*Result
args := structs.DCSpecificRequest{
Datacenter: configCtx.Datacenter,
QueryOptions: structs.QueryOptions{
Token: reqCtx.Token,
AllowStale: configCtx.AllowStale,
},
}
var out structs.IndexedNodes
// TODO: Replace ListNodes with an internal RPC that can do the filter
// server side to avoid transferring the entire node list.
if err := f.rpcFunc(context.Background(), "Catalog.ListNodes", &args, &out); err == nil {
for _, n := range out.Nodes {
if targetIP == n.Address {
results = append(results, &Result{
Node: &Location{
Name: n.Node,
Address: n.Address,
TaggedAddresses: makeTaggedAddressesFromStrings(n.TaggedAddresses),
},
Type: ResultTypeNode,
Tenancy: ResultTenancy{
Namespace: f.defaultEnterpriseMeta.NamespaceOrDefault(),
Partition: f.defaultEnterpriseMeta.PartitionOrDefault(),
Datacenter: configCtx.Datacenter,
},
})
return results, nil
}
}
}
// only look into the services if we didn't find a node
sargs := structs.ServiceSpecificRequest{
Datacenter: configCtx.Datacenter,
QueryOptions: structs.QueryOptions{
Token: reqCtx.Token,
AllowStale: configCtx.AllowStale,
},
ServiceAddress: targetIP,
EnterpriseMeta: *f.defaultEnterpriseMeta.WithWildcardNamespace(),
}
var sout structs.IndexedServiceNodes
if err := f.rpcFunc(context.Background(), "Catalog.ServiceNodes", &sargs, &sout); err == nil {
if len(sout.ServiceNodes) == 0 {
return nil, ErrNotFound
}
for _, n := range sout.ServiceNodes {
if n.ServiceAddress == targetIP {
results = append(results, &Result{
Service: &Location{
Name: n.ServiceName,
Address: n.ServiceAddress,
},
Type: ResultTypeService,
Node: &Location{
Name: n.Node,
Address: n.Address,
},
Tenancy: ResultTenancy{
Namespace: n.NamespaceOrEmpty(),
Partition: n.PartitionOrEmpty(),
Datacenter: n.Datacenter,
},
})
return results, nil
}
}
}
// nothing found locally, recurse
// TODO: (v2-dns) implement recursion (NET-7883)
//d.handleRecurse(resp, req)
return nil, fmt.Errorf("unhandled error in FetchRecordsByIp")
}
// FetchWorkload fetches a single Result associated with
// V2 Workload. V2-only.
func (f *V1DataFetcher) FetchWorkload(ctx Context, req *QueryPayload) (*Result, error) {
return nil, ErrNotSupported
}
// FetchPreparedQuery evaluates the results of a prepared query.
// deprecated in V2
func (f *V1DataFetcher) FetchPreparedQuery(ctx Context, req *QueryPayload) ([]*Result, error) {
cfg := f.dynamicConfig.Load().(*V1DataFetcherDynamicConfig)
// If no datacenter is passed, default to our own
datacenter := cfg.Datacenter
if req.Tenancy.Datacenter != "" {
datacenter = req.Tenancy.Datacenter
}
// Execute the prepared query.
args := structs.PreparedQueryExecuteRequest{
Datacenter: datacenter,
QueryIDOrName: req.Name,
QueryOptions: structs.QueryOptions{
Token: ctx.Token,
AllowStale: cfg.AllowStale,
MaxAge: cfg.CacheMaxAge,
},
// Always pass the local agent through. In the DNS interface, there
// is no provision for passing additional query parameters, so we
// send the local agent's data through to allow distance sorting
// relative to ourself on the server side.
Agent: structs.QuerySource{
Datacenter: cfg.Datacenter,
Segment: cfg.SegmentName,
Node: cfg.NodeName,
NodePartition: cfg.NodePartition,
},
Source: structs.QuerySource{
Ip: req.SourceIP.String(),
},
}
out, err := f.executePreparedQuery(cfg, args)
if err != nil {
// errors.Is() doesn't work with errors.New() so we need to check the error message.
if err.Error() == structs.ErrQueryNotFound.Error() {
err = ErrNotFound
}
return nil, ECSNotGlobalError{err}
}
// TODO (slackpad) - What's a safe limit we can set here? It seems like
// with dup filtering done at this level we need to get everything to
// match the previous behavior. We can optimize by pushing more filtering
// into the query execution, but for now I think we need to get the full
// response. We could also choose a large arbitrary number that will
// likely work in practice, like 10*maxUDPAnswerLimit which should help
// reduce bandwidth if there are thousands of nodes available.
// Determine the TTL. The parse should never fail since we vet it when
// the query is created, but we check anyway. If the query didn't
// specify a TTL then we will try to use the agent's service-specific
// TTL configs.
// Check is there is a TTL provided as part of the prepared query
var ttlOverride *uint32
if out.DNS.TTL != "" {
ttl, err := time.ParseDuration(out.DNS.TTL)
if err == nil {
ttlSec := uint32(ttl / time.Second)
ttlOverride = &ttlSec
} else {
f.logger.Warn("Failed to parse TTL for prepared query , ignoring",
"ttl", out.DNS.TTL,
"prepared_query", req.Name,
)
}
}
// If we have no nodes, return not found!
if len(out.Nodes) == 0 {
return nil, ECSNotGlobalError{ErrNotFound}
}
// Perform a random shuffle
out.Nodes.Shuffle()
return f.buildResultsFromServiceNodes(out.Nodes, req, ttlOverride), ECSNotGlobalError{}
}
// executePreparedQuery is used to execute a PreparedQuery against the Consul catalog.
// If the config is set to UseCache, it will use agent cache.
func (f *V1DataFetcher) executePreparedQuery(cfg *V1DataFetcherDynamicConfig, args structs.PreparedQueryExecuteRequest) (*structs.PreparedQueryExecuteResponse, error) {
var out structs.PreparedQueryExecuteResponse
RPC:
if cfg.UseCache {
raw, m, err := f.getFromCacheFunc(context.TODO(), cachetype.PreparedQueryName, &args)
if err != nil {
return nil, err
}
reply, ok := raw.(*structs.PreparedQueryExecuteResponse)
if !ok {
// This should never happen, but we want to protect against panics
return nil, err
}
f.logger.Trace("cache results for prepared query",
"cache_hit", m.Hit,
"prepared_query", args.QueryIDOrName,
)
out = *reply
} else {
if err := f.rpcFunc(context.Background(), "PreparedQuery.Execute", &args, &out); err != nil {
return nil, err
}
}
// Verify that request is not too stale, redo the request.
if args.AllowStale {
if out.LastContact > cfg.MaxStale {
args.AllowStale = false
f.logger.Warn("Query results too stale, re-requesting")
goto RPC
} else if out.LastContact > staleCounterThreshold {
metrics.IncrCounter([]string{"dns", "stale_queries"}, 1)
}
}
return &out, nil
}
func (f *V1DataFetcher) ValidateRequest(_ Context, req *QueryPayload) error {
if req.EnableFailover {
return ErrNotSupported
}
if req.PortName != "" {
return ErrNotSupported
}
return validateEnterpriseTenancy(req.Tenancy)
}
// buildResultsFromServiceNodes builds a list of results from a list of nodes.
func (f *V1DataFetcher) buildResultsFromServiceNodes(nodes []structs.CheckServiceNode, req *QueryPayload, ttlOverride *uint32) []*Result {
// Convert the service endpoints to results up to the limit
limit := req.Limit
if len(nodes) < limit || limit == 0 {
limit = len(nodes)
}
results := make([]*Result, 0, limit)
for idx := 0; idx < limit; idx++ {
n := nodes[idx]
results = append(results, &Result{
Service: &Location{
Name: n.Service.Service,
Address: n.Service.Address,
TaggedAddresses: makeTaggedAddressesFromServiceAddresses(n.Service.TaggedAddresses),
},
Node: &Location{
Name: n.Node.Node,
Address: n.Node.Address,
TaggedAddresses: makeTaggedAddressesFromStrings(n.Node.TaggedAddresses),
},
Type: ResultTypeService,
DNS: DNSConfig{
TTL: ttlOverride,
Weight: uint32(findWeight(n)),
},
Ports: []Port{
{Number: uint32(f.translateServicePortFunc(n.Node.Datacenter, n.Service.Port, n.Service.TaggedAddresses))},
},
Metadata: n.Node.Meta,
Tenancy: ResultTenancy{
Namespace: n.Service.NamespaceOrEmpty(),
Partition: n.Service.PartitionOrEmpty(),
Datacenter: n.Node.Datacenter,
PeerName: n.Service.PeerName,
},
})
}
return results
}
// makeTaggedAddressesFromServiceAddresses is used to convert a map of service addresses to a map of Locations.
func makeTaggedAddressesFromServiceAddresses(tagged map[string]structs.ServiceAddress) map[string]*TaggedAddress {
taggedAddresses := make(map[string]*TaggedAddress)
for k, v := range tagged {
taggedAddresses[k] = &TaggedAddress{
Name: k,
Address: v.Address,
Port: Port{
Number: uint32(v.Port),
},
}
}
return taggedAddresses
}
// makeTaggedAddressesFromStrings is used to convert a map of strings to a map of Locations.
func makeTaggedAddressesFromStrings(tagged map[string]string) map[string]*TaggedAddress {
taggedAddresses := make(map[string]*TaggedAddress)
for k, v := range tagged {
taggedAddresses[k] = &TaggedAddress{
Name: k,
Address: v,
}
}
return taggedAddresses
}
// fetchNode is used to look up a node in the Consul catalog within NodeServices.
// If the config is set to UseCache, it will get the record from the agent cache.
func (f *V1DataFetcher) fetchNode(cfg *V1DataFetcherDynamicConfig, args *structs.NodeSpecificRequest) (*structs.IndexedNodeServices, error) {
var out structs.IndexedNodeServices
useCache := cfg.UseCache
RPC:
if useCache {
raw, _, err := f.getFromCacheFunc(context.TODO(), cachetype.NodeServicesName, args)
if err != nil {
return nil, err
}
reply, ok := raw.(*structs.IndexedNodeServices)
if !ok {
// This should never happen, but we want to protect against panics
return nil, fmt.Errorf("internal error: response type not correct")
}
out = *reply
} else {
if err := f.rpcFunc(context.Background(), "Catalog.NodeServices", &args, &out); err != nil {
return nil, err
}
}
// Verify that request is not too stale, redo the request
if args.AllowStale {
if out.LastContact > cfg.MaxStale {
args.AllowStale = false
useCache = false
f.logger.Warn("Query results too stale, re-requesting")
goto RPC
} else if out.LastContact > staleCounterThreshold {
metrics.IncrCounter([]string{"dns", "stale_queries"}, 1)
}
}
return &out, nil
}
// fetchService is used to look up a service in the Consul catalog.
func (f *V1DataFetcher) fetchService(ctx Context, req *QueryPayload,
cfg *V1DataFetcherDynamicConfig, lookupType LookupType) ([]*Result, error) {
f.logger.Trace(fmt.Sprintf("fetchService - req: %+v", req))
// If no datacenter is passed, default to our own
datacenter := cfg.Datacenter
if req.Tenancy.Datacenter != "" {
datacenter = req.Tenancy.Datacenter
}
if req.Tenancy.Peer != "" {
datacenter = ""
}
serviceTags := []string{}
if req.Tag != "" {
serviceTags = []string{req.Tag}
}
healthFilterType := structs.HealthFilterExcludeCritical
if cfg.OnlyPassing {
healthFilterType = structs.HealthFilterIncludeOnlyPassing
}
args := structs.ServiceSpecificRequest{
PeerName: req.Tenancy.Peer,
SamenessGroup: req.Tenancy.SamenessGroup,
Connect: lookupType == LookupTypeConnect,
Ingress: lookupType == LookupTypeIngress,
Datacenter: datacenter,
ServiceName: req.Name,
ServiceTags: serviceTags,
TagFilter: req.Tag != "",
HealthFilterType: healthFilterType,
QueryOptions: structs.QueryOptions{
Token: ctx.Token,
AllowStale: cfg.AllowStale,
MaxAge: cfg.CacheMaxAge,
UseCache: cfg.UseCache,
MaxStaleDuration: cfg.MaxStale,
},
EnterpriseMeta: queryTenancyToEntMeta(req.Tenancy),
}
out, _, err := f.rpcFuncForServiceNodes(context.TODO(), args)
if err != nil {
if strings.Contains(err.Error(), structs.ErrNoDCPath.Error()) {
return nil, ErrNoPathToDatacenter
}
return nil, fmt.Errorf("rpc request failed: %w", err)
}
// If we have no nodes, return not found!
if len(out.Nodes) == 0 {
return nil, ErrNotFound
}
// Perform a random shuffle
out.Nodes.Shuffle()
return f.buildResultsFromServiceNodes(out.Nodes, req, nil), nil
}
// findWeight returns the weight of a service node.
func findWeight(node structs.CheckServiceNode) int {
// By default, when only_passing is false, warning and passing nodes are returned
// Those values will be used if using a client with support while server has no
// support for weights
weightPassing := 1
weightWarning := 1
if node.Service.Weights != nil {
weightPassing = node.Service.Weights.Passing
weightWarning = node.Service.Weights.Warning
}
serviceChecks := make(api.HealthChecks, 0, len(node.Checks))
for _, c := range node.Checks {
if c.ServiceName == node.Service.Service || c.ServiceName == "" {
healthCheck := &api.HealthCheck{
Node: c.Node,
CheckID: string(c.CheckID),
Name: c.Name,
Status: c.Status,
Notes: c.Notes,
Output: c.Output,
ServiceID: c.ServiceID,
ServiceName: c.ServiceName,
ServiceTags: c.ServiceTags,
}
serviceChecks = append(serviceChecks, healthCheck)
}
}
status := serviceChecks.AggregatedStatus()
switch status {
case api.HealthWarning:
return weightWarning
case api.HealthPassing:
return weightPassing
case api.HealthMaint:
// Not used in theory
return 0
case api.HealthCritical:
// Should not happen since already filtered
return 0
default:
// When non-standard status, return 1
return 1
}
}

View File

@ -1,30 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
//go:build !consulent
package discovery
import (
"github.com/hashicorp/consul/acl"
)
func (f *V1DataFetcher) NormalizeRequest(req *QueryPayload) {
// Nothing to do for CE
return
}
// validateEnterpriseTenancy validates the tenancy fields for an enterprise request to
// make sure that they are either set to an empty string or "default" to align with the behavior
// in CE.
func validateEnterpriseTenancy(req QueryTenancy) error {
if !(req.Namespace == acl.EmptyNamespaceName || req.Namespace == acl.DefaultNamespaceName) ||
!(req.Partition == acl.DefaultPartitionName || req.Partition == acl.NonEmptyDefaultPartitionName) {
return ErrNotSupported
}
return nil
}
func queryTenancyToEntMeta(_ QueryTenancy) acl.EnterpriseMeta {
return acl.EnterpriseMeta{}
}

View File

@ -1,64 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
//go:build !consulent
package discovery
import (
"github.com/stretchr/testify/require"
"testing"
)
const (
defaultTestNamespace = ""
defaultTestPartition = ""
)
func Test_validateEnterpriseTenancy(t *testing.T) {
testCases := []struct {
name string
req QueryTenancy
expected error
}{
{
name: "empty namespace and partition returns no error",
req: QueryTenancy{
Namespace: defaultTestNamespace,
Partition: defaultTestPartition,
},
expected: nil,
},
{
name: "namespace and partition set to 'default' returns no error",
req: QueryTenancy{
Namespace: "default",
Partition: "default",
},
expected: nil,
},
{
name: "namespace set to something other than empty string or `default` returns not supported error",
req: QueryTenancy{
Namespace: "namespace-1",
Partition: "default",
},
expected: ErrNotSupported,
},
{
name: "partition set to something other than empty string or `default` returns not supported error",
req: QueryTenancy{
Namespace: "default",
Partition: "partition-1",
},
expected: ErrNotSupported,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := validateEnterpriseTenancy(tc.req)
require.Equal(t, tc.expected, err)
})
}
}

View File

@ -1,207 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package discovery
import (
"context"
"errors"
"testing"
"time"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/cache"
cachetype "github.com/hashicorp/consul/agent/cache-types"
"github.com/hashicorp/consul/agent/config"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/sdk/testutil"
)
// Test_FetchVirtualIP tests the FetchVirtualIP method in scenarios where the RPC
// call succeeds and fails.
func Test_FetchVirtualIP(t *testing.T) {
// set these to confirm that RPC call does not use them for this particular RPC
rc := &config.RuntimeConfig{
DNSAllowStale: true,
DNSMaxStale: 100,
DNSUseCache: true,
DNSCacheMaxAge: 100,
}
tests := []struct {
name string
queryPayload *QueryPayload
context Context
expectedResult *Result
expectedErr error
}{
{
name: "FetchVirtualIP returns result",
queryPayload: &QueryPayload{
Name: "db",
Tenancy: QueryTenancy{
Peer: "test-peer",
Namespace: defaultTestNamespace,
Partition: defaultTestPartition,
},
},
context: Context{
Token: "test-token",
},
expectedResult: &Result{
Service: &Location{
Name: "db",
Address: "192.168.10.10",
},
Type: ResultTypeVirtual,
},
expectedErr: nil,
},
{
name: "FetchVirtualIP returns error",
queryPayload: &QueryPayload{
Name: "db",
Tenancy: QueryTenancy{
Peer: "test-peer",
Namespace: defaultTestNamespace,
Partition: defaultTestPartition},
},
context: Context{
Token: "test-token",
},
expectedResult: nil,
expectedErr: errors.New("test-error"),
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
logger := testutil.Logger(t)
mockRPC := cachetype.NewMockRPC(t)
mockRPC.On("RPC", mock.Anything, "Catalog.VirtualIPForService", mock.Anything, mock.Anything).
Return(tc.expectedErr).
Run(func(args mock.Arguments) {
req := args.Get(2).(*structs.ServiceSpecificRequest)
// validate RPC options are not set from config for the VirtuaLIPForService RPC
require.False(t, req.AllowStale)
require.Equal(t, time.Duration(0), req.MaxStaleDuration)
require.False(t, req.UseCache)
require.Equal(t, time.Duration(0), req.MaxAge)
// validate RPC options are set correctly from the queryPayload and context
require.Equal(t, tc.queryPayload.Tenancy.Peer, req.PeerName)
require.Equal(t, tc.queryPayload.Tenancy.Namespace, req.EnterpriseMeta.NamespaceOrEmpty())
require.Equal(t, tc.queryPayload.Tenancy.Partition, req.EnterpriseMeta.PartitionOrEmpty())
require.Equal(t, tc.context.Token, req.QueryOptions.Token)
if tc.expectedErr == nil {
// set the out parameter to ensure that it is used to formulate the result.Address
reply := args.Get(3).(*string)
*reply = tc.expectedResult.Service.Address
}
})
translateServicePortFunc := func(dc string, port int, taggedAddresses map[string]structs.ServiceAddress) int { return 0 }
rpcFuncForServiceNodes := func(ctx context.Context, req structs.ServiceSpecificRequest) (structs.IndexedCheckServiceNodes, cache.ResultMeta, error) {
return structs.IndexedCheckServiceNodes{}, cache.ResultMeta{}, nil
}
rpcFuncForSamenessGroup := func(ctx context.Context, req *structs.ConfigEntryQuery) (structs.SamenessGroupConfigEntry, cache.ResultMeta, error) {
return structs.SamenessGroupConfigEntry{}, cache.ResultMeta{}, nil
}
getFromCacheFunc := func(ctx context.Context, t string, r cache.Request) (interface{}, cache.ResultMeta, error) {
return nil, cache.ResultMeta{}, nil
}
df := NewV1DataFetcher(rc, acl.DefaultEnterpriseMeta(), getFromCacheFunc, mockRPC.RPC, rpcFuncForServiceNodes, rpcFuncForSamenessGroup, translateServicePortFunc, logger)
result, err := df.FetchVirtualIP(tc.context, tc.queryPayload)
require.Equal(t, tc.expectedErr, err)
require.Equal(t, tc.expectedResult, result)
})
}
}
// Test_FetchEndpoints tests the FetchEndpoints method in scenarios where the RPC
// call succeeds and fails.
func Test_FetchEndpoints(t *testing.T) {
// set these to confirm that RPC call does not use them for this particular RPC
rc := &config.RuntimeConfig{
Datacenter: "dc2",
DNSAllowStale: true,
DNSMaxStale: 100,
DNSUseCache: true,
DNSCacheMaxAge: 100,
}
ctx := Context{
Token: "test-token",
}
expectedResults := []*Result{
{
Node: &Location{
Name: "node-name",
Address: "node-address",
TaggedAddresses: map[string]*TaggedAddress{},
},
Service: &Location{
Name: "service-name",
Address: "service-address",
TaggedAddresses: map[string]*TaggedAddress{},
},
Type: ResultTypeService,
DNS: DNSConfig{
Weight: 1,
},
Ports: []Port{
{
Number: 0,
},
},
Tenancy: ResultTenancy{
PeerName: "test-peer",
},
},
}
logger := testutil.Logger(t)
mockRPC := cachetype.NewMockRPC(t)
translateServicePortFunc := func(dc string, port int, taggedAddresses map[string]structs.ServiceAddress) int { return 0 }
rpcFuncForSamenessGroup := func(ctx context.Context, req *structs.ConfigEntryQuery) (structs.SamenessGroupConfigEntry, cache.ResultMeta, error) {
return structs.SamenessGroupConfigEntry{}, cache.ResultMeta{}, nil
}
getFromCacheFunc := func(ctx context.Context, t string, r cache.Request) (interface{}, cache.ResultMeta, error) {
return nil, cache.ResultMeta{}, nil
}
rpcFuncForServiceNodes := func(ctx context.Context, req structs.ServiceSpecificRequest) (structs.IndexedCheckServiceNodes, cache.ResultMeta, error) {
return structs.IndexedCheckServiceNodes{
Nodes: []structs.CheckServiceNode{
{
Node: &structs.Node{
Address: "node-address",
Node: "node-name",
},
Service: &structs.NodeService{
Address: "service-address",
Service: "service-name",
PeerName: "test-peer",
},
},
},
}, cache.ResultMeta{}, nil
}
queryPayload := &QueryPayload{
Name: "service-name",
Tenancy: QueryTenancy{
Peer: "test-peer",
Namespace: defaultTestNamespace,
Partition: defaultTestPartition,
},
}
df := NewV1DataFetcher(rc, acl.DefaultEnterpriseMeta(), getFromCacheFunc, mockRPC.RPC, rpcFuncForServiceNodes, rpcFuncForSamenessGroup, translateServicePortFunc, logger)
results, err := df.FetchEndpoints(ctx, queryPayload, LookupTypeService)
require.NoError(t, err)
require.Equal(t, expectedResults, results)
}

View File

@ -1,365 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package discovery
import (
"context"
"fmt"
"math/rand"
"net"
"strings"
"sync/atomic"
"golang.org/x/exp/slices"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/consul/agent/config"
"github.com/hashicorp/consul/internal/resource"
pbcatalog "github.com/hashicorp/consul/proto-public/pbcatalog/v2beta1"
"github.com/hashicorp/consul/proto-public/pbresource"
)
// V2DataFetcherDynamicConfig is used to store the dynamic configuration of the V2 data fetcher.
type V2DataFetcherDynamicConfig struct {
OnlyPassing bool
}
// V2DataFetcher is used to fetch data from the V2 catalog.
type V2DataFetcher struct {
client pbresource.ResourceServiceClient
logger hclog.Logger
// Requests inherit the partition of the agent unless otherwise specified.
defaultPartition string
dynamicConfig atomic.Value
}
// NewV2DataFetcher creates a new V2 data fetcher.
func NewV2DataFetcher(config *config.RuntimeConfig, client pbresource.ResourceServiceClient, logger hclog.Logger) *V2DataFetcher {
f := &V2DataFetcher{
client: client,
logger: logger,
defaultPartition: config.PartitionOrDefault(),
}
f.LoadConfig(config)
return f
}
// LoadConfig loads the configuration for the V2 data fetcher.
func (f *V2DataFetcher) LoadConfig(config *config.RuntimeConfig) {
dynamicConfig := &V2DataFetcherDynamicConfig{
OnlyPassing: config.DNSOnlyPassing,
}
f.dynamicConfig.Store(dynamicConfig)
}
// GetConfig loads the configuration for the V2 data fetcher.
func (f *V2DataFetcher) GetConfig() *V2DataFetcherDynamicConfig {
return f.dynamicConfig.Load().(*V2DataFetcherDynamicConfig)
}
// FetchNodes fetches A/AAAA/CNAME
func (f *V2DataFetcher) FetchNodes(ctx Context, req *QueryPayload) ([]*Result, error) {
// TODO (v2-dns): NET-6623 - Implement FetchNodes
// Make sure that we validate that namespace is not provided here
return nil, fmt.Errorf("not implemented")
}
// FetchEndpoints fetches records for A/AAAA/CNAME or SRV requests for services
func (f *V2DataFetcher) FetchEndpoints(reqContext Context, req *QueryPayload, lookupType LookupType) ([]*Result, error) {
if lookupType != LookupTypeService {
return nil, ErrNotSupported
}
configCtx := f.dynamicConfig.Load().(*V2DataFetcherDynamicConfig)
serviceEndpoints := pbcatalog.ServiceEndpoints{}
serviceEndpointsResource, err := f.fetchResource(reqContext, *req, pbcatalog.ServiceEndpointsType, &serviceEndpoints)
if err != nil {
return nil, err
}
f.logger.Trace("shuffling endpoints", "name", req.Name, "endpoints", len(serviceEndpoints.Endpoints))
// Shuffle the endpoints slice
shuffleFunc := func(i, j int) {
serviceEndpoints.Endpoints[i], serviceEndpoints.Endpoints[j] = serviceEndpoints.Endpoints[j], serviceEndpoints.Endpoints[i]
}
rand.Shuffle(len(serviceEndpoints.Endpoints), shuffleFunc)
// Convert the service endpoints to results up to the limit
limit := req.Limit
if len(serviceEndpoints.Endpoints) < limit || limit == 0 {
limit = len(serviceEndpoints.Endpoints)
}
results := make([]*Result, 0, limit)
for _, endpoint := range serviceEndpoints.Endpoints[:limit] {
// First we check the endpoint first to make sure that the requested port is matched from the service.
// We error here because we expect all endpoints to have the same ports as the service.
ports := getResultPorts(req, endpoint.Ports) //assuming the logic changed in getResultPorts
if len(ports) == 0 {
f.logger.Debug("could not find matching port in endpoint", "name", req.Name, "port", req.PortName)
return nil, ErrNotFound
}
address, err := f.addressFromWorkloadAddresses(endpoint.Addresses, req.Name)
if err != nil {
return nil, err
}
weight, ok := getEndpointWeight(endpoint, configCtx)
if !ok {
f.logger.Debug("endpoint filtered out because of health status", "name", req.Name, "endpoint", endpoint.GetTargetRef().GetName())
continue
}
result := &Result{
Node: &Location{
Address: address,
Name: endpoint.GetTargetRef().GetName(),
},
Type: ResultTypeWorkload,
Tenancy: ResultTenancy{
Namespace: serviceEndpointsResource.GetId().GetTenancy().GetNamespace(),
Partition: serviceEndpointsResource.GetId().GetTenancy().GetPartition(),
},
DNS: DNSConfig{
Weight: weight,
},
Ports: ports,
}
results = append(results, result)
}
return results, nil
}
// FetchVirtualIP fetches A/AAAA records for virtual IPs
func (f *V2DataFetcher) FetchVirtualIP(ctx Context, req *QueryPayload) (*Result, error) {
// TODO (v2-dns): NET-6624 - Implement FetchVirtualIP
return nil, fmt.Errorf("not implemented")
}
// FetchRecordsByIp is used for PTR requests to look up a service/node from an IP.
func (f *V2DataFetcher) FetchRecordsByIp(ctx Context, ip net.IP) ([]*Result, error) {
// TODO (v2-dns): NET-6795 - Implement FetchRecordsByIp
// Validate non-nil IP
return nil, fmt.Errorf("not implemented")
}
// FetchWorkload is used to fetch a single workload from the V2 catalog.
// V2-only.
func (f *V2DataFetcher) FetchWorkload(reqContext Context, req *QueryPayload) (*Result, error) {
workload := pbcatalog.Workload{}
resourceObj, err := f.fetchResource(reqContext, *req, pbcatalog.WorkloadType, &workload)
if err != nil {
return nil, err
}
// First we check the endpoint first to make sure that the requested port is matched from the service.
// We error here because we expect all endpoints to have the same ports as the service.
ports := getResultPorts(req, workload.Ports) //assuming the logic changed in getResultPorts
if ports == nil || len(ports) == 0 {
f.logger.Debug("could not find matching port in endpoint", "name", req.Name, "port", req.PortName)
return nil, ErrNotFound
}
address, err := f.addressFromWorkloadAddresses(workload.Addresses, req.Name)
if err != nil {
return nil, err
}
tenancy := resourceObj.GetId().GetTenancy()
result := &Result{
Node: &Location{
Address: address,
Name: resourceObj.GetId().GetName(),
},
Type: ResultTypeWorkload,
Tenancy: ResultTenancy{
Namespace: tenancy.GetNamespace(),
Partition: tenancy.GetPartition(),
},
Ports: ports,
}
return result, nil
}
// FetchPreparedQuery is used to fetch a prepared query from the V2 catalog.
// Deprecated in V2.
func (f *V2DataFetcher) FetchPreparedQuery(ctx Context, req *QueryPayload) ([]*Result, error) {
return nil, ErrNotSupported
}
func (f *V2DataFetcher) NormalizeRequest(req *QueryPayload) {
// If we do not have an explicit partition in the request, we use the agent's
if req.Tenancy.Partition == "" {
req.Tenancy.Partition = f.defaultPartition
}
}
// ValidateRequest throws an error is any of the deprecated V1 input fields are used in a QueryByName for this data fetcher.
func (f *V2DataFetcher) ValidateRequest(_ Context, req *QueryPayload) error {
if req.Tag != "" {
return ErrNotSupported
}
if req.SourceIP != nil {
return ErrNotSupported
}
return nil
}
// fetchResource is used to read a single resource from the V2 catalog and cast into a concrete type.
func (f *V2DataFetcher) fetchResource(reqContext Context, req QueryPayload, kind *pbresource.Type, payload proto.Message) (*pbresource.Resource, error) {
// Query the resource service for the ServiceEndpoints by name and tenancy
resourceReq := pbresource.ReadRequest{
Id: &pbresource.ID{
Name: req.Name,
Type: kind,
Tenancy: queryTenancyToResourceTenancy(req.Tenancy),
},
}
f.logger.Trace("fetching "+kind.String(), "name", req.Name)
resourceCtx := metadata.AppendToOutgoingContext(context.Background(), "x-consul-token", reqContext.Token)
// If the service is not found, return nil and an error equivalent to NXDOMAIN
response, err := f.client.Read(resourceCtx, &resourceReq)
switch {
case grpcNotFoundErr(err):
f.logger.Debug(kind.String()+" not found", "name", req.Name)
return nil, ErrNotFound
case err != nil:
f.logger.Error("error fetching "+kind.String(), "name", req.Name)
return nil, fmt.Errorf("error fetching %s: %w", kind.String(), err)
// default: fallthrough
}
data := response.GetResource().GetData()
if err := data.UnmarshalTo(payload); err != nil {
f.logger.Error("error unmarshalling "+kind.String(), "name", req.Name)
return nil, fmt.Errorf("error unmarshalling %s: %w", kind.String(), err)
}
return response.GetResource(), nil
}
// addressFromWorkloadAddresses returns one address from the workload addresses.
func (f *V2DataFetcher) addressFromWorkloadAddresses(addresses []*pbcatalog.WorkloadAddress, name string) (string, error) {
// TODO: (v2-dns): we will need to intelligently return the right workload address based on either the translate
// address setting or the locality of the requester. Workloads must have at least one.
// We also need to make sure that we filter out unix sockets here.
address := addresses[0].GetHost()
if strings.HasPrefix(address, "unix://") {
f.logger.Error("unix sockets are currently unsupported in workload results", "name", name)
return "", ErrNotFound
}
return address, nil
}
// getEndpointWeight returns the weight of the endpoint and a boolean indicating if the endpoint should be included
// based on it's health status.
func getEndpointWeight(endpoint *pbcatalog.Endpoint, configCtx *V2DataFetcherDynamicConfig) (uint32, bool) {
health := endpoint.GetHealthStatus().Enum()
if health == nil {
return 0, false
}
// Filter based on health status and agent config
// This is also a good opportunity to see if SRV weights are set
var weight uint32
switch *health {
case pbcatalog.Health_HEALTH_PASSING:
weight = endpoint.GetDns().GetWeights().GetPassing()
case pbcatalog.Health_HEALTH_CRITICAL:
return 0, false // always filtered out
case pbcatalog.Health_HEALTH_WARNING:
if configCtx.OnlyPassing {
return 0, false // filtered out
}
weight = endpoint.GetDns().GetWeights().GetWarning()
default:
// Everything else can be filtered out
return 0, false
}
// Important! double-check the weight in the case DNS weights are not set
if weight == 0 {
weight = 1
}
return weight, true
}
// getResultPorts conditionally returns ports from a map based on a query. The results are sorted by name.
func getResultPorts(req *QueryPayload, workloadPorts map[string]*pbcatalog.WorkloadPort) []Port {
if len(workloadPorts) == 0 {
return nil
}
var ports []Port
if req.PortName != "" {
// Make sure the workload implements that port name.
if _, ok := workloadPorts[req.PortName]; !ok {
return nil
}
// In the case that the query asked for a specific port, we only return that port.
ports = []Port{
{
Name: req.PortName,
Number: workloadPorts[req.PortName].Port,
},
}
} else {
// If the client didn't specify a particular port, return all the workload ports.
for name, port := range workloadPorts {
ports = append(ports, Port{
Name: name,
Number: port.Port,
})
}
// Stable Sort
slices.SortStableFunc(ports, func(i, j Port) int {
if i.Name < j.Name {
return -1
} else if i.Name > j.Name {
return 1
}
return 0
})
}
return ports
}
// queryTenancyToResourceTenancy converts a QueryTenancy to a pbresource.Tenancy.
func queryTenancyToResourceTenancy(qTenancy QueryTenancy) *pbresource.Tenancy {
rTenancy := resource.DefaultNamespacedTenancy()
// If the request has any tenancy specified, it overrides the defaults.
if qTenancy.Namespace != "" {
rTenancy.Namespace = qTenancy.Namespace
}
// In the case of partition, we have the agent's partition as the fallback.
// That is handled in NormalizeRequest.
if qTenancy.Partition != "" {
rTenancy.Partition = qTenancy.Partition
}
return rTenancy
}
// grpcNotFoundErr returns true if the error is a gRPC status error with a code of NotFound.
func grpcNotFoundErr(err error) bool {
if err == nil {
return false
}
s, ok := status.FromError(err)
return ok && s.Code() == codes.NotFound
}

View File

@ -1,859 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package discovery
import (
"errors"
"fmt"
"testing"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/anypb"
"github.com/hashicorp/consul/agent/config"
mockpbresource "github.com/hashicorp/consul/grpcmocks/proto-public/pbresource"
"github.com/hashicorp/consul/internal/resource"
pbcatalog "github.com/hashicorp/consul/proto-public/pbcatalog/v2beta1"
"github.com/hashicorp/consul/proto-public/pbresource"
"github.com/hashicorp/consul/sdk/testutil"
)
var (
unknownErr = errors.New("I don't feel so good")
)
// Test_FetchService tests the FetchService method in scenarios where the RPC
// call succeeds and fails.
func Test_FetchWorkload(t *testing.T) {
rc := &config.RuntimeConfig{
DNSOnlyPassing: false,
}
tests := []struct {
name string
queryPayload *QueryPayload
context Context
configureMockClient func(mockClient *mockpbresource.ResourceServiceClient_Expecter)
expectedResult *Result
expectedErr error
}{
{
name: "FetchWorkload returns result",
queryPayload: &QueryPayload{
Name: "foo-1234",
},
context: Context{
Token: "test-token",
},
configureMockClient: func(mockClient *mockpbresource.ResourceServiceClient_Expecter) {
result := getTestWorkloadResponse(t, "foo-1234", "", "")
mockClient.Read(mock.Anything, mock.Anything).
Return(result, nil).
Once().
Run(func(args mock.Arguments) {
req := args.Get(1).(*pbresource.ReadRequest)
require.Equal(t, result.GetResource().GetId().GetName(), req.Id.Name)
})
},
expectedResult: &Result{
Node: &Location{Name: "foo-1234", Address: "1.2.3.4"},
Type: ResultTypeWorkload,
Ports: []Port{
{
Name: "api",
Number: 5678,
},
{
Name: "mesh",
Number: 21000,
},
},
Tenancy: ResultTenancy{
Namespace: resource.DefaultNamespaceName,
Partition: resource.DefaultPartitionName,
},
},
expectedErr: nil,
},
{
name: "FetchWorkload for non-existent workload",
queryPayload: &QueryPayload{
Name: "foo-1234",
},
context: Context{
Token: "test-token",
},
configureMockClient: func(mockClient *mockpbresource.ResourceServiceClient_Expecter) {
input := getTestWorkloadResponse(t, "foo-1234", "", "")
mockClient.Read(mock.Anything, mock.Anything).
Return(nil, status.Error(codes.NotFound, "not found")).
Once().
Run(func(args mock.Arguments) {
req := args.Get(1).(*pbresource.ReadRequest)
require.Equal(t, input.GetResource().GetId().GetName(), req.Id.Name)
})
},
expectedResult: nil,
expectedErr: ErrNotFound,
},
{
name: "FetchWorkload encounters a resource client error",
queryPayload: &QueryPayload{
Name: "foo-1234",
},
context: Context{
Token: "test-token",
},
configureMockClient: func(mockClient *mockpbresource.ResourceServiceClient_Expecter) {
input := getTestWorkloadResponse(t, "foo-1234", "", "")
mockClient.Read(mock.Anything, mock.Anything).
Return(nil, unknownErr).
Once().
Run(func(args mock.Arguments) {
req := args.Get(1).(*pbresource.ReadRequest)
require.Equal(t, input.GetResource().GetId().GetName(), req.Id.Name)
})
},
expectedResult: nil,
expectedErr: unknownErr,
},
{
name: "FetchWorkload with a matching port",
queryPayload: &QueryPayload{
Name: "foo-1234",
PortName: "api",
},
context: Context{
Token: "test-token",
},
configureMockClient: func(mockClient *mockpbresource.ResourceServiceClient_Expecter) {
result := getTestWorkloadResponse(t, "foo-1234", "", "")
mockClient.Read(mock.Anything, mock.Anything).
Return(result, nil).
Once().
Run(func(args mock.Arguments) {
req := args.Get(1).(*pbresource.ReadRequest)
require.Equal(t, result.GetResource().GetId().GetName(), req.Id.Name)
})
},
expectedResult: &Result{
Node: &Location{Name: "foo-1234", Address: "1.2.3.4"},
Type: ResultTypeWorkload,
Ports: []Port{
{
Name: "api",
Number: 5678,
},
},
Tenancy: ResultTenancy{
Namespace: resource.DefaultNamespaceName,
Partition: resource.DefaultPartitionName,
},
},
expectedErr: nil,
},
{
name: "FetchWorkload with a matching port",
queryPayload: &QueryPayload{
Name: "foo-1234",
PortName: "not-api",
},
context: Context{
Token: "test-token",
},
configureMockClient: func(mockClient *mockpbresource.ResourceServiceClient_Expecter) {
result := getTestWorkloadResponse(t, "foo-1234", "", "")
mockClient.Read(mock.Anything, mock.Anything).
Return(result, nil).
Once().
Run(func(args mock.Arguments) {
req := args.Get(1).(*pbresource.ReadRequest)
require.Equal(t, result.GetResource().GetId().GetName(), req.Id.Name)
})
},
expectedResult: nil,
expectedErr: ErrNotFound,
},
{
name: "FetchWorkload returns result for non-default tenancy",
queryPayload: &QueryPayload{
Name: "foo-1234",
Tenancy: QueryTenancy{
Namespace: "test-namespace",
Partition: "test-partition",
},
},
context: Context{
Token: "test-token",
},
configureMockClient: func(mockClient *mockpbresource.ResourceServiceClient_Expecter) {
result := getTestWorkloadResponse(t, "foo-1234", "test-namespace", "test-partition")
mockClient.Read(mock.Anything, mock.Anything).
Return(result, nil).
Once().
Run(func(args mock.Arguments) {
req := args.Get(1).(*pbresource.ReadRequest)
require.Equal(t, result.GetResource().GetId().GetName(), req.Id.Name)
require.Equal(t, result.GetResource().GetId().GetTenancy().GetNamespace(), req.Id.Tenancy.Namespace)
require.Equal(t, result.GetResource().GetId().GetTenancy().GetPartition(), req.Id.Tenancy.Partition)
})
},
expectedResult: &Result{
Node: &Location{Name: "foo-1234", Address: "1.2.3.4"},
Type: ResultTypeWorkload,
Ports: []Port{
{
Name: "api",
Number: 5678,
},
{
Name: "mesh",
Number: 21000,
},
},
Tenancy: ResultTenancy{
Namespace: "test-namespace",
Partition: "test-partition",
},
},
expectedErr: nil,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
logger := testutil.Logger(t)
client := mockpbresource.NewResourceServiceClient(t)
mockClient := client.EXPECT()
tc.configureMockClient(mockClient)
df := NewV2DataFetcher(rc, client, logger)
result, err := df.FetchWorkload(tc.context, tc.queryPayload)
require.True(t, errors.Is(err, tc.expectedErr))
require.Equal(t, tc.expectedResult, result)
})
}
}
// Test_V2FetchEndpoints the FetchService method in scenarios where the RPC
// call succeeds and fails.
func Test_V2FetchEndpoints(t *testing.T) {
tests := []struct {
name string
queryPayload *QueryPayload
context Context
configureMockClient func(mockClient *mockpbresource.ResourceServiceClient_Expecter)
rc *config.RuntimeConfig
expectedResult []*Result
expectedErr error
verifyShuffle bool
}{
{
name: "FetchEndpoints returns result",
queryPayload: &QueryPayload{
Name: "consul",
},
context: Context{
Token: "test-token",
},
configureMockClient: func(mockClient *mockpbresource.ResourceServiceClient_Expecter) {
endpoints := []*pbcatalog.Endpoint{
makeEndpoint("consul-1", "1.2.3.4", pbcatalog.Health_HEALTH_PASSING, 0, 0),
}
serviceEndpoints := getTestEndpointsResponse(t, "", "", endpoints...)
mockClient.Read(mock.Anything, mock.Anything).
Return(serviceEndpoints, nil).
Once().
Run(func(args mock.Arguments) {
req := args.Get(1).(*pbresource.ReadRequest)
require.Equal(t, serviceEndpoints.GetResource().GetId().GetName(), req.Id.Name)
})
},
expectedResult: []*Result{
{
Node: &Location{Name: "consul-1", Address: "1.2.3.4"},
Type: ResultTypeWorkload,
Ports: []Port{
{
Name: "api",
Number: 5678,
},
{
Name: "mesh",
Number: 21000,
},
},
Tenancy: ResultTenancy{
Namespace: resource.DefaultNamespaceName,
Partition: resource.DefaultPartitionName,
},
DNS: DNSConfig{
Weight: 1,
},
},
},
},
{
name: "FetchEndpoints returns empty result with no endpoints",
queryPayload: &QueryPayload{
Name: "consul",
},
context: Context{
Token: "test-token",
},
configureMockClient: func(mockClient *mockpbresource.ResourceServiceClient_Expecter) {
result := getTestEndpointsResponse(t, "", "")
mockClient.Read(mock.Anything, mock.Anything).
Return(result, nil).
Once().
Run(func(args mock.Arguments) {
req := args.Get(1).(*pbresource.ReadRequest)
require.Equal(t, result.GetResource().GetId().GetName(), req.Id.Name)
})
},
expectedResult: []*Result{},
},
{
name: "FetchEndpoints returns a name error when the ServiceEndpoint does not exist",
queryPayload: &QueryPayload{
Name: "consul",
},
context: Context{
Token: "test-token",
},
configureMockClient: func(mockClient *mockpbresource.ResourceServiceClient_Expecter) {
result := getTestEndpointsResponse(t, "", "")
mockClient.Read(mock.Anything, mock.Anything).
Return(nil, status.Error(codes.NotFound, "not found")).
Once().
Run(func(args mock.Arguments) {
req := args.Get(1).(*pbresource.ReadRequest)
require.Equal(t, result.GetResource().GetId().GetName(), req.Id.Name)
})
},
expectedErr: ErrNotFound,
},
{
name: "FetchEndpoints encounters a resource client error",
queryPayload: &QueryPayload{
Name: "consul",
},
context: Context{
Token: "test-token",
},
configureMockClient: func(mockClient *mockpbresource.ResourceServiceClient_Expecter) {
result := getTestEndpointsResponse(t, "", "")
mockClient.Read(mock.Anything, mock.Anything).
Return(nil, unknownErr).
Once().
Run(func(args mock.Arguments) {
req := args.Get(1).(*pbresource.ReadRequest)
require.Equal(t, result.GetResource().GetId().GetName(), req.Id.Name)
})
},
expectedErr: unknownErr,
},
{
name: "FetchEndpoints always filters out critical endpoints; DNS weights applied correctly",
queryPayload: &QueryPayload{
Name: "consul",
},
context: Context{
Token: "test-token",
},
configureMockClient: func(mockClient *mockpbresource.ResourceServiceClient_Expecter) {
results := []*pbcatalog.Endpoint{
makeEndpoint("consul-1", "1.2.3.4", pbcatalog.Health_HEALTH_PASSING, 2, 3),
makeEndpoint("consul-2", "2.3.4.5", pbcatalog.Health_HEALTH_WARNING, 2, 3),
makeEndpoint("consul-3", "3.4.5.6", pbcatalog.Health_HEALTH_CRITICAL, 2, 3),
}
result := getTestEndpointsResponse(t, "", "", results...)
mockClient.Read(mock.Anything, mock.Anything).
Return(result, nil).
Once().
Run(func(args mock.Arguments) {
req := args.Get(1).(*pbresource.ReadRequest)
require.Equal(t, result.GetResource().GetId().GetName(), req.Id.Name)
})
},
expectedResult: []*Result{
{
Node: &Location{Name: "consul-1", Address: "1.2.3.4"},
Type: ResultTypeWorkload,
Tenancy: ResultTenancy{
Namespace: resource.DefaultNamespaceName,
Partition: resource.DefaultPartitionName,
},
DNS: DNSConfig{
Weight: 2,
},
Ports: []Port{
{
Name: "api",
Number: 5678,
},
{
Name: "mesh",
Number: 21000,
},
},
},
{
Node: &Location{Name: "consul-2", Address: "2.3.4.5"},
Type: ResultTypeWorkload,
Tenancy: ResultTenancy{
Namespace: resource.DefaultNamespaceName,
Partition: resource.DefaultPartitionName,
},
DNS: DNSConfig{
Weight: 3,
},
Ports: []Port{
{
Name: "api",
Number: 5678,
},
{
Name: "mesh",
Number: 21000,
},
},
},
},
},
{
name: "FetchEndpoints filters out warning endpoints when DNSOnlyPassing is true",
queryPayload: &QueryPayload{
Name: "consul",
},
context: Context{
Token: "test-token",
},
configureMockClient: func(mockClient *mockpbresource.ResourceServiceClient_Expecter) {
results := []*pbcatalog.Endpoint{
makeEndpoint("consul-1", "1.2.3.4", pbcatalog.Health_HEALTH_PASSING, 2, 3),
makeEndpoint("consul-2", "2.3.4.5", pbcatalog.Health_HEALTH_WARNING, 2, 3),
makeEndpoint("consul-3", "3.4.5.6", pbcatalog.Health_HEALTH_CRITICAL, 2, 3),
}
result := getTestEndpointsResponse(t, "", "", results...)
mockClient.Read(mock.Anything, mock.Anything).
Return(result, nil).
Once().
Run(func(args mock.Arguments) {
req := args.Get(1).(*pbresource.ReadRequest)
require.Equal(t, result.GetResource().GetId().GetName(), req.Id.Name)
})
},
rc: &config.RuntimeConfig{
DNSOnlyPassing: true,
},
expectedResult: []*Result{
{
Node: &Location{Name: "consul-1", Address: "1.2.3.4"},
Type: ResultTypeWorkload,
Tenancy: ResultTenancy{
Namespace: resource.DefaultNamespaceName,
Partition: resource.DefaultPartitionName,
},
DNS: DNSConfig{
Weight: 2,
},
Ports: []Port{
{
Name: "api",
Number: 5678,
},
{
Name: "mesh",
Number: 21000,
},
},
},
},
},
{
name: "FetchEndpoints shuffles the results",
queryPayload: &QueryPayload{
Name: "consul",
},
context: Context{
Token: "test-token",
},
configureMockClient: func(mockClient *mockpbresource.ResourceServiceClient_Expecter) {
results := []*pbcatalog.Endpoint{
// use a set of 10 elements, the odds of getting the same result are 1 in 3628800
makeEndpoint("consul-1", "10.0.0.1", pbcatalog.Health_HEALTH_PASSING, 0, 0),
makeEndpoint("consul-2", "10.0.0.2", pbcatalog.Health_HEALTH_PASSING, 0, 0),
makeEndpoint("consul-3", "10.0.0.3", pbcatalog.Health_HEALTH_PASSING, 0, 0),
makeEndpoint("consul-4", "10.0.0.4", pbcatalog.Health_HEALTH_PASSING, 0, 0),
makeEndpoint("consul-5", "10.0.0.5", pbcatalog.Health_HEALTH_PASSING, 0, 0),
makeEndpoint("consul-6", "10.0.0.6", pbcatalog.Health_HEALTH_PASSING, 0, 0),
makeEndpoint("consul-7", "10.0.0.7", pbcatalog.Health_HEALTH_PASSING, 0, 0),
makeEndpoint("consul-8", "10.0.0.8", pbcatalog.Health_HEALTH_PASSING, 0, 0),
makeEndpoint("consul-9", "10.0.0.9", pbcatalog.Health_HEALTH_PASSING, 0, 0),
makeEndpoint("consul-10", "10.0.0.10", pbcatalog.Health_HEALTH_PASSING, 0, 0),
}
result := getTestEndpointsResponse(t, "", "", results...)
mockClient.Read(mock.Anything, mock.Anything).
Return(result, nil).
Once().
Run(func(args mock.Arguments) {
req := args.Get(1).(*pbresource.ReadRequest)
require.Equal(t, result.GetResource().GetId().GetName(), req.Id.Name)
})
},
expectedResult: func() []*Result {
results := make([]*Result, 0, 10)
for i := 0; i < 10; i++ {
name := fmt.Sprintf("consul-%d", i+1)
address := fmt.Sprintf("10.0.0.%d", i+1)
result := &Result{
Node: &Location{Name: name, Address: address},
Type: ResultTypeWorkload,
Tenancy: ResultTenancy{
Namespace: resource.DefaultNamespaceName,
Partition: resource.DefaultPartitionName,
},
Ports: []Port{
{
Name: "api",
Number: 5678,
},
{
Name: "mesh",
Number: 21000,
},
},
DNS: DNSConfig{
Weight: 1,
},
}
results = append(results, result)
}
return results
}(),
verifyShuffle: true,
},
{
name: "FetchEndpoints returns only the specified limit",
queryPayload: &QueryPayload{
Name: "consul",
Limit: 1,
},
context: Context{
Token: "test-token",
},
configureMockClient: func(mockClient *mockpbresource.ResourceServiceClient_Expecter) {
results := []*pbcatalog.Endpoint{
// intentionally all the same to make this easier to verify
makeEndpoint("consul-1", "10.0.0.1", pbcatalog.Health_HEALTH_PASSING, 0, 0),
makeEndpoint("consul-1", "10.0.0.1", pbcatalog.Health_HEALTH_PASSING, 0, 0),
makeEndpoint("consul-1", "10.0.0.1", pbcatalog.Health_HEALTH_PASSING, 0, 0),
}
result := getTestEndpointsResponse(t, "", "", results...)
mockClient.Read(mock.Anything, mock.Anything).
Return(result, nil).
Once().
Run(func(args mock.Arguments) {
req := args.Get(1).(*pbresource.ReadRequest)
require.Equal(t, result.GetResource().GetId().GetName(), req.Id.Name)
})
},
expectedResult: []*Result{
{
Node: &Location{Name: "consul-1", Address: "10.0.0.1"},
Type: ResultTypeWorkload,
Tenancy: ResultTenancy{
Namespace: resource.DefaultNamespaceName,
Partition: resource.DefaultPartitionName,
},
DNS: DNSConfig{
Weight: 1,
},
Ports: []Port{
{
Name: "api",
Number: 5678,
},
{
Name: "mesh",
Number: 21000,
},
},
},
},
},
{
name: "FetchEndpoints returns results with non-default tenancy",
queryPayload: &QueryPayload{
Name: "consul",
Tenancy: QueryTenancy{
Namespace: "test-namespace",
Partition: "test-partition",
},
},
context: Context{
Token: "test-token",
},
configureMockClient: func(mockClient *mockpbresource.ResourceServiceClient_Expecter) {
results := []*pbcatalog.Endpoint{
// intentionally all the same to make this easier to verify
makeEndpoint("consul-1", "10.0.0.1", pbcatalog.Health_HEALTH_PASSING, 0, 0),
}
result := getTestEndpointsResponse(t, "test-namespace", "test-partition", results...)
mockClient.Read(mock.Anything, mock.Anything).
Return(result, nil).
Once().
Run(func(args mock.Arguments) {
req := args.Get(1).(*pbresource.ReadRequest)
require.Equal(t, result.GetResource().GetId().GetName(), req.Id.Name)
require.Equal(t, result.GetResource().GetId().GetTenancy().GetNamespace(), req.Id.Tenancy.Namespace)
require.Equal(t, result.GetResource().GetId().GetTenancy().GetPartition(), req.Id.Tenancy.Partition)
})
},
expectedResult: []*Result{
{
Node: &Location{Name: "consul-1", Address: "10.0.0.1"},
Type: ResultTypeWorkload,
Tenancy: ResultTenancy{
Namespace: "test-namespace",
Partition: "test-partition",
},
DNS: DNSConfig{
Weight: 1,
},
Ports: []Port{
{
Name: "api",
Number: 5678,
},
{
Name: "mesh",
Number: 21000,
},
},
},
},
},
{
name: "FetchEndpoints returns only a specific port if is one requested",
queryPayload: &QueryPayload{
Name: "consul",
PortName: "api",
},
context: Context{
Token: "test-token",
},
configureMockClient: func(mockClient *mockpbresource.ResourceServiceClient_Expecter) {
endpoints := []*pbcatalog.Endpoint{
makeEndpoint("consul-1", "10.0.0.1", pbcatalog.Health_HEALTH_PASSING, 0, 0),
}
serviceEndpoints := getTestEndpointsResponse(t, "", "", endpoints...)
mockClient.Read(mock.Anything, mock.Anything).
Return(serviceEndpoints, nil).
Once().
Run(func(args mock.Arguments) {
req := args.Get(1).(*pbresource.ReadRequest)
require.Equal(t, serviceEndpoints.GetResource().GetId().GetName(), req.Id.Name)
})
},
expectedResult: []*Result{
{
Node: &Location{Name: "consul-1", Address: "10.0.0.1"},
Type: ResultTypeWorkload,
Ports: []Port{
{
Name: "api",
Number: 5678,
},
// No mesh port this time
},
Tenancy: ResultTenancy{
Namespace: resource.DefaultNamespaceName,
Partition: resource.DefaultPartitionName,
},
DNS: DNSConfig{
Weight: 1,
},
},
},
},
{
name: "FetchEndpoints returns a name error when a service doesn't implement the requested port",
queryPayload: &QueryPayload{
Name: "consul",
PortName: "banana",
},
context: Context{
Token: "test-token",
},
configureMockClient: func(mockClient *mockpbresource.ResourceServiceClient_Expecter) {
endpoints := []*pbcatalog.Endpoint{
makeEndpoint("consul-1", "10.0.0.1", pbcatalog.Health_HEALTH_PASSING, 0, 0),
}
serviceEndpoints := getTestEndpointsResponse(t, "", "", endpoints...)
mockClient.Read(mock.Anything, mock.Anything).
Return(serviceEndpoints, nil).
Once().
Run(func(args mock.Arguments) {
req := args.Get(1).(*pbresource.ReadRequest)
require.Equal(t, serviceEndpoints.GetResource().GetId().GetName(), req.Id.Name)
})
},
expectedErr: ErrNotFound,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
logger := testutil.Logger(t)
client := mockpbresource.NewResourceServiceClient(t)
mockClient := client.EXPECT()
tc.configureMockClient(mockClient)
if tc.rc == nil {
tc.rc = &config.RuntimeConfig{
DNSOnlyPassing: false,
}
}
df := NewV2DataFetcher(tc.rc, client, logger)
result, err := df.FetchEndpoints(tc.context, tc.queryPayload, LookupTypeService)
require.True(t, errors.Is(err, tc.expectedErr))
if tc.verifyShuffle {
require.NotEqualf(t, tc.expectedResult, result, "expected result to be shuffled. There is a small probability that it shuffled back to the original order. In that case, you may want to play the lottery.")
}
require.ElementsMatchf(t, tc.expectedResult, result, "elements of results should match")
})
}
}
func getTestWorkloadResponse(t *testing.T, name string, nsOverride string, partitionOverride string) *pbresource.ReadResponse {
workload := &pbcatalog.Workload{
Addresses: []*pbcatalog.WorkloadAddress{
{
Host: "1.2.3.4",
Ports: []string{"api", "mesh"},
},
},
Ports: map[string]*pbcatalog.WorkloadPort{
"api": {
Port: 5678,
},
"mesh": {
Port: 21000,
},
},
Identity: "test-identity",
}
data, err := anypb.New(workload)
require.NoError(t, err)
resp := &pbresource.ReadResponse{
Resource: &pbresource.Resource{
Id: &pbresource.ID{
Name: name,
Type: pbcatalog.WorkloadType,
Tenancy: resource.DefaultNamespacedTenancy(),
},
Data: data,
},
}
if nsOverride != "" {
resp.Resource.Id.Tenancy.Namespace = nsOverride
}
if partitionOverride != "" {
resp.Resource.Id.Tenancy.Partition = partitionOverride
}
return resp
}
func makeEndpoint(name string, address string, health pbcatalog.Health, weightPassing, weightWarning uint32) *pbcatalog.Endpoint {
endpoint := &pbcatalog.Endpoint{
Addresses: []*pbcatalog.WorkloadAddress{
{
Host: address,
Ports: []string{"api"},
},
},
Ports: map[string]*pbcatalog.WorkloadPort{
"api": {
Port: 5678,
},
"mesh": {
Port: 21000,
},
},
HealthStatus: health,
TargetRef: &pbresource.ID{
Name: name,
},
}
if weightPassing > 0 || weightWarning > 0 {
endpoint.Dns = &pbcatalog.DNSPolicy{
Weights: &pbcatalog.Weights{
Passing: weightPassing,
Warning: weightWarning,
},
}
}
return endpoint
}
func getTestEndpointsResponse(t *testing.T, nsOverride string, partitionOverride string, endpoints ...*pbcatalog.Endpoint) *pbresource.ReadResponse {
serviceEndpoints := &pbcatalog.ServiceEndpoints{
Endpoints: endpoints,
}
data, err := anypb.New(serviceEndpoints)
require.NoError(t, err)
resp := &pbresource.ReadResponse{
Resource: &pbresource.Resource{
Id: &pbresource.ID{
Name: "consul",
Type: pbcatalog.ServiceType,
Tenancy: resource.DefaultNamespacedTenancy(),
},
Data: data,
},
}
if nsOverride != "" {
resp.Resource.Id.Tenancy.Namespace = nsOverride
}
if partitionOverride != "" {
resp.Resource.Id.Tenancy.Partition = partitionOverride
}
return resp
}

View File

@ -8,6 +8,7 @@ import (
"encoding/hex"
"errors"
"fmt"
agentdns "github.com/hashicorp/consul/agent/dns"
"math"
"net"
"regexp"
@ -61,7 +62,17 @@ type dnsSOAConfig struct {
Minttl uint32 // 0
}
type dnsConfig struct {
// dnsRequestConfig returns the DNS request configuration that encapsulates:
// - the DNS server configuration.
// - the token from the request, if available.
// - the enterprise meta from the request, if available.
type dnsRequestConfig struct {
*dnsServerConfig
token string
defaultEnterpriseMeta acl.EnterpriseMeta
}
type dnsServerConfig struct {
AllowStale bool
Datacenter string
EnableTruncate bool
@ -119,7 +130,7 @@ type DNSServer struct {
altDomain string
logger hclog.Logger
// config stores the config as an atomic value (for hot-reloading). It is always of type *dnsConfig
// config stores the config as an atomic value (for hot-reloading). It is always of type *dnsServerConfig
config atomic.Value
// recursorEnabled stores whever the recursor handler is enabled as an atomic flag.
@ -141,7 +152,7 @@ func NewDNSServer(a *Agent) (*DNSServer, error) {
defaultEnterpriseMeta: *a.AgentEnterpriseMeta(),
mux: dns.NewServeMux(),
}
cfg, err := GetDNSConfig(a.config)
cfg, err := getDNSServerConfig(a.config)
if err != nil {
return nil, err
}
@ -163,9 +174,9 @@ func NewDNSServer(a *Agent) (*DNSServer, error) {
return srv, nil
}
// GetDNSConfig takes global config and creates the config used by DNS server
func GetDNSConfig(conf *config.RuntimeConfig) (*dnsConfig, error) {
cfg := &dnsConfig{
// getDNSServerConfig takes global config and creates the config used by DNS server
func getDNSServerConfig(conf *config.RuntimeConfig) (*dnsServerConfig, error) {
cfg := &dnsServerConfig{
AllowStale: conf.DNSAllowStale,
ARecordLimit: conf.DNSARecordLimit,
Datacenter: conf.Datacenter,
@ -217,7 +228,7 @@ func GetDNSConfig(conf *config.RuntimeConfig) (*dnsConfig, error) {
// GetTTLForService Find the TTL for a given service.
// return ttl, true if found, 0, false otherwise
func (cfg *dnsConfig) GetTTLForService(service string) (time.Duration, bool) {
func (cfg *dnsServerConfig) GetTTLForService(service string) (time.Duration, bool) {
if cfg.TTLStrict != nil {
ttl, ok := cfg.TTLStrict[service]
if ok {
@ -269,7 +280,7 @@ func (d *DNSServer) GetAddr() string {
}
// toggleRecursorHandlerFromConfig enables or disables the recursor handler based on config idempotently
func (d *DNSServer) toggleRecursorHandlerFromConfig(cfg *dnsConfig) {
func (d *DNSServer) toggleRecursorHandlerFromConfig(cfg *dnsServerConfig) {
shouldEnable := len(cfg.Recursors) > 0
if shouldEnable && atomic.CompareAndSwapUint32(&d.recursorEnabled, 0, 1) {
@ -287,7 +298,7 @@ func (d *DNSServer) toggleRecursorHandlerFromConfig(cfg *dnsConfig) {
// ReloadConfig hot-reloads the server config with new parameters under config.RuntimeConfig.DNS*
func (d *DNSServer) ReloadConfig(newCfg *config.RuntimeConfig) error {
cfg, err := GetDNSConfig(newCfg)
cfg, err := getDNSServerConfig(newCfg)
if err != nil {
return err
}
@ -407,7 +418,7 @@ func (d *DNSServer) handlePtr(resp dns.ResponseWriter, req *dns.Msg) {
)
}(time.Now())
cfg := d.config.Load().(*dnsConfig)
cfg := d.getRequestConfig(resp)
// Setup the message response
m := new(dns.Msg)
@ -430,7 +441,7 @@ func (d *DNSServer) handlePtr(resp dns.ResponseWriter, req *dns.Msg) {
args := structs.DCSpecificRequest{
Datacenter: datacenter,
QueryOptions: structs.QueryOptions{
Token: d.coalesceDNSToken(),
Token: d.coalesceDNSToken(cfg.token),
AllowStale: cfg.AllowStale,
},
}
@ -463,11 +474,11 @@ func (d *DNSServer) handlePtr(resp dns.ResponseWriter, req *dns.Msg) {
sargs := structs.ServiceSpecificRequest{
Datacenter: datacenter,
QueryOptions: structs.QueryOptions{
Token: d.coalesceDNSToken(),
Token: d.coalesceDNSToken(cfg.token),
AllowStale: cfg.AllowStale,
},
ServiceAddress: serviceAddress,
EnterpriseMeta: *d.defaultEnterpriseMeta.WithWildcardNamespace(),
EnterpriseMeta: *cfg.defaultEnterpriseMeta.WithWildcardNamespace(),
}
var sout structs.IndexedServiceNodes
@ -536,7 +547,7 @@ func (d *DNSServer) handleQuery(resp dns.ResponseWriter, req *dns.Msg) {
network = "tcp"
}
cfg := d.config.Load().(*dnsConfig)
cfg := d.getRequestConfig(resp)
// Set up the message response
m := new(dns.Msg)
@ -565,7 +576,7 @@ func (d *DNSServer) handleQuery(resp dns.ResponseWriter, req *dns.Msg) {
m.SetRcode(req, dns.RcodeNotImplemented)
default:
err = d.dispatch(resp.RemoteAddr(), req, m, maxRecursionLevelDefault)
err = d.dispatch(resp.RemoteAddr(), req, m, cfg, maxRecursionLevelDefault)
rCode := rCodeFromError(err)
if rCode == dns.RcodeNameError || errors.Is(err, errNoData) {
d.addSOAToMessage(cfg, m, q.Name)
@ -583,7 +594,7 @@ func (d *DNSServer) handleQuery(resp dns.ResponseWriter, req *dns.Msg) {
}
// Craft dns records for an SOA
func (d *DNSServer) makeSOARecord(cfg *dnsConfig, questionName string) *dns.SOA {
func (d *DNSServer) makeSOARecord(cfg *dnsRequestConfig, questionName string) *dns.SOA {
domain := d.domain
if d.altDomain != "" && strings.HasSuffix(questionName, "."+d.altDomain) {
domain = d.altDomain
@ -608,19 +619,19 @@ func (d *DNSServer) makeSOARecord(cfg *dnsConfig, questionName string) *dns.SOA
}
// addSOA is used to add an SOA record to a message for the given domain
func (d *DNSServer) addSOAToMessage(cfg *dnsConfig, msg *dns.Msg, questionName string) {
func (d *DNSServer) addSOAToMessage(cfg *dnsRequestConfig, msg *dns.Msg, questionName string) {
msg.Ns = append(msg.Ns, d.makeSOARecord(cfg, questionName))
}
// getNameserversAndNodeRecord returns the names and ip addresses of up to three random servers
// in the current cluster which serve as authoritative name servers for zone.
func (d *DNSServer) getNameserversAndNodeRecord(questionName string, cfg *dnsConfig, maxRecursionLevel int) (ns []dns.RR, extra []dns.RR) {
func (d *DNSServer) getNameserversAndNodeRecord(questionName string, cfg *dnsRequestConfig, maxRecursionLevel int) (ns []dns.RR, extra []dns.RR) {
out, err := d.lookupServiceNodes(cfg, serviceLookup{
Datacenter: d.agent.config.Datacenter,
Service: structs.ConsulServiceName,
Connect: false,
Ingress: false,
EnterpriseMeta: d.defaultEnterpriseMeta,
EnterpriseMeta: cfg.defaultEnterpriseMeta,
})
if err != nil {
d.logger.Warn("Unable to get list of servers", "error", err)
@ -658,7 +669,7 @@ func (d *DNSServer) getNameserversAndNodeRecord(questionName string, cfg *dnsCon
}
ns = append(ns, nsrr)
extra = append(extra, d.makeRecordFromNode(o.Node, dns.TypeANY, fqdn, cfg.NodeTTL, maxRecursionLevel)...)
extra = append(extra, d.makeRecordFromNode(o.Node, dns.TypeANY, fqdn, cfg, maxRecursionLevel)...)
// don't provide more than 3 servers
if len(ns) >= 3 {
@ -754,7 +765,7 @@ func (l queryLocality) effectiveDatacenter(defaultDC string) string {
// dispatch is used to parse a request and invoke the correct handler.
// parameter maxRecursionLevel will handle whether recursive call can be performed
func (d *DNSServer) dispatch(remoteAddr net.Addr, req, resp *dns.Msg, maxRecursionLevel int) error {
func (d *DNSServer) dispatch(remoteAddr net.Addr, req, resp *dns.Msg, cfg *dnsRequestConfig, maxRecursionLevel int) error {
// Choose correct response domain
respDomain := d.getResponseDomain(req.Question[0].Name)
@ -765,8 +776,6 @@ func (d *DNSServer) dispatch(remoteAddr net.Addr, req, resp *dns.Msg, maxRecursi
// Split into the label parts
labels := dns.SplitDomainName(qName)
cfg := d.config.Load().(*dnsConfig)
var queryKind string
var queryParts []string
var querySuffixes []string
@ -899,7 +908,7 @@ func (d *DNSServer) dispatch(remoteAddr net.Addr, req, resp *dns.Msg, maxRecursi
ServiceName: queryParts[len(queryParts)-1],
EnterpriseMeta: locality.EnterpriseMeta,
QueryOptions: structs.QueryOptions{
Token: d.coalesceDNSToken(),
Token: d.coalesceDNSToken(cfg.token),
},
}
if args.PeerName == "" {
@ -1099,7 +1108,8 @@ func rCodeFromError(err error) int {
case errors.Is(err, errNameNotFound),
structs.IsErrNoDCPath(err),
structs.IsErrQueryNotFound(err),
structs.IsErrSamenessGroupMustBeDefaultForFailover(err):
structs.IsErrSamenessGroupMustBeDefaultForFailover(err),
structs.IsErrSamenessGroupNotFound(err):
return dns.RcodeNameError
default:
return dns.RcodeServerFailure
@ -1107,7 +1117,7 @@ func rCodeFromError(err error) int {
}
// handleNodeQuery is used to handle a node query
func (d *DNSServer) handleNodeQuery(cfg *dnsConfig, lookup nodeLookup, req, resp *dns.Msg) error {
func (d *DNSServer) handleNodeQuery(cfg *dnsRequestConfig, lookup nodeLookup, req, resp *dns.Msg) error {
// Only handle ANY, A, AAAA, and TXT type requests
qType := req.Question[0].Qtype
if qType != dns.TypeANY && qType != dns.TypeA && qType != dns.TypeAAAA && qType != dns.TypeTXT {
@ -1120,7 +1130,7 @@ func (d *DNSServer) handleNodeQuery(cfg *dnsConfig, lookup nodeLookup, req, resp
PeerName: lookup.PeerName,
Node: lookup.Node,
QueryOptions: structs.QueryOptions{
Token: d.coalesceDNSToken(),
Token: d.coalesceDNSToken(cfg.token),
AllowStale: cfg.AllowStale,
},
EnterpriseMeta: lookup.EnterpriseMeta,
@ -1146,7 +1156,7 @@ func (d *DNSServer) handleNodeQuery(cfg *dnsConfig, lookup nodeLookup, req, resp
q := req.Question[0]
// Only compute A and CNAME record if query is not TXT type
if qType != dns.TypeTXT {
records := d.makeRecordFromNode(n, q.Qtype, q.Name, cfg.NodeTTL, lookup.MaxRecursionLevel)
records := d.makeRecordFromNode(n, q.Qtype, q.Name, cfg, lookup.MaxRecursionLevel)
resp.Answer = append(resp.Answer, records...)
}
@ -1159,7 +1169,7 @@ func (d *DNSServer) handleNodeQuery(cfg *dnsConfig, lookup nodeLookup, req, resp
// lookupNode is used to look up a node in the Consul catalog within NodeServices.
// If the config is set to UseCache, it will get the record from the agent cache.
func (d *DNSServer) lookupNode(cfg *dnsConfig, args *structs.NodeSpecificRequest) (*structs.IndexedNodeServices, error) {
func (d *DNSServer) lookupNode(cfg *dnsRequestConfig, args *structs.NodeSpecificRequest) (*structs.IndexedNodeServices, error) {
var out structs.IndexedNodeServices
useCache := cfg.UseCache
@ -1416,7 +1426,7 @@ func trimUDPResponse(req, resp *dns.Msg, udpAnswerLimit int) (trimmed bool) {
}
// trimDNSResponse will trim the response for UDP and TCP
func (d *DNSServer) trimDNSResponse(cfg *dnsConfig, network string, req, resp *dns.Msg) {
func (d *DNSServer) trimDNSResponse(cfg *dnsRequestConfig, network string, req, resp *dns.Msg) {
var trimmed bool
originalSize := resp.Len()
originalNumRecords := len(resp.Answer)
@ -1441,7 +1451,7 @@ func (d *DNSServer) trimDNSResponse(cfg *dnsConfig, network string, req, resp *d
// lookupServiceNodes is used to look up a node in the Consul health catalog within ServiceNodes.
// If the config is set to UseCache, it will get the record from the agent cache.
func (d *DNSServer) lookupServiceNodes(cfg *dnsConfig, lookup serviceLookup) (structs.IndexedCheckServiceNodes, error) {
func (d *DNSServer) lookupServiceNodes(cfg *dnsRequestConfig, lookup serviceLookup) (structs.IndexedCheckServiceNodes, error) {
serviceTags := []string{}
if lookup.Tag != "" {
serviceTags = []string{lookup.Tag}
@ -1461,7 +1471,7 @@ func (d *DNSServer) lookupServiceNodes(cfg *dnsConfig, lookup serviceLookup) (st
TagFilter: lookup.Tag != "",
HealthFilterType: healthFilterType,
QueryOptions: structs.QueryOptions{
Token: d.coalesceDNSToken(),
Token: d.coalesceDNSToken(cfg.token),
AllowStale: cfg.AllowStale,
MaxAge: cfg.CacheMaxAge,
UseCache: cfg.UseCache,
@ -1479,7 +1489,7 @@ func (d *DNSServer) lookupServiceNodes(cfg *dnsConfig, lookup serviceLookup) (st
}
// handleServiceQuery is used to handle a service query
func (d *DNSServer) handleServiceQuery(cfg *dnsConfig, lookup serviceLookup, req, resp *dns.Msg) error {
func (d *DNSServer) handleServiceQuery(cfg *dnsRequestConfig, lookup serviceLookup, req, resp *dns.Msg) error {
out, err := d.lookupServiceNodes(cfg, lookup)
if err != nil {
return fmt.Errorf("rpc request failed: %w", err)
@ -1528,13 +1538,13 @@ func ednsSubnetForRequest(req *dns.Msg) *dns.EDNS0_SUBNET {
}
// handlePreparedQuery is used to handle a prepared query.
func (d *DNSServer) handlePreparedQuery(cfg *dnsConfig, datacenter, query string, remoteAddr net.Addr, req, resp *dns.Msg, maxRecursionLevel int) error {
func (d *DNSServer) handlePreparedQuery(cfg *dnsRequestConfig, datacenter, query string, remoteAddr net.Addr, req, resp *dns.Msg, maxRecursionLevel int) error {
// Execute the prepared query.
args := structs.PreparedQueryExecuteRequest{
Datacenter: datacenter,
QueryIDOrName: query,
QueryOptions: structs.QueryOptions{
Token: d.coalesceDNSToken(),
Token: d.coalesceDNSToken(cfg.token),
AllowStale: cfg.AllowStale,
MaxAge: cfg.CacheMaxAge,
},
@ -1622,7 +1632,7 @@ func (d *DNSServer) handlePreparedQuery(cfg *dnsConfig, datacenter, query string
// lookupPreparedQuery is used to execute a PreparedQuery against the Consul catalog.
// If the config is set to UseCache, it will use agent cache.
func (d *DNSServer) lookupPreparedQuery(cfg *dnsConfig, args structs.PreparedQueryExecuteRequest) (*structs.PreparedQueryExecuteResponse, error) {
func (d *DNSServer) lookupPreparedQuery(cfg *dnsRequestConfig, args structs.PreparedQueryExecuteRequest) (*structs.PreparedQueryExecuteResponse, error) {
var out structs.PreparedQueryExecuteResponse
RPC:
@ -1664,7 +1674,7 @@ RPC:
}
// addServiceNodeRecordsToMessage is used to add the node records for a service lookup
func (d *DNSServer) addServiceNodeRecordsToMessage(cfg *dnsConfig, lookup serviceLookup, nodes structs.CheckServiceNodes, req, resp *dns.Msg, ttl time.Duration, maxRecursionLevel int) {
func (d *DNSServer) addServiceNodeRecordsToMessage(cfg *dnsRequestConfig, lookup serviceLookup, nodes structs.CheckServiceNodes, req, resp *dns.Msg, ttl time.Duration, maxRecursionLevel int) {
handled := make(map[string]struct{})
var answerCNAME []dns.RR = nil
@ -1803,7 +1813,8 @@ func makeARecord(qType uint16, ip net.IP, ttl time.Duration) dns.RR {
// Craft dns records for a node
// In case of an SRV query the answer will be a IN SRV and additional data will store an IN A to the node IP
// Otherwise it will return a IN A record
func (d *DNSServer) makeRecordFromNode(node *structs.Node, qType uint16, qName string, ttl time.Duration, maxRecursionLevel int) []dns.RR {
func (d *DNSServer) makeRecordFromNode(node *structs.Node, qType uint16, qName string, cfg *dnsRequestConfig, maxRecursionLevel int) []dns.RR {
ttl := cfg.NodeTTL
addrTranslate := dnsutil.TranslateAddressAcceptDomain
if qType == dns.TypeA {
addrTranslate |= dnsutil.TranslateAddressAcceptIPv4
@ -1830,7 +1841,7 @@ func (d *DNSServer) makeRecordFromNode(node *structs.Node, qType uint16, qName s
})
res = append(res,
d.resolveCNAME(d.config.Load().(*dnsConfig), dns.Fqdn(node.Address), maxRecursionLevel)...,
d.resolveCNAME(cfg, dns.Fqdn(node.Address), maxRecursionLevel)...,
)
return res
@ -1919,7 +1930,7 @@ func (d *DNSServer) makeRecordFromIP(lookup serviceLookup, addr net.IP, serviceN
// Craft dns records for an FQDN
// In case of an SRV query the answer will be a IN SRV and additional data will store an IN A to the IP
// Otherwise it will return a CNAME and a IN A record
func (d *DNSServer) makeRecordFromFQDN(lookup serviceLookup, fqdn string, serviceNode structs.CheckServiceNode, req *dns.Msg, ttl time.Duration, cfg *dnsConfig, maxRecursionLevel int) ([]dns.RR, []dns.RR) {
func (d *DNSServer) makeRecordFromFQDN(lookup serviceLookup, fqdn string, serviceNode structs.CheckServiceNode, req *dns.Msg, ttl time.Duration, cfg *dnsRequestConfig, maxRecursionLevel int) ([]dns.RR, []dns.RR) {
edns := req.IsEdns0() != nil
q := req.Question[0]
@ -1975,7 +1986,7 @@ MORE_REC:
}
// Craft dns records from a CheckServiceNode struct
func (d *DNSServer) makeNodeServiceRecords(lookup serviceLookup, node structs.CheckServiceNode, req *dns.Msg, ttl time.Duration, cfg *dnsConfig, maxRecursionLevel int) ([]dns.RR, []dns.RR) {
func (d *DNSServer) makeNodeServiceRecords(lookup serviceLookup, node structs.CheckServiceNode, req *dns.Msg, ttl time.Duration, cfg *dnsRequestConfig, maxRecursionLevel int) ([]dns.RR, []dns.RR) {
addrTranslate := dnsutil.TranslateAddressAcceptDomain
if req.Question[0].Qtype == dns.TypeA {
addrTranslate |= dnsutil.TranslateAddressAcceptIPv4
@ -2049,7 +2060,7 @@ func (d *DNSServer) makeTXTRecordFromNodeMeta(qName string, node *structs.Node,
}
// addServiceSRVRecordsToMessage is used to add the SRV records for a service lookup
func (d *DNSServer) addServiceSRVRecordsToMessage(cfg *dnsConfig, lookup serviceLookup, nodes structs.CheckServiceNodes, req, resp *dns.Msg, ttl time.Duration, maxRecursionLevel int) {
func (d *DNSServer) addServiceSRVRecordsToMessage(cfg *dnsRequestConfig, lookup serviceLookup, nodes structs.CheckServiceNodes, req, resp *dns.Msg, ttl time.Duration, maxRecursionLevel int) {
handled := make(map[string]struct{})
for _, node := range nodes {
@ -2080,7 +2091,7 @@ func (d *DNSServer) addServiceSRVRecordsToMessage(cfg *dnsConfig, lookup service
// handleRecurse is used to handle recursive DNS queries
func (d *DNSServer) handleRecurse(resp dns.ResponseWriter, req *dns.Msg) {
cfg := d.config.Load().(*dnsConfig)
cfg := d.getRequestConfig(resp)
q := req.Question[0]
network := "udp"
@ -2156,7 +2167,7 @@ func (d *DNSServer) handleRecurse(resp dns.ResponseWriter, req *dns.Msg) {
}
// resolveCNAME is used to recursively resolve CNAME records
func (d *DNSServer) resolveCNAME(cfg *dnsConfig, name string, maxRecursionLevel int) []dns.RR {
func (d *DNSServer) resolveCNAME(cfg *dnsRequestConfig, name string, maxRecursionLevel int) []dns.RR {
// If the CNAME record points to a Consul address, resolve it internally
// Convert query to lowercase because DNS is case insensitive; d.domain and
// d.altDomain are already converted
@ -2171,7 +2182,7 @@ func (d *DNSServer) resolveCNAME(cfg *dnsConfig, name string, maxRecursionLevel
req.SetQuestion(name, dns.TypeANY)
// TODO: handle error response
d.dispatch(nil, req, resp, maxRecursionLevel-1)
d.dispatch(nil, req, resp, cfg, maxRecursionLevel-1)
return resp.Answer
}
@ -2209,10 +2220,46 @@ func (d *DNSServer) resolveCNAME(cfg *dnsConfig, name string, maxRecursionLevel
return nil
}
func (d *DNSServer) coalesceDNSToken() string {
// coalesceDNSToken returns the ACL token to use for DNS queries.
// It returns the first token found in the following order:
// 1. The token from the request, if available.
// 2. The DNSToken from the agent.
// 3. The UserToken from the agent.
func (d *DNSServer) coalesceDNSToken(requestToken string) string {
// if the request token is set, which occurs when consul-dataplane forwards requests over gRPC, use it
if requestToken != "" {
return requestToken
}
if d.agent.tokens.DNSToken() != "" {
return d.agent.tokens.DNSToken()
} else {
return d.agent.tokens.UserToken()
}
return d.agent.tokens.UserToken()
}
// getRequestConfig returns the DNS request configuration that encapsulates:
// - the DNS server configuration.
// - the token from the request, if available.
// - the enterprise meta from the request, if available.
func (d *DNSServer) getRequestConfig(resp dns.ResponseWriter) *dnsRequestConfig {
dnsServerConfig := d.config.Load().(*dnsServerConfig)
requestDnsConfig := &dnsRequestConfig{
dnsServerConfig: dnsServerConfig,
defaultEnterpriseMeta: d.defaultEnterpriseMeta,
}
// DNS uses *dns.ServeMux, which takes a ResponseWriter interface and a DNS message both
// from the github.com/miekg/dns module, so we are limited in what we can pass as arguments.
// We can't pass a context.Context, so we have to add the RequestContext field to our
// implementation of dns.ResponseWriter to pass the context from the request.
if rw, ok := resp.(*agentdns.BufferResponseWriter); ok {
// use the ACL token from the request if available. Regular DNS hitting the
// agent will not carry a token, but gRPC requests from consul-dataplane will.
if rw.RequestContext.Token != "" {
requestDnsConfig.token = rw.RequestContext.Token
}
d.setEnterpriseMetaFromRequestContext(rw.RequestContext, requestDnsConfig)
}
return requestDnsConfig
}

View File

@ -0,0 +1,78 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package dns
import (
"github.com/hashicorp/go-hclog"
"github.com/miekg/dns"
"net"
)
// BufferResponseWriter writes a DNS response to a byte buffer.
type BufferResponseWriter struct {
// responseBuffer is the buffer that the response is written to.
responseBuffer []byte
// RequestContext is the context of the request that carries the ACL token and tenancy of the request.
RequestContext Context
// LocalAddress is the address of the server.
LocalAddress net.Addr
// RemoteAddress is the address of the client that sent the request.
RemoteAddress net.Addr
// Logger is the logger for the response writer.
Logger hclog.Logger
}
var _ dns.ResponseWriter = (*BufferResponseWriter)(nil)
// ResponseBuffer returns the buffer containing the response.
func (b *BufferResponseWriter) ResponseBuffer() []byte {
return b.responseBuffer
}
// LocalAddr returns the net.Addr of the server
func (b *BufferResponseWriter) LocalAddr() net.Addr {
return b.LocalAddress
}
// RemoteAddr returns the net.Addr of the client that sent the current request.
func (b *BufferResponseWriter) RemoteAddr() net.Addr {
return b.RemoteAddress
}
// WriteMsg writes a reply back to the client.
func (b *BufferResponseWriter) WriteMsg(m *dns.Msg) error {
// Pack message to bytes first.
msgBytes, err := m.Pack()
if err != nil {
b.Logger.Error("error packing message", "err", err)
return err
}
b.responseBuffer = msgBytes
return nil
}
// Write writes a raw buffer back to the client.
func (b *BufferResponseWriter) Write(m []byte) (int, error) {
b.Logger.Trace("Write was called")
return copy(b.responseBuffer, m), nil
}
// Close closes the connection.
func (b *BufferResponseWriter) Close() error {
// There's nothing for us to do here as we don't handle the connection.
return nil
}
// TsigStatus returns the status of the Tsig.
func (b *BufferResponseWriter) TsigStatus() error {
// TSIG doesn't apply to this response writer.
return nil
}
// TsigTimersOnly sets the tsig timers only boolean.
func (b *BufferResponseWriter) TsigTimersOnly(bool) {}
// Hijack lets the caller take over the connection.
// After a call to Hijack(), the DNS package will not do anything with the connection. {
func (b *BufferResponseWriter) Hijack() {}

View File

@ -1,383 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package dns
import (
"encoding/hex"
"net"
"strings"
"github.com/miekg/dns"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/discovery"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/internal/dnsutil"
)
// discoveryResultsFetcher is a facade for the DNS router to formulate
// and execute discovery queries.
type discoveryResultsFetcher struct{}
// getQueryOptions is a struct to hold the options for getQueryResults method.
type getQueryOptions struct {
req *dns.Msg
reqCtx Context
qName string
remoteAddress net.Addr
processor DiscoveryQueryProcessor
logger hclog.Logger
domain string
altDomain string
}
// getQueryResults returns a discovery.Result from a DNS message.
func (d discoveryResultsFetcher) getQueryResults(opts *getQueryOptions) ([]*discovery.Result, *discovery.Query, error) {
reqType := parseRequestType(opts.req)
switch reqType {
case requestTypeConsul:
// This is a special case of discovery.QueryByName where we know that we need to query the consul service
// regardless of the question name.
query := &discovery.Query{
QueryType: discovery.QueryTypeService,
QueryPayload: discovery.QueryPayload{
Name: structs.ConsulServiceName,
Tenancy: discovery.QueryTenancy{
// We specify the partition here so that in the case we are a client agent in a non-default partition.
// We don't want the query processors default partition to be used.
// This is a small hack because for V1 CE, this is not the correct default partition name, but we
// need to add something to disambiguate the empty field.
Partition: acl.DefaultPartitionName, //NOTE: note this won't work if we ever have V2 client agents
},
Limit: 3,
},
}
results, err := opts.processor.QueryByName(query, discovery.Context{Token: opts.reqCtx.Token})
return results, query, err
case requestTypeName:
query, err := buildQueryFromDNSMessage(opts.req, opts.reqCtx, opts.domain, opts.altDomain, opts.remoteAddress)
if err != nil {
opts.logger.Error("error building discovery query from DNS request", "error", err)
return nil, query, err
}
results, err := opts.processor.QueryByName(query, discovery.Context{Token: opts.reqCtx.Token})
if getErrorFromECSNotGlobalError(err) != nil {
opts.logger.Error("error processing discovery query", "error", err)
if structs.IsErrSamenessGroupMustBeDefaultForFailover(err) {
return nil, query, errNameNotFound
}
return nil, query, err
}
return results, query, err
case requestTypeIP:
ip := dnsutil.IPFromARPA(opts.qName)
if ip == nil {
opts.logger.Error("error building IP from DNS request", "name", opts.qName)
return nil, nil, errNameNotFound
}
results, err := opts.processor.QueryByIP(ip, discovery.Context{Token: opts.reqCtx.Token})
return results, nil, err
case requestTypeAddress:
results, err := buildAddressResults(opts.req)
if err != nil {
opts.logger.Error("error processing discovery query", "error", err)
return nil, nil, err
}
return results, nil, nil
}
opts.logger.Error("error parsing discovery query type", "requestType", reqType)
return nil, nil, errInvalidQuestion
}
// buildQueryFromDNSMessage returns a discovery.Query from a DNS message.
func buildQueryFromDNSMessage(req *dns.Msg, reqCtx Context, domain, altDomain string,
remoteAddress net.Addr) (*discovery.Query, error) {
queryType, queryParts, querySuffixes := getQueryTypePartsAndSuffixesFromDNSMessage(req, domain, altDomain)
queryTenancy, err := getQueryTenancy(reqCtx, queryType, querySuffixes)
if err != nil {
return nil, err
}
name, tag, err := getQueryNameAndTagFromParts(queryType, queryParts)
if err != nil {
return nil, err
}
portName := parsePort(queryParts)
switch {
case queryType == discovery.QueryTypeWorkload && req.Question[0].Qtype == dns.TypeSRV:
// Currently we do not support SRV records for workloads
return nil, errNotImplemented
case queryType == discovery.QueryTypeInvalid, name == "":
return nil, errInvalidQuestion
}
return &discovery.Query{
QueryType: queryType,
QueryPayload: discovery.QueryPayload{
Name: name,
Tenancy: queryTenancy,
Tag: tag,
PortName: portName,
SourceIP: getSourceIP(req, queryType, remoteAddress),
},
}, nil
}
// buildAddressResults returns a discovery.Result from a DNS request for addr. records.
func buildAddressResults(req *dns.Msg) ([]*discovery.Result, error) {
domain := dns.CanonicalName(req.Question[0].Name)
labels := dns.SplitDomainName(domain)
hexadecimal := labels[0]
if len(hexadecimal)/2 != 4 && len(hexadecimal)/2 != 16 {
return nil, errNameNotFound
}
var ip net.IP
ip, err := hex.DecodeString(hexadecimal)
if err != nil {
return nil, errNameNotFound
}
return []*discovery.Result{
{
Node: &discovery.Location{
Address: ip.String(),
},
Type: discovery.ResultTypeNode, // We choose node by convention since we do not know the origin of the IP
},
}, nil
}
// getQueryNameAndTagFromParts returns the query name and tag from the query parts that are taken from the original dns question.
//
// Valid Query Parts:
// [<tag>.]<service>
// [<port>.port.]<service>
// _<service>._<tag> // RFC 2782 style
func getQueryNameAndTagFromParts(queryType discovery.QueryType, queryParts []string) (string, string, error) {
n := len(queryParts)
if n == 0 {
return "", "", errInvalidQuestion
}
switch queryType {
case discovery.QueryTypeService:
if n > 3 {
// Having this many fields is never valid.
return "", "", errInvalidQuestion
}
if n == 3 && queryParts[n-2] != "port" {
// This probably means that someone was trying to use a tag name with a period.
// This was deprecated in Consul 0.3.
return "", "", errInvalidQuestion
}
// Support RFC 2782 style syntax
if n == 2 && strings.HasPrefix(queryParts[1], "_") && strings.HasPrefix(queryParts[0], "_") {
// Grab the tag since we make nuke it if it's tcp
tag := queryParts[1][1:]
// Treat _name._tcp.service.consul as a default, no need to filter on that tag
if tag == "tcp" {
tag = ""
}
name := queryParts[0][1:]
// _name._tag.service.consul
return name, tag, nil
}
// Standard-style lookup w/ tag
if n == 2 {
return queryParts[1], queryParts[0], nil
}
// This works for the v1 and v2 catalog queries, even if a port name was specified.
return queryParts[n-1], "", nil
case discovery.QueryTypePreparedQuery:
name := ""
// If the first and last DNS query parts begin with _, this is an RFC 2782 style SRV lookup.
// This allows for prepared query names to include "." (for backwards compatibility).
// Otherwise, this is a standard prepared query lookup.
if n >= 2 && strings.HasPrefix(queryParts[0], "_") && strings.HasPrefix(queryParts[n-1], "_") {
// The last DNS query part is the protocol field (ignored).
// All prior parts are the prepared query name or ID.
name = strings.Join(queryParts[:n-1], ".")
// Strip leading underscore
name = name[1:]
} else {
// Allow a "." in the query name, just join all the parts.
name = strings.Join(queryParts, ".")
}
if name == "" {
return "", "", errInvalidQuestion
}
return name, "", nil
}
name := queryParts[n-1]
if name == "" {
return "", "", errInvalidQuestion
}
return queryParts[n-1], "", nil
}
// getQueryTenancy returns a discovery.QueryTenancy from a DNS message.
func getQueryTenancy(reqCtx Context, queryType discovery.QueryType, querySuffixes []string) (discovery.QueryTenancy, error) {
labels, ok := parseLabels(querySuffixes)
if !ok {
return discovery.QueryTenancy{}, errNameNotFound
}
// If we don't have an explicit partition/ns in the request, try the first fallback
// which was supplied in the request context. The agent's partition will be used as the last fallback
// later in the query processor.
if labels.Partition == "" {
labels.Partition = reqCtx.DefaultPartition
}
if labels.Namespace == "" {
labels.Namespace = reqCtx.DefaultNamespace
}
// If we have a sameness group, we can return early without further data massage.
if labels.SamenessGroup != "" {
return discovery.QueryTenancy{
Namespace: labels.Namespace,
Partition: labels.Partition,
SamenessGroup: labels.SamenessGroup,
// Datacenter is not supported
}, nil
}
if queryType == discovery.QueryTypeVirtual {
if labels.Peer == "" {
// If the peer name was not explicitly defined, fall back to the ambiguously-parsed version.
labels.Peer = labels.PeerOrDatacenter
}
}
return discovery.QueryTenancy{
Namespace: labels.Namespace,
Partition: labels.Partition,
Peer: labels.Peer,
Datacenter: getEffectiveDatacenter(labels),
}, nil
}
// getEffectiveDatacenter returns the effective datacenter from the parsed labels.
func getEffectiveDatacenter(labels *parsedLabels) string {
switch {
case labels.Datacenter != "":
return labels.Datacenter
case labels.PeerOrDatacenter != "" && labels.Peer != labels.PeerOrDatacenter:
return labels.PeerOrDatacenter
}
return ""
}
// getQueryTypePartsAndSuffixesFromDNSMessage returns the query type, the parts, and suffixes of the query name.
func getQueryTypePartsAndSuffixesFromDNSMessage(req *dns.Msg, domain, altDomain string) (queryType discovery.QueryType, parts []string, suffixes []string) {
// Get the QName without the domain suffix
// TODO (v2-dns): we will also need to handle the "failover" and "no-failover" suffixes here.
// They come AFTER the domain. See `stripAnyFailoverSuffix` in router.go
qName := trimDomainFromQuestionName(req.Question[0].Name, domain, altDomain)
// Split into the label parts
labels := dns.SplitDomainName(qName)
done := false
for i := len(labels) - 1; i >= 0 && !done; i-- {
queryType = getQueryTypeFromLabels(labels[i])
switch queryType {
case discovery.QueryTypeService, discovery.QueryTypeWorkload,
discovery.QueryTypeConnect, discovery.QueryTypeVirtual, discovery.QueryTypeIngress,
discovery.QueryTypeNode, discovery.QueryTypePreparedQuery:
parts = labels[:i]
suffixes = labels[i+1:]
done = true
case discovery.QueryTypeInvalid:
fallthrough
default:
// If this is a SRV query the "service" label is optional, we add it back to use the
// existing code-path.
if req.Question[0].Qtype == dns.TypeSRV && strings.HasPrefix(labels[i], "_") {
queryType = discovery.QueryTypeService
parts = labels[:i+1]
suffixes = labels[i+1:]
done = true
}
}
}
return queryType, parts, suffixes
}
// trimDomainFromQuestionName returns the question name without the domain suffix.
func trimDomainFromQuestionName(questionName, domain, altDomain string) string {
qName := dns.CanonicalName(questionName)
longer := domain
shorter := altDomain
if len(shorter) > len(longer) {
longer, shorter = shorter, longer
}
if strings.HasSuffix(qName, "."+strings.TrimLeft(longer, ".")) {
return strings.TrimSuffix(qName, longer)
}
return strings.TrimSuffix(qName, shorter)
}
// getQueryTypeFromLabels returns the query type from the labels.
func getQueryTypeFromLabels(label string) discovery.QueryType {
switch label {
case "service":
return discovery.QueryTypeService
case "connect":
return discovery.QueryTypeConnect
case "virtual":
return discovery.QueryTypeVirtual
case "ingress":
return discovery.QueryTypeIngress
case "node":
return discovery.QueryTypeNode
case "query":
return discovery.QueryTypePreparedQuery
case "workload":
return discovery.QueryTypeWorkload
default:
return discovery.QueryTypeInvalid
}
}
// getSourceIP returns the source IP from the dns request.
func getSourceIP(req *dns.Msg, queryType discovery.QueryType, remoteAddr net.Addr) (sourceIP net.IP) {
if queryType == discovery.QueryTypePreparedQuery {
subnet := ednsSubnetForRequest(req)
if subnet != nil {
sourceIP = subnet.Address
} else {
switch v := remoteAddr.(type) {
case *net.UDPAddr:
sourceIP = v.IP
case *net.TCPAddr:
sourceIP = v.IP
case *net.IPAddr:
sourceIP = v.IP
}
}
}
return sourceIP
}

View File

@ -1,336 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package dns
import (
"testing"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/hashicorp/consul/agent/discovery"
)
// testCaseBuildQueryFromDNSMessage is a test case for the buildQueryFromDNSMessage function.
type testCaseBuildQueryFromDNSMessage struct {
name string
request *dns.Msg
requestContext *Context
expectedQuery *discovery.Query
expectedError string
}
// Test_buildQueryFromDNSMessage tests the buildQueryFromDNSMessage function.
func Test_buildQueryFromDNSMessage(t *testing.T) {
testCases := []testCaseBuildQueryFromDNSMessage{
// virtual ip queries
{
name: "test A 'virtual.' query",
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "db.virtual.consul", // "intentionally missing the trailing dot"
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
},
expectedQuery: &discovery.Query{
QueryType: discovery.QueryTypeVirtual,
QueryPayload: discovery.QueryPayload{
Name: "db",
Tenancy: discovery.QueryTenancy{},
},
},
},
{
name: "test A 'virtual.' with kitchen sink labels",
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "db.virtual.banana.ns.orange.ap.foo.peer.consul", // "intentionally missing the trailing dot"
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
},
expectedQuery: &discovery.Query{
QueryType: discovery.QueryTypeVirtual,
QueryPayload: discovery.QueryPayload{
Name: "db",
Tenancy: discovery.QueryTenancy{
Peer: "foo",
Namespace: "banana",
Partition: "orange",
},
},
},
},
{
name: "test A 'virtual.' with implicit peer",
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "db.virtual.foo.consul", // "intentionally missing the trailing dot"
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
},
expectedQuery: &discovery.Query{
QueryType: discovery.QueryTypeVirtual,
QueryPayload: discovery.QueryPayload{
Name: "db",
Tenancy: discovery.QueryTenancy{
Peer: "foo",
},
},
},
},
{
name: "test A 'virtual.' with implicit peer and namespace query",
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "db.virtual.frontend.foo.consul", // "intentionally missing the trailing dot"
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
},
expectedQuery: &discovery.Query{
QueryType: discovery.QueryTypeVirtual,
QueryPayload: discovery.QueryPayload{
Name: "db",
Tenancy: discovery.QueryTenancy{
Namespace: "frontend",
Peer: "foo",
},
},
},
},
// V1 Service Queries
{
name: "test A 'service.' standard query with tag",
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "primary.db.service.dc1.consul", // "intentionally missing the trailing dot"
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
},
expectedQuery: &discovery.Query{
QueryType: discovery.QueryTypeService,
QueryPayload: discovery.QueryPayload{
Name: "db",
Tag: "primary",
Tenancy: discovery.QueryTenancy{
Datacenter: "dc1",
},
},
},
},
{
name: "test A 'service.' RFC 2782 query with tag",
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "_db._primary.service.dc1.consul", // "intentionally missing the trailing dot"
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
},
expectedQuery: &discovery.Query{
QueryType: discovery.QueryTypeService,
QueryPayload: discovery.QueryPayload{
Name: "db",
Tag: "primary",
Tenancy: discovery.QueryTenancy{
Datacenter: "dc1",
},
},
},
},
{
name: "test A 'service.' RFC 2782 query",
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "_db._tcp.service.dc1.consul", // the `tcp` tag should be ignored
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
},
expectedQuery: &discovery.Query{
QueryType: discovery.QueryTypeService,
QueryPayload: discovery.QueryPayload{
Name: "db",
Tenancy: discovery.QueryTenancy{
Datacenter: "dc1",
},
},
},
},
{
name: "test A 'service.' with too many query parts (RFC 2782 style)",
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "nope._db._tcp.service.dc1.consul", // the `tcp` tag should be ignored
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
},
expectedError: "invalid question",
},
{
name: "test A 'service.' with too many query parts (standard style)",
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "too.many.parts.service.dc1.consul.",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
},
expectedError: "invalid question",
},
// V2 Catalog Queries
{
name: "test A 'workload.'",
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "foo.workload.consul", // "intentionally missing the trailing dot"
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
},
expectedQuery: &discovery.Query{
QueryType: discovery.QueryTypeWorkload,
QueryPayload: discovery.QueryPayload{
Name: "foo",
Tenancy: discovery.QueryTenancy{},
},
},
},
{
name: "test A 'workload.' with all possible labels",
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "api.port.foo.workload.banana.ns.orange.ap.apple.peer.consul", // "intentionally missing the trailing dot"
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
},
requestContext: &Context{
DefaultPartition: "default-partition",
},
expectedQuery: &discovery.Query{
QueryType: discovery.QueryTypeWorkload,
QueryPayload: discovery.QueryPayload{
Name: "foo",
PortName: "api",
Tenancy: discovery.QueryTenancy{
Namespace: "banana",
Partition: "orange",
Peer: "apple",
},
},
},
},
{
name: "test sameness group with all possible labels",
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "foo.service.apple.sg.banana.ns.orange.ap.consul", // "intentionally missing the trailing dot"
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
},
requestContext: &Context{
DefaultPartition: "default-partition",
},
expectedQuery: &discovery.Query{
QueryType: discovery.QueryTypeService,
QueryPayload: discovery.QueryPayload{
Name: "foo",
Tenancy: discovery.QueryTenancy{
Namespace: "banana",
Partition: "orange",
SamenessGroup: "apple",
},
},
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
context := tc.requestContext
if context == nil {
context = &Context{}
}
query, err := buildQueryFromDNSMessage(tc.request, *context, "consul.", ".", nil)
if tc.expectedError != "" {
require.Error(t, err)
assert.Contains(t, err.Error(), tc.expectedError)
return
}
require.NoError(t, err)
assert.Equal(t, tc.expectedQuery, query)
})
}
}

View File

@ -1,87 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package dns
import (
"github.com/miekg/dns"
"net"
"strings"
)
func newDNSAddress(addr string) *dnsAddress {
a := &dnsAddress{}
a.SetAddress(addr)
return a
}
// dnsAddress is a wrapper around a string that represents a DNS address and
// provides helper methods for determining whether it is an IP or FQDN and
// whether it is internal or external to the domain.
type dnsAddress struct {
addr string
// store an IP so helpers don't have to parse it multiple times
ip net.IP
}
// SetAddress sets the address field and the ip field if the string is an IP.
func (a *dnsAddress) SetAddress(addr string) {
a.addr = addr
a.ip = net.ParseIP(addr)
}
// IP returns the IP address if the address is an IP.
func (a *dnsAddress) IP() net.IP {
return a.ip
}
// IsIP returns true if the address is an IP.
func (a *dnsAddress) IsIP() bool {
return a.IP() != nil
}
// IsIPV4 returns true if the address is an IPv4 address.
func (a *dnsAddress) IsIPV4() bool {
if a.IP() == nil {
return false
}
return a.IP().To4() != nil
}
// FQDN returns the FQDN if the address is not an IP.
func (a *dnsAddress) FQDN() string {
if !a.IsEmptyString() && !a.IsIP() {
return dns.Fqdn(a.addr)
}
return ""
}
// IsFQDN returns true if the address is a FQDN and not an IP.
func (a *dnsAddress) IsFQDN() bool {
return !a.IsEmptyString() && !a.IsIP() && dns.IsFqdn(a.FQDN())
}
// String returns the address as a string.
func (a *dnsAddress) String() string {
return a.addr
}
// IsEmptyString returns true if the address is an empty string.
func (a *dnsAddress) IsEmptyString() bool {
return a.addr == ""
}
// IsInternalFQDN returns true if the address is a FQDN and is internal to the domain.
func (a *dnsAddress) IsInternalFQDN(domain string) bool {
return !a.IsIP() && a.IsFQDN() && strings.HasSuffix(a.FQDN(), domain)
}
// IsInternalFQDNOrIP returns true if the address is an IP or a FQDN and is internal to the domain.
func (a *dnsAddress) IsInternalFQDNOrIP(domain string) bool {
return a.IsIP() || a.IsInternalFQDN(domain)
}
// IsExternalFQDN returns true if the address is a FQDN and is external to the domain.
func (a *dnsAddress) IsExternalFQDN(domain string) bool {
return !a.IsIP() && a.IsFQDN() && strings.Count(a.FQDN(), ".") > 1 && !strings.HasSuffix(a.FQDN(), domain)
}

View File

@ -1,168 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package dns
import (
"github.com/stretchr/testify/assert"
"testing"
)
func Test_dnsAddress(t *testing.T) {
const domain = "consul."
type expectedResults struct {
isIp bool
stringResult string
fqdn string
isFQDN bool
isEmptyString bool
isExternalFQDN bool
isInternalFQDN bool
isInternalFQDNOrIP bool
}
type testCase struct {
name string
input string
expectedResults expectedResults
}
testCases := []testCase{
{
name: "empty string",
input: "",
expectedResults: expectedResults{
isIp: false,
stringResult: "",
fqdn: "",
isFQDN: false,
isEmptyString: true,
isExternalFQDN: false,
isInternalFQDN: false,
isInternalFQDNOrIP: false,
},
},
{
name: "ipv4 address",
input: "127.0.0.1",
expectedResults: expectedResults{
isIp: true,
stringResult: "127.0.0.1",
fqdn: "",
isFQDN: false,
isEmptyString: false,
isExternalFQDN: false,
isInternalFQDN: false,
isInternalFQDNOrIP: true,
},
},
{
name: "ipv6 address",
input: "2001:db8:1:2:cafe::1337",
expectedResults: expectedResults{
isIp: true,
stringResult: "2001:db8:1:2:cafe::1337",
fqdn: "",
isFQDN: false,
isEmptyString: false,
isExternalFQDN: false,
isInternalFQDN: false,
isInternalFQDNOrIP: true,
},
},
{
name: "internal FQDN without trailing period",
input: "web.service.consul",
expectedResults: expectedResults{
isIp: false,
stringResult: "web.service.consul",
fqdn: "web.service.consul.",
isFQDN: true,
isEmptyString: false,
isExternalFQDN: false,
isInternalFQDN: true,
isInternalFQDNOrIP: true,
},
},
{
name: "internal FQDN with period",
input: "web.service.consul.",
expectedResults: expectedResults{
isIp: false,
stringResult: "web.service.consul.",
fqdn: "web.service.consul.",
isFQDN: true,
isEmptyString: false,
isExternalFQDN: false,
isInternalFQDN: true,
isInternalFQDNOrIP: true,
},
},
{
name: "server name",
input: "web.",
expectedResults: expectedResults{
isIp: false,
stringResult: "web.",
fqdn: "web.",
isFQDN: true,
isEmptyString: false,
isExternalFQDN: false,
isInternalFQDN: false,
isInternalFQDNOrIP: false,
},
},
{
name: "external FQDN without trailing period",
input: "web.service.vault",
expectedResults: expectedResults{
isIp: false,
stringResult: "web.service.vault",
fqdn: "web.service.vault.",
isFQDN: true,
isEmptyString: false,
isExternalFQDN: true,
isInternalFQDN: false,
isInternalFQDNOrIP: false,
},
},
{
name: "external FQDN with trailing period",
input: "web.service.vault.",
expectedResults: expectedResults{
isIp: false,
stringResult: "web.service.vault.",
fqdn: "web.service.vault.",
isFQDN: true,
isEmptyString: false,
isExternalFQDN: true,
isInternalFQDN: false,
isInternalFQDNOrIP: false,
},
},
{
name: "another external FQDN",
input: "www.google.com",
expectedResults: expectedResults{
isIp: false,
stringResult: "www.google.com",
fqdn: "www.google.com.",
isFQDN: true,
isEmptyString: false,
isExternalFQDN: true,
isInternalFQDN: false,
isInternalFQDNOrIP: false,
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
dnsAddress := newDNSAddress(tc.input)
assert.Equal(t, tc.expectedResults.isIp, dnsAddress.IsIP())
assert.Equal(t, tc.expectedResults.stringResult, dnsAddress.String())
assert.Equal(t, tc.expectedResults.isFQDN, dnsAddress.IsFQDN())
assert.Equal(t, tc.expectedResults.isEmptyString, dnsAddress.IsEmptyString())
assert.Equal(t, tc.expectedResults.isExternalFQDN, dnsAddress.IsExternalFQDN(domain))
assert.Equal(t, tc.expectedResults.isInternalFQDN, dnsAddress.IsInternalFQDN(domain))
assert.Equal(t, tc.expectedResults.isInternalFQDNOrIP, dnsAddress.IsInternalFQDNOrIP(domain))
})
}
}

View File

@ -1,151 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package dns
import (
"regexp"
"strings"
"time"
"github.com/miekg/dns"
"github.com/hashicorp/consul/agent/discovery"
)
// dnsRecordMaker creates DNS records to be used when generating
// responses to dns requests.
type dnsRecordMaker struct{}
// makeSOA returns an SOA record for the given domain and config.
func (dnsRecordMaker) makeSOA(domain string, cfg *RouterDynamicConfig) dns.RR {
return &dns.SOA{
Hdr: dns.RR_Header{
Name: domain,
Rrtype: dns.TypeSOA,
Class: dns.ClassINET,
// Has to be consistent with MinTTL to avoid invalidation
Ttl: cfg.SOAConfig.Minttl,
},
Ns: "ns." + domain,
Serial: uint32(time.Now().Unix()),
Mbox: "hostmaster." + domain,
Refresh: cfg.SOAConfig.Refresh,
Retry: cfg.SOAConfig.Retry,
Expire: cfg.SOAConfig.Expire,
Minttl: cfg.SOAConfig.Minttl,
}
}
// makeNS returns an NS record for the given domain and fqdn.
func (dnsRecordMaker) makeNS(domain, fqdn string, ttl uint32) dns.RR {
return &dns.NS{
Hdr: dns.RR_Header{
Name: domain,
Rrtype: dns.TypeNS,
Class: dns.ClassINET,
Ttl: ttl,
},
Ns: fqdn,
}
}
// makeIPBasedRecord returns an A or AAAA record for the given name and IP.
// Note: we might want to pass in the Query Name here, which is used in addr. and virtual. queries
// since there is only ever one result. Right now choosing to leave it off for simplification.
func (dnsRecordMaker) makeIPBasedRecord(name string, addr *dnsAddress, ttl uint32) dns.RR {
if addr.IsIPV4() {
// check if the query type is A for IPv4 or ANY
return &dns.A{
Hdr: dns.RR_Header{
Name: name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: ttl,
},
A: addr.IP(),
}
}
return &dns.AAAA{
Hdr: dns.RR_Header{
Name: name,
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: ttl,
},
AAAA: addr.IP(),
}
}
// makeCNAME returns a CNAME record for the given name and target.
func (dnsRecordMaker) makeCNAME(name string, target string, ttl uint32) *dns.CNAME {
return &dns.CNAME{
Hdr: dns.RR_Header{
Name: name,
Rrtype: dns.TypeCNAME,
Class: dns.ClassINET,
Ttl: ttl,
},
Target: dns.Fqdn(target),
}
}
// makeSRV returns an SRV record for the given name and target.
func (dnsRecordMaker) makeSRV(name, target string, weight uint16, ttl uint32, port *discovery.Port) *dns.SRV {
return &dns.SRV{
Hdr: dns.RR_Header{
Name: name,
Rrtype: dns.TypeSRV,
Class: dns.ClassINET,
Ttl: ttl,
},
Priority: 1,
Weight: weight,
Port: uint16(port.Number),
Target: target,
}
}
// makeTXT returns a TXT record for the given name and result metadata.
func (dnsRecordMaker) makeTXT(name string, metadata map[string]string, ttl uint32) []dns.RR {
extra := make([]dns.RR, 0, len(metadata))
for key, value := range metadata {
txt := value
if !strings.HasPrefix(strings.ToLower(key), "rfc1035-") {
txt = encodeKVasRFC1464(key, value)
}
extra = append(extra, &dns.TXT{
Hdr: dns.RR_Header{
Name: name,
Rrtype: dns.TypeTXT,
Class: dns.ClassINET,
Ttl: ttl,
},
Txt: []string{txt},
})
}
return extra
}
// encodeKVasRFC1464 encodes a key-value pair according to RFC1464
func encodeKVasRFC1464(key, value string) (txt string) {
// For details on these replacements c.f. https://www.ietf.org/rfc/rfc1464.txt
key = strings.Replace(key, "`", "``", -1)
key = strings.Replace(key, "=", "`=", -1)
// Backquote the leading spaces
leadingSpacesRE := regexp.MustCompile("^ +")
numLeadingSpaces := len(leadingSpacesRE.FindString(key))
key = leadingSpacesRE.ReplaceAllString(key, strings.Repeat("` ", numLeadingSpaces))
// Backquote the trailing spaces
numTrailingSpaces := len(trailingSpacesRE.FindString(key))
key = trailingSpacesRE.ReplaceAllString(key, strings.Repeat("` ", numTrailingSpaces))
value = strings.Replace(value, "`", "``", -1)
return key + "=" + value
}

View File

@ -1,228 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package dns
import (
"testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/require"
"github.com/hashicorp/consul/agent/discovery"
)
func TestDNSRecordMaker_makeSOA(t *testing.T) {
cfg := &RouterDynamicConfig{
SOAConfig: SOAConfig{
Refresh: 1,
Retry: 2,
Expire: 3,
Minttl: 4,
},
}
domain := "testdomain."
expected := &dns.SOA{
Hdr: dns.RR_Header{
Name: "testdomain.",
Rrtype: dns.TypeSOA,
Class: dns.ClassINET,
Ttl: 4,
},
Ns: "ns.testdomain.",
Serial: uint32(time.Now().Unix()),
Mbox: "hostmaster.testdomain.",
Refresh: 1,
Retry: 2,
Expire: 3,
Minttl: 4,
}
actual := dnsRecordMaker{}.makeSOA(domain, cfg)
require.Equal(t, expected, actual)
}
func TestDNSRecordMaker_makeNS(t *testing.T) {
domain := "testdomain."
fqdn := "ns.testdomain."
ttl := uint32(123)
expected := &dns.NS{
Hdr: dns.RR_Header{
Name: "testdomain.",
Rrtype: dns.TypeNS,
Class: dns.ClassINET,
Ttl: 123,
},
Ns: "ns.testdomain.",
}
actual := dnsRecordMaker{}.makeNS(domain, fqdn, ttl)
require.Equal(t, expected, actual)
}
func TestDNSRecordMaker_makeIPBasedRecord(t *testing.T) {
ipv4Addr := newDNSAddress("1.2.3.4")
ipv6Addr := newDNSAddress("2001:db8:1:2:cafe::1337")
testCases := []struct {
name string
recordHeaderName string
addr *dnsAddress
ttl uint32
expected dns.RR
}{
{
name: "IPv4",
recordHeaderName: "my.service.dc1.consul.",
addr: ipv4Addr,
ttl: 123,
expected: &dns.A{
Hdr: dns.RR_Header{
Name: "my.service.dc1.consul.",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 123,
},
A: ipv4Addr.IP(),
},
},
{
name: "IPv6",
recordHeaderName: "my.service.dc1.consul.",
addr: ipv6Addr,
ttl: 123,
expected: &dns.AAAA{
Hdr: dns.RR_Header{
Name: "my.service.dc1.consul.",
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: 123,
},
AAAA: ipv6Addr.IP(),
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
actual := dnsRecordMaker{}.makeIPBasedRecord(tc.recordHeaderName, tc.addr, tc.ttl)
require.Equal(t, tc.expected, actual)
})
}
}
func TestDNSRecordMaker_makeCNAME(t *testing.T) {
name := "my.service.consul."
target := "foo"
ttl := uint32(123)
expected := &dns.CNAME{
Hdr: dns.RR_Header{
Name: "my.service.consul.",
Rrtype: dns.TypeCNAME,
Class: dns.ClassINET,
Ttl: 123,
},
Target: "foo.",
}
actual := dnsRecordMaker{}.makeCNAME(name, target, ttl)
require.Equal(t, expected, actual)
}
func TestDNSRecordMaker_makeSRV(t *testing.T) {
name := "my.service.consul."
target := "foo"
ttl := uint32(123)
expected := &dns.SRV{
Hdr: dns.RR_Header{
Name: "my.service.consul.",
Rrtype: dns.TypeSRV,
Class: dns.ClassINET,
Ttl: 123,
},
Priority: 1,
Weight: uint16(345),
Port: uint16(234),
Target: "foo",
}
actual := dnsRecordMaker{}.makeSRV(name, target, uint16(345), ttl, &discovery.Port{Number: 234})
require.Equal(t, expected, actual)
}
func TestDNSRecordMaker_makeTXT(t *testing.T) {
testCases := []struct {
name string
metadata map[string]string
ttl uint32
expected []dns.RR
}{
{
name: "single metadata",
metadata: map[string]string{
"key": "value",
},
ttl: 123,
expected: []dns.RR{
&dns.TXT{
Hdr: dns.RR_Header{
Name: "my.service.consul.",
Rrtype: dns.TypeTXT,
Class: dns.ClassINET,
Ttl: 123,
},
Txt: []string{"key=value"},
},
},
},
{
name: "multiple metadata entries",
metadata: map[string]string{
"key1": "value1",
"key2": "value2",
},
ttl: 123,
expected: []dns.RR{
&dns.TXT{
Hdr: dns.RR_Header{
Name: "my.service.consul.",
Rrtype: dns.TypeTXT,
Class: dns.ClassINET,
Ttl: 123,
},
Txt: []string{"key1=value1"},
},
&dns.TXT{
Hdr: dns.RR_Header{
Name: "my.service.consul.",
Rrtype: dns.TypeTXT,
Class: dns.ClassINET,
Ttl: 123,
},
Txt: []string{"key2=value2"},
},
},
},
{
name: "'rfc1035-' prefixed- metadata entry",
metadata: map[string]string{
"rfc1035-key": "value",
},
ttl: 123,
expected: []dns.RR{
&dns.TXT{
Hdr: dns.RR_Header{
Name: "my.service.consul.",
Rrtype: dns.TypeTXT,
Class: dns.ClassINET,
Ttl: 123,
},
Txt: []string{"value"},
},
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
actual := dnsRecordMaker{}.makeTXT("my.service.consul.", tc.metadata, tc.ttl)
require.ElementsMatchf(t, tc.expected, actual, "expected: %v, actual: %v", tc.expected, actual)
})
}
}

View File

@ -1,651 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package dns
import (
"encoding/hex"
"fmt"
"net"
"strings"
"time"
"github.com/miekg/dns"
"github.com/hashicorp/consul/agent/discovery"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/internal/dnsutil"
)
// messageSerializer is the high level orchestrator for generating the Answer,
// Extra, and Ns records for a DNS response.
type messageSerializer struct{}
// serializeOptions are the options for serializing a discovery.Result into a DNS message.
type serializeOptions struct {
req *dns.Msg
reqCtx Context
query *discovery.Query
results []*discovery.Result
resp *dns.Msg
cfg *RouterDynamicConfig
responseDomain string
remoteAddress net.Addr
maxRecursionLevel int
dnsRecordMaker dnsRecordMaker
translateAddressFunc func(dc string, addr string, taggedAddresses map[string]string, accept dnsutil.TranslateAddressAccept) string
translateServiceAddressFunc func(dc string, address string, taggedAddresses map[string]structs.ServiceAddress, accept dnsutil.TranslateAddressAccept) string
resolveCnameFunc func(cfgContext *RouterDynamicConfig, name string, reqCtx Context, remoteAddress net.Addr, maxRecursionLevel int) []dns.RR
}
// serializeQueryResults converts a discovery.Result into a DNS message.
func (d messageSerializer) serialize(opts *serializeOptions) (*dns.Msg, error) {
resp := new(dns.Msg)
resp.SetReply(opts.req)
resp.Compress = !opts.cfg.DisableCompression
resp.Authoritative = true
resp.RecursionAvailable = canRecurse(opts.cfg)
opts.resp = resp
qType := opts.req.Question[0].Qtype
reqType := parseRequestType(opts.req)
// Always add the SOA record if requested.
if qType == dns.TypeSOA {
resp.Answer = append(resp.Answer, opts.dnsRecordMaker.makeSOA(opts.responseDomain, opts.cfg))
}
switch {
case qType == dns.TypeSOA, reqType == requestTypeAddress:
for _, result := range opts.results {
for _, port := range getPortsFromResult(result) {
ans, ex, ns := d.getAnswerExtraAndNs(serializeToGetAnswerExtraAndNsOptions(opts, result, port))
resp.Answer = append(resp.Answer, ans...)
resp.Extra = append(resp.Extra, ex...)
resp.Ns = append(resp.Ns, ns...)
}
}
case qType == dns.TypeSRV:
handled := make(map[string]struct{})
for _, result := range opts.results {
for _, port := range getPortsFromResult(result) {
// Avoid duplicate entries, possible if a node has
// the same service the same port, etc.
// The datacenter should be empty during translation if it is a peering lookup.
// This should be fine because we should always prefer the WAN address.
address := ""
if result.Service != nil {
address = result.Service.Address
} else {
address = result.Node.Address
}
tuple := fmt.Sprintf("%s:%s:%d", result.Node.Name, address, port.Number)
if _, ok := handled[tuple]; ok {
continue
}
handled[tuple] = struct{}{}
ans, ex, ns := d.getAnswerExtraAndNs(serializeToGetAnswerExtraAndNsOptions(opts, result, port))
resp.Answer = append(resp.Answer, ans...)
resp.Extra = append(resp.Extra, ex...)
resp.Ns = append(resp.Ns, ns...)
}
}
default:
// default will send it to where it does some de-duping while it calls getAnswerExtraAndNs and recurses.
d.appendResultsToDNSResponse(opts)
}
if opts.query != nil && opts.query.QueryType != discovery.QueryTypeVirtual &&
len(resp.Answer) == 0 && len(resp.Extra) == 0 {
return nil, discovery.ErrNoData
}
return resp, nil
}
// appendResultsToDNSResponse builds dns message from the discovery results and
// appends them to the dns response.
func (d messageSerializer) appendResultsToDNSResponse(opts *serializeOptions) {
// Always add the SOA record if requested.
if opts.req.Question[0].Qtype == dns.TypeSOA {
opts.resp.Answer = append(opts.resp.Answer, opts.dnsRecordMaker.makeSOA(opts.responseDomain, opts.cfg))
}
handled := make(map[string]struct{})
var answerCNAME []dns.RR = nil
count := 0
for _, result := range opts.results {
for _, port := range getPortsFromResult(result) {
// Add the node record
had_answer := false
ans, extra, _ := d.getAnswerExtraAndNs(serializeToGetAnswerExtraAndNsOptions(opts, result, port))
opts.resp.Extra = append(opts.resp.Extra, extra...)
if len(ans) == 0 {
continue
}
// Avoid duplicate entries, possible if a node has
// the same service on multiple ports, etc.
if _, ok := handled[ans[0].String()]; ok {
continue
}
handled[ans[0].String()] = struct{}{}
switch ans[0].(type) {
case *dns.CNAME:
// keep track of the first CNAME + associated RRs but don't add to the resp.Answer yet
// this will only be added if no non-CNAME RRs are found
if len(answerCNAME) == 0 {
answerCNAME = ans
}
default:
opts.resp.Answer = append(opts.resp.Answer, ans...)
had_answer = true
}
if had_answer {
count++
if count == opts.cfg.ARecordLimit {
// We stop only if greater than 0 or we reached the limit
return
}
}
}
}
if len(opts.resp.Answer) == 0 && len(answerCNAME) > 0 {
opts.resp.Answer = answerCNAME
}
}
// getAnswerExtraAndNsOptions are the options for getting the Answer, Extra, and Ns records for a DNS response.
type getAnswerExtraAndNsOptions struct {
port discovery.Port
result *discovery.Result
req *dns.Msg
reqCtx Context
query *discovery.Query
results []*discovery.Result
resp *dns.Msg
cfg *RouterDynamicConfig
responseDomain string
remoteAddress net.Addr
maxRecursionLevel int
ttl uint32
dnsRecordMaker dnsRecordMaker
translateAddressFunc func(dc string, addr string, taggedAddresses map[string]string, accept dnsutil.TranslateAddressAccept) string
translateServiceAddressFunc func(dc string, address string, taggedAddresses map[string]structs.ServiceAddress, accept dnsutil.TranslateAddressAccept) string
resolveCnameFunc func(cfgContext *RouterDynamicConfig, name string, reqCtx Context, remoteAddress net.Addr, maxRecursionLevel int) []dns.RR
}
// getAnswerAndExtra creates the dns answer and extra from discovery results.
func (d messageSerializer) getAnswerExtraAndNs(opts *getAnswerExtraAndNsOptions) (answer []dns.RR, extra []dns.RR, ns []dns.RR) {
serviceAddress, nodeAddress := d.getServiceAndNodeAddresses(opts)
qName := opts.req.Question[0].Name
ttlLookupName := qName
if opts.query != nil {
ttlLookupName = opts.query.QueryPayload.Name
}
opts.ttl = getTTLForResult(ttlLookupName, opts.result.DNS.TTL, opts.query, opts.cfg)
qType := opts.req.Question[0].Qtype
// TODO (v2-dns): skip records that refer to a workload/node that don't have a valid DNS name.
// Special case responses
switch {
// PTR requests are first since they are a special case of domain overriding question type
case parseRequestType(opts.req) == requestTypeIP:
ptrTarget := ""
if opts.result.Type == discovery.ResultTypeNode {
ptrTarget = opts.result.Node.Name
} else if opts.result.Type == discovery.ResultTypeService {
ptrTarget = opts.result.Service.Name
}
ptr := &dns.PTR{
Hdr: dns.RR_Header{Name: qName, Rrtype: dns.TypePTR, Class: dns.ClassINET, Ttl: 0},
Ptr: canonicalNameForResult(opts.result.Type, ptrTarget, opts.responseDomain, opts.result.Tenancy, opts.port.Name),
}
answer = append(answer, ptr)
case qType == dns.TypeNS:
resultType := opts.result.Type
target := opts.result.Node.Name
if parseRequestType(opts.req) == requestTypeConsul && resultType == discovery.ResultTypeService {
resultType = discovery.ResultTypeNode
}
fqdn := canonicalNameForResult(resultType, target, opts.responseDomain, opts.result.Tenancy, opts.port.Name)
extraRecord := opts.dnsRecordMaker.makeIPBasedRecord(fqdn, nodeAddress, opts.ttl)
answer = append(answer, opts.dnsRecordMaker.makeNS(opts.responseDomain, fqdn, opts.ttl))
extra = append(extra, extraRecord)
case qType == dns.TypeSOA:
// to be returned in the result.
fqdn := canonicalNameForResult(opts.result.Type, opts.result.Node.Name, opts.responseDomain, opts.result.Tenancy, opts.port.Name)
extraRecord := opts.dnsRecordMaker.makeIPBasedRecord(fqdn, nodeAddress, opts.ttl)
ns = append(ns, opts.dnsRecordMaker.makeNS(opts.responseDomain, fqdn, opts.ttl))
extra = append(extra, extraRecord)
case qType == dns.TypeSRV:
fallthrough
default:
a, e := d.getAnswerExtrasForAddressAndTarget(nodeAddress, serviceAddress, opts)
answer = append(answer, a...)
extra = append(extra, e...)
}
a, e := getAnswerAndExtraTXT(opts.req, opts.cfg, qName, opts.result, opts.ttl,
opts.responseDomain, opts.query, &opts.port, opts.dnsRecordMaker)
answer = append(answer, a...)
extra = append(extra, e...)
return
}
// getServiceAndNodeAddresses returns the service and node addresses from a discovery result.
func (d messageSerializer) getServiceAndNodeAddresses(opts *getAnswerExtraAndNsOptions) (*dnsAddress, *dnsAddress) {
addrTranslate := dnsutil.TranslateAddressAcceptDomain
if opts.req.Question[0].Qtype == dns.TypeA {
addrTranslate |= dnsutil.TranslateAddressAcceptIPv4
} else if opts.req.Question[0].Qtype == dns.TypeAAAA {
addrTranslate |= dnsutil.TranslateAddressAcceptIPv6
} else {
addrTranslate |= dnsutil.TranslateAddressAcceptAny
}
// The datacenter should be empty during translation if it is a peering lookup.
// This should be fine because we should always prefer the WAN address.
serviceAddress := newDNSAddress("")
if opts.result.Service != nil {
sa := opts.translateServiceAddressFunc(opts.result.Tenancy.Datacenter,
opts.result.Service.Address, getServiceAddressMapFromLocationMap(opts.result.Service.TaggedAddresses),
addrTranslate)
serviceAddress = newDNSAddress(sa)
}
nodeAddress := newDNSAddress("")
if opts.result.Node != nil {
na := opts.translateAddressFunc(opts.result.Tenancy.Datacenter, opts.result.Node.Address,
getStringAddressMapFromTaggedAddressMap(opts.result.Node.TaggedAddresses), addrTranslate)
nodeAddress = newDNSAddress(na)
}
return serviceAddress, nodeAddress
}
// getAnswerExtrasForAddressAndTarget creates the dns answer and extra from nodeAddress and serviceAddress dnsAddress pairs.
func (d messageSerializer) getAnswerExtrasForAddressAndTarget(nodeAddress *dnsAddress,
serviceAddress *dnsAddress, opts *getAnswerExtraAndNsOptions) (answer []dns.RR, extra []dns.RR) {
qName := opts.req.Question[0].Name
reqType := parseRequestType(opts.req)
switch {
case (reqType == requestTypeAddress || opts.result.Type == discovery.ResultTypeVirtual) &&
serviceAddress.IsEmptyString() && nodeAddress.IsIP():
a, e := getAnswerExtrasForIP(qName, nodeAddress, opts.req.Question[0], reqType, opts.result, opts.ttl, opts.responseDomain, &opts.port, opts.dnsRecordMaker, false)
answer = append(answer, a...)
extra = append(extra, e...)
case opts.result.Type == discovery.ResultTypeNode && nodeAddress.IsIP():
canonicalNodeName := canonicalNameForResult(opts.result.Type,
opts.result.Node.Name, opts.responseDomain, opts.result.Tenancy, opts.port.Name)
a, e := getAnswerExtrasForIP(canonicalNodeName, nodeAddress, opts.req.Question[0], reqType, opts.result, opts.ttl, opts.responseDomain, &opts.port, opts.dnsRecordMaker, false)
answer = append(answer, a...)
extra = append(extra, e...)
case opts.result.Type == discovery.ResultTypeNode && !nodeAddress.IsIP():
a, e := d.makeRecordFromFQDN(serviceAddress.FQDN(), opts)
answer = append(answer, a...)
extra = append(extra, e...)
case serviceAddress.IsEmptyString() && nodeAddress.IsEmptyString():
return nil, nil
// There is no service address and the node address is an IP
case serviceAddress.IsEmptyString() && nodeAddress.IsIP():
resultType := discovery.ResultTypeNode
if opts.result.Type == discovery.ResultTypeWorkload {
resultType = discovery.ResultTypeWorkload
}
canonicalNodeName := canonicalNameForResult(resultType, opts.result.Node.Name,
opts.responseDomain, opts.result.Tenancy, opts.port.Name)
a, e := getAnswerExtrasForIP(canonicalNodeName, nodeAddress, opts.req.Question[0], reqType, opts.result, opts.ttl, opts.responseDomain, &opts.port, opts.dnsRecordMaker, nodeAddress.String() == opts.result.Node.Address) // We compare the node address to the result to detect changes from the WAN translation
answer = append(answer, a...)
extra = append(extra, e...)
// There is no service address and the node address is a FQDN (external service)
case serviceAddress.IsEmptyString():
a, e := d.makeRecordFromFQDN(nodeAddress.FQDN(), opts)
answer = append(answer, a...)
extra = append(extra, e...)
case serviceAddress.IsIP() && opts.req.Question[0].Qtype == dns.TypeSRV:
a, e := getAnswerExtrasForIP(qName, serviceAddress, opts.req.Question[0], requestTypeName, opts.result, opts.ttl, opts.responseDomain, &opts.port, opts.dnsRecordMaker, false)
answer = append(answer, a...)
extra = append(extra, e...)
// The service address is an IP
case serviceAddress.IsIP():
canonicalServiceName := canonicalNameForResult(discovery.ResultTypeService,
opts.result.Service.Name, opts.responseDomain, opts.result.Tenancy, opts.port.Name)
a, e := getAnswerExtrasForIP(canonicalServiceName, serviceAddress, opts.req.Question[0], reqType, opts.result, opts.ttl, opts.responseDomain, &opts.port, opts.dnsRecordMaker, false)
answer = append(answer, a...)
extra = append(extra, e...)
// If the service address is a CNAME for the service we are looking
// for then use the node address.
case serviceAddress.FQDN() == opts.req.Question[0].Name && nodeAddress.IsIP():
canonicalNodeName := canonicalNameForResult(discovery.ResultTypeNode,
opts.result.Node.Name, opts.responseDomain, opts.result.Tenancy, opts.port.Name)
a, e := getAnswerExtrasForIP(canonicalNodeName, nodeAddress, opts.req.Question[0], reqType, opts.result, opts.ttl, opts.responseDomain, &opts.port, opts.dnsRecordMaker, nodeAddress.String() == opts.result.Node.Address) // We compare the node address to the result to detect changes from the WAN translation
answer = append(answer, a...)
extra = append(extra, e...)
// The service address is a FQDN (internal or external service name)
default:
a, e := d.makeRecordFromFQDN(serviceAddress.FQDN(), opts)
answer = append(answer, a...)
extra = append(extra, e...)
}
return
}
// makeRecordFromFQDN creates a DNS record from a FQDN.
func (d messageSerializer) makeRecordFromFQDN(fqdn string, opts *getAnswerExtraAndNsOptions) ([]dns.RR, []dns.RR) {
edns := opts.req.IsEdns0() != nil
q := opts.req.Question[0]
more := opts.resolveCnameFunc(opts.cfg, dns.Fqdn(fqdn), opts.reqCtx, opts.remoteAddress, opts.maxRecursionLevel)
var additional []dns.RR
extra := 0
MORE_REC:
for _, rr := range more {
switch rr.Header().Rrtype {
case dns.TypeCNAME, dns.TypeA, dns.TypeAAAA, dns.TypeTXT:
// set the TTL manually
rr.Header().Ttl = opts.ttl
additional = append(additional, rr)
extra++
if extra == maxRecurseRecords && !edns {
break MORE_REC
}
}
}
if q.Qtype == dns.TypeSRV {
answer := opts.dnsRecordMaker.makeSRV(q.Name, fqdn, uint16(opts.result.DNS.Weight), opts.ttl, &opts.port)
return []dns.RR{answer}, additional
}
address := ""
if opts.result.Service != nil && opts.result.Service.Address != "" {
address = opts.result.Service.Address
} else if opts.result.Node != nil {
address = opts.result.Node.Address
}
answers := []dns.RR{
opts.dnsRecordMaker.makeCNAME(q.Name, address, opts.ttl),
}
answers = append(answers, additional...)
return answers, nil
}
// getAnswerAndExtraTXT determines whether a TXT needs to be create and then
// returns the TXT record in the answer or extra depending on the question type.
func getAnswerAndExtraTXT(req *dns.Msg, cfg *RouterDynamicConfig, qName string,
result *discovery.Result, ttl uint32, domain string, query *discovery.Query,
port *discovery.Port, maker dnsRecordMaker) (answer []dns.RR, extra []dns.RR) {
if !shouldAppendTXTRecord(query, cfg, req) {
return
}
recordHeaderName := qName
serviceAddress := newDNSAddress("")
if result.Service != nil {
serviceAddress = newDNSAddress(result.Service.Address)
}
if result.Type != discovery.ResultTypeNode &&
result.Type != discovery.ResultTypeVirtual &&
!serviceAddress.IsInternalFQDN(domain) &&
!serviceAddress.IsExternalFQDN(domain) {
recordHeaderName = canonicalNameForResult(discovery.ResultTypeNode, result.Node.Name,
domain, result.Tenancy, port.Name)
}
qType := req.Question[0].Qtype
generateMeta := false
metaInAnswer := false
if qType == dns.TypeANY || qType == dns.TypeTXT {
generateMeta = true
metaInAnswer = true
} else if cfg.NodeMetaTXT {
generateMeta = true
}
// Do not generate txt records if we don't have to: https://github.com/hashicorp/consul/pull/5272
if generateMeta {
meta := maker.makeTXT(recordHeaderName, result.Metadata, ttl)
if metaInAnswer {
answer = append(answer, meta...)
} else {
extra = append(extra, meta...)
}
}
return answer, extra
}
// shouldAppendTXTRecord determines whether a TXT record should be appended to the response.
func shouldAppendTXTRecord(query *discovery.Query, cfg *RouterDynamicConfig, req *dns.Msg) bool {
qType := req.Question[0].Qtype
switch {
// Node records
case query != nil && query.QueryType == discovery.QueryTypeNode && (cfg.NodeMetaTXT || qType == dns.TypeANY || qType == dns.TypeTXT):
return true
// Service records
case query != nil && query.QueryType == discovery.QueryTypeService && cfg.NodeMetaTXT && qType == dns.TypeSRV:
return true
// Prepared query records
case query != nil && query.QueryType == discovery.QueryTypePreparedQuery && cfg.NodeMetaTXT && qType == dns.TypeSRV:
return true
}
return false
}
// getAnswerExtrasForIP creates the dns answer and extra from IP dnsAddress pairs.
func getAnswerExtrasForIP(name string, addr *dnsAddress, question dns.Question, reqType requestType, result *discovery.Result, ttl uint32, domain string, port *discovery.Port, maker dnsRecordMaker, addressOverridden bool) (answer []dns.RR, extra []dns.RR) {
qType := question.Qtype
canReturnARecord := qType == dns.TypeSRV || qType == dns.TypeA || qType == dns.TypeANY || qType == dns.TypeNS || qType == dns.TypeTXT
canReturnAAAARecord := qType == dns.TypeSRV || qType == dns.TypeAAAA || qType == dns.TypeANY || qType == dns.TypeNS || qType == dns.TypeTXT
if reqType != requestTypeAddress && result.Type != discovery.ResultTypeVirtual {
switch {
// check IPV4
case addr.IsIP() && addr.IsIPV4() && !canReturnARecord,
// check IPV6
addr.IsIP() && !addr.IsIPV4() && !canReturnAAAARecord:
return
}
}
// Have to pass original question name here even if the system has recursed
// and stripped off the domain suffix.
recHdrName := question.Name
if qType == dns.TypeSRV {
nameSplit := strings.Split(name, ".")
if len(nameSplit) > 1 && nameSplit[1] == addrLabel {
recHdrName = name
} else {
recHdrName = name
}
name = question.Name
}
if reqType != requestTypeAddress && qType == dns.TypeSRV {
if addr.IsIP() && question.Name == name && !addressOverridden {
// encode the ip to be used in the header of the A/AAAA record
// as well as the target of the SRV record.
recHdrName = encodeIPAsFqdn(result, addr.IP(), domain)
}
if result.Type == discovery.ResultTypeWorkload {
recHdrName = canonicalNameForResult(result.Type, result.Node.Name, domain, result.Tenancy, port.Name)
}
srv := maker.makeSRV(name, recHdrName, uint16(result.DNS.Weight), ttl, port)
answer = append(answer, srv)
}
record := maker.makeIPBasedRecord(recHdrName, addr, ttl)
isARecordWhenNotExplicitlyQueried := record.Header().Rrtype == dns.TypeA && qType != dns.TypeA && qType != dns.TypeANY
isAAAARecordWhenNotExplicitlyQueried := record.Header().Rrtype == dns.TypeAAAA && qType != dns.TypeAAAA && qType != dns.TypeANY
// For explicit A/AAAA queries, we must only return those records in the answer section.
if isARecordWhenNotExplicitlyQueried ||
isAAAARecordWhenNotExplicitlyQueried {
extra = append(extra, record)
} else {
answer = append(answer, record)
}
return
}
// getPortsFromResult returns the ports from a discovery result.
func getPortsFromResult(result *discovery.Result) []discovery.Port {
if len(result.Ports) > 0 {
return result.Ports
}
// return one record.
return []discovery.Port{{}}
}
// encodeIPAsFqdn encodes an IP address as a FQDN.
func encodeIPAsFqdn(result *discovery.Result, ip net.IP, responseDomain string) string {
ipv4 := ip.To4()
ipStr := hex.EncodeToString(ip)
if ipv4 != nil {
ipStr = ipStr[len(ipStr)-(net.IPv4len*2):]
}
if result.Tenancy.PeerName != "" {
// Exclude the datacenter from the FQDN on the addr for peers.
// This technically makes no difference, since the addr endpoint ignores the DC
// component of the request, but do it anyway for a less confusing experience.
return fmt.Sprintf("%s.addr.%s", ipStr, responseDomain)
}
return fmt.Sprintf("%s.addr.%s.%s", ipStr, result.Tenancy.Datacenter, responseDomain)
}
// canonicalNameForResult returns the canonical name for a discovery result.
func canonicalNameForResult(resultType discovery.ResultType, target, domain string,
tenancy discovery.ResultTenancy, portName string) string {
switch resultType {
case discovery.ResultTypeService:
if tenancy.Namespace != "" {
return fmt.Sprintf("%s.%s.%s.%s.%s", target, "service", tenancy.Namespace, tenancy.Datacenter, domain)
}
return fmt.Sprintf("%s.%s.%s.%s", target, "service", tenancy.Datacenter, domain)
case discovery.ResultTypeNode:
if tenancy.PeerName != "" && tenancy.Partition != "" {
// We must return a more-specific DNS name for peering so
// that there is no ambiguity with lookups.
// Nodes are always registered in the default namespace, so
// the `.ns` qualifier is not required.
return fmt.Sprintf("%s.node.%s.peer.%s.ap.%s",
target,
tenancy.PeerName,
tenancy.Partition,
domain)
}
if tenancy.PeerName != "" {
// We must return a more-specific DNS name for peering so
// that there is no ambiguity with lookups.
return fmt.Sprintf("%s.node.%s.peer.%s",
target,
tenancy.PeerName,
domain)
}
// Return a simpler format for non-peering nodes.
return fmt.Sprintf("%s.node.%s.%s", target, tenancy.Datacenter, domain)
case discovery.ResultTypeWorkload:
// TODO (v2-dns): it doesn't appear this is being used to return a result. Need to investigate and refactor
if portName != "" {
return fmt.Sprintf("%s.port.%s.workload.%s.ns.%s.ap.%s", portName, target, tenancy.Namespace, tenancy.Partition, domain)
}
return fmt.Sprintf("%s.workload.%s.ns.%s.ap.%s", target, tenancy.Namespace, tenancy.Partition, domain)
}
return ""
}
// getServiceAddressMapFromLocationMap converts a map of Location to a map of ServiceAddress.
func getServiceAddressMapFromLocationMap(taggedAddresses map[string]*discovery.TaggedAddress) map[string]structs.ServiceAddress {
taggedServiceAddresses := make(map[string]structs.ServiceAddress, len(taggedAddresses))
for k, v := range taggedAddresses {
taggedServiceAddresses[k] = structs.ServiceAddress{
Address: v.Address,
Port: int(v.Port.Number),
}
}
return taggedServiceAddresses
}
// getStringAddressMapFromTaggedAddressMap converts a map of Location to a map of string.
func getStringAddressMapFromTaggedAddressMap(taggedAddresses map[string]*discovery.TaggedAddress) map[string]string {
taggedServiceAddresses := make(map[string]string, len(taggedAddresses))
for k, v := range taggedAddresses {
taggedServiceAddresses[k] = v.Address
}
return taggedServiceAddresses
}
// getTTLForResult returns the TTL for a given result.
func getTTLForResult(name string, overrideTTL *uint32, query *discovery.Query, cfg *RouterDynamicConfig) uint32 {
// In the case we are not making a discovery query, such as addr. or arpa. lookups,
// use the node TTL by convention
if query == nil {
return uint32(cfg.NodeTTL / time.Second)
}
if overrideTTL != nil {
// If a result was provided with an override, use that. This is the case for some prepared queries.
return *overrideTTL
}
switch query.QueryType {
case discovery.QueryTypeService, discovery.QueryTypePreparedQuery:
ttl, ok := cfg.GetTTLForService(name)
if ok {
return uint32(ttl / time.Second)
}
fallthrough
default:
return uint32(cfg.NodeTTL / time.Second)
}
}
// serializeToGetAnswerExtraAndNsOptions converts serializeOptions to getAnswerExtraAndNsOptions.
func serializeToGetAnswerExtraAndNsOptions(opts *serializeOptions,
result *discovery.Result, port discovery.Port) *getAnswerExtraAndNsOptions {
return &getAnswerExtraAndNsOptions{
port: port,
result: result,
req: opts.req,
reqCtx: opts.reqCtx,
query: opts.query,
results: opts.results,
resp: opts.resp,
cfg: opts.cfg,
responseDomain: opts.responseDomain,
remoteAddress: opts.remoteAddress,
maxRecursionLevel: opts.maxRecursionLevel,
translateAddressFunc: opts.translateAddressFunc,
translateServiceAddressFunc: opts.translateServiceAddressFunc,
resolveCnameFunc: opts.resolveCnameFunc,
dnsRecordMaker: opts.dnsRecordMaker,
}
}

View File

@ -1,82 +0,0 @@
// Code generated by mockery v2.32.4. DO NOT EDIT.
package dns
import (
config "github.com/hashicorp/consul/agent/config"
miekgdns "github.com/miekg/dns"
mock "github.com/stretchr/testify/mock"
net "net"
)
// MockDNSRouter is an autogenerated mock type for the DNSRouter type
type MockDNSRouter struct {
mock.Mock
}
// GetConfig provides a mock function with given fields:
func (_m *MockDNSRouter) GetConfig() *RouterDynamicConfig {
ret := _m.Called()
var r0 *RouterDynamicConfig
if rf, ok := ret.Get(0).(func() *RouterDynamicConfig); ok {
r0 = rf()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*RouterDynamicConfig)
}
}
return r0
}
// HandleRequest provides a mock function with given fields: req, reqCtx, remoteAddress
func (_m *MockDNSRouter) HandleRequest(req *miekgdns.Msg, reqCtx Context, remoteAddress net.Addr) *miekgdns.Msg {
ret := _m.Called(req, reqCtx, remoteAddress)
var r0 *miekgdns.Msg
if rf, ok := ret.Get(0).(func(*miekgdns.Msg, Context, net.Addr) *miekgdns.Msg); ok {
r0 = rf(req, reqCtx, remoteAddress)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*miekgdns.Msg)
}
}
return r0
}
// ReloadConfig provides a mock function with given fields: newCfg
func (_m *MockDNSRouter) ReloadConfig(newCfg *config.RuntimeConfig) error {
ret := _m.Called(newCfg)
var r0 error
if rf, ok := ret.Get(0).(func(*config.RuntimeConfig) error); ok {
r0 = rf(newCfg)
} else {
r0 = ret.Error(0)
}
return r0
}
// ServeDNS provides a mock function with given fields: w, req
func (_m *MockDNSRouter) ServeDNS(w miekgdns.ResponseWriter, req *miekgdns.Msg) {
_m.Called(w, req)
}
// NewMockDNSRouter creates a new instance of MockDNSRouter. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewMockDNSRouter(t interface {
mock.TestingT
Cleanup(func())
}) *MockDNSRouter {
mock := &MockDNSRouter{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}

View File

@ -1,55 +0,0 @@
// Code generated by mockery v2.32.4. DO NOT EDIT.
package dns
import (
miekgdns "github.com/miekg/dns"
mock "github.com/stretchr/testify/mock"
net "net"
)
// mockDnsRecursor is an autogenerated mock type for the dnsRecursor type
type mockDnsRecursor struct {
mock.Mock
}
// handle provides a mock function with given fields: req, cfgCtx, remoteAddr
func (_m *mockDnsRecursor) handle(req *miekgdns.Msg, cfgCtx *RouterDynamicConfig, remoteAddr net.Addr) (*miekgdns.Msg, error) {
ret := _m.Called(req, cfgCtx, remoteAddr)
var r0 *miekgdns.Msg
var r1 error
if rf, ok := ret.Get(0).(func(*miekgdns.Msg, *RouterDynamicConfig, net.Addr) (*miekgdns.Msg, error)); ok {
return rf(req, cfgCtx, remoteAddr)
}
if rf, ok := ret.Get(0).(func(*miekgdns.Msg, *RouterDynamicConfig, net.Addr) *miekgdns.Msg); ok {
r0 = rf(req, cfgCtx, remoteAddr)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*miekgdns.Msg)
}
}
if rf, ok := ret.Get(1).(func(*miekgdns.Msg, *RouterDynamicConfig, net.Addr) error); ok {
r1 = rf(req, cfgCtx, remoteAddr)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// newMockDnsRecursor creates a new instance of mockDnsRecursor. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func newMockDnsRecursor(t interface {
mock.TestingT
Cleanup(func())
}) *mockDnsRecursor {
mock := &mockDnsRecursor{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}

View File

@ -1,89 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package dns
// parsedLabels defines valid DNS labels that are possible for ALL DNS query in Consul. (v1 and v2, CE and ENT)
// It is the job of the parser to populate the struct, the routers to call the query processor,
// and the query processor to validate is the labels.
type parsedLabels struct {
Datacenter string
Namespace string
Partition string
Peer string
PeerOrDatacenter string // deprecated: use Datacenter or Peer
SamenessGroup string
}
// ParseLabels can parse a DNS query's labels and returns a parsedLabels.
// It also does light validation according to invariants across all possible DNS queries for all Consul versions
func parseLabels(labels []string) (*parsedLabels, bool) {
var result parsedLabels
switch len(labels) {
case 2, 4, 6:
// Supports the following formats:
// - [.<namespace>.ns][.<partition>.ap][.<datacenter>.dc]
// - <namespace>.<datacenter>
// - [.<namespace>.ns][.<partition>.ap][.<peer>.peer]
// - [.<samenessGroup>.sg][.<partition>.ap][.<namespace>.ns]
for i := 0; i < len(labels); i += 2 {
switch labels[i+1] {
case "ns":
result.Namespace = labels[i]
case "ap":
result.Partition = labels[i]
case "dc", "cluster":
result.Datacenter = labels[i]
case "sg":
result.SamenessGroup = labels[i]
case "peer":
result.Peer = labels[i]
default:
// The only case in which labels[i+1] is allowed to be a value
// other than ns, ap, or dc is if n == 2 to support the format:
// <namespace>.<datacenter>.
if len(labels) == 2 {
result.PeerOrDatacenter = labels[1]
result.Namespace = labels[0]
return &result, true
}
return nil, false
}
}
// VALIDATIONS
// Return nil result and false boolean when both datacenter and peer are specified.
if result.Datacenter != "" && result.Peer != "" {
return nil, false
}
// Validate that this a valid DNS including sg
if result.SamenessGroup != "" && (result.Datacenter != "" || result.Peer != "") {
return nil, false
}
return &result, true
case 1:
result.PeerOrDatacenter = labels[0]
return &result, true
case 0:
return &result, true
}
return &result, false
}
// parsePort looks through the query parts for a named port label.
// It assumes the only valid input format is["<portName>", "port", "<targetName>"].
// The other expected formats are ["<targetName>"] and ["<tag>", "<targetName>"].
// It is expected that the queryProcessor validates if the label is allowed for the query type.
func parsePort(parts []string) string {
// The minimum number of parts would be
if len(parts) != 3 || parts[1] != "port" {
return ""
}
return parts[0]
}

View File

@ -1,141 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package dns
import (
"github.com/stretchr/testify/require"
"testing"
)
func Test_parseLabels(t *testing.T) {
type testCase struct {
name string
labels []string
expectedOK bool
expectedResult *parsedLabels
}
testCases := []testCase{
{
name: "6 labels - with datacenter",
labels: []string{"test-ns", "ns", "test-ap", "ap", "test-dc", "dc"},
expectedResult: &parsedLabels{
Namespace: "test-ns",
Partition: "test-ap",
Datacenter: "test-dc",
},
expectedOK: true,
},
{
name: "6 labels - with cluster",
labels: []string{"test-ns", "ns", "test-ap", "ap", "test-cluster", "cluster"},
expectedResult: &parsedLabels{
Namespace: "test-ns",
Partition: "test-ap",
Datacenter: "test-cluster",
},
expectedOK: true,
},
{
name: "6 labels - with peer",
labels: []string{"test-ns", "ns", "test-ap", "ap", "test-peer", "peer"},
expectedResult: &parsedLabels{
Namespace: "test-ns",
Partition: "test-ap",
Peer: "test-peer",
},
expectedOK: true,
},
{
name: "6 labels - with sameness group",
labels: []string{"test-sg", "sg", "test-ap", "ap", "test-ns", "ns"},
expectedResult: &parsedLabels{
Namespace: "test-ns",
Partition: "test-ap",
SamenessGroup: "test-sg",
},
expectedOK: true,
},
{
name: "6 labels - invalid",
labels: []string{"test-ns", "not-ns", "test-ap", "ap", "test-dc", "dc"},
expectedResult: nil,
expectedOK: false,
},
{
name: "4 labels - namespace and datacenter",
labels: []string{"test-ns", "ns", "test-ap", "ap"},
expectedResult: &parsedLabels{
Namespace: "test-ns",
Partition: "test-ap",
},
expectedOK: true,
},
{
name: "4 labels - invalid",
labels: []string{"test-ns", "not-ns", "test-ap", "ap", "test-dc", "dc"},
expectedResult: nil,
expectedOK: false,
},
{
name: "2 labels - namespace and peer or datacenter",
labels: []string{"test-ns", "test-peer-or-dc"},
expectedResult: &parsedLabels{
Namespace: "test-ns",
PeerOrDatacenter: "test-peer-or-dc",
},
expectedOK: true,
},
{
name: "1 label - peer or datacenter",
labels: []string{"test-peer-or-dc"},
expectedResult: &parsedLabels{
PeerOrDatacenter: "test-peer-or-dc",
},
expectedOK: true,
},
{
name: "0 labels - returns empty result and true",
labels: []string{},
expectedResult: &parsedLabels{},
expectedOK: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result, ok := parseLabels(tc.labels)
require.Equal(t, tc.expectedOK, ok)
require.Equal(t, tc.expectedResult, result)
})
}
}
func Test_parsePort(t *testing.T) {
type testCase struct {
name string
labels []string
expectedResult string
}
testCases := []testCase{
{
name: "given 3 labels where the second label is port, the first label is returned",
labels: []string{"port-name", "port", "target-name"},
expectedResult: "port-name",
},
{
name: "given 3 labels where the second label is not port, an empty string is returned",
labels: []string{"port-name", "not-port", "target-name"},
expectedResult: "",
},
{
name: "given anything but 3 labels, an empty string is returned",
labels: []string{"port-name", "something-else"},
expectedResult: "",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
require.Equal(t, tc.expectedResult, parsePort(tc.labels))
})
}
}

View File

@ -1,123 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package dns
import (
"errors"
"net"
"time"
"github.com/hashicorp/go-hclog"
"github.com/miekg/dns"
"github.com/hashicorp/consul/ipaddr"
"github.com/hashicorp/consul/logging"
)
type recursor struct {
logger hclog.Logger
}
func newRecursor(logger hclog.Logger) *recursor {
return &recursor{
logger: logger.Named(logging.DNS),
}
}
// handle is used to process DNS queries for externally configured servers
func (r *recursor) handle(req *dns.Msg, cfgCtx *RouterDynamicConfig, remoteAddr net.Addr) (*dns.Msg, error) {
q := req.Question[0]
network := "udp"
defer func(s time.Time) {
r.logger.Trace("request served from client",
"question", q,
"network", network,
"latency", time.Since(s).String(),
"client", remoteAddr.String(),
"client_network", remoteAddr.Network(),
)
}(time.Now())
// Switch to TCP if the client is
if _, ok := remoteAddr.(*net.TCPAddr); ok {
network = "tcp"
}
// Recursively resolve
c := &dns.Client{Net: network, Timeout: cfgCtx.RecursorTimeout}
var resp *dns.Msg
var rtt time.Duration
var err error
for _, idx := range cfgCtx.RecursorStrategy.Indexes(len(cfgCtx.Recursors)) {
recurseAddr := cfgCtx.Recursors[idx]
resp, rtt, err = c.Exchange(req, recurseAddr)
// Check if the response is valid and has the desired Response code
if resp != nil && (resp.Rcode != dns.RcodeSuccess && resp.Rcode != dns.RcodeNameError) {
r.logger.Trace("recurse failed for question",
"question", q,
"rtt", rtt,
"recursor", recurseAddr,
"rcode", dns.RcodeToString[resp.Rcode],
)
// If we still have recursors to forward the query to,
// we move forward onto the next one else the loop ends
continue
} else if err == nil || (resp != nil && resp.Truncated) {
// Compress the response; we don't know if the incoming
// response was compressed or not, so by not compressing
// we might generate an invalid packet on the way out.
resp.Compress = !cfgCtx.DisableCompression
// Forward the response
r.logger.Trace("recurse succeeded for question",
"question", q,
"rtt", rtt,
"recursor", recurseAddr,
)
return resp, nil
}
r.logger.Error("recurse failed", "error", err)
}
// If all resolvers fail, return a SERVFAIL message
r.logger.Error("all resolvers failed for question from client",
"question", q,
"client", remoteAddr.String(),
"client_network", remoteAddr.Network(),
)
return nil, errRecursionFailed
}
// formatRecursorAddress is used to add a port to the recursor if omitted.
func formatRecursorAddress(recursor string) (string, error) {
_, _, err := net.SplitHostPort(recursor)
var ae *net.AddrError
if errors.As(err, &ae) {
switch ae.Err {
case "missing port in address":
recursor = ipaddr.FormatAddressPort(recursor, 53)
case "too many colons in address":
if ip := net.ParseIP(recursor); ip != nil && ip.To4() == nil {
recursor = ipaddr.FormatAddressPort(recursor, 53)
break
}
fallthrough
default:
return "", err
}
} else if err != nil {
return "", err
}
// Get the address
addr, err := net.ResolveTCPAddr("tcp", recursor)
if err != nil {
return "", err
}
// Return string
return addr.String(), nil
}

View File

@ -1,39 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package dns
import (
"strings"
"testing"
)
// Test_handle cases are covered by the integration tests in agent/dns_test.go.
// They should be moved here when the V1 DNS server is deprecated.
//func Test_handle(t *testing.T) {
func Test_formatRecursorAddress(t *testing.T) {
t.Parallel()
addr, err := formatRecursorAddress("8.8.8.8")
if err != nil {
t.Fatalf("err: %v", err)
}
if addr != "8.8.8.8:53" {
t.Fatalf("bad: %v", addr)
}
addr, err = formatRecursorAddress("2001:4860:4860::8888")
if err != nil {
t.Fatalf("err: %v", err)
}
if addr != "[2001:4860:4860::8888]:53" {
t.Fatalf("bad: %v", addr)
}
_, err = formatRecursorAddress("1.2.3.4::53")
if err == nil || !strings.Contains(err.Error(), "too many colons in address") {
t.Fatalf("err: %v", err)
}
_, err = formatRecursorAddress("2001:4860:4860::8888:::53")
if err == nil || !strings.Contains(err.Error(), "too many colons in address") {
t.Fatalf("err: %v", err)
}
}

View File

@ -1,425 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package dns
import (
"errors"
"fmt"
"math"
"net"
"strings"
"github.com/miekg/dns"
"github.com/hashicorp/consul/agent/discovery"
"github.com/hashicorp/consul/lib"
"github.com/hashicorp/go-hclog"
)
const (
// UDP can fit ~25 A records in a 512B response, and ~14 AAAA
// records. Limit further to prevent unintentional configuration
// abuse that would have a negative effect on application response
// times.
maxUDPAnswerLimit = 8
defaultMaxUDPSize = 512
// If a consumer sets a buffer size greater than this amount we will default it down
// to this amount to ensure that consul does respond. Previously if consumer had a larger buffer
// size than 65535 - 60 bytes (maximim 60 bytes for IP header. UDP header will be offset in the
// trimUDP call) consul would fail to respond and the consumer timesout
// the request.
maxUDPDatagramSize = math.MaxUint16 - 68
)
// dnsResponseGenerator is used to:
// - generate DNS responses for errors
// - trim and truncate DNS responses
// - EDNS to the response
type dnsResponseGenerator struct{}
// createRefusedResponse returns a REFUSED message. This is the default behavior for unmatched queries in
// upstream miekg/dns.
func (d dnsResponseGenerator) createRefusedResponse(req *dns.Msg) *dns.Msg {
// Return a REFUSED message
m := &dns.Msg{}
m.SetRcode(req, dns.RcodeRefused)
return m
}
// createServerFailureResponse returns a SERVFAIL message.
func (d dnsResponseGenerator) createServerFailureResponse(req *dns.Msg, cfg *RouterDynamicConfig, recursionAvailable bool) *dns.Msg {
// Return a SERVFAIL message
m := &dns.Msg{}
m.SetReply(req)
m.Compress = !cfg.DisableCompression
m.SetRcode(req, dns.RcodeServerFailure)
m.RecursionAvailable = recursionAvailable
if edns := req.IsEdns0(); edns != nil {
d.setEDNS(req, m, true)
}
return m
}
// createAuthoritativeResponse returns an authoritative message that contains the SOA in the event that data is
// not return for a query. There can be multiple reasons for not returning data, hence the rcode argument.
func (d dnsResponseGenerator) createAuthoritativeResponse(req *dns.Msg, cfg *RouterDynamicConfig, domain string, rcode int, ecsGlobal bool) *dns.Msg {
m := &dns.Msg{}
m.SetRcode(req, rcode)
m.Compress = !cfg.DisableCompression
m.Authoritative = true
m.RecursionAvailable = canRecurse(cfg)
if edns := req.IsEdns0(); edns != nil {
d.setEDNS(req, m, ecsGlobal)
}
// We add the SOA on NameErrors
maker := &dnsRecordMaker{}
soa := maker.makeSOA(domain, cfg)
m.Ns = append(m.Ns, soa)
return m
}
// generateResponseFromErrorOpts is used to pass options to generateResponseFromError.
type generateResponseFromErrorOpts struct {
req *dns.Msg
err error
qName string
configCtx *RouterDynamicConfig
responseDomain string
isECSGlobal bool
query *discovery.Query
canRecurse bool
logger hclog.Logger
}
// generateResponseFromError generates a response from an error.
func (d dnsResponseGenerator) generateResponseFromError(opts *generateResponseFromErrorOpts) *dns.Msg {
switch {
case errors.Is(opts.err, errInvalidQuestion):
opts.logger.Error("invalid question", "name", opts.qName)
return d.createAuthoritativeResponse(opts.req, opts.configCtx, opts.responseDomain, dns.RcodeNameError, opts.isECSGlobal)
case errors.Is(opts.err, errNameNotFound):
opts.logger.Error("name not found", "name", opts.qName)
return d.createAuthoritativeResponse(opts.req, opts.configCtx, opts.responseDomain, dns.RcodeNameError, opts.isECSGlobal)
case errors.Is(opts.err, errNotImplemented):
opts.logger.Error("query not implemented", "name", opts.qName, "type", dns.Type(opts.req.Question[0].Qtype).String())
return d.createAuthoritativeResponse(opts.req, opts.configCtx, opts.responseDomain, dns.RcodeNotImplemented, opts.isECSGlobal)
case errors.Is(opts.err, discovery.ErrNotSupported):
opts.logger.Debug("query name syntax not supported", "name", opts.req.Question[0].Name)
return d.createAuthoritativeResponse(opts.req, opts.configCtx, opts.responseDomain, dns.RcodeNameError, opts.isECSGlobal)
case errors.Is(opts.err, discovery.ErrNotFound):
opts.logger.Debug("query name not found", "name", opts.req.Question[0].Name)
return d.createAuthoritativeResponse(opts.req, opts.configCtx, opts.responseDomain, dns.RcodeNameError, opts.isECSGlobal)
case errors.Is(opts.err, discovery.ErrNoData):
opts.logger.Debug("no data available", "name", opts.qName)
return d.createAuthoritativeResponse(opts.req, opts.configCtx, opts.responseDomain, dns.RcodeSuccess, opts.isECSGlobal)
case errors.Is(opts.err, discovery.ErrNoPathToDatacenter):
dc := ""
if opts.query != nil {
dc = opts.query.QueryPayload.Tenancy.Datacenter
}
opts.logger.Debug("no path to datacenter", "datacenter", dc)
return d.createAuthoritativeResponse(opts.req, opts.configCtx, opts.responseDomain, dns.RcodeNameError, opts.isECSGlobal)
}
opts.logger.Error("error processing discovery query", "error", opts.err)
return d.createServerFailureResponse(opts.req, opts.configCtx, opts.canRecurse)
}
// trimDNSResponse will trim the response for UDP and TCP
func (d dnsResponseGenerator) trimDNSResponse(cfg *RouterDynamicConfig, remoteAddress net.Addr, req, resp *dns.Msg, logger hclog.Logger) {
// Switch to TCP if the client is
network := "udp"
if _, ok := remoteAddress.(*net.TCPAddr); ok {
network = "tcp"
}
var trimmed bool
originalSize := resp.Len()
originalNumRecords := len(resp.Answer)
if network != "tcp" {
trimmed = trimUDPResponse(req, resp, cfg.UDPAnswerLimit)
} else {
trimmed = trimTCPResponse(req, resp)
}
// Flag that there are more records to return in the UDP response
if trimmed {
if cfg.EnableTruncate {
resp.Truncated = true
}
logger.Debug("DNS response too large, truncated",
"protocol", network,
"question", req.Question,
"records", fmt.Sprintf("%d/%d", len(resp.Answer), originalNumRecords),
"size", fmt.Sprintf("%d/%d", resp.Len(), originalSize),
)
}
}
// setEDNS is used to set the responses EDNS size headers and
// possibly the ECS headers as well if they were present in the
// original request
func (d dnsResponseGenerator) setEDNS(request *dns.Msg, response *dns.Msg, ecsGlobal bool) {
edns := request.IsEdns0()
if edns == nil {
return
}
// cannot just use the SetEdns0 function as we need to embed
// the ECS option as well
ednsResp := new(dns.OPT)
ednsResp.Hdr.Name = "."
ednsResp.Hdr.Rrtype = dns.TypeOPT
ednsResp.SetUDPSize(edns.UDPSize())
// Set up the ECS option if present
if subnet := ednsSubnetForRequest(request); subnet != nil {
subOp := new(dns.EDNS0_SUBNET)
subOp.Code = dns.EDNS0SUBNET
subOp.Family = subnet.Family
subOp.Address = subnet.Address
subOp.SourceNetmask = subnet.SourceNetmask
if c := response.Rcode; ecsGlobal || c == dns.RcodeNameError || c == dns.RcodeServerFailure || c == dns.RcodeRefused || c == dns.RcodeNotImplemented {
// reply is globally valid and should be cached accordingly
subOp.SourceScope = 0
} else {
// reply is only valid for the subnet it was queried with
subOp.SourceScope = subnet.SourceNetmask
}
ednsResp.Option = append(ednsResp.Option, subOp)
}
response.Extra = append(response.Extra, ednsResp)
}
// ednsSubnetForRequest looks through the request to find any EDS subnet options
func ednsSubnetForRequest(req *dns.Msg) *dns.EDNS0_SUBNET {
// IsEdns0 returns the EDNS RR if present or nil otherwise
edns := req.IsEdns0()
if edns == nil {
return nil
}
for _, o := range edns.Option {
if subnet, ok := o.(*dns.EDNS0_SUBNET); ok {
return subnet
}
}
return nil
}
// trimTCPResponse limit the MaximumSize of messages to 64k as it is the limit
// of DNS responses
func trimTCPResponse(req, resp *dns.Msg) (trimmed bool) {
hasExtra := len(resp.Extra) > 0
// There is some overhead, 65535 does not work
maxSize := 65523 // 64k - 12 bytes DNS raw overhead
// We avoid some function calls and allocations by only handling the
// extra data when necessary.
var index map[string]dns.RR
// It is not possible to return more than 4k records even with compression
// Since we are performing binary search it is not a big deal, but it
// improves a bit performance, even with binary search
truncateAt := 4096
if req.Question[0].Qtype == dns.TypeSRV {
// More than 1024 SRV records do not fit in 64k
truncateAt = 1024
}
if len(resp.Answer) > truncateAt {
resp.Answer = resp.Answer[:truncateAt]
}
if hasExtra {
index = make(map[string]dns.RR, len(resp.Extra))
indexRRs(resp.Extra, index)
}
truncated := false
// This enforces the given limit on 64k, the max limit for DNS messages
for len(resp.Answer) > 1 && resp.Len() > maxSize {
truncated = true
// first try to remove the NS section may be it will truncate enough
if len(resp.Ns) != 0 {
resp.Ns = []dns.RR{}
}
// More than 100 bytes, find with a binary search
if resp.Len()-maxSize > 100 {
bestIndex := dnsBinaryTruncate(resp, maxSize, index, hasExtra)
resp.Answer = resp.Answer[:bestIndex]
} else {
resp.Answer = resp.Answer[:len(resp.Answer)-1]
}
if hasExtra {
syncExtra(index, resp)
}
}
return truncated
}
// trimUDPResponse makes sure a UDP response is not longer than allowed by RFC
// 1035. Enforce an arbitrary limit that can be further ratcheted down by
// config, and then make sure the response doesn't exceed 512 bytes. Any extra
// records will be trimmed along with answers.
func trimUDPResponse(req, resp *dns.Msg, udpAnswerLimit int) (trimmed bool) {
numAnswers := len(resp.Answer)
hasExtra := len(resp.Extra) > 0
maxSize := defaultMaxUDPSize
// Update to the maximum edns size
if edns := req.IsEdns0(); edns != nil {
if size := edns.UDPSize(); size > uint16(maxSize) {
maxSize = int(size)
}
}
// Overriding maxSize as the maxSize cannot be larger than the
// maxUDPDatagram size. Reliability guarantees disappear > than this amount.
if maxSize > maxUDPDatagramSize {
maxSize = maxUDPDatagramSize
}
// We avoid some function calls and allocations by only handling the
// extra data when necessary.
var index map[string]dns.RR
if hasExtra {
index = make(map[string]dns.RR, len(resp.Extra))
indexRRs(resp.Extra, index)
}
// This cuts UDP responses to a useful but limited number of responses.
maxAnswers := lib.MinInt(maxUDPAnswerLimit, udpAnswerLimit)
compress := resp.Compress
if maxSize == defaultMaxUDPSize && numAnswers > maxAnswers {
// We disable computation of Len ONLY for non-eDNS request (512 bytes)
resp.Compress = false
resp.Answer = resp.Answer[:maxAnswers]
if hasExtra {
syncExtra(index, resp)
}
}
if maxSize == defaultMaxUDPSize && numAnswers > maxAnswers {
// We disable computation of Len ONLY for non-eDNS request (512 bytes)
resp.Compress = false
resp.Answer = resp.Answer[:maxAnswers]
if hasExtra {
syncExtra(index, resp)
}
}
// This enforces the given limit on the number bytes. The default is 512 as
// per the RFC, but EDNS0 allows for the user to specify larger sizes. Note
// that we temporarily switch to uncompressed so that we limit to a response
// that will not exceed 512 bytes uncompressed, which is more conservative and
// will allow our responses to be compliant even if some downstream server
// uncompresses them.
// Even when size is too big for one single record, try to send it anyway
// (useful for 512 bytes messages). 8 is removed from maxSize to ensure that we account
// for the udp header (8 bytes).
for len(resp.Answer) > 1 && resp.Len() > maxSize-8 {
// first try to remove the NS section may be it will truncate enough
if len(resp.Ns) != 0 {
resp.Ns = []dns.RR{}
}
// More than 100 bytes, find with a binary search
if resp.Len()-maxSize > 100 {
bestIndex := dnsBinaryTruncate(resp, maxSize, index, hasExtra)
resp.Answer = resp.Answer[:bestIndex]
} else {
resp.Answer = resp.Answer[:len(resp.Answer)-1]
}
if hasExtra {
syncExtra(index, resp)
}
}
// For 512 non-eDNS responses, while we compute size non-compressed,
// we send result compressed
resp.Compress = compress
return len(resp.Answer) < numAnswers
}
// syncExtra takes a DNS response message and sets the extra data to the most
// minimal set needed to cover the answer data. A pre-made index of RRs is given
// so that can be re-used between calls. This assumes that the extra data is
// only used to provide info for SRV records. If that's not the case, then this
// will wipe out any additional data.
func syncExtra(index map[string]dns.RR, resp *dns.Msg) {
extra := make([]dns.RR, 0, len(resp.Answer))
resolved := make(map[string]struct{}, len(resp.Answer))
for _, ansRR := range resp.Answer {
srv, ok := ansRR.(*dns.SRV)
if !ok {
continue
}
// Note that we always use lower case when using the index so
// that compares are not case-sensitive. We don't alter the actual
// RRs we add into the extra section, however.
target := strings.ToLower(srv.Target)
RESOLVE:
if _, ok := resolved[target]; ok {
continue
}
resolved[target] = struct{}{}
extraRR, ok := index[target]
if ok {
extra = append(extra, extraRR)
if cname, ok := extraRR.(*dns.CNAME); ok {
target = strings.ToLower(cname.Target)
goto RESOLVE
}
}
}
resp.Extra = extra
}
// dnsBinaryTruncate find the optimal number of records using a fast binary search and return
// it in order to return a DNS answer lower than maxSize parameter.
func dnsBinaryTruncate(resp *dns.Msg, maxSize int, index map[string]dns.RR, hasExtra bool) int {
originalAnswser := resp.Answer
startIndex := 0
endIndex := len(resp.Answer) + 1
for endIndex-startIndex > 1 {
median := startIndex + (endIndex-startIndex)/2
resp.Answer = originalAnswser[:median]
if hasExtra {
syncExtra(index, resp)
}
aLen := resp.Len()
if aLen <= maxSize {
if maxSize-aLen < 10 {
// We are good, increasing will go out of bounds
return median
}
startIndex = median
} else {
endIndex = median
}
}
return startIndex
}
// indexRRs populates a map which indexes a given list of RRs by name. NOTE that
// the names are all squashed to lower case so we can perform case-insensitive
// lookups; the RRs are not modified.
func indexRRs(rrs []dns.RR, index map[string]dns.RR) {
for _, rr := range rrs {
name := strings.ToLower(rr.Header().Name)
if _, ok := index[name]; !ok {
index[name] = rr
}
}
}

View File

@ -1,739 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package dns
import (
"errors"
"net"
"testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/require"
"github.com/hashicorp/consul/agent/discovery"
"github.com/hashicorp/consul/sdk/testutil"
)
func TestDNSResponseGenerator_generateResponseFromError(t *testing.T) {
testCases := []struct {
name string
opts *generateResponseFromErrorOpts
expectedResponse *dns.Msg
}{
{
name: "error is nil returns server failure",
opts: &generateResponseFromErrorOpts{
req: &dns.Msg{},
logger: testutil.Logger(t),
configCtx: &RouterDynamicConfig{
DisableCompression: true,
},
err: nil,
},
expectedResponse: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
Response: true,
Authoritative: false,
Rcode: dns.RcodeServerFailure,
},
},
},
{
name: "error is invalid question returns name error",
opts: &generateResponseFromErrorOpts{
req: &dns.Msg{
Question: []dns.Question{
{
Name: "invalid-question",
Qtype: dns.TypeSRV,
Qclass: dns.ClassANY,
},
},
},
qName: "invalid-question",
responseDomain: "testdomain.",
logger: testutil.Logger(t),
configCtx: &RouterDynamicConfig{
DisableCompression: true,
},
err: errInvalidQuestion,
},
expectedResponse: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
Response: true,
Authoritative: true,
Rcode: dns.RcodeNameError,
},
Question: []dns.Question{
{
Name: "invalid-question",
Qtype: dns.TypeSRV,
Qclass: dns.ClassANY,
},
},
Ns: []dns.RR{
&dns.SOA{
Hdr: dns.RR_Header{
Name: "testdomain.",
Rrtype: dns.TypeSOA,
Class: dns.ClassINET,
Ttl: 0,
},
Ns: "ns.testdomain.",
Mbox: "hostmaster.testdomain.",
Serial: uint32(time.Now().Unix()),
},
},
},
},
{
name: "error is name not found returns name error",
opts: &generateResponseFromErrorOpts{
req: &dns.Msg{
Question: []dns.Question{
{
Name: "invalid-name",
Qtype: dns.TypeSRV,
Qclass: dns.ClassANY,
},
},
},
qName: "invalid-name",
responseDomain: "testdomain.",
logger: testutil.Logger(t),
configCtx: &RouterDynamicConfig{
DisableCompression: true,
},
err: errNameNotFound,
},
expectedResponse: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
Response: true,
Authoritative: true,
Rcode: dns.RcodeNameError,
},
Question: []dns.Question{
{
Name: "invalid-name",
Qtype: dns.TypeSRV,
Qclass: dns.ClassANY,
},
},
Ns: []dns.RR{
&dns.SOA{
Hdr: dns.RR_Header{
Name: "testdomain.",
Rrtype: dns.TypeSOA,
Class: dns.ClassINET,
Ttl: 0,
},
Ns: "ns.testdomain.",
Mbox: "hostmaster.testdomain.",
Serial: uint32(time.Now().Unix()),
},
},
},
},
{
name: "error is not implemented returns not implemented error",
opts: &generateResponseFromErrorOpts{
req: &dns.Msg{
Question: []dns.Question{
{
Name: "some-question",
Qtype: dns.TypeSRV,
Qclass: dns.ClassANY,
},
},
},
qName: "some-question",
responseDomain: "testdomain.",
logger: testutil.Logger(t),
configCtx: &RouterDynamicConfig{
DisableCompression: true,
},
err: errNotImplemented,
},
expectedResponse: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
Response: true,
Authoritative: true,
Rcode: dns.RcodeNotImplemented,
},
Question: []dns.Question{
{
Name: "some-question",
Qtype: dns.TypeSRV,
Qclass: dns.ClassANY,
},
},
Ns: []dns.RR{
&dns.SOA{
Hdr: dns.RR_Header{
Name: "testdomain.",
Rrtype: dns.TypeSOA,
Class: dns.ClassINET,
Ttl: 0,
},
Ns: "ns.testdomain.",
Mbox: "hostmaster.testdomain.",
Serial: uint32(time.Now().Unix()),
},
},
},
},
{
name: "error is not supported returns name error",
opts: &generateResponseFromErrorOpts{
req: &dns.Msg{
Question: []dns.Question{
{
Name: "some-question",
Qtype: dns.TypeSRV,
Qclass: dns.ClassANY,
},
},
},
qName: "some-question",
responseDomain: "testdomain.",
logger: testutil.Logger(t),
configCtx: &RouterDynamicConfig{
DisableCompression: true,
},
err: discovery.ErrNotSupported,
},
expectedResponse: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
Response: true,
Authoritative: true,
Rcode: dns.RcodeNameError,
},
Question: []dns.Question{
{
Name: "some-question",
Qtype: dns.TypeSRV,
Qclass: dns.ClassANY,
},
},
Ns: []dns.RR{
&dns.SOA{
Hdr: dns.RR_Header{
Name: "testdomain.",
Rrtype: dns.TypeSOA,
Class: dns.ClassINET,
Ttl: 0,
},
Ns: "ns.testdomain.",
Mbox: "hostmaster.testdomain.",
Serial: uint32(time.Now().Unix()),
},
},
},
},
{
name: "error is not found returns name error",
opts: &generateResponseFromErrorOpts{
req: &dns.Msg{
Question: []dns.Question{
{
Name: "some-question",
Qtype: dns.TypeSRV,
Qclass: dns.ClassANY,
},
},
},
qName: "some-question",
responseDomain: "testdomain.",
logger: testutil.Logger(t),
configCtx: &RouterDynamicConfig{
DisableCompression: true,
},
err: discovery.ErrNotFound,
},
expectedResponse: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
Response: true,
Authoritative: true,
Rcode: dns.RcodeNameError,
},
Question: []dns.Question{
{
Name: "some-question",
Qtype: dns.TypeSRV,
Qclass: dns.ClassANY,
},
},
Ns: []dns.RR{
&dns.SOA{
Hdr: dns.RR_Header{
Name: "testdomain.",
Rrtype: dns.TypeSOA,
Class: dns.ClassINET,
Ttl: 0,
},
Ns: "ns.testdomain.",
Mbox: "hostmaster.testdomain.",
Serial: uint32(time.Now().Unix()),
},
},
},
},
{
name: "error is no data returns success with soa",
opts: &generateResponseFromErrorOpts{
req: &dns.Msg{
Question: []dns.Question{
{
Name: "some-question",
Qtype: dns.TypeSRV,
Qclass: dns.ClassANY,
},
},
},
qName: "some-question",
responseDomain: "testdomain.",
logger: testutil.Logger(t),
configCtx: &RouterDynamicConfig{
DisableCompression: true,
},
err: discovery.ErrNoData,
},
expectedResponse: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
Response: true,
Authoritative: true,
Rcode: dns.RcodeSuccess,
},
Question: []dns.Question{
{
Name: "some-question",
Qtype: dns.TypeSRV,
Qclass: dns.ClassANY,
},
},
Ns: []dns.RR{
&dns.SOA{
Hdr: dns.RR_Header{
Name: "testdomain.",
Rrtype: dns.TypeSOA,
Class: dns.ClassINET,
Ttl: 0,
},
Ns: "ns.testdomain.",
Mbox: "hostmaster.testdomain.",
Serial: uint32(time.Now().Unix()),
},
},
},
},
{
name: "error is no path to datacenter returns name error",
opts: &generateResponseFromErrorOpts{
req: &dns.Msg{
Question: []dns.Question{
{
Name: "some-question",
Qtype: dns.TypeSRV,
Qclass: dns.ClassANY,
},
},
},
qName: "some-question",
responseDomain: "testdomain.",
logger: testutil.Logger(t),
configCtx: &RouterDynamicConfig{
DisableCompression: true,
},
err: discovery.ErrNoPathToDatacenter,
},
expectedResponse: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
Response: true,
Authoritative: true,
Rcode: dns.RcodeNameError,
},
Question: []dns.Question{
{
Name: "some-question",
Qtype: dns.TypeSRV,
Qclass: dns.ClassANY,
},
},
Ns: []dns.RR{
&dns.SOA{
Hdr: dns.RR_Header{
Name: "testdomain.",
Rrtype: dns.TypeSOA,
Class: dns.ClassINET,
Ttl: 0,
},
Ns: "ns.testdomain.",
Mbox: "hostmaster.testdomain.",
Serial: uint32(time.Now().Unix()),
},
},
},
},
{
name: "error is something else returns server failure error",
opts: &generateResponseFromErrorOpts{
req: &dns.Msg{
Question: []dns.Question{
{
Name: "some-question",
Qtype: dns.TypeSRV,
Qclass: dns.ClassANY,
},
},
},
qName: "some-question",
responseDomain: "testdomain.",
logger: testutil.Logger(t),
configCtx: &RouterDynamicConfig{
DisableCompression: true,
},
err: errors.New("KABOOM"),
},
expectedResponse: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
Response: true,
Authoritative: false,
Rcode: dns.RcodeServerFailure,
},
Question: []dns.Question{
{
Name: "some-question",
Qtype: dns.TypeSRV,
Qclass: dns.ClassANY,
},
},
Ns: nil,
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
tc.opts.req.IsEdns0()
actualResponse := dnsResponseGenerator{}.generateResponseFromError(tc.opts)
require.Equal(t, tc.expectedResponse, actualResponse)
})
}
}
func TestDNSResponseGenerator_setEDNS(t *testing.T) {
testCases := []struct {
name string
req *dns.Msg
response *dns.Msg
ecsGlobal bool
expectedResponse *dns.Msg
}{
{
name: "request is not edns0, response is not edns0",
req: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Extra: []dns.RR{
&dns.OPT{
Hdr: dns.RR_Header{
Name: ".",
Rrtype: dns.TypeOPT,
Class: 4096,
Ttl: 0,
},
Option: []dns.EDNS0{
&dns.EDNS0_SUBNET{
Code: 1,
Family: 2,
SourceNetmask: 3,
SourceScope: 4,
Address: net.ParseIP("255.255.255.255"),
},
},
},
},
},
response: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
},
expectedResponse: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Extra: []dns.RR{
&dns.OPT{
Hdr: dns.RR_Header{
Name: ".",
Rrtype: dns.TypeOPT,
Class: 4096,
Ttl: 0,
},
Option: []dns.EDNS0{
&dns.EDNS0_SUBNET{
Code: 8,
Family: 2,
SourceNetmask: 3,
SourceScope: 3,
Address: net.ParseIP("255.255.255.255"),
},
},
},
},
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
dnsResponseGenerator{}.setEDNS(tc.req, tc.response, tc.ecsGlobal)
require.Equal(t, tc.expectedResponse, tc.response)
})
}
}
func TestDNSResponseGenerator_trimDNSResponse(t *testing.T) {
testCases := []struct {
name string
req *dns.Msg
response *dns.Msg
cfg *RouterDynamicConfig
remoteAddress net.Addr
expectedResponse *dns.Msg
}{
{
name: "network is udp, enable truncate is true, answer count of 1 is less/equal than configured max f 1, response is not trimmed",
req: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "foo.query.consul.",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
},
cfg: &RouterDynamicConfig{
UDPAnswerLimit: 1,
},
remoteAddress: &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
},
response: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
Rcode: dns.RcodeSuccess,
},
Question: []dns.Question{
{
Name: "foo.query.consul.",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
Answer: []dns.RR{
&dns.A{
Hdr: dns.RR_Header{
Name: "foo.query.consul.",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 123,
},
A: net.ParseIP("1.2.3.4"),
},
},
},
expectedResponse: &dns.Msg{
MsgHdr: dns.MsgHdr{
Rcode: dns.RcodeSuccess,
},
Question: []dns.Question{
{
Name: "foo.query.consul.",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
Answer: []dns.RR{
&dns.A{
Hdr: dns.RR_Header{
Name: "foo.query.consul.",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 123,
},
A: net.ParseIP("1.2.3.4"),
},
},
},
},
{
name: "network is udp, enable truncate is true, answer count of 2 is greater than configure UDP max f 2, response is trimmed",
req: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "foo.query.consul.",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
},
cfg: &RouterDynamicConfig{
UDPAnswerLimit: 1,
},
remoteAddress: &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
},
response: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
Rcode: dns.RcodeSuccess,
},
Question: []dns.Question{
{
Name: "foo.query.consul.",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
Answer: []dns.RR{
&dns.A{
Hdr: dns.RR_Header{
Name: "foo1.query.consul.",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 123,
},
A: net.ParseIP("1.2.3.4"),
},
&dns.A{
Hdr: dns.RR_Header{
Name: "foo2.query.consul.",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 123,
},
A: net.ParseIP("2.2.3.4"),
},
},
},
expectedResponse: &dns.Msg{
MsgHdr: dns.MsgHdr{
Rcode: dns.RcodeSuccess,
},
Question: []dns.Question{
{
Name: "foo.query.consul.",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
Answer: []dns.RR{
&dns.A{
Hdr: dns.RR_Header{
Name: "foo1.query.consul.",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 123,
},
A: net.ParseIP("1.2.3.4"),
},
},
},
},
{
name: "network is tcp, enable truncate is true, answer is less than 64k limit, response is not trimmed",
req: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "foo.query.consul.",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
},
cfg: &RouterDynamicConfig{},
remoteAddress: &net.TCPAddr{
IP: net.ParseIP("127.0.0.1"),
},
response: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
Rcode: dns.RcodeSuccess,
},
Question: []dns.Question{
{
Name: "foo.query.consul.",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
Answer: []dns.RR{
&dns.A{
Hdr: dns.RR_Header{
Name: "foo.query.consul.",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 123,
},
A: net.ParseIP("1.2.3.4"),
},
},
},
expectedResponse: &dns.Msg{
MsgHdr: dns.MsgHdr{
Rcode: dns.RcodeSuccess,
},
Question: []dns.Question{
{
Name: "foo.query.consul.",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
Answer: []dns.RR{
&dns.A{
Hdr: dns.RR_Header{
Name: "foo.query.consul.",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 123,
},
A: net.ParseIP("1.2.3.4"),
},
},
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
logger := testutil.Logger(t)
dnsResponseGenerator{}.trimDNSResponse(tc.cfg, tc.remoteAddress, tc.req, tc.response, logger)
require.Equal(t, tc.expectedResponse, tc.response)
})
}
}

View File

@ -1,557 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package dns
import (
"errors"
"fmt"
"net"
"regexp"
"strings"
"sync/atomic"
"time"
"github.com/armon/go-metrics"
"github.com/armon/go-radix"
"github.com/hashicorp/go-hclog"
"github.com/miekg/dns"
"github.com/hashicorp/consul/agent/config"
"github.com/hashicorp/consul/agent/discovery"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/internal/dnsutil"
"github.com/hashicorp/consul/logging"
)
const (
addrLabel = "addr"
arpaDomain = "arpa."
arpaLabel = "arpa"
suffixFailover = "failover."
suffixNoFailover = "no-failover."
maxRecursionLevelDefault = 3 // This field comes from the V1 DNS server and affects V1 catalog lookups
maxRecurseRecords = 5
)
var (
errInvalidQuestion = fmt.Errorf("invalid question")
errNameNotFound = fmt.Errorf("name not found")
errNotImplemented = fmt.Errorf("not implemented")
errRecursionFailed = fmt.Errorf("recursion failed")
trailingSpacesRE = regexp.MustCompile(" +$")
)
// RouterDynamicConfig is the dynamic configuration that can be hot-reloaded
type RouterDynamicConfig struct {
ARecordLimit int
DisableCompression bool
EnableTruncate bool
NodeMetaTXT bool
NodeTTL time.Duration
Recursors []string
RecursorTimeout time.Duration
RecursorStrategy structs.RecursorStrategy
SOAConfig SOAConfig
// TTLRadix sets service TTLs by prefix, eg: "database-*"
TTLRadix *radix.Tree
// TTLStrict sets TTLs to service by full name match. It Has higher priority than TTLRadix
TTLStrict map[string]time.Duration
UDPAnswerLimit int
}
// GetTTLForService Find the TTL for a given service.
// return ttl, true if found, 0, false otherwise
func (cfg *RouterDynamicConfig) GetTTLForService(service string) (time.Duration, bool) {
if cfg.TTLStrict != nil {
ttl, ok := cfg.TTLStrict[service]
if ok {
return ttl, true
}
}
if cfg.TTLRadix != nil {
_, ttlRaw, ok := cfg.TTLRadix.LongestPrefix(service)
if ok {
return ttlRaw.(time.Duration), true
}
}
return 0, false
}
type SOAConfig struct {
Refresh uint32 // 3600 by default
Retry uint32 // 600
Expire uint32 // 86400
Minttl uint32 // 0
}
// DiscoveryQueryProcessor is an interface that can be used by any consumer requesting Service Discovery results.
// This could be attached to a gRPC endpoint in the future in addition to DNS.
// Making this an interface means testing the router with a mock is trivial.
type DiscoveryQueryProcessor interface {
QueryByName(*discovery.Query, discovery.Context) ([]*discovery.Result, error)
QueryByIP(net.IP, discovery.Context) ([]*discovery.Result, error)
}
// dnsRecursor is an interface that can be used to mock calls to external DNS servers for unit testing.
//
//go:generate mockery --name dnsRecursor --inpackage
type dnsRecursor interface {
handle(req *dns.Msg, cfgCtx *RouterDynamicConfig, remoteAddress net.Addr) (*dns.Msg, error)
}
// Router replaces miekg/dns.ServeMux with a simpler router that only checks for the 2-3 valid domains
// that Consul supports and forwards to a single DiscoveryQueryProcessor handler. If there is no match, it will recurse.
type Router struct {
processor DiscoveryQueryProcessor
recursor dnsRecursor
domain string
altDomain string
nodeName string
logger hclog.Logger
tokenFunc func() string
translateAddressFunc func(dc string, addr string, taggedAddresses map[string]string, accept dnsutil.TranslateAddressAccept) string
translateServiceAddressFunc func(dc string, address string, taggedAddresses map[string]structs.ServiceAddress, accept dnsutil.TranslateAddressAccept) string
// dynamicConfig stores the config as an atomic value (for hot-reloading).
// It is always of type *RouterDynamicConfig
dynamicConfig atomic.Value
}
var _ = dns.Handler(&Router{})
var _ = DNSRouter(&Router{})
func NewRouter(cfg Config) (*Router, error) {
// Make sure domains are FQDN, make them case-insensitive for DNSRequestRouter
domain := dns.CanonicalName(cfg.AgentConfig.DNSDomain)
altDomain := dns.CanonicalName(cfg.AgentConfig.DNSAltDomain)
logger := cfg.Logger.Named(logging.DNS)
router := &Router{
processor: cfg.Processor,
recursor: newRecursor(logger),
domain: domain,
altDomain: altDomain,
logger: logger,
nodeName: cfg.AgentConfig.NodeName,
tokenFunc: cfg.TokenFunc,
translateAddressFunc: cfg.TranslateAddressFunc,
translateServiceAddressFunc: cfg.TranslateServiceAddressFunc,
}
if err := router.ReloadConfig(cfg.AgentConfig); err != nil {
return nil, err
}
return router, nil
}
// HandleRequest is used to process an individual DNS request. It returns a message in success or fail cases.
func (r *Router) HandleRequest(req *dns.Msg, reqCtx Context, remoteAddress net.Addr) *dns.Msg {
configCtx := r.dynamicConfig.Load().(*RouterDynamicConfig)
respGenerator := dnsResponseGenerator{}
err := validateAndNormalizeRequest(req)
if err != nil {
r.logger.Error("error parsing DNS query", "error", err)
if errors.Is(err, errInvalidQuestion) {
return respGenerator.createRefusedResponse(req)
}
return respGenerator.createServerFailureResponse(req, configCtx, false)
}
r.logger.Trace("received request", "question", req.Question[0].Name, "type", dns.Type(req.Question[0].Qtype).String())
r.normalizeContext(&reqCtx)
defer func(s time.Time, q dns.Question) {
metrics.MeasureSinceWithLabels([]string{"dns", "query"}, s,
[]metrics.Label{
{Name: "node", Value: r.nodeName},
{Name: "type", Value: dns.Type(q.Qtype).String()},
})
r.logger.Trace("request served from client",
"name", q.Name,
"type", dns.Type(q.Qtype).String(),
"class", dns.Class(q.Qclass).String(),
"latency", time.Since(s).String(),
"client", remoteAddress.String(),
"client_network", remoteAddress.Network(),
)
}(time.Now(), req.Question[0])
return r.handleRequestRecursively(req, reqCtx, configCtx, remoteAddress, maxRecursionLevelDefault)
}
// handleRequestRecursively is used to process an individual DNS request. It will recurse as needed
// a maximum number of times and returns a message in success or fail cases.
func (r *Router) handleRequestRecursively(req *dns.Msg, reqCtx Context, configCtx *RouterDynamicConfig,
remoteAddress net.Addr, maxRecursionLevel int) *dns.Msg {
respGenerator := dnsResponseGenerator{}
r.logger.Trace(
"received request",
"question", req.Question[0].Name,
"type", dns.Type(req.Question[0].Qtype).String(),
"recursion_remaining", maxRecursionLevel)
responseDomain, needRecurse := r.parseDomain(req.Question[0].Name)
if needRecurse && !canRecurse(configCtx) {
// This is the same error as an unmatched domain
return respGenerator.createRefusedResponse(req)
}
if needRecurse {
r.logger.Trace("checking recursors to handle request", "question", req.Question[0].Name, "type", dns.Type(req.Question[0].Qtype).String())
// This assumes `canRecurse(configCtx)` is true above
resp, err := r.recursor.handle(req, configCtx, remoteAddress)
if err != nil && !errors.Is(err, errRecursionFailed) {
r.logger.Error("unhandled error recursing DNS query", "error", err)
}
if err != nil {
return respGenerator.createServerFailureResponse(req, configCtx, true)
}
return resp
}
// Need to pass the question name to properly support recursion and the
// trimming of the domain suffixes.
qName := dns.CanonicalName(req.Question[0].Name)
if maxRecursionLevel < maxRecursionLevelDefault {
// Get the QName without the domain suffix
qName = r.trimDomain(qName)
}
results, query, err := discoveryResultsFetcher{}.getQueryResults(&getQueryOptions{
req: req,
reqCtx: reqCtx,
qName: qName,
remoteAddress: remoteAddress,
processor: r.processor,
logger: r.logger,
domain: r.domain,
altDomain: r.altDomain,
})
// in case of the wrapped ECSNotGlobalError, extract the error from it.
isECSGlobal := !errors.Is(err, discovery.ErrECSNotGlobal)
err = getErrorFromECSNotGlobalError(err)
if err != nil {
return respGenerator.generateResponseFromError(&generateResponseFromErrorOpts{
req: req,
err: err,
qName: qName,
configCtx: configCtx,
responseDomain: responseDomain,
isECSGlobal: isECSGlobal,
query: query,
canRecurse: canRecurse(configCtx),
logger: r.logger,
})
}
r.logger.Trace("serializing results", "question", req.Question[0].Name, "results-found", len(results))
// This needs the question information because it affects the serialization format.
// e.g., the Consul service has the same "results" for both NS and A/AAAA queries, but the serialization differs.
serializedOpts := &serializeOptions{
req: req,
reqCtx: reqCtx,
query: query,
results: results,
cfg: configCtx,
responseDomain: responseDomain,
remoteAddress: remoteAddress,
maxRecursionLevel: maxRecursionLevel,
translateAddressFunc: r.translateAddressFunc,
translateServiceAddressFunc: r.translateServiceAddressFunc,
resolveCnameFunc: r.resolveCNAME,
}
resp, err := messageSerializer{}.serialize(serializedOpts)
if err != nil {
return respGenerator.generateResponseFromError(&generateResponseFromErrorOpts{
req: req,
err: err,
qName: qName,
configCtx: configCtx,
responseDomain: responseDomain,
isECSGlobal: isECSGlobal,
query: query,
canRecurse: false,
logger: r.logger,
})
}
respGenerator.trimDNSResponse(configCtx, remoteAddress, req, resp, r.logger)
respGenerator.setEDNS(req, resp, isECSGlobal)
return resp
}
// trimDomain trims the domain from the question name.
func (r *Router) trimDomain(questionName string) string {
longer := r.domain
shorter := r.altDomain
if len(shorter) > len(longer) {
longer, shorter = shorter, longer
}
if strings.HasSuffix(questionName, "."+strings.TrimLeft(longer, ".")) {
return strings.TrimSuffix(questionName, longer)
}
return strings.TrimSuffix(questionName, shorter)
}
// ServeDNS implements the miekg/dns.Handler interface.
// This is a standard DNS listener.
func (r *Router) ServeDNS(w dns.ResponseWriter, req *dns.Msg) {
out := r.HandleRequest(req, Context{}, w.RemoteAddr())
w.WriteMsg(out)
}
// ReloadConfig hot-reloads the router config with new parameters
func (r *Router) ReloadConfig(newCfg *config.RuntimeConfig) error {
cfg, err := getDynamicRouterConfig(newCfg)
if err != nil {
return fmt.Errorf("error loading DNS config: %w", err)
}
r.dynamicConfig.Store(cfg)
return nil
}
// resolveCNAME is used to recursively resolve CNAME records
func (r *Router) resolveCNAME(cfgContext *RouterDynamicConfig, name string, reqCtx Context,
remoteAddress net.Addr, maxRecursionLevel int) []dns.RR {
// If the CNAME record points to a Consul address, resolve it internally
// Convert query to lowercase because DNS is case-insensitive; r.domain and
// r.altDomain are already converted
if ln := strings.ToLower(name); strings.HasSuffix(ln, "."+r.domain) || strings.HasSuffix(ln, "."+r.altDomain) {
if maxRecursionLevel < 1 {
r.logger.Error("Infinite recursion detected for name, won't perform any CNAME resolution.", "name", name)
return nil
}
req := &dns.Msg{}
req.SetQuestion(name, dns.TypeANY)
// TODO: handle error response (this is a comment from the V1 DNS Server)
resp := r.handleRequestRecursively(req, reqCtx, cfgContext, nil, maxRecursionLevel-1)
return resp.Answer
}
// Do nothing if we don't have a recursor
if !canRecurse(cfgContext) {
return nil
}
// Ask for any A records
m := new(dns.Msg)
m.SetQuestion(name, dns.TypeA)
// Make a DNS lookup request
recursorResponse, err := r.recursor.handle(m, cfgContext, remoteAddress)
if err == nil {
return recursorResponse.Answer
}
r.logger.Error("all resolvers failed for name", "name", name)
return nil
}
// Request type is similar to miekg/dns.Type, but correlates to the different query processors we might need to invoke.
type requestType string
const (
requestTypeName requestType = "NAME" // A/AAAA/CNAME/SRV
requestTypeIP requestType = "IP" // PTR
requestTypeAddress requestType = "ADDR" // Custom addr. A/AAAA lookups
requestTypeConsul requestType = "CONSUL" // SOA/NS
)
// parseDomain converts a DNS message into a generic discovery request.
// If the request domain does not match "consul." or the alternative domain,
// it will return true for needRecurse. The logic is based on miekg/dns.ServeDNS matcher.
// The implementation assumes that the only valid domains are "consul." and the alternative domain, and
// that DS query types are not supported.
func (r *Router) parseDomain(questionName string) (string, bool) {
target := dns.CanonicalName(questionName)
target, _ = stripAnyFailoverSuffix(target)
for offset, overflow := 0, false; !overflow; offset, overflow = dns.NextLabel(target, offset) {
subdomain := target[offset:]
switch subdomain {
case ".":
// We don't support consul having a domain or altdomain attached to the root.
return "", true
case r.domain:
return r.domain, false
case r.altDomain:
return r.altDomain, false
case arpaDomain:
// PTR queries always respond with the primary domain.
return r.domain, false
// Default: fallthrough
}
}
// No match found; recurse if possible
return "", true
}
// GetConfig returns the current router config
func (r *Router) GetConfig() *RouterDynamicConfig {
return r.dynamicConfig.Load().(*RouterDynamicConfig)
}
// getErrorFromECSNotGlobalError returns the underlying error from an ECSNotGlobalError, if it exists.
func getErrorFromECSNotGlobalError(err error) error {
if errors.Is(err, discovery.ErrECSNotGlobal) {
return err.(discovery.ECSNotGlobalError).Unwrap()
}
return err
}
// parseRequestType inspects the DNS message type and question name to determine the requestType of request.
// We assume by the time this is called, we are responding to a question with a domain we serve.
// This is used internally to determine which query processor method (if any) to invoke.
func parseRequestType(req *dns.Msg) requestType {
switch {
case req.Question[0].Qtype == dns.TypeSOA || req.Question[0].Qtype == dns.TypeNS:
// SOA and NS type supersede the domain
// NOTE!: In V1 of the DNS server it was possible to serve a PTR lookup using the arpa domain but a SOA question type.
// This also included the SOA record. This seemed inconsistent and unnecessary - it was removed for simplicity.
return requestTypeConsul
case isPTRSubdomain(req.Question[0].Name):
return requestTypeIP
case isAddrSubdomain(req.Question[0].Name):
return requestTypeAddress
default:
return requestTypeName
}
}
// validateAndNormalizeRequest validates the DNS request and normalizes the request name.
func validateAndNormalizeRequest(req *dns.Msg) error {
// like upstream miekg/dns, we require at least one question,
// but we will only answer the first.
if len(req.Question) == 0 {
return errInvalidQuestion
}
// We mutate the request name to respond with the canonical name.
// This is Consul convention.
req.Question[0].Name = dns.CanonicalName(req.Question[0].Name)
return nil
}
// normalizeContext makes sure context information is populated with agent defaults as needed.
// Right now this is just the ACL token. We do this in the router with the token because DNS doesn't
// allow a token to be passed in the request, and we expect ACL tokens upfront in APIs when they are enabled.
// Tenancy information is left out because it is safe/expected to assume agent defaults in the backend lookup.
func (r *Router) normalizeContext(ctx *Context) {
if ctx.Token == "" {
ctx.Token = r.tokenFunc()
}
}
// stripAnyFailoverSuffix strips off the suffixes that may have been added to the request name.
func stripAnyFailoverSuffix(target string) (string, bool) {
enableFailover := false
// Strip off any suffixes that may have been added.
offset, underflow := dns.PrevLabel(target, 1)
if !underflow {
maybeSuffix := target[offset:]
switch maybeSuffix {
case suffixFailover:
target = target[:offset]
enableFailover = true
case suffixNoFailover:
target = target[:offset]
}
}
return target, enableFailover
}
// isAddrSubdomain returns true if the domain is a valid addr subdomain.
func isAddrSubdomain(domain string) bool {
labels := dns.SplitDomainName(domain)
// Looking for <hexadecimal-encoded IP>.addr.<optional datacenter>.consul.
if len(labels) > 2 {
return labels[1] == addrLabel
}
return false
}
// isPTRSubdomain returns true if the domain ends in the PTR domain, "in-addr.arpa.".
func isPTRSubdomain(domain string) bool {
labels := dns.SplitDomainName(domain)
labelCount := len(labels)
// We keep this check brief so we can have more specific error handling later.
if labelCount < 1 {
return false
}
return labels[labelCount-1] == arpaLabel
}
// getDynamicRouterConfig takes agent config and creates/resets the config used by DNS Router
func getDynamicRouterConfig(conf *config.RuntimeConfig) (*RouterDynamicConfig, error) {
cfg := &RouterDynamicConfig{
ARecordLimit: conf.DNSARecordLimit,
EnableTruncate: conf.DNSEnableTruncate,
NodeTTL: conf.DNSNodeTTL,
RecursorStrategy: conf.DNSRecursorStrategy,
RecursorTimeout: conf.DNSRecursorTimeout,
UDPAnswerLimit: conf.DNSUDPAnswerLimit,
NodeMetaTXT: conf.DNSNodeMetaTXT,
DisableCompression: conf.DNSDisableCompression,
SOAConfig: SOAConfig{
Expire: conf.DNSSOA.Expire,
Minttl: conf.DNSSOA.Minttl,
Refresh: conf.DNSSOA.Refresh,
Retry: conf.DNSSOA.Retry,
},
}
if conf.DNSServiceTTL != nil {
cfg.TTLRadix = radix.New()
cfg.TTLStrict = make(map[string]time.Duration)
for key, ttl := range conf.DNSServiceTTL {
// All suffix with '*' are put in radix
// This include '*' that will match anything
if strings.HasSuffix(key, "*") {
cfg.TTLRadix.Insert(key[:len(key)-1], ttl)
} else {
cfg.TTLStrict[key] = ttl
}
}
} else {
cfg.TTLRadix = nil
cfg.TTLStrict = nil
}
for _, r := range conf.DNSRecursors {
ra, err := formatRecursorAddress(r)
if err != nil {
return nil, fmt.Errorf("invalid recursor address: %w", err)
}
cfg.Recursors = append(cfg.Recursors, ra)
}
return cfg, nil
}
// canRecurse returns true if the router can recurse on the request.
func canRecurse(cfg *RouterDynamicConfig) bool {
return len(cfg.Recursors) > 0
}

View File

@ -1,405 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package dns
import (
"net"
"testing"
"time"
"github.com/miekg/dns"
)
func Test_HandleRequest_ADDR(t *testing.T) {
testCases := []HandleTestCase{
{
name: "test A 'addr.' query, ipv4 response",
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "c000020a.addr.dc1.consul", // "intentionally missing the trailing dot"
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
},
response: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
Response: true,
Authoritative: true,
},
Compress: true,
Question: []dns.Question{
{
Name: "c000020a.addr.dc1.consul.",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
Answer: []dns.RR{
&dns.A{
Hdr: dns.RR_Header{
Name: "c000020a.addr.dc1.consul.",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 123,
},
A: net.ParseIP("192.0.2.10"),
},
},
},
},
{
name: "test AAAA 'addr.' query, ipv4 response",
// Since we asked for an AAAA record, the A record that resolves from the address is attached as an extra
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "c000020a.addr.dc1.consul",
Qtype: dns.TypeAAAA,
Qclass: dns.ClassINET,
},
},
},
response: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
Response: true,
Authoritative: true,
},
Compress: true,
Question: []dns.Question{
{
Name: "c000020a.addr.dc1.consul.",
Qtype: dns.TypeAAAA,
Qclass: dns.ClassINET,
},
},
Extra: []dns.RR{
&dns.A{
Hdr: dns.RR_Header{
Name: "c000020a.addr.dc1.consul.",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 123,
},
A: net.ParseIP("192.0.2.10"),
},
},
},
},
{
name: "test SRV 'addr.' query, ipv4 response",
// Since we asked for a SRV record, the A record that resolves from the address is attached as an extra
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "c000020a.addr.dc1.consul",
Qtype: dns.TypeSRV,
Qclass: dns.ClassINET,
},
},
},
response: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
Response: true,
Authoritative: true,
},
Compress: true,
Question: []dns.Question{
{
Name: "c000020a.addr.dc1.consul.",
Qtype: dns.TypeSRV,
Qclass: dns.ClassINET,
},
},
Extra: []dns.RR{
&dns.A{
Hdr: dns.RR_Header{
Name: "c000020a.addr.dc1.consul.",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 123,
},
A: net.ParseIP("192.0.2.10"),
},
},
},
},
{
name: "test ANY 'addr.' query, ipv4 response",
// The response to ANY should look the same as the A response
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "c000020a.addr.dc1.consul",
Qtype: dns.TypeANY,
Qclass: dns.ClassINET,
},
},
},
response: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
Response: true,
Authoritative: true,
},
Compress: true,
Question: []dns.Question{
{
Name: "c000020a.addr.dc1.consul.",
Qtype: dns.TypeANY,
Qclass: dns.ClassINET,
},
},
Answer: []dns.RR{
&dns.A{
Hdr: dns.RR_Header{
Name: "c000020a.addr.dc1.consul.",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 123,
},
A: net.ParseIP("192.0.2.10"),
},
},
},
},
{
name: "test AAAA 'addr.' query, ipv6 response",
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "20010db800010002cafe000000001337.addr.dc1.consul",
Qtype: dns.TypeAAAA,
Qclass: dns.ClassINET,
},
},
},
response: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
Response: true,
Authoritative: true,
},
Compress: true,
Question: []dns.Question{
{
Name: "20010db800010002cafe000000001337.addr.dc1.consul.",
Qtype: dns.TypeAAAA,
Qclass: dns.ClassINET,
},
},
Answer: []dns.RR{
&dns.AAAA{
Hdr: dns.RR_Header{
Name: "20010db800010002cafe000000001337.addr.dc1.consul.",
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: 123,
},
AAAA: net.ParseIP("2001:db8:1:2:cafe::1337"),
},
},
},
},
{
name: "test A 'addr.' query, ipv6 response",
// Since we asked for an A record, the AAAA record that resolves from the address is attached as an extra
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "20010db800010002cafe000000001337.addr.dc1.consul",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
},
response: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
Response: true,
Authoritative: true,
},
Compress: true,
Question: []dns.Question{
{
Name: "20010db800010002cafe000000001337.addr.dc1.consul.",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
Extra: []dns.RR{
&dns.AAAA{
Hdr: dns.RR_Header{
Name: "20010db800010002cafe000000001337.addr.dc1.consul.",
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: 123,
},
AAAA: net.ParseIP("2001:db8:1:2:cafe::1337"),
},
},
},
},
{
name: "test SRV 'addr.' query, ipv6 response",
// Since we asked for an SRV record, the AAAA record that resolves from the address is attached as an extra
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "20010db800010002cafe000000001337.addr.dc1.consul",
Qtype: dns.TypeSRV,
Qclass: dns.ClassINET,
},
},
},
response: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
Response: true,
Authoritative: true,
},
Compress: true,
Question: []dns.Question{
{
Name: "20010db800010002cafe000000001337.addr.dc1.consul.",
Qtype: dns.TypeSRV,
Qclass: dns.ClassINET,
},
},
Extra: []dns.RR{
&dns.AAAA{
Hdr: dns.RR_Header{
Name: "20010db800010002cafe000000001337.addr.dc1.consul.",
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: 123,
},
AAAA: net.ParseIP("2001:db8:1:2:cafe::1337"),
},
},
},
},
{
name: "test ANY 'addr.' query, ipv6 response",
// The response to ANY should look the same as the AAAA response
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "20010db800010002cafe000000001337.addr.dc1.consul",
Qtype: dns.TypeANY,
Qclass: dns.ClassINET,
},
},
},
response: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
Response: true,
Authoritative: true,
},
Compress: true,
Question: []dns.Question{
{
Name: "20010db800010002cafe000000001337.addr.dc1.consul.",
Qtype: dns.TypeANY,
Qclass: dns.ClassINET,
},
},
Answer: []dns.RR{
&dns.AAAA{
Hdr: dns.RR_Header{
Name: "20010db800010002cafe000000001337.addr.dc1.consul.",
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: 123,
},
AAAA: net.ParseIP("2001:db8:1:2:cafe::1337"),
},
},
},
},
{
name: "test malformed 'addr.' query",
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "c000.addr.dc1.consul", // too short
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
},
response: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
Response: true,
Rcode: dns.RcodeNameError, // NXDOMAIN
Authoritative: true,
},
Compress: true,
Question: []dns.Question{
{
Name: "c000.addr.dc1.consul.",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
Ns: []dns.RR{
&dns.SOA{
Hdr: dns.RR_Header{
Name: "consul.",
Rrtype: dns.TypeSOA,
Class: dns.ClassINET,
Ttl: 4,
},
Ns: "ns.consul.",
Serial: uint32(time.Now().Unix()),
Mbox: "hostmaster.consul.",
Refresh: 1,
Expire: 3,
Retry: 2,
Minttl: 4,
},
},
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
runHandleTestCases(t, tc)
})
}
}

View File

@ -1,244 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package dns
import (
"net"
"testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/hashicorp/consul/agent/config"
"github.com/hashicorp/consul/agent/discovery"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/internal/resource"
)
func Test_HandleRequest_NS(t *testing.T) {
testCases := []HandleTestCase{
{
name: "vanilla NS query",
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "consul.",
Qtype: dns.TypeNS,
Qclass: dns.ClassINET,
},
},
},
configureDataFetcher: func(fetcher discovery.CatalogDataFetcher) {
fetcher.(*discovery.MockCatalogDataFetcher).
On("FetchEndpoints", mock.Anything, mock.Anything, mock.Anything).
Return([]*discovery.Result{
{
Node: &discovery.Location{Name: "server-one", Address: "1.2.3.4"},
Type: discovery.ResultTypeWorkload,
Tenancy: discovery.ResultTenancy{
Namespace: resource.DefaultNamespaceName,
Partition: resource.DefaultPartitionName,
},
},
{
Node: &discovery.Location{Name: "server-two", Address: "4.5.6.7"},
Type: discovery.ResultTypeWorkload,
Tenancy: discovery.ResultTenancy{
Namespace: resource.DefaultNamespaceName,
Partition: resource.DefaultPartitionName,
},
},
}, nil).
Run(func(args mock.Arguments) {
req := args.Get(1).(*discovery.QueryPayload)
reqType := args.Get(2).(discovery.LookupType)
require.Equal(t, discovery.LookupTypeService, reqType)
require.Equal(t, structs.ConsulServiceName, req.Name)
require.Equal(t, 3, req.Limit)
})
},
validateAndNormalizeExpected: true,
response: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
Response: true,
Authoritative: true,
},
Compress: true,
Question: []dns.Question{
{
Name: "consul.",
Qtype: dns.TypeNS,
Qclass: dns.ClassINET,
},
},
Answer: []dns.RR{
&dns.NS{
Hdr: dns.RR_Header{
Name: "consul.",
Rrtype: dns.TypeNS,
Class: dns.ClassINET,
Ttl: 123,
},
Ns: "server-one.workload.default.ns.default.ap.consul.",
},
&dns.NS{
Hdr: dns.RR_Header{
Name: "consul.",
Rrtype: dns.TypeNS,
Class: dns.ClassINET,
Ttl: 123,
},
Ns: "server-two.workload.default.ns.default.ap.consul.",
},
},
Extra: []dns.RR{
&dns.A{
Hdr: dns.RR_Header{
Name: "server-one.workload.default.ns.default.ap.consul.",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 123,
},
A: net.ParseIP("1.2.3.4"),
},
&dns.A{
Hdr: dns.RR_Header{
Name: "server-two.workload.default.ns.default.ap.consul.",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 123,
},
A: net.ParseIP("4.5.6.7"),
},
},
},
},
{
name: "NS query against alternate domain",
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "testdomain.",
Qtype: dns.TypeNS,
Qclass: dns.ClassINET,
},
},
},
agentConfig: &config.RuntimeConfig{
DNSDomain: "consul",
DNSAltDomain: "testdomain",
DNSNodeTTL: 123 * time.Second,
DNSSOA: config.RuntimeSOAConfig{
Refresh: 1,
Retry: 2,
Expire: 3,
Minttl: 4,
},
DNSUDPAnswerLimit: maxUDPAnswerLimit,
},
configureDataFetcher: func(fetcher discovery.CatalogDataFetcher) {
fetcher.(*discovery.MockCatalogDataFetcher).
On("FetchEndpoints", mock.Anything, mock.Anything, mock.Anything).
Return([]*discovery.Result{
{
Node: &discovery.Location{Name: "server-one", Address: "1.2.3.4"},
Type: discovery.ResultTypeWorkload,
Tenancy: discovery.ResultTenancy{
Namespace: resource.DefaultNamespaceName,
Partition: resource.DefaultPartitionName,
},
},
{
Node: &discovery.Location{Name: "server-two", Address: "4.5.6.7"},
Type: discovery.ResultTypeWorkload,
Tenancy: discovery.ResultTenancy{
Namespace: resource.DefaultNamespaceName,
Partition: resource.DefaultPartitionName,
},
},
}, nil).
Run(func(args mock.Arguments) {
req := args.Get(1).(*discovery.QueryPayload)
reqType := args.Get(2).(discovery.LookupType)
require.Equal(t, discovery.LookupTypeService, reqType)
require.Equal(t, structs.ConsulServiceName, req.Name)
require.Equal(t, 3, req.Limit)
})
},
validateAndNormalizeExpected: true,
response: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
Response: true,
Authoritative: true,
},
Compress: true,
Question: []dns.Question{
{
Name: "testdomain.",
Qtype: dns.TypeNS,
Qclass: dns.ClassINET,
},
},
Answer: []dns.RR{
&dns.NS{
Hdr: dns.RR_Header{
Name: "testdomain.",
Rrtype: dns.TypeNS,
Class: dns.ClassINET,
Ttl: 123,
},
Ns: "server-one.workload.default.ns.default.ap.testdomain.",
},
&dns.NS{
Hdr: dns.RR_Header{
Name: "testdomain.",
Rrtype: dns.TypeNS,
Class: dns.ClassINET,
Ttl: 123,
},
Ns: "server-two.workload.default.ns.default.ap.testdomain.",
},
},
Extra: []dns.RR{
&dns.A{
Hdr: dns.RR_Header{
Name: "server-one.workload.default.ns.default.ap.testdomain.",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 123,
},
A: net.ParseIP("1.2.3.4"),
},
&dns.A{
Hdr: dns.RR_Header{
Name: "server-two.workload.default.ns.default.ap.testdomain.",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 123,
},
A: net.ParseIP("4.5.6.7"),
},
},
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
runHandleTestCases(t, tc)
})
}
}

View File

@ -1,187 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package dns
import (
"net"
"testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/hashicorp/consul/agent/config"
"github.com/hashicorp/consul/agent/discovery"
)
func Test_HandleRequest_PreparedQuery(t *testing.T) {
testCases := []HandleTestCase{
{
name: "v1 prepared query w/ TTL override, ANY query, returns A record",
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "foo.query.consul.",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
},
agentConfig: &config.RuntimeConfig{
DNSDomain: "consul",
DNSNodeTTL: 123 * time.Second,
DNSSOA: config.RuntimeSOAConfig{
Refresh: 1,
Retry: 2,
Expire: 3,
Minttl: 4,
},
DNSUDPAnswerLimit: maxUDPAnswerLimit,
// We shouldn't use this if we have the override defined
DNSServiceTTL: map[string]time.Duration{
"foo": 1 * time.Second,
},
},
configureDataFetcher: func(fetcher discovery.CatalogDataFetcher) {
fetcher.(*discovery.MockCatalogDataFetcher).
On("FetchPreparedQuery", mock.Anything, mock.Anything).
Return([]*discovery.Result{
{
Service: &discovery.Location{Name: "foo", Address: "1.2.3.4"},
Node: &discovery.Location{Name: "bar", Address: "1.2.3.4"},
Type: discovery.ResultTypeService,
Tenancy: discovery.ResultTenancy{
Datacenter: "dc1",
},
DNS: discovery.DNSConfig{
TTL: getUint32Ptr(3),
Weight: 1,
},
},
}, nil).
Run(func(args mock.Arguments) {
req := args.Get(1).(*discovery.QueryPayload)
require.Equal(t, "foo", req.Name)
})
},
validateAndNormalizeExpected: true,
response: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
Response: true,
Authoritative: true,
},
Compress: true,
Question: []dns.Question{
{
Name: "foo.query.consul.",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
Answer: []dns.RR{
&dns.A{
Hdr: dns.RR_Header{
Name: "foo.query.consul.",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 3,
},
A: net.ParseIP("1.2.3.4"),
},
},
},
},
{
name: "v1 prepared query w/ matching service TTL, ANY query, returns A record",
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "foo.query.dc1.cluster.consul.",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
},
agentConfig: &config.RuntimeConfig{
DNSDomain: "consul",
DNSNodeTTL: 123 * time.Second,
DNSSOA: config.RuntimeSOAConfig{
Refresh: 1,
Retry: 2,
Expire: 3,
Minttl: 4,
},
DNSUDPAnswerLimit: maxUDPAnswerLimit,
// Results should use this as the TTL
DNSServiceTTL: map[string]time.Duration{
"foo": 1 * time.Second,
},
},
configureDataFetcher: func(fetcher discovery.CatalogDataFetcher) {
fetcher.(*discovery.MockCatalogDataFetcher).
On("FetchPreparedQuery", mock.Anything, mock.Anything).
Return([]*discovery.Result{
{
Service: &discovery.Location{Name: "foo", Address: "1.2.3.4"},
Node: &discovery.Location{Name: "bar", Address: "1.2.3.4"},
Type: discovery.ResultTypeService,
Tenancy: discovery.ResultTenancy{
Datacenter: "dc1",
},
DNS: discovery.DNSConfig{
// Intentionally no TTL here.
Weight: 1,
},
},
}, nil).
Run(func(args mock.Arguments) {
req := args.Get(1).(*discovery.QueryPayload)
require.Equal(t, "foo", req.Name)
require.Equal(t, "dc1", req.Tenancy.Datacenter)
})
},
validateAndNormalizeExpected: true,
response: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
Response: true,
Authoritative: true,
},
Compress: true,
Question: []dns.Question{
{
Name: "foo.query.dc1.cluster.consul.",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
Answer: []dns.RR{
&dns.A{
Hdr: dns.RR_Header{
Name: "foo.query.dc1.cluster.consul.",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 1,
},
A: net.ParseIP("1.2.3.4"),
},
},
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
runHandleTestCases(t, tc)
})
}
}

View File

@ -1,563 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package dns
import (
"net"
"testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/hashicorp/consul/agent/discovery"
)
func Test_HandleRequest_PTR(t *testing.T) {
testCases := []HandleTestCase{
{
name: "PTR lookup for node, query type is ANY",
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "4.3.2.1.in-addr.arpa",
Qtype: dns.TypeANY,
Qclass: dns.ClassINET,
},
},
},
configureDataFetcher: func(fetcher discovery.CatalogDataFetcher) {
results := []*discovery.Result{
{
Node: &discovery.Location{Name: "foo", Address: "1.2.3.4"},
Service: &discovery.Location{Name: "bar", Address: "foo"},
Type: discovery.ResultTypeNode,
Tenancy: discovery.ResultTenancy{
Datacenter: "dc2",
},
},
}
fetcher.(*discovery.MockCatalogDataFetcher).
On("FetchRecordsByIp", mock.Anything, mock.Anything).
Return(results, nil).
Run(func(args mock.Arguments) {
req := args.Get(1).(net.IP)
require.NotNil(t, req)
require.Equal(t, "1.2.3.4", req.String())
})
},
response: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
Response: true,
Authoritative: true,
},
Compress: true,
Question: []dns.Question{
{
Name: "4.3.2.1.in-addr.arpa.",
Qtype: dns.TypeANY,
Qclass: dns.ClassINET,
},
},
Answer: []dns.RR{
&dns.PTR{
Hdr: dns.RR_Header{
Name: "4.3.2.1.in-addr.arpa.",
Rrtype: dns.TypePTR,
Class: dns.ClassINET,
},
Ptr: "foo.node.dc2.consul.",
},
},
},
},
{
name: "PTR lookup for IPV6 node",
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "b.a.9.8.7.6.5.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa",
Qtype: dns.TypePTR,
Qclass: dns.ClassINET,
},
},
},
configureDataFetcher: func(fetcher discovery.CatalogDataFetcher) {
results := []*discovery.Result{
{
Node: &discovery.Location{Name: "foo", Address: "2001:db8::567:89ab"},
Service: &discovery.Location{Name: "web", Address: "foo"},
Type: discovery.ResultTypeNode,
Tenancy: discovery.ResultTenancy{
Datacenter: "dc2",
},
},
}
fetcher.(*discovery.MockCatalogDataFetcher).
On("FetchRecordsByIp", mock.Anything, mock.Anything).
Return(results, nil).
Run(func(args mock.Arguments) {
req := args.Get(1).(net.IP)
require.NotNil(t, req)
require.Equal(t, "2001:db8::567:89ab", req.String())
})
},
response: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
Response: true,
Authoritative: true,
},
Compress: true,
Question: []dns.Question{
{
Name: "b.a.9.8.7.6.5.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.",
Qtype: dns.TypePTR,
Qclass: dns.ClassINET,
},
},
Answer: []dns.RR{
&dns.PTR{
Hdr: dns.RR_Header{
Name: "b.a.9.8.7.6.5.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.",
Rrtype: dns.TypePTR,
Class: dns.ClassINET,
},
Ptr: "foo.node.dc2.consul.",
},
},
},
},
{
name: "PTR lookup for invalid IP address",
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "257.3.2.1.in-addr.arpa",
Qtype: dns.TypeANY,
Qclass: dns.ClassINET,
},
},
},
response: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
Response: true,
Authoritative: true,
Rcode: dns.RcodeNameError,
},
Compress: true,
Question: []dns.Question{
{
Name: "257.3.2.1.in-addr.arpa.",
Qtype: dns.TypeANY,
Qclass: dns.ClassINET,
},
},
Ns: []dns.RR{
&dns.SOA{
Hdr: dns.RR_Header{
Name: "consul.",
Rrtype: dns.TypeSOA,
Class: dns.ClassINET,
Ttl: 4,
},
Ns: "ns.consul.",
Serial: uint32(time.Now().Unix()),
Mbox: "hostmaster.consul.",
Refresh: 1,
Expire: 3,
Retry: 2,
Minttl: 4,
},
},
},
},
{
name: "PTR lookup for invalid subdomain",
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "4.3.2.1.blah.arpa",
Qtype: dns.TypeANY,
Qclass: dns.ClassINET,
},
},
},
response: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
Response: true,
Authoritative: true,
Rcode: dns.RcodeNameError,
},
Compress: true,
Question: []dns.Question{
{
Name: "4.3.2.1.blah.arpa.",
Qtype: dns.TypeANY,
Qclass: dns.ClassINET,
},
},
Ns: []dns.RR{
&dns.SOA{
Hdr: dns.RR_Header{
Name: "consul.",
Rrtype: dns.TypeSOA,
Class: dns.ClassINET,
Ttl: 4,
},
Ns: "ns.consul.",
Serial: uint32(time.Now().Unix()),
Mbox: "hostmaster.consul.",
Refresh: 1,
Expire: 3,
Retry: 2,
Minttl: 4,
},
},
},
},
{
name: "[ENT] PTR Lookup for node w/ peer name in default partition, query type is ANY",
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "4.3.2.1.in-addr.arpa",
Qtype: dns.TypeANY,
Qclass: dns.ClassINET,
},
},
},
configureDataFetcher: func(fetcher discovery.CatalogDataFetcher) {
results := []*discovery.Result{
{
Node: &discovery.Location{Name: "foo", Address: "1.2.3.4"},
Type: discovery.ResultTypeNode,
Service: &discovery.Location{Name: "foo-web", Address: "foo"},
Tenancy: discovery.ResultTenancy{
Datacenter: "dc2",
PeerName: "peer1",
Partition: "default",
},
},
}
fetcher.(*discovery.MockCatalogDataFetcher).
On("FetchRecordsByIp", mock.Anything, mock.Anything).
Return(results, nil).
Run(func(args mock.Arguments) {
req := args.Get(1).(net.IP)
require.NotNil(t, req)
require.Equal(t, "1.2.3.4", req.String())
})
},
response: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
Response: true,
Authoritative: true,
},
Compress: true,
Question: []dns.Question{
{
Name: "4.3.2.1.in-addr.arpa.",
Qtype: dns.TypeANY,
Qclass: dns.ClassINET,
},
},
Answer: []dns.RR{
&dns.PTR{
Hdr: dns.RR_Header{
Name: "4.3.2.1.in-addr.arpa.",
Rrtype: dns.TypePTR,
Class: dns.ClassINET,
},
Ptr: "foo.node.peer1.peer.default.ap.consul.",
},
},
},
},
{
name: "[ENT] PTR Lookup for service in default namespace, query type is PTR",
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "4.3.2.1.in-addr.arpa",
Qtype: dns.TypePTR,
Qclass: dns.ClassINET,
},
},
},
configureDataFetcher: func(fetcher discovery.CatalogDataFetcher) {
results := []*discovery.Result{
{
Node: &discovery.Location{Name: "foo", Address: "1.2.3.4"},
Type: discovery.ResultTypeService,
Service: &discovery.Location{Name: "foo", Address: "foo"},
Tenancy: discovery.ResultTenancy{
Datacenter: "dc2",
Namespace: "default",
},
},
}
fetcher.(*discovery.MockCatalogDataFetcher).
On("FetchRecordsByIp", mock.Anything, mock.Anything).
Return(results, nil).
Run(func(args mock.Arguments) {
req := args.Get(1).(net.IP)
require.NotNil(t, req)
require.Equal(t, "1.2.3.4", req.String())
})
},
response: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
Response: true,
Authoritative: true,
},
Compress: true,
Question: []dns.Question{
{
Name: "4.3.2.1.in-addr.arpa.",
Qtype: dns.TypePTR,
Qclass: dns.ClassINET,
},
},
Answer: []dns.RR{
&dns.PTR{
Hdr: dns.RR_Header{
Name: "4.3.2.1.in-addr.arpa.",
Rrtype: dns.TypePTR,
Class: dns.ClassINET,
},
Ptr: "foo.service.default.dc2.consul.",
},
},
},
},
{
name: "[ENT] PTR Lookup for service in a non-default namespace, query type is PTR",
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "4.3.2.1.in-addr.arpa",
Qtype: dns.TypePTR,
Qclass: dns.ClassINET,
},
},
},
configureDataFetcher: func(fetcher discovery.CatalogDataFetcher) {
results := []*discovery.Result{
{
Node: &discovery.Location{Name: "foo-node", Address: "1.2.3.4"},
Type: discovery.ResultTypeService,
Service: &discovery.Location{Name: "foo", Address: "foo"},
Tenancy: discovery.ResultTenancy{
Datacenter: "dc2",
Namespace: "bar",
Partition: "baz",
},
},
}
fetcher.(*discovery.MockCatalogDataFetcher).
On("FetchRecordsByIp", mock.Anything, mock.Anything).
Return(results, nil).
Run(func(args mock.Arguments) {
req := args.Get(1).(net.IP)
require.NotNil(t, req)
require.Equal(t, "1.2.3.4", req.String())
})
},
response: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
Response: true,
Authoritative: true,
},
Compress: true,
Question: []dns.Question{
{
Name: "4.3.2.1.in-addr.arpa.",
Qtype: dns.TypePTR,
Qclass: dns.ClassINET,
},
},
Answer: []dns.RR{
&dns.PTR{
Hdr: dns.RR_Header{
Name: "4.3.2.1.in-addr.arpa.",
Rrtype: dns.TypePTR,
Class: dns.ClassINET,
},
Ptr: "foo.service.bar.dc2.consul.",
},
},
},
},
{
name: "[CE] PTR Lookup for node w/ peer name, query type is ANY",
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "4.3.2.1.in-addr.arpa",
Qtype: dns.TypeANY,
Qclass: dns.ClassINET,
},
},
},
configureDataFetcher: func(fetcher discovery.CatalogDataFetcher) {
results := []*discovery.Result{
{
Node: &discovery.Location{Name: "foo", Address: "1.2.3.4"},
Type: discovery.ResultTypeNode,
Service: &discovery.Location{Name: "foo", Address: "foo"},
Tenancy: discovery.ResultTenancy{
Datacenter: "dc2",
PeerName: "peer1",
},
},
}
fetcher.(*discovery.MockCatalogDataFetcher).
On("FetchRecordsByIp", mock.Anything, mock.Anything).
Return(results, nil).
Run(func(args mock.Arguments) {
req := args.Get(1).(net.IP)
require.NotNil(t, req)
require.Equal(t, "1.2.3.4", req.String())
})
},
response: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
Response: true,
Authoritative: true,
},
Compress: true,
Question: []dns.Question{
{
Name: "4.3.2.1.in-addr.arpa.",
Qtype: dns.TypeANY,
Qclass: dns.ClassINET,
},
},
Answer: []dns.RR{
&dns.PTR{
Hdr: dns.RR_Header{
Name: "4.3.2.1.in-addr.arpa.",
Rrtype: dns.TypePTR,
Class: dns.ClassINET,
},
Ptr: "foo.node.peer1.peer.consul.",
},
},
},
},
{
name: "[CE] PTR Lookup for service, query type is PTR",
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "4.3.2.1.in-addr.arpa",
Qtype: dns.TypePTR,
Qclass: dns.ClassINET,
},
},
},
configureDataFetcher: func(fetcher discovery.CatalogDataFetcher) {
results := []*discovery.Result{
{
Node: &discovery.Location{Name: "foo", Address: "1.2.3.4"},
Service: &discovery.Location{Name: "foo", Address: "foo"},
Type: discovery.ResultTypeService,
Tenancy: discovery.ResultTenancy{
Datacenter: "dc2",
},
},
}
fetcher.(*discovery.MockCatalogDataFetcher).
On("FetchRecordsByIp", mock.Anything, mock.Anything).
Return(results, nil).
Run(func(args mock.Arguments) {
req := args.Get(1).(net.IP)
require.NotNil(t, req)
require.Equal(t, "1.2.3.4", req.String())
})
},
response: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
Response: true,
Authoritative: true,
},
Compress: true,
Question: []dns.Question{
{
Name: "4.3.2.1.in-addr.arpa.",
Qtype: dns.TypePTR,
Qclass: dns.ClassINET,
},
},
Answer: []dns.RR{
&dns.PTR{
Hdr: dns.RR_Header{
Name: "4.3.2.1.in-addr.arpa.",
Rrtype: dns.TypePTR,
Class: dns.ClassINET,
},
Ptr: "foo.service.dc2.consul.",
},
},
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
runHandleTestCases(t, tc)
})
}
}

View File

@ -1,296 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package dns
import (
"errors"
"github.com/hashicorp/consul/agent/config"
"github.com/miekg/dns"
"github.com/stretchr/testify/mock"
"net"
"testing"
)
func Test_HandleRequest_recursor(t *testing.T) {
testCases := []HandleTestCase{
{
name: "recursors not configured, non-matching domain",
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "google.com",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
},
// configureRecursor: call not expected.
response: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
Response: true,
Rcode: dns.RcodeRefused,
},
Question: []dns.Question{
{
Name: "google.com.",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
},
},
{
name: "recursors configured, matching domain",
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "google.com",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
},
agentConfig: &config.RuntimeConfig{
DNSRecursors: []string{"8.8.8.8"},
DNSUDPAnswerLimit: maxUDPAnswerLimit,
},
configureRecursor: func(recursor dnsRecursor) {
resp := &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
Response: true,
Authoritative: true,
Rcode: dns.RcodeSuccess,
},
Question: []dns.Question{
{
Name: "google.com.",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
Answer: []dns.RR{
&dns.A{
Hdr: dns.RR_Header{
Name: "google.com.",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
},
A: net.ParseIP("1.2.3.4"),
},
},
}
recursor.(*mockDnsRecursor).On("handle",
mock.Anything, mock.Anything, mock.Anything).Return(resp, nil)
},
response: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
Response: true,
Authoritative: true,
Rcode: dns.RcodeSuccess,
},
Question: []dns.Question{
{
Name: "google.com.",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
Answer: []dns.RR{
&dns.A{
Hdr: dns.RR_Header{
Name: "google.com.",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
},
A: net.ParseIP("1.2.3.4"),
},
},
},
},
{
name: "recursors configured, no matching domain",
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "google.com",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
},
agentConfig: &config.RuntimeConfig{
DNSRecursors: []string{"8.8.8.8"},
DNSUDPAnswerLimit: maxUDPAnswerLimit,
},
configureRecursor: func(recursor dnsRecursor) {
recursor.(*mockDnsRecursor).On("handle", mock.Anything, mock.Anything, mock.Anything).
Return(nil, errRecursionFailed)
},
response: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
Response: true,
Authoritative: false,
Rcode: dns.RcodeServerFailure,
RecursionAvailable: true,
},
Compress: true,
Question: []dns.Question{
{
Name: "google.com.",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
},
},
{
name: "recursors configured, unhandled error calling recursors",
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "google.com",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
},
agentConfig: &config.RuntimeConfig{
DNSRecursors: []string{"8.8.8.8"},
DNSUDPAnswerLimit: maxUDPAnswerLimit,
},
configureRecursor: func(recursor dnsRecursor) {
err := errors.New("ahhhhh!!!!")
recursor.(*mockDnsRecursor).On("handle", mock.Anything, mock.Anything, mock.Anything).
Return(nil, err)
},
response: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
Response: true,
Authoritative: false,
Rcode: dns.RcodeServerFailure,
RecursionAvailable: true,
},
Compress: true,
Question: []dns.Question{
{
Name: "google.com.",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
},
},
{
name: "recursors configured, the root domain is handled by the recursor",
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: ".",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
},
agentConfig: &config.RuntimeConfig{
DNSRecursors: []string{"8.8.8.8"},
DNSUDPAnswerLimit: maxUDPAnswerLimit,
},
configureRecursor: func(recursor dnsRecursor) {
// this response is modeled after `dig .`
resp := &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
Response: true,
Authoritative: true,
Rcode: dns.RcodeSuccess,
},
Question: []dns.Question{
{
Name: ".",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
Answer: []dns.RR{
&dns.SOA{
Hdr: dns.RR_Header{
Name: ".",
Rrtype: dns.TypeSOA,
Class: dns.ClassINET,
Ttl: 86391,
},
Ns: "a.root-servers.net.",
Serial: 2024012200,
Mbox: "nstld.verisign-grs.com.",
Refresh: 1800,
Retry: 900,
Expire: 604800,
Minttl: 86400,
},
},
}
recursor.(*mockDnsRecursor).On("handle",
mock.Anything, mock.Anything, mock.Anything).Return(resp, nil)
},
response: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
Response: true,
Authoritative: true,
Rcode: dns.RcodeSuccess,
},
Question: []dns.Question{
{
Name: ".",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
Answer: []dns.RR{
&dns.SOA{
Hdr: dns.RR_Header{
Name: ".",
Rrtype: dns.TypeSOA,
Class: dns.ClassINET,
Ttl: 86391,
},
Ns: "a.root-servers.net.",
Serial: 2024012200,
Mbox: "nstld.verisign-grs.com.",
Refresh: 1800,
Retry: 900,
Expire: 604800,
Minttl: 86400,
},
},
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
runHandleTestCases(t, tc)
})
}
}

View File

@ -1,241 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package dns
import (
"net"
"testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/hashicorp/consul/agent/discovery"
)
func Test_HandleRequest_ServiceQuestions(t *testing.T) {
testCases := []HandleTestCase{
// Service Lookup
{
name: "When no data is return from a query, send SOA",
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "foo.service.consul.",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
},
configureDataFetcher: func(fetcher discovery.CatalogDataFetcher) {
fetcher.(*discovery.MockCatalogDataFetcher).
On("FetchEndpoints", mock.Anything, mock.Anything, mock.Anything).
Return(nil, discovery.ErrNoData).
Run(func(args mock.Arguments) {
req := args.Get(1).(*discovery.QueryPayload)
reqType := args.Get(2).(discovery.LookupType)
require.Equal(t, discovery.LookupTypeService, reqType)
require.Equal(t, "foo", req.Name)
})
},
validateAndNormalizeExpected: true,
response: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
Response: true,
Authoritative: true,
Rcode: dns.RcodeSuccess,
},
Compress: true,
Question: []dns.Question{
{
Name: "foo.service.consul.",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
Ns: []dns.RR{
&dns.SOA{
Hdr: dns.RR_Header{
Name: "consul.",
Rrtype: dns.TypeSOA,
Class: dns.ClassINET,
Ttl: 4,
},
Ns: "ns.consul.",
Serial: uint32(time.Now().Unix()),
Mbox: "hostmaster.consul.",
Refresh: 1,
Expire: 3,
Retry: 2,
Minttl: 4,
},
},
},
},
{
// TestDNS_ExternalServiceToConsulCNAMELookup
name: "req type: service / question type: SRV / CNAME required: no",
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "alias.service.consul.",
Qtype: dns.TypeSRV,
},
},
},
configureDataFetcher: func(fetcher discovery.CatalogDataFetcher) {
fetcher.(*discovery.MockCatalogDataFetcher).
On("FetchEndpoints", mock.Anything,
&discovery.QueryPayload{
Name: "alias",
Tenancy: discovery.QueryTenancy{},
}, discovery.LookupTypeService).
Return([]*discovery.Result{
{
Type: discovery.ResultTypeVirtual,
Service: &discovery.Location{Name: "alias", Address: "web.service.consul"},
Node: &discovery.Location{Name: "web", Address: "web.service.consul"},
},
},
nil).On("FetchEndpoints", mock.Anything,
&discovery.QueryPayload{
Name: "web",
Tenancy: discovery.QueryTenancy{},
}, discovery.LookupTypeService).
Return([]*discovery.Result{
{
Type: discovery.ResultTypeNode,
Service: &discovery.Location{Name: "web", Address: "webnode"},
Node: &discovery.Location{Name: "webnode", Address: "127.0.0.2"},
},
}, nil).On("ValidateRequest", mock.Anything,
mock.Anything).Return(nil).On("NormalizeRequest", mock.Anything)
},
response: &dns.Msg{
MsgHdr: dns.MsgHdr{
Response: true,
Authoritative: true,
},
Compress: true,
Question: []dns.Question{
{
Name: "alias.service.consul.",
Qtype: dns.TypeSRV,
},
},
Answer: []dns.RR{
&dns.SRV{
Hdr: dns.RR_Header{
Name: "alias.service.consul.",
Rrtype: dns.TypeSRV,
Class: dns.ClassINET,
Ttl: 123,
},
Target: "web.service.consul.",
Priority: 1,
},
},
Extra: []dns.RR{
&dns.A{
Hdr: dns.RR_Header{
Name: "web.service.consul.",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 123,
},
A: net.ParseIP("127.0.0.2"),
},
},
},
},
{
name: "req type: service / question type: SRV / CNAME required: no - multiple results without Node address + tags",
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "tag.foo.service.consul.",
Qtype: dns.TypeSRV,
},
},
},
configureDataFetcher: func(fetcher discovery.CatalogDataFetcher) {
fetcher.(*discovery.MockCatalogDataFetcher).
On("FetchEndpoints", mock.Anything,
&discovery.QueryPayload{
Name: "foo",
Tenancy: discovery.QueryTenancy{},
Tag: "tag",
}, discovery.LookupTypeService).
Return([]*discovery.Result{
{
// This result emulates an allocation registration with Nomad.
// The node name is generated by Consul and shares the service IP
Type: discovery.ResultTypeService,
Service: &discovery.Location{Name: "foo", Address: "127.0.0.1"},
Node: &discovery.Location{Name: "Node-9e46a487-f5be-2f40-ad60-5a10e32237ed", Address: "127.0.0.1"},
Tenancy: discovery.ResultTenancy{
Datacenter: "dc1",
},
},
},
nil).On("ValidateRequest", mock.Anything,
mock.Anything).Return(nil).On("NormalizeRequest", mock.Anything)
},
response: &dns.Msg{
MsgHdr: dns.MsgHdr{
Response: true,
Authoritative: true,
},
Compress: true,
Question: []dns.Question{
{
Name: "tag.foo.service.consul.",
Qtype: dns.TypeSRV,
},
},
Answer: []dns.RR{
&dns.SRV{
Hdr: dns.RR_Header{
Name: "tag.foo.service.consul.",
Rrtype: dns.TypeSRV,
Class: dns.ClassINET,
Ttl: 123,
},
Target: "7f000001.addr.dc1.consul.",
Priority: 1,
},
},
Extra: []dns.RR{
&dns.A{
Hdr: dns.RR_Header{
Name: "7f000001.addr.dc1.consul.",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 123,
},
A: net.ParseIP("127.0.0.1"),
},
},
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
runHandleTestCases(t, tc)
})
}
}

View File

@ -1,277 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package dns
import (
"net"
"testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/hashicorp/consul/agent/config"
"github.com/hashicorp/consul/agent/discovery"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/internal/resource"
)
func Test_HandleRequest_SOA(t *testing.T) {
testCases := []HandleTestCase{
{
name: "vanilla SOA query",
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "consul.",
Qtype: dns.TypeSOA,
Qclass: dns.ClassINET,
},
},
},
configureDataFetcher: func(fetcher discovery.CatalogDataFetcher) {
fetcher.(*discovery.MockCatalogDataFetcher).
On("FetchEndpoints", mock.Anything, mock.Anything, mock.Anything).
Return([]*discovery.Result{
{
Node: &discovery.Location{Name: "server-one", Address: "1.2.3.4"},
Type: discovery.ResultTypeWorkload,
Tenancy: discovery.ResultTenancy{
Namespace: resource.DefaultNamespaceName,
Partition: resource.DefaultPartitionName,
},
},
{
Node: &discovery.Location{Name: "server-two", Address: "4.5.6.7"},
Type: discovery.ResultTypeWorkload,
Tenancy: discovery.ResultTenancy{
Namespace: resource.DefaultNamespaceName,
Partition: resource.DefaultPartitionName,
},
},
}, nil).
Run(func(args mock.Arguments) {
req := args.Get(1).(*discovery.QueryPayload)
reqType := args.Get(2).(discovery.LookupType)
require.Equal(t, discovery.LookupTypeService, reqType)
require.Equal(t, structs.ConsulServiceName, req.Name)
require.Equal(t, 3, req.Limit)
})
},
validateAndNormalizeExpected: true,
response: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
Response: true,
Authoritative: true,
},
Compress: true,
Question: []dns.Question{
{
Name: "consul.",
Qtype: dns.TypeSOA,
Qclass: dns.ClassINET,
},
},
Answer: []dns.RR{
&dns.SOA{
Hdr: dns.RR_Header{
Name: "consul.",
Rrtype: dns.TypeSOA,
Class: dns.ClassINET,
Ttl: 4,
},
Ns: "ns.consul.",
Serial: uint32(time.Now().Unix()),
Mbox: "hostmaster.consul.",
Refresh: 1,
Expire: 3,
Retry: 2,
Minttl: 4,
},
},
Ns: []dns.RR{
&dns.NS{
Hdr: dns.RR_Header{
Name: "consul.",
Rrtype: dns.TypeNS,
Class: dns.ClassINET,
Ttl: 123,
},
Ns: "server-one.workload.default.ns.default.ap.consul.",
},
&dns.NS{
Hdr: dns.RR_Header{
Name: "consul.",
Rrtype: dns.TypeNS,
Class: dns.ClassINET,
Ttl: 123,
},
Ns: "server-two.workload.default.ns.default.ap.consul.",
},
},
Extra: []dns.RR{
&dns.A{
Hdr: dns.RR_Header{
Name: "server-one.workload.default.ns.default.ap.consul.",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 123,
},
A: net.ParseIP("1.2.3.4"),
},
&dns.A{
Hdr: dns.RR_Header{
Name: "server-two.workload.default.ns.default.ap.consul.",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 123,
},
A: net.ParseIP("4.5.6.7"),
},
},
},
},
{
name: "SOA query against alternate domain",
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "testdomain.",
Qtype: dns.TypeSOA,
Qclass: dns.ClassINET,
},
},
},
agentConfig: &config.RuntimeConfig{
DNSDomain: "consul",
DNSAltDomain: "testdomain",
DNSNodeTTL: 123 * time.Second,
DNSSOA: config.RuntimeSOAConfig{
Refresh: 1,
Retry: 2,
Expire: 3,
Minttl: 4,
},
DNSUDPAnswerLimit: maxUDPAnswerLimit,
},
configureDataFetcher: func(fetcher discovery.CatalogDataFetcher) {
fetcher.(*discovery.MockCatalogDataFetcher).
On("FetchEndpoints", mock.Anything, mock.Anything, mock.Anything).
Return([]*discovery.Result{
{
Node: &discovery.Location{Name: "server-one", Address: "1.2.3.4"},
Type: discovery.ResultTypeWorkload,
Tenancy: discovery.ResultTenancy{
Namespace: resource.DefaultNamespaceName,
Partition: resource.DefaultPartitionName,
},
},
{
Node: &discovery.Location{Name: "server-two", Address: "4.5.6.7"},
Type: discovery.ResultTypeWorkload,
Tenancy: discovery.ResultTenancy{
Namespace: resource.DefaultNamespaceName,
Partition: resource.DefaultPartitionName,
}},
}, nil).
Run(func(args mock.Arguments) {
req := args.Get(1).(*discovery.QueryPayload)
reqType := args.Get(2).(discovery.LookupType)
require.Equal(t, discovery.LookupTypeService, reqType)
require.Equal(t, structs.ConsulServiceName, req.Name)
require.Equal(t, 3, req.Limit)
})
},
validateAndNormalizeExpected: true,
response: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
Response: true,
Authoritative: true,
},
Compress: true,
Question: []dns.Question{
{
Name: "testdomain.",
Qtype: dns.TypeSOA,
Qclass: dns.ClassINET,
},
},
Answer: []dns.RR{
&dns.SOA{
Hdr: dns.RR_Header{
Name: "testdomain.",
Rrtype: dns.TypeSOA,
Class: dns.ClassINET,
Ttl: 4,
},
Ns: "ns.testdomain.",
Serial: uint32(time.Now().Unix()),
Mbox: "hostmaster.testdomain.",
Refresh: 1,
Expire: 3,
Retry: 2,
Minttl: 4,
},
},
Ns: []dns.RR{
&dns.NS{
Hdr: dns.RR_Header{
Name: "testdomain.",
Rrtype: dns.TypeNS,
Class: dns.ClassINET,
Ttl: 123,
},
Ns: "server-one.workload.default.ns.default.ap.testdomain.",
},
&dns.NS{
Hdr: dns.RR_Header{
Name: "testdomain.",
Rrtype: dns.TypeNS,
Class: dns.ClassINET,
Ttl: 123,
},
Ns: "server-two.workload.default.ns.default.ap.testdomain.",
},
},
Extra: []dns.RR{
&dns.A{
Hdr: dns.RR_Header{
Name: "server-one.workload.default.ns.default.ap.testdomain.",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 123,
},
A: net.ParseIP("1.2.3.4"),
},
&dns.A{
Hdr: dns.RR_Header{
Name: "server-two.workload.default.ns.default.ap.testdomain.",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 123,
},
A: net.ParseIP("4.5.6.7"),
},
},
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
runHandleTestCases(t, tc)
})
}
}

View File

@ -1,842 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package dns
import (
"fmt"
"net"
"reflect"
"testing"
"time"
"github.com/armon/go-radix"
"github.com/hashicorp/consul/internal/dnsutil"
"github.com/miekg/dns"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/config"
"github.com/hashicorp/consul/agent/discovery"
"github.com/hashicorp/consul/agent/structs"
)
// HandleTestCase is a test case for the HandleRequest function.
// Tests for HandleRequest are split into multiple files to make it easier to
// manage and understand the tests. Other test files are:
// - router_addr_test.go
// - router_ns_test.go
// - router_prepared_query_test.go
// - router_ptr_test.go
// - router_recursor_test.go
// - router_service_test.go
// - router_soa_test.go
// - router_virtual_test.go
// - router_v2_services_test.go
// - router_workload_test.go
type HandleTestCase struct {
name string
agentConfig *config.RuntimeConfig // This will override the default test Router Config
configureDataFetcher func(fetcher discovery.CatalogDataFetcher)
validateAndNormalizeExpected bool
configureRecursor func(recursor dnsRecursor)
mockProcessorError error
request *dns.Msg
requestContext *Context
remoteAddress net.Addr
response *dns.Msg
}
func Test_HandleRequest_Validation(t *testing.T) {
testCases := []HandleTestCase{
{
name: "request with empty message",
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{},
},
validateAndNormalizeExpected: false,
response: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
Response: true,
Authoritative: false,
Rcode: dns.RcodeRefused,
},
Compress: false,
Question: nil,
Answer: nil,
Ns: nil,
Extra: nil,
},
},
// Context Tests
{
name: "When a request context is provided, use those field in the query",
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "foo.service.consul.",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
},
requestContext: &Context{
Token: "test-token",
DefaultNamespace: "test-namespace",
DefaultPartition: "test-partition",
},
configureDataFetcher: func(fetcher discovery.CatalogDataFetcher) {
result := []*discovery.Result{
{
Type: discovery.ResultTypeNode,
Node: &discovery.Location{Name: "foo", Address: "1.2.3.4"},
Tenancy: discovery.ResultTenancy{
Namespace: "test-namespace",
Partition: "test-partition",
},
},
}
fetcher.(*discovery.MockCatalogDataFetcher).
On("FetchEndpoints", mock.Anything, mock.Anything, mock.Anything).
Return(result, nil).
Run(func(args mock.Arguments) {
ctx := args.Get(0).(discovery.Context)
req := args.Get(1).(*discovery.QueryPayload)
reqType := args.Get(2).(discovery.LookupType)
require.Equal(t, "test-token", ctx.Token)
require.Equal(t, "foo", req.Name)
require.Equal(t, "test-namespace", req.Tenancy.Namespace)
require.Equal(t, "test-partition", req.Tenancy.Partition)
require.Equal(t, discovery.LookupTypeService, reqType)
})
},
validateAndNormalizeExpected: true,
response: &dns.Msg{
MsgHdr: dns.MsgHdr{
Response: true,
Authoritative: true,
},
Compress: true,
Question: []dns.Question{
{
Name: "foo.service.consul.",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
Answer: []dns.RR{
&dns.A{
Hdr: dns.RR_Header{
Name: "foo.service.consul.",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 123,
},
A: net.ParseIP("1.2.3.4"),
},
},
},
},
{
name: "When a request context is provided, values do not override explicit tenancy",
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "foo.service.bar.ns.baz.ap.consul.",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
},
requestContext: &Context{
Token: "test-token",
DefaultNamespace: "test-namespace",
DefaultPartition: "test-partition",
},
configureDataFetcher: func(fetcher discovery.CatalogDataFetcher) {
result := []*discovery.Result{
{
Type: discovery.ResultTypeNode,
Node: &discovery.Location{Name: "foo", Address: "1.2.3.4"},
Tenancy: discovery.ResultTenancy{
Namespace: "bar",
Partition: "baz",
},
},
}
fetcher.(*discovery.MockCatalogDataFetcher).
On("FetchEndpoints", mock.Anything, mock.Anything, mock.Anything).
Return(result, nil).
Run(func(args mock.Arguments) {
ctx := args.Get(0).(discovery.Context)
req := args.Get(1).(*discovery.QueryPayload)
reqType := args.Get(2).(discovery.LookupType)
require.Equal(t, "test-token", ctx.Token)
require.Equal(t, "foo", req.Name)
require.Equal(t, "bar", req.Tenancy.Namespace)
require.Equal(t, "baz", req.Tenancy.Partition)
require.Equal(t, discovery.LookupTypeService, reqType)
})
},
validateAndNormalizeExpected: true,
response: &dns.Msg{
MsgHdr: dns.MsgHdr{
Response: true,
Authoritative: true,
},
Compress: true,
Question: []dns.Question{
{
Name: "foo.service.bar.ns.baz.ap.consul.",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
Answer: []dns.RR{
&dns.A{
Hdr: dns.RR_Header{
Name: "foo.service.bar.ns.baz.ap.consul.",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 123,
},
A: net.ParseIP("1.2.3.4"),
},
},
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
runHandleTestCases(t, tc)
})
}
}
// runHandleTestCases runs the test cases for the HandleRequest function.
func runHandleTestCases(t *testing.T, tc HandleTestCase) {
cdf := discovery.NewMockCatalogDataFetcher(t)
if tc.validateAndNormalizeExpected {
cdf.On("ValidateRequest", mock.Anything, mock.Anything).Return(nil)
cdf.On("NormalizeRequest", mock.Anything).Return()
}
if tc.configureDataFetcher != nil {
tc.configureDataFetcher(cdf)
}
cfg := buildDNSConfig(tc.agentConfig, cdf, tc.mockProcessorError)
router, err := NewRouter(cfg)
require.NoError(t, err)
// Replace the recursor with a mock and configure
router.recursor = newMockDnsRecursor(t)
if tc.configureRecursor != nil {
tc.configureRecursor(router.recursor)
}
ctx := tc.requestContext
if ctx == nil {
ctx = &Context{}
}
var remoteAddress net.Addr
if tc.remoteAddress != nil {
remoteAddress = tc.remoteAddress
} else {
remoteAddress = &net.UDPAddr{}
}
actual := router.HandleRequest(tc.request, *ctx, remoteAddress)
require.Equal(t, tc.response, actual)
}
func TestRouterDynamicConfig_GetTTLForService(t *testing.T) {
type testCase struct {
name string
inputKey string
shouldMatch bool
expectedDuration time.Duration
}
testCases := []testCase{
{
name: "strict match",
inputKey: "foo",
shouldMatch: true,
expectedDuration: 1 * time.Second,
},
{
name: "wildcard match",
inputKey: "bar",
shouldMatch: true,
expectedDuration: 2 * time.Second,
},
{
name: "wildcard match 2",
inputKey: "bart",
shouldMatch: true,
expectedDuration: 2 * time.Second,
},
{
name: "no match",
inputKey: "homer",
shouldMatch: false,
expectedDuration: 0 * time.Second,
},
}
rtCfg := &config.RuntimeConfig{
DNSServiceTTL: map[string]time.Duration{
"foo": 1 * time.Second,
"bar*": 2 * time.Second,
},
}
cfg, err := getDynamicRouterConfig(rtCfg)
require.NoError(t, err)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
actual, ok := cfg.GetTTLForService(tc.inputKey)
require.Equal(t, tc.shouldMatch, ok)
require.Equal(t, tc.expectedDuration, actual)
})
}
}
func buildDNSConfig(agentConfig *config.RuntimeConfig, cdf discovery.CatalogDataFetcher, _ error) Config {
cfg := Config{
AgentConfig: &config.RuntimeConfig{
DNSDomain: "consul",
DNSNodeTTL: 123 * time.Second,
DNSSOA: config.RuntimeSOAConfig{
Refresh: 1,
Retry: 2,
Expire: 3,
Minttl: 4,
},
DNSUDPAnswerLimit: maxUDPAnswerLimit,
},
EntMeta: acl.EnterpriseMeta{},
Logger: hclog.NewNullLogger(),
Processor: discovery.NewQueryProcessor(cdf),
TokenFunc: func() string { return "" },
TranslateServiceAddressFunc: func(dc string, address string, taggedAddresses map[string]structs.ServiceAddress, accept dnsutil.TranslateAddressAccept) string {
return address
},
TranslateAddressFunc: func(dc string, addr string, taggedAddresses map[string]string, accept dnsutil.TranslateAddressAccept) string {
return addr
},
}
if agentConfig != nil {
cfg.AgentConfig = agentConfig
}
return cfg
}
// TestDNS_BinaryTruncate tests the dnsBinaryTruncate function.
func TestDNS_BinaryTruncate(t *testing.T) {
msgSrc := new(dns.Msg)
msgSrc.Compress = true
msgSrc.SetQuestion("redis.service.consul.", dns.TypeSRV)
for i := 0; i < 5000; i++ {
target := fmt.Sprintf("host-redis-%d-%d.test.acme.com.node.dc1.consul.", i/256, i%256)
msgSrc.Answer = append(msgSrc.Answer, &dns.SRV{Hdr: dns.RR_Header{Name: "redis.service.consul.", Class: 1, Rrtype: dns.TypeSRV, Ttl: 0x3c}, Port: 0x4c57, Target: target})
msgSrc.Extra = append(msgSrc.Extra, &dns.CNAME{Hdr: dns.RR_Header{Name: target, Class: 1, Rrtype: dns.TypeCNAME, Ttl: 0x3c}, Target: fmt.Sprintf("fx.168.%d.%d.", i/256, i%256)})
}
for _, compress := range []bool{true, false} {
for idx, maxSize := range []int{12, 256, 512, 8192, 65535} {
t.Run(fmt.Sprintf("binarySearch %d", maxSize), func(t *testing.T) {
msg := new(dns.Msg)
msgSrc.Compress = compress
msgSrc.SetQuestion("redis.service.consul.", dns.TypeSRV)
msg.Answer = msgSrc.Answer
msg.Extra = msgSrc.Extra
msg.Ns = msgSrc.Ns
index := make(map[string]dns.RR, len(msg.Extra))
indexRRs(msg.Extra, index)
blen := dnsBinaryTruncate(msg, maxSize, index, true)
msg.Answer = msg.Answer[:blen]
syncExtra(index, msg)
predicted := msg.Len()
buf, err := msg.Pack()
if err != nil {
t.Error(err)
}
if predicted < len(buf) {
t.Fatalf("Bug in DNS library: %d != %d", predicted, len(buf))
}
if len(buf) > maxSize || (idx != 0 && len(buf) < 16) {
t.Fatalf("bad[%d]: %d > %d", idx, len(buf), maxSize)
}
})
}
}
}
// TestDNS_syncExtra tests the syncExtra function.
func TestDNS_syncExtra(t *testing.T) {
resp := &dns.Msg{
Answer: []dns.RR{
// These two are on the same host so the redundant extra
// records should get deduplicated.
&dns.SRV{
Hdr: dns.RR_Header{
Name: "redis-cache-redis.service.consul.",
Rrtype: dns.TypeSRV,
Class: dns.ClassINET,
},
Port: 1001,
Target: "ip-10-0-1-185.node.dc1.consul.",
},
&dns.SRV{
Hdr: dns.RR_Header{
Name: "redis-cache-redis.service.consul.",
Rrtype: dns.TypeSRV,
Class: dns.ClassINET,
},
Port: 1002,
Target: "ip-10-0-1-185.node.dc1.consul.",
},
// This one isn't in the Consul domain so it will get a
// CNAME and then an A record from the recursor.
&dns.SRV{
Hdr: dns.RR_Header{
Name: "redis-cache-redis.service.consul.",
Rrtype: dns.TypeSRV,
Class: dns.ClassINET,
},
Port: 1003,
Target: "demo.consul.io.",
},
// This one isn't in the Consul domain and it will get
// a CNAME and A record from a recursor that alters the
// case of the name. This proves we look up in the index
// in a case-insensitive way.
&dns.SRV{
Hdr: dns.RR_Header{
Name: "redis-cache-redis.service.consul.",
Rrtype: dns.TypeSRV,
Class: dns.ClassINET,
},
Port: 1001,
Target: "insensitive.consul.io.",
},
// This is also a CNAME, but it'll be set up to loop to
// make sure we don't crash.
&dns.SRV{
Hdr: dns.RR_Header{
Name: "redis-cache-redis.service.consul.",
Rrtype: dns.TypeSRV,
Class: dns.ClassINET,
},
Port: 1001,
Target: "deadly.consul.io.",
},
// This is also a CNAME, but it won't have another record.
&dns.SRV{
Hdr: dns.RR_Header{
Name: "redis-cache-redis.service.consul.",
Rrtype: dns.TypeSRV,
Class: dns.ClassINET,
},
Port: 1001,
Target: "nope.consul.io.",
},
},
Extra: []dns.RR{
// These should get deduplicated.
&dns.A{
Hdr: dns.RR_Header{
Name: "ip-10-0-1-185.node.dc1.consul.",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
},
A: net.ParseIP("10.0.1.185"),
},
&dns.A{
Hdr: dns.RR_Header{
Name: "ip-10-0-1-185.node.dc1.consul.",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
},
A: net.ParseIP("10.0.1.185"),
},
// This is a normal CNAME followed by an A record but we
// have flipped the order. The algorithm should emit them
// in the opposite order.
&dns.A{
Hdr: dns.RR_Header{
Name: "fakeserver.consul.io.",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
},
A: net.ParseIP("127.0.0.1"),
},
&dns.CNAME{
Hdr: dns.RR_Header{
Name: "demo.consul.io.",
Rrtype: dns.TypeCNAME,
Class: dns.ClassINET,
},
Target: "fakeserver.consul.io.",
},
// These differ in case to test case insensitivity.
&dns.CNAME{
Hdr: dns.RR_Header{
Name: "INSENSITIVE.CONSUL.IO.",
Rrtype: dns.TypeCNAME,
Class: dns.ClassINET,
},
Target: "Another.Server.Com.",
},
&dns.A{
Hdr: dns.RR_Header{
Name: "another.server.com.",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
},
A: net.ParseIP("127.0.0.1"),
},
// This doesn't appear in the answer, so should get
// dropped.
&dns.A{
Hdr: dns.RR_Header{
Name: "ip-10-0-1-186.node.dc1.consul.",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
},
A: net.ParseIP("10.0.1.186"),
},
// These two test edge cases with CNAME handling.
&dns.CNAME{
Hdr: dns.RR_Header{
Name: "deadly.consul.io.",
Rrtype: dns.TypeCNAME,
Class: dns.ClassINET,
},
Target: "deadly.consul.io.",
},
&dns.CNAME{
Hdr: dns.RR_Header{
Name: "nope.consul.io.",
Rrtype: dns.TypeCNAME,
Class: dns.ClassINET,
},
Target: "notthere.consul.io.",
},
},
}
index := make(map[string]dns.RR)
indexRRs(resp.Extra, index)
syncExtra(index, resp)
expected := &dns.Msg{
Answer: resp.Answer,
Extra: []dns.RR{
&dns.A{
Hdr: dns.RR_Header{
Name: "ip-10-0-1-185.node.dc1.consul.",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
},
A: net.ParseIP("10.0.1.185"),
},
&dns.CNAME{
Hdr: dns.RR_Header{
Name: "demo.consul.io.",
Rrtype: dns.TypeCNAME,
Class: dns.ClassINET,
},
Target: "fakeserver.consul.io.",
},
&dns.A{
Hdr: dns.RR_Header{
Name: "fakeserver.consul.io.",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
},
A: net.ParseIP("127.0.0.1"),
},
&dns.CNAME{
Hdr: dns.RR_Header{
Name: "INSENSITIVE.CONSUL.IO.",
Rrtype: dns.TypeCNAME,
Class: dns.ClassINET,
},
Target: "Another.Server.Com.",
},
&dns.A{
Hdr: dns.RR_Header{
Name: "another.server.com.",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
},
A: net.ParseIP("127.0.0.1"),
},
&dns.CNAME{
Hdr: dns.RR_Header{
Name: "deadly.consul.io.",
Rrtype: dns.TypeCNAME,
Class: dns.ClassINET,
},
Target: "deadly.consul.io.",
},
&dns.CNAME{
Hdr: dns.RR_Header{
Name: "nope.consul.io.",
Rrtype: dns.TypeCNAME,
Class: dns.ClassINET,
},
Target: "notthere.consul.io.",
},
},
}
if !reflect.DeepEqual(resp, expected) {
t.Fatalf("Bad %#v vs. %#v", *resp, *expected)
}
}
// getUint32Ptr return the pointer of an uint32 literal
func getUint32Ptr(i uint32) *uint32 {
return &i
}
func TestRouter_ReloadConfig(t *testing.T) {
cdf := discovery.NewMockCatalogDataFetcher(t)
cfg := buildDNSConfig(nil, cdf, nil)
router, err := NewRouter(cfg)
require.NoError(t, err)
router.recursor = newMockDnsRecursor(t)
// Reload the config
newAgentConfig := &config.RuntimeConfig{
DNSARecordLimit: 123,
DNSEnableTruncate: true,
DNSNodeTTL: 234,
DNSRecursorStrategy: "strategy-123",
DNSRecursorTimeout: 345,
DNSUDPAnswerLimit: 456,
DNSNodeMetaTXT: true,
DNSDisableCompression: true,
DNSSOA: config.RuntimeSOAConfig{
Expire: 123,
Minttl: 234,
Refresh: 345,
Retry: 456,
},
DNSServiceTTL: map[string]time.Duration{
"wildcard-config-*": 123,
"strict-config": 234,
},
DNSRecursors: []string{
"8.8.8.8",
"2001:4860:4860::8888",
},
}
expectTTLRadix := radix.New()
expectTTLRadix.Insert("wildcard-config-", time.Duration(123))
expectedCfg := &RouterDynamicConfig{
ARecordLimit: 123,
EnableTruncate: true,
NodeTTL: 234,
RecursorStrategy: "strategy-123",
RecursorTimeout: 345,
UDPAnswerLimit: 456,
NodeMetaTXT: true,
DisableCompression: true,
SOAConfig: SOAConfig{
Expire: 123,
Minttl: 234,
Refresh: 345,
Retry: 456,
},
TTLRadix: expectTTLRadix,
TTLStrict: map[string]time.Duration{
"strict-config": 234,
},
Recursors: []string{
"8.8.8.8:53",
"[2001:4860:4860::8888]:53",
},
}
err = router.ReloadConfig(newAgentConfig)
require.NoError(t, err)
savedCfg := router.dynamicConfig.Load().(*RouterDynamicConfig)
// Ensure the new config is used
require.Equal(t, expectedCfg, savedCfg)
}
func Test_isPTRSubdomain(t *testing.T) {
testCases := []struct {
name string
domain string
expected bool
}{
{
name: "empty domain returns false",
domain: "",
expected: false,
},
{
name: "last label is 'arpa' returns true",
domain: "my-addr.arpa.",
expected: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
actual := isPTRSubdomain(tc.domain)
require.Equal(t, tc.expected, actual)
})
}
}
func Test_isAddrSubdomain(t *testing.T) {
testCases := []struct {
name string
domain string
expected bool
}{
{
name: "empty domain returns false",
domain: "",
expected: false,
},
{
name: "'c000020a.addr.dc1.consul.' returns true",
domain: "c000020a.addr.dc1.consul.",
expected: true,
},
{
name: "'c000020a.addr.consul.' returns true",
domain: "c000020a.addr.consul.",
expected: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
actual := isAddrSubdomain(tc.domain)
require.Equal(t, tc.expected, actual)
})
}
}
func Test_stripAnyFailoverSuffix(t *testing.T) {
testCases := []struct {
name string
target string
expectedEnableFailover bool
expectedResult string
}{
{
name: "my-addr.service.dc1.consul.failover. returns 'my-addr.service.dc1.consul' and true",
target: "my-addr.service.dc1.consul.failover.",
expectedEnableFailover: true,
expectedResult: "my-addr.service.dc1.consul.",
},
{
name: "my-addr.service.dc1.consul.no-failover. returns 'my-addr.service.dc1.consul' and false",
target: "my-addr.service.dc1.consul.no-failover.",
expectedEnableFailover: false,
expectedResult: "my-addr.service.dc1.consul.",
},
{
name: "my-addr.service.dc1.consul. returns 'my-addr.service.dc1.consul' and false",
target: "my-addr.service.dc1.consul.",
expectedEnableFailover: false,
expectedResult: "my-addr.service.dc1.consul.",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
actual, actualEnableFailover := stripAnyFailoverSuffix(tc.target)
require.Equal(t, tc.expectedEnableFailover, actualEnableFailover)
require.Equal(t, tc.expectedResult, actual)
})
}
}
func Test_trimDomain(t *testing.T) {
testCases := []struct {
name string
domain string
altDomain string
questionName string
expectedResult string
}{
{
name: "given domain is 'consul.' and altDomain is 'my.consul.', when calling trimDomain with 'my-service.my.consul.', it returns 'my-service.'",
questionName: "my-service.my.consul.",
domain: "consul.",
altDomain: "my.consul.",
expectedResult: "my-service.",
},
{
name: "given domain is 'consul.' and altDomain is 'my.consul.', when calling trimDomain with 'my-service.consul.', it returns 'my-service.'",
questionName: "my-service.consul.",
domain: "consul.",
altDomain: "my.consul.",
expectedResult: "my-service.",
},
{
name: "given domain is 'consul.' and altDomain is 'my-consul.', when calling trimDomain with 'my-service.consul.', it returns 'my-service.'",
questionName: "my-service.consul.",
domain: "consul.",
altDomain: "my-consul.",
expectedResult: "my-service.",
},
{
name: "given domain is 'consul.' and altDomain is 'my-consul.', when calling trimDomain with 'my-service.my-consul.', it returns 'my-service.'",
questionName: "my-service.my-consul.",
domain: "consul.",
altDomain: "my-consul.",
expectedResult: "my-service.",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
router := Router{
domain: tc.domain,
altDomain: tc.altDomain,
}
actual := router.trimDomain(tc.questionName)
require.Equal(t, tc.expectedResult, actual)
})
}
}

View File

@ -1,628 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package dns
import (
"net"
"testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/hashicorp/consul/agent/config"
"github.com/hashicorp/consul/agent/discovery"
"github.com/hashicorp/consul/internal/resource"
)
func Test_HandleRequest_V2Services(t *testing.T) {
testCases := []HandleTestCase{
{
name: "A/AAAA Query a service and return multiple A records",
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "foo.service.consul.",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
},
configureDataFetcher: func(fetcher discovery.CatalogDataFetcher) {
results := []*discovery.Result{
{
Node: &discovery.Location{Name: "foo-1", Address: "10.0.0.1"},
Type: discovery.ResultTypeWorkload,
Tenancy: discovery.ResultTenancy{
Namespace: resource.DefaultNamespaceName,
Partition: resource.DefaultPartitionName,
},
Ports: []discovery.Port{
{
Name: "api",
Number: 5678,
},
// Intentionally not in the mesh
},
DNS: discovery.DNSConfig{
Weight: 2,
},
},
{
Node: &discovery.Location{Name: "foo-2", Address: "10.0.0.2"},
Type: discovery.ResultTypeWorkload,
Tenancy: discovery.ResultTenancy{
Namespace: resource.DefaultNamespaceName,
Partition: resource.DefaultPartitionName,
},
Ports: []discovery.Port{
{
Name: "api",
Number: 5678,
},
{
Name: "mesh",
Number: 21000,
},
},
DNS: discovery.DNSConfig{
Weight: 3,
},
},
}
fetcher.(*discovery.MockCatalogDataFetcher).
On("FetchEndpoints", mock.Anything, mock.Anything, mock.Anything).
Return(results, nil).
Run(func(args mock.Arguments) {
req := args.Get(1).(*discovery.QueryPayload)
reqType := args.Get(2).(discovery.LookupType)
require.Equal(t, "foo", req.Name)
require.Empty(t, req.PortName)
require.Equal(t, discovery.LookupTypeService, reqType)
})
},
validateAndNormalizeExpected: true,
response: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
Response: true,
Authoritative: true,
},
Compress: true,
Question: []dns.Question{
{
Name: "foo.service.consul.",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
Answer: []dns.RR{
&dns.A{
Hdr: dns.RR_Header{
Name: "foo.service.consul.",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: uint32(123),
},
A: net.ParseIP("10.0.0.1"),
},
&dns.A{
Hdr: dns.RR_Header{
Name: "foo.service.consul.",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: uint32(123),
},
A: net.ParseIP("10.0.0.2"),
},
},
},
},
{
name: "SRV Query with a multi-port service return multiple SRV records",
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "foo.service.consul.",
Qtype: dns.TypeSRV,
Qclass: dns.ClassINET,
},
},
},
configureDataFetcher: func(fetcher discovery.CatalogDataFetcher) {
results := []*discovery.Result{
{
Node: &discovery.Location{Name: "foo-1", Address: "10.0.0.1"},
Type: discovery.ResultTypeWorkload,
Tenancy: discovery.ResultTenancy{
Namespace: resource.DefaultNamespaceName,
Partition: resource.DefaultPartitionName,
},
Ports: []discovery.Port{
{
Name: "api",
Number: 5678,
},
// Intentionally not in the mesh
},
DNS: discovery.DNSConfig{
Weight: 2,
},
},
{
Node: &discovery.Location{Name: "foo-2", Address: "10.0.0.2"},
Type: discovery.ResultTypeWorkload,
Tenancy: discovery.ResultTenancy{
Namespace: resource.DefaultNamespaceName,
Partition: resource.DefaultPartitionName,
},
Ports: []discovery.Port{
{
Name: "api",
Number: 5678,
},
{
Name: "mesh",
Number: 21000,
},
},
DNS: discovery.DNSConfig{
Weight: 3,
},
},
}
fetcher.(*discovery.MockCatalogDataFetcher).
On("FetchEndpoints", mock.Anything, mock.Anything, mock.Anything).
Return(results, nil).
Run(func(args mock.Arguments) {
req := args.Get(1).(*discovery.QueryPayload)
reqType := args.Get(2).(discovery.LookupType)
require.Equal(t, "foo", req.Name)
require.Empty(t, req.PortName)
require.Equal(t, discovery.LookupTypeService, reqType)
})
},
validateAndNormalizeExpected: true,
response: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
Response: true,
Authoritative: true,
},
Compress: true,
Question: []dns.Question{
{
Name: "foo.service.consul.",
Qtype: dns.TypeSRV,
Qclass: dns.ClassINET,
},
},
Answer: []dns.RR{
&dns.SRV{
Hdr: dns.RR_Header{
Name: "foo.service.consul.",
Rrtype: dns.TypeSRV,
Class: dns.ClassINET,
Ttl: uint32(123),
},
Weight: 2,
Priority: 1,
Port: 5678,
Target: "api.port.foo-1.workload.default.ns.default.ap.consul.",
},
&dns.SRV{
Hdr: dns.RR_Header{
Name: "foo.service.consul.",
Rrtype: dns.TypeSRV,
Class: dns.ClassINET,
Ttl: uint32(123),
},
Weight: 3,
Priority: 1,
Port: 5678,
Target: "api.port.foo-2.workload.default.ns.default.ap.consul.",
},
&dns.SRV{
Hdr: dns.RR_Header{
Name: "foo.service.consul.",
Rrtype: dns.TypeSRV,
Class: dns.ClassINET,
Ttl: uint32(123),
},
Weight: 3,
Priority: 1,
Port: 21000,
Target: "mesh.port.foo-2.workload.default.ns.default.ap.consul.",
},
},
Extra: []dns.RR{
&dns.A{
Hdr: dns.RR_Header{
Name: "api.port.foo-1.workload.default.ns.default.ap.consul.",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: uint32(123),
},
A: net.ParseIP("10.0.0.1"),
},
&dns.A{
Hdr: dns.RR_Header{
Name: "api.port.foo-2.workload.default.ns.default.ap.consul.",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: uint32(123),
},
A: net.ParseIP("10.0.0.2"),
},
&dns.A{
Hdr: dns.RR_Header{
Name: "mesh.port.foo-2.workload.default.ns.default.ap.consul.",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: uint32(123),
},
A: net.ParseIP("10.0.0.2"),
},
},
},
},
{
name: "SRV Query with a multi-port service where the client requests a specific port, returns SRV and A records",
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "mesh.port.foo.service.consul.",
Qtype: dns.TypeSRV,
Qclass: dns.ClassINET,
},
},
},
configureDataFetcher: func(fetcher discovery.CatalogDataFetcher) {
results := []*discovery.Result{
{
Node: &discovery.Location{Name: "foo-2", Address: "10.0.0.2"},
Type: discovery.ResultTypeWorkload,
Tenancy: discovery.ResultTenancy{
Namespace: resource.DefaultNamespaceName,
Partition: resource.DefaultPartitionName,
},
Ports: []discovery.Port{
{
Name: "mesh",
Number: 21000,
},
},
DNS: discovery.DNSConfig{
Weight: 3,
},
},
}
fetcher.(*discovery.MockCatalogDataFetcher).
On("FetchEndpoints", mock.Anything, mock.Anything, mock.Anything).
Return(results, nil).
Run(func(args mock.Arguments) {
req := args.Get(1).(*discovery.QueryPayload)
reqType := args.Get(2).(discovery.LookupType)
require.Equal(t, "foo", req.Name)
require.Equal(t, "mesh", req.PortName)
require.Equal(t, discovery.LookupTypeService, reqType)
})
},
validateAndNormalizeExpected: true,
response: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
Response: true,
Authoritative: true,
},
Compress: true,
Question: []dns.Question{
{
Name: "mesh.port.foo.service.consul.",
Qtype: dns.TypeSRV,
Qclass: dns.ClassINET,
},
},
Answer: []dns.RR{
&dns.SRV{
Hdr: dns.RR_Header{
Name: "mesh.port.foo.service.consul.",
Rrtype: dns.TypeSRV,
Class: dns.ClassINET,
Ttl: uint32(123),
},
Weight: 3,
Priority: 1,
Port: 21000,
Target: "mesh.port.foo-2.workload.default.ns.default.ap.consul.",
},
},
Extra: []dns.RR{
&dns.A{
Hdr: dns.RR_Header{
Name: "mesh.port.foo-2.workload.default.ns.default.ap.consul.",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: uint32(123),
},
A: net.ParseIP("10.0.0.2"),
},
},
},
},
{
name: "SRV Query with a multi-port service that has workloads w/ hostnames (no recursors)",
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "foo.service.consul.",
Qtype: dns.TypeSRV,
Qclass: dns.ClassINET,
},
},
},
configureDataFetcher: func(fetcher discovery.CatalogDataFetcher) {
results := []*discovery.Result{
{
Node: &discovery.Location{Name: "foo-1", Address: "foo-1.example.com"},
Type: discovery.ResultTypeWorkload,
Tenancy: discovery.ResultTenancy{
Namespace: resource.DefaultNamespaceName,
Partition: resource.DefaultPartitionName,
},
Ports: []discovery.Port{
{
Name: "api",
Number: 5678,
},
{
Name: "web",
Number: 8080,
},
},
DNS: discovery.DNSConfig{
Weight: 2,
},
},
}
fetcher.(*discovery.MockCatalogDataFetcher).
On("FetchEndpoints", mock.Anything, mock.Anything, mock.Anything).
Return(results, nil).
Run(func(args mock.Arguments) {
req := args.Get(1).(*discovery.QueryPayload)
reqType := args.Get(2).(discovery.LookupType)
require.Equal(t, "foo", req.Name)
require.Empty(t, req.PortName)
require.Equal(t, discovery.LookupTypeService, reqType)
})
},
validateAndNormalizeExpected: true,
response: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
Response: true,
Authoritative: true,
},
Compress: true,
Question: []dns.Question{
{
Name: "foo.service.consul.",
Qtype: dns.TypeSRV,
Qclass: dns.ClassINET,
},
},
Answer: []dns.RR{
&dns.SRV{
Hdr: dns.RR_Header{
Name: "foo.service.consul.",
Rrtype: dns.TypeSRV,
Class: dns.ClassINET,
Ttl: uint32(123),
},
Weight: 2,
Priority: 1,
Port: 5678,
Target: "foo-1.example.com.",
},
&dns.SRV{
Hdr: dns.RR_Header{
Name: "foo.service.consul.",
Rrtype: dns.TypeSRV,
Class: dns.ClassINET,
Ttl: uint32(123),
},
Weight: 2,
Priority: 1,
Port: 8080,
Target: "foo-1.example.com.",
},
},
},
},
{
name: "SRV Query with a multi-port service that has workloads w/ hostnames (with recursor)",
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "foo.service.consul.",
Qtype: dns.TypeSRV,
Qclass: dns.ClassINET,
},
},
},
configureDataFetcher: func(fetcher discovery.CatalogDataFetcher) {
results := []*discovery.Result{
{
Node: &discovery.Location{Name: "foo-1", Address: "foo-1.example.com"},
Type: discovery.ResultTypeWorkload,
Tenancy: discovery.ResultTenancy{
Namespace: resource.DefaultNamespaceName,
Partition: resource.DefaultPartitionName,
},
Ports: []discovery.Port{
{
Name: "api",
Number: 5678,
},
{
Name: "web",
Number: 8080,
},
},
DNS: discovery.DNSConfig{
Weight: 2,
},
},
}
fetcher.(*discovery.MockCatalogDataFetcher).
On("FetchEndpoints", mock.Anything, mock.Anything, mock.Anything).
Return(results, nil).
Run(func(args mock.Arguments) {
req := args.Get(1).(*discovery.QueryPayload)
reqType := args.Get(2).(discovery.LookupType)
require.Equal(t, "foo", req.Name)
require.Empty(t, req.PortName)
require.Equal(t, discovery.LookupTypeService, reqType)
})
},
agentConfig: &config.RuntimeConfig{
DNSRecursors: []string{"8.8.8.8"},
DNSDomain: "consul",
DNSNodeTTL: 123 * time.Second,
DNSSOA: config.RuntimeSOAConfig{
Refresh: 1,
Retry: 2,
Expire: 3,
Minttl: 4,
},
DNSUDPAnswerLimit: maxUDPAnswerLimit,
},
configureRecursor: func(recursor dnsRecursor) {
resp := &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
Response: true,
Authoritative: true,
Rcode: dns.RcodeSuccess,
},
Question: []dns.Question{
{
Name: "foo-1.example.com.",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
Answer: []dns.RR{
&dns.A{
Hdr: dns.RR_Header{
Name: "foo-1.example.com.",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
},
A: net.ParseIP("1.2.3.4"),
},
},
}
recursor.(*mockDnsRecursor).On("handle",
mock.Anything, mock.Anything, mock.Anything).Return(resp, nil)
},
validateAndNormalizeExpected: true,
response: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
Response: true,
Authoritative: true,
RecursionAvailable: true,
},
Compress: true,
Question: []dns.Question{
{
Name: "foo.service.consul.",
Qtype: dns.TypeSRV,
Qclass: dns.ClassINET,
},
},
Answer: []dns.RR{
&dns.SRV{
Hdr: dns.RR_Header{
Name: "foo.service.consul.",
Rrtype: dns.TypeSRV,
Class: dns.ClassINET,
Ttl: uint32(123),
},
Weight: 2,
Priority: 1,
Port: 5678,
Target: "foo-1.example.com.",
},
&dns.SRV{
Hdr: dns.RR_Header{
Name: "foo.service.consul.",
Rrtype: dns.TypeSRV,
Class: dns.ClassINET,
Ttl: uint32(123),
},
Weight: 2,
Priority: 1,
Port: 8080,
Target: "foo-1.example.com.",
},
},
Extra: []dns.RR{
&dns.A{
Hdr: dns.RR_Header{
Name: "foo-1.example.com.",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: uint32(123),
},
A: net.ParseIP("1.2.3.4"),
},
// TODO (v2-dns): This needs to be de-duplicated (NET-8064)
&dns.A{
Hdr: dns.RR_Header{
Name: "foo-1.example.com.",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: uint32(123),
},
A: net.ParseIP("1.2.3.4"),
},
},
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
runHandleTestCases(t, tc)
})
}
}

View File

@ -1,122 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package dns
import (
"github.com/hashicorp/consul/agent/discovery"
"github.com/miekg/dns"
"github.com/stretchr/testify/mock"
"net"
"testing"
)
func Test_HandleRequest_Virtual(t *testing.T) {
testCases := []HandleTestCase{
{
name: "test A 'virtual.' query, ipv4 response",
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "c000020a.virtual.dc1.consul", // "intentionally missing the trailing dot"
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
},
configureDataFetcher: func(fetcher discovery.CatalogDataFetcher) {
fetcher.(*discovery.MockCatalogDataFetcher).On("FetchVirtualIP",
mock.Anything, mock.Anything).Return(&discovery.Result{
Node: &discovery.Location{Address: "240.0.0.2"},
Type: discovery.ResultTypeVirtual,
}, nil)
},
validateAndNormalizeExpected: true,
response: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
Response: true,
Authoritative: true,
},
Compress: true,
Question: []dns.Question{
{
Name: "c000020a.virtual.dc1.consul.",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
Answer: []dns.RR{
&dns.A{
Hdr: dns.RR_Header{
Name: "c000020a.virtual.dc1.consul.",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 123,
},
A: net.ParseIP("240.0.0.2"),
},
},
},
},
{
name: "test A 'virtual.' query, ipv6 response",
// Since we asked for an A record, the AAAA record that resolves from the address is attached as an extra
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "20010db800010002cafe000000001337.virtual.dc1.consul",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
},
configureDataFetcher: func(fetcher discovery.CatalogDataFetcher) {
fetcher.(*discovery.MockCatalogDataFetcher).On("FetchVirtualIP",
mock.Anything, mock.Anything).Return(&discovery.Result{
Node: &discovery.Location{Address: "2001:db8:1:2:cafe::1337"},
Type: discovery.ResultTypeVirtual,
}, nil)
},
validateAndNormalizeExpected: true,
response: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
Response: true,
Authoritative: true,
},
Compress: true,
Question: []dns.Question{
{
Name: "20010db800010002cafe000000001337.virtual.dc1.consul.",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
Extra: []dns.RR{
&dns.AAAA{
Hdr: dns.RR_Header{
Name: "20010db800010002cafe000000001337.virtual.dc1.consul.",
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: 123,
},
AAAA: net.ParseIP("2001:db8:1:2:cafe::1337"),
},
},
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
runHandleTestCases(t, tc)
})
}
}

View File

@ -1,515 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package dns
import (
"net"
"testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/hashicorp/consul/agent/config"
"github.com/hashicorp/consul/agent/discovery"
)
func Test_HandleRequest_workloads(t *testing.T) {
testCases := []HandleTestCase{
{
name: "workload A query w/ port, returns A record",
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "api.port.foo.workload.consul.",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
},
configureDataFetcher: func(fetcher discovery.CatalogDataFetcher) {
result := &discovery.Result{
Node: &discovery.Location{Name: "foo", Address: "1.2.3.4"},
Type: discovery.ResultTypeWorkload,
Tenancy: discovery.ResultTenancy{},
Ports: []discovery.Port{
{
Name: "api",
Number: 5678,
},
},
}
fetcher.(*discovery.MockCatalogDataFetcher).
On("FetchWorkload", mock.Anything, mock.Anything).
Return(result, nil). //TODO
Run(func(args mock.Arguments) {
req := args.Get(1).(*discovery.QueryPayload)
require.Equal(t, "foo", req.Name)
require.Equal(t, "api", req.PortName)
})
},
validateAndNormalizeExpected: true,
response: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
Response: true,
Authoritative: true,
},
Compress: true,
Question: []dns.Question{
{
Name: "api.port.foo.workload.consul.",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
Answer: []dns.RR{
&dns.A{
Hdr: dns.RR_Header{
Name: "api.port.foo.workload.consul.",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 123,
},
A: net.ParseIP("1.2.3.4"),
},
},
},
},
{
name: "workload ANY query w/o port, returns A record",
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "foo.workload.consul.",
Qtype: dns.TypeANY,
Qclass: dns.ClassINET,
},
},
},
configureDataFetcher: func(fetcher discovery.CatalogDataFetcher) {
result := &discovery.Result{
Node: &discovery.Location{Name: "foo", Address: "1.2.3.4"},
Type: discovery.ResultTypeWorkload,
Tenancy: discovery.ResultTenancy{},
}
fetcher.(*discovery.MockCatalogDataFetcher).
On("FetchWorkload", mock.Anything, mock.Anything).
Return(result, nil). //TODO
Run(func(args mock.Arguments) {
req := args.Get(1).(*discovery.QueryPayload)
require.Equal(t, "foo", req.Name)
require.Empty(t, req.PortName)
})
},
validateAndNormalizeExpected: true,
response: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
Response: true,
Authoritative: true,
},
Compress: true,
Question: []dns.Question{
{
Name: "foo.workload.consul.",
Qtype: dns.TypeANY,
Qclass: dns.ClassINET,
},
},
Answer: []dns.RR{
&dns.A{
Hdr: dns.RR_Header{
Name: "foo.workload.consul.",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 123,
},
A: net.ParseIP("1.2.3.4"),
},
},
},
},
{
name: "workload A query with namespace, partition, and cluster id; IPV4 address; returns A record",
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "foo.workload.bar.ns.baz.ap.dc3.dc.consul.",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
},
configureDataFetcher: func(fetcher discovery.CatalogDataFetcher) {
result := &discovery.Result{
Node: &discovery.Location{Name: "foo", Address: "1.2.3.4"},
Type: discovery.ResultTypeWorkload,
Tenancy: discovery.ResultTenancy{
Namespace: "bar",
Partition: "baz",
// We currently don't set the datacenter in any of the V2 results.
},
}
fetcher.(*discovery.MockCatalogDataFetcher).
On("FetchWorkload", mock.Anything, mock.Anything).
Return(result, nil).
Run(func(args mock.Arguments) {
req := args.Get(1).(*discovery.QueryPayload)
require.Equal(t, "foo", req.Name)
require.Empty(t, req.PortName)
})
},
validateAndNormalizeExpected: true,
response: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
Response: true,
Authoritative: true,
},
Compress: true,
Question: []dns.Question{
{
Name: "foo.workload.bar.ns.baz.ap.dc3.dc.consul.",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
Answer: []dns.RR{
&dns.A{
Hdr: dns.RR_Header{
Name: "foo.workload.bar.ns.baz.ap.dc3.dc.consul.",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 123,
},
A: net.ParseIP("1.2.3.4"),
},
},
},
},
{
name: "workload w/hostname address, ANY query (no recursor)",
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "api.port.foo.workload.consul.",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
},
configureDataFetcher: func(fetcher discovery.CatalogDataFetcher) {
result := &discovery.Result{
Node: &discovery.Location{Name: "foo", Address: "foo.example.com"},
Type: discovery.ResultTypeWorkload,
Tenancy: discovery.ResultTenancy{},
Ports: []discovery.Port{
{
Name: "api",
Number: 5678,
},
},
}
fetcher.(*discovery.MockCatalogDataFetcher).
On("FetchWorkload", mock.Anything, mock.Anything).
Return(result, nil). //TODO
Run(func(args mock.Arguments) {
req := args.Get(1).(*discovery.QueryPayload)
require.Equal(t, "foo", req.Name)
require.Equal(t, "api", req.PortName)
})
},
validateAndNormalizeExpected: true,
response: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
Response: true,
Authoritative: true,
},
Compress: true,
Question: []dns.Question{
{
Name: "api.port.foo.workload.consul.",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
Answer: []dns.RR{
&dns.CNAME{
Hdr: dns.RR_Header{
Name: "api.port.foo.workload.consul.",
Rrtype: dns.TypeCNAME,
Class: dns.ClassINET,
Ttl: 123,
},
Target: "foo.example.com.",
},
},
},
},
{
name: "workload w/hostname address, ANY query (w/ recursor)",
// https://datatracker.ietf.org/doc/html/rfc1034#section-3.6.2 both the CNAME and the A record should be in the answer
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "api.port.foo.workload.consul.",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
},
configureDataFetcher: func(fetcher discovery.CatalogDataFetcher) {
result := &discovery.Result{
Node: &discovery.Location{Name: "foo", Address: "foo.example.com"},
Type: discovery.ResultTypeWorkload,
Tenancy: discovery.ResultTenancy{},
Ports: []discovery.Port{
{
Name: "api",
Number: 5678,
},
},
}
fetcher.(*discovery.MockCatalogDataFetcher).
On("FetchWorkload", mock.Anything, mock.Anything).
Return(result, nil). //TODO
Run(func(args mock.Arguments) {
req := args.Get(1).(*discovery.QueryPayload)
require.Equal(t, "foo", req.Name)
require.Equal(t, "api", req.PortName)
})
},
agentConfig: &config.RuntimeConfig{
DNSRecursors: []string{"8.8.8.8"},
DNSDomain: "consul",
DNSNodeTTL: 123 * time.Second,
DNSSOA: config.RuntimeSOAConfig{
Refresh: 1,
Retry: 2,
Expire: 3,
Minttl: 4,
},
DNSUDPAnswerLimit: maxUDPAnswerLimit,
},
configureRecursor: func(recursor dnsRecursor) {
resp := &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
Response: true,
Authoritative: true,
Rcode: dns.RcodeSuccess,
},
Question: []dns.Question{
{
Name: "foo.example.com.",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
Answer: []dns.RR{
&dns.A{
Hdr: dns.RR_Header{
Name: "foo.example.com.",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
},
A: net.ParseIP("1.2.3.4"),
},
},
}
recursor.(*mockDnsRecursor).On("handle",
mock.Anything, mock.Anything, mock.Anything).Return(resp, nil)
},
validateAndNormalizeExpected: true,
response: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
Response: true,
Authoritative: true,
RecursionAvailable: true,
},
Compress: true,
Question: []dns.Question{
{
Name: "api.port.foo.workload.consul.",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
Answer: []dns.RR{
&dns.CNAME{
Hdr: dns.RR_Header{
Name: "api.port.foo.workload.consul.",
Rrtype: dns.TypeCNAME,
Class: dns.ClassINET,
Ttl: 123,
},
Target: "foo.example.com.",
},
&dns.A{
Hdr: dns.RR_Header{
Name: "foo.example.com.",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 123,
},
A: net.ParseIP("1.2.3.4"),
},
},
},
},
{
name: "workload w/hostname address, CNAME query (w/ recursor)",
// https://datatracker.ietf.org/doc/html/rfc1034#section-3.6.2 only the CNAME should be in the answer
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "api.port.foo.workload.consul.",
Qtype: dns.TypeCNAME,
Qclass: dns.ClassINET,
},
},
},
configureDataFetcher: func(fetcher discovery.CatalogDataFetcher) {
result := &discovery.Result{
Node: &discovery.Location{Name: "foo", Address: "foo.example.com"},
Type: discovery.ResultTypeWorkload,
Tenancy: discovery.ResultTenancy{},
Ports: []discovery.Port{
{
Name: "api",
Number: 5678,
},
},
}
fetcher.(*discovery.MockCatalogDataFetcher).
On("FetchWorkload", mock.Anything, mock.Anything).
Return(result, nil). //TODO
Run(func(args mock.Arguments) {
req := args.Get(1).(*discovery.QueryPayload)
require.Equal(t, "foo", req.Name)
require.Equal(t, "api", req.PortName)
})
},
agentConfig: &config.RuntimeConfig{
DNSRecursors: []string{"8.8.8.8"},
DNSDomain: "consul",
DNSNodeTTL: 123 * time.Second,
DNSSOA: config.RuntimeSOAConfig{
Refresh: 1,
Retry: 2,
Expire: 3,
Minttl: 4,
},
DNSUDPAnswerLimit: maxUDPAnswerLimit,
},
configureRecursor: func(recursor dnsRecursor) {
resp := &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
Response: true,
Authoritative: true,
Rcode: dns.RcodeSuccess,
},
Question: []dns.Question{
{
Name: "foo.example.com.",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
Answer: []dns.RR{
&dns.A{
Hdr: dns.RR_Header{
Name: "foo.example.com.",
Rrtype: dns.TypeCNAME,
Class: dns.ClassINET,
},
A: net.ParseIP("1.2.3.4"),
},
},
}
recursor.(*mockDnsRecursor).On("handle",
mock.Anything, mock.Anything, mock.Anything).Return(resp, nil)
},
validateAndNormalizeExpected: true,
response: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
Response: true,
Authoritative: true,
RecursionAvailable: true,
},
Compress: true,
Question: []dns.Question{
{
Name: "api.port.foo.workload.consul.",
Qtype: dns.TypeCNAME,
Qclass: dns.ClassINET,
},
},
Answer: []dns.RR{
&dns.CNAME{
Hdr: dns.RR_Header{
Name: "api.port.foo.workload.consul.",
Rrtype: dns.TypeCNAME,
Class: dns.ClassINET,
Ttl: 123,
},
Target: "foo.example.com.",
},
// TODO (v2-dns): this next record is wrong per the RFC-1034 mentioned in the comment above (NET-8060)
&dns.A{
Hdr: dns.RR_Header{
Name: "foo.example.com.",
Rrtype: dns.TypeCNAME,
Class: dns.ClassINET,
Ttl: 123,
},
A: net.ParseIP("1.2.3.4"),
},
},
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
runHandleTestCases(t, tc)
})
}
}

View File

@ -1,106 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package dns
import (
"fmt"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/internal/dnsutil"
"net"
"github.com/miekg/dns"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/config"
"github.com/hashicorp/consul/logging"
)
// DNSRouter is a mock for Router that can be used for testing.
//
//go:generate mockery --name DNSRouter --inpackage
type DNSRouter interface {
HandleRequest(req *dns.Msg, reqCtx Context, remoteAddress net.Addr) *dns.Msg
ServeDNS(w dns.ResponseWriter, req *dns.Msg)
GetConfig() *RouterDynamicConfig
ReloadConfig(newCfg *config.RuntimeConfig) error
}
// Server is used to expose service discovery queries using a DNS interface.
// It implements the agent.dnsServer interface.
type Server struct {
*dns.Server // Used for setting up listeners
Router DNSRouter // Used to routes and parse DNS requests
logger hclog.Logger
}
// Config represent all the DNS configuration required to construct a DNS server.
type Config struct {
AgentConfig *config.RuntimeConfig
EntMeta acl.EnterpriseMeta
Logger hclog.Logger
Processor DiscoveryQueryProcessor
TokenFunc func() string
TranslateAddressFunc func(dc string, addr string, taggedAddresses map[string]string, accept dnsutil.TranslateAddressAccept) string
TranslateServiceAddressFunc func(dc string, address string, taggedAddresses map[string]structs.ServiceAddress, accept dnsutil.TranslateAddressAccept) string
}
// NewServer creates a new DNS server.
func NewServer(config Config) (*Server, error) {
router, err := NewRouter(config)
if err != nil {
return nil, fmt.Errorf("error creating DNS router: %w", err)
}
srv := &Server{
Router: router,
logger: config.Logger.Named(logging.DNS),
}
return srv, nil
}
// ListenAndServe starts the DNS server.
func (d *Server) ListenAndServe(network, addr string, notif func()) error {
d.Server = &dns.Server{
Addr: addr,
Net: network,
Handler: d.Router,
NotifyStartedFunc: notif,
}
if network == "udp" {
d.UDPSize = 65535
}
return d.Server.ListenAndServe()
}
// ReloadConfig hot-reloads the server config with new parameters under config.RuntimeConfig.DNS*
func (d *Server) ReloadConfig(newCfg *config.RuntimeConfig) error {
return d.Router.ReloadConfig(newCfg)
}
// Shutdown stops the DNS server.
func (d *Server) Shutdown() {
if d.Server != nil {
d.logger.Info("Stopping server",
"protocol", "DNS",
"address", d.Server.Addr,
"network", d.Server.Net,
)
err := d.Server.Shutdown()
if err != nil {
d.logger.Error("Error stopping DNS server", "error", err)
}
}
d.Router = nil
}
// GetAddr is a function to return the server address if is not nil.
func (d *Server) GetAddr() string {
if d.Server != nil {
return d.Server.Addr
}
return ""
}

View File

@ -1,78 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package dns
import (
"github.com/hashicorp/consul/agent/config"
"github.com/hashicorp/consul/sdk/testutil"
"github.com/stretchr/testify/require"
"testing"
)
// TestServer_ReloadConfig tests that the ReloadConfig method calls the router's ReloadConfig method.
func TestDNSServer_ReloadConfig(t *testing.T) {
srv, err := NewServer(Config{
AgentConfig: &config.RuntimeConfig{
DNSDomain: "test-domain",
DNSAltDomain: "test-alt-domain",
},
Logger: testutil.Logger(t),
})
srv.Router = NewMockDNSRouter(t)
require.NoError(t, err)
cfg := &config.RuntimeConfig{
DNSARecordLimit: 123,
DNSEnableTruncate: true,
DNSNodeTTL: 123,
DNSRecursorStrategy: "test",
DNSRecursorTimeout: 123,
DNSUDPAnswerLimit: 123,
DNSNodeMetaTXT: true,
DNSDisableCompression: true,
DNSSOA: config.RuntimeSOAConfig{
Expire: 123,
Refresh: 123,
Retry: 123,
Minttl: 123,
},
}
srv.Router.(*MockDNSRouter).On("ReloadConfig", cfg).Return(nil)
err = srv.ReloadConfig(cfg)
require.NoError(t, err)
require.True(t, srv.Router.(*MockDNSRouter).AssertExpectations(t))
}
// TestDNSServer_Lifecycle tests that the server can be started and shutdown.
func TestDNSServer_Lifecycle(t *testing.T) {
// Arrange
srv, err := NewServer(Config{
AgentConfig: &config.RuntimeConfig{
DNSDomain: "test-domain",
DNSAltDomain: "test-alt-domain",
},
Logger: testutil.Logger(t),
})
defer srv.Shutdown()
require.NotNil(t, srv.Router)
require.NoError(t, err)
require.NotNil(t, srv)
ch := make(chan bool)
go func() {
err = srv.ListenAndServe("udp", "127.0.0.1:8500", func() {
ch <- true
})
require.NoError(t, err)
}()
started, ok := <-ch
require.True(t, ok)
require.True(t, started)
require.NotNil(t, srv.Handler)
require.NotNil(t, srv.Handler.(*Router))
require.NotNil(t, srv.PacketConn)
//Shutdown
srv.Shutdown()
require.Nil(t, srv.Router)
}

View File

@ -1,516 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package agent
import (
"context"
"fmt"
"net"
"testing"
"github.com/miekg/dns"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/internal/resource"
pbcatalog "github.com/hashicorp/consul/proto-public/pbcatalog/v2beta1"
"github.com/hashicorp/consul/proto-public/pbresource"
"github.com/hashicorp/consul/sdk/testutil/retry"
"github.com/hashicorp/consul/testrpc"
)
// Similar to TestDNS_ServiceLookup, but removes config for features unsupported in v2 and
// tests against DNS v2 and Catalog v2 explicitly using a resource API client.
func TestDNS_CatalogV2_Basic(t *testing.T) {
if testing.Short() {
t.Skip("too slow for testing.Short")
}
var err error
a := NewTestAgent(t, `experiments=["resource-apis"]`) // v2dns is implicit w/ resource-apis
defer a.Shutdown()
testrpc.WaitForRaftLeader(t, a.RPC, "dc1")
client := a.delegate.ResourceServiceClient()
// Smoke test for `consul-server` service.
readResource(t, client, &pbresource.ID{
Name: structs.ConsulServiceName,
Type: pbcatalog.ServiceType,
Tenancy: resource.DefaultNamespacedTenancy(),
}, new(pbcatalog.Service))
// Register a new service.
dbServiceId := &pbresource.ID{
Name: "db",
Type: pbcatalog.ServiceType,
Tenancy: resource.DefaultNamespacedTenancy(),
}
emptyServiceId := &pbresource.ID{
Name: "empty",
Type: pbcatalog.ServiceType,
Tenancy: resource.DefaultNamespacedTenancy(),
}
dbService := &pbcatalog.Service{
Workloads: &pbcatalog.WorkloadSelector{
Prefixes: []string{"db-"},
},
Ports: []*pbcatalog.ServicePort{
{
TargetPort: "tcp",
Protocol: pbcatalog.Protocol_PROTOCOL_TCP,
},
{
TargetPort: "admin",
Protocol: pbcatalog.Protocol_PROTOCOL_HTTP,
},
{
TargetPort: "mesh",
Protocol: pbcatalog.Protocol_PROTOCOL_MESH,
},
},
}
emptyService := &pbcatalog.Service{
Workloads: &pbcatalog.WorkloadSelector{
Prefixes: []string{"empty-"},
},
Ports: []*pbcatalog.ServicePort{
{
TargetPort: "tcp",
Protocol: pbcatalog.Protocol_PROTOCOL_TCP,
},
{
TargetPort: "admin",
Protocol: pbcatalog.Protocol_PROTOCOL_HTTP,
},
{
TargetPort: "mesh",
Protocol: pbcatalog.Protocol_PROTOCOL_MESH,
},
},
}
dbServiceResource := &pbresource.Resource{
Id: dbServiceId,
Data: toAny(t, dbService),
}
emptyServiceResource := &pbresource.Resource{
Id: emptyServiceId,
Data: toAny(t, emptyService),
}
for _, r := range []*pbresource.Resource{dbServiceResource, emptyServiceResource} {
_, err := client.Write(context.Background(), &pbresource.WriteRequest{Resource: r})
if err != nil {
t.Fatalf("failed to create the %s service: %v", r.Id.Name, err)
}
}
// Validate services written.
readResource(t, client, dbServiceId, new(pbcatalog.Service))
readResource(t, client, emptyServiceId, new(pbcatalog.Service))
// Register workloads.
dbWorkloadId1 := &pbresource.ID{
Name: "db-1",
Type: pbcatalog.WorkloadType,
Tenancy: resource.DefaultNamespacedTenancy(),
}
dbWorkloadId2 := &pbresource.ID{
Name: "db-2",
Type: pbcatalog.WorkloadType,
Tenancy: resource.DefaultNamespacedTenancy(),
}
dbWorkloadId3 := &pbresource.ID{
Name: "db-3",
Type: pbcatalog.WorkloadType,
Tenancy: resource.DefaultNamespacedTenancy(),
}
dbWorkloadPorts := map[string]*pbcatalog.WorkloadPort{
"tcp": {
Port: 12345,
Protocol: pbcatalog.Protocol_PROTOCOL_TCP,
},
"admin": {
Port: 23456,
Protocol: pbcatalog.Protocol_PROTOCOL_HTTP,
},
"mesh": {
Port: 20000,
Protocol: pbcatalog.Protocol_PROTOCOL_MESH,
},
}
dbWorkloadFn := func(ip string) *pbcatalog.Workload {
return &pbcatalog.Workload{
Addresses: []*pbcatalog.WorkloadAddress{
{
Host: ip,
},
},
Identity: "test-identity",
Ports: dbWorkloadPorts,
}
}
dbWorkload1 := dbWorkloadFn("172.16.1.1")
_, err = client.Write(context.Background(), &pbresource.WriteRequest{Resource: &pbresource.Resource{
Id: dbWorkloadId1,
Data: toAny(t, dbWorkload1),
}})
if err != nil {
t.Fatalf("failed to create the %s workload: %v", dbWorkloadId1.Name, err)
}
dbWorkload2 := dbWorkloadFn("172.16.1.2")
_, err = client.Write(context.Background(), &pbresource.WriteRequest{Resource: &pbresource.Resource{
Id: dbWorkloadId2,
Data: toAny(t, dbWorkload2),
}})
if err != nil {
t.Fatalf("failed to create the %s workload: %v", dbWorkloadId2.Name, err)
}
dbWorkload3 := dbWorkloadFn("2001:db8:85a3::8a2e:370:7334") // test IPv6
_, err = client.Write(context.Background(), &pbresource.WriteRequest{Resource: &pbresource.Resource{
Id: dbWorkloadId3,
Data: toAny(t, dbWorkload3),
}})
if err != nil {
t.Fatalf("failed to create the %s workload: %v", dbWorkloadId2.Name, err)
}
// Validate workloads written.
dbWorkloads := make(map[string]*pbcatalog.Workload)
dbWorkloads["db-1"] = readResource(t, client, dbWorkloadId1, new(pbcatalog.Workload)).(*pbcatalog.Workload)
dbWorkloads["db-2"] = readResource(t, client, dbWorkloadId2, new(pbcatalog.Workload)).(*pbcatalog.Workload)
dbWorkloads["db-3"] = readResource(t, client, dbWorkloadId3, new(pbcatalog.Workload)).(*pbcatalog.Workload)
// Ensure endpoints exist and have health status, which is required for inclusion in DNS results.
retry.Run(t, func(r *retry.R) {
endpoints := readResource(r, client, resource.ReplaceType(pbcatalog.ServiceEndpointsType, dbServiceId), new(pbcatalog.ServiceEndpoints)).(*pbcatalog.ServiceEndpoints)
require.Equal(r, 3, len(endpoints.GetEndpoints()))
for _, e := range endpoints.GetEndpoints() {
require.True(r,
// We only return results for passing and warning health checks.
e.HealthStatus == pbcatalog.Health_HEALTH_PASSING || e.HealthStatus == pbcatalog.Health_HEALTH_WARNING,
fmt.Sprintf("unexpected health status: %v", e.HealthStatus))
}
})
// Test UDP and TCP clients.
for _, client := range []*dns.Client{
newDNSClient(false),
newDNSClient(true),
} {
// Lookup a service without matching workloads, we should receive an SOA and no answers.
questions := []string{
"empty.service.consul.",
"_empty._tcp.service.consul.",
}
for _, question := range questions {
for _, dnsType := range []uint16{dns.TypeSRV, dns.TypeA, dns.TypeAAAA} {
m := new(dns.Msg)
m.SetQuestion(question, dnsType)
in, _, err := client.Exchange(m, a.DNSAddr())
if err != nil {
t.Fatalf("err: %v", err)
}
require.Equal(t, 0, len(in.Answer), "Bad: %s", in.String())
require.Equal(t, 0, len(in.Extra), "Bad: %s", in.String())
require.Equal(t, 1, len(in.Ns), "Bad: %s", in.String())
soaRec, ok := in.Ns[0].(*dns.SOA)
require.True(t, ok, "Bad: %s", in.Ns[0].String())
require.EqualValues(t, 0, soaRec.Hdr.Ttl, "Bad: %s", in.Ns[0].String())
}
}
// Look up the service directly including all ports.
questions = []string{
"db.service.consul.",
"_db._tcp.service.consul.", // RFC 2782 query. All ports are TCP, so this should return the same result.
}
for _, question := range questions {
m := new(dns.Msg)
m.SetQuestion(question, dns.TypeSRV)
in, _, err := client.Exchange(m, a.DNSAddr())
if err != nil {
t.Fatalf("err: %v", err)
}
// This check only runs for a TCP client because a UDP client will truncate the response.
if client.Net == "tcp" {
for portName, port := range dbWorkloadPorts {
for workloadName, workload := range dbWorkloads {
workloadTarget := fmt.Sprintf("%s.port.%s.workload.default.ns.default.ap.consul.", portName, workloadName)
workloadHost := workload.Addresses[0].Host
srvRec := findSrvAnswerForTarget(t, in, workloadTarget)
require.EqualValues(t, port.Port, srvRec.Port, "Bad: %s", srvRec.String())
require.EqualValues(t, 0, srvRec.Hdr.Ttl, "Bad: %s", srvRec.String())
a := findAorAAAAForName(t, in, in.Extra, workloadTarget)
require.Equal(t, workloadHost, a.AorAAAA.String(), "Bad: %s", a.Original.String())
require.EqualValues(t, 0, a.Hdr.Ttl, "Bad: %s", a.Original.String())
}
}
// Expect 1 result per port, per workload.
require.Equal(t, 9, len(in.Answer), "answer count did not match expected\n\n%s", in.String())
require.Equal(t, 9, len(in.Extra), "extra answer count did not match expected\n\n%s", in.String())
} else {
// Expect 1 result per port, per workload, up to the default limit of 3. In practice the results are truncated
// at 2 because of the record byte size.
require.Equal(t, 2, len(in.Answer), "answer count did not match expected\n\n%s", in.String())
require.Equal(t, 2, len(in.Extra), "extra answer count did not match expected\n\n%s", in.String())
}
}
// Look up the service directly by each port.
for portName, port := range dbWorkloadPorts {
question := fmt.Sprintf("%s.port.db.service.consul.", portName)
for workloadName, workload := range dbWorkloads {
workloadTarget := fmt.Sprintf("%s.port.%s.workload.default.ns.default.ap.consul.", portName, workloadName)
workloadHost := workload.Addresses[0].Host
m := new(dns.Msg)
m.SetQuestion(question, dns.TypeSRV)
in, _, err := client.Exchange(m, a.DNSAddr())
if err != nil {
t.Fatalf("err: %v", err)
}
srvRec := findSrvAnswerForTarget(t, in, workloadTarget)
require.EqualValues(t, port.Port, srvRec.Port, "Bad: %s", srvRec.String())
require.EqualValues(t, 0, srvRec.Hdr.Ttl, "Bad: %s", srvRec.String())
a := findAorAAAAForName(t, in, in.Extra, workloadTarget)
require.Equal(t, workloadHost, a.AorAAAA.String(), "Bad: %s", a.Original.String())
require.EqualValues(t, 0, a.Hdr.Ttl, "Bad: %s", a.Original.String())
// Expect 1 result per port.
require.Equal(t, 3, len(in.Answer), "answer count did not match expected\n\n%s", in.String())
require.Equal(t, 3, len(in.Extra), "extra answer count did not match expected\n\n%s", in.String())
}
}
// Look up A/AAAA by service.
questions = []string{
"db.service.consul.",
}
for _, question := range questions {
for workloadName, dnsType := range map[string]uint16{
"db-1": dns.TypeA,
"db-2": dns.TypeA,
"db-3": dns.TypeAAAA,
} {
workload := dbWorkloads[workloadName]
m := new(dns.Msg)
m.SetQuestion(question, dnsType)
in, _, err := client.Exchange(m, a.DNSAddr())
if err != nil {
t.Fatalf("err: %v", err)
}
workloadHost := workload.Addresses[0].Host
a := findAorAAAAForAddress(t, in, in.Answer, workloadHost)
require.Equal(t, question, a.Hdr.Name, "Bad: %s", a.Original.String())
require.EqualValues(t, 0, a.Hdr.Ttl, "Bad: %s", a.Original.String())
// Expect 1 answer per workload. For A records, expect 2 answers because there's 2 IPv4 workloads.
if dnsType == dns.TypeA {
require.Equal(t, 2, len(in.Answer), "answer count did not match expected\n\n%s", in.String())
} else {
require.Equal(t, 1, len(in.Answer), "answer count did not match expected\n\n%s", in.String())
}
require.Equal(t, 0, len(in.Extra), "extra answer count did not match expected\n\n%s", in.String())
}
}
// Lookup a non-existing service, we should receive an SOA.
questions = []string{
"nodb.service.consul.",
"nope.query.consul.", // prepared query is not supported in v2
}
for _, question := range questions {
m := new(dns.Msg)
m.SetQuestion(question, dns.TypeSRV)
in, _, err := client.Exchange(m, a.DNSAddr())
if err != nil {
t.Fatalf("err: %v", err)
}
require.Equal(t, 1, len(in.Ns), "Bad: %s", in.String())
soaRec, ok := in.Ns[0].(*dns.SOA)
require.True(t, ok, "Bad: %s", in.Ns[0].String())
require.EqualValues(t, 0, soaRec.Hdr.Ttl, "Bad: %s", in.Ns[0].String())
}
// Lookup workloads directly with a port.
for workloadName, dnsType := range map[string]uint16{
"db-1": dns.TypeA,
"db-2": dns.TypeA,
"db-3": dns.TypeAAAA,
} {
for _, question := range []string{
fmt.Sprintf("%s.workload.default.ns.default.ap.consul.", workloadName),
fmt.Sprintf("tcp.port.%s.workload.default.ns.default.ap.consul.", workloadName),
fmt.Sprintf("admin.port.%s.workload.default.ns.default.ap.consul.", workloadName),
fmt.Sprintf("mesh.port.%s.workload.default.ns.default.ap.consul.", workloadName),
} {
workload := dbWorkloads[workloadName]
workloadHost := workload.Addresses[0].Host
m := new(dns.Msg)
m.SetQuestion(question, dnsType)
in, _, err := client.Exchange(m, a.DNSAddr())
if err != nil {
t.Fatalf("err: %v", err)
}
require.Equal(t, 1, len(in.Answer), "Bad: %s", in.String())
a := findAorAAAAForName(t, in, in.Answer, question)
require.Equal(t, workloadHost, a.AorAAAA.String(), "Bad: %s", a.Original.String())
require.EqualValues(t, 0, a.Hdr.Ttl, "Bad: %s", a.Original.String())
}
}
// Lookup a non-existing workload, we should receive an NXDOMAIN response.
for _, aType := range []uint16{dns.TypeA, dns.TypeAAAA} {
question := "unknown.workload.consul."
m := new(dns.Msg)
m.SetQuestion(question, aType)
in, _, err := client.Exchange(m, a.DNSAddr())
if err != nil {
t.Fatalf("err: %v", err)
}
require.Equal(t, 0, len(in.Answer), "Bad: %s", in.String())
require.Equal(t, dns.RcodeNameError, in.Rcode, "Bad: %s", in.String())
}
}
}
func findSrvAnswerForTarget(t *testing.T, in *dns.Msg, target string) *dns.SRV {
t.Helper()
for _, a := range in.Answer {
srvRec, ok := a.(*dns.SRV)
if ok && srvRec.Target == target {
return srvRec
}
}
t.Fatalf("could not find SRV record for target: %s\n\n%s", target, in.String())
return nil
}
func findAorAAAAForName(t *testing.T, in *dns.Msg, rrs []dns.RR, name string) *dnsAOrAAAA {
t.Helper()
for _, rr := range rrs {
a := newAOrAAAA(t, rr)
if a.Hdr.Name == name {
return a
}
}
t.Fatalf("could not find A/AAAA record for name: %s\n\n%+v", name, in.String())
return nil
}
func findAorAAAAForAddress(t *testing.T, in *dns.Msg, rrs []dns.RR, address string) *dnsAOrAAAA {
t.Helper()
for _, rr := range rrs {
a := newAOrAAAA(t, rr)
if a.AorAAAA.String() == address {
return a
}
}
t.Fatalf("could not find A/AAAA record for address: %s\n\n%+v", address, in.String())
return nil
}
func readResource(t retry.TestingTB, client pbresource.ResourceServiceClient, id *pbresource.ID, m proto.Message) proto.Message {
t.Helper()
retry.Run(t, func(r *retry.R) {
res, err := client.Read(context.Background(), &pbresource.ReadRequest{Id: id})
if err != nil {
r.Fatalf("err: %v", err)
}
data := res.GetResource()
require.NotEmpty(r, data)
err = data.Data.UnmarshalTo(m)
require.NoError(r, err)
})
return m
}
func toAny(t retry.TestingTB, m proto.Message) *anypb.Any {
t.Helper()
a, err := anypb.New(m)
if err != nil {
t.Fatalf("could not convert proto to `any` message: %v", err)
}
return a
}
// dnsAOrAAAA unifies A and AAAA records for simpler testing when the IP type doesn't matter.
type dnsAOrAAAA struct {
Original dns.RR
Hdr dns.RR_Header
AorAAAA net.IP
isAAAA bool
}
func newAOrAAAA(t *testing.T, rr dns.RR) *dnsAOrAAAA {
t.Helper()
if aRec, ok := rr.(*dns.A); ok {
return &dnsAOrAAAA{
Original: rr,
Hdr: aRec.Hdr,
AorAAAA: aRec.A,
isAAAA: false,
}
}
if aRec, ok := rr.(*dns.AAAA); ok {
return &dnsAOrAAAA{
Original: rr,
Hdr: aRec.Hdr,
AorAAAA: aRec.AAAA,
isAAAA: true,
}
}
t.Fatalf("Bad A or AAAA record: %#v", rr)
return nil
}
func newDNSClient(tcp bool) *dns.Client {
c := new(dns.Client)
// Use TCP to avoid truncation of larger responses and
// sidestep the default UDP size limit of 3 answers
// set by config.DefaultSource() in agent/config/default.go.
if tcp {
c.Net = "tcp"
}
return c
}

View File

@ -10,6 +10,7 @@ import (
"github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/config"
agentdns "github.com/hashicorp/consul/agent/dns"
"github.com/hashicorp/consul/agent/structs"
)
@ -26,9 +27,9 @@ func getEnterpriseDNSConfig(conf *config.RuntimeConfig) enterpriseDNSConfig {
// parseLocality can parse peer name or datacenter from a DNS query's labels.
// Peer name is parsed from the same query part that datacenter is, so given this ambiguity
// we parse a "peerOrDatacenter". The caller or RPC handler are responsible for disambiguating.
func (d *DNSServer) parseLocality(labels []string, cfg *dnsConfig) (queryLocality, bool) {
func (d *DNSServer) parseLocality(labels []string, cfg *dnsRequestConfig) (queryLocality, bool) {
locality := queryLocality{
EnterpriseMeta: d.defaultEnterpriseMeta,
EnterpriseMeta: cfg.defaultEnterpriseMeta,
}
switch len(labels) {
@ -52,7 +53,10 @@ func (d *DNSServer) parseLocality(labels []string, cfg *dnsConfig) (queryLocalit
}
return locality, true
case 1:
return queryLocality{peerOrDatacenter: labels[0]}, true
return queryLocality{
peerOrDatacenter: labels[0],
EnterpriseMeta: cfg.defaultEnterpriseMeta,
}, true
case 0:
return queryLocality{}, true
@ -64,10 +68,12 @@ func (d *DNSServer) parseLocality(labels []string, cfg *dnsConfig) (queryLocalit
type querySameness struct{}
// parseSamenessGroupLocality wraps parseLocality in CE
func (d *DNSServer) parseSamenessGroupLocality(cfg *dnsConfig, labels []string, errfnc func() error) (queryLocality, error) {
func (d *DNSServer) parseSamenessGroupLocality(cfg *dnsRequestConfig, labels []string, errfnc func() error) (queryLocality, error) {
locality, ok := d.parseLocality(labels, cfg)
if !ok {
return queryLocality{}, errfnc()
return queryLocality{
EnterpriseMeta: cfg.defaultEnterpriseMeta,
}, errfnc()
}
return locality, nil
}
@ -88,3 +94,9 @@ func nodeCanonicalDNSName(node *structs.Node, respDomain string) string {
// Return a simpler format for non-peering nodes.
return fmt.Sprintf("%s.node.%s.%s", node.Node, node.Datacenter, respDomain)
}
// setEnterpriseMetaFromRequestContext sets the DefaultNamespace and DefaultPartition on the requestDnsConfig
// based on the requestContext's DefaultNamespace and DefaultPartition.
func (d *DNSServer) setEnterpriseMetaFromRequestContext(requestContext agentdns.Context, requestDnsConfig *dnsRequestConfig) {
// do nothing
}

View File

@ -22,116 +22,112 @@ func TestDNS_CE_PeeredServices(t *testing.T) {
t.Skip("too slow for testing.Short")
}
for name, experimentsHCL := range getVersionHCL(true) {
t.Run(name, func(t *testing.T) {
a := StartTestAgent(t, TestAgent{HCL: ``, Overrides: `peering = { test_allow_peer_registrations = true } ` + experimentsHCL})
defer a.Shutdown()
testrpc.WaitForTestAgent(t, a.RPC, "dc1")
a := StartTestAgent(t, TestAgent{HCL: ``, Overrides: `peering = { test_allow_peer_registrations = true } `})
defer a.Shutdown()
testrpc.WaitForTestAgent(t, a.RPC, "dc1")
makeReq := func() *structs.RegisterRequest {
return &structs.RegisterRequest{
PeerName: "peer1",
Datacenter: "dc1",
Node: "peernode1",
Address: "198.18.1.1",
Service: &structs.NodeService{
PeerName: "peer1",
Kind: structs.ServiceKindConnectProxy,
Service: "web-proxy",
Address: "199.0.0.1",
Port: 12345,
Proxy: structs.ConnectProxyConfig{
DestinationServiceName: "peer-web",
},
EnterpriseMeta: *acl.DefaultEnterpriseMeta(),
},
EnterpriseMeta: *acl.DefaultEnterpriseMeta(),
}
}
dnsQuery := func(t *testing.T, question string, typ uint16) *dns.Msg {
m := new(dns.Msg)
m.SetQuestion(question, typ)
c := new(dns.Client)
reply, _, err := c.Exchange(m, a.DNSAddr())
require.NoError(t, err)
require.Len(t, reply.Answer, 1, "zero valid records found for %q", question)
return reply
}
assertARec := func(t *testing.T, rec dns.RR, expectName, expectIP string) {
aRec, ok := rec.(*dns.A)
require.True(t, ok, "Extra is not an A record: %T", rec)
require.Equal(t, expectName, aRec.Hdr.Name)
require.Equal(t, expectIP, aRec.A.String())
}
assertSRVRec := func(t *testing.T, rec dns.RR, expectName string, expectPort uint16) {
srvRec, ok := rec.(*dns.SRV)
require.True(t, ok, "Answer is not a SRV record: %T", rec)
require.Equal(t, expectName, srvRec.Target)
require.Equal(t, expectPort, srvRec.Port)
}
t.Run("srv-with-addr-reply", func(t *testing.T) {
require.NoError(t, a.RPC(context.Background(), "Catalog.Register", makeReq(), &struct{}{}))
q := dnsQuery(t, "web-proxy.service.peer1.peer.consul.", dns.TypeSRV)
require.Len(t, q.Answer, 1)
require.Len(t, q.Extra, 1)
addr := "c7000001.addr.consul."
assertSRVRec(t, q.Answer[0], addr, 12345)
assertARec(t, q.Extra[0], addr, "199.0.0.1")
// Query the addr to make sure it's also valid.
q = dnsQuery(t, addr, dns.TypeA)
require.Len(t, q.Answer, 1)
require.Len(t, q.Extra, 0)
assertARec(t, q.Answer[0], addr, "199.0.0.1")
})
t.Run("srv-with-node-reply", func(t *testing.T) {
req := makeReq()
// Clear service address to trigger node response
req.Service.Address = ""
require.NoError(t, a.RPC(context.Background(), "Catalog.Register", req, &struct{}{}))
q := dnsQuery(t, "web-proxy.service.peer1.peer.consul.", dns.TypeSRV)
require.Len(t, q.Answer, 1)
require.Len(t, q.Extra, 1)
nodeName := "peernode1.node.peer1.peer.consul."
assertSRVRec(t, q.Answer[0], nodeName, 12345)
assertARec(t, q.Extra[0], nodeName, "198.18.1.1")
// Query the node to make sure it's also valid.
q = dnsQuery(t, nodeName, dns.TypeA)
require.Len(t, q.Answer, 1)
require.Len(t, q.Extra, 0)
assertARec(t, q.Answer[0], nodeName, "198.18.1.1")
})
t.Run("srv-with-fqdn-reply", func(t *testing.T) {
req := makeReq()
// Set non-ip address to trigger external response
req.Address = "localhost"
req.Service.Address = ""
require.NoError(t, a.RPC(context.Background(), "Catalog.Register", req, &struct{}{}))
q := dnsQuery(t, "web-proxy.service.peer1.peer.consul.", dns.TypeSRV)
require.Len(t, q.Answer, 1)
require.Len(t, q.Extra, 0)
assertSRVRec(t, q.Answer[0], "localhost.", 12345)
})
t.Run("a-reply", func(t *testing.T) {
require.NoError(t, a.RPC(context.Background(), "Catalog.Register", makeReq(), &struct{}{}))
q := dnsQuery(t, "web-proxy.service.peer1.peer.consul.", dns.TypeA)
require.Len(t, q.Answer, 1)
require.Len(t, q.Extra, 0)
assertARec(t, q.Answer[0], "web-proxy.service.peer1.peer.consul.", "199.0.0.1")
})
})
makeReq := func() *structs.RegisterRequest {
return &structs.RegisterRequest{
PeerName: "peer1",
Datacenter: "dc1",
Node: "peernode1",
Address: "198.18.1.1",
Service: &structs.NodeService{
PeerName: "peer1",
Kind: structs.ServiceKindConnectProxy,
Service: "web-proxy",
Address: "199.0.0.1",
Port: 12345,
Proxy: structs.ConnectProxyConfig{
DestinationServiceName: "peer-web",
},
EnterpriseMeta: *acl.DefaultEnterpriseMeta(),
},
EnterpriseMeta: *acl.DefaultEnterpriseMeta(),
}
}
dnsQuery := func(t *testing.T, question string, typ uint16) *dns.Msg {
m := new(dns.Msg)
m.SetQuestion(question, typ)
c := new(dns.Client)
reply, _, err := c.Exchange(m, a.DNSAddr())
require.NoError(t, err)
require.Len(t, reply.Answer, 1, "zero valid records found for %q", question)
return reply
}
assertARec := func(t *testing.T, rec dns.RR, expectName, expectIP string) {
aRec, ok := rec.(*dns.A)
require.True(t, ok, "Extra is not an A record: %T", rec)
require.Equal(t, expectName, aRec.Hdr.Name)
require.Equal(t, expectIP, aRec.A.String())
}
assertSRVRec := func(t *testing.T, rec dns.RR, expectName string, expectPort uint16) {
srvRec, ok := rec.(*dns.SRV)
require.True(t, ok, "Answer is not a SRV record: %T", rec)
require.Equal(t, expectName, srvRec.Target)
require.Equal(t, expectPort, srvRec.Port)
}
t.Run("srv-with-addr-reply", func(t *testing.T) {
require.NoError(t, a.RPC(context.Background(), "Catalog.Register", makeReq(), &struct{}{}))
q := dnsQuery(t, "web-proxy.service.peer1.peer.consul.", dns.TypeSRV)
require.Len(t, q.Answer, 1)
require.Len(t, q.Extra, 1)
addr := "c7000001.addr.consul."
assertSRVRec(t, q.Answer[0], addr, 12345)
assertARec(t, q.Extra[0], addr, "199.0.0.1")
// Query the addr to make sure it's also valid.
q = dnsQuery(t, addr, dns.TypeA)
require.Len(t, q.Answer, 1)
require.Len(t, q.Extra, 0)
assertARec(t, q.Answer[0], addr, "199.0.0.1")
})
t.Run("srv-with-node-reply", func(t *testing.T) {
req := makeReq()
// Clear service address to trigger node response
req.Service.Address = ""
require.NoError(t, a.RPC(context.Background(), "Catalog.Register", req, &struct{}{}))
q := dnsQuery(t, "web-proxy.service.peer1.peer.consul.", dns.TypeSRV)
require.Len(t, q.Answer, 1)
require.Len(t, q.Extra, 1)
nodeName := "peernode1.node.peer1.peer.consul."
assertSRVRec(t, q.Answer[0], nodeName, 12345)
assertARec(t, q.Extra[0], nodeName, "198.18.1.1")
// Query the node to make sure it's also valid.
q = dnsQuery(t, nodeName, dns.TypeA)
require.Len(t, q.Answer, 1)
require.Len(t, q.Extra, 0)
assertARec(t, q.Answer[0], nodeName, "198.18.1.1")
})
t.Run("srv-with-fqdn-reply", func(t *testing.T) {
req := makeReq()
// Set non-ip address to trigger external response
req.Address = "localhost"
req.Service.Address = ""
require.NoError(t, a.RPC(context.Background(), "Catalog.Register", req, &struct{}{}))
q := dnsQuery(t, "web-proxy.service.peer1.peer.consul.", dns.TypeSRV)
require.Len(t, q.Answer, 1)
require.Len(t, q.Extra, 0)
assertSRVRec(t, q.Answer[0], "localhost.", 12345)
})
t.Run("a-reply", func(t *testing.T) {
require.NoError(t, a.RPC(context.Background(), "Catalog.Register", makeReq(), &struct{}{}))
q := dnsQuery(t, "web-proxy.service.peer1.peer.consul.", dns.TypeA)
require.Len(t, q.Answer, 1)
require.Len(t, q.Extra, 0)
assertARec(t, q.Answer[0], "web-proxy.service.peer1.peer.consul.", "199.0.0.1")
})
}
func getTestCasesParseLocality() []testCaseParseLocality {

File diff suppressed because it is too large Load Diff

View File

@ -17,45 +17,41 @@ func TestDNS_ReverseLookup(t *testing.T) {
t.Skip("too slow for testing.Short")
}
for name, experimentsHCL := range getVersionHCL(true) {
t.Run(name, func(t *testing.T) {
a := NewTestAgent(t, experimentsHCL)
defer a.Shutdown()
testrpc.WaitForLeader(t, a.RPC, "dc1")
a := NewTestAgent(t, "")
defer a.Shutdown()
testrpc.WaitForLeader(t, a.RPC, "dc1")
// Register node
args := &structs.RegisterRequest{
Datacenter: "dc1",
Node: "foo2",
Address: "127.0.0.2",
}
// Register node
args := &structs.RegisterRequest{
Datacenter: "dc1",
Node: "foo2",
Address: "127.0.0.2",
}
var out struct{}
if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil {
t.Fatalf("err: %v", err)
}
var out struct{}
if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil {
t.Fatalf("err: %v", err)
}
m := new(dns.Msg)
m.SetQuestion("2.0.0.127.in-addr.arpa.", dns.TypeANY)
m := new(dns.Msg)
m.SetQuestion("2.0.0.127.in-addr.arpa.", dns.TypeANY)
c := new(dns.Client)
in, _, err := c.Exchange(m, a.DNSAddr())
if err != nil {
t.Fatalf("err: %v", err)
}
c := new(dns.Client)
in, _, err := c.Exchange(m, a.DNSAddr())
if err != nil {
t.Fatalf("err: %v", err)
}
if len(in.Answer) != 1 {
t.Fatalf("Bad: %#v", in)
}
if len(in.Answer) != 1 {
t.Fatalf("Bad: %#v", in)
}
ptrRec, ok := in.Answer[0].(*dns.PTR)
if !ok {
t.Fatalf("Bad: %#v", in.Answer[0])
}
if ptrRec.Ptr != "foo2.node.dc1.consul." {
t.Fatalf("Bad: %#v", ptrRec)
}
})
ptrRec, ok := in.Answer[0].(*dns.PTR)
if !ok {
t.Fatalf("Bad: %#v", in.Answer[0])
}
if ptrRec.Ptr != "foo2.node.dc1.consul." {
t.Fatalf("Bad: %#v", ptrRec)
}
}
@ -64,47 +60,43 @@ func TestDNS_ReverseLookup_CustomDomain(t *testing.T) {
t.Skip("too slow for testing.Short")
}
for name, experimentsHCL := range getVersionHCL(true) {
t.Run(name, func(t *testing.T) {
a := NewTestAgent(t, `
a := NewTestAgent(t, `
domain = "custom"
`+experimentsHCL)
defer a.Shutdown()
testrpc.WaitForLeader(t, a.RPC, "dc1")
`)
defer a.Shutdown()
testrpc.WaitForLeader(t, a.RPC, "dc1")
// Register node
args := &structs.RegisterRequest{
Datacenter: "dc1",
Node: "foo2",
Address: "127.0.0.2",
}
// Register node
args := &structs.RegisterRequest{
Datacenter: "dc1",
Node: "foo2",
Address: "127.0.0.2",
}
var out struct{}
if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil {
t.Fatalf("err: %v", err)
}
var out struct{}
if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil {
t.Fatalf("err: %v", err)
}
m := new(dns.Msg)
m.SetQuestion("2.0.0.127.in-addr.arpa.", dns.TypeANY)
m := new(dns.Msg)
m.SetQuestion("2.0.0.127.in-addr.arpa.", dns.TypeANY)
c := new(dns.Client)
in, _, err := c.Exchange(m, a.DNSAddr())
if err != nil {
t.Fatalf("err: %v", err)
}
c := new(dns.Client)
in, _, err := c.Exchange(m, a.DNSAddr())
if err != nil {
t.Fatalf("err: %v", err)
}
if len(in.Answer) != 1 {
t.Fatalf("Bad: %#v", in)
}
if len(in.Answer) != 1 {
t.Fatalf("Bad: %#v", in)
}
ptrRec, ok := in.Answer[0].(*dns.PTR)
if !ok {
t.Fatalf("Bad: %#v", in.Answer[0])
}
if ptrRec.Ptr != "foo2.node.dc1.custom." {
t.Fatalf("Bad: %#v", ptrRec)
}
})
ptrRec, ok := in.Answer[0].(*dns.PTR)
if !ok {
t.Fatalf("Bad: %#v", in.Answer[0])
}
if ptrRec.Ptr != "foo2.node.dc1.custom." {
t.Fatalf("Bad: %#v", ptrRec)
}
}
@ -113,45 +105,41 @@ func TestDNS_ReverseLookup_IPV6(t *testing.T) {
t.Skip("too slow for testing.Short")
}
for name, experimentsHCL := range getVersionHCL(true) {
t.Run(name, func(t *testing.T) {
a := NewTestAgent(t, experimentsHCL)
defer a.Shutdown()
testrpc.WaitForLeader(t, a.RPC, "dc1")
a := NewTestAgent(t, "")
defer a.Shutdown()
testrpc.WaitForLeader(t, a.RPC, "dc1")
// Register node
args := &structs.RegisterRequest{
Datacenter: "dc1",
Node: "bar",
Address: "::4242:4242",
}
// Register node
args := &structs.RegisterRequest{
Datacenter: "dc1",
Node: "bar",
Address: "::4242:4242",
}
var out struct{}
if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil {
t.Fatalf("err: %v", err)
}
var out struct{}
if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil {
t.Fatalf("err: %v", err)
}
m := new(dns.Msg)
m.SetQuestion("2.4.2.4.2.4.2.4.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa.", dns.TypeANY)
m := new(dns.Msg)
m.SetQuestion("2.4.2.4.2.4.2.4.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa.", dns.TypeANY)
c := new(dns.Client)
in, _, err := c.Exchange(m, a.DNSAddr())
if err != nil {
t.Fatalf("err: %v", err)
}
c := new(dns.Client)
in, _, err := c.Exchange(m, a.DNSAddr())
if err != nil {
t.Fatalf("err: %v", err)
}
if len(in.Answer) != 1 {
t.Fatalf("Bad: %#v", in)
}
if len(in.Answer) != 1 {
t.Fatalf("Bad: %#v", in)
}
ptrRec, ok := in.Answer[0].(*dns.PTR)
if !ok {
t.Fatalf("Bad: %#v", in.Answer[0])
}
if ptrRec.Ptr != "bar.node.dc1.consul." {
t.Fatalf("Bad: %#v", ptrRec)
}
})
ptrRec, ok := in.Answer[0].(*dns.PTR)
if !ok {
t.Fatalf("Bad: %#v", in.Answer[0])
}
if ptrRec.Ptr != "bar.node.dc1.consul." {
t.Fatalf("Bad: %#v", ptrRec)
}
}
@ -160,58 +148,53 @@ func TestDNS_Compression_ReverseLookup(t *testing.T) {
t.Skip("too slow for testing.Short")
}
for name, experimentsHCL := range getVersionHCL(true) {
t.Run(name, func(t *testing.T) {
a := NewTestAgent(t, "")
defer a.Shutdown()
testrpc.WaitForLeader(t, a.RPC, "dc1")
a := NewTestAgent(t, experimentsHCL)
defer a.Shutdown()
testrpc.WaitForLeader(t, a.RPC, "dc1")
// Register node.
args := &structs.RegisterRequest{
Datacenter: "dc1",
Node: "foo2",
Address: "127.0.0.2",
}
var out struct{}
if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil {
t.Fatalf("err: %v", err)
}
// Register node.
args := &structs.RegisterRequest{
Datacenter: "dc1",
Node: "foo2",
Address: "127.0.0.2",
}
var out struct{}
if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil {
t.Fatalf("err: %v", err)
}
m := new(dns.Msg)
m.SetQuestion("2.0.0.127.in-addr.arpa.", dns.TypeANY)
m := new(dns.Msg)
m.SetQuestion("2.0.0.127.in-addr.arpa.", dns.TypeANY)
conn, err := dns.Dial("udp", a.DNSAddr())
if err != nil {
t.Fatalf("err: %v", err)
}
conn, err := dns.Dial("udp", a.DNSAddr())
if err != nil {
t.Fatalf("err: %v", err)
}
// Do a manual exchange with compression on (the default).
if err := conn.WriteMsg(m); err != nil {
t.Fatalf("err: %v", err)
}
p := make([]byte, dns.MaxMsgSize)
compressed, err := conn.Read(p)
if err != nil {
t.Fatalf("err: %v", err)
}
// Do a manual exchange with compression on (the default).
if err := conn.WriteMsg(m); err != nil {
t.Fatalf("err: %v", err)
}
p := make([]byte, dns.MaxMsgSize)
compressed, err := conn.Read(p)
if err != nil {
t.Fatalf("err: %v", err)
}
// Disable compression and try again.
a.DNSDisableCompression(true)
if err := conn.WriteMsg(m); err != nil {
t.Fatalf("err: %v", err)
}
unc, err := conn.Read(p)
if err != nil {
t.Fatalf("err: %v", err)
}
// Disable compression and try again.
a.DNSDisableCompression(true)
if err := conn.WriteMsg(m); err != nil {
t.Fatalf("err: %v", err)
}
unc, err := conn.Read(p)
if err != nil {
t.Fatalf("err: %v", err)
}
// We can't see the compressed status given the DNS API, so we just make
// sure the message is smaller to see if it's respecting the flag.
if compressed == 0 || unc == 0 || compressed >= unc {
t.Fatalf("doesn't look compressed: %d vs. %d", compressed, unc)
}
})
// We can't see the compressed status given the DNS API, so we just make
// sure the message is smaller to see if it's respecting the flag.
if compressed == 0 || unc == 0 || compressed >= unc {
t.Fatalf("doesn't look compressed: %d vs. %d", compressed, unc)
}
}
@ -220,53 +203,49 @@ func TestDNS_ServiceReverseLookup(t *testing.T) {
t.Skip("too slow for testing.Short")
}
for name, experimentsHCL := range getVersionHCL(true) {
t.Run(name, func(t *testing.T) {
a := NewTestAgent(t, experimentsHCL)
defer a.Shutdown()
testrpc.WaitForLeader(t, a.RPC, "dc1")
a := NewTestAgent(t, "")
defer a.Shutdown()
testrpc.WaitForLeader(t, a.RPC, "dc1")
// Register a node with a service.
{
args := &structs.RegisterRequest{
Datacenter: "dc1",
Node: "foo",
Address: "127.0.0.1",
Service: &structs.NodeService{
Service: "db",
Tags: []string{"primary"},
Port: 12345,
Address: "127.0.0.2",
},
}
// Register a node with a service.
{
args := &structs.RegisterRequest{
Datacenter: "dc1",
Node: "foo",
Address: "127.0.0.1",
Service: &structs.NodeService{
Service: "db",
Tags: []string{"primary"},
Port: 12345,
Address: "127.0.0.2",
},
}
var out struct{}
if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil {
t.Fatalf("err: %v", err)
}
}
var out struct{}
if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil {
t.Fatalf("err: %v", err)
}
}
m := new(dns.Msg)
m.SetQuestion("2.0.0.127.in-addr.arpa.", dns.TypeANY)
m := new(dns.Msg)
m.SetQuestion("2.0.0.127.in-addr.arpa.", dns.TypeANY)
c := new(dns.Client)
in, _, err := c.Exchange(m, a.DNSAddr())
if err != nil {
t.Fatalf("err: %v", err)
}
c := new(dns.Client)
in, _, err := c.Exchange(m, a.DNSAddr())
if err != nil {
t.Fatalf("err: %v", err)
}
if len(in.Answer) != 1 {
t.Fatalf("Bad: %#v", in)
}
if len(in.Answer) != 1 {
t.Fatalf("Bad: %#v", in)
}
ptrRec, ok := in.Answer[0].(*dns.PTR)
if !ok {
t.Fatalf("Bad: %#v", in.Answer[0])
}
if ptrRec.Ptr != serviceCanonicalDNSName("db", "service", "dc1", "consul", nil)+"." {
t.Fatalf("Bad: %#v", ptrRec)
}
})
ptrRec, ok := in.Answer[0].(*dns.PTR)
if !ok {
t.Fatalf("Bad: %#v", in.Answer[0])
}
if ptrRec.Ptr != serviceCanonicalDNSName("db", "service", "dc1", "consul", nil)+"." {
t.Fatalf("Bad: %#v", ptrRec)
}
}
@ -275,53 +254,49 @@ func TestDNS_ServiceReverseLookup_IPV6(t *testing.T) {
t.Skip("too slow for testing.Short")
}
for name, experimentsHCL := range getVersionHCL(true) {
t.Run(name, func(t *testing.T) {
a := NewTestAgent(t, experimentsHCL)
defer a.Shutdown()
testrpc.WaitForLeader(t, a.RPC, "dc1")
a := NewTestAgent(t, "")
defer a.Shutdown()
testrpc.WaitForLeader(t, a.RPC, "dc1")
// Register a node with a service.
{
args := &structs.RegisterRequest{
Datacenter: "dc1",
Node: "foo",
Address: "2001:db8::1",
Service: &structs.NodeService{
Service: "db",
Tags: []string{"primary"},
Port: 12345,
Address: "2001:db8::ff00:42:8329",
},
}
// Register a node with a service.
{
args := &structs.RegisterRequest{
Datacenter: "dc1",
Node: "foo",
Address: "2001:db8::1",
Service: &structs.NodeService{
Service: "db",
Tags: []string{"primary"},
Port: 12345,
Address: "2001:db8::ff00:42:8329",
},
}
var out struct{}
if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil {
t.Fatalf("err: %v", err)
}
}
var out struct{}
if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil {
t.Fatalf("err: %v", err)
}
}
m := new(dns.Msg)
m.SetQuestion("9.2.3.8.2.4.0.0.0.0.f.f.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.", dns.TypeANY)
m := new(dns.Msg)
m.SetQuestion("9.2.3.8.2.4.0.0.0.0.f.f.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.", dns.TypeANY)
c := new(dns.Client)
in, _, err := c.Exchange(m, a.DNSAddr())
if err != nil {
t.Fatalf("err: %v", err)
}
c := new(dns.Client)
in, _, err := c.Exchange(m, a.DNSAddr())
if err != nil {
t.Fatalf("err: %v", err)
}
if len(in.Answer) != 1 {
t.Fatalf("Bad: %#v", in)
}
if len(in.Answer) != 1 {
t.Fatalf("Bad: %#v", in)
}
ptrRec, ok := in.Answer[0].(*dns.PTR)
if !ok {
t.Fatalf("Bad: %#v", in.Answer[0])
}
if ptrRec.Ptr != serviceCanonicalDNSName("db", "service", "dc1", "consul", nil)+"." {
t.Fatalf("Bad: %#v", ptrRec)
}
})
ptrRec, ok := in.Answer[0].(*dns.PTR)
if !ok {
t.Fatalf("Bad: %#v", in.Answer[0])
}
if ptrRec.Ptr != serviceCanonicalDNSName("db", "service", "dc1", "consul", nil)+"." {
t.Fatalf("Bad: %#v", ptrRec)
}
}
@ -330,55 +305,51 @@ func TestDNS_ServiceReverseLookup_CustomDomain(t *testing.T) {
t.Skip("too slow for testing.Short")
}
for name, experimentsHCL := range getVersionHCL(true) {
t.Run(name, func(t *testing.T) {
a := NewTestAgent(t, `
a := NewTestAgent(t, `
domain = "custom"
`+experimentsHCL)
defer a.Shutdown()
testrpc.WaitForLeader(t, a.RPC, "dc1")
`)
defer a.Shutdown()
testrpc.WaitForLeader(t, a.RPC, "dc1")
// Register a node with a service.
{
args := &structs.RegisterRequest{
Datacenter: "dc1",
Node: "foo",
Address: "127.0.0.1",
Service: &structs.NodeService{
Service: "db",
Tags: []string{"primary"},
Port: 12345,
Address: "127.0.0.2",
},
}
// Register a node with a service.
{
args := &structs.RegisterRequest{
Datacenter: "dc1",
Node: "foo",
Address: "127.0.0.1",
Service: &structs.NodeService{
Service: "db",
Tags: []string{"primary"},
Port: 12345,
Address: "127.0.0.2",
},
}
var out struct{}
if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil {
t.Fatalf("err: %v", err)
}
}
var out struct{}
if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil {
t.Fatalf("err: %v", err)
}
}
m := new(dns.Msg)
m.SetQuestion("2.0.0.127.in-addr.arpa.", dns.TypeANY)
m := new(dns.Msg)
m.SetQuestion("2.0.0.127.in-addr.arpa.", dns.TypeANY)
c := new(dns.Client)
in, _, err := c.Exchange(m, a.DNSAddr())
if err != nil {
t.Fatalf("err: %v", err)
}
c := new(dns.Client)
in, _, err := c.Exchange(m, a.DNSAddr())
if err != nil {
t.Fatalf("err: %v", err)
}
if len(in.Answer) != 1 {
t.Fatalf("Bad: %#v", in)
}
if len(in.Answer) != 1 {
t.Fatalf("Bad: %#v", in)
}
ptrRec, ok := in.Answer[0].(*dns.PTR)
if !ok {
t.Fatalf("Bad: %#v", in.Answer[0])
}
if ptrRec.Ptr != serviceCanonicalDNSName("db", "service", "dc1", "custom", nil)+"." {
t.Fatalf("Bad: %#v", ptrRec)
}
})
ptrRec, ok := in.Answer[0].(*dns.PTR)
if !ok {
t.Fatalf("Bad: %#v", in.Answer[0])
}
if ptrRec.Ptr != serviceCanonicalDNSName("db", "service", "dc1", "custom", nil)+"." {
t.Fatalf("Bad: %#v", ptrRec)
}
}
@ -387,53 +358,49 @@ func TestDNS_ServiceReverseLookupNodeAddress(t *testing.T) {
t.Skip("too slow for testing.Short")
}
for name, experimentsHCL := range getVersionHCL(true) {
t.Run(name, func(t *testing.T) {
a := NewTestAgent(t, experimentsHCL)
defer a.Shutdown()
testrpc.WaitForLeader(t, a.RPC, "dc1")
a := NewTestAgent(t, "")
defer a.Shutdown()
testrpc.WaitForLeader(t, a.RPC, "dc1")
// Register a node with a service.
{
args := &structs.RegisterRequest{
Datacenter: "dc1",
Node: "foo",
Address: "127.0.0.1",
Service: &structs.NodeService{
Service: "db",
Tags: []string{"primary"},
Port: 12345,
Address: "127.0.0.1",
},
}
// Register a node with a service.
{
args := &structs.RegisterRequest{
Datacenter: "dc1",
Node: "foo",
Address: "127.0.0.1",
Service: &structs.NodeService{
Service: "db",
Tags: []string{"primary"},
Port: 12345,
Address: "127.0.0.1",
},
}
var out struct{}
if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil {
t.Fatalf("err: %v", err)
}
}
var out struct{}
if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil {
t.Fatalf("err: %v", err)
}
}
m := new(dns.Msg)
m.SetQuestion("1.0.0.127.in-addr.arpa.", dns.TypeANY)
m := new(dns.Msg)
m.SetQuestion("1.0.0.127.in-addr.arpa.", dns.TypeANY)
c := new(dns.Client)
in, _, err := c.Exchange(m, a.DNSAddr())
if err != nil {
t.Fatalf("err: %v", err)
}
c := new(dns.Client)
in, _, err := c.Exchange(m, a.DNSAddr())
if err != nil {
t.Fatalf("err: %v", err)
}
if len(in.Answer) != 1 {
t.Fatalf("Bad: %#v", in)
}
if len(in.Answer) != 1 {
t.Fatalf("Bad: %#v", in)
}
ptrRec, ok := in.Answer[0].(*dns.PTR)
if !ok {
t.Fatalf("Bad: %#v", in.Answer[0])
}
if ptrRec.Ptr != "foo.node.dc1.consul." {
t.Fatalf("Bad: %#v", ptrRec)
}
})
ptrRec, ok := in.Answer[0].(*dns.PTR)
if !ok {
t.Fatalf("Bad: %#v", in.Answer[0])
}
if ptrRec.Ptr != "foo.node.dc1.consul." {
t.Fatalf("Bad: %#v", ptrRec)
}
}
@ -442,35 +409,31 @@ func TestDNS_ReverseLookup_NotFound(t *testing.T) {
t.Skip("too slow for testing.Short")
}
for name, experimentsHCL := range getVersionHCL(true) {
t.Run(name, func(t *testing.T) {
// do not configure recursors
a := NewTestAgent(t, experimentsHCL)
defer a.Shutdown()
testrpc.WaitForLeader(t, a.RPC, "dc1")
// do not configure recursors
a := NewTestAgent(t, "")
defer a.Shutdown()
testrpc.WaitForLeader(t, a.RPC, "dc1")
// Do not register any nodes
m := new(dns.Msg)
qName := "2.0.0.127.in-addr.arpa."
m.SetQuestion(qName, dns.TypeANY)
// Do not register any nodes
m := new(dns.Msg)
qName := "2.0.0.127.in-addr.arpa."
m.SetQuestion(qName, dns.TypeANY)
c := new(dns.Client)
in, _, err := c.Exchange(m, a.DNSAddr())
require.NoError(t, err)
require.Nil(t, in.Answer)
require.Nil(t, in.Extra)
c := new(dns.Client)
in, _, err := c.Exchange(m, a.DNSAddr())
require.NoError(t, err)
require.Nil(t, in.Answer)
require.Nil(t, in.Extra)
require.Equal(t, dns.RcodeNameError, in.Rcode)
require.Equal(t, dns.RcodeNameError, in.Rcode)
question := in.Question[0]
require.Equal(t, qName, question.Name)
require.Equal(t, dns.TypeANY, question.Qtype)
require.Equal(t, uint16(dns.ClassINET), question.Qclass)
question := in.Question[0]
require.Equal(t, qName, question.Name)
require.Equal(t, dns.TypeANY, question.Qtype)
require.Equal(t, uint16(dns.ClassINET), question.Qclass)
soa, ok := in.Ns[0].(*dns.SOA)
require.True(t, ok)
require.Equal(t, "ns.consul.", soa.Ns)
require.Equal(t, "hostmaster.consul.", soa.Mbox)
})
}
soa, ok := in.Ns[0].(*dns.SOA)
require.True(t, ok)
require.Equal(t, "ns.consul.", soa.Ns)
require.Equal(t, "hostmaster.consul.", soa.Mbox)
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -6,6 +6,7 @@ package dns
import (
"context"
"fmt"
agentdns "github.com/hashicorp/consul/agent/dns"
"net"
"github.com/hashicorp/go-hclog"
@ -41,61 +42,6 @@ func (s *Server) Register(registrar grpc.ServiceRegistrar) {
pbdns.RegisterDNSServiceServer(registrar, s)
}
// BufferResponseWriter writes a DNS response to a byte buffer.
type BufferResponseWriter struct {
responseBuffer []byte
LocalAddress net.Addr
RemoteAddress net.Addr
Logger hclog.Logger
}
// LocalAddr returns the net.Addr of the server
func (b *BufferResponseWriter) LocalAddr() net.Addr {
return b.LocalAddress
}
// RemoteAddr returns the net.Addr of the client that sent the current request.
func (b *BufferResponseWriter) RemoteAddr() net.Addr {
return b.RemoteAddress
}
// WriteMsg writes a reply back to the client.
func (b *BufferResponseWriter) WriteMsg(m *dns.Msg) error {
// Pack message to bytes first.
msgBytes, err := m.Pack()
if err != nil {
b.Logger.Error("error packing message", "err", err)
return err
}
b.responseBuffer = msgBytes
return nil
}
// Write writes a raw buffer back to the client.
func (b *BufferResponseWriter) Write(m []byte) (int, error) {
b.Logger.Debug("Write was called")
return copy(b.responseBuffer, m), nil
}
// Close closes the connection.
func (b *BufferResponseWriter) Close() error {
// There's nothing for us to do here as we don't handle the connection.
return nil
}
// TsigStatus returns the status of the Tsig.
func (b *BufferResponseWriter) TsigStatus() error {
// TSIG doesn't apply to this response writer.
return nil
}
// TsigTimersOnly sets the tsig timers only boolean.
func (b *BufferResponseWriter) TsigTimersOnly(bool) {}
// Hijack lets the caller take over the connection.
// After a call to Hijack(), the DNS package will not do anything with the connection. {
func (b *BufferResponseWriter) Hijack() {}
// Query is a gRPC endpoint that will serve dns requests. It will be consumed primarily by the
// consul dataplane to proxy dns requests to consul.
func (s *Server) Query(ctx context.Context, req *pbdns.QueryRequest) (*pbdns.QueryResponse, error) {
@ -121,21 +67,29 @@ func (s *Server) Query(ctx context.Context, req *pbdns.QueryRequest) (*pbdns.Que
return nil, status.Error(codes.InvalidArgument, fmt.Sprintf("error protocol type not set: %v", req.GetProtocol()))
}
respWriter := &BufferResponseWriter{
LocalAddress: local,
RemoteAddress: remote,
Logger: s.Logger,
reqCtx, err := agentdns.NewContextFromGRPCContext(ctx)
if err != nil {
s.Logger.Error("error parsing DNS context from grpc metadata", "err", err)
return nil, status.Error(codes.Internal, fmt.Sprintf("error parsing DNS context from grpc metadata: %s", err.Error()))
}
respWriter := &agentdns.BufferResponseWriter{
LocalAddress: local,
RemoteAddress: remote,
Logger: s.Logger,
RequestContext: reqCtx,
}
msg := &dns.Msg{}
err := msg.Unpack(req.Msg)
err = msg.Unpack(req.Msg)
if err != nil {
s.Logger.Error("error unpacking message", "err", err)
return nil, status.Error(codes.Internal, fmt.Sprintf("failure decoding dns request: %s", err.Error()))
}
s.DNSServeMux.ServeDNS(respWriter, msg)
queryResponse := &pbdns.QueryResponse{Msg: respWriter.responseBuffer}
queryResponse := &pbdns.QueryResponse{Msg: respWriter.ResponseBuffer()}
return queryResponse, nil
}

View File

@ -1,90 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package dns
import (
"context"
"fmt"
"net"
"github.com/miekg/dns"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/status"
"github.com/hashicorp/go-hclog"
agentdns "github.com/hashicorp/consul/agent/dns"
"github.com/hashicorp/consul/proto-public/pbdns"
)
type ConfigV2 struct {
DNSRouter agentdns.DNSRouter
Logger hclog.Logger
TokenFunc func() string
}
var _ pbdns.DNSServiceServer = (*ServerV2)(nil)
// ServerV2 is a gRPC server that implements pbdns.DNSServiceServer.
// It is compatible with the refactored V2 DNS server and suitable for
// passing additional metadata along the grpc connection to catalog queries.
type ServerV2 struct {
ConfigV2
}
func NewServerV2(cfg ConfigV2) *ServerV2 {
return &ServerV2{cfg}
}
func (s *ServerV2) Register(registrar grpc.ServiceRegistrar) {
pbdns.RegisterDNSServiceServer(registrar, s)
}
// Query is a gRPC endpoint that will serve dns requests. It will be consumed primarily by the
// consul dataplane to proxy dns requests to consul.
func (s *ServerV2) Query(ctx context.Context, req *pbdns.QueryRequest) (*pbdns.QueryResponse, error) {
pr, ok := peer.FromContext(ctx)
if !ok {
return nil, fmt.Errorf("error retrieving peer information from context")
}
var remote net.Addr
// We do this so that we switch to udp/tcp when handling the request since it will be proxied
// through consul through gRPC, and we need to 'fake' the protocol so that the message is trimmed
// according to whether it is UDP or TCP.
switch req.GetProtocol() {
case pbdns.Protocol_PROTOCOL_TCP:
remote = pr.Addr
case pbdns.Protocol_PROTOCOL_UDP:
remoteAddr := pr.Addr.(*net.TCPAddr)
remote = &net.UDPAddr{IP: remoteAddr.IP, Port: remoteAddr.Port}
default:
return nil, status.Error(codes.InvalidArgument, fmt.Sprintf("error protocol type not set: %v", req.GetProtocol()))
}
msg := &dns.Msg{}
err := msg.Unpack(req.Msg)
if err != nil {
s.Logger.Error("error unpacking message", "err", err)
return nil, status.Error(codes.Internal, fmt.Sprintf("failure decoding dns request: %s", err.Error()))
}
reqCtx, err := agentdns.NewContextFromGRPCContext(ctx)
if err != nil {
s.Logger.Error("error parsing DNS context from grpc metadata", "err", err)
return nil, status.Error(codes.Internal, fmt.Sprintf("error parsing DNS context from grpc metadata: %s", err.Error()))
}
resp := s.DNSRouter.HandleRequest(msg, reqCtx, remote)
data, err := resp.Pack()
if err != nil {
s.Logger.Error("error packing message", "err", err)
return nil, status.Error(codes.Internal, fmt.Sprintf("failure encoding dns request: %s", err.Error()))
}
queryResponse := &pbdns.QueryResponse{Msg: data}
return queryResponse, nil
}

View File

@ -1,164 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package dns
import (
"context"
"errors"
"github.com/hashicorp/go-hclog"
"github.com/miekg/dns"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"google.golang.org/grpc/metadata"
agentdns "github.com/hashicorp/consul/agent/dns"
"github.com/hashicorp/consul/proto-public/pbdns"
)
func basicResponse() *dns.Msg {
return &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
Response: true,
Authoritative: true,
},
Compress: true,
Question: []dns.Question{
{
Name: "abc.com.",
Qtype: dns.TypeANY,
Qclass: dns.ClassINET,
},
},
Extra: []dns.RR{
&dns.TXT{
Hdr: dns.RR_Header{
Name: "abc.com.",
Rrtype: dns.TypeTXT,
Class: dns.ClassINET,
Ttl: 0,
},
Txt: txtRR,
},
},
}
}
func (s *DNSTestSuite) TestProxy_V2Success() {
testCases := map[string]struct {
question string
configureRouter func(router *agentdns.MockDNSRouter)
clientQuery func(qR *pbdns.QueryRequest)
metadata map[string]string
expectedErr error
}{
"happy path udp": {
question: "abc.com.",
configureRouter: func(router *agentdns.MockDNSRouter) {
router.On("HandleRequest", mock.Anything, mock.Anything, mock.Anything).
Return(basicResponse(), nil)
},
clientQuery: func(qR *pbdns.QueryRequest) {
qR.Protocol = pbdns.Protocol_PROTOCOL_UDP
},
},
"happy path tcp": {
question: "abc.com.",
configureRouter: func(router *agentdns.MockDNSRouter) {
router.On("HandleRequest", mock.Anything, mock.Anything, mock.Anything).
Return(basicResponse(), nil)
},
clientQuery: func(qR *pbdns.QueryRequest) {
qR.Protocol = pbdns.Protocol_PROTOCOL_TCP
},
},
"happy path with context variables set": {
question: "abc.com.",
configureRouter: func(router *agentdns.MockDNSRouter) {
router.On("HandleRequest", mock.Anything, mock.Anything, mock.Anything).
Run(func(args mock.Arguments) {
ctx, ok := args.Get(1).(agentdns.Context)
require.True(s.T(), ok, "error casting to agentdns.Context")
require.Equal(s.T(), "test-token", ctx.Token, "token not set in context")
require.Equal(s.T(), "test-namespace", ctx.DefaultNamespace, "namespace not set in context")
require.Equal(s.T(), "test-partition", ctx.DefaultPartition, "partition not set in context")
}).
Return(basicResponse(), nil)
},
clientQuery: func(qR *pbdns.QueryRequest) {
qR.Protocol = pbdns.Protocol_PROTOCOL_UDP
},
metadata: map[string]string{
"x-consul-token": "test-token",
"x-consul-namespace": "test-namespace",
"x-consul-partition": "test-partition",
},
},
"No protocol set": {
question: "abc.com.",
clientQuery: func(qR *pbdns.QueryRequest) {},
expectedErr: errors.New("error protocol type not set: PROTOCOL_UNSET_UNSPECIFIED"),
},
"Invalid question": {
question: "notvalid",
clientQuery: func(qR *pbdns.QueryRequest) {
qR.Protocol = pbdns.Protocol_PROTOCOL_UDP
},
expectedErr: errors.New("failure decoding dns request"),
},
}
for name, tc := range testCases {
s.Run(name, func() {
router := agentdns.NewMockDNSRouter(s.T())
if tc.configureRouter != nil {
tc.configureRouter(router)
}
server := NewServerV2(ConfigV2{
Logger: hclog.Default(),
DNSRouter: router,
TokenFunc: func() string { return "" },
})
client := testClient(s.T(), server)
req := dns.Msg{}
req.SetQuestion(tc.question, dns.TypeA)
bytes, _ := req.Pack()
ctx := context.Background()
if len(tc.metadata) > 0 {
md := metadata.MD{}
for k, v := range tc.metadata {
md.Set(k, v)
}
ctx = metadata.NewOutgoingContext(ctx, md)
}
clientReq := &pbdns.QueryRequest{Msg: bytes}
tc.clientQuery(clientReq)
clientResp, err := client.Query(ctx, clientReq)
if tc.expectedErr != nil {
s.Require().Error(err, "no errror calling gRPC endpoint")
s.Require().ErrorContains(err, tc.expectedErr.Error())
} else {
s.Require().NoError(err, "error calling gRPC endpoint")
resp := clientResp.GetMsg()
var dnsResp dns.Msg
err = dnsResp.Unpack(resp)
s.Require().NoError(err, "error unpacking dns response")
rr := dnsResp.Extra[0].(*dns.TXT)
s.Require().EqualValues(rr.Txt, txtRR)
}
})
}
}

View File

@ -28,7 +28,6 @@ import (
"github.com/hashicorp/consul/agent/consul/stream"
"github.com/hashicorp/consul/agent/consul/usagemetrics"
"github.com/hashicorp/consul/agent/consul/xdscapacity"
"github.com/hashicorp/consul/agent/discovery"
"github.com/hashicorp/consul/agent/grpc-external/limiter"
grpcInt "github.com/hashicorp/consul/agent/grpc-internal"
"github.com/hashicorp/consul/agent/grpc-internal/balancer"
@ -437,7 +436,6 @@ func getPrometheusDefs(cfg *config.RuntimeConfig, isServer bool) ([]prometheus.G
consul.CatalogCounters,
consul.ClientCounters,
consul.RPCCounters,
discovery.DNSCounters,
grpcWare.StatsCounters,
local.StateCounters,
xds.StatsCounters,

View File

@ -24,6 +24,7 @@ const (
errNotPrimaryDatacenter = "not the primary datacenter"
errStateReadOnly = "CA Provider State is read-only"
errUsingV2CatalogExperiment = "V1 catalog is disabled when V2 is enabled"
errSamenessGroupNotFound = "Sameness Group not found"
errSamenessGroupMustBeDefaultForFailover = "Sameness Group must have DefaultForFailover set to true in order to use this endpoint"
)
@ -42,6 +43,7 @@ var (
ErrNotPrimaryDatacenter = errors.New(errNotPrimaryDatacenter)
ErrStateReadOnly = errors.New(errStateReadOnly)
ErrUsingV2CatalogExperiment = errors.New(errUsingV2CatalogExperiment)
ErrSamenessGroupNotFound = errors.New(errSamenessGroupNotFound)
ErrSamenessGroupMustBeDefaultForFailover = errors.New(errSamenessGroupMustBeDefaultForFailover)
)
@ -65,6 +67,10 @@ func IsErrUsingV2CatalogExperiment(err error) bool {
return err != nil && strings.Contains(err.Error(), errUsingV2CatalogExperiment)
}
func IsErrSamenessGroupNotFound(err error) bool {
return err != nil && strings.Contains(err.Error(), errSamenessGroupNotFound)
}
func IsErrSamenessGroupMustBeDefaultForFailover(err error) bool {
return err != nil && strings.Contains(err.Error(), errSamenessGroupMustBeDefaultForFailover)
}

View File

@ -766,14 +766,11 @@ Use these links to navigate to a particular top-level stanza.
- `experiments` ((#v-global-experiments)) (`array<string>: []`) - Consul feature flags that will be enabled across components.
Supported feature flags:
- `v1dns`:
When this flag is set, Consul agents use the legacy DNS implementation.
This setting exists in the case a DNS bug is found after the refactoring introduced in v1.19.0.
Example:
```yaml
experiments: [ "v1dns" ]
experiments: [ "<experiment name>" ]
```
### server ((#h-server))
@ -1787,7 +1784,7 @@ Use these links to navigate to a particular top-level stanza.
or may not be broadly accessible depending on your Kubernetes cluster.
Set this to false to skip syncing ClusterIP services.
- `syncLoadBalancerEndpoints` ((#v-synccatalog-syncloadbalancerendpoints)) (`boolean: false`) - If true, LoadBalancer service endpoints instead of ingress addresses will be synced to Consul.
- `syncLoadBalancerEndpoints` ((#v-synccatalog-syncloadbalancerendpoints)) (`boolean: false`) - If true, LoadBalancer service endpoints instead of ingress addresses will be synced to Consul.
If false, LoadBalancer endpoints are not synced to Consul.
- `ingress` ((#v-synccatalog-ingress))

View File

@ -17,13 +17,13 @@ We are pleased to announce the following Consul updates.
- **API Gateway metrics**: The Consul API Gateway now provides a Prometheus metrics endpoint you can use to gather information about the health of the gateway, as well as traffic for proxied connections or requests.
- **File system certificate configuration entry**: A new [`file-system-certificate` configuration entry](/consul/docs/connect/config-entries/file-system-certificate) supports specifying a filepath to the certificate and private key for Consul API Gateway on VMs on the local system. Previously, the certificate and private key were specified directly in the `inline-certificate` configuration entry. When using the file system certificates, the Consul server never sees the contents of these files.
- **File system certificate configuration entry**: A new [`file-system-certificate` configuration entry](/consul/docs/connect/config-entries/file-system-certificate) supports specifying a filepath to the certificate and private key for Consul API Gateway on VMs on the local system. Previously, the certificate and private key were specified directly in the `inline-certificate` configuration entry. When using the file system certificates, the Consul server never sees the contents of these files.
File system certificates also include a file system watch that allows for changing the certificate and key without restarting the gateway. This feature requires that you have access to the gateways file system in order to place the certificate or update it.
Consul on Kubernetes deployments that use `consul-k8s` Helm chart v1.5.0 or later use file system certificates without additional configuration. For more information, refer to [File system certificate configuration reference](/consul/docs/connect/config-entries/file-system-certificate).
- **Argo Rollouts Plugin**: A new Argo Rollouts plugin for progressive delivery is now available for `consul-k8s`. This plugin supports canary deployments by allowing you to incrementally release and test new versions of applications and roll back to previous versions by splitting traffic between subsets of services. Refer to [Argo Rollouts Progressive Delivery with Consul on Kubernetes](/consul/docs/k8s/deployment-configurations/argo-rollouts-configuration) for more information.
- **Argo Rollouts Plugin**: A new Argo Rollouts plugin for progressive delivery is now available for `consul-k8s`. This plugin supports canary deployments by allowing you to incrementally release and test new versions of applications and roll back to previous versions by splitting traffic between subsets of services. Refer to [Argo Rollouts Progressive Delivery with Consul on Kubernetes](/consul/docs/k8s/deployment-configurations/argo-rollouts-configuration) for more information.
## What's deprecated
@ -37,14 +37,14 @@ For more detailed information, please refer to the [upgrade details page](/consu
The following issues are known to exist in the v1.19.x releases:
- v1.19.0 - There are known issues with the new DNS server implementation.
To revert to the old DNS behavior, set `experiments: [ "v1dns" ]` in the agent configuration.
Fixes will be included in Consul v1.19.1.
- v1.19.0 & v1.19.1 - There are known issues with the DNS server implementation.
To revert to the old DNS behavior on 1.19.0 and 1.19.1, set `experiments: [ "v1dns" ]` in the agent configuration.
In v1.19.2, the modified DNS subsystem will be reverted and the old DNS behavior will be restored resolving these issues.
- DNS SRV records for registrations that specify a service address instead of a node address return identical `.service` hostnames instead of unique `.addr` addresses.
As a result, it is impossible to resolve the individual service addresses.
This bug can affect Nomad installations using Consul for service discovery because the service address field is always specified to Consul.
[[GH-21325](https://github.com/hashicorp/consul/issues/21325)].
- DNS Tags are not resolved when using the Standard query format, `tag.name.service.consul`.
- DNS Tags are not resolved when using the Standard query format, `tag.name.service.consul`.
[[GH-21326](https://github.com/hashicorp/consul/issues/21336)].
## Changelogs