feat(v2dns): catalog v2 service query support (#20564)

This commit is contained in:
Dan Stough 2024-02-09 17:41:40 -05:00 committed by GitHub
parent e24b73a6dd
commit 01001f630e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 1357 additions and 292 deletions

View File

@ -105,13 +105,16 @@ const (
// It is the responsibility of the DNS encoder to know what to do with
// each Result, based on the query type.
type Result struct {
Service *Location // The name and address of the service.
Node *Location // The name and address of the node.
PortName string // Used to generate a fgdn when a specifc port was queried
PortNumber uint32 // SRV queries
Metadata map[string]string // Used to collect metadata into TXT Records
Type ResultType // Used to reconstruct the fqdn name of the resource
DNS DNSConfig // Used for DNS-specific configuration for this result
Service *Location // The name and address of the service.
Node *Location // The name and address of the node.
Weight uint32 // SRV queries
Metadata map[string]string // Used to collect metadata into TXT Records
Type ResultType // Used to reconstruct the fqdn name of the resource
DNS DNSConfig // Used for DNS-specific configuration for this result
// Ports include anything the node/service/workload implements. These are filtered if requested by the client.
// They are used in to generate the FQDN and SRV port numbers in V2 Catalog responses.
Ports []Port
Tenancy ResultTenancy
}
@ -127,6 +130,11 @@ type DNSConfig struct {
Weight uint32 // SRV queries
}
type Port struct {
Name string
Number uint32
}
// ResultTenancy is used to reconstruct the fqdn name of the resource.
type ResultTenancy struct {
Namespace string

View File

@ -422,8 +422,10 @@ func (f *V1DataFetcher) buildResultsFromServiceNodes(nodes []structs.CheckServic
TTL: ttlOverride,
Weight: uint32(findWeight(n)),
},
PortNumber: uint32(f.translateServicePortFunc(n.Node.Datacenter, n.Service.Port, n.Service.TaggedAddresses)),
Metadata: n.Node.Meta,
Ports: []Port{
{Number: uint32(f.translateServicePortFunc(n.Node.Datacenter, n.Service.Port, n.Service.TaggedAddresses))},
},
Metadata: n.Node.Meta,
Tenancy: ResultTenancy{
Namespace: n.Service.NamespaceOrEmpty(),
Partition: n.Service.PartitionOrEmpty(),

View File

@ -151,6 +151,11 @@ func Test_FetchEndpoints(t *testing.T) {
DNS: DNSConfig{
Weight: 1,
},
Ports: []Port{
{
Number: 0,
},
},
},
}

View File

@ -11,6 +11,7 @@ import (
"strings"
"sync/atomic"
"golang.org/x/exp/slices"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
@ -73,11 +74,13 @@ func (f *V2DataFetcher) FetchEndpoints(reqContext Context, req *QueryPayload, lo
configCtx := f.dynamicConfig.Load().(*v2DataFetcherDynamicConfig)
serviceEndpoints := pbcatalog.ServiceEndpoints{}
resourceObj, err := f.fetchResource(reqContext, *req, pbcatalog.ServiceEndpointsType, &serviceEndpoints)
serviceEndpointsResource, err := f.fetchResource(reqContext, *req, pbcatalog.ServiceEndpointsType, &serviceEndpoints)
if err != nil {
return nil, err
}
f.logger.Trace("shuffling endpoints", "name", req.Name, "endpoints", len(serviceEndpoints.Endpoints))
// Shuffle the endpoints slice
shuffleFunc := func(i, j int) {
serviceEndpoints.Endpoints[i], serviceEndpoints.Endpoints[j] = serviceEndpoints.Endpoints[j], serviceEndpoints.Endpoints[i]
@ -91,10 +94,15 @@ func (f *V2DataFetcher) FetchEndpoints(reqContext Context, req *QueryPayload, lo
}
results := make([]*Result, 0, limit)
for idx := 0; idx < limit; idx++ {
endpoint := serviceEndpoints.Endpoints[idx]
for _, endpoint := range serviceEndpoints.Endpoints[:limit] {
// TODO (v2-dns): filter based on the port name requested
// First we check the endpoint first to make sure that the requested port is matched from the service.
// We error here because we expect all endpoints to have the same ports as the service.
ports := getResultPorts(req, endpoint.Ports) //assuming the logic changed in getResultPorts
if len(ports) == 0 {
f.logger.Debug("could not find matching port in endpoint", "name", req.Name, "port", req.PortName)
return nil, ErrNotFound
}
address, err := f.addressFromWorkloadAddresses(endpoint.Addresses, req.Name)
if err != nil {
@ -103,6 +111,7 @@ func (f *V2DataFetcher) FetchEndpoints(reqContext Context, req *QueryPayload, lo
weight, ok := getEndpointWeight(endpoint, configCtx)
if !ok {
f.logger.Debug("endpoint filtered out because of health status", "name", req.Name, "endpoint", endpoint.GetTargetRef().GetName())
continue
}
@ -111,14 +120,15 @@ func (f *V2DataFetcher) FetchEndpoints(reqContext Context, req *QueryPayload, lo
Address: address,
Name: endpoint.GetTargetRef().GetName(),
},
Type: ResultTypeWorkload, // TODO (v2-dns): I'm not really sure if it's better to have SERVICE OR WORKLOAD here
Type: ResultTypeWorkload,
Tenancy: ResultTenancy{
Namespace: resourceObj.GetId().GetTenancy().GetNamespace(),
Partition: resourceObj.GetId().GetTenancy().GetPartition(),
Namespace: serviceEndpointsResource.GetId().GetTenancy().GetNamespace(),
Partition: serviceEndpointsResource.GetId().GetTenancy().GetPartition(),
},
DNS: DNSConfig{
Weight: weight,
},
Ports: ports,
}
results = append(results, result)
}
@ -145,6 +155,14 @@ func (f *V2DataFetcher) FetchWorkload(reqContext Context, req *QueryPayload) (*R
return nil, err
}
// First we check the endpoint first to make sure that the requested port is matched from the service.
// We error here because we expect all endpoints to have the same ports as the service.
ports := getResultPorts(req, workload.Ports) //assuming the logic changed in getResultPorts
if ports == nil || len(ports) == 0 {
f.logger.Debug("could not find matching port in endpoint", "name", req.Name, "port", req.PortName)
return nil, ErrNotFound
}
address, err := f.addressFromWorkloadAddresses(workload.Addresses, req.Name)
if err != nil {
return nil, err
@ -161,24 +179,10 @@ func (f *V2DataFetcher) FetchWorkload(reqContext Context, req *QueryPayload) (*R
Namespace: tenancy.GetNamespace(),
Partition: tenancy.GetPartition(),
},
Ports: ports,
}
if req.PortName == "" {
return result, nil
}
// If a port is specified, make sure the workload implements that port name.
for name, port := range workload.Ports {
if name == req.PortName {
result.PortName = req.PortName
result.PortNumber = port.Port
return result, nil
}
}
f.logger.Debug("could not find matching port for workload", "name", req.Name, "port", req.PortName)
// Return an ErrNotFound, which is equivalent to NXDOMAIN
return nil, ErrNotFound
return result, nil
}
// FetchPreparedQuery is used to fetch a prepared query from the V2 catalog.
@ -285,6 +289,46 @@ func getEndpointWeight(endpoint *pbcatalog.Endpoint, configCtx *v2DataFetcherDyn
return weight, true
}
// getResultPorts conditionally returns ports from a map based on a query. The results are sorted by name.
func getResultPorts(req *QueryPayload, workloadPorts map[string]*pbcatalog.WorkloadPort) []Port {
if len(workloadPorts) == 0 {
return nil
}
var ports []Port
if req.PortName != "" {
// Make sure the workload implements that port name.
if _, ok := workloadPorts[req.PortName]; !ok {
return nil
}
// In the case that the query asked for a specific port, we only return that port.
ports = []Port{
{
Name: req.PortName,
Number: workloadPorts[req.PortName].Port,
},
}
} else {
// If the client didn't specify a particular port, return all the workload ports.
for name, port := range workloadPorts {
ports = append(ports, Port{
Name: name,
Number: port.Port,
})
}
// Stable Sort
slices.SortStableFunc(ports, func(i, j Port) int {
if i.Name < j.Name {
return -1
} else if i.Name > j.Name {
return 1
}
return 0
})
}
return ports
}
// queryTenancyToResourceTenancy converts a QueryTenancy to a pbresource.Tenancy.
func queryTenancyToResourceTenancy(qTenancy QueryTenancy) *pbresource.Tenancy {
rTenancy := resource.DefaultNamespacedTenancy()

View File

@ -5,6 +5,7 @@ package discovery
import (
"errors"
"fmt"
"testing"
"github.com/stretchr/testify/mock"
@ -50,7 +51,7 @@ func Test_FetchWorkload(t *testing.T) {
Token: "test-token",
},
configureMockClient: func(mockClient *mockpbresource.ResourceServiceClient_Expecter) {
result := getTestWorkloadResponse(t, "", "")
result := getTestWorkloadResponse(t, "foo-1234", "", "")
mockClient.Read(mock.Anything, mock.Anything).
Return(result, nil).
Once().
@ -62,6 +63,16 @@ func Test_FetchWorkload(t *testing.T) {
expectedResult: &Result{
Node: &Location{Name: "foo-1234", Address: "1.2.3.4"},
Type: ResultTypeWorkload,
Ports: []Port{
{
Name: "api",
Number: 5678,
},
{
Name: "mesh",
Number: 21000,
},
},
Tenancy: ResultTenancy{
Namespace: resource.DefaultNamespaceName,
Partition: resource.DefaultPartitionName,
@ -78,7 +89,7 @@ func Test_FetchWorkload(t *testing.T) {
Token: "test-token",
},
configureMockClient: func(mockClient *mockpbresource.ResourceServiceClient_Expecter) {
input := getTestWorkloadResponse(t, "", "")
input := getTestWorkloadResponse(t, "foo-1234", "", "")
mockClient.Read(mock.Anything, mock.Anything).
Return(nil, status.Error(codes.NotFound, "not found")).
Once().
@ -99,7 +110,7 @@ func Test_FetchWorkload(t *testing.T) {
Token: "test-token",
},
configureMockClient: func(mockClient *mockpbresource.ResourceServiceClient_Expecter) {
input := getTestWorkloadResponse(t, "", "")
input := getTestWorkloadResponse(t, "foo-1234", "", "")
mockClient.Read(mock.Anything, mock.Anything).
Return(nil, unknownErr).
Once().
@ -121,7 +132,7 @@ func Test_FetchWorkload(t *testing.T) {
Token: "test-token",
},
configureMockClient: func(mockClient *mockpbresource.ResourceServiceClient_Expecter) {
result := getTestWorkloadResponse(t, "", "")
result := getTestWorkloadResponse(t, "foo-1234", "", "")
mockClient.Read(mock.Anything, mock.Anything).
Return(result, nil).
Once().
@ -131,10 +142,14 @@ func Test_FetchWorkload(t *testing.T) {
})
},
expectedResult: &Result{
Node: &Location{Name: "foo-1234", Address: "1.2.3.4"},
Type: ResultTypeWorkload,
PortName: "api",
PortNumber: 5678,
Node: &Location{Name: "foo-1234", Address: "1.2.3.4"},
Type: ResultTypeWorkload,
Ports: []Port{
{
Name: "api",
Number: 5678,
},
},
Tenancy: ResultTenancy{
Namespace: resource.DefaultNamespaceName,
Partition: resource.DefaultPartitionName,
@ -152,7 +167,7 @@ func Test_FetchWorkload(t *testing.T) {
Token: "test-token",
},
configureMockClient: func(mockClient *mockpbresource.ResourceServiceClient_Expecter) {
result := getTestWorkloadResponse(t, "", "")
result := getTestWorkloadResponse(t, "foo-1234", "", "")
mockClient.Read(mock.Anything, mock.Anything).
Return(result, nil).
Once().
@ -177,7 +192,7 @@ func Test_FetchWorkload(t *testing.T) {
Token: "test-token",
},
configureMockClient: func(mockClient *mockpbresource.ResourceServiceClient_Expecter) {
result := getTestWorkloadResponse(t, "test-namespace", "test-partition")
result := getTestWorkloadResponse(t, "foo-1234", "test-namespace", "test-partition")
mockClient.Read(mock.Anything, mock.Anything).
Return(result, nil).
Once().
@ -191,6 +206,16 @@ func Test_FetchWorkload(t *testing.T) {
expectedResult: &Result{
Node: &Location{Name: "foo-1234", Address: "1.2.3.4"},
Type: ResultTypeWorkload,
Ports: []Port{
{
Name: "api",
Number: 5678,
},
{
Name: "mesh",
Number: 21000,
},
},
Tenancy: ResultTenancy{
Namespace: "test-namespace",
Partition: "test-partition",
@ -240,23 +265,33 @@ func Test_V2FetchEndpoints(t *testing.T) {
Token: "test-token",
},
configureMockClient: func(mockClient *mockpbresource.ResourceServiceClient_Expecter) {
results := []*pbcatalog.Endpoint{
endpoints := []*pbcatalog.Endpoint{
makeEndpoint("consul-1", "1.2.3.4", pbcatalog.Health_HEALTH_PASSING, 0, 0),
}
result := getTestEndpointsResponse(t, "", "", results...)
serviceEndpoints := getTestEndpointsResponse(t, "", "", endpoints...)
mockClient.Read(mock.Anything, mock.Anything).
Return(result, nil).
Return(serviceEndpoints, nil).
Once().
Run(func(args mock.Arguments) {
req := args.Get(1).(*pbresource.ReadRequest)
require.Equal(t, result.GetResource().GetId().GetName(), req.Id.Name)
require.Equal(t, serviceEndpoints.GetResource().GetId().GetName(), req.Id.Name)
})
},
expectedResult: []*Result{
{
Node: &Location{Name: "consul-1", Address: "1.2.3.4"},
Type: ResultTypeWorkload,
Ports: []Port{
{
Name: "api",
Number: 5678,
},
{
Name: "mesh",
Number: 21000,
},
},
Tenancy: ResultTenancy{
Namespace: resource.DefaultNamespaceName,
Partition: resource.DefaultPartitionName,
@ -365,6 +400,16 @@ func Test_V2FetchEndpoints(t *testing.T) {
DNS: DNSConfig{
Weight: 2,
},
Ports: []Port{
{
Name: "api",
Number: 5678,
},
{
Name: "mesh",
Number: 21000,
},
},
},
{
Node: &Location{Name: "consul-2", Address: "2.3.4.5"},
@ -376,6 +421,16 @@ func Test_V2FetchEndpoints(t *testing.T) {
DNS: DNSConfig{
Weight: 3,
},
Ports: []Port{
{
Name: "api",
Number: 5678,
},
{
Name: "mesh",
Number: 21000,
},
},
},
},
},
@ -417,6 +472,16 @@ func Test_V2FetchEndpoints(t *testing.T) {
DNS: DNSConfig{
Weight: 2,
},
Ports: []Port{
{
Name: "api",
Number: 5678,
},
{
Name: "mesh",
Number: 21000,
},
},
},
},
},
@ -452,118 +517,37 @@ func Test_V2FetchEndpoints(t *testing.T) {
require.Equal(t, result.GetResource().GetId().GetName(), req.Id.Name)
})
},
expectedResult: []*Result{
{
Node: &Location{Name: "consul-1", Address: "10.0.0.1"},
Type: ResultTypeWorkload,
Tenancy: ResultTenancy{
Namespace: resource.DefaultNamespaceName,
Partition: resource.DefaultPartitionName,
},
DNS: DNSConfig{
Weight: 1,
},
},
{
Node: &Location{Name: "consul-2", Address: "10.0.0.2"},
Type: ResultTypeWorkload,
Tenancy: ResultTenancy{
Namespace: resource.DefaultNamespaceName,
Partition: resource.DefaultPartitionName,
},
DNS: DNSConfig{
Weight: 1,
},
},
{
Node: &Location{Name: "consul-3", Address: "10.0.0.3"},
Type: ResultTypeWorkload,
Tenancy: ResultTenancy{
Namespace: resource.DefaultNamespaceName,
Partition: resource.DefaultPartitionName,
},
DNS: DNSConfig{
Weight: 1,
},
},
{
Node: &Location{Name: "consul-4", Address: "10.0.0.4"},
Type: ResultTypeWorkload,
Tenancy: ResultTenancy{
Namespace: resource.DefaultNamespaceName,
Partition: resource.DefaultPartitionName,
},
DNS: DNSConfig{
Weight: 1,
},
},
{
Node: &Location{Name: "consul-5", Address: "10.0.0.5"},
Type: ResultTypeWorkload,
Tenancy: ResultTenancy{
Namespace: resource.DefaultNamespaceName,
Partition: resource.DefaultPartitionName,
},
DNS: DNSConfig{
Weight: 1,
},
},
{
Node: &Location{Name: "consul-6", Address: "10.0.0.6"},
Type: ResultTypeWorkload,
Tenancy: ResultTenancy{
Namespace: resource.DefaultNamespaceName,
Partition: resource.DefaultPartitionName,
},
DNS: DNSConfig{
Weight: 1,
},
},
{
Node: &Location{Name: "consul-7", Address: "10.0.0.7"},
Type: ResultTypeWorkload,
Tenancy: ResultTenancy{
Namespace: resource.DefaultNamespaceName,
Partition: resource.DefaultPartitionName,
},
DNS: DNSConfig{
Weight: 1,
},
},
{
Node: &Location{Name: "consul-8", Address: "10.0.0.8"},
Type: ResultTypeWorkload,
Tenancy: ResultTenancy{
Namespace: resource.DefaultNamespaceName,
Partition: resource.DefaultPartitionName,
},
DNS: DNSConfig{
Weight: 1,
},
},
{
Node: &Location{Name: "consul-9", Address: "10.0.0.9"},
Type: ResultTypeWorkload,
Tenancy: ResultTenancy{
Namespace: resource.DefaultNamespaceName,
Partition: resource.DefaultPartitionName,
},
DNS: DNSConfig{
Weight: 1,
},
},
{
Node: &Location{Name: "consul-10", Address: "10.0.0.10"},
Type: ResultTypeWorkload,
Tenancy: ResultTenancy{
Namespace: resource.DefaultNamespaceName,
Partition: resource.DefaultPartitionName,
},
DNS: DNSConfig{
Weight: 1,
},
},
},
expectedResult: func() []*Result {
results := make([]*Result, 0, 10)
for i := 0; i < 10; i++ {
name := fmt.Sprintf("consul-%d", i+1)
address := fmt.Sprintf("10.0.0.%d", i+1)
result := &Result{
Node: &Location{Name: name, Address: address},
Type: ResultTypeWorkload,
Tenancy: ResultTenancy{
Namespace: resource.DefaultNamespaceName,
Partition: resource.DefaultPartitionName,
},
Ports: []Port{
{
Name: "api",
Number: 5678,
},
{
Name: "mesh",
Number: 21000,
},
},
DNS: DNSConfig{
Weight: 1,
},
}
results = append(results, result)
}
return results
}(),
verifyShuffle: true,
},
{
@ -603,6 +587,16 @@ func Test_V2FetchEndpoints(t *testing.T) {
DNS: DNSConfig{
Weight: 1,
},
Ports: []Port{
{
Name: "api",
Number: 5678,
},
{
Name: "mesh",
Number: 21000,
},
},
},
},
},
@ -646,9 +640,88 @@ func Test_V2FetchEndpoints(t *testing.T) {
DNS: DNSConfig{
Weight: 1,
},
Ports: []Port{
{
Name: "api",
Number: 5678,
},
{
Name: "mesh",
Number: 21000,
},
},
},
},
},
{
name: "FetchEndpoints returns only a specific port if is one requested",
queryPayload: &QueryPayload{
Name: "consul",
PortName: "api",
},
context: Context{
Token: "test-token",
},
configureMockClient: func(mockClient *mockpbresource.ResourceServiceClient_Expecter) {
endpoints := []*pbcatalog.Endpoint{
makeEndpoint("consul-1", "10.0.0.1", pbcatalog.Health_HEALTH_PASSING, 0, 0),
}
serviceEndpoints := getTestEndpointsResponse(t, "", "", endpoints...)
mockClient.Read(mock.Anything, mock.Anything).
Return(serviceEndpoints, nil).
Once().
Run(func(args mock.Arguments) {
req := args.Get(1).(*pbresource.ReadRequest)
require.Equal(t, serviceEndpoints.GetResource().GetId().GetName(), req.Id.Name)
})
},
expectedResult: []*Result{
{
Node: &Location{Name: "consul-1", Address: "10.0.0.1"},
Type: ResultTypeWorkload,
Ports: []Port{
{
Name: "api",
Number: 5678,
},
// No mesh port this time
},
Tenancy: ResultTenancy{
Namespace: resource.DefaultNamespaceName,
Partition: resource.DefaultPartitionName,
},
DNS: DNSConfig{
Weight: 1,
},
},
},
},
{
name: "FetchEndpoints returns a name error when a service doesn't implement the requested port",
queryPayload: &QueryPayload{
Name: "consul",
PortName: "banana",
},
context: Context{
Token: "test-token",
},
configureMockClient: func(mockClient *mockpbresource.ResourceServiceClient_Expecter) {
endpoints := []*pbcatalog.Endpoint{
makeEndpoint("consul-1", "10.0.0.1", pbcatalog.Health_HEALTH_PASSING, 0, 0),
}
serviceEndpoints := getTestEndpointsResponse(t, "", "", endpoints...)
mockClient.Read(mock.Anything, mock.Anything).
Return(serviceEndpoints, nil).
Once().
Run(func(args mock.Arguments) {
req := args.Get(1).(*pbresource.ReadRequest)
require.Equal(t, serviceEndpoints.GetResource().GetId().GetName(), req.Id.Name)
})
},
expectedErr: ErrNotFound,
},
}
for _, tc := range tests {
@ -679,18 +752,21 @@ func Test_V2FetchEndpoints(t *testing.T) {
}
}
func getTestWorkloadResponse(t *testing.T, nsOverride string, partitionOverride string) *pbresource.ReadResponse {
func getTestWorkloadResponse(t *testing.T, name string, nsOverride string, partitionOverride string) *pbresource.ReadResponse {
workload := &pbcatalog.Workload{
Addresses: []*pbcatalog.WorkloadAddress{
{
Host: "1.2.3.4",
Ports: []string{"api"},
Ports: []string{"api", "mesh"},
},
},
Ports: map[string]*pbcatalog.WorkloadPort{
"api": {
Port: 5678,
},
"mesh": {
Port: 21000,
},
},
Identity: "test-identity",
}
@ -701,7 +777,7 @@ func getTestWorkloadResponse(t *testing.T, nsOverride string, partitionOverride
resp := &pbresource.ReadResponse{
Resource: &pbresource.Resource{
Id: &pbresource.ID{
Name: "foo-1234",
Name: name,
Type: pbcatalog.WorkloadType,
Tenancy: resource.DefaultNamespacedTenancy(),
},
@ -723,7 +799,16 @@ func makeEndpoint(name string, address string, health pbcatalog.Health, weightPa
endpoint := &pbcatalog.Endpoint{
Addresses: []*pbcatalog.WorkloadAddress{
{
Host: address,
Host: address,
Ports: []string{"api"},
},
},
Ports: map[string]*pbcatalog.WorkloadPort{
"api": {
Port: 5678,
},
"mesh": {
Port: 21000,
},
},
HealthStatus: health,

View File

@ -161,6 +161,8 @@ func (r *Router) handleRequestRecursively(req *dns.Msg, reqCtx Context,
remoteAddress net.Addr, maxRecursionLevel int) *dns.Msg {
configCtx := r.dynamicConfig.Load().(*RouterDynamicConfig)
r.logger.Trace("received request", "question", req.Question[0].Name, "type", dns.Type(req.Question[0].Qtype).String())
err := validateAndNormalizeRequest(req)
if err != nil {
r.logger.Error("error parsing DNS query", "error", err)
@ -177,6 +179,8 @@ func (r *Router) handleRequestRecursively(req *dns.Msg, reqCtx Context,
}
if needRecurse {
r.logger.Trace("checking recursors to handle request", "question", req.Question[0].Name, "type", dns.Type(req.Question[0].Qtype).String())
// This assumes `canRecurse(configCtx)` is true above
resp, err := r.recursor.handle(req, configCtx, remoteAddress)
if err != nil && !errors.Is(err, errRecursionFailed) {
@ -207,6 +211,8 @@ func (r *Router) handleRequestRecursively(req *dns.Msg, reqCtx Context,
isECSGlobal, query, canRecurse(configCtx))
}
r.logger.Trace("serializing results", "question", req.Question[0].Name, "results-found", len(results))
// This needs the question information because it affects the serialization format.
// e.g., the Consul service has the same "results" for both NS and A/AAAA queries, but the serialization differs.
resp, err := r.serializeQueryResults(req, reqCtx, query, results, configCtx, responseDomain, remoteAddress, maxRecursionLevel)
@ -468,6 +474,14 @@ func parseRequestType(req *dns.Msg) requestType {
}
}
func getPortsFromResult(result *discovery.Result) []discovery.Port {
if len(result.Ports) > 0 {
return result.Ports
}
// return one record.
return []discovery.Port{{}}
}
// serializeQueryResults converts a discovery.Result into a DNS message.
func (r *Router) serializeQueryResults(req *dns.Msg, reqCtx Context,
query *discovery.Query, results []*discovery.Result, cfg *RouterDynamicConfig,
@ -486,39 +500,54 @@ func (r *Router) serializeQueryResults(req *dns.Msg, reqCtx Context,
case qType == dns.TypeSOA:
resp.Answer = append(resp.Answer, makeSOARecord(responseDomain, cfg))
for _, result := range results {
ans, ex, ns := r.getAnswerExtraAndNs(result, req, reqCtx, query, cfg, responseDomain, remoteAddress, maxRecursionLevel)
resp.Answer = append(resp.Answer, ans...)
resp.Extra = append(resp.Extra, ex...)
resp.Ns = append(resp.Ns, ns...)
for _, port := range getPortsFromResult(result) {
ans, ex, ns := r.getAnswerExtraAndNs(result, port, req, reqCtx, query, cfg, responseDomain, remoteAddress, maxRecursionLevel)
resp.Answer = append(resp.Answer, ans...)
resp.Extra = append(resp.Extra, ex...)
resp.Ns = append(resp.Ns, ns...)
}
}
case reqType == requestTypeAddress:
for _, result := range results {
ans, ex, ns := r.getAnswerExtraAndNs(result, req, reqCtx, query, cfg, responseDomain, remoteAddress, maxRecursionLevel)
resp.Answer = append(resp.Answer, ans...)
resp.Extra = append(resp.Extra, ex...)
resp.Ns = append(resp.Ns, ns...)
for _, port := range getPortsFromResult(result) {
ans, ex, ns := r.getAnswerExtraAndNs(result, port, req, reqCtx, query, cfg, responseDomain, remoteAddress, maxRecursionLevel)
resp.Answer = append(resp.Answer, ans...)
resp.Extra = append(resp.Extra, ex...)
resp.Ns = append(resp.Ns, ns...)
}
}
case qType == dns.TypeSRV:
handled := make(map[string]struct{})
for _, result := range results {
// Avoid duplicate entries, possible if a node has
// the same service the same port, etc.
for _, port := range getPortsFromResult(result) {
// The datacenter should be empty during translation if it is a peering lookup.
// This should be fine because we should always prefer the WAN address.
//serviceAddress := d.agent.TranslateServiceAddress(lookup.Datacenter, node.Service.Address, node.Service.TaggedAddresses, TranslateAddressAcceptAny)
//servicePort := d.agent.TranslateServicePort(lookup.Datacenter, node.Service.Port, node.Service.TaggedAddresses)
//tuple := fmt.Sprintf("%s:%s:%d", node.Node.Node, serviceAddress, servicePort)
// Avoid duplicate entries, possible if a node has
// the same service the same port, etc.
tuple := fmt.Sprintf("%s:%s:%d", result.Node.Name, result.Service.Address, result.PortNumber)
if _, ok := handled[tuple]; ok {
continue
// The datacenter should be empty during translation if it is a peering lookup.
// This should be fine because we should always prefer the WAN address.
//serviceAddress := d.agent.TranslateServiceAddress(lookup.Datacenter, node.Service.Address, node.Service.TaggedAddresses, TranslateAddressAcceptAny)
//servicePort := d.agent.TranslateServicePort(lookup.Datacenter, node.Service.Port, node.Service.TaggedAddresses)
//tuple := fmt.Sprintf("%s:%s:%d", node.Node.Node, serviceAddress, servicePort)
// TODO (v2-dns): this needs a clean up so we're not assuming this everywhere.
address := ""
if result.Service != nil {
address = result.Service.Address
} else {
address = result.Node.Address
}
tuple := fmt.Sprintf("%s:%s:%d", result.Node.Name, address, port.Number)
if _, ok := handled[tuple]; ok {
continue
}
handled[tuple] = struct{}{}
ans, ex, ns := r.getAnswerExtraAndNs(result, port, req, reqCtx, query, cfg, responseDomain, remoteAddress, maxRecursionLevel)
resp.Answer = append(resp.Answer, ans...)
resp.Extra = append(resp.Extra, ex...)
resp.Ns = append(resp.Ns, ns...)
}
handled[tuple] = struct{}{}
ans, ex, ns := r.getAnswerExtraAndNs(result, req, reqCtx, query, cfg, responseDomain, remoteAddress, maxRecursionLevel)
resp.Answer = append(resp.Answer, ans...)
resp.Extra = append(resp.Extra, ex...)
resp.Ns = append(resp.Ns, ns...)
}
default:
// default will send it to where it does some de-duping while it calls getAnswerExtraAndNs and recurses.
@ -548,43 +577,45 @@ func (r *Router) appendResultsToDNSResponse(req *dns.Msg, reqCtx Context,
count := 0
for _, result := range results {
// Add the node record
had_answer := false
ans, extra, _ := r.getAnswerExtraAndNs(result, req, reqCtx, query, cfg, responseDomain, remoteAddress, maxRecursionLevel)
resp.Extra = append(resp.Extra, extra...)
for _, port := range getPortsFromResult(result) {
if len(ans) == 0 {
continue
}
// Add the node record
had_answer := false
ans, extra, _ := r.getAnswerExtraAndNs(result, port, req, reqCtx, query, cfg, responseDomain, remoteAddress, maxRecursionLevel)
resp.Extra = append(resp.Extra, extra...)
// Avoid duplicate entries, possible if a node has
// the same service on multiple ports, etc.
if _, ok := handled[ans[0].String()]; ok {
continue
}
handled[ans[0].String()] = struct{}{}
switch ans[0].(type) {
case *dns.CNAME:
// keep track of the first CNAME + associated RRs but don't add to the resp.Answer yet
// this will only be added if no non-CNAME RRs are found
if len(answerCNAME) == 0 {
answerCNAME = ans
if len(ans) == 0 {
continue
}
default:
resp.Answer = append(resp.Answer, ans...)
had_answer = true
}
if had_answer {
count++
if count == cfg.ARecordLimit {
// We stop only if greater than 0 or we reached the limit
return
// Avoid duplicate entries, possible if a node has
// the same service on multiple ports, etc.
if _, ok := handled[ans[0].String()]; ok {
continue
}
handled[ans[0].String()] = struct{}{}
switch ans[0].(type) {
case *dns.CNAME:
// keep track of the first CNAME + associated RRs but don't add to the resp.Answer yet
// this will only be added if no non-CNAME RRs are found
if len(answerCNAME) == 0 {
answerCNAME = ans
}
default:
resp.Answer = append(resp.Answer, ans...)
had_answer = true
}
if had_answer {
count++
if count == cfg.ARecordLimit {
// We stop only if greater than 0 or we reached the limit
return
}
}
}
}
if len(resp.Answer) == 0 && len(answerCNAME) > 0 {
resp.Answer = answerCNAME
}
@ -872,9 +903,10 @@ func buildAddressResults(req *dns.Msg) ([]*discovery.Result, error) {
}
// getAnswerAndExtra creates the dns answer and extra from discovery results.
func (r *Router) getAnswerExtraAndNs(result *discovery.Result, req *dns.Msg, reqCtx Context,
func (r *Router) getAnswerExtraAndNs(result *discovery.Result, port discovery.Port, req *dns.Msg, reqCtx Context,
query *discovery.Query, cfg *RouterDynamicConfig, domain string, remoteAddress net.Addr,
maxRecursionLevel int) (answer []dns.RR, extra []dns.RR, ns []dns.RR) {
serviceAddress := newDNSAddress("")
if result.Service != nil {
serviceAddress = newDNSAddress(result.Service.Address)
@ -908,7 +940,7 @@ func (r *Router) getAnswerExtraAndNs(result *discovery.Result, req *dns.Msg, req
ptr := &dns.PTR{
Hdr: dns.RR_Header{Name: qName, Rrtype: dns.TypePTR, Class: dns.ClassINET, Ttl: 0},
Ptr: canonicalNameForResult(result.Type, ptrTarget, domain, result.Tenancy, result.PortName),
Ptr: canonicalNameForResult(result.Type, ptrTarget, domain, result.Tenancy, port.Name),
}
answer = append(answer, ptr)
case qType == dns.TypeNS:
@ -918,7 +950,7 @@ func (r *Router) getAnswerExtraAndNs(result *discovery.Result, req *dns.Msg, req
if parseRequestType(req) == requestTypeConsul && resultType == discovery.ResultTypeService {
resultType = discovery.ResultTypeNode
}
fqdn := canonicalNameForResult(resultType, target, domain, result.Tenancy, result.PortName)
fqdn := canonicalNameForResult(resultType, target, domain, result.Tenancy, port.Name)
extraRecord := makeIPBasedRecord(fqdn, nodeAddress, ttl) // TODO (v2-dns): this is not sufficient, because recursion and CNAMES are supported
answer = append(answer, makeNSRecord(domain, fqdn, ttl))
@ -926,7 +958,7 @@ func (r *Router) getAnswerExtraAndNs(result *discovery.Result, req *dns.Msg, req
case qType == dns.TypeSOA:
// TODO (v2-dns): fqdn in V1 has the datacenter included, this would need to be added to discovery.Result
// to be returned in the result.
fqdn := canonicalNameForResult(result.Type, result.Node.Name, domain, result.Tenancy, result.PortName)
fqdn := canonicalNameForResult(result.Type, result.Node.Name, domain, result.Tenancy, port.Name)
extraRecord := makeIPBasedRecord(fqdn, nodeAddress, ttl) // TODO (v2-dns): this is not sufficient, because recursion and CNAMES are supported
ns = append(ns, makeNSRecord(domain, fqdn, ttl))
@ -934,18 +966,18 @@ func (r *Router) getAnswerExtraAndNs(result *discovery.Result, req *dns.Msg, req
case qType == dns.TypeSRV:
// We put A/AAAA/CNAME records in the additional section for SRV requests
a, e := r.getAnswerExtrasForAddressAndTarget(nodeAddress, serviceAddress, req, reqCtx,
result, ttl, remoteAddress, cfg, domain, maxRecursionLevel)
result, port, ttl, remoteAddress, cfg, domain, maxRecursionLevel)
answer = append(answer, a...)
extra = append(extra, e...)
default:
a, e := r.getAnswerExtrasForAddressAndTarget(nodeAddress, serviceAddress, req, reqCtx,
result, ttl, remoteAddress, cfg, domain, maxRecursionLevel)
result, port, ttl, remoteAddress, cfg, domain, maxRecursionLevel)
answer = append(answer, a...)
extra = append(extra, e...)
}
a, e := getAnswerAndExtraTXT(req, cfg, qName, result, ttl, domain, query)
a, e := getAnswerAndExtraTXT(req, cfg, qName, result, ttl, domain, query, &port)
answer = append(answer, a...)
extra = append(extra, e...)
return
@ -953,7 +985,7 @@ func (r *Router) getAnswerExtraAndNs(result *discovery.Result, req *dns.Msg, req
// getAnswerExtrasForAddressAndTarget creates the dns answer and extra from nodeAddress and serviceAddress dnsAddress pairs.
func (r *Router) getAnswerExtrasForAddressAndTarget(nodeAddress *dnsAddress, serviceAddress *dnsAddress, req *dns.Msg,
reqCtx Context, result *discovery.Result, ttl uint32, remoteAddress net.Addr,
reqCtx Context, result *discovery.Result, port discovery.Port, ttl uint32, remoteAddress net.Addr,
cfg *RouterDynamicConfig, domain string, maxRecursionLevel int) (answer []dns.RR, extra []dns.RR) {
qName := req.Question[0].Name
reqType := parseRequestType(req)
@ -961,21 +993,20 @@ func (r *Router) getAnswerExtrasForAddressAndTarget(nodeAddress *dnsAddress, ser
switch {
case (reqType == requestTypeAddress || result.Type == discovery.ResultTypeVirtual) &&
serviceAddress.IsEmptyString() && nodeAddress.IsIP():
a, e := getAnswerExtrasForIP(qName, nodeAddress, req.Question[0], reqType,
result, ttl, domain)
a, e := getAnswerExtrasForIP(qName, nodeAddress, req.Question[0], reqType, result, ttl, domain, &port)
answer = append(answer, a...)
extra = append(extra, e...)
case result.Type == discovery.ResultTypeNode && nodeAddress.IsIP():
canonicalNodeName := canonicalNameForResult(result.Type, result.Node.Name, domain, result.Tenancy, result.PortName)
canonicalNodeName := canonicalNameForResult(result.Type, result.Node.Name, domain, result.Tenancy, port.Name)
a, e := getAnswerExtrasForIP(canonicalNodeName, nodeAddress, req.Question[0], reqType,
result, ttl, domain)
result, ttl, domain, &port)
answer = append(answer, a...)
extra = append(extra, e...)
case result.Type == discovery.ResultTypeNode && !nodeAddress.IsIP():
a, e := r.makeRecordFromFQDN(serviceAddress.FQDN(), result, req, reqCtx, cfg,
ttl, remoteAddress, maxRecursionLevel)
a, e := r.makeRecordFromFQDN(result, req, reqCtx, cfg,
ttl, remoteAddress, maxRecursionLevel, serviceAddress.FQDN(), &port)
answer = append(answer, a...)
extra = append(extra, e...)
@ -984,40 +1015,39 @@ func (r *Router) getAnswerExtrasForAddressAndTarget(nodeAddress *dnsAddress, ser
// There is no service address and the node address is an IP
case serviceAddress.IsEmptyString() && nodeAddress.IsIP():
canonicalNodeName := canonicalNameForResult(discovery.ResultTypeNode, result.Node.Name, domain, result.Tenancy, result.PortName)
a, e := getAnswerExtrasForIP(canonicalNodeName, nodeAddress, req.Question[0], reqType,
result, ttl, domain)
resultType := discovery.ResultTypeNode
if result.Type == discovery.ResultTypeWorkload {
resultType = discovery.ResultTypeWorkload
}
canonicalNodeName := canonicalNameForResult(resultType, result.Node.Name, domain, result.Tenancy, port.Name)
a, e := getAnswerExtrasForIP(canonicalNodeName, nodeAddress, req.Question[0], reqType, result, ttl, domain, &port)
answer = append(answer, a...)
extra = append(extra, e...)
// There is no service address and the node address is a FQDN (external service)
case serviceAddress.IsEmptyString():
a, e := r.makeRecordFromFQDN(nodeAddress.FQDN(), result, req, reqCtx, cfg,
ttl, remoteAddress, maxRecursionLevel)
a, e := r.makeRecordFromFQDN(result, req, reqCtx, cfg, ttl, remoteAddress, maxRecursionLevel, nodeAddress.FQDN(), &port)
answer = append(answer, a...)
extra = append(extra, e...)
// The service address is an IP
case serviceAddress.IsIP():
canonicalServiceName := canonicalNameForResult(discovery.ResultTypeService, result.Service.Name, domain, result.Tenancy, result.PortName)
a, e := getAnswerExtrasForIP(canonicalServiceName, serviceAddress, req.Question[0], reqType,
result, ttl, domain)
canonicalServiceName := canonicalNameForResult(discovery.ResultTypeService, result.Service.Name, domain, result.Tenancy, port.Name)
a, e := getAnswerExtrasForIP(canonicalServiceName, serviceAddress, req.Question[0], reqType, result, ttl, domain, &port)
answer = append(answer, a...)
extra = append(extra, e...)
// If the service address is a CNAME for the service we are looking
// for then use the node address.
case serviceAddress.FQDN() == req.Question[0].Name && nodeAddress.IsIP():
canonicalNodeName := canonicalNameForResult(discovery.ResultTypeNode, result.Node.Name, domain, result.Tenancy, result.PortName)
a, e := getAnswerExtrasForIP(canonicalNodeName, nodeAddress, req.Question[0], reqType,
result, ttl, domain)
canonicalNodeName := canonicalNameForResult(discovery.ResultTypeNode, result.Node.Name, domain, result.Tenancy, port.Name)
a, e := getAnswerExtrasForIP(canonicalNodeName, nodeAddress, req.Question[0], reqType, result, ttl, domain, &port)
answer = append(answer, a...)
extra = append(extra, e...)
// The service address is a FQDN (internal or external service name)
default:
a, e := r.makeRecordFromFQDN(serviceAddress.FQDN(), result, req, reqCtx, cfg,
ttl, remoteAddress, maxRecursionLevel)
a, e := r.makeRecordFromFQDN(result, req, reqCtx, cfg, ttl, remoteAddress, maxRecursionLevel, serviceAddress.FQDN(), &port)
answer = append(answer, a...)
extra = append(extra, e...)
}
@ -1028,7 +1058,7 @@ func (r *Router) getAnswerExtrasForAddressAndTarget(nodeAddress *dnsAddress, ser
// getAnswerAndExtraTXT determines whether a TXT needs to be create and then
// returns the TXT record in the answer or extra depending on the question type.
func getAnswerAndExtraTXT(req *dns.Msg, cfg *RouterDynamicConfig, qName string,
result *discovery.Result, ttl uint32, domain string, query *discovery.Query) (answer []dns.RR, extra []dns.RR) {
result *discovery.Result, ttl uint32, domain string, query *discovery.Query, port *discovery.Port) (answer []dns.RR, extra []dns.RR) {
if !shouldAppendTXTRecord(query, cfg, req) {
return
}
@ -1042,7 +1072,7 @@ func getAnswerAndExtraTXT(req *dns.Msg, cfg *RouterDynamicConfig, qName string,
!serviceAddress.IsInternalFQDN(domain) &&
!serviceAddress.IsExternalFQDN(domain) {
recordHeaderName = canonicalNameForResult(discovery.ResultTypeNode, result.Node.Name,
domain, result.Tenancy, result.PortName)
domain, result.Tenancy, port.Name)
}
qType := req.Question[0].Qtype
generateMeta := false
@ -1085,7 +1115,7 @@ func shouldAppendTXTRecord(query *discovery.Query, cfg *RouterDynamicConfig, req
// getAnswerExtrasForIP creates the dns answer and extra from IP dnsAddress pairs.
func getAnswerExtrasForIP(name string, addr *dnsAddress, question dns.Question,
reqType requestType, result *discovery.Result, ttl uint32, domain string) (answer []dns.RR, extra []dns.RR) {
reqType requestType, result *discovery.Result, ttl uint32, domain string, port *discovery.Port) (answer []dns.RR, extra []dns.RR) {
qType := question.Qtype
canReturnARecord := qType == dns.TypeSRV || qType == dns.TypeA || qType == dns.TypeANY || qType == dns.TypeNS || qType == dns.TypeTXT
canReturnAAAARecord := qType == dns.TypeSRV || qType == dns.TypeAAAA || qType == dns.TypeANY || qType == dns.TypeNS || qType == dns.TypeTXT
@ -1119,7 +1149,10 @@ func getAnswerExtrasForIP(name string, addr *dnsAddress, question dns.Question,
// as well as the target of the SRV record.
recHdrName = encodeIPAsFqdn(result, addr.IP(), domain)
}
srv := makeSRVRecord(name, recHdrName, result, ttl)
if result.Type == discovery.ResultTypeWorkload {
recHdrName = canonicalNameForResult(result.Type, result.Node.Name, domain, result.Tenancy, port.Name)
}
srv := makeSRVRecord(name, recHdrName, result, ttl, port)
answer = append(answer, srv)
}
@ -1215,9 +1248,7 @@ func makeIPBasedRecord(name string, addr *dnsAddress, ttl uint32) dns.RR {
}
}
func (r *Router) makeRecordFromFQDN(fqdn string, result *discovery.Result,
req *dns.Msg, reqCtx Context, cfg *RouterDynamicConfig, ttl uint32,
remoteAddress net.Addr, maxRecursionLevel int) ([]dns.RR, []dns.RR) {
func (r *Router) makeRecordFromFQDN(result *discovery.Result, req *dns.Msg, reqCtx Context, cfg *RouterDynamicConfig, ttl uint32, remoteAddress net.Addr, maxRecursionLevel int, fqdn string, port *discovery.Port) ([]dns.RR, []dns.RR) {
edns := req.IsEdns0() != nil
q := req.Question[0]
@ -1240,10 +1271,8 @@ MORE_REC:
}
if q.Qtype == dns.TypeSRV {
answers := []dns.RR{
makeSRVRecord(q.Name, fqdn, result, ttl),
}
return answers, additional
answer := makeSRVRecord(q.Name, fqdn, result, ttl, port)
return []dns.RR{answer}, additional
}
address := ""
@ -1275,7 +1304,7 @@ func makeCNAMERecord(name string, target string, ttl uint32) *dns.CNAME {
}
// func makeSRVRecord returns an SRV record for the given name and target.
func makeSRVRecord(name, target string, result *discovery.Result, ttl uint32) *dns.SRV {
func makeSRVRecord(name, target string, result *discovery.Result, ttl uint32, port *discovery.Port) *dns.SRV {
return &dns.SRV{
Hdr: dns.RR_Header{
Name: name,
@ -1285,7 +1314,7 @@ func makeSRVRecord(name, target string, result *discovery.Result, ttl uint32) *d
},
Priority: 1,
Weight: uint16(result.DNS.Weight),
Port: uint16(result.PortNumber),
Port: uint16(port.Number),
Target: target,
}
}

View File

@ -36,8 +36,3 @@ func canonicalNameForResult(resultType discovery.ResultType, target, domain stri
}
return ""
}
// getDefaultPartitionName returns the default partition name.
func getDefaultPartitionName() string {
return ""
}

File diff suppressed because it is too large Load Diff