diff --git a/agent/http.go b/agent/http.go index 7ccd00cadc..e7d4166af4 100644 --- a/agent/http.go +++ b/agent/http.go @@ -499,22 +499,22 @@ func (s *HTTPServer) parseToken(req *http.Request, token *string) { } func sourceAddrFromRequest(req *http.Request) (string, error) { + forwardHost := req.Header.Get("X-Forwarded-For") + forwardIp := net.ParseIP(forwardHost) + if forwardIp != nil { + return forwardIp.String(), nil + } + host, _, err := net.SplitHostPort(req.RemoteAddr) if err != nil { return "", err } ip := net.ParseIP(host) - if ip == nil { - return "", fmt.Errorf("Could not get IP from request") - } - - forwardHost := req.Header.Get("X-Forwarded-For") - forwardIp := net.ParseIP(forwardHost) - if forwardIp != nil { - return forwardIp.String(), nil - } else { + if ip != nil { return ip.String(), nil + } else { + return "", fmt.Errorf("Could not get remote IP from HTTP Request") } } diff --git a/agent/prepared_query_endpoint_test.go b/agent/prepared_query_endpoint_test.go index 58e27df2a5..350f1a5590 100644 --- a/agent/prepared_query_endpoint_test.go +++ b/agent/prepared_query_endpoint_test.go @@ -324,6 +324,63 @@ func TestPreparedQuery_Execute(t *testing.T) { t.Fatalf("bad: %v", r) } }) + + t.Run("", func(t *testing.T) { + a := NewTestAgent(t.Name(), "") + defer a.Shutdown() + + m := MockPreparedQuery{ + executeFn: func(args *structs.PreparedQueryExecuteRequest, reply *structs.PreparedQueryExecuteResponse) error { + expected := &structs.PreparedQueryExecuteRequest{ + Datacenter: "dc1", + QueryIDOrName: "my-id", + Limit: 5, + Source: structs.QuerySource{ + Datacenter: "dc1", + Node: "_ip", + Ip: "127.0.0.1", + }, + Agent: structs.QuerySource{ + Datacenter: a.Config.Datacenter, + Node: a.Config.NodeName, + }, + QueryOptions: structs.QueryOptions{ + Token: "my-token", + RequireConsistent: true, + }, + } + if !reflect.DeepEqual(args, expected) { + t.Fatalf("bad: %v", args) + } + + // Just set something so we can tell this is returned. + reply.Failovers = 99 + return nil + }, + } + if err := a.registerEndpoint("PreparedQuery", &m); err != nil { + t.Fatalf("err: %v", err) + } + + body := bytes.NewBuffer(nil) + req, _ := http.NewRequest("GET", "/v1/query/my-id/execute?token=my-token&consistent=true&near=_ip&limit=5", body) + req.Header.Add("X-Forwarded-For", "127.0.0.1") + resp := httptest.NewRecorder() + obj, err := a.srv.PreparedQuerySpecific(resp, req) + if err != nil { + t.Fatalf("err: %v", err) + } + if resp.Code != 200 { + t.Fatalf("bad code: %d", resp.Code) + } + r, ok := obj.(structs.PreparedQueryExecuteResponse) + if !ok { + t.Fatalf("unexpected: %T", obj) + } + if r.Failovers != 99 { + t.Fatalf("bad: %v", r) + } + }) // Ensure the proper params are set when no special args are passed t.Run("", func(t *testing.T) {