feat(dataplane): allow token and tenancy information for proxied DNS (#20899)

* feat(dataplane): allow token and tenancy information for proxied DNS

* changelog
This commit is contained in:
Dan Stough 2024-04-22 14:30:43 -04:00 committed by GitHub
parent 057ad7e952
commit 03ab7367a6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 376 additions and 68 deletions

4
.changelog/20899.txt Normal file
View File

@ -0,0 +1,4 @@
```release-note:improvement
dns: DNS-over-grpc when using `consul-dataplane` now accepts partition, namespace, token as metadata to default those query parameters.
`consul-dataplane` v1.5+ will send this information automatically.
```

View File

@ -116,11 +116,17 @@ func (f *V1DataFetcher) FetchNodes(ctx Context, req *QueryPayload) ([]*Result, e
// Nodes are not namespaced, so this is a name error
return nil, ErrNotFound
}
cfg := f.dynamicConfig.Load().(*V1DataFetcherDynamicConfig)
// If no datacenter is passed, default to our own
datacenter := cfg.Datacenter
if req.Tenancy.Datacenter != "" {
datacenter = req.Tenancy.Datacenter
}
// Make an RPC request
args := &structs.NodeSpecificRequest{
Datacenter: req.Tenancy.Datacenter,
Datacenter: datacenter,
PeerName: req.Tenancy.Peer,
Node: req.Name,
QueryOptions: structs.QueryOptions{
@ -299,9 +305,15 @@ func (f *V1DataFetcher) FetchWorkload(ctx Context, req *QueryPayload) (*Result,
func (f *V1DataFetcher) FetchPreparedQuery(ctx Context, req *QueryPayload) ([]*Result, error) {
cfg := f.dynamicConfig.Load().(*V1DataFetcherDynamicConfig)
// If no datacenter is passed, default to our own
datacenter := cfg.Datacenter
if req.Tenancy.Datacenter != "" {
datacenter = req.Tenancy.Datacenter
}
// Execute the prepared query.
args := structs.PreparedQueryExecuteRequest{
Datacenter: req.Tenancy.Datacenter,
Datacenter: datacenter,
QueryIDOrName: req.Name,
QueryOptions: structs.QueryOptions{
Token: ctx.Token,
@ -548,7 +560,11 @@ func (f *V1DataFetcher) fetchServiceBasedOnTenancy(ctx Context, req *QueryPayloa
return nil, errors.New("sameness groups are not allowed for service lookups based on tenancy")
}
datacenter := req.Tenancy.Datacenter
// If no datacenter is passed, default to our own
datacenter := cfg.Datacenter
if req.Tenancy.Datacenter != "" {
datacenter = req.Tenancy.Datacenter
}
if req.Tenancy.Peer != "" {
datacenter = ""
}

View File

@ -128,6 +128,7 @@ func Test_FetchVirtualIP(t *testing.T) {
func Test_FetchEndpoints(t *testing.T) {
// set these to confirm that RPC call does not use them for this particular RPC
rc := &config.RuntimeConfig{
Datacenter: "dc2",
DNSAllowStale: true,
DNSMaxStale: 100,
DNSUseCache: true,

View File

@ -347,6 +347,7 @@ func queryTenancyToResourceTenancy(qTenancy QueryTenancy) *pbresource.Tenancy {
rTenancy.Namespace = qTenancy.Namespace
}
// In the case of partition, we have the agent's partition as the fallback.
// That is handled in NormalizeRequest.
if qTenancy.Partition != "" {
rTenancy.Partition = qTenancy.Partition
}

56
agent/dns/context.go Normal file
View File

@ -0,0 +1,56 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package dns
import (
"context"
"fmt"
"github.com/mitchellh/mapstructure"
"google.golang.org/grpc/metadata"
)
// Context is used augment a DNS message with Consul-specific metadata.
type Context struct {
Token string `mapstructure:"x-consul-token,omitempty"`
DefaultNamespace string `mapstructure:"x-consul-namespace,omitempty"`
DefaultPartition string `mapstructure:"x-consul-partition,omitempty"`
}
// NewContextFromGRPCContext returns the request context using the gRPC metadata attached to the
// given context. If there is no gRPC metadata, it returns an empty context.
func NewContextFromGRPCContext(ctx context.Context) (Context, error) {
if ctx == nil {
return Context{}, nil
}
reqCtx := Context{}
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return reqCtx, nil
}
m := map[string]string{}
for k, v := range md {
m[k] = v[0]
}
decoderConfig := &mapstructure.DecoderConfig{
Metadata: nil,
Result: &reqCtx,
WeaklyTypedInput: true,
}
decoder, err := mapstructure.NewDecoder(decoderConfig)
if err != nil {
return Context{}, fmt.Errorf("error creating mapstructure decoder: %w", err)
}
err = decoder.Decode(m)
if err != nil {
return Context{}, fmt.Errorf("error decoding metadata: %w", err)
}
return reqCtx, nil
}

70
agent/dns/context_test.go Normal file
View File

@ -0,0 +1,70 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package dns
import (
"context"
"testing"
"github.com/stretchr/testify/require"
"google.golang.org/grpc/metadata"
)
func TestNewContextFromGRPCContext(t *testing.T) {
t.Parallel()
md := metadata.MD{}
testMeta := map[string]string{
"x-consul-token": "test-token",
"x-consul-namespace": "test-namespace",
"x-consul-partition": "test-partition",
}
for k, v := range testMeta {
md.Set(k, v)
}
testGRPCContext := metadata.NewIncomingContext(context.Background(), md)
testCases := []struct {
name string
grpcCtx context.Context
expected *Context
error error
}{
{
name: "nil grpc context",
grpcCtx: nil,
expected: &Context{},
},
{
name: "grpc context w/o metadata",
grpcCtx: context.Background(),
expected: &Context{},
},
{
name: "grpc context w/ kitchen sink",
grpcCtx: testGRPCContext,
expected: &Context{
Token: "test-token",
DefaultNamespace: "test-namespace",
DefaultPartition: "test-partition",
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ctx, err := NewContextFromGRPCContext(tc.grpcCtx)
if tc.error != nil {
require.Error(t, err)
require.Equal(t, Context{}, &ctx)
require.Equal(t, tc.error, err)
return
}
require.NotNil(t, ctx)
require.Equal(t, tc.expected, &ctx)
})
}
}

View File

@ -206,20 +206,24 @@ func getQueryTenancy(reqCtx Context, queryType discovery.QueryType, querySuffixe
return discovery.QueryTenancy{}, errNameNotFound
}
// If we don't have an explicit partition in the request, try the first fallback
// If we don't have an explicit partition/ns in the request, try the first fallback
// which was supplied in the request context. The agent's partition will be used as the last fallback
// later in the query processor.
if labels.Partition == "" {
labels.Partition = reqCtx.DefaultPartition
}
if labels.Namespace == "" {
labels.Namespace = reqCtx.DefaultNamespace
}
// If we have a sameness group, we can return early without further data massage.
if labels.SamenessGroup != "" {
return discovery.QueryTenancy{
Namespace: labels.Namespace,
Partition: labels.Partition,
SamenessGroup: labels.SamenessGroup,
Datacenter: reqCtx.DefaultDatacenter,
// Datacenter is not supported
}, nil
}
@ -234,19 +238,19 @@ func getQueryTenancy(reqCtx Context, queryType discovery.QueryType, querySuffixe
Namespace: labels.Namespace,
Partition: labels.Partition,
Peer: labels.Peer,
Datacenter: getEffectiveDatacenter(labels, reqCtx.DefaultDatacenter),
Datacenter: getEffectiveDatacenter(labels),
}, nil
}
// getEffectiveDatacenter returns the effective datacenter from the parsed labels.
func getEffectiveDatacenter(labels *parsedLabels, defaultDC string) string {
func getEffectiveDatacenter(labels *parsedLabels) string {
switch {
case labels.Datacenter != "":
return labels.Datacenter
case labels.PeerOrDatacenter != "" && labels.Peer != labels.PeerOrDatacenter:
return labels.PeerOrDatacenter
}
return defaultDC
return ""
}
// getQueryTypePartsAndSuffixesFromDNSMessage returns the query type, the parts, and suffixes of the query name.

View File

@ -160,8 +160,7 @@ func Test_buildQueryFromDNSMessage(t *testing.T) {
},
},
requestContext: &Context{
DefaultDatacenter: "default-dc",
DefaultPartition: "default-partition",
DefaultPartition: "default-partition",
},
expectedQuery: &discovery.Query{
QueryType: discovery.QueryTypeWorkload,
@ -169,10 +168,9 @@ func Test_buildQueryFromDNSMessage(t *testing.T) {
Name: "foo",
PortName: "api",
Tenancy: discovery.QueryTenancy{
Namespace: "banana",
Partition: "orange",
Peer: "apple",
Datacenter: "default-dc",
Namespace: "banana",
Partition: "orange",
Peer: "apple",
},
},
},
@ -192,8 +190,7 @@ func Test_buildQueryFromDNSMessage(t *testing.T) {
},
},
requestContext: &Context{
DefaultDatacenter: "default-dc",
DefaultPartition: "default-partition",
DefaultPartition: "default-partition",
},
expectedQuery: &discovery.Query{
QueryType: discovery.QueryTypeService,
@ -203,7 +200,6 @@ func Test_buildQueryFromDNSMessage(t *testing.T) {
Namespace: "banana",
Partition: "orange",
SamenessGroup: "apple",
Datacenter: "default-dc",
},
},
},

View File

@ -14,9 +14,8 @@ import (
"github.com/armon/go-metrics"
"github.com/armon/go-radix"
"github.com/miekg/dns"
"github.com/hashicorp/go-hclog"
"github.com/miekg/dns"
"github.com/hashicorp/consul/agent/config"
"github.com/hashicorp/consul/agent/discovery"
@ -46,13 +45,6 @@ var (
trailingSpacesRE = regexp.MustCompile(" +$")
)
// Context is used augment a DNS message with Consul-specific metadata.
type Context struct {
Token string
DefaultPartition string
DefaultDatacenter string
}
// RouterDynamicConfig is the dynamic configuration that can be hot-reloaded
type RouterDynamicConfig struct {
ARecordLimit int
@ -114,13 +106,12 @@ type dnsRecursor interface {
// Router replaces miekg/dns.ServeMux with a simpler router that only checks for the 2-3 valid domains
// that Consul supports and forwards to a single DiscoveryQueryProcessor handler. If there is no match, it will recurse.
type Router struct {
processor DiscoveryQueryProcessor
recursor dnsRecursor
domain string
altDomain string
datacenter string
nodeName string
logger hclog.Logger
processor DiscoveryQueryProcessor
recursor dnsRecursor
domain string
altDomain string
nodeName string
logger hclog.Logger
tokenFunc func() string
translateAddressFunc func(dc string, addr string, taggedAddresses map[string]string, accept dnsutil.TranslateAddressAccept) string
@ -146,7 +137,6 @@ func NewRouter(cfg Config) (*Router, error) {
recursor: newRecursor(logger),
domain: domain,
altDomain: altDomain,
datacenter: cfg.AgentConfig.Datacenter,
logger: logger,
nodeName: cfg.AgentConfig.NodeName,
tokenFunc: cfg.TokenFunc,
@ -176,6 +166,7 @@ func (r *Router) HandleRequest(req *dns.Msg, reqCtx Context, remoteAddress net.A
}
r.logger.Trace("received request", "question", req.Question[0].Name, "type", dns.Type(req.Question[0].Qtype).String())
r.normalizeContext(&reqCtx)
defer func(s time.Time, q dns.Question) {
metrics.MeasureSinceWithLabels([]string{"dns", "query"}, s,
@ -319,10 +310,9 @@ func (r *Router) trimDomain(questionName string) string {
}
// ServeDNS implements the miekg/dns.Handler interface.
// This is a standard DNS listener, so we inject a default request context based on the agent's config.
// This is a standard DNS listener.
func (r *Router) ServeDNS(w dns.ResponseWriter, req *dns.Msg) {
reqCtx := r.defaultAgentDNSRequestContext()
out := r.HandleRequest(req, reqCtx, w.RemoteAddr())
out := r.HandleRequest(req, Context{}, w.RemoteAddr())
w.WriteMsg(out)
}
@ -420,16 +410,6 @@ func (r *Router) GetConfig() *RouterDynamicConfig {
return r.dynamicConfig.Load().(*RouterDynamicConfig)
}
// defaultAgentDNSRequestContext returns a default request context based on the agent's config.
func (r *Router) defaultAgentDNSRequestContext() Context {
return Context{
Token: r.tokenFunc(),
DefaultDatacenter: r.datacenter,
// We don't need to specify the agent's partition here because that will be handled further down the stack
// in the query processor.
}
}
// getErrorFromECSNotGlobalError returns the underlying error from an ECSNotGlobalError, if it exists.
func getErrorFromECSNotGlobalError(err error) error {
if errors.Is(err, discovery.ErrECSNotGlobal) {
@ -471,6 +451,16 @@ func validateAndNormalizeRequest(req *dns.Msg) error {
return nil
}
// normalizeContext makes sure context information is populated with agent defaults as needed.
// Right now this is just the ACL token. We do this in the router with the token because DNS doesn't
// allow a token to be passed in the request, and we expect ACL tokens upfront in APIs when they are enabled.
// Tenancy information is left out because it is safe/expected to assume agent defaults in the backend lookup.
func (r *Router) normalizeContext(ctx *Context) {
if ctx.Token == "" {
ctx.Token = r.tokenFunc()
}
}
// stripAnyFailoverSuffix strips off the suffixes that may have been added to the request name.
func stripAnyFailoverSuffix(target string) (string, bool) {
enableFailover := false

View File

@ -52,22 +52,6 @@ type HandleTestCase struct {
response *dns.Msg
}
var testSOA = &dns.SOA{
Hdr: dns.RR_Header{
Name: "consul.",
Rrtype: dns.TypeSOA,
Class: dns.ClassINET,
Ttl: 4,
},
Ns: "ns.consul.",
Mbox: "hostmaster.consul.",
Serial: uint32(time.Now().Unix()),
Refresh: 1,
Retry: 2,
Expire: 3,
Minttl: 4,
}
func Test_HandleRequest_Validation(t *testing.T) {
testCases := []HandleTestCase{
{
@ -93,6 +77,157 @@ func Test_HandleRequest_Validation(t *testing.T) {
Extra: nil,
},
},
// Context Tests
{
name: "When a request context is provided, use those field in the query",
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "foo.service.consul.",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
},
requestContext: &Context{
Token: "test-token",
DefaultNamespace: "test-namespace",
DefaultPartition: "test-partition",
},
configureDataFetcher: func(fetcher discovery.CatalogDataFetcher) {
result := []*discovery.Result{
{
Type: discovery.ResultTypeNode,
Node: &discovery.Location{Name: "foo", Address: "1.2.3.4"},
Tenancy: discovery.ResultTenancy{
Namespace: "test-namespace",
Partition: "test-partition",
},
},
}
fetcher.(*discovery.MockCatalogDataFetcher).
On("FetchEndpoints", mock.Anything, mock.Anything, mock.Anything).
Return(result, nil).
Run(func(args mock.Arguments) {
ctx := args.Get(0).(discovery.Context)
req := args.Get(1).(*discovery.QueryPayload)
reqType := args.Get(2).(discovery.LookupType)
require.Equal(t, "test-token", ctx.Token)
require.Equal(t, "foo", req.Name)
require.Equal(t, "test-namespace", req.Tenancy.Namespace)
require.Equal(t, "test-partition", req.Tenancy.Partition)
require.Equal(t, discovery.LookupTypeService, reqType)
})
},
validateAndNormalizeExpected: true,
response: &dns.Msg{
MsgHdr: dns.MsgHdr{
Response: true,
Authoritative: true,
},
Compress: true,
Question: []dns.Question{
{
Name: "foo.service.consul.",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
Answer: []dns.RR{
&dns.A{
Hdr: dns.RR_Header{
Name: "foo.service.consul.",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 123,
},
A: net.ParseIP("1.2.3.4"),
},
},
},
},
{
name: "When a request context is provided, values do not override explicit tenancy",
request: &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
},
Question: []dns.Question{
{
Name: "foo.service.bar.ns.baz.ap.consul.",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
},
requestContext: &Context{
Token: "test-token",
DefaultNamespace: "test-namespace",
DefaultPartition: "test-partition",
},
configureDataFetcher: func(fetcher discovery.CatalogDataFetcher) {
result := []*discovery.Result{
{
Type: discovery.ResultTypeNode,
Node: &discovery.Location{Name: "foo", Address: "1.2.3.4"},
Tenancy: discovery.ResultTenancy{
Namespace: "bar",
Partition: "baz",
},
},
}
fetcher.(*discovery.MockCatalogDataFetcher).
On("FetchEndpoints", mock.Anything, mock.Anything, mock.Anything).
Return(result, nil).
Run(func(args mock.Arguments) {
ctx := args.Get(0).(discovery.Context)
req := args.Get(1).(*discovery.QueryPayload)
reqType := args.Get(2).(discovery.LookupType)
require.Equal(t, "test-token", ctx.Token)
require.Equal(t, "foo", req.Name)
require.Equal(t, "bar", req.Tenancy.Namespace)
require.Equal(t, "baz", req.Tenancy.Partition)
require.Equal(t, discovery.LookupTypeService, reqType)
})
},
validateAndNormalizeExpected: true,
response: &dns.Msg{
MsgHdr: dns.MsgHdr{
Response: true,
Authoritative: true,
},
Compress: true,
Question: []dns.Question{
{
Name: "foo.service.bar.ns.baz.ap.consul.",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
},
Answer: []dns.RR{
&dns.A{
Hdr: dns.RR_Header{
Name: "foo.service.bar.ns.baz.ap.consul.",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 123,
},
A: net.ParseIP("1.2.3.4"),
},
},
},
},
}
for _, tc := range testCases {

View File

@ -72,9 +72,10 @@ func (s *ServerV2) Query(ctx context.Context, req *pbdns.QueryRequest) (*pbdns.Q
return nil, status.Error(codes.Internal, fmt.Sprintf("failure decoding dns request: %s", err.Error()))
}
// TODO (v2-dns): parse token and other context metadata from the grpc request/metadata (NET-7885)
reqCtx := agentdns.Context{
Token: s.TokenFunc(),
reqCtx, err := agentdns.NewContextFromGRPCContext(ctx)
if err != nil {
s.Logger.Error("error parsing DNS context from grpc metadata", "err", err)
return nil, status.Error(codes.Internal, fmt.Sprintf("error parsing DNS context from grpc metadata: %s", err.Error()))
}
resp := s.DNSRouter.HandleRequest(msg, reqCtx, remote)

View File

@ -10,6 +10,8 @@ import (
"github.com/hashicorp/go-hclog"
"github.com/miekg/dns"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"google.golang.org/grpc/metadata"
agentdns "github.com/hashicorp/consul/agent/dns"
"github.com/hashicorp/consul/proto-public/pbdns"
@ -50,6 +52,7 @@ func (s *DNSTestSuite) TestProxy_V2Success() {
question string
configureRouter func(router *agentdns.MockDNSRouter)
clientQuery func(qR *pbdns.QueryRequest)
metadata map[string]string
expectedErr error
}{
@ -73,6 +76,28 @@ func (s *DNSTestSuite) TestProxy_V2Success() {
qR.Protocol = pbdns.Protocol_PROTOCOL_TCP
},
},
"happy path with context variables set": {
question: "abc.com.",
configureRouter: func(router *agentdns.MockDNSRouter) {
router.On("HandleRequest", mock.Anything, mock.Anything, mock.Anything).
Run(func(args mock.Arguments) {
ctx, ok := args.Get(1).(agentdns.Context)
require.True(s.T(), ok, "error casting to agentdns.Context")
require.Equal(s.T(), "test-token", ctx.Token, "token not set in context")
require.Equal(s.T(), "test-namespace", ctx.DefaultNamespace, "namespace not set in context")
require.Equal(s.T(), "test-partition", ctx.DefaultPartition, "partition not set in context")
}).
Return(basicResponse(), nil)
},
clientQuery: func(qR *pbdns.QueryRequest) {
qR.Protocol = pbdns.Protocol_PROTOCOL_UDP
},
metadata: map[string]string{
"x-consul-token": "test-token",
"x-consul-namespace": "test-namespace",
"x-consul-partition": "test-partition",
},
},
"No protocol set": {
question: "abc.com.",
clientQuery: func(qR *pbdns.QueryRequest) {},
@ -108,9 +133,18 @@ func (s *DNSTestSuite) TestProxy_V2Success() {
bytes, _ := req.Pack()
ctx := context.Background()
if len(tc.metadata) > 0 {
md := metadata.MD{}
for k, v := range tc.metadata {
md.Set(k, v)
}
ctx = metadata.NewOutgoingContext(ctx, md)
}
clientReq := &pbdns.QueryRequest{Msg: bytes}
tc.clientQuery(clientReq)
clientResp, err := client.Query(context.Background(), clientReq)
clientResp, err := client.Query(ctx, clientReq)
if tc.expectedErr != nil {
s.Require().Error(err, "no errror calling gRPC endpoint")
s.Require().ErrorContains(err, tc.expectedErr.Error())