diff --git a/command/agent/agent.go b/command/agent/agent.go index 79b94d55db..fcd647bd18 100644 --- a/command/agent/agent.go +++ b/command/agent/agent.go @@ -1069,6 +1069,18 @@ func (a *Agent) UpdateCheck(checkID types.CheckID, status, output string) error return nil } +// TranslateAddr is used to provide the final, translated address for a node, +// depending on how this agent and the other node are configured. +func (a *Agent) TranslateAddr(dc string, addr string, taggedAddr map[string]string) string { + if a.config.TranslateWanAddrs && (a.config.Datacenter != dc) { + wanAddr := taggedAddr["wan"] + if wanAddr != "" { + addr = wanAddr + } + } + return addr +} + // persistCheckState is used to record the check status into the data dir. // This allows the state to be restored on a later agent start. Currently // only useful for TTL based checks. diff --git a/command/agent/catalog_endpoint.go b/command/agent/catalog_endpoint.go index e8df2ee17b..1aa05447c6 100644 --- a/command/agent/catalog_endpoint.go +++ b/command/agent/catalog_endpoint.go @@ -77,6 +77,12 @@ func (s *HTTPServer) CatalogNodes(resp http.ResponseWriter, req *http.Request) ( if out.Nodes == nil { out.Nodes = make(structs.Nodes, 0) } + + for _, node := range out.Nodes { + addr := s.agent.TranslateAddr(args.Datacenter, node.Address, node.TaggedAddresses) + node.Address = addr + } + return out.Nodes, nil } @@ -129,6 +135,12 @@ func (s *HTTPServer) CatalogServiceNodes(resp http.ResponseWriter, req *http.Req if out.ServiceNodes == nil { out.ServiceNodes = make(structs.ServiceNodes, 0) } + + for _, serviceNode := range out.ServiceNodes { + addr := s.agent.TranslateAddr(args.Datacenter, serviceNode.Address, serviceNode.TaggedAddresses) + serviceNode.Address = addr + } + return out.ServiceNodes, nil } @@ -153,5 +165,12 @@ func (s *HTTPServer) CatalogNodeServices(resp http.ResponseWriter, req *http.Req if err := s.agent.RPC("Catalog.NodeServices", &args, &out); err != nil { return nil, err } + + if out.NodeServices != nil { + node := out.NodeServices.Node + addr := s.agent.TranslateAddr(args.Datacenter, node.Address, node.TaggedAddresses) + node.Address = addr + } + return out.NodeServices, nil } diff --git a/command/agent/catalog_endpoint_test.go b/command/agent/catalog_endpoint_test.go index f4ea16d45f..ab78306b2d 100644 --- a/command/agent/catalog_endpoint_test.go +++ b/command/agent/catalog_endpoint_test.go @@ -145,6 +145,95 @@ func TestCatalogNodes(t *testing.T) { } } +func TestCatalogNodes_WanTranslation(t *testing.T) { + httpCtx1, httpCtx2 := setupWanHTTPServers(t) + defer shutdownHTTPServer(httpCtx1) + defer shutdownHTTPServer(httpCtx2) + srv1 := httpCtx1.srv + srv2 := httpCtx2.srv + + // Register a node with DC2 + { + args := &structs.RegisterRequest{ + Datacenter: "dc2", + Node: "wan_translation_test", + Address: "127.0.0.1", + TaggedAddresses: map[string]string{ + "wan": "127.0.0.2", + }, + Service: &structs.NodeService{ + Service: "http_wan_translation_test", + }, + } + + var out struct{} + if err := srv2.agent.RPC("Catalog.Register", args, &out); err != nil { + t.Fatalf("err: %v", err) + } + } + + req, err := http.NewRequest("GET", "/v1/catalog/nodes?dc=dc2", nil) + if err != nil { + t.Fatalf("err: %v", err) + } + + // get nodes for DC2 from DC1 + resp1 := httptest.NewRecorder() + obj1, err1 := srv1.CatalogNodes(resp1, req) + if err1 != nil { + t.Fatalf("err: %v", err1) + } + + // Verify an index is set + assertIndex(t, resp1) + + nodes1 := obj1.(structs.Nodes) + if len(nodes1) != 2 { + t.Fatalf("bad: %v", obj1) + } + + var node1 *structs.Node + + for _, node := range nodes1 { + if node.Node == "wan_translation_test" { + node1 = node + } + } + + // Expect that DC1 gives us a public address (since the node is in DC2) + if node1.Address != "127.0.0.2" { + t.Fatalf("bad: %v", node1) + } + + // get nodes for DC2 from DC2 + resp2 := httptest.NewRecorder() + obj2, err2 := srv2.CatalogNodes(resp2, req) + if err2 != nil { + t.Fatalf("err: %v", err2) + } + + // Verify an index is set + assertIndex(t, resp2) + + nodes2 := obj2.(structs.Nodes) + if len(nodes2) != 2 { + t.Fatalf("bad: %v", obj2) + } + + var node2 *structs.Node + + for _, node := range nodes2 { + if node.Node == "wan_translation_test" { + node2 = node + } + } + + // Expect that DC2 gives us a private address (since the node is in DC2) + if node2.Address != "127.0.0.1" { + t.Fatalf("bad: %v", node2) + } +} + func TestCatalogNodes_Blocking(t *testing.T) { dir, srv := makeHTTPServer(t) defer os.RemoveAll(dir) @@ -407,6 +496,81 @@ func TestCatalogServiceNodes(t *testing.T) { } } +func TestCatalogServiceNodes_WanTranslation(t *testing.T) { + httpCtx1, httpCtx2 := setupWanHTTPServers(t) + defer shutdownHTTPServer(httpCtx1) + defer shutdownHTTPServer(httpCtx2) + srv1 := httpCtx1.srv + srv2 := httpCtx2.srv + + // Register a node with DC2 + { + args := &structs.RegisterRequest{ + Datacenter: "dc2", + Node: "foo", + Address: "127.0.0.1", + TaggedAddresses: map[string]string{ + "wan": "127.0.0.2", + }, + Service: &structs.NodeService{ + Service: "http_wan_translation_test", + }, + } + + var out struct{} + if err := srv2.agent.RPC("Catalog.Register", args, &out); err != nil { + t.Fatalf("err: %v", err) + } + } + + req, err := http.NewRequest("GET", "/v1/catalog/service/http_wan_translation_test?dc=dc2", nil) + if err != nil { + t.Fatalf("err: %v", err) + } + + // Ask HTTP server on DC1 for the node + resp1 := httptest.NewRecorder() + obj1, err1 := srv1.CatalogServiceNodes(resp1, req) + if err1 != nil { + t.Fatalf("err: %v", err1) + } + + assertIndex(t, resp1) + + nodes1 := obj1.(structs.ServiceNodes) + if len(nodes1) != 1 { + t.Fatalf("bad: %v", obj1) + } + + node1 := nodes1[0] + + // Expect that DC1 gives us a public address (since the node is in DC2) + if node1.Address != "127.0.0.2" { + t.Fatalf("bad: %v", node1) + } + + // Ask HTTP server on DC2 for the node + resp2 := httptest.NewRecorder() + obj2, err2 := srv2.CatalogServiceNodes(resp2, req) + if err2 != nil { + t.Fatalf("err: %v", err2) + } + + assertIndex(t, resp2) + + nodes2 := obj2.(structs.ServiceNodes) + if len(nodes2) != 1 { + t.Fatalf("bad: %v", obj2) + } + + node2 := nodes2[0] + + // Expect that DC2 gives us a local address (since the node is in DC2) + if node2.Address != "127.0.0.1" { + t.Fatalf("bad: %v", node2) + } +} + func TestCatalogServiceNodes_DistanceSort(t *testing.T) { dir, srv := makeHTTPServer(t) defer os.RemoveAll(dir) @@ -550,3 +714,76 @@ func TestCatalogNodeServices(t *testing.T) { t.Fatalf("bad: %v", obj) } } + +func TestCatalogNodeServices_WanTranslation(t *testing.T) { + httpCtx1, httpCtx2 := setupWanHTTPServers(t) + defer shutdownHTTPServer(httpCtx1) + defer shutdownHTTPServer(httpCtx2) + srv1 := httpCtx1.srv + srv2 := httpCtx2.srv + + // Register a node with DC2 + { + args := &structs.RegisterRequest{ + Datacenter: "dc2", + Node: "foo", + Address: "127.0.0.1", + TaggedAddresses: map[string]string{ + "wan": "127.0.0.2", + }, + Service: &structs.NodeService{ + Service: "http_wan_translation_test", + }, + } + + var out struct{} + if err := srv2.agent.RPC("Catalog.Register", args, &out); err != nil { + t.Fatalf("err: %v", err) + } + } + + req, err := http.NewRequest("GET", "/v1/catalog/node/foo?dc=dc2", nil) + if err != nil { + t.Fatalf("err: %v", err) + } + + // ask DC1 for node in DC2 + resp1 := httptest.NewRecorder() + obj1, err1 := srv1.CatalogNodeServices(resp1, req) + if err1 != nil { + t.Fatalf("err: %v", err1) + } + assertIndex(t, resp1) + + services1 := obj1.(*structs.NodeServices) + if len(services1.Services) != 1 { + t.Fatalf("bad: %v", obj1) + } + + service1 := services1.Node + + // Expect that DC1 gives us a public address (since the node is in DC2) + if service1.Address != "127.0.0.2" { + t.Fatalf("bad: %v", service1) + } + + // ask DC2 for node in DC2 + resp2 := httptest.NewRecorder() + obj2, err2 := srv2.CatalogNodeServices(resp2, req) + if err2 != nil { + t.Fatalf("err: %v", err2) + } + assertIndex(t, resp2) + + services2 := obj2.(*structs.NodeServices) + if len(services2.Services) != 1 { + t.Fatalf("bad: %v", obj2) + } + + service2 := services2.Node + + // Expect that DC2 gives us a private address (since the node is in DC2) + if service2.Address != "127.0.0.1" { + t.Fatalf("bad: %v", service2) + } +} diff --git a/command/agent/dns.go b/command/agent/dns.go index 3755bc0502..1a8eec1442 100644 --- a/command/agent/dns.go +++ b/command/agent/dns.go @@ -370,19 +370,6 @@ INVALID: resp.SetRcode(req, dns.RcodeNameError) } -// translateAddr is used to provide the final, translated address for a node, -// depending on how this agent and the other node are configured. -func (d *DNSServer) translateAddr(dc string, node *structs.Node) string { - addr := node.Address - if d.agent.config.TranslateWanAddrs && (d.agent.config.Datacenter != dc) { - wanAddr := node.TaggedAddresses["wan"] - if wanAddr != "" { - addr = wanAddr - } - } - return addr -} - // nodeLookup is used to handle a node query func (d *DNSServer) nodeLookup(network, datacenter, node string, req, resp *dns.Msg) { // Only handle ANY, A and AAAA type requests @@ -423,7 +410,8 @@ RPC: } // Add the node record - addr := d.translateAddr(datacenter, out.NodeServices.Node) + n := out.NodeServices.Node + addr := d.agent.TranslateAddr(datacenter, n.Address, n.TaggedAddresses) records := d.formatNodeRecord(out.NodeServices.Node, addr, req.Question[0].Name, qType, d.config.NodeTTL) if records != nil { @@ -776,7 +764,7 @@ func (d *DNSServer) serviceNodeRecords(dc string, nodes structs.CheckServiceNode for _, node := range nodes { // Start with the translated address but use the service address, // if specified. - addr := d.translateAddr(dc, node.Node) + addr := d.agent.TranslateAddr(dc, node.Node.Address, node.Node.TaggedAddresses) if node.Service.Address != "" { addr = node.Service.Address } @@ -825,7 +813,7 @@ func (d *DNSServer) serviceSRVRecords(dc string, nodes structs.CheckServiceNodes // Start with the translated address but use the service address, // if specified. - addr := d.translateAddr(dc, node.Node) + addr := d.agent.TranslateAddr(dc, node.Node.Address, node.Node.TaggedAddresses) if node.Service.Address != "" { addr = node.Service.Address } diff --git a/command/agent/health_endpoint.go b/command/agent/health_endpoint.go index b252a4126a..f2ddbafb9c 100644 --- a/command/agent/health_endpoint.go +++ b/command/agent/health_endpoint.go @@ -143,6 +143,13 @@ func (s *HTTPServer) HealthServiceNodes(resp http.ResponseWriter, req *http.Requ if out.Nodes == nil { out.Nodes = make(structs.CheckServiceNodes, 0) } + + for _, checkServiceNode := range out.Nodes { + node := checkServiceNode.Node + addr := s.agent.TranslateAddr(args.Datacenter, node.Address, node.TaggedAddresses) + node.Address = addr + } + return out.Nodes, nil } diff --git a/command/agent/health_endpoint_test.go b/command/agent/health_endpoint_test.go index 7bcbc91169..584a0ae2e6 100644 --- a/command/agent/health_endpoint_test.go +++ b/command/agent/health_endpoint_test.go @@ -554,6 +554,83 @@ func TestHealthServiceNodes_PassingFilter(t *testing.T) { } } +func TestHealthServiceNodes_WanTranslation(t *testing.T) { + httpCtx1, httpCtx2 := setupWanHTTPServers(t) + defer shutdownHTTPServer(httpCtx1) + defer shutdownHTTPServer(httpCtx2) + srv1 := httpCtx1.srv + srv2 := httpCtx2.srv + + // Register a node with DC2 + { + args := &structs.RegisterRequest{ + Datacenter: "dc2", + Node: "foo", + Address: "127.0.0.1", + TaggedAddresses: map[string]string{ + "wan": "127.0.0.2", + }, + Service: &structs.NodeService{ + Service: "http_wan_translation_test", + }, + } + + var out struct{} + if err := srv2.agent.RPC("Catalog.Register", args, &out); err != nil { + t.Fatalf("err: %v", err) + } + } + + req, err := http.NewRequest("GET", "/v1/health/service/http_wan_translation_test?dc=dc2", nil) + if err != nil { + t.Fatalf("err: %v", err) + } + + // ask DC1 for node in DC2 + resp1 := httptest.NewRecorder() + obj1, err1 := srv1.HealthServiceNodes(resp1, req) + if err1 != nil { + t.Fatalf("err: %v", err1) + } + + assertIndex(t, resp1) + + // Should be 1 health check for consul + nodes1 := obj1.(structs.CheckServiceNodes) + if len(nodes1) != 1 { + t.Fatalf("bad: %v", obj1) + } + + node1 := nodes1[0].Node + + // Expect that DC1 gives us a public address (since the node is in DC2) + if node1.Address != "127.0.0.2" { + t.Fatalf("bad: %v", node1) + } + + // ask DC2 for node in DC2 + resp2 := httptest.NewRecorder() + obj2, err2 := srv2.HealthServiceNodes(resp2, req) + if err2 != nil { + t.Fatalf("err: %v", err2) + } + + assertIndex(t, resp2) + + // Should be 1 health check for consul + nodes2 := obj2.(structs.CheckServiceNodes) + if len(nodes2) != 1 { + t.Fatalf("bad: %v", obj2) + } + + node2 := nodes2[0].Node + + // Expect that DC2 gives us a private address (since the node is in DC2) + if node2.Address != "127.0.0.1" { + t.Fatalf("bad: %v", node2) + } +} + func TestFilterNonPassing(t *testing.T) { nodes := structs.CheckServiceNodes{ structs.CheckServiceNode{ diff --git a/command/agent/http_test.go b/command/agent/http_test.go index b6618977f3..7352c6186c 100644 --- a/command/agent/http_test.go +++ b/command/agent/http_test.go @@ -23,6 +23,50 @@ import ( "github.com/hashicorp/go-cleanhttp" ) +type HTTPServerCtx struct { + dir string + srv *HTTPServer +} + +func setupWanHTTPServers(t *testing.T) (HTTPServerCtx, HTTPServerCtx) { + dir1, srv1 := makeHTTPServerWithConfig(t, + func(c *Config) { + c.Datacenter = "dc1" + c.TranslateWanAddrs = true + }) + + dir2, srv2 := makeHTTPServerWithConfig(t, + func(c *Config) { + c.Datacenter = "dc2" + c.TranslateWanAddrs = true + }) + + testutil.WaitForLeader(t, srv1.agent.RPC, "dc1") + testutil.WaitForLeader(t, srv2.agent.RPC, "dc2") + + // Join WAN cluster + addr := fmt.Sprintf("127.0.0.1:%d", + srv1.agent.config.Ports.SerfWan) + if _, err := srv2.agent.JoinWAN([]string{addr}); err != nil { + t.Fatalf("err: %v", err) + } + + testutil.WaitForResult( + func() (bool, error) { + return len(srv1.agent.WANMembers()) > 1, nil + }, + func(err error) { + t.Fatalf("Failed waiting for WAN join: %v", err) + }) + return HTTPServerCtx{dir1, srv1}, HTTPServerCtx{dir2, srv2} +} + +func shutdownHTTPServer(httpCtx HTTPServerCtx) { + os.RemoveAll(httpCtx.dir) + httpCtx.srv.Shutdown() + httpCtx.srv.agent.Shutdown() +} + func makeHTTPServer(t *testing.T) (string, *HTTPServer) { return makeHTTPServerWithConfig(t, nil) } diff --git a/command/agent/prepared_query_endpoint.go b/command/agent/prepared_query_endpoint.go index 1a6ff6d72e..df6f51f4d4 100644 --- a/command/agent/prepared_query_endpoint.go +++ b/command/agent/prepared_query_endpoint.go @@ -126,6 +126,13 @@ func (s *HTTPServer) preparedQueryExecute(id string, resp http.ResponseWriter, r if reply.Nodes == nil { reply.Nodes = make(structs.CheckServiceNodes, 0) } + + for _, checkServiceNode := range reply.Nodes { + node := checkServiceNode.Node + addr := s.agent.TranslateAddr(args.Datacenter, node.Address, node.TaggedAddresses) + node.Address = addr + } + return reply, nil } diff --git a/command/agent/prepared_query_endpoint_test.go b/command/agent/prepared_query_endpoint_test.go index ff757e0acf..ceac2f4408 100644 --- a/command/agent/prepared_query_endpoint_test.go +++ b/command/agent/prepared_query_endpoint_test.go @@ -359,6 +359,52 @@ func TestPreparedQuery_Execute(t *testing.T) { } }) + // testing WAN translation in the response + httpTestWithConfig(t, func(srv *HTTPServer) { + m := MockPreparedQuery{} + if err := srv.agent.InjectEndpoint("PreparedQuery", &m); err != nil { + t.Fatalf("err: %v", err) + } + + m.executeFn = func(args *structs.PreparedQueryExecuteRequest, reply *structs.PreparedQueryExecuteResponse) error { + nodesResponse := make(structs.CheckServiceNodes, 1) + nodesResponse[0].Node = &structs.Node{Node: "foo", Address: "127.0.0.1", + TaggedAddresses: map[string]string{"wan": "127.0.0.2"}} + reply.Nodes = nodesResponse + return nil + } + + body := bytes.NewBuffer(nil) + req, err := http.NewRequest("GET", "/v1/query/my-id/execute?dc=dc2", body) + if err != nil { + t.Fatalf("err: %v", err) + } + + resp := httptest.NewRecorder() + obj, err := srv.PreparedQuerySpecific(resp, req) + if err != nil { + t.Fatalf("err: %v", err) + } + if resp.Code != 200 { + t.Fatalf("bad code: %d", resp.Code) + } + r, ok := obj.(structs.PreparedQueryExecuteResponse) + if !ok { + t.Fatalf("unexpected: %T", obj) + } + if r.Nodes == nil || len(r.Nodes) != 1 { + t.Fatalf("bad: %v", r) + } + + node := r.Nodes[0] + if node.Node.Address != "127.0.0.2" { + t.Fatalf("bad: %v", node.Node) + } + }, func(c *Config) { + c.Datacenter = "dc1" + c.TranslateWanAddrs = true + }) + httpTest(t, func(srv *HTTPServer) { body := bytes.NewBuffer(nil) req, err := http.NewRequest("GET", "/v1/query/not-there/execute", body) diff --git a/consul/state/state_store.go b/consul/state/state_store.go index 7c48a1feee..0f1d95fb90 100644 --- a/consul/state/state_store.go +++ b/consul/state/state_store.go @@ -839,7 +839,9 @@ func (s *StateStore) parseServiceNodes(tx *memdb.Txn, services structs.ServiceNo if err != nil { return nil, fmt.Errorf("failed node lookup: %s", err) } - s.Address = n.(*structs.Node).Address + node := n.(*structs.Node) + s.Address = node.Address + s.TaggedAddresses = node.TaggedAddresses results = append(results, s) } return results, nil diff --git a/consul/structs/structs.go b/consul/structs/structs.go index c9ee486fc1..23c4c78eb4 100644 --- a/consul/structs/structs.go +++ b/consul/structs/structs.go @@ -264,6 +264,7 @@ type Services map[string][]string type ServiceNode struct { Node string Address string + TaggedAddresses map[string]string ServiceID string ServiceName string ServiceTags []string