From e47d7eeddb93bac09b7e5e94231ef5fa17cc112b Mon Sep 17 00:00:00 2001 From: "R.B. Boyer" Date: Fri, 26 Apr 2019 12:49:28 -0500 Subject: [PATCH] acl: adding support for kubernetes auth provider login (#5600) * auth providers * binding rules * auth provider for kubernetes * login/logout --- agent/acl_endpoint.go | 327 +++ agent/acl_endpoint_test.go | 733 ++++- agent/consul/acl.go | 69 +- agent/consul/acl_authmethod.go | 169 ++ agent/consul/acl_authmethod_test.go | 48 + agent/consul/acl_endpoint.go | 668 ++++- agent/consul/acl_endpoint_legacy.go | 2 +- agent/consul/acl_endpoint_test.go | 2499 ++++++++++++++++- agent/consul/acl_replication_legacy.go | 2 +- agent/consul/acl_replication_legacy_test.go | 4 +- agent/consul/acl_replication_test.go | 10 +- agent/consul/acl_replication_types.go | 2 +- agent/consul/acl_server.go | 11 + agent/consul/acl_test.go | 197 +- agent/consul/acl_token_exp_test.go | 2 +- agent/consul/authmethod/authmethods.go | 112 + agent/consul/authmethod/kubeauth/k8s.go | 202 ++ agent/consul/authmethod/kubeauth/k8s_test.go | 144 + agent/consul/authmethod/kubeauth/testing.go | 532 ++++ agent/consul/authmethod/testauth/testing.go | 166 ++ agent/consul/fsm/commands_oss.go | 48 + agent/consul/fsm/snapshot_oss.go | 46 + agent/consul/fsm/snapshot_oss_test.go | 66 +- agent/consul/leader.go | 4 + agent/consul/server.go | 3 + agent/consul/state/acl.go | 593 +++- agent/consul/state/acl_test.go | 1080 ++++++- agent/consul/state/state_store.go | 22 +- agent/consul/util.go | 42 + agent/consul/util_test.go | 131 + agent/http_oss.go | 8 + agent/structs/acl.go | 316 ++- agent/structs/acl_cache.go | 11 +- agent/structs/acl_cache_test.go | 1 + agent/structs/acl_test.go | 3 +- agent/structs/structs.go | 54 +- api/acl.go | 364 ++- api/api.go | 28 + api/api_test.go | 16 +- command/acl/acl_helpers.go | 211 +- command/acl/authmethod/authmethod.go | 64 + .../authmethod/create/authmethod_create.go | 186 ++ .../create/authmethod_create_test.go | 226 ++ .../authmethod/delete/authmethod_delete.go | 82 + .../delete/authmethod_delete_test.go | 131 + .../acl/authmethod/list/authmethod_list.go | 83 + .../authmethod/list/authmethod_list_test.go | 109 + .../acl/authmethod/read/authmethod_read.go | 96 + .../authmethod/read/authmethod_read_test.go | 118 + .../authmethod/update/authmethod_update.go | 220 ++ .../update/authmethod_update_test.go | 647 +++++ command/acl/bindingrule/bindingrule.go | 60 + .../bindingrule/create/bindingrule_create.go | 148 + .../create/bindingrule_create_test.go | 178 ++ .../bindingrule/delete/bindingrule_delete.go | 97 + .../delete/bindingrule_delete_test.go | 187 ++ .../acl/bindingrule/list/bindingrule_list.go | 98 + .../bindingrule/list/bindingrule_list_test.go | 167 ++ .../acl/bindingrule/read/bindingrule_read.go | 108 + .../bindingrule/read/bindingrule_read_test.go | 152 + .../bindingrule/update/bindingrule_update.go | 212 ++ .../update/bindingrule_update_test.go | 768 +++++ command/acl/role/delete/role_delete.go | 15 +- command/acl/role/delete/role_delete_test.go | 4 +- command/commands_oss.go | 28 + command/connect/envoy/envoy.go | 2 +- command/connect/proxy/proxy.go | 2 +- command/flags/http.go | 30 + command/login/login.go | 148 + command/login/login_test.go | 321 +++ command/logout/logout.go | 70 + command/logout/logout_test.go | 299 ++ command/watch/watch.go | 12 +- go.mod | 6 +- go.sum | 10 +- vendor/golang.org/x/crypto/pbkdf2/pbkdf2.go | 77 + .../square/go-jose.v2/.gitcookies.sh.enc | 1 + vendor/gopkg.in/square/go-jose.v2/.gitignore | 7 + vendor/gopkg.in/square/go-jose.v2/.travis.yml | 46 + .../gopkg.in/square/go-jose.v2/BUG-BOUNTY.md | 10 + .../square/go-jose.v2/CONTRIBUTING.md | 14 + vendor/gopkg.in/square/go-jose.v2/LICENSE | 202 ++ vendor/gopkg.in/square/go-jose.v2/README.md | 118 + .../gopkg.in/square/go-jose.v2/asymmetric.go | 592 ++++ .../square/go-jose.v2/cipher/cbc_hmac.go | 196 ++ .../square/go-jose.v2/cipher/concat_kdf.go | 75 + .../square/go-jose.v2/cipher/ecdh_es.go | 62 + .../square/go-jose.v2/cipher/key_wrap.go | 109 + vendor/gopkg.in/square/go-jose.v2/crypter.go | 535 ++++ vendor/gopkg.in/square/go-jose.v2/doc.go | 27 + vendor/gopkg.in/square/go-jose.v2/encoding.go | 179 ++ .../gopkg.in/square/go-jose.v2/json/LICENSE | 27 + .../gopkg.in/square/go-jose.v2/json/README.md | 13 + .../gopkg.in/square/go-jose.v2/json/decode.go | 1183 ++++++++ .../gopkg.in/square/go-jose.v2/json/encode.go | 1197 ++++++++ .../gopkg.in/square/go-jose.v2/json/indent.go | 141 + .../square/go-jose.v2/json/scanner.go | 623 ++++ .../gopkg.in/square/go-jose.v2/json/stream.go | 480 ++++ .../gopkg.in/square/go-jose.v2/json/tags.go | 44 + vendor/gopkg.in/square/go-jose.v2/jwe.go | 294 ++ vendor/gopkg.in/square/go-jose.v2/jwk.go | 608 ++++ vendor/gopkg.in/square/go-jose.v2/jws.go | 321 +++ .../gopkg.in/square/go-jose.v2/jwt/builder.go | 334 +++ .../gopkg.in/square/go-jose.v2/jwt/claims.go | 120 + vendor/gopkg.in/square/go-jose.v2/jwt/doc.go | 22 + .../gopkg.in/square/go-jose.v2/jwt/errors.go | 53 + vendor/gopkg.in/square/go-jose.v2/jwt/jwt.go | 163 ++ .../square/go-jose.v2/jwt/validation.go | 114 + vendor/gopkg.in/square/go-jose.v2/opaque.go | 83 + vendor/gopkg.in/square/go-jose.v2/shared.go | 499 ++++ vendor/gopkg.in/square/go-jose.v2/signing.go | 389 +++ .../gopkg.in/square/go-jose.v2/symmetric.go | 482 ++++ .../apimachinery/pkg/api/errors/errors.go | 24 + .../apimachinery/pkg/apis/meta/v1/types.go | 4 + .../k8s.io/apimachinery/pkg/runtime/codec.go | 20 + .../runtime/serializer/streaming/streaming.go | 2 +- .../apimachinery/pkg/util/runtime/runtime.go | 6 +- vendor/modules.txt | 54 +- 118 files changed, 23163 insertions(+), 417 deletions(-) create mode 100644 agent/consul/acl_authmethod.go create mode 100644 agent/consul/acl_authmethod_test.go create mode 100644 agent/consul/authmethod/authmethods.go create mode 100644 agent/consul/authmethod/kubeauth/k8s.go create mode 100644 agent/consul/authmethod/kubeauth/k8s_test.go create mode 100644 agent/consul/authmethod/kubeauth/testing.go create mode 100644 agent/consul/authmethod/testauth/testing.go create mode 100644 command/acl/authmethod/authmethod.go create mode 100644 command/acl/authmethod/create/authmethod_create.go create mode 100644 command/acl/authmethod/create/authmethod_create_test.go create mode 100644 command/acl/authmethod/delete/authmethod_delete.go create mode 100644 command/acl/authmethod/delete/authmethod_delete_test.go create mode 100644 command/acl/authmethod/list/authmethod_list.go create mode 100644 command/acl/authmethod/list/authmethod_list_test.go create mode 100644 command/acl/authmethod/read/authmethod_read.go create mode 100644 command/acl/authmethod/read/authmethod_read_test.go create mode 100644 command/acl/authmethod/update/authmethod_update.go create mode 100644 command/acl/authmethod/update/authmethod_update_test.go create mode 100644 command/acl/bindingrule/bindingrule.go create mode 100644 command/acl/bindingrule/create/bindingrule_create.go create mode 100644 command/acl/bindingrule/create/bindingrule_create_test.go create mode 100644 command/acl/bindingrule/delete/bindingrule_delete.go create mode 100644 command/acl/bindingrule/delete/bindingrule_delete_test.go create mode 100644 command/acl/bindingrule/list/bindingrule_list.go create mode 100644 command/acl/bindingrule/list/bindingrule_list_test.go create mode 100644 command/acl/bindingrule/read/bindingrule_read.go create mode 100644 command/acl/bindingrule/read/bindingrule_read_test.go create mode 100644 command/acl/bindingrule/update/bindingrule_update.go create mode 100644 command/acl/bindingrule/update/bindingrule_update_test.go create mode 100644 command/login/login.go create mode 100644 command/login/login_test.go create mode 100644 command/logout/logout.go create mode 100644 command/logout/logout_test.go create mode 100644 vendor/golang.org/x/crypto/pbkdf2/pbkdf2.go create mode 100644 vendor/gopkg.in/square/go-jose.v2/.gitcookies.sh.enc create mode 100644 vendor/gopkg.in/square/go-jose.v2/.gitignore create mode 100644 vendor/gopkg.in/square/go-jose.v2/.travis.yml create mode 100644 vendor/gopkg.in/square/go-jose.v2/BUG-BOUNTY.md create mode 100644 vendor/gopkg.in/square/go-jose.v2/CONTRIBUTING.md create mode 100644 vendor/gopkg.in/square/go-jose.v2/LICENSE create mode 100644 vendor/gopkg.in/square/go-jose.v2/README.md create mode 100644 vendor/gopkg.in/square/go-jose.v2/asymmetric.go create mode 100644 vendor/gopkg.in/square/go-jose.v2/cipher/cbc_hmac.go create mode 100644 vendor/gopkg.in/square/go-jose.v2/cipher/concat_kdf.go create mode 100644 vendor/gopkg.in/square/go-jose.v2/cipher/ecdh_es.go create mode 100644 vendor/gopkg.in/square/go-jose.v2/cipher/key_wrap.go create mode 100644 vendor/gopkg.in/square/go-jose.v2/crypter.go create mode 100644 vendor/gopkg.in/square/go-jose.v2/doc.go create mode 100644 vendor/gopkg.in/square/go-jose.v2/encoding.go create mode 100644 vendor/gopkg.in/square/go-jose.v2/json/LICENSE create mode 100644 vendor/gopkg.in/square/go-jose.v2/json/README.md create mode 100644 vendor/gopkg.in/square/go-jose.v2/json/decode.go create mode 100644 vendor/gopkg.in/square/go-jose.v2/json/encode.go create mode 100644 vendor/gopkg.in/square/go-jose.v2/json/indent.go create mode 100644 vendor/gopkg.in/square/go-jose.v2/json/scanner.go create mode 100644 vendor/gopkg.in/square/go-jose.v2/json/stream.go create mode 100644 vendor/gopkg.in/square/go-jose.v2/json/tags.go create mode 100644 vendor/gopkg.in/square/go-jose.v2/jwe.go create mode 100644 vendor/gopkg.in/square/go-jose.v2/jwk.go create mode 100644 vendor/gopkg.in/square/go-jose.v2/jws.go create mode 100644 vendor/gopkg.in/square/go-jose.v2/jwt/builder.go create mode 100644 vendor/gopkg.in/square/go-jose.v2/jwt/claims.go create mode 100644 vendor/gopkg.in/square/go-jose.v2/jwt/doc.go create mode 100644 vendor/gopkg.in/square/go-jose.v2/jwt/errors.go create mode 100644 vendor/gopkg.in/square/go-jose.v2/jwt/jwt.go create mode 100644 vendor/gopkg.in/square/go-jose.v2/jwt/validation.go create mode 100644 vendor/gopkg.in/square/go-jose.v2/opaque.go create mode 100644 vendor/gopkg.in/square/go-jose.v2/shared.go create mode 100644 vendor/gopkg.in/square/go-jose.v2/signing.go create mode 100644 vendor/gopkg.in/square/go-jose.v2/symmetric.go diff --git a/agent/acl_endpoint.go b/agent/acl_endpoint.go index cafe6e11c3..12c6b313a2 100644 --- a/agent/acl_endpoint.go +++ b/agent/acl_endpoint.go @@ -376,6 +376,7 @@ func (s *HTTPServer) ACLTokenList(resp http.ResponseWriter, req *http.Request) ( args.Policy = req.URL.Query().Get("policy") args.Role = req.URL.Query().Get("role") + args.AuthMethod = req.URL.Query().Get("authmethod") var out structs.ACLTokenListResponse defer setMeta(resp, &out.QueryMeta) @@ -701,3 +702,329 @@ func (s *HTTPServer) ACLRoleDelete(resp http.ResponseWriter, req *http.Request, return true, nil } + +func (s *HTTPServer) ACLBindingRuleList(resp http.ResponseWriter, req *http.Request) (interface{}, error) { + if s.checkACLDisabled(resp, req) { + return nil, nil + } + + var args structs.ACLBindingRuleListRequest + if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done { + return nil, nil + } + + if args.Datacenter == "" { + args.Datacenter = s.agent.config.Datacenter + } + + args.AuthMethod = req.URL.Query().Get("authmethod") + + var out structs.ACLBindingRuleListResponse + defer setMeta(resp, &out.QueryMeta) + if err := s.agent.RPC("ACL.BindingRuleList", &args, &out); err != nil { + return nil, err + } + + // make sure we return an array and not nil + if out.BindingRules == nil { + out.BindingRules = make(structs.ACLBindingRules, 0) + } + + return out.BindingRules, nil +} + +func (s *HTTPServer) ACLBindingRuleCRUD(resp http.ResponseWriter, req *http.Request) (interface{}, error) { + if s.checkACLDisabled(resp, req) { + return nil, nil + } + + var fn func(resp http.ResponseWriter, req *http.Request, bindingRuleID string) (interface{}, error) + + switch req.Method { + case "GET": + fn = s.ACLBindingRuleRead + + case "PUT": + fn = s.ACLBindingRuleWrite + + case "DELETE": + fn = s.ACLBindingRuleDelete + + default: + return nil, MethodNotAllowedError{req.Method, []string{"GET", "PUT", "DELETE"}} + } + + bindingRuleID := strings.TrimPrefix(req.URL.Path, "/v1/acl/binding-rule/") + if bindingRuleID == "" && req.Method != "PUT" { + return nil, BadRequestError{Reason: "Missing binding rule ID"} + } + + return fn(resp, req, bindingRuleID) +} + +func (s *HTTPServer) ACLBindingRuleRead(resp http.ResponseWriter, req *http.Request, bindingRuleID string) (interface{}, error) { + args := structs.ACLBindingRuleGetRequest{ + Datacenter: s.agent.config.Datacenter, + BindingRuleID: bindingRuleID, + } + if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done { + return nil, nil + } + + if args.Datacenter == "" { + args.Datacenter = s.agent.config.Datacenter + } + + var out structs.ACLBindingRuleResponse + defer setMeta(resp, &out.QueryMeta) + if err := s.agent.RPC("ACL.BindingRuleRead", &args, &out); err != nil { + return nil, err + } + + if out.BindingRule == nil { + resp.WriteHeader(http.StatusNotFound) + return nil, nil + } + + return out.BindingRule, nil +} + +func (s *HTTPServer) ACLBindingRuleCreate(resp http.ResponseWriter, req *http.Request) (interface{}, error) { + if s.checkACLDisabled(resp, req) { + return nil, nil + } + + return s.ACLBindingRuleWrite(resp, req, "") +} + +func (s *HTTPServer) ACLBindingRuleWrite(resp http.ResponseWriter, req *http.Request, bindingRuleID string) (interface{}, error) { + args := structs.ACLBindingRuleSetRequest{ + Datacenter: s.agent.config.Datacenter, + } + s.parseToken(req, &args.Token) + + if err := decodeBody(req, &args.BindingRule, fixTimeAndHashFields); err != nil { + return nil, BadRequestError{Reason: fmt.Sprintf("BindingRule decoding failed: %v", err)} + } + + if args.BindingRule.ID != "" && args.BindingRule.ID != bindingRuleID { + return nil, BadRequestError{Reason: "BindingRule ID in URL and payload do not match"} + } else if args.BindingRule.ID == "" { + args.BindingRule.ID = bindingRuleID + } + + var out structs.ACLBindingRule + if err := s.agent.RPC("ACL.BindingRuleSet", args, &out); err != nil { + return nil, err + } + + return &out, nil +} + +func (s *HTTPServer) ACLBindingRuleDelete(resp http.ResponseWriter, req *http.Request, bindingRuleID string) (interface{}, error) { + args := structs.ACLBindingRuleDeleteRequest{ + Datacenter: s.agent.config.Datacenter, + BindingRuleID: bindingRuleID, + } + s.parseToken(req, &args.Token) + + var ignored bool + if err := s.agent.RPC("ACL.BindingRuleDelete", args, &ignored); err != nil { + return nil, err + } + + return true, nil +} + +func (s *HTTPServer) ACLAuthMethodList(resp http.ResponseWriter, req *http.Request) (interface{}, error) { + if s.checkACLDisabled(resp, req) { + return nil, nil + } + + var args structs.ACLAuthMethodListRequest + if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done { + return nil, nil + } + + if args.Datacenter == "" { + args.Datacenter = s.agent.config.Datacenter + } + + var out structs.ACLAuthMethodListResponse + defer setMeta(resp, &out.QueryMeta) + if err := s.agent.RPC("ACL.AuthMethodList", &args, &out); err != nil { + return nil, err + } + + // make sure we return an array and not nil + if out.AuthMethods == nil { + out.AuthMethods = make(structs.ACLAuthMethodListStubs, 0) + } + + return out.AuthMethods, nil +} + +func (s *HTTPServer) ACLAuthMethodCRUD(resp http.ResponseWriter, req *http.Request) (interface{}, error) { + if s.checkACLDisabled(resp, req) { + return nil, nil + } + + var fn func(resp http.ResponseWriter, req *http.Request, methodName string) (interface{}, error) + + switch req.Method { + case "GET": + fn = s.ACLAuthMethodRead + + case "PUT": + fn = s.ACLAuthMethodWrite + + case "DELETE": + fn = s.ACLAuthMethodDelete + + default: + return nil, MethodNotAllowedError{req.Method, []string{"GET", "PUT", "DELETE"}} + } + + methodName := strings.TrimPrefix(req.URL.Path, "/v1/acl/auth-method/") + if methodName == "" && req.Method != "PUT" { + return nil, BadRequestError{Reason: "Missing auth method name"} + } + + return fn(resp, req, methodName) +} + +func (s *HTTPServer) ACLAuthMethodRead(resp http.ResponseWriter, req *http.Request, methodName string) (interface{}, error) { + args := structs.ACLAuthMethodGetRequest{ + Datacenter: s.agent.config.Datacenter, + AuthMethodName: methodName, + } + if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done { + return nil, nil + } + + if args.Datacenter == "" { + args.Datacenter = s.agent.config.Datacenter + } + + var out structs.ACLAuthMethodResponse + defer setMeta(resp, &out.QueryMeta) + if err := s.agent.RPC("ACL.AuthMethodRead", &args, &out); err != nil { + return nil, err + } + + if out.AuthMethod == nil { + resp.WriteHeader(http.StatusNotFound) + return nil, nil + } + + fixupAuthMethodConfig(out.AuthMethod) + return out.AuthMethod, nil +} + +func (s *HTTPServer) ACLAuthMethodCreate(resp http.ResponseWriter, req *http.Request) (interface{}, error) { + if s.checkACLDisabled(resp, req) { + return nil, nil + } + + return s.ACLAuthMethodWrite(resp, req, "") +} + +func (s *HTTPServer) ACLAuthMethodWrite(resp http.ResponseWriter, req *http.Request, methodName string) (interface{}, error) { + args := structs.ACLAuthMethodSetRequest{ + Datacenter: s.agent.config.Datacenter, + } + s.parseToken(req, &args.Token) + + if err := decodeBody(req, &args.AuthMethod, fixTimeAndHashFields); err != nil { + return nil, BadRequestError{Reason: fmt.Sprintf("AuthMethod decoding failed: %v", err)} + } + + if methodName != "" { + if args.AuthMethod.Name != "" && args.AuthMethod.Name != methodName { + return nil, BadRequestError{Reason: "AuthMethod Name in URL and payload do not match"} + } else if args.AuthMethod.Name == "" { + args.AuthMethod.Name = methodName + } + } + + var out structs.ACLAuthMethod + if err := s.agent.RPC("ACL.AuthMethodSet", args, &out); err != nil { + return nil, err + } + + fixupAuthMethodConfig(&out) + return &out, nil +} + +func (s *HTTPServer) ACLAuthMethodDelete(resp http.ResponseWriter, req *http.Request, methodName string) (interface{}, error) { + args := structs.ACLAuthMethodDeleteRequest{ + Datacenter: s.agent.config.Datacenter, + AuthMethodName: methodName, + } + s.parseToken(req, &args.Token) + + var ignored bool + if err := s.agent.RPC("ACL.AuthMethodDelete", args, &ignored); err != nil { + return nil, err + } + + return true, nil +} + +func (s *HTTPServer) ACLLogin(resp http.ResponseWriter, req *http.Request) (interface{}, error) { + if s.checkACLDisabled(resp, req) { + return nil, nil + } + + args := &structs.ACLLoginRequest{ + Datacenter: s.agent.config.Datacenter, + } + s.parseDC(req, &args.Datacenter) + + if err := decodeBody(req, &args.Auth, nil); err != nil { + return nil, BadRequestError{Reason: fmt.Sprintf("Failed to decode request body:: %v", err)} + } + + var out structs.ACLToken + if err := s.agent.RPC("ACL.Login", args, &out); err != nil { + return nil, err + } + + return &out, nil +} + +func (s *HTTPServer) ACLLogout(resp http.ResponseWriter, req *http.Request) (interface{}, error) { + if s.checkACLDisabled(resp, req) { + return nil, nil + } + + args := structs.ACLLogoutRequest{ + Datacenter: s.agent.config.Datacenter, + } + s.parseDC(req, &args.Datacenter) + s.parseToken(req, &args.Token) + + if args.Token == "" { + return nil, acl.ErrNotFound + } + + var ignored bool + if err := s.agent.RPC("ACL.Logout", &args, &ignored); err != nil { + return nil, err + } + + return true, nil +} + +// A hack to fix up the config types inside of the map[string]interface{} +// so that they get formatted correctly during json.Marshal. Without this, +// string values that get converted to []uint8 end up getting output back +// to the user in base64-encoded form. +func fixupAuthMethodConfig(method *structs.ACLAuthMethod) { + for k, v := range method.Config { + if raw, ok := v.([]uint8); ok { + strVal := structs.Uint8ToString(raw) + method.Config[k] = strVal + } + } +} diff --git a/agent/acl_endpoint_test.go b/agent/acl_endpoint_test.go index 754b931cc6..ab92c1457a 100644 --- a/agent/acl_endpoint_test.go +++ b/agent/acl_endpoint_test.go @@ -8,6 +8,8 @@ import ( "net/http/httptest" "testing" + "github.com/hashicorp/consul/acl" + "github.com/hashicorp/consul/agent/consul/authmethod/testauth" "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/testrpc" "github.com/stretchr/testify/require" @@ -40,6 +42,17 @@ func TestACL_Disabled_Response(t *testing.T) { {"ACLTokenCreate", a.srv.ACLTokenCreate}, {"ACLTokenSelf", a.srv.ACLTokenSelf}, {"ACLTokenCRUD", a.srv.ACLTokenCRUD}, + {"ACLRoleList", a.srv.ACLRoleList}, + {"ACLRoleCreate", a.srv.ACLRoleCreate}, + {"ACLRoleCRUD", a.srv.ACLRoleCRUD}, + {"ACLBindingRuleList", a.srv.ACLBindingRuleList}, + {"ACLBindingRuleCreate", a.srv.ACLBindingRuleCreate}, + {"ACLBindingRuleCRUD", a.srv.ACLBindingRuleCRUD}, + {"ACLAuthMethodList", a.srv.ACLAuthMethodList}, + {"ACLAuthMethodCreate", a.srv.ACLAuthMethodCreate}, + {"ACLAuthMethodCRUD", a.srv.ACLAuthMethodCRUD}, + {"ACLLogin", a.srv.ACLLogin}, + {"ACLLogout", a.srv.ACLLogout}, } testrpc.WaitForLeader(t, a.RPC, "dc1") for _, tt := range tests { @@ -119,6 +132,7 @@ func TestACL_HTTP(t *testing.T) { idMap := make(map[string]string) policyMap := make(map[string]*structs.ACLPolicy) + roleMap := make(map[string]*structs.ACLRole) tokenMap := make(map[string]*structs.ACLToken) // This is all done as a subtest for a couple reasons @@ -220,7 +234,7 @@ func TestACL_HTTP(t *testing.T) { policyMap[policy.ID] = policy }) - t.Run("Update Name ID Mistmatch", func(t *testing.T) { + t.Run("Update Name ID Mismatch", func(t *testing.T) { policyInput := &structs.ACLPolicy{ ID: "ac7560be-7f11-4d6d-bfcf-15633c2090fd", Name: "read-all-nodes", @@ -355,6 +369,222 @@ func TestACL_HTTP(t *testing.T) { }) }) + t.Run("Role", func(t *testing.T) { + t.Run("Create", func(t *testing.T) { + roleInput := &structs.ACLRole{ + Name: "test", + Description: "test", + Policies: []structs.ACLRolePolicyLink{ + structs.ACLRolePolicyLink{ + ID: idMap["policy-test"], + Name: policyMap[idMap["policy-test"]].Name, + }, + structs.ACLRolePolicyLink{ + ID: idMap["policy-read-all-nodes"], + Name: policyMap[idMap["policy-read-all-nodes"]].Name, + }, + }, + } + + req, _ := http.NewRequest("PUT", "/v1/acl/role?token=root", jsonBody(roleInput)) + resp := httptest.NewRecorder() + obj, err := a.srv.ACLRoleCreate(resp, req) + require.NoError(t, err) + + role, ok := obj.(*structs.ACLRole) + require.True(t, ok) + + // 36 = length of the string form of uuids + require.Len(t, role.ID, 36) + require.Equal(t, roleInput.Name, role.Name) + require.Equal(t, roleInput.Description, role.Description) + require.Equal(t, roleInput.Policies, role.Policies) + require.True(t, role.CreateIndex > 0) + require.Equal(t, role.CreateIndex, role.ModifyIndex) + require.NotNil(t, role.Hash) + require.NotEqual(t, role.Hash, []byte{}) + + idMap["role-test"] = role.ID + roleMap[role.ID] = role + }) + + t.Run("Name Chars", func(t *testing.T) { + roleInput := &structs.ACLRole{ + Name: "service-id-web", + ServiceIdentities: []*structs.ACLServiceIdentity{ + &structs.ACLServiceIdentity{ + ServiceName: "web", + }, + }, + } + + req, _ := http.NewRequest("PUT", "/v1/acl/role?token=root", jsonBody(roleInput)) + resp := httptest.NewRecorder() + obj, err := a.srv.ACLRoleCreate(resp, req) + require.NoError(t, err) + + role, ok := obj.(*structs.ACLRole) + require.True(t, ok) + + // 36 = length of the string form of uuids + require.Len(t, role.ID, 36) + require.Equal(t, roleInput.Name, role.Name) + require.Equal(t, roleInput.Description, role.Description) + require.Equal(t, roleInput.ServiceIdentities, role.ServiceIdentities) + require.True(t, role.CreateIndex > 0) + require.Equal(t, role.CreateIndex, role.ModifyIndex) + require.NotNil(t, role.Hash) + require.NotEqual(t, role.Hash, []byte{}) + + idMap["role-service-id-web"] = role.ID + roleMap[role.ID] = role + }) + + t.Run("Update Name ID Mismatch", func(t *testing.T) { + roleInput := &structs.ACLRole{ + ID: "ac7560be-7f11-4d6d-bfcf-15633c2090fd", + Name: "test", + Description: "test", + ServiceIdentities: []*structs.ACLServiceIdentity{ + &structs.ACLServiceIdentity{ + ServiceName: "db", + }, + }, + } + + req, _ := http.NewRequest("PUT", "/v1/acl/role/"+idMap["role-test"]+"?token=root", jsonBody(roleInput)) + resp := httptest.NewRecorder() + _, err := a.srv.ACLRoleCRUD(resp, req) + require.Error(t, err) + _, ok := err.(BadRequestError) + require.True(t, ok) + }) + + t.Run("Role CRUD Missing ID in URL", func(t *testing.T) { + req, _ := http.NewRequest("GET", "/v1/acl/role/?token=root", nil) + resp := httptest.NewRecorder() + _, err := a.srv.ACLRoleCRUD(resp, req) + require.Error(t, err) + _, ok := err.(BadRequestError) + require.True(t, ok) + }) + + t.Run("Update", func(t *testing.T) { + roleInput := &structs.ACLRole{ + Name: "test", + Description: "test", + ServiceIdentities: []*structs.ACLServiceIdentity{ + &structs.ACLServiceIdentity{ + ServiceName: "web-indexer", + }, + }, + } + + req, _ := http.NewRequest("PUT", "/v1/acl/role/"+idMap["role-test"]+"?token=root", jsonBody(roleInput)) + resp := httptest.NewRecorder() + obj, err := a.srv.ACLRoleCRUD(resp, req) + require.NoError(t, err) + + role, ok := obj.(*structs.ACLRole) + require.True(t, ok) + + // 36 = length of the string form of uuids + require.Len(t, role.ID, 36) + require.Equal(t, roleInput.Name, role.Name) + require.Equal(t, roleInput.Description, role.Description) + require.Equal(t, roleInput.Policies, role.Policies) + require.Equal(t, roleInput.ServiceIdentities, role.ServiceIdentities) + require.True(t, role.CreateIndex > 0) + require.True(t, role.CreateIndex < role.ModifyIndex) + require.NotNil(t, role.Hash) + require.NotEqual(t, role.Hash, []byte{}) + + idMap["role-test"] = role.ID + roleMap[role.ID] = role + }) + + t.Run("ID Supplied", func(t *testing.T) { + roleInput := &structs.ACLRole{ + ID: "12123d01-37f1-47e6-b55b-32328652bd38", + Name: "with-id", + Description: "test", + ServiceIdentities: []*structs.ACLServiceIdentity{ + &structs.ACLServiceIdentity{ + ServiceName: "foobar", + }, + }, + } + + req, _ := http.NewRequest("PUT", "/v1/acl/role?token=root", jsonBody(roleInput)) + resp := httptest.NewRecorder() + _, err := a.srv.ACLRoleCreate(resp, req) + require.Error(t, err) + _, ok := err.(BadRequestError) + require.True(t, ok) + }) + + t.Run("Invalid payload", func(t *testing.T) { + body := bytes.NewBuffer(nil) + body.Write([]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}) + + req, _ := http.NewRequest("PUT", "/v1/acl/role?token=root", body) + resp := httptest.NewRecorder() + _, err := a.srv.ACLRoleCreate(resp, req) + require.Error(t, err) + _, ok := err.(BadRequestError) + require.True(t, ok) + }) + + t.Run("Delete", func(t *testing.T) { + req, _ := http.NewRequest("DELETE", "/v1/acl/role/"+idMap["role-service-id-web"]+"?token=root", nil) + resp := httptest.NewRecorder() + _, err := a.srv.ACLRoleCRUD(resp, req) + require.NoError(t, err) + delete(roleMap, idMap["role-service-id-web"]) + delete(idMap, "role-service-id-web") + }) + + t.Run("List", func(t *testing.T) { + req, _ := http.NewRequest("GET", "/v1/acl/roles?token=root", nil) + resp := httptest.NewRecorder() + raw, err := a.srv.ACLRoleList(resp, req) + require.NoError(t, err) + roles, ok := raw.(structs.ACLRoles) + require.True(t, ok) + + // 1 we just created + require.Len(t, roles, 1) + + for roleID, expected := range roleMap { + found := false + for _, actual := range roles { + if actual.ID == roleID { + require.Equal(t, expected.Name, actual.Name) + require.Equal(t, expected.Policies, actual.Policies) + require.Equal(t, expected.ServiceIdentities, actual.ServiceIdentities) + require.Equal(t, expected.Hash, actual.Hash) + require.Equal(t, expected.CreateIndex, actual.CreateIndex) + require.Equal(t, expected.ModifyIndex, actual.ModifyIndex) + found = true + break + } + } + + require.True(t, found) + } + }) + + t.Run("Read", func(t *testing.T) { + req, _ := http.NewRequest("GET", "/v1/acl/role/"+idMap["role-test"]+"?token=root", nil) + resp := httptest.NewRecorder() + raw, err := a.srv.ACLRoleCRUD(resp, req) + require.NoError(t, err) + role, ok := raw.(*structs.ACLRole) + require.True(t, ok) + require.Equal(t, roleMap[idMap["role-test"]], role) + }) + }) + t.Run("Token", func(t *testing.T) { t.Run("Create", func(t *testing.T) { tokenInput := &structs.ACLToken{ @@ -594,3 +824,504 @@ func TestACL_HTTP(t *testing.T) { }) }) } + +func TestACL_LoginProcedure_HTTP(t *testing.T) { + // This tests AuthMethods, BindingRules, Login, and Logout. + t.Parallel() + a := NewTestAgent(t, t.Name(), TestACLConfig()) + defer a.Shutdown() + + testrpc.WaitForLeader(t, a.RPC, "dc1") + + idMap := make(map[string]string) + methodMap := make(map[string]*structs.ACLAuthMethod) + ruleMap := make(map[string]*structs.ACLBindingRule) + tokenMap := make(map[string]*structs.ACLToken) + + testSessionID := testauth.StartSession() + defer testauth.ResetSession(testSessionID) + + // This is all done as a subtest for a couple reasons + // 1. It uses only 1 test agent and these are + // somewhat expensive to bring up and tear down often + // 2. Instead of having to bring up a new agent and prime + // the ACL system with some data before running the test + // we can intelligently order these tests so we can still + // test everything with less actual operations and do + // so in a manner that is less prone to being flaky + // 3. While this test will be large it should + t.Run("AuthMethod", func(t *testing.T) { + t.Run("Create", func(t *testing.T) { + methodInput := &structs.ACLAuthMethod{ + Name: "test", + Type: "testing", + Description: "test", + Config: map[string]interface{}{ + "SessionID": testSessionID, + }, + } + + req, _ := http.NewRequest("PUT", "/v1/acl/auth-method?token=root", jsonBody(methodInput)) + resp := httptest.NewRecorder() + obj, err := a.srv.ACLAuthMethodCreate(resp, req) + require.NoError(t, err) + + method, ok := obj.(*structs.ACLAuthMethod) + require.True(t, ok) + + require.Equal(t, methodInput.Name, method.Name) + require.Equal(t, methodInput.Type, method.Type) + require.Equal(t, methodInput.Description, method.Description) + require.Equal(t, methodInput.Config, method.Config) + require.True(t, method.CreateIndex > 0) + require.Equal(t, method.CreateIndex, method.ModifyIndex) + + methodMap[method.Name] = method + }) + + t.Run("Create other", func(t *testing.T) { + methodInput := &structs.ACLAuthMethod{ + Name: "other", + Type: "testing", + Description: "test", + Config: map[string]interface{}{ + "SessionID": testSessionID, + }, + } + + req, _ := http.NewRequest("PUT", "/v1/acl/auth-method?token=root", jsonBody(methodInput)) + resp := httptest.NewRecorder() + obj, err := a.srv.ACLAuthMethodCreate(resp, req) + require.NoError(t, err) + + method, ok := obj.(*structs.ACLAuthMethod) + require.True(t, ok) + + require.Equal(t, methodInput.Name, method.Name) + require.Equal(t, methodInput.Type, method.Type) + require.Equal(t, methodInput.Description, method.Description) + require.Equal(t, methodInput.Config, method.Config) + require.True(t, method.CreateIndex > 0) + require.Equal(t, method.CreateIndex, method.ModifyIndex) + + methodMap[method.Name] = method + }) + + t.Run("Update Name URL Mismatch", func(t *testing.T) { + methodInput := &structs.ACLAuthMethod{ + Name: "test", + Type: "testing", + Description: "test", + Config: map[string]interface{}{ + "SessionID": testSessionID, + }, + } + + req, _ := http.NewRequest("PUT", "/v1/acl/auth-method/not-test?token=root", jsonBody(methodInput)) + resp := httptest.NewRecorder() + _, err := a.srv.ACLAuthMethodCRUD(resp, req) + require.Error(t, err) + _, ok := err.(BadRequestError) + require.True(t, ok) + }) + + t.Run("Update", func(t *testing.T) { + methodInput := &structs.ACLAuthMethod{ + Name: "test", + Type: "testing", + Description: "updated description", + Config: map[string]interface{}{ + "SessionID": testSessionID, + }, + } + + req, _ := http.NewRequest("PUT", "/v1/acl/auth-method/test?token=root", jsonBody(methodInput)) + resp := httptest.NewRecorder() + obj, err := a.srv.ACLAuthMethodCRUD(resp, req) + require.NoError(t, err) + + method, ok := obj.(*structs.ACLAuthMethod) + require.True(t, ok) + + require.Equal(t, methodInput.Name, method.Name) + require.Equal(t, methodInput.Type, method.Type) + require.Equal(t, methodInput.Description, method.Description) + require.Equal(t, methodInput.Config, method.Config) + require.True(t, method.CreateIndex > 0) + require.True(t, method.CreateIndex < method.ModifyIndex) + + methodMap[method.Name] = method + }) + + t.Run("Invalid payload", func(t *testing.T) { + body := bytes.NewBuffer(nil) + body.Write([]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}) + + req, _ := http.NewRequest("PUT", "/v1/acl/auth-method?token=root", body) + resp := httptest.NewRecorder() + _, err := a.srv.ACLAuthMethodCreate(resp, req) + require.Error(t, err) + _, ok := err.(BadRequestError) + require.True(t, ok) + }) + + t.Run("List", func(t *testing.T) { + req, _ := http.NewRequest("GET", "/v1/acl/auth-methods?token=root", nil) + resp := httptest.NewRecorder() + raw, err := a.srv.ACLAuthMethodList(resp, req) + require.NoError(t, err) + methods, ok := raw.(structs.ACLAuthMethodListStubs) + require.True(t, ok) + + // 2 we just created + require.Len(t, methods, 2) + + for methodName, expected := range methodMap { + found := false + for _, actual := range methods { + if actual.Name == methodName { + require.Equal(t, expected.Name, actual.Name) + require.Equal(t, expected.Type, actual.Type) + require.Equal(t, expected.Description, actual.Description) + require.Equal(t, expected.CreateIndex, actual.CreateIndex) + require.Equal(t, expected.ModifyIndex, actual.ModifyIndex) + found = true + break + } + } + + require.True(t, found) + } + }) + + t.Run("Delete", func(t *testing.T) { + req, _ := http.NewRequest("DELETE", "/v1/acl/auth-method/other?token=root", nil) + resp := httptest.NewRecorder() + _, err := a.srv.ACLAuthMethodCRUD(resp, req) + require.NoError(t, err) + delete(methodMap, "other") + }) + + t.Run("Read", func(t *testing.T) { + req, _ := http.NewRequest("GET", "/v1/acl/auth-method/test?token=root", nil) + resp := httptest.NewRecorder() + raw, err := a.srv.ACLAuthMethodCRUD(resp, req) + require.NoError(t, err) + method, ok := raw.(*structs.ACLAuthMethod) + require.True(t, ok) + require.Equal(t, methodMap["test"], method) + }) + }) + + t.Run("BindingRule", func(t *testing.T) { + t.Run("Create", func(t *testing.T) { + ruleInput := &structs.ACLBindingRule{ + Description: "test", + AuthMethod: "test", + Selector: "serviceaccount.namespace==default", + BindType: structs.BindingRuleBindTypeService, + BindName: "web", + } + + req, _ := http.NewRequest("PUT", "/v1/acl/binding-rule?token=root", jsonBody(ruleInput)) + resp := httptest.NewRecorder() + obj, err := a.srv.ACLBindingRuleCreate(resp, req) + require.NoError(t, err) + + rule, ok := obj.(*structs.ACLBindingRule) + require.True(t, ok) + + // 36 = length of the string form of uuids + require.Len(t, rule.ID, 36) + require.Equal(t, ruleInput.Description, rule.Description) + require.Equal(t, ruleInput.AuthMethod, rule.AuthMethod) + require.Equal(t, ruleInput.Selector, rule.Selector) + require.Equal(t, ruleInput.BindType, rule.BindType) + require.Equal(t, ruleInput.BindName, rule.BindName) + require.True(t, rule.CreateIndex > 0) + require.Equal(t, rule.CreateIndex, rule.ModifyIndex) + + idMap["rule-test"] = rule.ID + ruleMap[rule.ID] = rule + }) + + t.Run("Create other", func(t *testing.T) { + ruleInput := &structs.ACLBindingRule{ + Description: "other", + AuthMethod: "test", + Selector: "serviceaccount.namespace==default", + BindType: structs.BindingRuleBindTypeRole, + BindName: "fancy-role", + } + + req, _ := http.NewRequest("PUT", "/v1/acl/binding-rule?token=root", jsonBody(ruleInput)) + resp := httptest.NewRecorder() + obj, err := a.srv.ACLBindingRuleCreate(resp, req) + require.NoError(t, err) + + rule, ok := obj.(*structs.ACLBindingRule) + require.True(t, ok) + + // 36 = length of the string form of uuids + require.Len(t, rule.ID, 36) + require.Equal(t, ruleInput.Description, rule.Description) + require.Equal(t, ruleInput.AuthMethod, rule.AuthMethod) + require.Equal(t, ruleInput.Selector, rule.Selector) + require.Equal(t, ruleInput.BindType, rule.BindType) + require.Equal(t, ruleInput.BindName, rule.BindName) + require.True(t, rule.CreateIndex > 0) + require.Equal(t, rule.CreateIndex, rule.ModifyIndex) + + idMap["rule-other"] = rule.ID + ruleMap[rule.ID] = rule + }) + + t.Run("BindingRule CRUD Missing ID in URL", func(t *testing.T) { + req, _ := http.NewRequest("GET", "/v1/acl/binding-rule/?token=root", nil) + resp := httptest.NewRecorder() + _, err := a.srv.ACLBindingRuleCRUD(resp, req) + require.Error(t, err) + _, ok := err.(BadRequestError) + require.True(t, ok) + }) + + t.Run("Update", func(t *testing.T) { + ruleInput := &structs.ACLBindingRule{ + Description: "updated", + AuthMethod: "test", + Selector: "serviceaccount.namespace==default", + BindType: structs.BindingRuleBindTypeService, + BindName: "${serviceaccount.name}", + } + + req, _ := http.NewRequest("PUT", "/v1/acl/binding-rule/"+idMap["rule-test"]+"?token=root", jsonBody(ruleInput)) + resp := httptest.NewRecorder() + obj, err := a.srv.ACLBindingRuleCRUD(resp, req) + require.NoError(t, err) + + rule, ok := obj.(*structs.ACLBindingRule) + require.True(t, ok) + + // 36 = length of the string form of uuids + require.Len(t, rule.ID, 36) + require.Equal(t, ruleInput.Description, rule.Description) + require.Equal(t, ruleInput.AuthMethod, rule.AuthMethod) + require.Equal(t, ruleInput.Selector, rule.Selector) + require.Equal(t, ruleInput.BindType, rule.BindType) + require.Equal(t, ruleInput.BindName, rule.BindName) + require.True(t, rule.CreateIndex > 0) + require.True(t, rule.CreateIndex < rule.ModifyIndex) + + idMap["rule-test"] = rule.ID + ruleMap[rule.ID] = rule + }) + + t.Run("ID Supplied", func(t *testing.T) { + ruleInput := &structs.ACLBindingRule{ + ID: "12123d01-37f1-47e6-b55b-32328652bd38", + Description: "with-id", + AuthMethod: "test", + Selector: "serviceaccount.namespace==default", + BindType: structs.BindingRuleBindTypeService, + BindName: "vault", + } + + req, _ := http.NewRequest("PUT", "/v1/acl/binding-rule?token=root", jsonBody(ruleInput)) + resp := httptest.NewRecorder() + _, err := a.srv.ACLBindingRuleCreate(resp, req) + require.Error(t, err) + _, ok := err.(BadRequestError) + require.True(t, ok) + }) + + t.Run("Invalid payload", func(t *testing.T) { + body := bytes.NewBuffer(nil) + body.Write([]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}) + + req, _ := http.NewRequest("PUT", "/v1/acl/binding-rule?token=root", body) + resp := httptest.NewRecorder() + _, err := a.srv.ACLBindingRuleCreate(resp, req) + require.Error(t, err) + _, ok := err.(BadRequestError) + require.True(t, ok) + }) + + t.Run("List", func(t *testing.T) { + req, _ := http.NewRequest("GET", "/v1/acl/binding-rules?token=root", nil) + resp := httptest.NewRecorder() + raw, err := a.srv.ACLBindingRuleList(resp, req) + require.NoError(t, err) + rules, ok := raw.(structs.ACLBindingRules) + require.True(t, ok) + + // 2 we just created + require.Len(t, rules, 2) + + for ruleID, expected := range ruleMap { + found := false + for _, actual := range rules { + if actual.ID == ruleID { + require.Equal(t, expected.Description, actual.Description) + require.Equal(t, expected.AuthMethod, actual.AuthMethod) + require.Equal(t, expected.Selector, actual.Selector) + require.Equal(t, expected.BindType, actual.BindType) + require.Equal(t, expected.BindName, actual.BindName) + require.Equal(t, expected.CreateIndex, actual.CreateIndex) + require.Equal(t, expected.ModifyIndex, actual.ModifyIndex) + found = true + break + } + } + + require.True(t, found) + } + }) + + t.Run("Delete", func(t *testing.T) { + req, _ := http.NewRequest("DELETE", "/v1/acl/binding-rule/"+idMap["rule-other"]+"?token=root", nil) + resp := httptest.NewRecorder() + _, err := a.srv.ACLBindingRuleCRUD(resp, req) + require.NoError(t, err) + delete(ruleMap, idMap["rule-other"]) + delete(idMap, "rule-other") + }) + + t.Run("Read", func(t *testing.T) { + req, _ := http.NewRequest("GET", "/v1/acl/binding-rule/"+idMap["rule-test"]+"?token=root", nil) + resp := httptest.NewRecorder() + raw, err := a.srv.ACLBindingRuleCRUD(resp, req) + require.NoError(t, err) + rule, ok := raw.(*structs.ACLBindingRule) + require.True(t, ok) + require.Equal(t, ruleMap[idMap["rule-test"]], rule) + }) + }) + + testauth.InstallSessionToken(testSessionID, "token1", "default", "demo1", "abc123") + testauth.InstallSessionToken(testSessionID, "token2", "default", "demo2", "def456") + + t.Run("Login", func(t *testing.T) { + t.Run("Create Token 1", func(t *testing.T) { + loginInput := &structs.ACLLoginParams{ + AuthMethod: "test", + BearerToken: "token1", + Meta: map[string]string{"foo": "bar"}, + } + + req, _ := http.NewRequest("POST", "/v1/acl/login?token=root", jsonBody(loginInput)) + resp := httptest.NewRecorder() + obj, err := a.srv.ACLLogin(resp, req) + require.NoError(t, err) + + token, ok := obj.(*structs.ACLToken) + require.True(t, ok) + + // 36 = length of the string form of uuids + require.Len(t, token.AccessorID, 36) + require.Len(t, token.SecretID, 36) + require.Equal(t, `token created via login: {"foo":"bar"}`, token.Description) + require.True(t, token.Local) + require.Len(t, token.Policies, 0) + require.Len(t, token.Roles, 0) + require.Len(t, token.ServiceIdentities, 1) + require.Equal(t, "demo1", token.ServiceIdentities[0].ServiceName) + require.Len(t, token.ServiceIdentities[0].Datacenters, 0) + require.True(t, token.CreateIndex > 0) + require.Equal(t, token.CreateIndex, token.ModifyIndex) + require.NotNil(t, token.Hash) + require.NotEqual(t, token.Hash, []byte{}) + + idMap["token-test-1"] = token.AccessorID + tokenMap[token.AccessorID] = token + }) + t.Run("Create Token 2", func(t *testing.T) { + loginInput := &structs.ACLLoginParams{ + AuthMethod: "test", + BearerToken: "token2", + Meta: map[string]string{"blah": "woot"}, + } + + req, _ := http.NewRequest("POST", "/v1/acl/login?token=root", jsonBody(loginInput)) + resp := httptest.NewRecorder() + obj, err := a.srv.ACLLogin(resp, req) + require.NoError(t, err) + + token, ok := obj.(*structs.ACLToken) + require.True(t, ok) + + // 36 = length of the string form of uuids + require.Len(t, token.AccessorID, 36) + require.Len(t, token.SecretID, 36) + require.Equal(t, `token created via login: {"blah":"woot"}`, token.Description) + require.True(t, token.Local) + require.Len(t, token.Policies, 0) + require.Len(t, token.Roles, 0) + require.Len(t, token.ServiceIdentities, 1) + require.Equal(t, "demo2", token.ServiceIdentities[0].ServiceName) + require.Len(t, token.ServiceIdentities[0].Datacenters, 0) + require.True(t, token.CreateIndex > 0) + require.Equal(t, token.CreateIndex, token.ModifyIndex) + require.NotNil(t, token.Hash) + require.NotEqual(t, token.Hash, []byte{}) + + idMap["token-test-2"] = token.AccessorID + tokenMap[token.AccessorID] = token + }) + + t.Run("List Tokens by (incorrect) Method", func(t *testing.T) { + req, _ := http.NewRequest("GET", "/v1/acl/tokens?token=root&authmethod=other", nil) + resp := httptest.NewRecorder() + raw, err := a.srv.ACLTokenList(resp, req) + require.NoError(t, err) + tokens, ok := raw.(structs.ACLTokenListStubs) + require.True(t, ok) + require.Len(t, tokens, 0) + }) + + t.Run("List Tokens by (correct) Method", func(t *testing.T) { + req, _ := http.NewRequest("GET", "/v1/acl/tokens?token=root&authmethod=test", nil) + resp := httptest.NewRecorder() + raw, err := a.srv.ACLTokenList(resp, req) + require.NoError(t, err) + tokens, ok := raw.(structs.ACLTokenListStubs) + require.True(t, ok) + require.Len(t, tokens, 2) + + for tokenID, expected := range tokenMap { + found := false + for _, actual := range tokens { + if actual.AccessorID == tokenID { + require.Equal(t, expected.Description, actual.Description) + require.Equal(t, expected.Policies, actual.Policies) + require.Equal(t, expected.Roles, actual.Roles) + require.Equal(t, expected.ServiceIdentities, actual.ServiceIdentities) + require.Equal(t, expected.Local, actual.Local) + require.Equal(t, expected.CreateTime, actual.CreateTime) + require.Equal(t, expected.Hash, actual.Hash) + require.Equal(t, expected.CreateIndex, actual.CreateIndex) + require.Equal(t, expected.ModifyIndex, actual.ModifyIndex) + found = true + break + } + } + require.True(t, found) + } + }) + + t.Run("Logout", func(t *testing.T) { + tok := tokenMap[idMap["token-test-1"]] + req, _ := http.NewRequest("POST", "/v1/acl/logout?token="+tok.SecretID, nil) + resp := httptest.NewRecorder() + _, err := a.srv.ACLLogout(resp, req) + require.NoError(t, err) + }) + + t.Run("Token is gone after Logout", func(t *testing.T) { + req, _ := http.NewRequest("GET", "/v1/acl/token/"+idMap["token-test-1"]+"?token=root", nil) + resp := httptest.NewRecorder() + _, err := a.srv.ACLTokenCRUD(resp, req) + require.Error(t, err) + require.True(t, acl.IsErrNotFound(err), err.Error()) + }) + }) +} diff --git a/agent/consul/acl.go b/agent/consul/acl.go index 74ebb90385..6e8130af27 100644 --- a/agent/consul/acl.go +++ b/agent/consul/acl.go @@ -447,25 +447,8 @@ func (r *ACLResolver) fetchAndCachePoliciesForIdentity(identity structs.ACLIdent return out, nil } - if acl.IsErrNotFound(err) { - // make sure to indicate that this identity is no longer valid within - // the cache - r.cache.PutIdentity(identity.SecretToken(), nil) - - // Do not touch the policy cache. Getting a top level ACL not found error - // only indicates that the secret token used in the request - // no longer exists - return nil, &policyOrRoleTokenError{acl.ErrNotFound, identity.SecretToken()} - } - - if acl.IsErrPermissionDenied(err) { - // invalidate our ID cache so that identity resolution will take place - // again in the future - r.cache.RemoveIdentity(identity.SecretToken()) - - // Do not remove from the policy cache for permission denied - // what this does indicate is that our view of the token is out of date - return nil, &policyOrRoleTokenError{acl.ErrPermissionDenied, identity.SecretToken()} + if handledErr := r.maybeHandleIdentityErrorDuringFetch(identity, err); handledErr != nil { + return nil, handledErr } // other RPC error - use cache if available @@ -519,25 +502,8 @@ func (r *ACLResolver) fetchAndCacheRolesForIdentity(identity structs.ACLIdentity return out, nil } - if acl.IsErrNotFound(err) { - // make sure to indicate that this identity is no longer valid within - // the cache - r.cache.PutIdentity(identity.SecretToken(), nil) - - // Do not touch the cache. Getting a top level ACL not found error - // only indicates that the secret token used in the request - // no longer exists - return nil, &policyOrRoleTokenError{acl.ErrNotFound, identity.SecretToken()} - } - - if acl.IsErrPermissionDenied(err) { - // invalidate our ID cache so that identity resolution will take place - // again in the future - r.cache.RemoveIdentity(identity.SecretToken()) - - // Do not remove from the cache for permission denied - // what this does indicate is that our view of the token is out of date - return nil, &policyOrRoleTokenError{acl.ErrPermissionDenied, identity.SecretToken()} + if handledErr := r.maybeHandleIdentityErrorDuringFetch(identity, err); handledErr != nil { + return nil, handledErr } // other RPC error - use cache if available @@ -557,12 +523,39 @@ func (r *ACLResolver) fetchAndCacheRolesForIdentity(identity structs.ACLIdentity insufficientCache = true } } + if insufficientCache { return nil, ACLRemoteError{Err: err} } + return out, nil } +func (r *ACLResolver) maybeHandleIdentityErrorDuringFetch(identity structs.ACLIdentity, err error) error { + if acl.IsErrNotFound(err) { + // make sure to indicate that this identity is no longer valid within + // the cache + r.cache.PutIdentity(identity.SecretToken(), nil) + + // Do not touch the cache. Getting a top level ACL not found error + // only indicates that the secret token used in the request + // no longer exists + return &policyOrRoleTokenError{acl.ErrNotFound, identity.SecretToken()} + } + + if acl.IsErrPermissionDenied(err) { + // invalidate our ID cache so that identity resolution will take place + // again in the future + r.cache.RemoveIdentity(identity.SecretToken()) + + // Do not remove from the cache for permission denied + // what this does indicate is that our view of the token is out of date + return &policyOrRoleTokenError{acl.ErrPermissionDenied, identity.SecretToken()} + } + + return nil +} + func (r *ACLResolver) filterPoliciesByScope(policies structs.ACLPolicies) structs.ACLPolicies { var out structs.ACLPolicies for _, policy := range policies { diff --git a/agent/consul/acl_authmethod.go b/agent/consul/acl_authmethod.go new file mode 100644 index 0000000000..ba3b3772da --- /dev/null +++ b/agent/consul/acl_authmethod.go @@ -0,0 +1,169 @@ +package consul + +import ( + "fmt" + + "github.com/hashicorp/consul/agent/consul/authmethod" + "github.com/hashicorp/consul/agent/structs" + "github.com/hashicorp/go-bexpr" + + // register this as a builtin auth method + _ "github.com/hashicorp/consul/agent/consul/authmethod/kubeauth" +) + +type authMethodValidatorEntry struct { + Validator authmethod.Validator + ModifyIndex uint64 // the raft index when this last changed +} + +// loadAuthMethodValidator returns an authmethod.Validator for the given auth +// method configuration. If the cache is up to date as-of the provided index +// then the cached version is returned, otherwise a new validator is created +// and cached. +func (s *Server) loadAuthMethodValidator(idx uint64, method *structs.ACLAuthMethod) (authmethod.Validator, error) { + if prevIdx, v, ok := s.getCachedAuthMethodValidator(method.Name); ok && idx <= prevIdx { + return v, nil + } + + v, err := authmethod.NewValidator(method) + if err != nil { + return nil, fmt.Errorf("auth method validator for %q could not be initialized: %v", method.Name, err) + } + + v = s.getOrReplaceAuthMethodValidator(method.Name, idx, v) + + return v, nil +} + +// getCachedAuthMethodValidator returns an AuthMethodValidator for +// the given name exclusively from the cache. If one is not found in the cache +// nil is returned. +func (s *Server) getCachedAuthMethodValidator(name string) (uint64, authmethod.Validator, bool) { + s.aclAuthMethodValidatorLock.RLock() + defer s.aclAuthMethodValidatorLock.RUnlock() + + if s.aclAuthMethodValidators != nil { + v, ok := s.aclAuthMethodValidators[name] + if ok { + return v.ModifyIndex, v.Validator, true + } + } + return 0, nil, false +} + +// getOrReplaceAuthMethodValidator updates the cached validator with the +// provided one UNLESS it has been updated by another goroutine in which case +// the updated one is returned. +func (s *Server) getOrReplaceAuthMethodValidator(name string, idx uint64, v authmethod.Validator) authmethod.Validator { + s.aclAuthMethodValidatorLock.Lock() + defer s.aclAuthMethodValidatorLock.Unlock() + + if s.aclAuthMethodValidators == nil { + s.aclAuthMethodValidators = make(map[string]*authMethodValidatorEntry) + } + + prev, ok := s.aclAuthMethodValidators[name] + if ok { + if prev.ModifyIndex >= idx { + return prev.Validator + } + } + + s.logger.Printf("[DEBUG] acl: updating cached auth method validator for %q", name) + + s.aclAuthMethodValidators[name] = &authMethodValidatorEntry{ + Validator: v, + ModifyIndex: idx, + } + return v +} + +// purgeAuthMethodValidators resets the cache of validators. +func (s *Server) purgeAuthMethodValidators() { + s.aclAuthMethodValidatorLock.Lock() + s.aclAuthMethodValidators = make(map[string]*authMethodValidatorEntry) + s.aclAuthMethodValidatorLock.Unlock() +} + +// evaluateRoleBindings evaluates all current binding rules associated with the +// given auth method against the verified data returned from the authentication +// process. +// +// A list of role links and service identities are returned. +func (s *Server) evaluateRoleBindings( + validator authmethod.Validator, + verifiedFields map[string]string, +) ([]*structs.ACLServiceIdentity, []structs.ACLTokenRoleLink, error) { + // Only fetch rules that are relevant for this method. + _, rules, err := s.fsm.State().ACLBindingRuleList(nil, validator.Name()) + if err != nil { + return nil, nil, err + } else if len(rules) == 0 { + return nil, nil, nil + } + + // Convert the fields into something suitable for go-bexpr. + selectableVars := validator.MakeFieldMapSelectable(verifiedFields) + + // Find all binding rules that match the provided fields. + var matchingRules []*structs.ACLBindingRule + for _, rule := range rules { + if doesBindingRuleMatch(rule, selectableVars) { + matchingRules = append(matchingRules, rule) + } + } + if len(matchingRules) == 0 { + return nil, nil, nil + } + + // For all matching rules compute the attributes of a token. + var ( + roleLinks []structs.ACLTokenRoleLink + serviceIdentities []*structs.ACLServiceIdentity + ) + for _, rule := range matchingRules { + bindName, valid, err := computeBindingRuleBindName(rule.BindType, rule.BindName, verifiedFields) + if err != nil { + return nil, nil, fmt.Errorf("cannot compute %q bind name for bind target: %v", rule.BindType, err) + } else if !valid { + return nil, nil, fmt.Errorf("computed %q bind name for bind target is invalid: %q", rule.BindType, bindName) + } + + switch rule.BindType { + case structs.BindingRuleBindTypeService: + serviceIdentities = append(serviceIdentities, &structs.ACLServiceIdentity{ + ServiceName: bindName, + }) + + case structs.BindingRuleBindTypeRole: + roleLinks = append(roleLinks, structs.ACLTokenRoleLink{ + Name: bindName, + }) + + default: + // skip unknown bind type; don't grant privileges + } + } + + return serviceIdentities, roleLinks, nil +} + +// doesBindingRuleMatch checks that a single binding rule matches the provided +// vars. +func doesBindingRuleMatch(rule *structs.ACLBindingRule, selectableVars interface{}) bool { + if rule.Selector == "" { + return true // catch-all + } + + eval, err := bexpr.CreateEvaluatorForType(rule.Selector, nil, selectableVars) + if err != nil { + return false // fails to match if selector is invalid + } + + result, err := eval.Evaluate(selectableVars) + if err != nil { + return false // fails to match if evaluation fails + } + + return result +} diff --git a/agent/consul/acl_authmethod_test.go b/agent/consul/acl_authmethod_test.go new file mode 100644 index 0000000000..45e3021e44 --- /dev/null +++ b/agent/consul/acl_authmethod_test.go @@ -0,0 +1,48 @@ +package consul + +import ( + "testing" + + "github.com/hashicorp/consul/agent/structs" + "github.com/stretchr/testify/require" +) + +func TestDoesBindingRuleMatch(t *testing.T) { + type matchable struct { + A string `bexpr:"a"` + C string `bexpr:"c"` + } + + for _, test := range []struct { + name string + selector string + details interface{} + ok bool + }{ + {"no fields", + "a==b", nil, false}, + {"1 term ok", + "a==b", &matchable{A: "b"}, true}, + {"1 term no field", + "a==b", &matchable{C: "d"}, false}, + {"1 term wrong value", + "a==b", &matchable{A: "z"}, false}, + {"2 terms ok", + "a==b and c==d", &matchable{A: "b", C: "d"}, true}, + {"2 terms one missing field", + "a==b and c==d", &matchable{A: "b"}, false}, + {"2 terms one wrong value", + "a==b and c==d", &matchable{A: "z", C: "d"}, false}, + /////////////////////////////// + {"no fields (no selectors)", + "", nil, true}, + {"1 term ok (no selectors)", + "", &matchable{A: "b"}, true}, + } { + t.Run(test.name, func(t *testing.T) { + rule := structs.ACLBindingRule{Selector: test.selector} + ok := doesBindingRuleMatch(&rule, test.details) + require.Equal(t, test.ok, ok) + }) + } +} diff --git a/agent/consul/acl_endpoint.go b/agent/consul/acl_endpoint.go index 62c17496f9..cc3b5da2e7 100644 --- a/agent/consul/acl_endpoint.go +++ b/agent/consul/acl_endpoint.go @@ -1,6 +1,8 @@ package consul import ( + "encoding/json" + "errors" "fmt" "io/ioutil" "os" @@ -10,9 +12,11 @@ import ( metrics "github.com/armon/go-metrics" "github.com/hashicorp/consul/acl" + "github.com/hashicorp/consul/agent/consul/authmethod" "github.com/hashicorp/consul/agent/consul/state" "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/lib" + "github.com/hashicorp/go-bexpr" memdb "github.com/hashicorp/go-memdb" uuid "github.com/hashicorp/go-uuid" ) @@ -29,6 +33,7 @@ var ( validServiceIdentityName = regexp.MustCompile(`^[a-z0-9]([a-z0-9\-_]*[a-z0-9])?$`) serviceIdentityNameMaxLength = 256 validRoleName = regexp.MustCompile(`^[A-Za-z0-9\-_]{1,256}$`) + validAuthMethod = regexp.MustCompile(`^[A-Za-z0-9\-_]{1,128}$`) ) // ACL endpoint is used to manipulate ACLs @@ -273,6 +278,10 @@ func (a *ACL) TokenClone(args *structs.ACLTokenSetRequest, reply *structs.ACLTok return a.srv.forwardDC("ACL.TokenClone", a.srv.config.ACLDatacenter, args, reply) } + if token.AuthMethod != "" { + return fmt.Errorf("Cannot clone a token created from an auth method") + } + if token.Rules != "" { return fmt.Errorf("Cannot clone a legacy ACL with this endpoint") } @@ -324,7 +333,7 @@ func (a *ACL) TokenSet(args *structs.ACLTokenSetRequest, reply *structs.ACLToken return a.tokenSetInternal(args, reply, false) } -func (a *ACL) tokenSetInternal(args *structs.ACLTokenSetRequest, reply *structs.ACLToken, upgrade bool) error { +func (a *ACL) tokenSetInternal(args *structs.ACLTokenSetRequest, reply *structs.ACLToken, fromLogin bool) error { token := &args.ACLToken if !a.srv.LocalTokensEnabled() { @@ -354,6 +363,19 @@ func (a *ACL) tokenSetInternal(args *structs.ACLTokenSetRequest, reply *structs. token.CreateTime = time.Now() + if fromLogin { + if token.AuthMethod == "" { + return fmt.Errorf("AuthMethod field is required during Login") + } + if !token.Local { + return fmt.Errorf("Cannot create Global token via Login") + } + } else { + if token.AuthMethod != "" { + return fmt.Errorf("AuthMethod field is disallowed outside of Login") + } + } + // Ensure an ExpirationTTL is valid if provided. if token.ExpirationTTL != 0 { if token.ExpirationTTL < 0 { @@ -418,6 +440,12 @@ func (a *ACL) tokenSetInternal(args *structs.ACLTokenSetRequest, reply *structs. return fmt.Errorf("cannot toggle local mode of %s", token.AccessorID) } + if token.AuthMethod == "" { + token.AuthMethod = existing.AuthMethod + } else if token.AuthMethod != existing.AuthMethod { + return fmt.Errorf("Cannot change AuthMethod of %s", token.AccessorID) + } + if token.ExpirationTTL != 0 { return fmt.Errorf("Cannot change expiration time of %s", token.AccessorID) } @@ -430,11 +458,7 @@ func (a *ACL) tokenSetInternal(args *structs.ACLTokenSetRequest, reply *structs. return fmt.Errorf("Cannot change expiration time of %s", token.AccessorID) } - if upgrade { - token.CreateTime = time.Now() - } else { - token.CreateTime = existing.CreateTime - } + token.CreateTime = existing.CreateTime } policyIDs := make(map[string]struct{}) @@ -467,7 +491,7 @@ func (a *ACL) tokenSetInternal(args *structs.ACLTokenSetRequest, reply *structs. roleIDs := make(map[string]struct{}) var roles []structs.ACLTokenRoleLink - // Validate all the role names and convert them to role IDs + // Validate all the role names and convert them to role IDs. for _, link := range token.Roles { if link.ID == "" { _, role, err := state.ACLRoleGetByName(nil, link.Name) @@ -502,6 +526,7 @@ func (a *ACL) tokenSetInternal(args *structs.ACLTokenSetRequest, reply *structs. return fmt.Errorf("Service identity %q has an invalid name. Only alphanumeric characters, '-' and '_' are allowed", svcid.ServiceName) } } + token.ServiceIdentities = dedupeServiceIdentities(token.ServiceIdentities) if token.Rules != "" { return fmt.Errorf("Rules cannot be specified for this token") @@ -540,6 +565,51 @@ func (a *ACL) tokenSetInternal(args *structs.ACLTokenSetRequest, reply *structs. return nil } +func validateBindingRuleBindName(bindType, bindName string, availableFields []string) (bool, error) { + if bindType == "" || bindName == "" { + return false, nil + } + + fakeVarMap := make(map[string]string) + for _, v := range availableFields { + fakeVarMap[v] = "fake" + } + + _, valid, err := computeBindingRuleBindName(bindType, bindName, fakeVarMap) + if err != nil { + return false, err + } + return valid, nil +} + +// computeBindingRuleBindName processes the HIL for the provided bind type+name +// using the verified fields. +// +// - If the HIL is invalid ("", false, AN_ERROR) is returned. +// - If the computed name is not valid for the type ("INVALID_NAME", false, nil) is returned. +// - If the computed name is valid for the type ("VALID_NAME", true, nil) is returned. +func computeBindingRuleBindName(bindType, bindName string, verifiedFields map[string]string) (string, bool, error) { + bindName, err := InterpolateHIL(bindName, verifiedFields) + if err != nil { + return "", false, err + } + + valid := false + + switch bindType { + case structs.BindingRuleBindTypeService: + valid = isValidServiceIdentityName(bindName) + + case structs.BindingRuleBindTypeRole: + valid = validRoleName.MatchString(bindName) + + default: + return "", false, fmt.Errorf("unknown binding rule bind type: %s", bindType) + } + + return bindName, valid, nil +} + // isValidServiceIdentityName returns true if the provided name can be used as // an ACLServiceIdentity ServiceName. This is more restrictive than standard // catalog registration, which basically takes the view that "everything is @@ -652,7 +722,7 @@ func (a *ACL) TokenList(args *structs.ACLTokenListRequest, reply *structs.ACLTok return a.srv.blockingQuery(&args.QueryOptions, &reply.QueryMeta, func(ws memdb.WatchSet, state *state.Store) error { - index, tokens, err := state.ACLTokenList(ws, args.IncludeLocal, args.IncludeGlobal, args.Policy, args.Role) + index, tokens, err := state.ACLTokenList(ws, args.IncludeLocal, args.IncludeGlobal, args.Policy, args.Role, args.AuthMethod) if err != nil { return err } @@ -1252,17 +1322,11 @@ func (a *ACL) RoleSet(args *structs.ACLRoleSetRequest, reply *structs.ACLRole) e if svcid.ServiceName == "" { return fmt.Errorf("Service identity is missing the service name field on this role") } - // TODO(rb): ugh if a local token gets a role that has a service - // identity that has datacenters set, we won't be anble to enforce this - // next blob here. This makes me lean more towards nuking ServiceIdentity.Datacenters again - // - // if token.Local && len(svcid.Datacenters) > 0 { - // return fmt.Errorf("Service identity %q cannot specify a list of datacenters on a local token", svcid.ServiceName) - // } if !isValidServiceIdentityName(svcid.ServiceName) { return fmt.Errorf("Service identity %q has an invalid name. Only alphanumeric characters, '-' and '_' are allowed", svcid.ServiceName) } } + role.ServiceIdentities = dedupeServiceIdentities(role.ServiceIdentities) // calculate the hash for this role role.SetHash(true) @@ -1412,3 +1476,577 @@ func (a *ACL) RoleResolve(args *structs.ACLRoleBatchGetRequest, reply *structs.A return nil } + +var errAuthMethodsRequireTokenReplication = errors.New("Token replication is required for auth methods to function") + +func (a *ACL) BindingRuleRead(args *structs.ACLBindingRuleGetRequest, reply *structs.ACLBindingRuleResponse) error { + if err := a.aclPreCheck(); err != nil { + return err + } + + if !a.srv.LocalTokensEnabled() { + return errAuthMethodsRequireTokenReplication + } + + if done, err := a.srv.forward("ACL.BindingRuleRead", args, args, reply); done { + return err + } + + if rule, err := a.srv.ResolveToken(args.Token); err != nil { + return err + } else if rule == nil || !rule.ACLRead() { + return acl.ErrPermissionDenied + } + + return a.srv.blockingQuery(&args.QueryOptions, &reply.QueryMeta, + func(ws memdb.WatchSet, state *state.Store) error { + index, rule, err := state.ACLBindingRuleGetByID(ws, args.BindingRuleID) + + if err != nil { + return err + } + + reply.Index, reply.BindingRule = index, rule + return nil + }) +} + +func (a *ACL) BindingRuleSet(args *structs.ACLBindingRuleSetRequest, reply *structs.ACLBindingRule) error { + if err := a.aclPreCheck(); err != nil { + return err + } + + if !a.srv.LocalTokensEnabled() { + return errAuthMethodsRequireTokenReplication + } + + if done, err := a.srv.forward("ACL.BindingRuleSet", args, args, reply); done { + return err + } + + defer metrics.MeasureSince([]string{"acl", "bindingrule", "upsert"}, time.Now()) + + // Verify token is permitted to modify ACLs + if rule, err := a.srv.ResolveToken(args.Token); err != nil { + return err + } else if rule == nil || !rule.ACLWrite() { + return acl.ErrPermissionDenied + } + + rule := &args.BindingRule + state := a.srv.fsm.State() + + if rule.ID == "" { + // with no binding rule ID one will be generated + var err error + + rule.ID, err = lib.GenerateUUID(a.srv.checkBindingRuleUUID) + if err != nil { + return err + } + } else { + if _, err := uuid.ParseUUID(rule.ID); err != nil { + return fmt.Errorf("Binding Rule ID invalid UUID") + } + + // Verify the role exists + _, existing, err := state.ACLBindingRuleGetByID(nil, rule.ID) + if err != nil { + return fmt.Errorf("acl binding rule lookup failed: %v", err) + } else if existing == nil { + return fmt.Errorf("cannot find binding rule %s", rule.ID) + } + + if rule.AuthMethod == "" { + rule.AuthMethod = existing.AuthMethod + } else if existing.AuthMethod != rule.AuthMethod { + return fmt.Errorf("the AuthMethod field of an Binding Rule is immutable") + } + } + + if rule.AuthMethod == "" { + return fmt.Errorf("Invalid Binding Rule: no AuthMethod is set") + } + + methodIdx, method, err := state.ACLAuthMethodGetByName(nil, rule.AuthMethod) + if err != nil { + return fmt.Errorf("acl auth method lookup failed: %v", err) + } else if method == nil { + return fmt.Errorf("cannot find auth method with name %q", rule.AuthMethod) + } + validator, err := a.srv.loadAuthMethodValidator(methodIdx, method) + if err != nil { + return err + } + + if rule.Selector != "" { + selectableVars := validator.MakeFieldMapSelectable(map[string]string{}) + _, err := bexpr.CreateEvaluatorForType(rule.Selector, nil, selectableVars) + if err != nil { + return fmt.Errorf("invalid Binding Rule: Selector is invalid: %v", err) + } + } + + if rule.BindType == "" { + return fmt.Errorf("Invalid Binding Rule: no BindType is set") + } + + if rule.BindName == "" { + return fmt.Errorf("Invalid Binding Rule: no BindName is set") + } + + switch rule.BindType { + case structs.BindingRuleBindTypeService: + case structs.BindingRuleBindTypeRole: + default: + return fmt.Errorf("Invalid Binding Rule: unknown BindType %q", rule.BindType) + } + + if valid, err := validateBindingRuleBindName(rule.BindType, rule.BindName, validator.AvailableFields()); err != nil { + return fmt.Errorf("Invalid Binding Rule: invalid BindName: %v", err) + } else if !valid { + return fmt.Errorf("Invalid Binding Rule: invalid BindName") + } + + req := &structs.ACLBindingRuleBatchSetRequest{ + BindingRules: structs.ACLBindingRules{rule}, + } + + resp, err := a.srv.raftApply(structs.ACLBindingRuleSetRequestType, req) + if err != nil { + return fmt.Errorf("Failed to apply binding rule upsert request: %v", err) + } + + if respErr, ok := resp.(error); ok { + return respErr + } + + if _, rule, err := a.srv.fsm.State().ACLBindingRuleGetByID(nil, rule.ID); err == nil && rule != nil { + *reply = *rule + } + + return nil +} + +func (a *ACL) BindingRuleDelete(args *structs.ACLBindingRuleDeleteRequest, reply *bool) error { + if err := a.aclPreCheck(); err != nil { + return err + } + + if !a.srv.LocalTokensEnabled() { + return errAuthMethodsRequireTokenReplication + } + + if done, err := a.srv.forward("ACL.BindingRuleDelete", args, args, reply); done { + return err + } + + defer metrics.MeasureSince([]string{"acl", "bindingrule", "delete"}, time.Now()) + + // Verify token is permitted to modify ACLs + if rule, err := a.srv.ResolveToken(args.Token); err != nil { + return err + } else if rule == nil || !rule.ACLWrite() { + return acl.ErrPermissionDenied + } + + _, rule, err := a.srv.fsm.State().ACLBindingRuleGetByID(nil, args.BindingRuleID) + if err != nil { + return err + } + + if rule == nil { + return nil + } + + req := structs.ACLBindingRuleBatchDeleteRequest{ + BindingRuleIDs: []string{args.BindingRuleID}, + } + + resp, err := a.srv.raftApply(structs.ACLBindingRuleDeleteRequestType, &req) + if err != nil { + return fmt.Errorf("Failed to apply binding rule delete request: %v", err) + } + + if respErr, ok := resp.(error); ok { + return respErr + } + + *reply = true + + return nil +} + +func (a *ACL) BindingRuleList(args *structs.ACLBindingRuleListRequest, reply *structs.ACLBindingRuleListResponse) error { + if err := a.aclPreCheck(); err != nil { + return err + } + + if !a.srv.LocalTokensEnabled() { + return errAuthMethodsRequireTokenReplication + } + + if done, err := a.srv.forward("ACL.BindingRuleList", args, args, reply); done { + return err + } + + if rule, err := a.srv.ResolveToken(args.Token); err != nil { + return err + } else if rule == nil || !rule.ACLRead() { + return acl.ErrPermissionDenied + } + + return a.srv.blockingQuery(&args.QueryOptions, &reply.QueryMeta, + func(ws memdb.WatchSet, state *state.Store) error { + index, rules, err := state.ACLBindingRuleList(ws, args.AuthMethod) + if err != nil { + return err + } + + reply.Index, reply.BindingRules = index, rules + return nil + }) +} + +func (a *ACL) AuthMethodRead(args *structs.ACLAuthMethodGetRequest, reply *structs.ACLAuthMethodResponse) error { + if err := a.aclPreCheck(); err != nil { + return err + } + + if !a.srv.LocalTokensEnabled() { + return errAuthMethodsRequireTokenReplication + } + + if done, err := a.srv.forward("ACL.AuthMethodRead", args, args, reply); done { + return err + } + + if rule, err := a.srv.ResolveToken(args.Token); err != nil { + return err + } else if rule == nil || !rule.ACLRead() { + return acl.ErrPermissionDenied + } + + return a.srv.blockingQuery(&args.QueryOptions, &reply.QueryMeta, + func(ws memdb.WatchSet, state *state.Store) error { + index, method, err := state.ACLAuthMethodGetByName(ws, args.AuthMethodName) + + if err != nil { + return err + } + + reply.Index, reply.AuthMethod = index, method + return nil + }) +} + +func (a *ACL) AuthMethodSet(args *structs.ACLAuthMethodSetRequest, reply *structs.ACLAuthMethod) error { + if err := a.aclPreCheck(); err != nil { + return err + } + + if !a.srv.LocalTokensEnabled() { + return errAuthMethodsRequireTokenReplication + } + + if done, err := a.srv.forward("ACL.AuthMethodSet", args, args, reply); done { + return err + } + + defer metrics.MeasureSince([]string{"acl", "authmethod", "upsert"}, time.Now()) + + // Verify token is permitted to modify ACLs + if rule, err := a.srv.ResolveToken(args.Token); err != nil { + return err + } else if rule == nil || !rule.ACLWrite() { + return acl.ErrPermissionDenied + } + + method := &args.AuthMethod + state := a.srv.fsm.State() + + // ensure a name is set + if method.Name == "" { + return fmt.Errorf("Invalid Auth Method: no Name is set") + } + if !validAuthMethod.MatchString(method.Name) { + return fmt.Errorf("Invalid Auth Method: invalid Name. Only alphanumeric characters, '-' and '_' are allowed") + } + + // Check to see if the method exists first. + _, existing, err := state.ACLAuthMethodGetByName(nil, method.Name) + if err != nil { + return fmt.Errorf("acl auth method lookup failed: %v", err) + } + + if existing != nil { + if method.Type == "" { + method.Type = existing.Type + } else if existing.Type != method.Type { + return fmt.Errorf("the Type field of an Auth Method is immutable") + } + } + + if !authmethod.IsRegisteredType(method.Type) { + return fmt.Errorf("Invalid Auth Method: Type should be one of: %v", authmethod.Types()) + } + + // Instantiate a validator but do not cache it yet. This will validate the + // configuration. + if _, err := authmethod.NewValidator(method); err != nil { + return fmt.Errorf("Invalid Auth Method: %v", err) + } + + req := &structs.ACLAuthMethodBatchSetRequest{ + AuthMethods: structs.ACLAuthMethods{method}, + } + + resp, err := a.srv.raftApply(structs.ACLAuthMethodSetRequestType, req) + if err != nil { + return fmt.Errorf("Failed to apply auth method upsert request: %v", err) + } + + if respErr, ok := resp.(error); ok { + return respErr + } + + if _, method, err := a.srv.fsm.State().ACLAuthMethodGetByName(nil, method.Name); err == nil && method != nil { + *reply = *method + } + + return nil +} + +func (a *ACL) AuthMethodDelete(args *structs.ACLAuthMethodDeleteRequest, reply *bool) error { + if err := a.aclPreCheck(); err != nil { + return err + } + + if !a.srv.LocalTokensEnabled() { + return errAuthMethodsRequireTokenReplication + } + + if done, err := a.srv.forward("ACL.AuthMethodDelete", args, args, reply); done { + return err + } + + defer metrics.MeasureSince([]string{"acl", "authmethod", "delete"}, time.Now()) + + // Verify token is permitted to modify ACLs + if rule, err := a.srv.ResolveToken(args.Token); err != nil { + return err + } else if rule == nil || !rule.ACLWrite() { + return acl.ErrPermissionDenied + } + + _, method, err := a.srv.fsm.State().ACLAuthMethodGetByName(nil, args.AuthMethodName) + if err != nil { + return err + } + + if method == nil { + return nil + } + + req := structs.ACLAuthMethodBatchDeleteRequest{ + AuthMethodNames: []string{args.AuthMethodName}, + } + + resp, err := a.srv.raftApply(structs.ACLAuthMethodDeleteRequestType, &req) + if err != nil { + return fmt.Errorf("Failed to apply auth method delete request: %v", err) + } + + if respErr, ok := resp.(error); ok { + return respErr + } + + *reply = true + + return nil +} + +func (a *ACL) AuthMethodList(args *structs.ACLAuthMethodListRequest, reply *structs.ACLAuthMethodListResponse) error { + if err := a.aclPreCheck(); err != nil { + return err + } + + if !a.srv.LocalTokensEnabled() { + return errAuthMethodsRequireTokenReplication + } + + if done, err := a.srv.forward("ACL.AuthMethodList", args, args, reply); done { + return err + } + + if rule, err := a.srv.ResolveToken(args.Token); err != nil { + return err + } else if rule == nil || !rule.ACLRead() { + return acl.ErrPermissionDenied + } + + return a.srv.blockingQuery(&args.QueryOptions, &reply.QueryMeta, + func(ws memdb.WatchSet, state *state.Store) error { + index, methods, err := state.ACLAuthMethodList(ws) + if err != nil { + return err + } + + var stubs structs.ACLAuthMethodListStubs + for _, method := range methods { + stubs = append(stubs, method.Stub()) + } + + reply.Index, reply.AuthMethods = index, stubs + return nil + }) +} + +func (a *ACL) Login(args *structs.ACLLoginRequest, reply *structs.ACLToken) error { + if err := a.aclPreCheck(); err != nil { + return err + } + + if !a.srv.LocalTokensEnabled() { + return errAuthMethodsRequireTokenReplication + } + + if args.Token != "" { // This shouldn't happen. + return errors.New("do not provide a token when logging in") + } + + if done, err := a.srv.forward("ACL.Login", args, args, reply); done { + return err + } + + defer metrics.MeasureSince([]string{"acl", "login"}, time.Now()) + + auth := args.Auth + + // 1. take args.Data.AuthMethod to get an AuthMethod Validator + idx, method, err := a.srv.fsm.State().ACLAuthMethodGetByName(nil, auth.AuthMethod) + if err != nil { + return err + } else if method == nil { + return acl.ErrNotFound + } + + validator, err := a.srv.loadAuthMethodValidator(idx, method) + if err != nil { + return err + } + + // 2. Send args.Data.BearerToken to method validator and get back a fields map + verifiedFields, err := validator.ValidateLogin(auth.BearerToken) + if err != nil { + return err + } + + // 3. send map through role bindings + serviceIdentities, roleLinks, err := a.srv.evaluateRoleBindings(validator, verifiedFields) + if err != nil { + return err + } + + if len(serviceIdentities) == 0 && len(roleLinks) == 0 { + return acl.ErrPermissionDenied + } + + description := "token created via login" + loginMeta, err := encodeLoginMeta(auth.Meta) + if err != nil { + return err + } + if loginMeta != "" { + description += ": " + loginMeta + } + + // 4. create token + createReq := structs.ACLTokenSetRequest{ + Datacenter: args.Datacenter, + ACLToken: structs.ACLToken{ + Description: description, + Local: true, + AuthMethod: auth.AuthMethod, + ServiceIdentities: serviceIdentities, + Roles: roleLinks, + }, + WriteRequest: args.WriteRequest, + } + + // 5. return token information like a TokenCreate would + return a.tokenSetInternal(&createReq, reply, true) +} + +func encodeLoginMeta(meta map[string]string) (string, error) { + if len(meta) == 0 { + return "", nil + } + + d, err := json.Marshal(meta) + if err != nil { + return "", err + } + return string(d), nil +} + +func (a *ACL) Logout(args *structs.ACLLogoutRequest, reply *bool) error { + if err := a.aclPreCheck(); err != nil { + return err + } + + if !a.srv.LocalTokensEnabled() { + return errAuthMethodsRequireTokenReplication + } + + if args.Token == "" { + return acl.ErrNotFound + } + + if done, err := a.srv.forward("ACL.Logout", args, args, reply); done { + return err + } + + defer metrics.MeasureSince([]string{"acl", "logout"}, time.Now()) + + _, token, err := a.srv.fsm.State().ACLTokenGetBySecret(nil, args.Token) + if err != nil { + return err + + } else if token == nil { + return acl.ErrNotFound + + } else if token.AuthMethod == "" { + // Can't "logout" of a token that wasn't a result of login. + return acl.ErrPermissionDenied + + } else if !a.srv.InACLDatacenter() && !token.Local { + // global token writes must be forwarded to the primary DC + args.Datacenter = a.srv.config.ACLDatacenter + return a.srv.forwardDC("ACL.Logout", a.srv.config.ACLDatacenter, args, reply) + } + + // No need to check expiration time because it's being deleted. + + req := &structs.ACLTokenBatchDeleteRequest{ + TokenIDs: []string{token.AccessorID}, + } + + resp, err := a.srv.raftApply(structs.ACLTokenDeleteRequestType, req) + if err != nil { + return fmt.Errorf("Failed to apply token delete request: %v", err) + } + + // Purge the identity from the cache to prevent using the previous definition of the identity + if token != nil { + a.srv.acls.cache.RemoveIdentity(token.SecretID) + } + + if respErr, ok := resp.(error); ok { + return respErr + } + + *reply = true + + return nil +} diff --git a/agent/consul/acl_endpoint_legacy.go b/agent/consul/acl_endpoint_legacy.go index 3b5ee22c6e..16379faa2e 100644 --- a/agent/consul/acl_endpoint_legacy.go +++ b/agent/consul/acl_endpoint_legacy.go @@ -255,7 +255,7 @@ func (a *ACL) List(args *structs.DCSpecificRequest, return a.srv.blockingQuery(&args.QueryOptions, &reply.QueryMeta, func(ws memdb.WatchSet, state *state.Store) error { - index, tokens, err := state.ACLTokenList(ws, false, true, "", "") + index, tokens, err := state.ACLTokenList(ws, false, true, "", "", "") if err != nil { return err } diff --git a/agent/consul/acl_endpoint_test.go b/agent/consul/acl_endpoint_test.go index 99de10be32..bfdecafea9 100644 --- a/agent/consul/acl_endpoint_test.go +++ b/agent/consul/acl_endpoint_test.go @@ -12,6 +12,8 @@ import ( "time" "github.com/hashicorp/consul/acl" + "github.com/hashicorp/consul/agent/consul/authmethod/kubeauth" + "github.com/hashicorp/consul/agent/consul/authmethod/testauth" "github.com/hashicorp/consul/agent/structs" tokenStore "github.com/hashicorp/consul/agent/token" "github.com/hashicorp/consul/lib" @@ -964,6 +966,117 @@ func TestACLEndpoint_TokenSet(t *testing.T) { require.Len(t, token.Roles, 0) }) + t.Run("Create it with AuthMethod set outside of login", func(t *testing.T) { + req := structs.ACLTokenSetRequest{ + Datacenter: "dc1", + ACLToken: structs.ACLToken{ + Description: "foobar", + AuthMethod: "fakemethod", + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + resp := structs.ACLToken{} + + err := acl.TokenSet(&req, &resp) + requireErrorContains(t, err, "AuthMethod field is disallowed outside of Login") + }) + + t.Run("Update auth method linked token and try to change auth method", func(t *testing.T) { + acl := ACL{srv: s1} + + testSessionID := testauth.StartSession() + defer testauth.ResetSession(testSessionID) + testauth.InstallSessionToken(testSessionID, "fake-token", "default", "demo", "abc123") + + method1, err := upsertTestAuthMethod(codec, "root", "dc1", testSessionID) + require.NoError(t, err) + + _, err = upsertTestBindingRule(codec, "root", "dc1", method1.Name, "", structs.BindingRuleBindTypeService, "demo") + require.NoError(t, err) + + // create a token in one method + methodToken := structs.ACLToken{} + require.NoError(t, acl.Login(&structs.ACLLoginRequest{ + Auth: &structs.ACLLoginParams{ + AuthMethod: method1.Name, + BearerToken: "fake-token", + }, + Datacenter: "dc1", + }, &methodToken)) + + method2, err := upsertTestAuthMethod(codec, "root", "dc1", "") + require.NoError(t, err) + + // try to update the token and change the method + req := structs.ACLTokenSetRequest{ + Datacenter: "dc1", + ACLToken: structs.ACLToken{ + AccessorID: methodToken.AccessorID, + SecretID: methodToken.SecretID, + AuthMethod: method2.Name, + Description: "updated token", + Local: true, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + resp := structs.ACLToken{} + + err = acl.TokenSet(&req, &resp) + requireErrorContains(t, err, "Cannot change AuthMethod") + }) + + t.Run("Update auth method linked token and let the SecretID and AuthMethod be defaulted", func(t *testing.T) { + acl := ACL{srv: s1} + + testSessionID := testauth.StartSession() + defer testauth.ResetSession(testSessionID) + testauth.InstallSessionToken(testSessionID, "fake-token", "default", "demo", "abc123") + + method, err := upsertTestAuthMethod(codec, "root", "dc1", testSessionID) + require.NoError(t, err) + + _, err = upsertTestBindingRule(codec, "root", "dc1", method.Name, "", structs.BindingRuleBindTypeService, "demo") + require.NoError(t, err) + + methodToken := structs.ACLToken{} + require.NoError(t, acl.Login(&structs.ACLLoginRequest{ + Auth: &structs.ACLLoginParams{ + AuthMethod: method.Name, + BearerToken: "fake-token", + }, + Datacenter: "dc1", + }, &methodToken)) + + req := structs.ACLTokenSetRequest{ + Datacenter: "dc1", + ACLToken: structs.ACLToken{ + AccessorID: methodToken.AccessorID, + // SecretID: methodToken.SecretID, + // AuthMethod: method.Name, + Description: "updated token", + Local: true, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + resp := structs.ACLToken{} + + require.NoError(t, acl.TokenSet(&req, &resp)) + + // Get the token directly to validate that it exists + tokenResp, err := retrieveTestToken(codec, "root", "dc1", resp.AccessorID) + require.NoError(t, err) + token := tokenResp.Token + + require.Len(t, token.Roles, 0) + require.Equal(t, "updated token", token.Description) + require.True(t, token.Local) + require.Equal(t, methodToken.SecretID, token.SecretID) + require.Equal(t, methodToken.AuthMethod, token.AuthMethod) + }) + t.Run("Create it with invalid service identity (empty)", func(t *testing.T) { req := structs.ACLTokenSetRequest{ Datacenter: "dc1", @@ -1062,6 +1175,69 @@ func TestACLEndpoint_TokenSet(t *testing.T) { }) } + t.Run("Create it with two of the same service identities", func(t *testing.T) { + req := structs.ACLTokenSetRequest{ + Datacenter: "dc1", + ACLToken: structs.ACLToken{ + Description: "foobar", + Policies: nil, + Local: false, + ServiceIdentities: []*structs.ACLServiceIdentity{ + &structs.ACLServiceIdentity{ServiceName: "example"}, + &structs.ACLServiceIdentity{ServiceName: "example"}, + }, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + resp := structs.ACLToken{} + + err := acl.TokenSet(&req, &resp) + require.NoError(t, err) + + // Get the token directly to validate that it exists + tokenResp, err := retrieveTestToken(codec, "root", "dc1", resp.AccessorID) + require.NoError(t, err) + token := tokenResp.Token + require.Len(t, token.ServiceIdentities, 1) + }) + + t.Run("Create it with two of the same service identities and different DCs", func(t *testing.T) { + req := structs.ACLTokenSetRequest{ + Datacenter: "dc1", + ACLToken: structs.ACLToken{ + Description: "foobar", + Policies: nil, + Local: false, + ServiceIdentities: []*structs.ACLServiceIdentity{ + &structs.ACLServiceIdentity{ + ServiceName: "example", + Datacenters: []string{"dc2", "dc3"}, + }, + &structs.ACLServiceIdentity{ + ServiceName: "example", + Datacenters: []string{"dc1", "dc2"}, + }, + }, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + resp := structs.ACLToken{} + + err := acl.TokenSet(&req, &resp) + require.NoError(t, err) + + // Get the token directly to validate that it exists + tokenResp, err := retrieveTestToken(codec, "root", "dc1", resp.AccessorID) + require.NoError(t, err) + token := tokenResp.Token + require.Len(t, token.ServiceIdentities, 1) + svcid := token.ServiceIdentities[0] + require.Equal(t, "example", svcid.ServiceName) + require.ElementsMatch(t, []string{"dc1", "dc2", "dc3"}, svcid.Datacenters) + }) + t.Run("Create it with invalid service identity (datacenters set on local token)", func(t *testing.T) { req := structs.ACLTokenSetRequest{ Datacenter: "dc1", @@ -1241,11 +1417,37 @@ func TestACLEndpoint_TokenSet(t *testing.T) { // do not insert another test at this point: these tests need to be serial + t.Run("Update anything except expiration time is ok - omit expiration time and let it default", func(t *testing.T) { + req := structs.ACLTokenSetRequest{ + Datacenter: "dc1", + ACLToken: structs.ACLToken{ + Description: "new-description-1", + AccessorID: tokenID, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + resp := structs.ACLToken{} + + err := acl.TokenSet(&req, &resp) + require.NoError(t, err) + + // Get the token directly to validate that it exists + tokenResp, err := retrieveTestToken(codec, "root", "dc1", resp.AccessorID) + require.NoError(t, err) + token := tokenResp.Token + + require.NotNil(t, token.AccessorID) + require.Equal(t, token.Description, "new-description-1") + require.Equal(t, token.AccessorID, resp.AccessorID) + requireTimeEquals(t, &expTime, resp.ExpirationTime) + }) + t.Run("Update anything except expiration time is ok", func(t *testing.T) { req := structs.ACLTokenSetRequest{ Datacenter: "dc1", ACLToken: structs.ACLToken{ - Description: "new-description", + Description: "new-description-2", AccessorID: tokenID, ExpirationTime: &expTime, }, @@ -1263,7 +1465,7 @@ func TestACLEndpoint_TokenSet(t *testing.T) { token := tokenResp.Token require.NotNil(t, token.AccessorID) - require.Equal(t, token.Description, "new-description") + require.Equal(t, token.Description, "new-description-2") require.Equal(t, token.AccessorID, resp.AccessorID) requireTimeEquals(t, &expTime, resp.ExpirationTime) }) @@ -1615,12 +1817,7 @@ func TestACLEndpoint_TokenList(t *testing.T) { t2.AccessorID, t3.AccessorID, } - - var retrievedTokens []string - for _, v := range resp.Tokens { - retrievedTokens = append(retrievedTokens, v.AccessorID) - } - require.ElementsMatch(t, retrievedTokens, tokens) + require.ElementsMatch(t, gatherIDs(t, resp.Tokens), tokens) }) time.Sleep(20 * time.Millisecond) // now 't3' is expired @@ -1642,12 +1839,7 @@ func TestACLEndpoint_TokenList(t *testing.T) { t1.AccessorID, t2.AccessorID, } - - var retrievedTokens []string - for _, v := range resp.Tokens { - retrievedTokens = append(retrievedTokens, v.AccessorID) - } - require.ElementsMatch(t, retrievedTokens, tokens) + require.ElementsMatch(t, gatherIDs(t, resp.Tokens), tokens) }) } @@ -1694,13 +1886,7 @@ func TestACLEndpoint_TokenBatchRead(t *testing.T) { err = acl.TokenBatchRead(&req, &resp) require.NoError(t, err) - - var retrievedTokens []string - - for _, v := range resp.Tokens { - retrievedTokens = append(retrievedTokens, v.AccessorID) - } - require.EqualValues(t, retrievedTokens, tokens) + require.ElementsMatch(t, gatherIDs(t, resp.Tokens), tokens) }) time.Sleep(20 * time.Millisecond) // now 't3' is expired @@ -1718,13 +1904,7 @@ func TestACLEndpoint_TokenBatchRead(t *testing.T) { err = acl.TokenBatchRead(&req, &resp) require.NoError(t, err) - - var retrievedTokens []string - - for _, v := range resp.Tokens { - retrievedTokens = append(retrievedTokens, v.AccessorID) - } - require.EqualValues(t, retrievedTokens, tokens) + require.ElementsMatch(t, gatherIDs(t, resp.Tokens), tokens) }) } @@ -1801,13 +1981,7 @@ func TestACLEndpoint_PolicyBatchRead(t *testing.T) { err = acl.PolicyBatchRead(&req, &resp) require.NoError(t, err) - - var retrievedPolicies []string - - for _, v := range resp.Policies { - retrievedPolicies = append(retrievedPolicies, v.ID) - } - require.EqualValues(t, retrievedPolicies, policies) + require.ElementsMatch(t, gatherIDs(t, resp.Policies), []string{p1.ID, p2.ID}) } func TestACLEndpoint_PolicySet(t *testing.T) { @@ -2053,12 +2227,7 @@ func TestACLEndpoint_PolicyList(t *testing.T) { p1.ID, p2.ID, } - var retrievedPolicies []string - - for _, v := range resp.Policies { - retrievedPolicies = append(retrievedPolicies, v.ID) - } - require.ElementsMatch(t, retrievedPolicies, policies) + require.ElementsMatch(t, gatherIDs(t, resp.Policies), policies) } func TestACLEndpoint_PolicyResolve(t *testing.T) { @@ -2114,13 +2283,7 @@ func TestACLEndpoint_PolicyResolve(t *testing.T) { } err = acl.PolicyResolve(&req, &resp) require.NoError(t, err) - - var retrievedPolicies []string - - for _, v := range resp.Policies { - retrievedPolicies = append(retrievedPolicies, v.ID) - } - require.EqualValues(t, retrievedPolicies, policies) + require.ElementsMatch(t, gatherIDs(t, resp.Policies), policies) } func TestACLEndpoint_RoleRead(t *testing.T) { @@ -2189,13 +2352,7 @@ func TestACLEndpoint_RoleBatchRead(t *testing.T) { err = acl.RoleBatchRead(&req, &resp) require.NoError(t, err) - - var retrievedRoles []string - - for _, v := range resp.Roles { - retrievedRoles = append(retrievedRoles, v.ID) - } - require.EqualValues(t, retrievedRoles, roles) + require.ElementsMatch(t, gatherIDs(t, resp.Roles), roles) } func TestACLEndpoint_RoleSet(t *testing.T) { @@ -2432,6 +2589,67 @@ func TestACLEndpoint_RoleSet(t *testing.T) { } }) } + + t.Run("Create it with two of the same service identities", func(t *testing.T) { + req := structs.ACLRoleSetRequest{ + Datacenter: "dc1", + Role: structs.ACLRole{ + Description: "foobar", + Name: roleNameGen(t), + ServiceIdentities: []*structs.ACLServiceIdentity{ + &structs.ACLServiceIdentity{ServiceName: "example"}, + &structs.ACLServiceIdentity{ServiceName: "example"}, + }, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + resp := structs.ACLRole{} + + err := acl.RoleSet(&req, &resp) + require.NoError(t, err) + + // Get the role directly to validate that it exists + roleResp, err := retrieveTestRole(codec, "root", "dc1", resp.ID) + require.NoError(t, err) + role := roleResp.Role + require.Len(t, role.ServiceIdentities, 1) + }) + + t.Run("Create it with two of the same service identities and different DCs", func(t *testing.T) { + req := structs.ACLRoleSetRequest{ + Datacenter: "dc1", + Role: structs.ACLRole{ + Description: "foobar", + Name: roleNameGen(t), + ServiceIdentities: []*structs.ACLServiceIdentity{ + &structs.ACLServiceIdentity{ + ServiceName: "example", + Datacenters: []string{"dc2", "dc3"}, + }, + &structs.ACLServiceIdentity{ + ServiceName: "example", + Datacenters: []string{"dc1", "dc2"}, + }, + }, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + resp := structs.ACLRole{} + + err := acl.RoleSet(&req, &resp) + require.NoError(t, err) + + // Get the role directly to validate that it exists + roleResp, err := retrieveTestRole(codec, "root", "dc1", resp.ID) + require.NoError(t, err) + role := roleResp.Role + require.Len(t, role.ServiceIdentities, 1) + svcid := role.ServiceIdentities[0] + require.Equal(t, "example", svcid.ServiceName) + require.ElementsMatch(t, []string{"dc1", "dc2", "dc3"}, svcid.Datacenters) + }) } func TestACLEndpoint_RoleSet_names(t *testing.T) { @@ -2589,14 +2807,2009 @@ func TestACLEndpoint_RoleList(t *testing.T) { err = acl.RoleList(&req, &resp) require.NoError(t, err) + require.ElementsMatch(t, gatherIDs(t, resp.Roles), []string{r1.ID, r2.ID}) +} - roles := []string{r1.ID, r2.ID} - var retrievedRoles []string +func TestACLEndpoint_RoleResolve(t *testing.T) { + t.Parallel() - for _, v := range resp.Roles { - retrievedRoles = append(retrievedRoles, v.ID) + dir1, s1 := testServerWithConfig(t, func(c *Config) { + c.ACLDatacenter = "dc1" + c.ACLsEnabled = true + c.ACLMasterToken = "root" + }) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + codec := rpcClient(t, s1) + defer codec.Close() + + testrpc.WaitForLeader(t, s1.RPC, "dc1") + + t.Run("Normal", func(t *testing.T) { + r1, err := upsertTestRole(codec, "root", "dc1") + require.NoError(t, err) + + r2, err := upsertTestRole(codec, "root", "dc1") + require.NoError(t, err) + + acl := ACL{srv: s1} + + // Assign the roles to a token + tokenUpsertReq := structs.ACLTokenSetRequest{ + Datacenter: "dc1", + ACLToken: structs.ACLToken{ + Roles: []structs.ACLTokenRoleLink{ + structs.ACLTokenRoleLink{ + ID: r1.ID, + }, + structs.ACLTokenRoleLink{ + ID: r2.ID, + }, + }, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + token := structs.ACLToken{} + err = acl.TokenSet(&tokenUpsertReq, &token) + require.NoError(t, err) + require.NotEmpty(t, token.SecretID) + + resp := structs.ACLRoleBatchResponse{} + req := structs.ACLRoleBatchGetRequest{ + Datacenter: "dc1", + RoleIDs: []string{r1.ID, r2.ID}, + QueryOptions: structs.QueryOptions{Token: token.SecretID}, + } + err = acl.RoleResolve(&req, &resp) + require.NoError(t, err) + require.ElementsMatch(t, gatherIDs(t, resp.Roles), []string{r1.ID, r2.ID}) + }) +} + +func TestACLEndpoint_AuthMethodSet(t *testing.T) { + t.Parallel() + + tempDir, err := ioutil.TempDir("", "consul") + require.NoError(t, err) + defer os.RemoveAll(tempDir) + + dir1, s1 := testServerWithConfig(t, func(c *Config) { + c.ACLDatacenter = "dc1" + c.ACLsEnabled = true + c.ACLMasterToken = "root" + }) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + codec := rpcClient(t, s1) + defer codec.Close() + + testrpc.WaitForLeader(t, s1.RPC, "dc1") + + acl := ACL{srv: s1} + + newAuthMethod := func(name string) structs.ACLAuthMethod { + return structs.ACLAuthMethod{ + Name: name, + Description: "test", + Type: "testing", + } + } + + t.Run("Create", func(t *testing.T) { + reqMethod := newAuthMethod("test") + + req := structs.ACLAuthMethodSetRequest{ + Datacenter: "dc1", + AuthMethod: reqMethod, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + resp := structs.ACLAuthMethod{} + + err := acl.AuthMethodSet(&req, &resp) + require.NoError(t, err) + + // Get the method directly to validate that it exists + methodResp, err := retrieveTestAuthMethod(codec, "root", "dc1", resp.Name) + require.NoError(t, err) + method := methodResp.AuthMethod + + require.Equal(t, method.Name, "test") + require.Equal(t, method.Description, "test") + require.Equal(t, method.Type, "testing") + }) + + t.Run("Update fails; not allowed to change types", func(t *testing.T) { + reqMethod := newAuthMethod("test") + reqMethod.Type = "invalid" + + req := structs.ACLAuthMethodSetRequest{ + Datacenter: "dc1", + AuthMethod: reqMethod, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + resp := structs.ACLAuthMethod{} + + err := acl.AuthMethodSet(&req, &resp) + require.Error(t, err) + }) + + t.Run("Update - allow type to default", func(t *testing.T) { + reqMethod := newAuthMethod("test") + reqMethod.Description = "test modified 1" + reqMethod.Type = "" // unset + + req := structs.ACLAuthMethodSetRequest{ + Datacenter: "dc1", + AuthMethod: reqMethod, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + resp := structs.ACLAuthMethod{} + + err := acl.AuthMethodSet(&req, &resp) + require.NoError(t, err) + + // Get the method directly to validate that it exists + methodResp, err := retrieveTestAuthMethod(codec, "root", "dc1", resp.Name) + require.NoError(t, err) + method := methodResp.AuthMethod + + require.Equal(t, method.Name, "test") + require.Equal(t, method.Description, "test modified 1") + require.Equal(t, method.Type, "testing") + }) + + t.Run("Update - specify type", func(t *testing.T) { + reqMethod := newAuthMethod("test") + reqMethod.Description = "test modified 2" + + req := structs.ACLAuthMethodSetRequest{ + Datacenter: "dc1", + AuthMethod: reqMethod, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + resp := structs.ACLAuthMethod{} + + err := acl.AuthMethodSet(&req, &resp) + require.NoError(t, err) + + // Get the method directly to validate that it exists + methodResp, err := retrieveTestAuthMethod(codec, "root", "dc1", resp.Name) + require.NoError(t, err) + method := methodResp.AuthMethod + + require.Equal(t, method.Name, "test") + require.Equal(t, method.Description, "test modified 2") + require.Equal(t, method.Type, "testing") + }) + + t.Run("Create with no name", func(t *testing.T) { + req := structs.ACLAuthMethodSetRequest{ + Datacenter: "dc1", + AuthMethod: newAuthMethod(""), + WriteRequest: structs.WriteRequest{Token: "root"}, + } + resp := structs.ACLAuthMethod{} + + err := acl.AuthMethodSet(&req, &resp) + require.Error(t, err) + }) + + t.Run("Create with invalid type", func(t *testing.T) { + req := structs.ACLAuthMethodSetRequest{ + Datacenter: "dc1", + AuthMethod: structs.ACLAuthMethod{ + Name: "invalid", + Description: "invalid test", + Type: "invalid", + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + resp := structs.ACLAuthMethod{} + + err := acl.AuthMethodSet(&req, &resp) + require.Error(t, err) + }) + + for _, test := range []struct { + name string + ok bool + }{ + {strings.Repeat("x", 129), false}, + {strings.Repeat("x", 128), true}, + {"-abc", true}, + {"abc-", true}, + {"a-bc", true}, + {"_abc", true}, + {"abc_", true}, + {"a_bc", true}, + {":abc", false}, + {"abc:", false}, + {"a:bc", false}, + {"Abc", true}, + {"aBc", true}, + {"abC", true}, + {"0abc", true}, + {"abc0", true}, + {"a0bc", true}, + } { + var testName string + if test.ok { + testName = "Create with valid name (by regex): " + test.name + } else { + testName = "Create with invalid name (by regex): " + test.name + } + t.Run(testName, func(t *testing.T) { + req := structs.ACLAuthMethodSetRequest{ + Datacenter: "dc1", + AuthMethod: newAuthMethod(test.name), + WriteRequest: structs.WriteRequest{Token: "root"}, + } + resp := structs.ACLAuthMethod{} + + err := acl.AuthMethodSet(&req, &resp) + + if test.ok { + require.NoError(t, err) + + // Get the method directly to validate that it exists + methodResp, err := retrieveTestAuthMethod(codec, "root", "dc1", resp.Name) + require.NoError(t, err) + method := methodResp.AuthMethod + + require.Equal(t, method.Name, test.name) + require.Equal(t, method.Type, "testing") + } else { + require.Error(t, err) + } + }) + } +} + +func TestACLEndpoint_AuthMethodDelete(t *testing.T) { + t.Parallel() + + dir1, s1 := testServerWithConfig(t, func(c *Config) { + c.ACLDatacenter = "dc1" + c.ACLsEnabled = true + c.ACLMasterToken = "root" + }) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + codec := rpcClient(t, s1) + defer codec.Close() + + testrpc.WaitForLeader(t, s1.RPC, "dc1") + + testSessionID := testauth.StartSession() + defer testauth.ResetSession(testSessionID) + + existingMethod, err := upsertTestAuthMethod(codec, "root", "dc1", testSessionID) + require.NoError(t, err) + + acl := ACL{srv: s1} + + t.Run("normal", func(t *testing.T) { + req := structs.ACLAuthMethodDeleteRequest{ + Datacenter: "dc1", + AuthMethodName: existingMethod.Name, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + var ignored bool + err = acl.AuthMethodDelete(&req, &ignored) + require.NoError(t, err) + + // Make sure the method is gone + methodResp, err := retrieveTestAuthMethod(codec, "root", "dc1", existingMethod.Name) + require.NoError(t, err) + require.Nil(t, methodResp.AuthMethod) + }) + + t.Run("delete something that doesn't exist", func(t *testing.T) { + req := structs.ACLAuthMethodDeleteRequest{ + Datacenter: "dc1", + AuthMethodName: "missing", + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + var ignored bool + err = acl.AuthMethodDelete(&req, &ignored) + require.NoError(t, err) + }) +} + +// Deleting an auth method atomically deletes all rules and tokens as well. +func TestACLEndpoint_AuthMethodDelete_RuleAndTokenCascade(t *testing.T) { + t.Parallel() + + dir1, s1 := testServerWithConfig(t, func(c *Config) { + c.ACLDatacenter = "dc1" + c.ACLsEnabled = true + c.ACLMasterToken = "root" + }) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + codec := rpcClient(t, s1) + defer codec.Close() + + testrpc.WaitForLeader(t, s1.RPC, "dc1") + + testSessionID1 := testauth.StartSession() + defer testauth.ResetSession(testSessionID1) + testauth.InstallSessionToken(testSessionID1, "fake-token1", "default", "abc", "abc123") + + testSessionID2 := testauth.StartSession() + defer testauth.ResetSession(testSessionID2) + testauth.InstallSessionToken(testSessionID2, "fake-token2", "default", "abc", "abc123") + + createToken := func(methodName, bearerToken string) *structs.ACLToken { + acl := ACL{srv: s1} + + resp := structs.ACLToken{} + + require.NoError(t, acl.Login(&structs.ACLLoginRequest{ + Auth: &structs.ACLLoginParams{ + AuthMethod: methodName, + BearerToken: bearerToken, + }, + Datacenter: "dc1", + }, &resp)) + + return &resp + } + + method1, err := upsertTestAuthMethod(codec, "root", "dc1", testSessionID1) + require.NoError(t, err) + i1_r1, err := upsertTestBindingRule( + codec, "root", "dc1", + method1.Name, + "serviceaccount.name==abc", + structs.BindingRuleBindTypeService, + "abc", + ) + require.NoError(t, err) + i1_r2, err := upsertTestBindingRule( + codec, "root", "dc1", + method1.Name, + "serviceaccount.name==def", + structs.BindingRuleBindTypeService, + "def", + ) + require.NoError(t, err) + i1_t1 := createToken(method1.Name, "fake-token1") + i1_t2 := createToken(method1.Name, "fake-token1") + + method2, err := upsertTestAuthMethod(codec, "root", "dc1", testSessionID2) + require.NoError(t, err) + i2_r1, err := upsertTestBindingRule( + codec, "root", "dc1", + method2.Name, + "serviceaccount.name==abc", + structs.BindingRuleBindTypeService, + "abc", + ) + require.NoError(t, err) + i2_r2, err := upsertTestBindingRule( + codec, "root", "dc1", + method2.Name, + "serviceaccount.name==def", + structs.BindingRuleBindTypeService, + "def", + ) + require.NoError(t, err) + i2_t1 := createToken(method2.Name, "fake-token2") + i2_t2 := createToken(method2.Name, "fake-token2") + + acl := ACL{srv: s1} + + req := structs.ACLAuthMethodDeleteRequest{ + Datacenter: "dc1", + AuthMethodName: method1.Name, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + var ignored bool + err = acl.AuthMethodDelete(&req, &ignored) + require.NoError(t, err) + + // Make sure the method is gone. + methodResp, err := retrieveTestAuthMethod(codec, "root", "dc1", method1.Name) + require.NoError(t, err) + require.Nil(t, methodResp.AuthMethod) + + // Make sure the rules and tokens are gone. + for _, id := range []string{i1_r1.ID, i1_r2.ID} { + ruleResp, err := retrieveTestBindingRule(codec, "root", "dc1", id) + require.NoError(t, err) + require.Nil(t, ruleResp.BindingRule) + } + for _, id := range []string{i1_t1.AccessorID, i1_t2.AccessorID} { + tokResp, err := retrieveTestToken(codec, "root", "dc1", id) + require.NoError(t, err) + require.Nil(t, tokResp.Token) + } + + // Make sure the rules and tokens for the untouched auth method are still there. + for _, id := range []string{i2_r1.ID, i2_r2.ID} { + ruleResp, err := retrieveTestBindingRule(codec, "root", "dc1", id) + require.NoError(t, err) + require.NotNil(t, ruleResp.BindingRule) + } + for _, id := range []string{i2_t1.AccessorID, i2_t2.AccessorID} { + tokResp, err := retrieveTestToken(codec, "root", "dc1", id) + require.NoError(t, err) + require.NotNil(t, tokResp.Token) + } +} + +func TestACLEndpoint_AuthMethodList(t *testing.T) { + t.Parallel() + + dir1, s1 := testServerWithConfig(t, func(c *Config) { + c.ACLDatacenter = "dc1" + c.ACLsEnabled = true + c.ACLMasterToken = "root" + }) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + codec := rpcClient(t, s1) + defer codec.Close() + + testrpc.WaitForLeader(t, s1.RPC, "dc1") + + i1, err := upsertTestAuthMethod(codec, "root", "dc1", "") + require.NoError(t, err) + + i2, err := upsertTestAuthMethod(codec, "root", "dc1", "") + require.NoError(t, err) + + acl := ACL{srv: s1} + + req := structs.ACLAuthMethodListRequest{ + Datacenter: "dc1", + QueryOptions: structs.QueryOptions{Token: "root"}, + } + + resp := structs.ACLAuthMethodListResponse{} + + err = acl.AuthMethodList(&req, &resp) + require.NoError(t, err) + require.ElementsMatch(t, gatherIDs(t, resp.AuthMethods), []string{i1.Name, i2.Name}) +} + +func TestACLEndpoint_BindingRuleSet(t *testing.T) { + t.Parallel() + + dir1, s1 := testServerWithConfig(t, func(c *Config) { + c.ACLDatacenter = "dc1" + c.ACLsEnabled = true + c.ACLMasterToken = "root" + }) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + codec := rpcClient(t, s1) + defer codec.Close() + + testrpc.WaitForLeader(t, s1.RPC, "dc1") + + acl := ACL{srv: s1} + var ruleID string + + testAuthMethod, err := upsertTestAuthMethod(codec, "root", "dc1", "") + require.NoError(t, err) + + otherTestAuthMethod, err := upsertTestAuthMethod(codec, "root", "dc1", "") + require.NoError(t, err) + + newRule := func() structs.ACLBindingRule { + return structs.ACLBindingRule{ + Description: "foobar", + AuthMethod: testAuthMethod.Name, + Selector: "serviceaccount.name==abc", + BindType: structs.BindingRuleBindTypeService, + BindName: "abc", + } + } + + requireSetErrors := func(t *testing.T, reqRule structs.ACLBindingRule) { + req := structs.ACLBindingRuleSetRequest{ + Datacenter: "dc1", + BindingRule: reqRule, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + resp := structs.ACLBindingRule{} + + err := acl.BindingRuleSet(&req, &resp) + require.Error(t, err) + } + + requireOK := func(t *testing.T, reqRule structs.ACLBindingRule) *structs.ACLBindingRule { + req := structs.ACLBindingRuleSetRequest{ + Datacenter: "dc1", + BindingRule: reqRule, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + resp := structs.ACLBindingRule{} + + err := acl.BindingRuleSet(&req, &resp) + require.NoError(t, err) + require.NotEmpty(t, resp.ID) + return &resp + } + + t.Run("Create it", func(t *testing.T) { + reqRule := newRule() + + req := structs.ACLBindingRuleSetRequest{ + Datacenter: "dc1", + BindingRule: reqRule, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + resp := structs.ACLBindingRule{} + + err := acl.BindingRuleSet(&req, &resp) + require.NoError(t, err) + require.NotNil(t, resp.ID) + + // Get the rule directly to validate that it exists + ruleResp, err := retrieveTestBindingRule(codec, "root", "dc1", resp.ID) + require.NoError(t, err) + rule := ruleResp.BindingRule + + require.NotEmpty(t, rule.ID) + require.Equal(t, rule.Description, "foobar") + require.Equal(t, rule.AuthMethod, testAuthMethod.Name) + require.Equal(t, "serviceaccount.name==abc", rule.Selector) + require.Equal(t, structs.BindingRuleBindTypeService, rule.BindType) + require.Equal(t, "abc", rule.BindName) + + ruleID = rule.ID + }) + + t.Run("Update fails; cannot change method name", func(t *testing.T) { + reqRule := newRule() + reqRule.ID = ruleID + reqRule.AuthMethod = otherTestAuthMethod.Name + requireSetErrors(t, reqRule) + }) + + t.Run("Update it - omit method name", func(t *testing.T) { + reqRule := newRule() + reqRule.ID = ruleID + reqRule.Description = "foobar modified 1" + reqRule.Selector = "serviceaccount.namespace==def" + reqRule.BindType = structs.BindingRuleBindTypeRole + reqRule.BindName = "def" + reqRule.AuthMethod = "" // clear + + req := structs.ACLBindingRuleSetRequest{ + Datacenter: "dc1", + BindingRule: reqRule, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + resp := structs.ACLBindingRule{} + + err := acl.BindingRuleSet(&req, &resp) + require.NoError(t, err) + require.NotNil(t, resp.ID) + + // Get the rule directly to validate that it exists + ruleResp, err := retrieveTestBindingRule(codec, "root", "dc1", resp.ID) + require.NoError(t, err) + rule := ruleResp.BindingRule + + require.NotEmpty(t, rule.ID) + require.Equal(t, rule.Description, "foobar modified 1") + require.Equal(t, rule.AuthMethod, testAuthMethod.Name) + require.Equal(t, "serviceaccount.namespace==def", rule.Selector) + require.Equal(t, structs.BindingRuleBindTypeRole, rule.BindType) + require.Equal(t, "def", rule.BindName) + }) + + t.Run("Update it - specify method name", func(t *testing.T) { + reqRule := newRule() + reqRule.ID = ruleID + reqRule.Description = "foobar modified 2" + reqRule.Selector = "serviceaccount.namespace==def" + reqRule.BindType = structs.BindingRuleBindTypeRole + reqRule.BindName = "def" + + req := structs.ACLBindingRuleSetRequest{ + Datacenter: "dc1", + BindingRule: reqRule, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + resp := structs.ACLBindingRule{} + + err := acl.BindingRuleSet(&req, &resp) + require.NoError(t, err) + require.NotNil(t, resp.ID) + + // Get the rule directly to validate that it exists + ruleResp, err := retrieveTestBindingRule(codec, "root", "dc1", resp.ID) + require.NoError(t, err) + rule := ruleResp.BindingRule + + require.NotEmpty(t, rule.ID) + require.Equal(t, rule.Description, "foobar modified 2") + require.Equal(t, rule.AuthMethod, testAuthMethod.Name) + require.Equal(t, "serviceaccount.namespace==def", rule.Selector) + require.Equal(t, structs.BindingRuleBindTypeRole, rule.BindType) + require.Equal(t, "def", rule.BindName) + }) + + t.Run("Create fails; empty method name", func(t *testing.T) { + reqRule := newRule() + reqRule.AuthMethod = "" + requireSetErrors(t, reqRule) + }) + + t.Run("Create fails; unknown method name", func(t *testing.T) { + reqRule := newRule() + reqRule.AuthMethod = "unknown" + requireSetErrors(t, reqRule) + }) + + t.Run("Create with no explicit selector", func(t *testing.T) { + reqRule := newRule() + reqRule.Selector = "" + + rule := requireOK(t, reqRule) + require.Empty(t, rule.Selector, 0) + }) + + t.Run("Create fails; match selector with unknown vars", func(t *testing.T) { + reqRule := newRule() + reqRule.Selector = "serviceaccount.name==a and serviceaccount.bizarroname==b" + requireSetErrors(t, reqRule) + }) + + t.Run("Create fails; match selector invalid", func(t *testing.T) { + reqRule := newRule() + reqRule.Selector = "serviceaccount.name" + requireSetErrors(t, reqRule) + }) + + t.Run("Create fails; empty bind type", func(t *testing.T) { + reqRule := newRule() + reqRule.BindType = "" + requireSetErrors(t, reqRule) + }) + + t.Run("Create fails; empty bind name", func(t *testing.T) { + reqRule := newRule() + reqRule.BindName = "" + requireSetErrors(t, reqRule) + }) + + t.Run("Create fails; invalid bind type", func(t *testing.T) { + reqRule := newRule() + reqRule.BindType = "invalid" + requireSetErrors(t, reqRule) + }) + + t.Run("Create fails; bind name with unknown vars", func(t *testing.T) { + reqRule := newRule() + reqRule.BindName = "method-${serviceaccount.bizarroname}" + requireSetErrors(t, reqRule) + }) + + t.Run("Create fails; invalid bind name no template", func(t *testing.T) { + reqRule := newRule() + reqRule.BindName = "-abc:" + requireSetErrors(t, reqRule) + }) + + t.Run("Create fails; invalid bind name with template", func(t *testing.T) { + reqRule := newRule() + reqRule.BindName = "method-${serviceaccount.name" + requireSetErrors(t, reqRule) + }) + t.Run("Create fails; invalid bind name after template computed", func(t *testing.T) { + reqRule := newRule() + reqRule.BindName = "method-${serviceaccount.name}:blah-" + requireSetErrors(t, reqRule) + }) +} + +func TestACLEndpoint_BindingRuleDelete(t *testing.T) { + t.Parallel() + + dir1, s1 := testServerWithConfig(t, func(c *Config) { + c.ACLDatacenter = "dc1" + c.ACLsEnabled = true + c.ACLMasterToken = "root" + }) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + codec := rpcClient(t, s1) + defer codec.Close() + + testrpc.WaitForLeader(t, s1.RPC, "dc1") + + testAuthMethod, err := upsertTestAuthMethod(codec, "root", "dc1", "") + require.NoError(t, err) + + existingRule, err := upsertTestBindingRule( + codec, "root", "dc1", + testAuthMethod.Name, + "serviceaccount.name==abc", + structs.BindingRuleBindTypeService, + "abc", + ) + require.NoError(t, err) + + acl := ACL{srv: s1} + + t.Run("normal", func(t *testing.T) { + req := structs.ACLBindingRuleDeleteRequest{ + Datacenter: "dc1", + BindingRuleID: existingRule.ID, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + var ignored bool + err = acl.BindingRuleDelete(&req, &ignored) + require.NoError(t, err) + + // Make sure the rule is gone + ruleResp, err := retrieveTestBindingRule(codec, "root", "dc1", existingRule.ID) + require.NoError(t, err) + require.Nil(t, ruleResp.BindingRule) + }) + + t.Run("delete something that doesn't exist", func(t *testing.T) { + fakeID, err := uuid.GenerateUUID() + require.NoError(t, err) + + req := structs.ACLBindingRuleDeleteRequest{ + Datacenter: "dc1", + BindingRuleID: fakeID, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + var ignored bool + err = acl.BindingRuleDelete(&req, &ignored) + require.NoError(t, err) + }) +} + +func TestACLEndpoint_BindingRuleList(t *testing.T) { + t.Parallel() + + dir1, s1 := testServerWithConfig(t, func(c *Config) { + c.ACLDatacenter = "dc1" + c.ACLsEnabled = true + c.ACLMasterToken = "root" + }) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + codec := rpcClient(t, s1) + defer codec.Close() + + testrpc.WaitForLeader(t, s1.RPC, "dc1") + + testAuthMethod, err := upsertTestAuthMethod(codec, "root", "dc1", "") + require.NoError(t, err) + + r1, err := upsertTestBindingRule( + codec, "root", "dc1", + testAuthMethod.Name, + "serviceaccount.name==abc", + structs.BindingRuleBindTypeService, + "abc", + ) + require.NoError(t, err) + + r2, err := upsertTestBindingRule( + codec, "root", "dc1", + testAuthMethod.Name, + "serviceaccount.name==def", + structs.BindingRuleBindTypeService, + "def", + ) + require.NoError(t, err) + + acl := ACL{srv: s1} + + req := structs.ACLBindingRuleListRequest{ + Datacenter: "dc1", + QueryOptions: structs.QueryOptions{Token: "root"}, + } + + resp := structs.ACLBindingRuleListResponse{} + + err = acl.BindingRuleList(&req, &resp) + require.NoError(t, err) + require.ElementsMatch(t, gatherIDs(t, resp.BindingRules), []string{r1.ID, r2.ID}) +} + +func TestACLEndpoint_SecureIntroEndpoints_LocalTokensDisabled(t *testing.T) { + t.Parallel() + + dir1, s1 := testServerWithConfig(t, func(c *Config) { + c.ACLDatacenter = "dc1" + c.ACLsEnabled = true + c.ACLMasterToken = "root" + c.ACLTokenMinExpirationTTL = 10 * time.Millisecond + c.ACLTokenMaxExpirationTTL = 5 * time.Second + }) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + codec := rpcClient(t, s1) + defer codec.Close() + + testrpc.WaitForLeader(t, s1.RPC, "dc1") + + dir2, s2 := testServerWithConfig(t, func(c *Config) { + c.ACLDatacenter = "dc1" + c.ACLsEnabled = true + c.Datacenter = "dc2" + c.ACLTokenMinExpirationTTL = 10 * time.Millisecond + c.ACLTokenMaxExpirationTTL = 5 * time.Second + // disable local tokens + c.ACLTokenReplication = false + }) + defer os.RemoveAll(dir2) + defer s2.Shutdown() + codec2 := rpcClient(t, s2) + defer codec2.Close() + + s2.tokens.UpdateReplicationToken("root", tokenStore.TokenSourceConfig) + + testrpc.WaitForLeader(t, s1.RPC, "dc1") + testrpc.WaitForLeader(t, s2.RPC, "dc2") + + // Try to join + joinWAN(t, s2, s1) + + waitForNewACLs(t, s1) + waitForNewACLs(t, s2) + + acl2 := ACL{srv: s2} + var ignored bool + + errString := errAuthMethodsRequireTokenReplication.Error() + + t.Run("AuthMethodRead", func(t *testing.T) { + requireErrorContains(t, + acl2.AuthMethodRead(&structs.ACLAuthMethodGetRequest{Datacenter: "dc2"}, + &structs.ACLAuthMethodResponse{}), + errString, + ) + }) + t.Run("AuthMethodSet", func(t *testing.T) { + requireErrorContains(t, + acl2.AuthMethodSet(&structs.ACLAuthMethodSetRequest{Datacenter: "dc2"}, + &structs.ACLAuthMethod{}), + errString, + ) + }) + t.Run("AuthMethodDelete", func(t *testing.T) { + requireErrorContains(t, + acl2.AuthMethodDelete(&structs.ACLAuthMethodDeleteRequest{Datacenter: "dc2"}, &ignored), + errString, + ) + }) + t.Run("AuthMethodList", func(t *testing.T) { + requireErrorContains(t, + acl2.AuthMethodList(&structs.ACLAuthMethodListRequest{Datacenter: "dc2"}, + &structs.ACLAuthMethodListResponse{}), + errString, + ) + }) + + t.Run("BindingRuleRead", func(t *testing.T) { + requireErrorContains(t, + acl2.BindingRuleRead(&structs.ACLBindingRuleGetRequest{Datacenter: "dc2"}, + &structs.ACLBindingRuleResponse{}), + errString, + ) + }) + t.Run("BindingRuleSet", func(t *testing.T) { + requireErrorContains(t, + acl2.BindingRuleSet(&structs.ACLBindingRuleSetRequest{Datacenter: "dc2"}, + &structs.ACLBindingRule{}), + errString, + ) + }) + t.Run("BindingRuleDelete", func(t *testing.T) { + requireErrorContains(t, + acl2.BindingRuleDelete(&structs.ACLBindingRuleDeleteRequest{Datacenter: "dc2"}, &ignored), + errString, + ) + }) + t.Run("BindingRuleList", func(t *testing.T) { + requireErrorContains(t, + acl2.BindingRuleList(&structs.ACLBindingRuleListRequest{Datacenter: "dc2"}, + &structs.ACLBindingRuleListResponse{}), + errString, + ) + }) + + t.Run("Login", func(t *testing.T) { + requireErrorContains(t, + acl2.Login(&structs.ACLLoginRequest{Datacenter: "dc2"}, + &structs.ACLToken{}), + errString, + ) + }) + t.Run("Logout", func(t *testing.T) { + requireErrorContains(t, + acl2.Logout(&structs.ACLLogoutRequest{Datacenter: "dc2"}, &ignored), + errString, + ) + }) +} + +func TestACLEndpoint_SecureIntroEndpoints_OnlyCreateLocalData(t *testing.T) { + t.Parallel() + + dir1, s1 := testServerWithConfig(t, func(c *Config) { + c.ACLDatacenter = "dc1" + c.ACLsEnabled = true + c.ACLMasterToken = "root" + c.ACLTokenMinExpirationTTL = 10 * time.Millisecond + c.ACLTokenMaxExpirationTTL = 5 * time.Second + }) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + codec1 := rpcClient(t, s1) + defer codec1.Close() + + testrpc.WaitForLeader(t, s1.RPC, "dc1") + + dir2, s2 := testServerWithConfig(t, func(c *Config) { + c.ACLDatacenter = "dc1" + c.ACLsEnabled = true + c.Datacenter = "dc2" + c.ACLTokenMinExpirationTTL = 10 * time.Millisecond + c.ACLTokenMaxExpirationTTL = 5 * time.Second + // enable token replication so secure intro works + c.ACLTokenReplication = true + }) + defer os.RemoveAll(dir2) + defer s2.Shutdown() + codec2 := rpcClient(t, s2) + defer codec2.Close() + + s2.tokens.UpdateReplicationToken("root", tokenStore.TokenSourceConfig) + + testrpc.WaitForLeader(t, s1.RPC, "dc1") + testrpc.WaitForLeader(t, s2.RPC, "dc2") + + // Try to join + joinWAN(t, s2, s1) + + waitForNewACLs(t, s1) + waitForNewACLs(t, s2) + + acl := ACL{srv: s1} + acl2 := ACL{srv: s2} + + // + // this order is specific so that we can do it in one pass + // + + testSessionID_1 := testauth.StartSession() + defer testauth.ResetSession(testSessionID_1) + + testSessionID_2 := testauth.StartSession() + defer testauth.ResetSession(testSessionID_2) + + testauth.InstallSessionToken( + testSessionID_1, + "fake-web1-token", + "default", "web1", "abc123", + ) + testauth.InstallSessionToken( + testSessionID_2, + "fake-web2-token", + "default", "web2", "def456", + ) + + t.Run("create auth method", func(t *testing.T) { + req := structs.ACLAuthMethodSetRequest{ + Datacenter: "dc2", + AuthMethod: structs.ACLAuthMethod{ + Name: "testmethod", + Description: "test original", + Type: "testing", + Config: map[string]interface{}{ + "SessionID": testSessionID_2, + }, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + resp := structs.ACLAuthMethod{} + + require.NoError(t, acl2.AuthMethodSet(&req, &resp)) + + // present in dc2 + resp2, err := retrieveTestAuthMethod(codec2, "root", "dc2", "testmethod") + require.NoError(t, err) + require.NotNil(t, resp2.AuthMethod) + require.Equal(t, "test original", resp2.AuthMethod.Description) + // absent in dc1 + resp2, err = retrieveTestAuthMethod(codec1, "root", "dc1", "testmethod") + require.NoError(t, err) + require.Nil(t, resp2.AuthMethod) + }) + + t.Run("update auth method", func(t *testing.T) { + req := structs.ACLAuthMethodSetRequest{ + Datacenter: "dc2", + AuthMethod: structs.ACLAuthMethod{ + Name: "testmethod", + Description: "test updated", + Config: map[string]interface{}{ + "SessionID": testSessionID_2, + }, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + resp := structs.ACLAuthMethod{} + + require.NoError(t, acl2.AuthMethodSet(&req, &resp)) + + // present in dc2 + resp2, err := retrieveTestAuthMethod(codec2, "root", "dc2", "testmethod") + require.NoError(t, err) + require.NotNil(t, resp2.AuthMethod) + require.Equal(t, "test updated", resp2.AuthMethod.Description) + // absent in dc1 + resp2, err = retrieveTestAuthMethod(codec1, "root", "dc1", "testmethod") + require.NoError(t, err) + require.Nil(t, resp2.AuthMethod) + }) + + t.Run("read auth method", func(t *testing.T) { + // present in dc2 + req := structs.ACLAuthMethodGetRequest{ + Datacenter: "dc2", + AuthMethodName: "testmethod", + QueryOptions: structs.QueryOptions{Token: "root"}, + } + resp := structs.ACLAuthMethodResponse{} + require.NoError(t, acl2.AuthMethodRead(&req, &resp)) + require.NotNil(t, resp.AuthMethod) + require.Equal(t, "test updated", resp.AuthMethod.Description) + + // absent in dc1 + req = structs.ACLAuthMethodGetRequest{ + Datacenter: "dc1", + AuthMethodName: "testmethod", + QueryOptions: structs.QueryOptions{Token: "root"}, + } + resp = structs.ACLAuthMethodResponse{} + require.NoError(t, acl.AuthMethodRead(&req, &resp)) + require.Nil(t, resp.AuthMethod) + }) + + t.Run("list auth method", func(t *testing.T) { + // present in dc2 + req := structs.ACLAuthMethodListRequest{ + Datacenter: "dc2", + QueryOptions: structs.QueryOptions{Token: "root"}, + } + resp := structs.ACLAuthMethodListResponse{} + require.NoError(t, acl2.AuthMethodList(&req, &resp)) + require.Len(t, resp.AuthMethods, 1) + + // absent in dc1 + req = structs.ACLAuthMethodListRequest{ + Datacenter: "dc1", + QueryOptions: structs.QueryOptions{Token: "root"}, + } + resp = structs.ACLAuthMethodListResponse{} + require.NoError(t, acl.AuthMethodList(&req, &resp)) + require.Len(t, resp.AuthMethods, 0) + }) + + var ruleID string + t.Run("create binding rule", func(t *testing.T) { + req := structs.ACLBindingRuleSetRequest{ + Datacenter: "dc2", + BindingRule: structs.ACLBindingRule{ + Description: "test original", + AuthMethod: "testmethod", + BindType: structs.BindingRuleBindTypeService, + BindName: "${serviceaccount.name}", + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + resp := structs.ACLBindingRule{} + + require.NoError(t, acl2.BindingRuleSet(&req, &resp)) + ruleID = resp.ID + + // present in dc2 + resp2, err := retrieveTestBindingRule(codec2, "root", "dc2", ruleID) + require.NoError(t, err) + require.NotNil(t, resp2.BindingRule) + require.Equal(t, "test original", resp2.BindingRule.Description) + // absent in dc1 + resp2, err = retrieveTestBindingRule(codec1, "root", "dc1", ruleID) + require.NoError(t, err) + require.Nil(t, resp2.BindingRule) + }) + + t.Run("update binding rule", func(t *testing.T) { + req := structs.ACLBindingRuleSetRequest{ + Datacenter: "dc2", + BindingRule: structs.ACLBindingRule{ + ID: ruleID, + Description: "test updated", + AuthMethod: "testmethod", + BindType: structs.BindingRuleBindTypeService, + BindName: "${serviceaccount.name}", + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + resp := structs.ACLBindingRule{} + + require.NoError(t, acl2.BindingRuleSet(&req, &resp)) + ruleID = resp.ID + + // present in dc2 + resp2, err := retrieveTestBindingRule(codec2, "root", "dc2", ruleID) + require.NoError(t, err) + require.NotNil(t, resp2.BindingRule) + require.Equal(t, "test updated", resp2.BindingRule.Description) + // absent in dc1 + resp2, err = retrieveTestBindingRule(codec1, "root", "dc1", ruleID) + require.NoError(t, err) + require.Nil(t, resp2.BindingRule) + }) + + t.Run("read binding rule", func(t *testing.T) { + // present in dc2 + req := structs.ACLBindingRuleGetRequest{ + Datacenter: "dc2", + BindingRuleID: ruleID, + QueryOptions: structs.QueryOptions{Token: "root"}, + } + resp := structs.ACLBindingRuleResponse{} + require.NoError(t, acl2.BindingRuleRead(&req, &resp)) + require.NotNil(t, resp.BindingRule) + require.Equal(t, "test updated", resp.BindingRule.Description) + + // absent in dc1 + req = structs.ACLBindingRuleGetRequest{ + Datacenter: "dc1", + BindingRuleID: ruleID, + QueryOptions: structs.QueryOptions{Token: "root"}, + } + resp = structs.ACLBindingRuleResponse{} + require.NoError(t, acl.BindingRuleRead(&req, &resp)) + require.Nil(t, resp.BindingRule) + }) + + t.Run("list binding rule", func(t *testing.T) { + // present in dc2 + req := structs.ACLBindingRuleListRequest{ + Datacenter: "dc2", + QueryOptions: structs.QueryOptions{Token: "root"}, + } + resp := structs.ACLBindingRuleListResponse{} + require.NoError(t, acl2.BindingRuleList(&req, &resp)) + require.Len(t, resp.BindingRules, 1) + + // absent in dc1 + req = structs.ACLBindingRuleListRequest{ + Datacenter: "dc1", + QueryOptions: structs.QueryOptions{Token: "root"}, + } + resp = structs.ACLBindingRuleListResponse{} + require.NoError(t, acl.BindingRuleList(&req, &resp)) + require.Len(t, resp.BindingRules, 0) + }) + + var remoteToken *structs.ACLToken + t.Run("login in remote", func(t *testing.T) { + req := structs.ACLLoginRequest{ + Datacenter: "dc2", + Auth: &structs.ACLLoginParams{ + AuthMethod: "testmethod", + BearerToken: "fake-web2-token", + }, + } + resp := structs.ACLToken{} + + require.NoError(t, acl.Login(&req, &resp)) + remoteToken = &resp + + // present in dc2 + resp2, err := retrieveTestToken(codec2, "root", "dc2", remoteToken.AccessorID) + require.NoError(t, err) + require.NotNil(t, resp2.Token) + require.Len(t, resp2.Token.ServiceIdentities, 1) + require.Equal(t, "web2", resp2.Token.ServiceIdentities[0].ServiceName) + // absent in dc1 + resp2, err = retrieveTestToken(codec1, "root", "dc1", remoteToken.AccessorID) + require.NoError(t, err) + require.Nil(t, resp2.Token) + }) + + // We delay until now to setup an auth method and binding rule in the + // primary so our earlier listing tests were sane. We need to be able to + // use auth methods in both datacenters in order to verify Logout is + // properly scoped. + t.Run("initialize primary so we can test logout", func(t *testing.T) { + reqAM := structs.ACLAuthMethodSetRequest{ + Datacenter: "dc1", + AuthMethod: structs.ACLAuthMethod{ + Name: "primarymethod", + Type: "testing", + Config: map[string]interface{}{ + "SessionID": testSessionID_1, + }, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + respAM := structs.ACLAuthMethod{} + require.NoError(t, acl.AuthMethodSet(&reqAM, &respAM)) + + reqBR := structs.ACLBindingRuleSetRequest{ + Datacenter: "dc1", + BindingRule: structs.ACLBindingRule{ + AuthMethod: "primarymethod", + BindType: structs.BindingRuleBindTypeService, + BindName: "${serviceaccount.name}", + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + respBR := structs.ACLBindingRule{} + require.NoError(t, acl.BindingRuleSet(&reqBR, &respBR)) + }) + + var primaryToken *structs.ACLToken + t.Run("login in primary", func(t *testing.T) { + req := structs.ACLLoginRequest{ + Datacenter: "dc1", + Auth: &structs.ACLLoginParams{ + AuthMethod: "primarymethod", + BearerToken: "fake-web1-token", + }, + } + resp := structs.ACLToken{} + + require.NoError(t, acl.Login(&req, &resp)) + primaryToken = &resp + + // present in dc1 + resp2, err := retrieveTestToken(codec1, "root", "dc1", primaryToken.AccessorID) + require.NoError(t, err) + require.NotNil(t, resp2.Token) + require.Len(t, resp2.Token.ServiceIdentities, 1) + require.Equal(t, "web1", resp2.Token.ServiceIdentities[0].ServiceName) + // absent in dc2 + resp2, err = retrieveTestToken(codec2, "root", "dc2", primaryToken.AccessorID) + require.NoError(t, err) + require.Nil(t, resp2.Token) + }) + + t.Run("logout of remote token in remote dc", func(t *testing.T) { + req := structs.ACLLogoutRequest{ + Datacenter: "dc2", + WriteRequest: structs.WriteRequest{Token: remoteToken.SecretID}, + } + + var ignored bool + require.NoError(t, acl.Logout(&req, &ignored)) + + // absent in dc2 + resp2, err := retrieveTestToken(codec2, "root", "dc2", remoteToken.AccessorID) + require.NoError(t, err) + require.Nil(t, resp2.Token) + // absent in dc1 + resp2, err = retrieveTestToken(codec1, "root", "dc1", remoteToken.AccessorID) + require.NoError(t, err) + require.Nil(t, resp2.Token) + }) + + t.Run("logout of primary token in remote dc should not work", func(t *testing.T) { + req := structs.ACLLogoutRequest{ + Datacenter: "dc2", + WriteRequest: structs.WriteRequest{Token: primaryToken.SecretID}, + } + + var ignored bool + requireErrorContains(t, acl.Logout(&req, &ignored), "ACL not found") + + // present in dc1 + resp2, err := retrieveTestToken(codec1, "root", "dc1", primaryToken.AccessorID) + require.NoError(t, err) + require.NotNil(t, resp2.Token) + require.Len(t, resp2.Token.ServiceIdentities, 1) + require.Equal(t, "web1", resp2.Token.ServiceIdentities[0].ServiceName) + // absent in dc2 + resp2, err = retrieveTestToken(codec2, "root", "dc2", primaryToken.AccessorID) + require.NoError(t, err) + require.Nil(t, resp2.Token) + }) + + // Don't trigger the auth method delete cascade so we know the individual + // endpoints follow the rules. + + t.Run("delete binding rule", func(t *testing.T) { + req := structs.ACLBindingRuleDeleteRequest{ + Datacenter: "dc2", + BindingRuleID: ruleID, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + var ignored bool + require.NoError(t, acl2.BindingRuleDelete(&req, &ignored)) + + // absent in dc2 + resp2, err := retrieveTestBindingRule(codec2, "root", "dc2", ruleID) + require.NoError(t, err) + require.Nil(t, resp2.BindingRule) + // absent in dc1 + resp2, err = retrieveTestBindingRule(codec1, "root", "dc1", ruleID) + require.NoError(t, err) + require.Nil(t, resp2.BindingRule) + }) + + t.Run("delete auth method", func(t *testing.T) { + req := structs.ACLAuthMethodDeleteRequest{ + Datacenter: "dc2", + AuthMethodName: "testmethod", + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + var ignored bool + require.NoError(t, acl2.AuthMethodDelete(&req, &ignored)) + + // absent in dc2 + resp2, err := retrieveTestAuthMethod(codec2, "root", "dc2", "testmethod") + require.NoError(t, err) + require.Nil(t, resp2.AuthMethod) + // absent in dc1 + resp2, err = retrieveTestAuthMethod(codec1, "root", "dc1", "testmethod") + require.NoError(t, err) + require.Nil(t, resp2.AuthMethod) + }) +} + +func TestACLEndpoint_Login(t *testing.T) { + t.Parallel() + + dir1, s1 := testServerWithConfig(t, func(c *Config) { + c.ACLDatacenter = "dc1" + c.ACLsEnabled = true + c.ACLMasterToken = "root" + }) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + codec := rpcClient(t, s1) + defer codec.Close() + + testrpc.WaitForLeader(t, s1.RPC, "dc1") + + acl := ACL{srv: s1} + + testSessionID := testauth.StartSession() + defer testauth.ResetSession(testSessionID) + + testauth.InstallSessionToken( + testSessionID, + "fake-web", // no rules + "default", "web", "abc123", + ) + testauth.InstallSessionToken( + testSessionID, + "fake-db", // 1 rule + "default", "db", "def456", + ) + testauth.InstallSessionToken( + testSessionID, + "fake-monolith", // 1 rule, must exist + "default", "monolith", "ghi789", + ) + + method, err := upsertTestAuthMethod(codec, "root", "dc1", testSessionID) + require.NoError(t, err) + + ruleDB, err := upsertTestBindingRule( + codec, "root", "dc1", method.Name, + "serviceaccount.namespace==default and serviceaccount.name==db", + structs.BindingRuleBindTypeService, + "method-${serviceaccount.name}", + ) + _, err = upsertTestBindingRule( + codec, "root", "dc1", method.Name, + "serviceaccount.namespace==default and serviceaccount.name==monolith", + structs.BindingRuleBindTypeRole, + "method-${serviceaccount.name}", + ) + require.NoError(t, err) + + t.Run("do not provide a token", func(t *testing.T) { + req := structs.ACLLoginRequest{ + Auth: &structs.ACLLoginParams{ + AuthMethod: method.Name, + BearerToken: "fake-web", + Meta: map[string]string{"pod": "pod1"}, + }, + Datacenter: "dc1", + } + req.Token = "nope" + resp := structs.ACLToken{} + + requireErrorContains(t, acl.Login(&req, &resp), "do not provide a token") + }) + + t.Run("unknown method", func(t *testing.T) { + req := structs.ACLLoginRequest{ + Auth: &structs.ACLLoginParams{ + AuthMethod: method.Name + "-notexist", + BearerToken: "fake-web", + Meta: map[string]string{"pod": "pod1"}, + }, + Datacenter: "dc1", + } + resp := structs.ACLToken{} + + requireErrorContains(t, acl.Login(&req, &resp), "ACL not found") + }) + + t.Run("invalid method token", func(t *testing.T) { + req := structs.ACLLoginRequest{ + Auth: &structs.ACLLoginParams{ + AuthMethod: method.Name, + BearerToken: "invalid", + Meta: map[string]string{"pod": "pod1"}, + }, + Datacenter: "dc1", + } + resp := structs.ACLToken{} + + require.Error(t, acl.Login(&req, &resp)) + }) + + t.Run("valid method token no bindings", func(t *testing.T) { + req := structs.ACLLoginRequest{ + Auth: &structs.ACLLoginParams{ + AuthMethod: method.Name, + BearerToken: "fake-web", + Meta: map[string]string{"pod": "pod1"}, + }, + Datacenter: "dc1", + } + resp := structs.ACLToken{} + + requireErrorContains(t, acl.Login(&req, &resp), "Permission denied") + }) + + t.Run("valid method token 1 role binding must exist and does not exist", func(t *testing.T) { + req := structs.ACLLoginRequest{ + Auth: &structs.ACLLoginParams{ + AuthMethod: method.Name, + BearerToken: "fake-monolith", + Meta: map[string]string{"pod": "pod1"}, + }, + Datacenter: "dc1", + } + resp := structs.ACLToken{} + + require.Error(t, acl.Login(&req, &resp)) + }) + + // create the role so that the bindtype=existing login works + var monolithRoleID string + { + arg := structs.ACLRoleSetRequest{ + Datacenter: "dc1", + Role: structs.ACLRole{ + Name: "method-monolith", + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + var out structs.ACLRole + require.NoError(t, acl.RoleSet(&arg, &out)) + + monolithRoleID = out.ID + } + s1.purgeAuthMethodValidators() + + t.Run("valid bearer token 1 role binding must exist and now exists", func(t *testing.T) { + req := structs.ACLLoginRequest{ + Auth: &structs.ACLLoginParams{ + AuthMethod: method.Name, + BearerToken: "fake-monolith", + Meta: map[string]string{"pod": "pod1"}, + }, + Datacenter: "dc1", + } + resp := structs.ACLToken{} + + require.NoError(t, acl.Login(&req, &resp)) + + require.Equal(t, method.Name, resp.AuthMethod) + require.Equal(t, `token created via login: {"pod":"pod1"}`, resp.Description) + require.True(t, resp.Local) + require.Len(t, resp.ServiceIdentities, 0) + require.Len(t, resp.Roles, 1) + role := resp.Roles[0] + require.Equal(t, monolithRoleID, role.ID) + require.Equal(t, "method-monolith", role.Name) + }) + + t.Run("valid bearer token 1 service binding", func(t *testing.T) { + req := structs.ACLLoginRequest{ + Auth: &structs.ACLLoginParams{ + AuthMethod: method.Name, + BearerToken: "fake-db", + Meta: map[string]string{"pod": "pod1"}, + }, + Datacenter: "dc1", + } + resp := structs.ACLToken{} + + require.NoError(t, acl.Login(&req, &resp)) + + require.Equal(t, method.Name, resp.AuthMethod) + require.Equal(t, `token created via login: {"pod":"pod1"}`, resp.Description) + require.True(t, resp.Local) + require.Len(t, resp.Roles, 0) + require.Len(t, resp.ServiceIdentities, 1) + svcid := resp.ServiceIdentities[0] + require.Len(t, svcid.Datacenters, 0) + require.Equal(t, "method-db", svcid.ServiceName) + }) + + { + req := structs.ACLBindingRuleSetRequest{ + Datacenter: "dc1", + BindingRule: structs.ACLBindingRule{ + AuthMethod: ruleDB.AuthMethod, + BindType: structs.BindingRuleBindTypeService, + BindName: ruleDB.BindName, + Selector: "", + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + var out structs.ACLBindingRule + require.NoError(t, acl.BindingRuleSet(&req, &out)) + } + + t.Run("valid bearer token 1 binding (no selectors this time)", func(t *testing.T) { + req := structs.ACLLoginRequest{ + Auth: &structs.ACLLoginParams{ + AuthMethod: method.Name, + BearerToken: "fake-db", + Meta: map[string]string{"pod": "pod1"}, + }, + Datacenter: "dc1", + } + resp := structs.ACLToken{} + + require.NoError(t, acl.Login(&req, &resp)) + + require.Equal(t, method.Name, resp.AuthMethod) + require.Equal(t, `token created via login: {"pod":"pod1"}`, resp.Description) + require.True(t, resp.Local) + require.Len(t, resp.Roles, 0) + require.Len(t, resp.ServiceIdentities, 1) + svcid := resp.ServiceIdentities[0] + require.Len(t, svcid.Datacenters, 0) + require.Equal(t, "method-db", svcid.ServiceName) + }) + + testSessionID_2 := testauth.StartSession() + defer testauth.ResetSession(testSessionID_2) + { + // Update the method to force the cache to invalidate for the next + // subtest. + updated := *method + updated.Description = "updated for the test" + updated.Config = map[string]interface{}{ + "SessionID": testSessionID_2, + } + + req := structs.ACLAuthMethodSetRequest{ + Datacenter: "dc1", + AuthMethod: updated, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + var ignored structs.ACLAuthMethod + require.NoError(t, acl.AuthMethodSet(&req, &ignored)) + } + + t.Run("updating the method invalidates the cache", func(t *testing.T) { + // We'll try to login with the 'fake-db' cred which DOES exist in the + // old fake validator, but no longer exists in the new fake validator. + req := structs.ACLLoginRequest{ + Auth: &structs.ACLLoginParams{ + AuthMethod: method.Name, + BearerToken: "fake-db", + Meta: map[string]string{"pod": "pod1"}, + }, + Datacenter: "dc1", + } + resp := structs.ACLToken{} + + requireErrorContains(t, acl.Login(&req, &resp), "ACL not found") + }) +} + +func TestACLEndpoint_Login_k8s(t *testing.T) { + t.Parallel() + + dir1, s1 := testServerWithConfig(t, func(c *Config) { + c.ACLDatacenter = "dc1" + c.ACLsEnabled = true + c.ACLMasterToken = "root" + }) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + codec := rpcClient(t, s1) + defer codec.Close() + + testrpc.WaitForLeader(t, s1.RPC, "dc1") + + acl := ACL{srv: s1} + + // spin up a fake api server + testSrv := kubeauth.StartTestAPIServer(t) + defer testSrv.Stop() + + testSrv.AuthorizeJWT(goodJWT_A) + testSrv.SetAllowedServiceAccount( + "default", + "demo", + "76091af4-4b56-11e9-ac4b-708b11801cbe", + "", + goodJWT_B, + ) + + method, err := upsertTestKubernetesAuthMethod( + codec, "root", "dc1", + testSrv.CACert(), + testSrv.Addr(), + goodJWT_A, + ) + require.NoError(t, err) + + t.Run("invalid bearer token", func(t *testing.T) { + req := structs.ACLLoginRequest{ + Auth: &structs.ACLLoginParams{ + AuthMethod: method.Name, + BearerToken: "invalid", + Meta: map[string]string{"pod": "pod1"}, + }, + Datacenter: "dc1", + } + resp := structs.ACLToken{} + + require.Error(t, acl.Login(&req, &resp)) + }) + + t.Run("valid bearer token no bindings", func(t *testing.T) { + req := structs.ACLLoginRequest{ + Auth: &structs.ACLLoginParams{ + AuthMethod: method.Name, + BearerToken: goodJWT_B, + Meta: map[string]string{"pod": "pod1"}, + }, + Datacenter: "dc1", + } + resp := structs.ACLToken{} + + requireErrorContains(t, acl.Login(&req, &resp), "Permission denied") + }) + + _, err = upsertTestBindingRule( + codec, "root", "dc1", method.Name, + "serviceaccount.namespace==default", + structs.BindingRuleBindTypeService, + "${serviceaccount.name}", + ) + require.NoError(t, err) + + t.Run("valid bearer token 1 service binding", func(t *testing.T) { + req := structs.ACLLoginRequest{ + Auth: &structs.ACLLoginParams{ + AuthMethod: method.Name, + BearerToken: goodJWT_B, + Meta: map[string]string{"pod": "pod1"}, + }, + Datacenter: "dc1", + } + resp := structs.ACLToken{} + + require.NoError(t, acl.Login(&req, &resp)) + + require.Equal(t, method.Name, resp.AuthMethod) + require.Equal(t, `token created via login: {"pod":"pod1"}`, resp.Description) + require.True(t, resp.Local) + require.Len(t, resp.Roles, 0) + require.Len(t, resp.ServiceIdentities, 1) + svcid := resp.ServiceIdentities[0] + require.Len(t, svcid.Datacenters, 0) + require.Equal(t, "demo", svcid.ServiceName) + }) + + // annotate the account + testSrv.SetAllowedServiceAccount( + "default", + "demo", + "76091af4-4b56-11e9-ac4b-708b11801cbe", + "alternate-name", + goodJWT_B, + ) + + t.Run("valid bearer token 1 service binding - with annotation", func(t *testing.T) { + req := structs.ACLLoginRequest{ + Auth: &structs.ACLLoginParams{ + AuthMethod: method.Name, + BearerToken: goodJWT_B, + Meta: map[string]string{"pod": "pod1"}, + }, + Datacenter: "dc1", + } + resp := structs.ACLToken{} + + require.NoError(t, acl.Login(&req, &resp)) + + require.Equal(t, method.Name, resp.AuthMethod) + require.Equal(t, `token created via login: {"pod":"pod1"}`, resp.Description) + require.True(t, resp.Local) + require.Len(t, resp.Roles, 0) + require.Len(t, resp.ServiceIdentities, 1) + svcid := resp.ServiceIdentities[0] + require.Len(t, svcid.Datacenters, 0) + require.Equal(t, "alternate-name", svcid.ServiceName) + }) +} + +func TestACLEndpoint_Logout(t *testing.T) { + t.Parallel() + + dir1, s1 := testServerWithConfig(t, func(c *Config) { + c.ACLDatacenter = "dc1" + c.ACLsEnabled = true + c.ACLMasterToken = "root" + }) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + codec := rpcClient(t, s1) + defer codec.Close() + + testrpc.WaitForLeader(t, s1.RPC, "dc1") + + acl := ACL{srv: s1} + + testSessionID := testauth.StartSession() + defer testauth.ResetSession(testSessionID) + testauth.InstallSessionToken( + testSessionID, + "fake-db", // 1 rule + "default", "db", "def456", + ) + + method, err := upsertTestAuthMethod(codec, "root", "dc1", testSessionID) + require.NoError(t, err) + + _, err = upsertTestBindingRule( + codec, "root", "dc1", method.Name, + "", + structs.BindingRuleBindTypeService, + "method-${serviceaccount.name}", + ) + require.NoError(t, err) + + t.Run("you must provide a token", func(t *testing.T) { + req := structs.ACLLogoutRequest{ + Datacenter: "dc1", + // WriteRequest: structs.WriteRequest{Token: "root"}, + } + req.Token = "" + var ignored bool + + requireErrorContains(t, acl.Logout(&req, &ignored), "ACL not found") + }) + + t.Run("logout from deleted token", func(t *testing.T) { + req := structs.ACLLogoutRequest{ + Datacenter: "dc1", + WriteRequest: structs.WriteRequest{Token: "not-found"}, + } + var ignored bool + requireErrorContains(t, acl.Logout(&req, &ignored), "ACL not found") + }) + + t.Run("logout from non-auth method-linked token should fail", func(t *testing.T) { + req := structs.ACLLogoutRequest{ + Datacenter: "dc1", + WriteRequest: structs.WriteRequest{Token: "root"}, + } + var ignored bool + requireErrorContains(t, acl.Logout(&req, &ignored), "Permission denied") + }) + + t.Run("login then logout", func(t *testing.T) { + // Create a totally legit Login token. + loginReq := structs.ACLLoginRequest{ + Auth: &structs.ACLLoginParams{ + AuthMethod: method.Name, + BearerToken: "fake-db", + }, + Datacenter: "dc1", + } + loginToken := structs.ACLToken{} + + require.NoError(t, acl.Login(&loginReq, &loginToken)) + require.NotEmpty(t, loginToken.SecretID) + + // Now turn around and nuke it. + req := structs.ACLLogoutRequest{ + Datacenter: "dc1", + WriteRequest: structs.WriteRequest{Token: loginToken.SecretID}, + } + + var ignored bool + require.NoError(t, acl.Logout(&req, &ignored)) + }) +} + +func gatherIDs(t *testing.T, v interface{}) []string { + t.Helper() + + var out []string + switch x := v.(type) { + case []*structs.ACLRole: + for _, r := range x { + out = append(out, r.ID) + } + case structs.ACLRoles: + for _, r := range x { + out = append(out, r.ID) + } + case []*structs.ACLPolicy: + for _, p := range x { + out = append(out, p.ID) + } + case structs.ACLPolicyListStubs: + for _, p := range x { + out = append(out, p.ID) + } + case []*structs.ACLToken: + for _, p := range x { + out = append(out, p.AccessorID) + } + case structs.ACLTokenListStubs: + for _, p := range x { + out = append(out, p.AccessorID) + } + case []*structs.ACLAuthMethod: + for _, p := range x { + out = append(out, p.Name) + } + case structs.ACLAuthMethodListStubs: + for _, p := range x { + out = append(out, p.Name) + } + case []*structs.ACLBindingRule: + for _, p := range x { + out = append(out, p.ID) + } + case structs.ACLBindingRules: + for _, p := range x { + out = append(out, p.ID) + } + default: + t.Fatalf("unknown type: %T", x) + } + return out +} + +func TestValidateBindingRuleBindName(t *testing.T) { + t.Parallel() + + type testcase struct { + name string + bindType string + bindName string + fields string + valid bool // valid HIL, invalid contents + err bool // invalid HIL + } + + for _, test := range []testcase{ + {"no bind type", + "", "", "", false, false}, + {"bad bind type", + "invalid", "blah", "", false, true}, + // valid HIL, invalid name + {"empty", + "both", "", "", false, false}, + {"just end", + "both", "}", "", false, false}, + {"var without start", + "both", " item }", "item", false, false}, + {"two vars missing second start", + "both", "before-${ item }after--more }", "item,more", false, false}, + // names for the two types are validated differently + {"@ is disallowed", + "both", "bad@name", "", false, false}, + {"leading dash", + "role", "-name", "", true, false}, + {"leading dash", + "service", "-name", "", false, false}, + {"trailing dash", + "role", "name-", "", true, false}, + {"trailing dash", + "service", "name-", "", false, false}, + {"inner dash", + "both", "name-end", "", true, false}, + {"upper case", + "role", "NAME", "", true, false}, + {"upper case", + "service", "NAME", "", false, false}, + // valid HIL, valid name + {"no vars", + "both", "nothing", "", true, false}, + {"just var", + "both", "${item}", "item", true, false}, + {"var in middle", + "both", "before-${item}after", "item", true, false}, + {"two vars", + "both", "before-${item}after-${more}", "item,more", true, false}, + // bad + {"no bind name", + "both", "", "", false, false}, + {"just start", + "both", "${", "", false, true}, + {"backwards", + "both", "}${", "", false, true}, + {"no varname", + "both", "${}", "", false, true}, + {"missing map key", + "both", "${item}", "", false, true}, + {"var without end", + "both", "${ item ", "item", false, true}, + {"two vars missing first end", + "both", "before-${ item after-${ more }", "item,more", false, true}, + } { + var cases []testcase + if test.bindType == "both" { + test1 := test + test1.bindType = "role" + test2 := test + test2.bindType = "service" + cases = []testcase{test1, test2} + } else { + cases = []testcase{test} + } + + for _, test := range cases { + test := test + t.Run(test.bindType+"--"+test.name, func(t *testing.T) { + t.Parallel() + valid, err := validateBindingRuleBindName( + test.bindType, + test.bindName, + strings.Split(test.fields, ","), + ) + if test.err { + require.NotNil(t, err) + require.False(t, valid) + } else { + require.NoError(t, err) + require.Equal(t, test.valid, valid) + } + }) + } } - require.ElementsMatch(t, retrievedRoles, roles) } // upsertTestToken creates a token for testing purposes @@ -2836,6 +5049,166 @@ func retrieveTestRoleByName(codec rpc.ClientCodec, masterToken string, datacente return &out, nil } +func deleteTestAuthMethod(codec rpc.ClientCodec, masterToken string, datacenter string, methodName string) error { + arg := structs.ACLAuthMethodDeleteRequest{ + Datacenter: datacenter, + AuthMethodName: methodName, + WriteRequest: structs.WriteRequest{Token: masterToken}, + } + + var ignored string + err := msgpackrpc.CallWithCodec(codec, "ACL.AuthMethodDelete", &arg, &ignored) + return err +} +func upsertTestAuthMethod( + codec rpc.ClientCodec, masterToken string, datacenter string, + sessionID string, +) (*structs.ACLAuthMethod, error) { + name, err := uuid.GenerateUUID() + if err != nil { + return nil, err + } + + req := structs.ACLAuthMethodSetRequest{ + Datacenter: datacenter, + AuthMethod: structs.ACLAuthMethod{ + Name: "test-method-" + name, + Type: "testing", + Config: map[string]interface{}{ + "SessionID": sessionID, + }, + }, + WriteRequest: structs.WriteRequest{Token: masterToken}, + } + + var out structs.ACLAuthMethod + + err = msgpackrpc.CallWithCodec(codec, "ACL.AuthMethodSet", &req, &out) + if err != nil { + return nil, err + } + + return &out, nil +} + +func upsertTestKubernetesAuthMethod( + codec rpc.ClientCodec, masterToken string, datacenter string, + caCert, kubeHost, kubeJWT string, +) (*structs.ACLAuthMethod, error) { + name, err := uuid.GenerateUUID() + if err != nil { + return nil, err + } + + if kubeHost == "" { + kubeHost = "https://abc:8443" + } + if kubeJWT == "" { + kubeJWT = goodJWT_A + } + + req := structs.ACLAuthMethodSetRequest{ + Datacenter: datacenter, + AuthMethod: structs.ACLAuthMethod{ + Name: "test-method-" + name, + Type: "kubernetes", + Config: map[string]interface{}{ + "Host": kubeHost, + "CACert": caCert, + "ServiceAccountJWT": kubeJWT, + }, + }, + WriteRequest: structs.WriteRequest{Token: masterToken}, + } + + var out structs.ACLAuthMethod + + err = msgpackrpc.CallWithCodec(codec, "ACL.AuthMethodSet", &req, &out) + if err != nil { + return nil, err + } + + return &out, nil +} + +func retrieveTestAuthMethod(codec rpc.ClientCodec, masterToken string, datacenter string, name string) (*structs.ACLAuthMethodResponse, error) { + arg := structs.ACLAuthMethodGetRequest{ + Datacenter: datacenter, + AuthMethodName: name, + QueryOptions: structs.QueryOptions{Token: masterToken}, + } + + var out structs.ACLAuthMethodResponse + + err := msgpackrpc.CallWithCodec(codec, "ACL.AuthMethodRead", &arg, &out) + + if err != nil { + return nil, err + } + + return &out, nil +} + +func deleteTestBindingRule(codec rpc.ClientCodec, masterToken string, datacenter string, ruleID string) error { + arg := structs.ACLBindingRuleDeleteRequest{ + Datacenter: datacenter, + BindingRuleID: ruleID, + WriteRequest: structs.WriteRequest{Token: masterToken}, + } + + var ignored string + err := msgpackrpc.CallWithCodec(codec, "ACL.BindingRuleDelete", &arg, &ignored) + return err +} + +func upsertTestBindingRule( + codec rpc.ClientCodec, + masterToken string, + datacenter string, + methodName string, + selector string, + bindType string, + bindName string, +) (*structs.ACLBindingRule, error) { + req := structs.ACLBindingRuleSetRequest{ + Datacenter: datacenter, + BindingRule: structs.ACLBindingRule{ + AuthMethod: methodName, + BindType: bindType, + BindName: bindName, + Selector: selector, + }, + WriteRequest: structs.WriteRequest{Token: masterToken}, + } + + var out structs.ACLBindingRule + + err := msgpackrpc.CallWithCodec(codec, "ACL.BindingRuleSet", &req, &out) + if err != nil { + return nil, err + } + + return &out, nil +} + +func retrieveTestBindingRule(codec rpc.ClientCodec, masterToken string, datacenter string, ruleID string) (*structs.ACLBindingRuleResponse, error) { + arg := structs.ACLBindingRuleGetRequest{ + Datacenter: datacenter, + BindingRuleID: ruleID, + QueryOptions: structs.QueryOptions{Token: masterToken}, + } + + var out structs.ACLBindingRuleResponse + + err := msgpackrpc.CallWithCodec(codec, "ACL.BindingRuleRead", &arg, &out) + + if err != nil { + return nil, err + } + + return &out, nil +} + func requireTimeEquals(t *testing.T, expect, got *time.Time) { t.Helper() if expect == nil && got == nil { @@ -2858,3 +5231,9 @@ func requireErrorContains(t *testing.T, err error, expectedErrorMessage string) t.Fatalf("unexpected error: %v", err) } } + +// 'default/admin' +const goodJWT_A = "eyJhbGciOiJSUzI1NiIsImtpZCI6IiJ9.eyJpc3MiOiJrdWJlcm5ldGVzL3NlcnZpY2VhY2NvdW50Iiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9uYW1lc3BhY2UiOiJkZWZhdWx0Iiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9zZWNyZXQubmFtZSI6ImFkbWluLXRva2VuLXFsejQyIiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9zZXJ2aWNlLWFjY291bnQubmFtZSI6ImFkbWluIiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9zZXJ2aWNlLWFjY291bnQudWlkIjoiNzM4YmMyNTEtNjUzMi0xMWU5LWI2N2YtNDhlNmM4YjhlY2I1Iiwic3ViIjoic3lzdGVtOnNlcnZpY2VhY2NvdW50OmRlZmF1bHQ6YWRtaW4ifQ.ixMlnWrAG7NVuTTKu8cdcYfM7gweS3jlKaEsIBNGOVEjPE7rtXtgMkAwjQTdYR08_0QBjkgzy5fQC5ZNyglSwONJ-bPaXGvhoH1cTnRi1dz9H_63CfqOCvQP1sbdkMeRxNTGVAyWZT76rXoCUIfHP4LY2I8aab0KN9FTIcgZRF0XPTtT70UwGIrSmRpxW38zjiy2ymWL01cc5VWGhJqVysmWmYk3wNp0h5N57H_MOrz4apQR4pKaamzskzjLxO55gpbmZFC76qWuUdexAR7DT2fpbHLOw90atN_NlLMY-VrXyW3-Ei5EhYaVreMB9PSpKwkrA4jULITohV-sxpa1LA" + +// 'default/demo' +const goodJWT_B = "eyJhbGciOiJSUzI1NiIsImtpZCI6IiJ9.eyJpc3MiOiJrdWJlcm5ldGVzL3NlcnZpY2VhY2NvdW50Iiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9uYW1lc3BhY2UiOiJkZWZhdWx0Iiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9zZWNyZXQubmFtZSI6ImRlbW8tdG9rZW4ta21iOW4iLCJrdWJlcm5ldGVzLmlvL3NlcnZpY2VhY2NvdW50L3NlcnZpY2UtYWNjb3VudC5uYW1lIjoiZGVtbyIsImt1YmVybmV0ZXMuaW8vc2VydmljZWFjY291bnQvc2VydmljZS1hY2NvdW50LnVpZCI6Ijc2MDkxYWY0LTRiNTYtMTFlOS1hYzRiLTcwOGIxMTgwMWNiZSIsInN1YiI6InN5c3RlbTpzZXJ2aWNlYWNjb3VudDpkZWZhdWx0OmRlbW8ifQ.ZiAHjijBAOsKdum0Aix6lgtkLkGo9_Tu87dWQ5Zfwnn3r2FejEWDAnftTft1MqqnMzivZ9Wyyki5ZjQRmTAtnMPJuHC-iivqY4Wh4S6QWCJ1SivBv5tMZR79t5t8mE7R1-OHwst46spru1pps9wt9jsA04d3LpV0eeKYgdPTVaQKklxTm397kIMUugA6yINIBQ3Rh8eQqBgNwEmL4iqyYubzHLVkGkoP9MJikFI05vfRiHtYr-piXz6JFDzXMQj9rW6xtMmrBSn79ChbyvC5nz-Nj2rJPnHsb_0rDUbmXY5PpnMhBpdSH-CbZ4j8jsiib6DtaGJhVZeEQ1GjsFAZwQ" diff --git a/agent/consul/acl_replication_legacy.go b/agent/consul/acl_replication_legacy.go index 010c220a96..b933f714e9 100644 --- a/agent/consul/acl_replication_legacy.go +++ b/agent/consul/acl_replication_legacy.go @@ -138,7 +138,7 @@ func reconcileLegacyACLs(local, remote structs.ACLs, lastRemoteIndex uint64) str // FetchLocalACLs returns the ACLs in the local state store. func (s *Server) fetchLocalLegacyACLs() (structs.ACLs, error) { - _, local, err := s.fsm.State().ACLTokenList(nil, false, true, "", "") + _, local, err := s.fsm.State().ACLTokenList(nil, false, true, "", "", "") if err != nil { return nil, err } diff --git a/agent/consul/acl_replication_legacy_test.go b/agent/consul/acl_replication_legacy_test.go index f5a2601d54..171f71c359 100644 --- a/agent/consul/acl_replication_legacy_test.go +++ b/agent/consul/acl_replication_legacy_test.go @@ -396,11 +396,11 @@ func TestACLReplication_LegacyTokens(t *testing.T) { } checkSame := func() error { - index, remote, err := s1.fsm.State().ACLTokenList(nil, true, true, "", "") + index, remote, err := s1.fsm.State().ACLTokenList(nil, true, true, "", "", "") if err != nil { return err } - _, local, err := s2.fsm.State().ACLTokenList(nil, true, true, "", "") + _, local, err := s2.fsm.State().ACLTokenList(nil, true, true, "", "", "") if err != nil { return err } diff --git a/agent/consul/acl_replication_test.go b/agent/consul/acl_replication_test.go index 730527eedc..e8a6a7d693 100644 --- a/agent/consul/acl_replication_test.go +++ b/agent/consul/acl_replication_test.go @@ -351,9 +351,9 @@ func TestACLReplication_Tokens(t *testing.T) { checkSame := func(t *retry.R) { // only account for global tokens - local tokens shouldn't be replicated - index, remote, err := s1.fsm.State().ACLTokenList(nil, false, true, "", "") + index, remote, err := s1.fsm.State().ACLTokenList(nil, false, true, "", "", "") require.NoError(t, err) - _, local, err := s2.fsm.State().ACLTokenList(nil, false, true, "", "") + _, local, err := s2.fsm.State().ACLTokenList(nil, false, true, "", "", "") require.NoError(t, err) require.Len(t, local, len(remote)) @@ -444,7 +444,7 @@ func TestACLReplication_Tokens(t *testing.T) { }) // verify dc2 local tokens didn't get blown away - _, local, err := s2.fsm.State().ACLTokenList(nil, true, false, "", "") + _, local, err := s2.fsm.State().ACLTokenList(nil, true, false, "", "", "") require.NoError(t, err) require.Len(t, local, 50) @@ -779,10 +779,10 @@ func TestACLReplication_AllTypes(t *testing.T) { checkSameTokens := func(t *retry.R) { // only account for global tokens - local tokens shouldn't be replicated - index, remote, err := s1.fsm.State().ACLTokenList(nil, false, true, "", "") + index, remote, err := s1.fsm.State().ACLTokenList(nil, false, true, "", "", "") require.NoError(t, err) // Query for all of them, so that we can prove that no globals snuck in. - _, local, err := s2.fsm.State().ACLTokenList(nil, true, true, "", "") + _, local, err := s2.fsm.State().ACLTokenList(nil, true, true, "", "", "") require.NoError(t, err) require.Len(t, remote, len(local)) diff --git a/agent/consul/acl_replication_types.go b/agent/consul/acl_replication_types.go index 7044442fdf..8efc229632 100644 --- a/agent/consul/acl_replication_types.go +++ b/agent/consul/acl_replication_types.go @@ -34,7 +34,7 @@ func (r *aclTokenReplicator) FetchRemote(srv *Server, lastRemoteIndex uint64) (i func (r *aclTokenReplicator) FetchLocal(srv *Server) (int, uint64, error) { r.local = nil - idx, local, err := srv.fsm.State().ACLTokenList(nil, false, true, "", "") + idx, local, err := srv.fsm.State().ACLTokenList(nil, false, true, "", "", "") if err != nil { return 0, 0, err } diff --git a/agent/consul/acl_server.go b/agent/consul/acl_server.go index d895d922a2..34ca09584b 100644 --- a/agent/consul/acl_server.go +++ b/agent/consul/acl_server.go @@ -73,6 +73,17 @@ func (s *Server) checkRoleUUID(id string) (bool, error) { return !structs.ACLIDReserved(id), nil } +func (s *Server) checkBindingRuleUUID(id string) (bool, error) { + state := s.fsm.State() + if _, rule, err := state.ACLBindingRuleGetByID(nil, id); err != nil { + return false, err + } else if rule != nil { + return false, nil + } + + return !structs.ACLIDReserved(id), nil +} + func (s *Server) updateACLAdvertisement() { // One thing to note is that once in new ACL mode the server will // never transition to legacy ACL mode. This is not currently a diff --git a/agent/consul/acl_test.go b/agent/consul/acl_test.go index 65f05d9875..e2c84afe21 100644 --- a/agent/consul/acl_test.go +++ b/agent/consul/acl_test.go @@ -2,7 +2,6 @@ package consul import ( "fmt" - "log" "os" "reflect" "strings" @@ -12,6 +11,7 @@ import ( "github.com/hashicorp/consul/acl" "github.com/hashicorp/consul/agent/structs" + "github.com/hashicorp/consul/sdk/testutil" "github.com/hashicorp/consul/sdk/testutil/retry" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -139,6 +139,26 @@ func testIdentityForToken(token string) (bool, structs.ACLIdentity, error) { }, }, }, nil + case "found-synthetic-policy-1": + return true, &structs.ACLToken{ + AccessorID: "f6c5a5fb-4da4-422b-9abf-2c942813fc71", + SecretID: "55cb7d69-2bea-42c3-a68f-2a1443d2abbc", + ServiceIdentities: []*structs.ACLServiceIdentity{ + &structs.ACLServiceIdentity{ + ServiceName: "service1", + }, + }, + }, nil + case "found-synthetic-policy-2": + return true, &structs.ACLToken{ + AccessorID: "7c87dfad-be37-446e-8305-299585677cb5", + SecretID: "dfca9676-ac80-453a-837b-4c0cf923473c", + ServiceIdentities: []*structs.ACLServiceIdentity{ + &structs.ACLServiceIdentity{ + ServiceName: "service2", + }, + }, + }, nil case "acl-ro": return true, &structs.ACLToken{ AccessorID: "435a75af-1763-4980-89f4-f0951dda53b4", @@ -430,57 +450,87 @@ type ACLResolverTestDelegate struct { roleCached bool } +func (d *ACLResolverTestDelegate) Reset() { + d.tokenCached = false + d.policyCached = false + d.roleCached = false +} + var errRPC = fmt.Errorf("Induced RPC Error") func (d *ACLResolverTestDelegate) defaultTokenReadFn(errAfterCached error) func(*structs.ACLTokenGetRequest, *structs.ACLTokenResponse) error { return func(args *structs.ACLTokenGetRequest, reply *structs.ACLTokenResponse) error { if !d.tokenCached { - _, token, _ := testIdentityForToken(args.TokenID) - reply.Token = token.(*structs.ACLToken) - + err := d.plainTokenReadFn(args, reply) d.tokenCached = true - return nil + return err } return errAfterCached } } +func (d *ACLResolverTestDelegate) plainTokenReadFn(args *structs.ACLTokenGetRequest, reply *structs.ACLTokenResponse) error { + _, token, err := testIdentityForToken(args.TokenID) + if token != nil { + reply.Token = token.(*structs.ACLToken) + } + return err +} + func (d *ACLResolverTestDelegate) defaultPolicyResolveFn(errAfterCached error) func(*structs.ACLPolicyBatchGetRequest, *structs.ACLPolicyBatchResponse) error { return func(args *structs.ACLPolicyBatchGetRequest, reply *structs.ACLPolicyBatchResponse) error { if !d.policyCached { - for _, policyID := range args.PolicyIDs { - _, policy, _ := testPolicyForID(policyID) - if policy != nil { - reply.Policies = append(reply.Policies, policy) - } - } - + err := d.plainPolicyResolveFn(args, reply) d.policyCached = true - return nil + return err } return errAfterCached } } +func (d *ACLResolverTestDelegate) plainPolicyResolveFn(args *structs.ACLPolicyBatchGetRequest, reply *structs.ACLPolicyBatchResponse) error { + // TODO: if we were being super correct about it, we'd verify the token first + // TODO: and possibly return a not-found or permission-denied here + + for _, policyID := range args.PolicyIDs { + _, policy, _ := testPolicyForID(policyID) + if policy != nil { + reply.Policies = append(reply.Policies, policy) + } + } + + return nil +} + func (d *ACLResolverTestDelegate) defaultRoleResolveFn(errAfterCached error) func(*structs.ACLRoleBatchGetRequest, *structs.ACLRoleBatchResponse) error { return func(args *structs.ACLRoleBatchGetRequest, reply *structs.ACLRoleBatchResponse) error { if !d.roleCached { - for _, roleID := range args.RoleIDs { - _, role, _ := testRoleForID(roleID) - if role != nil { - reply.Roles = append(reply.Roles, role) - } - } - + err := d.plainRoleResolveFn(args, reply) d.roleCached = true - return nil + return err } return errAfterCached } } +// plainRoleResolveFn tries to follow the normal logic of ACL.RoleResolve using +// the test fixtures. +func (d *ACLResolverTestDelegate) plainRoleResolveFn(args *structs.ACLRoleBatchGetRequest, reply *structs.ACLRoleBatchResponse) error { + // TODO: if we were being super correct about it, we'd verify the token first + // TODO: and possibly return a not-found or permission-denied here + + for _, roleID := range args.RoleIDs { + _, role, _ := testRoleForID(roleID) + if role != nil { + reply.Roles = append(reply.Roles, role) + } + } + + return nil +} + func (d *ACLResolverTestDelegate) ACLsEnabled() bool { return d.enabled } @@ -549,7 +599,7 @@ func newTestACLResolver(t *testing.T, delegate ACLResolverDelegate, cb func(*ACL config.ACLDownPolicy = "extend-cache" rconf := &ACLResolverConfig{ Config: config, - Logger: log.New(os.Stdout, t.Name()+" - ", log.LstdFlags|log.Lmicroseconds), + Logger: testutil.TestLoggerWithName(t, t.Name()), CacheConfig: &structs.ACLCachesConfig{ Identities: 4, Policies: 4, @@ -1058,7 +1108,7 @@ func TestACLResolver_DownPolicy(t *testing.T) { require.NoError(t, err) require.NotNil(t, authz2) // testing pointer equality - these will be the same object because it is cached. - require.True(t, authz == authz2) + require.True(t, authz == authz2, "\n[1]={%+v} != \n[2]={%+v}", authz, authz2) require.True(t, authz2.NodeWrite("foo", nil)) }) @@ -1445,6 +1495,23 @@ func TestACLResolver_Client(t *testing.T) { }) } +func TestACLResolver_Client_TokensPoliciesAndRoles(t *testing.T) { + t.Parallel() + delegate := &ACLResolverTestDelegate{ + enabled: true, + datacenter: "dc1", + legacy: false, + localTokens: false, + localPolicies: false, + localRoles: false, + } + delegate.tokenReadFn = delegate.plainTokenReadFn + delegate.policyResolveFn = delegate.plainPolicyResolveFn + delegate.roleResolveFn = delegate.plainRoleResolveFn + + testACLResolver_variousTokens(t, delegate) +} + func TestACLResolver_LocalTokensPoliciesAndRoles(t *testing.T) { t.Parallel() delegate := &ACLResolverTestDelegate{ @@ -1470,31 +1537,43 @@ func TestACLResolver_LocalPoliciesAndRoles(t *testing.T) { localTokens: false, localPolicies: true, localRoles: true, - tokenReadFn: func(args *structs.ACLTokenGetRequest, reply *structs.ACLTokenResponse) error { - _, token, err := testIdentityForToken(args.TokenID) - - if token != nil { - reply.Token = token.(*structs.ACLToken) - } - return err - }, } + delegate.tokenReadFn = delegate.plainTokenReadFn testACLResolver_variousTokens(t, delegate) } func testACLResolver_variousTokens(t *testing.T, delegate *ACLResolverTestDelegate) { t.Helper() - r := newTestACLResolver(t, delegate, nil) + r := newTestACLResolver(t, delegate, func(config *ACLResolverConfig) { + config.Config.ACLTokenTTL = 600 * time.Second + config.Config.ACLPolicyTTL = 30 * time.Millisecond + config.Config.ACLRoleTTL = 30 * time.Millisecond + config.Config.ACLDownPolicy = "extend-cache" + }) + reset := func() { + // prevent subtest bleedover + r.cache.Purge() + delegate.Reset() + } - t.Run("Missing Identity", func(t *testing.T) { + runTwiceAndReset := func(name string, f func(t *testing.T)) { + t.Helper() + defer reset() // reset the stateful resolve AND blow away the cache + + t.Run(name+" (no-cache)", f) + delegate.Reset() // allow the stateful resolve functions to reset + t.Run(name+" (cached)", f) + } + + runTwiceAndReset("Missing Identity", func(t *testing.T) { authz, err := r.ResolveToken("doesn't exist") require.Nil(t, authz) require.Error(t, err) require.True(t, acl.IsErrNotFound(err)) }) - t.Run("Missing Policy", func(t *testing.T) { + runTwiceAndReset("Missing Policy", func(t *testing.T) { authz, err := r.ResolveToken("missing-policy") require.NoError(t, err) require.NotNil(t, authz) @@ -1502,7 +1581,7 @@ func testACLResolver_variousTokens(t *testing.T, delegate *ACLResolverTestDelega require.False(t, authz.NodeWrite("foo", nil)) }) - t.Run("Missing Role", func(t *testing.T) { + runTwiceAndReset("Missing Role", func(t *testing.T) { authz, err := r.ResolveToken("missing-role") require.NoError(t, err) require.NotNil(t, authz) @@ -1510,7 +1589,7 @@ func testACLResolver_variousTokens(t *testing.T, delegate *ACLResolverTestDelega require.False(t, authz.NodeWrite("foo", nil)) }) - t.Run("Missing Policy on Role", func(t *testing.T) { + runTwiceAndReset("Missing Policy on Role", func(t *testing.T) { authz, err := r.ResolveToken("missing-policy-on-role") require.NoError(t, err) require.NotNil(t, authz) @@ -1518,7 +1597,7 @@ func testACLResolver_variousTokens(t *testing.T, delegate *ACLResolverTestDelega require.False(t, authz.NodeWrite("foo", nil)) }) - t.Run("Normal with Policy", func(t *testing.T) { + runTwiceAndReset("Normal with Policy", func(t *testing.T) { authz, err := r.ResolveToken("found") require.NotNil(t, authz) require.NoError(t, err) @@ -1526,7 +1605,7 @@ func testACLResolver_variousTokens(t *testing.T, delegate *ACLResolverTestDelega require.True(t, authz.NodeWrite("foo", nil)) }) - t.Run("Normal with Role", func(t *testing.T) { + runTwiceAndReset("Normal with Role", func(t *testing.T) { authz, err := r.ResolveToken("found-role") require.NotNil(t, authz) require.NoError(t, err) @@ -1534,7 +1613,7 @@ func testACLResolver_variousTokens(t *testing.T, delegate *ACLResolverTestDelega require.True(t, authz.NodeWrite("foo", nil)) }) - t.Run("Normal with Policy and Role", func(t *testing.T) { + runTwiceAndReset("Normal with Policy and Role", func(t *testing.T) { authz, err := r.ResolveToken("found-policy-and-role") require.NotNil(t, authz) require.NoError(t, err) @@ -1543,7 +1622,41 @@ func testACLResolver_variousTokens(t *testing.T, delegate *ACLResolverTestDelega require.True(t, authz.ServiceRead("bar")) }) - t.Run("Anonymous", func(t *testing.T) { + runTwiceAndReset("Synthetic Policies Independently Cache", func(t *testing.T) { + // We resolve both of these tokens in the same cache session + // to verify that the keys for caching synthetic policies don't bleed + // over between each other. + { + authz, err := r.ResolveToken("found-synthetic-policy-1") + require.NotNil(t, authz) + require.NoError(t, err) + // spot check some random perms + require.False(t, authz.ACLRead()) + require.False(t, authz.NodeWrite("foo", nil)) + // ensure we didn't bleed over to the other synthetic policy + require.False(t, authz.ServiceWrite("service2", nil)) + // check our own synthetic policy + require.True(t, authz.ServiceWrite("service1", nil)) + require.True(t, authz.ServiceRead("literally-anything")) + require.True(t, authz.NodeRead("any-node")) + } + { + authz, err := r.ResolveToken("found-synthetic-policy-2") + require.NotNil(t, authz) + require.NoError(t, err) + // spot check some random perms + require.False(t, authz.ACLRead()) + require.False(t, authz.NodeWrite("foo", nil)) + // ensure we didn't bleed over to the other synthetic policy + require.False(t, authz.ServiceWrite("service1", nil)) + // check our own synthetic policy + require.True(t, authz.ServiceWrite("service2", nil)) + require.True(t, authz.ServiceRead("literally-anything")) + require.True(t, authz.NodeRead("any-node")) + } + }) + + runTwiceAndReset("Anonymous", func(t *testing.T) { authz, err := r.ResolveToken("") require.NotNil(t, authz) require.NoError(t, err) @@ -1551,7 +1664,7 @@ func testACLResolver_variousTokens(t *testing.T, delegate *ACLResolverTestDelega require.True(t, authz.NodeWrite("foo", nil)) }) - t.Run("legacy-management", func(t *testing.T) { + runTwiceAndReset("legacy-management", func(t *testing.T) { authz, err := r.ResolveToken("legacy-management") require.NotNil(t, authz) require.NoError(t, err) @@ -1559,7 +1672,7 @@ func testACLResolver_variousTokens(t *testing.T, delegate *ACLResolverTestDelega require.True(t, authz.KeyRead("foo")) }) - t.Run("legacy-client", func(t *testing.T) { + runTwiceAndReset("legacy-client", func(t *testing.T) { authz, err := r.ResolveToken("legacy-client") require.NoError(t, err) require.NotNil(t, authz) diff --git a/agent/consul/acl_token_exp_test.go b/agent/consul/acl_token_exp_test.go index 20ae878afc..a851b4dc33 100644 --- a/agent/consul/acl_token_exp_test.go +++ b/agent/consul/acl_token_exp_test.go @@ -51,7 +51,7 @@ func testACLTokenReap_Primary(t *testing.T, local, global bool) { codec := rpcClient(t, s1) defer codec.Close() - acl := ACL{s1} + acl := ACL{srv: s1} masterTokenAccessorID, err := retrieveTestTokenAccessorForSecret(codec, "root", "dc1", "root") require.NoError(t, err) diff --git a/agent/consul/authmethod/authmethods.go b/agent/consul/authmethod/authmethods.go new file mode 100644 index 0000000000..8fd477d0f2 --- /dev/null +++ b/agent/consul/authmethod/authmethods.go @@ -0,0 +1,112 @@ +package authmethod + +import ( + "fmt" + "sort" + "sync" + + "github.com/hashicorp/consul/agent/structs" + "github.com/mitchellh/mapstructure" +) + +type ValidatorFactory func(method *structs.ACLAuthMethod) (Validator, error) + +type Validator interface { + // Name returns the name of the auth method backing this validator. + Name() string + + // ValidateLogin takes raw user-provided auth method metadata and ensures + // it is sane, provably correct, and currently valid. Relevant identifying + // data is extracted and returned for immediate use by the role binding + // process. + // + // Depending upon the method, it may make sense to use these calls to + // continue to extend the life of the underlying token. + // + // Returns auth method specific metadata suitable for the Role Binding + // process. + ValidateLogin(loginToken string) (map[string]string, error) + + // AvailableFields returns a slice of all fields that are returned as a + // result of ValidateLogin. These are valid fields for use in any + // BindingRule tied to this auth method. + AvailableFields() []string + + // MakeFieldMapSelectable converts a field map as returned by ValidateLogin + // into a structure suitable for selection with a binding rule. + MakeFieldMapSelectable(fieldMap map[string]string) interface{} +} + +var ( + typesMu sync.RWMutex + types = make(map[string]ValidatorFactory) +) + +// Register makes an auth method with the given type available for use. If +// Register is called twice with the same name or if validator is nil, it +// panics. +func Register(name string, factory ValidatorFactory) { + typesMu.Lock() + defer typesMu.Unlock() + if factory == nil { + panic("authmethod: Register factory is nil for type " + name) + } + if _, dup := types[name]; dup { + panic("authmethod: Register called twice for type " + name) + } + types[name] = factory +} + +func IsRegisteredType(typeName string) bool { + typesMu.RLock() + _, ok := types[typeName] + typesMu.RUnlock() + return ok +} + +// NewValidator instantiates a new Validator for the given auth method +// configuration. If no auth method is registered with the provided type an +// error is returned. +func NewValidator(method *structs.ACLAuthMethod) (Validator, error) { + typesMu.RLock() + factory, ok := types[method.Type] + typesMu.RUnlock() + + if !ok { + return nil, fmt.Errorf("no auth method registered with type: %s", method.Type) + } + + return factory(method) +} + +// Types returns a sorted list of the names of the registered types. +func Types() []string { + typesMu.RLock() + defer typesMu.RUnlock() + var list []string + for name := range types { + list = append(list, name) + } + sort.Strings(list) + return list +} + +// ParseConfig parses the config block for a auth method. +func ParseConfig(rawConfig map[string]interface{}, out interface{}) error { + decodeConf := &mapstructure.DecoderConfig{ + Result: out, + WeaklyTypedInput: true, + ErrorUnused: true, + } + + decoder, err := mapstructure.NewDecoder(decodeConf) + if err != nil { + return err + } + + if err := decoder.Decode(rawConfig); err != nil { + return fmt.Errorf("error decoding config: %s", err) + } + + return nil +} diff --git a/agent/consul/authmethod/kubeauth/k8s.go b/agent/consul/authmethod/kubeauth/k8s.go new file mode 100644 index 0000000000..88c4b32e3d --- /dev/null +++ b/agent/consul/authmethod/kubeauth/k8s.go @@ -0,0 +1,202 @@ +package kubeauth + +import ( + "errors" + "fmt" + "strings" + + "github.com/hashicorp/consul/agent/consul/authmethod" + "github.com/hashicorp/consul/agent/structs" + cleanhttp "github.com/hashicorp/go-cleanhttp" + "gopkg.in/square/go-jose.v2/jwt" + authv1 "k8s.io/api/authentication/v1" + client_metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + k8s "k8s.io/client-go/kubernetes" + client_authv1 "k8s.io/client-go/kubernetes/typed/authentication/v1" + client_corev1 "k8s.io/client-go/kubernetes/typed/core/v1" + client_rest "k8s.io/client-go/rest" + cert "k8s.io/client-go/util/cert" +) + +func init() { + // register this as an available auth method type + authmethod.Register("kubernetes", func(method *structs.ACLAuthMethod) (authmethod.Validator, error) { + v, err := NewValidator(method) + if err != nil { + return nil, err + } + return v, nil + }) +} + +const ( + serviceAccountNamespaceField = "serviceaccount.namespace" + serviceAccountNameField = "serviceaccount.name" + serviceAccountUIDField = "serviceaccount.uid" + + serviceAccountServiceNameAnnotation = "consul.hashicorp.com/service-name" +) + +type Config struct { + // Host must be a host string, a host:port pair, or a URL to the base of + // the Kubernetes API server. + Host string `json:",omitempty"` + + // PEM encoded CA cert for use by the TLS client used to talk with the + // Kubernetes API. Every line must end with a newline: \n + CACert string `json:",omitempty"` + + // A service account JWT used to access the TokenReview API to validate + // other JWTs during login. It also must be able to read ServiceAccount + // annotations. + ServiceAccountJWT string `json:",omitempty"` +} + +// Validator is the wrapper around the relevant portions of the Kubernetes API +// that also conforms to the authmethod.Validator interface. +type Validator struct { + name string + config *Config + saGetter client_corev1.ServiceAccountsGetter + trGetter client_authv1.TokenReviewsGetter +} + +func NewValidator(method *structs.ACLAuthMethod) (*Validator, error) { + if method.Type != "kubernetes" { + return nil, fmt.Errorf("%q is not a kubernetes auth method", method.Name) + } + + var config Config + if err := authmethod.ParseConfig(method.Config, &config); err != nil { + return nil, err + } + + if config.Host == "" { + return nil, fmt.Errorf("Config.Host is required") + } + + if config.CACert == "" { + return nil, fmt.Errorf("Config.CACert is required") + } + if _, err := cert.ParseCertsPEM([]byte(config.CACert)); err != nil { + return nil, fmt.Errorf("error parsing kubernetes ca cert: %v", err) + } + + // This is the bearer token we give the apiserver to use the API. + if config.ServiceAccountJWT == "" { + return nil, fmt.Errorf("Config.ServiceAccountJWT is required") + } + if _, err := jwt.ParseSigned(config.ServiceAccountJWT); err != nil { + return nil, fmt.Errorf("Config.ServiceAccountJWT is not a valid JWT: %v", err) + } + + transport := cleanhttp.DefaultTransport() + client, err := k8s.NewForConfig(&client_rest.Config{ + Host: config.Host, + BearerToken: config.ServiceAccountJWT, + Dial: transport.DialContext, + TLSClientConfig: client_rest.TLSClientConfig{ + CAData: []byte(config.CACert), + }, + ContentConfig: client_rest.ContentConfig{ + ContentType: "application/json", + }, + }) + if err != nil { + return nil, err + } + + return &Validator{ + name: method.Name, + config: &config, + saGetter: client.CoreV1(), + trGetter: client.AuthenticationV1(), + }, nil +} + +func (v *Validator) Name() string { return v.name } + +func (v *Validator) ValidateLogin(loginToken string) (map[string]string, error) { + if _, err := jwt.ParseSigned(loginToken); err != nil { + return nil, fmt.Errorf("failed to parse and validate JWT: %v", err) + } + + // Check TokenReview for the bulk of the work. + trResp, err := v.trGetter.TokenReviews().Create(&authv1.TokenReview{ + Spec: authv1.TokenReviewSpec{ + Token: loginToken, + }, + }) + + if err != nil { + return nil, err + } else if trResp.Status.Error != "" { + return nil, fmt.Errorf("lookup failed: %s", trResp.Status.Error) + } + + if !trResp.Status.Authenticated { + return nil, errors.New("lookup failed: service account jwt not valid") + } + + // The username is of format: system:serviceaccount:(NAMESPACE):(SERVICEACCOUNT) + parts := strings.Split(trResp.Status.User.Username, ":") + if len(parts) != 4 { + return nil, errors.New("lookup failed: unexpected username format") + } + + // Validate the user that comes back from token review is a service account + if parts[0] != "system" || parts[1] != "serviceaccount" { + return nil, errors.New("lookup failed: username returned is not a service account") + } + + var ( + saNamespace = parts[2] + saName = parts[3] + saUID = string(trResp.Status.User.UID) + ) + + // Check to see if there is an override name on the ServiceAccount object. + sa, err := v.saGetter.ServiceAccounts(saNamespace).Get(saName, client_metav1.GetOptions{}) + if err != nil { + return nil, fmt.Errorf("annotation lookup failed: %v", err) + } + + annotations := sa.GetObjectMeta().GetAnnotations() + if serviceNameOverride, ok := annotations[serviceAccountServiceNameAnnotation]; ok { + saName = serviceNameOverride + } + + return map[string]string{ + serviceAccountNamespaceField: saNamespace, + serviceAccountNameField: saName, + serviceAccountUIDField: saUID, + }, nil +} + +func (p *Validator) AvailableFields() []string { + return []string{ + serviceAccountNamespaceField, + serviceAccountNameField, + serviceAccountUIDField, + } +} + +func (v *Validator) MakeFieldMapSelectable(fieldMap map[string]string) interface{} { + return &k8sFieldDetails{ + ServiceAccount: k8sFieldDetailsServiceAccount{ + Namespace: fieldMap[serviceAccountNamespaceField], + Name: fieldMap[serviceAccountNameField], + UID: fieldMap[serviceAccountUIDField], + }, + } +} + +type k8sFieldDetails struct { + ServiceAccount k8sFieldDetailsServiceAccount `bexpr:"serviceaccount"` +} + +type k8sFieldDetailsServiceAccount struct { + Namespace string `bexpr:"namespace"` + Name string `bexpr:"name"` + UID string `bexpr:"uid"` +} diff --git a/agent/consul/authmethod/kubeauth/k8s_test.go b/agent/consul/authmethod/kubeauth/k8s_test.go new file mode 100644 index 0000000000..614538c40e --- /dev/null +++ b/agent/consul/authmethod/kubeauth/k8s_test.go @@ -0,0 +1,144 @@ +package kubeauth + +import ( + "testing" + + "github.com/hashicorp/consul/agent/connect" + "github.com/hashicorp/consul/agent/structs" + "github.com/stretchr/testify/require" +) + +func TestValidateLogin(t *testing.T) { + testSrv := StartTestAPIServer(t) + defer testSrv.Stop() + + testSrv.AuthorizeJWT(goodJWT_A) + testSrv.SetAllowedServiceAccount( + "default", + "demo", + "76091af4-4b56-11e9-ac4b-708b11801cbe", + "", + goodJWT_B, + ) + + method := &structs.ACLAuthMethod{ + Name: "test-k8s", + Description: "k8s test", + Type: "kubernetes", + Config: map[string]interface{}{ + "Host": testSrv.Addr(), + "CACert": testSrv.CACert(), + "ServiceAccountJWT": goodJWT_A, + }, + } + validator, err := NewValidator(method) + require.NoError(t, err) + + t.Run("invalid bearer token", func(t *testing.T) { + _, err := validator.ValidateLogin("invalid") + require.Error(t, err) + }) + + t.Run("valid bearer token", func(t *testing.T) { + fields, err := validator.ValidateLogin(goodJWT_B) + require.NoError(t, err) + require.Equal(t, map[string]string{ + "serviceaccount.namespace": "default", + "serviceaccount.name": "demo", + "serviceaccount.uid": "76091af4-4b56-11e9-ac4b-708b11801cbe", + }, fields) + }) + + // annotate the account + testSrv.SetAllowedServiceAccount( + "default", + "demo", + "76091af4-4b56-11e9-ac4b-708b11801cbe", + "alternate-name", + goodJWT_B, + ) + + t.Run("valid bearer token with annotation", func(t *testing.T) { + fields, err := validator.ValidateLogin(goodJWT_B) + require.NoError(t, err) + require.Equal(t, map[string]string{ + "serviceaccount.namespace": "default", + "serviceaccount.name": "alternate-name", + "serviceaccount.uid": "76091af4-4b56-11e9-ac4b-708b11801cbe", + }, fields) + }) +} + +func TestNewValidator(t *testing.T) { + ca := connect.TestCA(t, nil) + + type AM = *structs.ACLAuthMethod + + makeAuthMethod := func(f func(method AM)) *structs.ACLAuthMethod { + method := &structs.ACLAuthMethod{ + Name: "test-k8s", + Description: "k8s test", + Type: "kubernetes", + Config: map[string]interface{}{ + "Host": "https://abc:8443", + "CACert": ca.RootCert, + "ServiceAccountJWT": goodJWT_A, + }, + } + if f != nil { + f(method) + } + return method + } + + for _, test := range []struct { + name string + method *structs.ACLAuthMethod + ok bool + }{ + // bad + {"wrong type", makeAuthMethod(func(method AM) { + method.Type = "invalid" + }), false}, + {"extra config", makeAuthMethod(func(method AM) { + method.Config["extra"] = "config" + }), false}, + {"wrong type of config", makeAuthMethod(func(method AM) { + method.Config["Host"] = []int{12345} + }), false}, + {"missing host", makeAuthMethod(func(method AM) { + delete(method.Config, "Host") + }), false}, + {"missing ca cert", makeAuthMethod(func(method AM) { + delete(method.Config, "CACert") + }), false}, + {"invalid ca cert", makeAuthMethod(func(method AM) { + method.Config["CACert"] = "invalid" + }), false}, + {"invalid jwt", makeAuthMethod(func(method AM) { + method.Config["ServiceAccountJWT"] = "invalid" + }), false}, + {"garbage host", makeAuthMethod(func(method AM) { + method.Config["Host"] = "://:12345" + }), false}, + // good + {"normal", makeAuthMethod(nil), true}, + } { + t.Run(test.name, func(t *testing.T) { + v, err := NewValidator(test.method) + if test.ok { + require.NoError(t, err) + require.NotNil(t, v) + } else { + require.NotNil(t, err) + require.Nil(t, v) + } + }) + } +} + +// 'default/admin' +const goodJWT_A = "eyJhbGciOiJSUzI1NiIsImtpZCI6IiJ9.eyJpc3MiOiJrdWJlcm5ldGVzL3NlcnZpY2VhY2NvdW50Iiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9uYW1lc3BhY2UiOiJkZWZhdWx0Iiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9zZWNyZXQubmFtZSI6ImFkbWluLXRva2VuLXFsejQyIiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9zZXJ2aWNlLWFjY291bnQubmFtZSI6ImFkbWluIiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9zZXJ2aWNlLWFjY291bnQudWlkIjoiNzM4YmMyNTEtNjUzMi0xMWU5LWI2N2YtNDhlNmM4YjhlY2I1Iiwic3ViIjoic3lzdGVtOnNlcnZpY2VhY2NvdW50OmRlZmF1bHQ6YWRtaW4ifQ.ixMlnWrAG7NVuTTKu8cdcYfM7gweS3jlKaEsIBNGOVEjPE7rtXtgMkAwjQTdYR08_0QBjkgzy5fQC5ZNyglSwONJ-bPaXGvhoH1cTnRi1dz9H_63CfqOCvQP1sbdkMeRxNTGVAyWZT76rXoCUIfHP4LY2I8aab0KN9FTIcgZRF0XPTtT70UwGIrSmRpxW38zjiy2ymWL01cc5VWGhJqVysmWmYk3wNp0h5N57H_MOrz4apQR4pKaamzskzjLxO55gpbmZFC76qWuUdexAR7DT2fpbHLOw90atN_NlLMY-VrXyW3-Ei5EhYaVreMB9PSpKwkrA4jULITohV-sxpa1LA" + +// 'default/demo' +const goodJWT_B = "eyJhbGciOiJSUzI1NiIsImtpZCI6IiJ9.eyJpc3MiOiJrdWJlcm5ldGVzL3NlcnZpY2VhY2NvdW50Iiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9uYW1lc3BhY2UiOiJkZWZhdWx0Iiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9zZWNyZXQubmFtZSI6ImRlbW8tdG9rZW4ta21iOW4iLCJrdWJlcm5ldGVzLmlvL3NlcnZpY2VhY2NvdW50L3NlcnZpY2UtYWNjb3VudC5uYW1lIjoiZGVtbyIsImt1YmVybmV0ZXMuaW8vc2VydmljZWFjY291bnQvc2VydmljZS1hY2NvdW50LnVpZCI6Ijc2MDkxYWY0LTRiNTYtMTFlOS1hYzRiLTcwOGIxMTgwMWNiZSIsInN1YiI6InN5c3RlbTpzZXJ2aWNlYWNjb3VudDpkZWZhdWx0OmRlbW8ifQ.ZiAHjijBAOsKdum0Aix6lgtkLkGo9_Tu87dWQ5Zfwnn3r2FejEWDAnftTft1MqqnMzivZ9Wyyki5ZjQRmTAtnMPJuHC-iivqY4Wh4S6QWCJ1SivBv5tMZR79t5t8mE7R1-OHwst46spru1pps9wt9jsA04d3LpV0eeKYgdPTVaQKklxTm397kIMUugA6yINIBQ3Rh8eQqBgNwEmL4iqyYubzHLVkGkoP9MJikFI05vfRiHtYr-piXz6JFDzXMQj9rW6xtMmrBSn79ChbyvC5nz-Nj2rJPnHsb_0rDUbmXY5PpnMhBpdSH-CbZ4j8jsiib6DtaGJhVZeEQ1GjsFAZwQ" diff --git a/agent/consul/authmethod/kubeauth/testing.go b/agent/consul/authmethod/kubeauth/testing.go new file mode 100644 index 0000000000..7e6340dd9d --- /dev/null +++ b/agent/consul/authmethod/kubeauth/testing.go @@ -0,0 +1,532 @@ +package kubeauth + +import ( + "bytes" + "encoding/json" + "encoding/pem" + "io/ioutil" + "net/http" + "net/http/httptest" + "net/url" + "regexp" + "strings" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" + authv1 "k8s.io/api/authentication/v1" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" +) + +// TestAPIServer is a way to mock the Kubernetes API server as it is used by +// the consul kubernetes auth method. +// +// - POST /apis/authentication.k8s.io/v1/tokenreviews +// - GET /api/v1/namespaces//serviceaccounts/ +// +type TestAPIServer struct { + t *testing.T + srv *httptest.Server + caCert string + + mu sync.Mutex + authorizedJWT string // token review and sa read + allowedServiceAccountJWT string // general service account + replyStatus *authv1.TokenReview // general service account + replyRead *corev1.ServiceAccount // general service account +} + +// StartTestAPIServer creates a disposable TestAPIServer and binds it to a +// random free port. +func StartTestAPIServer(t *testing.T) *TestAPIServer { + s := &TestAPIServer{t: t} + + s.srv = httptest.NewTLSServer(s) + + bs := s.srv.TLS.Certificates[0].Certificate[0] + + var buf bytes.Buffer + require.NoError(t, pem.Encode(&buf, &pem.Block{Type: "CERTIFICATE", Bytes: bs})) + s.caCert = buf.String() + + return s +} + +// AuthorizeJWT whitelists the given JWT as able to use the API server. +func (s *TestAPIServer) AuthorizeJWT(jwt string) { + s.mu.Lock() + defer s.mu.Unlock() + + s.authorizedJWT = jwt +} + +// SetAllowedServiceAccount configures the singular known Service Account +// installed in this API server. If any of namespace/name/uid/jwt are empty +// it removes anything previously configured. +// +// It is up to the caller to ensure that the provided JWT matches the other +// data. +func (s *TestAPIServer) SetAllowedServiceAccount( + namespace, name, uid, overrideAnnotation, jwt string, +) { + s.mu.Lock() + defer s.mu.Unlock() + + if namespace == "" || name == "" || uid == "" || jwt == "" { + s.allowedServiceAccountJWT = "" + s.replyStatus = nil + s.replyRead = nil + return + } + + s.allowedServiceAccountJWT = jwt + s.replyRead = createReadServiceAccountFound(namespace, name, uid, overrideAnnotation, jwt) + s.replyStatus = createTokenReviewFound(namespace, name, uid, jwt) +} + +// Stop stops the running TestAPIServer. +func (s *TestAPIServer) Stop() { + s.srv.Close() +} + +// Addr returns the current base URL for the running webserver. +func (s *TestAPIServer) Addr() string { return s.srv.URL } + +// CACert returns the pem-encoded CA certificate used by the HTTPS server. +func (s *TestAPIServer) CACert() string { return s.caCert } + +var readServiceAccountPathRE = regexp.MustCompile("^/api/v1/namespaces/([^/]+)/serviceaccounts/([^/]+)$") + +func (s *TestAPIServer) ServeHTTP(w http.ResponseWriter, req *http.Request) { + s.mu.Lock() + defer s.mu.Unlock() + + w.Header().Set("content-type", "application/json") + + if req.URL.Path == "/apis/authentication.k8s.io/v1/tokenreviews" { + s.handleTokenReview(w, req) + return + } + + if m := readServiceAccountPathRE.FindStringSubmatch(req.URL.Path); m != nil { + namespace, err := url.QueryUnescape(m[1]) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + name, err := url.QueryUnescape(m[2]) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + + s.handleReadServiceAccount(namespace, name, w, req) + return + } + + w.WriteHeader(http.StatusNotFound) +} + +func writeJSON(w http.ResponseWriter, out interface{}) error { + enc := json.NewEncoder(w) + return enc.Encode(out) +} + +func (s *TestAPIServer) handleTokenReview(w http.ResponseWriter, req *http.Request) { + if req.Method != "POST" { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + + if auth, anon := s.isAuthenticated(req); !auth { + var out interface{} + if anon { + out = createTokenReviewForbidden_NoAuthz() + } else { + out = createTokenReviewForbidden("default", "fake-account") + } + + w.WriteHeader(http.StatusForbidden) + if err := writeJSON(w, out); err != nil { + w.WriteHeader(http.StatusInternalServerError) + } + return + } + + if req.Body == nil { + w.WriteHeader(http.StatusBadRequest) + return + } + defer req.Body.Close() + + b, err := ioutil.ReadAll(req.Body) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + + var trReq authv1.TokenReview + if err := json.Unmarshal(b, &trReq); err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + reviewingJWT := trReq.Spec.Token + + var out interface{} + if s.replyStatus == nil || reviewingJWT != s.allowedServiceAccountJWT { + out = createTokenReviewNotFound(reviewingJWT) + } else { + out = s.replyStatus + } + w.WriteHeader(http.StatusCreated) + + if err := writeJSON(w, out); err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } +} + +func (s *TestAPIServer) handleReadServiceAccount( + namespace, name string, + w http.ResponseWriter, + req *http.Request, +) { + if req.Method != "GET" { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + + var out interface{} + if auth, anon := s.isAuthenticated(req); !auth { + if anon { + out = createReadServiceAccountForbidden_NoAuthz() + } else { + out = createReadServiceAccountForbidden(namespace, name) + } + w.WriteHeader(http.StatusForbidden) + } else if s.replyRead == nil { + out = createReadServiceAccountNotFound(namespace, name) + w.WriteHeader(http.StatusNotFound) + } else if s.replyRead.Namespace != namespace || s.replyRead.Name != name { + out = createReadServiceAccountNotFound(namespace, name) + w.WriteHeader(http.StatusNotFound) + } else { + out = s.replyRead + w.WriteHeader(http.StatusOK) + } + + if err := writeJSON(w, out); err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } +} + +func (s *TestAPIServer) isAuthenticated(req *http.Request) (auth, anonymous bool) { + authz := req.Header.Get("Authorization") + if !strings.HasPrefix(authz, "Bearer ") { + return false, true + } + jwt := strings.TrimPrefix(authz, "Bearer ") + + return s.authorizedJWT == jwt, false +} + +func createTokenReviewForbidden_NoAuthz() *metav1.Status { + /* + STATUS: 403 + { + "kind": "Status", + "apiVersion": "v1", + "metadata": {}, + "status": "Failure", + "message": "tokenreviews.authentication.k8s.io is forbidden: User \"system:anonymous\" cannot create resource \"tokenreviews\" in API group \"authentication.k8s.io\" at the cluster scope", + "reason": "Forbidden", + "details": { + "group": "authentication.k8s.io", + "kind": "tokenreviews" + }, + "code": 403 + } + */ + return createStatus( + metav1.StatusFailure, + "tokenreviews.authentication.k8s.io is forbidden: User \"system:anonymous\" cannot create resource \"tokenreviews\" in API group \"authentication.k8s.io\" in the cluster scope", + metav1.StatusReasonForbidden, + &metav1.StatusDetails{ + Group: "authentication.k8s.io", + Kind: "tokenreviews", + }, + 403, + ) +} + +func createTokenReviewForbidden(namespace, name string) *metav1.Status { + /* + STATUS: 403 + { + "kind": "Status", + "apiVersion": "v1", + "metadata": {}, + "status": "Failure", + "message": "tokenreviews.authentication.k8s.io is forbidden: User \"system:serviceaccount:default:admin\" cannot create resource \"tokenreviews\" in API group \"authentication.k8s.io\" at the cluster scope", + "reason": "Forbidden", + "details": { + "group": "authentication.k8s.io", + "kind": "tokenreviews" + }, + "code": 403 + } + */ + return createStatus( + metav1.StatusFailure, + "tokenreviews.authentication.k8s.io is forbidden: User \"system:serviceaccount:"+namespace+":"+name+"\" cannot create resource \"tokenreviews\" in API group \"authentication.k8s.io\" in the cluster scope", + metav1.StatusReasonForbidden, + &metav1.StatusDetails{ + Group: "authentication.k8s.io", + Kind: "tokenreviews", + }, + 403, + ) +} + +func createTokenReviewNotFound(jwt string) *authv1.TokenReview { + /* + STATUS: 201 + { + "kind": "TokenReview", + "apiVersion": "authentication.k8s.io/v1", + "metadata": { + "creationTimestamp": null + }, + "spec": { + "token": "eyJhbGciOiJSUzI1NiIsImtpZCI6IiJ9.eyJpc3MiOiJrdWJlcm5ldGVzL3NlcnZpY2VhY2NvdW50Iiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9uYW1lc3BhY2UiOiJkZWZhdWx0Iiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9zZWNyZXQubmFtZSI6ImZha2UtdG9rZW4tano2YnYiLCJrdWJlcm5ldGVzLmlvL3NlcnZpY2VhY2NvdW50L3NlcnZpY2UtYWNjb3VudC5uYW1lIjoiZmFrZSIsImt1YmVybmV0ZXMuaW8vc2VydmljZWFjY291bnQvc2VydmljZS1hY2NvdW50LnVpZCI6IjgxYTY1Mjg2LTU3YzEtMTFlOS1iYzJhLTQ4ZTZjOGI4ZWNiNSIsInN1YiI6InN5c3RlbTpzZXJ2aWNlYWNjb3VudDpkZWZhdWx0OmZha2UifQ.DqjUXe34SzCP4NCwbhqV9EuksfzmTSLhJzkE_URyufeGJDn-Gw0_JS-_KmxZSdAO0XXNzB1tJNM1NCVW-V6YbThnPUw5WY4V2J6U1W72c2dzNBx_ipBxGBZ632ZnpViIRu6tL2guT36lWa8YnMDF_OY8sHhl_3kJ6MRxNxY41vAuf45mohi3gri46Kpzc3pf1g6PJ-0oogvUsZ2nBFv1mIdciGBV0zejMKc5Bnxur1L-hEQ9EgZrJ7o0yQRCWYgam_yo_M38EsB8b-suTzQJMA-pRgApOb9dHIV6YAE_b3g_pGkJjrPYzV4IJC1CiPfdz1SAjm7e0ARXtZmaoPltjQ" + }, + "status": { + "user": {}, + "error": "[invalid bearer token, Token has been invalidated]" + } + } + */ + return &authv1.TokenReview{ + TypeMeta: metav1.TypeMeta{ + Kind: "TokenReview", + APIVersion: "authentication.k8s.io/v1", + }, + ObjectMeta: metav1.ObjectMeta{}, + Spec: authv1.TokenReviewSpec{ + Token: jwt, + }, + Status: authv1.TokenReviewStatus{ + User: authv1.UserInfo{}, + Error: "[invalid bearer token, Token has been invalidated]", + }, + } +} + +func createTokenReviewFound(namespace, name, uid, jwt string) *authv1.TokenReview { + /* + STATUS: 201 + { + "kind": "TokenReview", + "apiVersion": "authentication.k8s.io/v1", + "metadata": { + "creationTimestamp": null + }, + "spec": { + "token": "eyJhbGciOiJSUzI1NiIsImtpZCI6IiJ9.eyJpc3MiOiJrdWJlcm5ldGVzL3NlcnZpY2VhY2NvdW50Iiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9uYW1lc3BhY2UiOiJkZWZhdWx0Iiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9zZWNyZXQubmFtZSI6ImRlbW8tdG9rZW4tbTljdm4iLCJrdWJlcm5ldGVzLmlvL3NlcnZpY2VhY2NvdW50L3NlcnZpY2UtYWNjb3VudC5uYW1lIjoiZGVtbyIsImt1YmVybmV0ZXMuaW8vc2VydmljZWFjY291bnQvc2VydmljZS1hY2NvdW50LnVpZCI6IjlmZjUxZmY0LTU1N2UtMTFlOS05Njg3LTQ4ZTZjOGI4ZWNiNSIsInN1YiI6InN5c3RlbTpzZXJ2aWNlYWNjb3VudDpkZWZhdWx0OmRlbW8ifQ.UJEphtrN261gy9WCl4ZKjm2PRDLDkc3Xg9VcDGfzyroOqFQ6sog5dVAb9voc5Nc0-H5b1yGwxDViEMucwKvZpA5pi7VEx_OskK-KTWXSmafM0Xg_AvzpU9Ed5TSRno-OhXaAraxdjXoC4myh1ay2DMeHUusJg_ibqcYJrWx-6MO1bH_ObORtAKhoST_8fzkqNAlZmsQ87FinQvYN5mzDXYukl-eeRdBgQUBkWvEb-Ju6cc0-QE4sUQ4IH_fs0fUyX_xc0om0SZGWLP909FTz4V8LxV8kr6L7irxROiS1jn3Fvyc9ur1PamVf3JOPPrOyfmKbaGRiWJM32b3buQw7cg" + }, + "status": { + "authenticated": true, + "user": { + "username": "system:serviceaccount:default:demo", + "uid": "9ff51ff4-557e-11e9-9687-48e6c8b8ecb5", + "groups": [ + "system:serviceaccounts", + "system:serviceaccounts:default", + "system:authenticated" + ] + } + } + } + */ + return &authv1.TokenReview{ + TypeMeta: metav1.TypeMeta{ + Kind: "TokenReview", + APIVersion: "authentication.k8s.io/v1", + }, + ObjectMeta: metav1.ObjectMeta{}, + Spec: authv1.TokenReviewSpec{ + Token: jwt, + }, + Status: authv1.TokenReviewStatus{ + Authenticated: true, + User: authv1.UserInfo{ + Username: "system:serviceaccount:" + namespace + ":" + name, + UID: uid, + Groups: []string{ + "system:serviceaccounts", + "system:serviceaccounts:default", + "system:authenticated", + }, + }, + }, + } +} + +func createReadServiceAccountForbidden(namespace, name string) *metav1.Status { + /* + STATUS: 403 + { + "kind": "Status", + "apiVersion": "v1", + "metadata": {}, + "status": "Failure", + "message": "serviceaccounts \"demo\" is forbidden: User \"system:serviceaccount:default:admin\" cannot get resource \"serviceaccounts\" in API group \"\" in the namespace \"default\"", + "reason": "Forbidden", + "details": { + "name": "demo", + "kind": "serviceaccounts" + }, + "code": 403 + } + */ + return createStatus( + metav1.StatusFailure, + "serviceaccounts \""+name+"\" is forbidden: User \"system:serviceaccount:"+namespace+":"+name+"\" cannot get resource \"serviceaccounts\" in API group \"\" in the namespace \""+namespace+"\"", + metav1.StatusReasonForbidden, + &metav1.StatusDetails{ + Kind: "serviceaccounts", + Name: name, + }, + 403, + ) +} + +func createReadServiceAccountForbidden_NoAuthz() *metav1.Status { + // missing bearer token header 403 + /* + { + "kind": "Status", + "apiVersion": "v1", + "metadata": {}, + "status": "Failure", + "message": "serviceaccounts \"demo\" is forbidden: User \"system:anonymous\" cannot get resource \"serviceaccounts\" in API group \"\" in the namespace \"default\"", + "reason": "Forbidden", + "details": { + "name": "demo", + "kind": "serviceaccounts" + }, + "code": 403 + } + */ + return createStatus( + metav1.StatusFailure, + "serviceaccounts \"PLACEHOLDER\" is forbidden: User \"system:anonymous\" cannot get resource \"serviceaccounts\" in API group \"\" in the namespace \"default\"", + metav1.StatusReasonForbidden, + &metav1.StatusDetails{ + Kind: "serviceaccounts", + Name: "PLACEHOLDER", + }, + 403, + ) +} + +func createReadServiceAccountNotFound(namespace, name string) *metav1.Status { + /* + STATUS: 404 + { + "kind": "Status", + "apiVersion": "v1", + "metadata": {}, + "status": "Failure", + "message": "serviceaccounts \"demo\" not found", + "reason": "NotFound", + "details": { + "name": "demo", + "kind": "serviceaccounts" + }, + "code": 404 + } + */ + return createStatus( + metav1.StatusFailure, + "serviceaccounts \""+name+"\" not found", + metav1.StatusReasonNotFound, + &metav1.StatusDetails{ + Kind: "serviceaccounts", + Name: name, + }, + 404, + ) +} + +func createReadServiceAccountFound(namespace, name, uid, overrideAnnotation, jwt string) *corev1.ServiceAccount { + /* + STATUS: 200 + { + "kind": "ServiceAccount", + "apiVersion": "v1", + "metadata": { + "name": "demo", + "namespace": "default", + "selfLink": "/api/v1/namespaces/default/serviceaccounts/demo", + "uid": "9ff51ff4-557e-11e9-9687-48e6c8b8ecb5", + "resourceVersion": "2101", + "creationTimestamp": "2019-04-02T19:36:34Z", + "annotations": { + "consul.hashicorp.com/service-name": "actual", + "kubectl.kubernetes.io/last-applied-configuration": "{\"apiVersion\":\"v1\",\"kind\":\"ServiceAccount\",\"metadata\":{\"annotations\":{\"consul.hashicorp.com/service-name\":\"actual\"},\"name\":\"demo\",\"namespace\":\"default\"}}\n" + } + }, + "secrets": [ + { + "name": "demo-token-m9cvn" + } + ] + } + */ + sa := &corev1.ServiceAccount{ + TypeMeta: metav1.TypeMeta{ + Kind: "ServiceAccount", + APIVersion: "v1", + }, + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Namespace: namespace, + SelfLink: "/api/v1/namespaces/" + namespace + "/serviceaccounts/" + name, + UID: types.UID(uid), + ResourceVersion: "123", + CreationTimestamp: metav1.Time{Time: time.Now()}, + }, + Secrets: []corev1.ObjectReference{ + corev1.ObjectReference{ + Name: name + "-token-m9cvn", + }, + }, + } + if overrideAnnotation != "" { + sa.ObjectMeta.Annotations = map[string]string{ + "consul.hashicorp.com/service-name": overrideAnnotation, + } + } + + return sa +} + +func createStatus(status, message string, reason metav1.StatusReason, details *metav1.StatusDetails, code int32) *metav1.Status { + return &metav1.Status{ + TypeMeta: metav1.TypeMeta{ + Kind: "Status", + APIVersion: "v1", + }, + ListMeta: metav1.ListMeta{}, + Status: status, + Message: message, + Reason: reason, + Details: details, + Code: code, + } +} diff --git a/agent/consul/authmethod/testauth/testing.go b/agent/consul/authmethod/testauth/testing.go new file mode 100644 index 0000000000..638450d94d --- /dev/null +++ b/agent/consul/authmethod/testauth/testing.go @@ -0,0 +1,166 @@ +package testauth + +import ( + "fmt" + "sync" + + "github.com/hashicorp/consul/acl" + "github.com/hashicorp/consul/agent/consul/authmethod" + "github.com/hashicorp/consul/agent/structs" + "github.com/hashicorp/go-uuid" +) + +func init() { + authmethod.Register("testing", newValidator) +} + +var ( + tokenDatabaseMu sync.Mutex + tokenDatabase map[string]map[string]map[string]string // session => token => fieldmap +) + +func StartSession() string { + sessionID, err := uuid.GenerateUUID() + if err != nil { + panic(err) + } + return sessionID +} + +func ResetSession(sessionID string) { + tokenDatabaseMu.Lock() + defer tokenDatabaseMu.Unlock() + if tokenDatabase != nil { + delete(tokenDatabase, sessionID) + } +} + +func InstallSessionToken(sessionID string, token string, namespace, name, uid string) { + fields := map[string]string{ + serviceAccountNamespaceField: namespace, + serviceAccountNameField: name, + serviceAccountUIDField: uid, + } + + tokenDatabaseMu.Lock() + defer tokenDatabaseMu.Unlock() + if tokenDatabase == nil { + tokenDatabase = make(map[string]map[string]map[string]string) + } + sdb, ok := tokenDatabase[sessionID] + if !ok { + sdb = make(map[string]map[string]string) + tokenDatabase[sessionID] = sdb + } + sdb[token] = fields +} + +func GetSessionToken(sessionID string, token string) (map[string]string, bool) { + tokenDatabaseMu.Lock() + defer tokenDatabaseMu.Unlock() + if tokenDatabase == nil { + return nil, false + } + sdb, ok := tokenDatabase[sessionID] + if !ok { + return nil, false + } + fields, ok := sdb[token] + if !ok { + return nil, false + } + + fmCopy := make(map[string]string) + for k, v := range fields { + fmCopy[k] = v + } + + return fmCopy, true +} + +type Config struct { + SessionID string // unique identifier for this set of tokens in the database +} + +func newValidator(method *structs.ACLAuthMethod) (authmethod.Validator, error) { + if method.Type != "testing" { + return nil, fmt.Errorf("%q is not a testing auth method", method.Name) + } + + var config Config + if err := authmethod.ParseConfig(method.Config, &config); err != nil { + return nil, err + } + + if config.SessionID == "" { + // If you don't explicitly create one, we create a random one but you + // won't have access to it. Useful if you are testing everything EXCEPT + // ValidateToken(). + config.SessionID = StartSession() + } + + return &Validator{ + name: method.Name, + config: &config, + }, nil +} + +type Validator struct { + name string + config *Config +} + +func (v *Validator) Name() string { return v.name } + +// ValidateLogin takes raw user-provided auth method metadata and ensures it is +// sane, provably correct, and currently valid. Relevant identifying data is +// extracted and returned for immediate use by the role binding process. +// +// Depending upon the method, it may make sense to use these calls to continue +// to extend the life of the underlying token. +// +// Returns auth method specific metadata suitable for the Role Binding process. +func (v *Validator) ValidateLogin(loginToken string) (map[string]string, error) { + fields, valid := GetSessionToken(v.config.SessionID, loginToken) + if !valid { + return nil, acl.ErrNotFound + } + + return fields, nil +} + +func (v *Validator) AvailableFields() []string { return availableFields } + +const ( + serviceAccountNamespaceField = "serviceaccount.namespace" + serviceAccountNameField = "serviceaccount.name" + serviceAccountUIDField = "serviceaccount.uid" +) + +var availableFields = []string{ + serviceAccountNamespaceField, + serviceAccountNameField, + serviceAccountUIDField, +} + +// MakeFieldMapSelectable converts a field map as returned by ValidateLogin +// into a structure suitable for selection with a binding rule. +func (v *Validator) MakeFieldMapSelectable(fieldMap map[string]string) interface{} { + return &selectableVars{ + ServiceAccount: selectableServiceAccount{ + Namespace: fieldMap[serviceAccountNamespaceField], + Name: fieldMap[serviceAccountNameField], + UID: fieldMap[serviceAccountUIDField], + }, + } +} + +type selectableVars struct { + ServiceAccount selectableServiceAccount `bexpr:"serviceaccount"` +} + +type selectableServiceAccount struct { + Namespace string `bexpr:"namespace"` + Name string `bexpr:"name"` + UID string `bexpr:"uid"` +} diff --git a/agent/consul/fsm/commands_oss.go b/agent/consul/fsm/commands_oss.go index 36a09174df..f093aa0abd 100644 --- a/agent/consul/fsm/commands_oss.go +++ b/agent/consul/fsm/commands_oss.go @@ -32,6 +32,10 @@ func init() { registerCommand(structs.ConfigEntryRequestType, (*FSM).applyConfigEntryOperation) registerCommand(structs.ACLRoleSetRequestType, (*FSM).applyACLRoleSetOperation) registerCommand(structs.ACLRoleDeleteRequestType, (*FSM).applyACLRoleDeleteOperation) + registerCommand(structs.ACLBindingRuleSetRequestType, (*FSM).applyACLBindingRuleSetOperation) + registerCommand(structs.ACLBindingRuleDeleteRequestType, (*FSM).applyACLBindingRuleDeleteOperation) + registerCommand(structs.ACLAuthMethodSetRequestType, (*FSM).applyACLAuthMethodSetOperation) + registerCommand(structs.ACLAuthMethodDeleteRequestType, (*FSM).applyACLAuthMethodDeleteOperation) } func (c *FSM) applyRegister(buf []byte, index uint64) interface{} { @@ -476,3 +480,47 @@ func (c *FSM) applyACLRoleDeleteOperation(buf []byte, index uint64) interface{} return c.state.ACLRoleBatchDelete(index, req.RoleIDs) } + +func (c *FSM) applyACLBindingRuleSetOperation(buf []byte, index uint64) interface{} { + var req structs.ACLBindingRuleBatchSetRequest + if err := structs.Decode(buf, &req); err != nil { + panic(fmt.Errorf("failed to decode request: %v", err)) + } + defer metrics.MeasureSinceWithLabels([]string{"fsm", "acl", "bindingrule"}, time.Now(), + []metrics.Label{{Name: "op", Value: "upsert"}}) + + return c.state.ACLBindingRuleBatchSet(index, req.BindingRules) +} + +func (c *FSM) applyACLBindingRuleDeleteOperation(buf []byte, index uint64) interface{} { + var req structs.ACLBindingRuleBatchDeleteRequest + if err := structs.Decode(buf, &req); err != nil { + panic(fmt.Errorf("failed to decode request: %v", err)) + } + defer metrics.MeasureSinceWithLabels([]string{"fsm", "acl", "bindingrule"}, time.Now(), + []metrics.Label{{Name: "op", Value: "delete"}}) + + return c.state.ACLBindingRuleBatchDelete(index, req.BindingRuleIDs) +} + +func (c *FSM) applyACLAuthMethodSetOperation(buf []byte, index uint64) interface{} { + var req structs.ACLAuthMethodBatchSetRequest + if err := structs.Decode(buf, &req); err != nil { + panic(fmt.Errorf("failed to decode request: %v", err)) + } + defer metrics.MeasureSinceWithLabels([]string{"fsm", "acl", "authmethod"}, time.Now(), + []metrics.Label{{Name: "op", Value: "upsert"}}) + + return c.state.ACLAuthMethodBatchSet(index, req.AuthMethods) +} + +func (c *FSM) applyACLAuthMethodDeleteOperation(buf []byte, index uint64) interface{} { + var req structs.ACLAuthMethodBatchDeleteRequest + if err := structs.Decode(buf, &req); err != nil { + panic(fmt.Errorf("failed to decode request: %v", err)) + } + defer metrics.MeasureSinceWithLabels([]string{"fsm", "acl", "authmethod"}, time.Now(), + []metrics.Label{{Name: "op", Value: "delete"}}) + + return c.state.ACLAuthMethodBatchDelete(index, req.AuthMethodNames) +} diff --git a/agent/consul/fsm/snapshot_oss.go b/agent/consul/fsm/snapshot_oss.go index 3ad281434b..195e6cf136 100644 --- a/agent/consul/fsm/snapshot_oss.go +++ b/agent/consul/fsm/snapshot_oss.go @@ -29,6 +29,8 @@ func init() { registerRestorer(structs.ACLPolicySetRequestType, restorePolicy) registerRestorer(structs.ConfigEntryRequestType, restoreConfigEntry) registerRestorer(structs.ACLRoleSetRequestType, restoreRole) + registerRestorer(structs.ACLBindingRuleSetRequestType, restoreBindingRule) + registerRestorer(structs.ACLAuthMethodSetRequestType, restoreAuthMethod) } func persistOSS(s *snapshot, sink raft.SnapshotSink, encoder *codec.Encoder) error { @@ -218,6 +220,34 @@ func (s *snapshot) persistACLs(sink raft.SnapshotSink, } } + rules, err := s.state.ACLBindingRules() + if err != nil { + return err + } + + for rule := rules.Next(); rule != nil; rule = rules.Next() { + if _, err := sink.Write([]byte{byte(structs.ACLBindingRuleSetRequestType)}); err != nil { + return err + } + if err := encoder.Encode(rule.(*structs.ACLBindingRule)); err != nil { + return err + } + } + + methods, err := s.state.ACLAuthMethods() + if err != nil { + return err + } + + for method := methods.Next(); method != nil; method = rules.Next() { + if _, err := sink.Write([]byte{byte(structs.ACLAuthMethodSetRequestType)}); err != nil { + return err + } + if err := encoder.Encode(method.(*structs.ACLAuthMethod)); err != nil { + return err + } + } + return nil } @@ -626,3 +656,19 @@ func restoreRole(header *snapshotHeader, restore *state.Restore, decoder *codec. } return restore.ACLRole(&req) } + +func restoreBindingRule(header *snapshotHeader, restore *state.Restore, decoder *codec.Decoder) error { + var req structs.ACLBindingRule + if err := decoder.Decode(&req); err != nil { + return err + } + return restore.ACLBindingRule(&req) +} + +func restoreAuthMethod(header *snapshotHeader, restore *state.Restore, decoder *codec.Decoder) error { + var req structs.ACLAuthMethod + if err := decoder.Decode(&req); err != nil { + return err + } + return restore.ACLAuthMethod(&req) +} diff --git a/agent/consul/fsm/snapshot_oss_test.go b/agent/consul/fsm/snapshot_oss_test.go index b8b5bb1327..9571b5f1a1 100644 --- a/agent/consul/fsm/snapshot_oss_test.go +++ b/agent/consul/fsm/snapshot_oss_test.go @@ -86,7 +86,7 @@ func TestFSM_SnapshotRestore_OSS(t *testing.T) { }) session := &structs.Session{ID: generateUUID(), Node: "foo"} fsm.state.SessionCreate(9, session) - policy := structs.ACLPolicy{ + policy := &structs.ACLPolicy{ ID: structs.ACLPolicyGlobalManagementID, Name: "global-management", Description: "Builtin Policy that grants unlimited access", @@ -94,7 +94,20 @@ func TestFSM_SnapshotRestore_OSS(t *testing.T) { Syntax: acl.SyntaxCurrent, } policy.SetHash(true) - require.NoError(fsm.state.ACLPolicySet(1, &policy)) + require.NoError(fsm.state.ACLPolicySet(1, policy)) + + role := &structs.ACLRole{ + ID: "86dedd19-8fae-4594-8294-4e6948a81f9a", + Name: "some-role", + Description: "test snapshot role", + ServiceIdentities: []*structs.ACLServiceIdentity{ + &structs.ACLServiceIdentity{ + ServiceName: "example", + }, + }, + } + role.SetHash(true) + require.NoError(fsm.state.ACLRoleSet(1, role)) token := &structs.ACLToken{ AccessorID: "30fca056-9fbb-4455-b94a-bf0e2bc575d6", @@ -112,6 +125,26 @@ func TestFSM_SnapshotRestore_OSS(t *testing.T) { } require.NoError(fsm.state.ACLBootstrap(10, 0, token, false)) + method := &structs.ACLAuthMethod{ + Name: "some-method", + Type: "testing", + Description: "test snapshot auth method", + Config: map[string]interface{}{ + "SessionID": "952ebfa8-2a42-46f0-bcd3-fd98a842000e", + }, + } + require.NoError(fsm.state.ACLAuthMethodSet(1, method)) + + bindingRule := &structs.ACLBindingRule{ + ID: "85184c52-5997-4a84-9817-5945f2632a17", + Description: "test snapshot binding rule", + AuthMethod: "some-method", + Selector: "serviceaccount.namespace==default", + BindType: structs.BindingRuleBindTypeService, + BindName: "${serviceaccount.name}", + } + require.NoError(fsm.state.ACLBindingRuleSet(1, bindingRule)) + fsm.state.KVSSet(11, &structs.DirEntry{ Key: "/remove", Value: []byte("foo"), @@ -314,21 +347,40 @@ func TestFSM_SnapshotRestore_OSS(t *testing.T) { t.Fatalf("bad index: %d", idx) } - // Verify ACL Token is restored - _, a, err := fsm2.state.ACLTokenGetByAccessor(nil, token.AccessorID) + // Verify ACL Binding Rule is restored + _, bindingRule2, err := fsm2.state.ACLBindingRuleGetByID(nil, bindingRule.ID) require.NoError(err) - require.Equal(token.AccessorID, a.AccessorID) - require.Equal(token.ModifyIndex, a.ModifyIndex) + require.Equal(bindingRule, bindingRule2) + + // Verify ACL Auth Method is restored + _, method2, err := fsm2.state.ACLAuthMethodGetByName(nil, method.Name) + require.NoError(err) + require.Equal(method, method2) + + // Verify ACL Token is restored + _, token2, err := fsm2.state.ACLTokenGetByAccessor(nil, token.AccessorID) + require.NoError(err) + { + // time.Time is tricky to compare generically when it takes a ser/deserialization round trip. + require.True(token.CreateTime.Equal(token2.CreateTime)) + token2.CreateTime = token.CreateTime + } + require.Equal(token, token2) // Verify the acl-token-bootstrap index was restored canBootstrap, index, err := fsm2.state.CanBootstrapACLToken() require.False(canBootstrap) require.True(index > 0) + // Verify ACL Role is restored + _, role2, err := fsm2.state.ACLRoleGetByID(nil, role.ID) + require.NoError(err) + require.Equal(role, role2) + // Verify ACL Policy is restored _, policy2, err := fsm2.state.ACLPolicyGetByID(nil, structs.ACLPolicyGlobalManagementID) require.NoError(err) - require.Equal(policy.Name, policy2.Name) + require.Equal(policy, policy2) // Verify tombstones are restored func() { diff --git a/agent/consul/leader.go b/agent/consul/leader.go index cbef94da47..6f7799cfe6 100644 --- a/agent/consul/leader.go +++ b/agent/consul/leader.go @@ -427,6 +427,10 @@ func (s *Server) initializeACLs(upgrade bool) error { // leader. s.acls.cache.Purge() + // Purge the auth method validators since they could've changed while we + // were not leader. + s.purgeAuthMethodValidators() + // Remove any token affected by CVE-2019-8336 if !s.InACLDatacenter() { _, token, err := s.fsm.State().ACLTokenGetBySecret(nil, redactedToken) diff --git a/agent/consul/server.go b/agent/consul/server.go index b44b417816..d76a46a9b8 100644 --- a/agent/consul/server.go +++ b/agent/consul/server.go @@ -115,6 +115,9 @@ type Server struct { aclTokenReapLock sync.RWMutex aclTokenReapEnabled bool + aclAuthMethodValidators map[string]*authMethodValidatorEntry + aclAuthMethodValidatorLock sync.RWMutex + // DEPRECATED (ACL-Legacy-Compat) - only needed while we support both // useNewACLs is used to determine whether we can use new ACLs or not useNewACLs int32 diff --git a/agent/consul/state/acl.go b/agent/consul/state/acl.go index 8ff538922a..1249844c86 100644 --- a/agent/consul/state/acl.go +++ b/agent/consul/state/acl.go @@ -240,12 +240,20 @@ func tokensTableSchema() *memdb.TableSchema { Indexer: &TokenPoliciesIndex{}, }, "roles": &memdb.IndexSchema{ - Name: "roles", - // Need to allow missing for the anonymous token + Name: "roles", AllowMissing: true, Unique: false, Indexer: &TokenRolesIndex{}, }, + "authmethod": &memdb.IndexSchema{ + Name: "authmethod", + AllowMissing: true, + Unique: false, + Indexer: &memdb.StringFieldIndex{ + Field: "AuthMethod", + Lowercase: false, + }, + }, "local": &memdb.IndexSchema{ Name: "local", AllowMissing: false, @@ -349,10 +357,54 @@ func rolesTableSchema() *memdb.TableSchema { } } +func bindingRulesTableSchema() *memdb.TableSchema { + return &memdb.TableSchema{ + Name: "acl-binding-rules", + Indexes: map[string]*memdb.IndexSchema{ + "id": &memdb.IndexSchema{ + Name: "id", + AllowMissing: false, + Unique: true, + Indexer: &memdb.UUIDFieldIndex{ + Field: "ID", + }, + }, + "authmethod": &memdb.IndexSchema{ + Name: "authmethod", + AllowMissing: false, + Unique: false, + Indexer: &memdb.StringFieldIndex{ + Field: "AuthMethod", + Lowercase: true, + }, + }, + }, + } +} + +func authMethodsTableSchema() *memdb.TableSchema { + return &memdb.TableSchema{ + Name: "acl-auth-methods", + Indexes: map[string]*memdb.IndexSchema{ + "id": &memdb.IndexSchema{ + Name: "id", + AllowMissing: false, + Unique: true, + Indexer: &memdb.StringFieldIndex{ + Field: "Name", + Lowercase: true, + }, + }, + }, + } +} + func init() { registerSchema(tokensTableSchema) registerSchema(policiesTableSchema) registerSchema(rolesTableSchema) + registerSchema(bindingRulesTableSchema) + registerSchema(authMethodsTableSchema) } // ACLTokens is used when saving a snapshot @@ -417,6 +469,46 @@ func (s *Restore) ACLRole(role *structs.ACLRole) error { return nil } +// ACLBindingRules is used when saving a snapshot +func (s *Snapshot) ACLBindingRules() (memdb.ResultIterator, error) { + iter, err := s.tx.Get("acl-binding-rules", "id") + if err != nil { + return nil, err + } + return iter, nil +} + +func (s *Restore) ACLBindingRule(rule *structs.ACLBindingRule) error { + if err := s.tx.Insert("acl-binding-rules", rule); err != nil { + return fmt.Errorf("failed restoring acl binding rule: %s", err) + } + + if err := indexUpdateMaxTxn(s.tx, rule.ModifyIndex, "acl-binding-rules"); err != nil { + return fmt.Errorf("failed updating index: %s", err) + } + return nil +} + +// ACLAuthMethods is used when saving a snapshot +func (s *Snapshot) ACLAuthMethods() (memdb.ResultIterator, error) { + iter, err := s.tx.Get("acl-auth-methods", "id") + if err != nil { + return nil, err + } + return iter, nil +} + +func (s *Restore) ACLAuthMethod(method *structs.ACLAuthMethod) error { + if err := s.tx.Insert("acl-auth-methods", method); err != nil { + return fmt.Errorf("failed restoring acl auth method: %s", err) + } + + if err := indexUpdateMaxTxn(s.tx, method.ModifyIndex, "acl-auth-methods"); err != nil { + return fmt.Errorf("failed updating index: %s", err) + } + return nil +} + // ACLBootstrap is used to perform a one-time ACL bootstrap operation on a // cluster to get the first management token. func (s *Store) ACLBootstrap(idx, resetIndex uint64, token *structs.ACLToken, legacy bool) error { @@ -789,6 +881,15 @@ func (s *Store) aclTokenSetTxn(tx *memdb.Txn, idx uint64, token *structs.ACLToke return err } + if token.AuthMethod != "" { + method, err := s.getAuthMethodWithTxn(tx, nil, token.AuthMethod, "id") + if err != nil { + return err + } else if method == nil { + return fmt.Errorf("No such auth method with Name: %s", token.AuthMethod) + } + } + for _, svcid := range token.ServiceIdentities { if svcid.ServiceName == "" { return fmt.Errorf("Encountered a Token with an empty service identity name in the state store") @@ -890,7 +991,7 @@ func (s *Store) aclTokenGetTxn(tx *memdb.Txn, ws memdb.WatchSet, value, index st } // ACLTokenList is used to list out all of the ACLs in the state store. -func (s *Store) ACLTokenList(ws memdb.WatchSet, local, global bool, policy, role string) (uint64, structs.ACLTokens, error) { +func (s *Store) ACLTokenList(ws memdb.WatchSet, local, global bool, policy, role, methodName string) (uint64, structs.ACLTokens, error) { tx := s.db.Txn(false) defer tx.Abort() @@ -901,57 +1002,53 @@ func (s *Store) ACLTokenList(ws memdb.WatchSet, local, global bool, policy, role // to false but for defaulted structs (zero values for both) we want it to list out // all tokens so our checks just ensure that global == local - if policy != "" && role != "" { - return 0, nil, fmt.Errorf("cannot filter by role and policy at the same time") - } + needLocalityFilter := false + if policy == "" && role == "" && methodName == "" { + if global == local { + iter, err = tx.Get("acl-tokens", "id") + } else if global { + iter, err = tx.Get("acl-tokens", "local", false) + } else { + iter, err = tx.Get("acl-tokens", "local", true) + } - if policy != "" { + } else if policy != "" && role == "" && methodName == "" { iter, err = tx.Get("acl-tokens", "policies", policy) - if err == nil && global != local { - iter = memdb.NewFilterIterator(iter, func(raw interface{}) bool { - token, ok := raw.(*structs.ACLToken) - if !ok { - return true - } + needLocalityFilter = true - if global && !token.Local { - return false - } else if local && token.Local { - return false - } - - return true - }) - } - } else if role != "" { + } else if policy == "" && role != "" && methodName == "" { iter, err = tx.Get("acl-tokens", "roles", role) - if err == nil && global != local { - iter = memdb.NewFilterIterator(iter, func(raw interface{}) bool { - token, ok := raw.(*structs.ACLToken) - if !ok { - return true - } + needLocalityFilter = true - if global && !token.Local { - return false - } else if local && token.Local { - return false - } + } else if policy == "" && role == "" && methodName != "" { + iter, err = tx.Get("acl-tokens", "authmethod", methodName) + needLocalityFilter = true - return true - }) - } - } else if global == local { - iter, err = tx.Get("acl-tokens", "id") - } else if global { - iter, err = tx.Get("acl-tokens", "local", false) } else { - iter, err = tx.Get("acl-tokens", "local", true) + return 0, nil, fmt.Errorf("can only filter by one of policy, role, or methodName at a time") } if err != nil { return 0, nil, fmt.Errorf("failed acl token lookup: %v", err) } + + if needLocalityFilter && global != local { + iter = memdb.NewFilterIterator(iter, func(raw interface{}) bool { + token, ok := raw.(*structs.ACLToken) + if !ok { + return true + } + + if global && !token.Local { + return false + } else if local && token.Local { + return false + } + + return true + }) + } + ws.Add(iter.WatchCh()) var result structs.ACLTokens @@ -1114,6 +1211,35 @@ func (s *Store) aclTokenDeleteTxn(tx *memdb.Txn, idx uint64, value, index string return nil } +func (s *Store) aclTokenDeleteAllForAuthMethodTxn(tx *memdb.Txn, idx uint64, methodName string) error { + // collect them all + iter, err := tx.Get("acl-tokens", "authmethod", methodName) + if err != nil { + return fmt.Errorf("failed acl token lookup: %v", err) + } + + var tokens structs.ACLTokens + for raw := iter.Next(); raw != nil; raw = iter.Next() { + token := raw.(*structs.ACLToken) + tokens = append(tokens, token) + } + + if len(tokens) > 0 { + // delete them all + for _, token := range tokens { + if err := tx.Delete("acl-tokens", token); err != nil { + return fmt.Errorf("failed deleting acl token: %v", err) + } + } + + if err := indexUpdateMaxTxn(tx, idx, "acl-tokens"); err != nil { + return fmt.Errorf("failed updating index: %v", err) + } + } + + return nil +} + func (s *Store) ACLPolicyBatchSet(idx uint64, policies structs.ACLPolicies) error { tx := s.db.Txn(true) defer tx.Abort() @@ -1437,7 +1563,7 @@ func (s *Store) ACLRoleBatchGet(ws memdb.WatchSet, ids []string) (uint64, struct tx := s.db.Txn(false) defer tx.Abort() - roles := make(structs.ACLRoles, 0) + roles := make(structs.ACLRoles, 0, len(ids)) for _, rid := range ids { role, err := s.getRoleWithTxn(tx, ws, rid, "id") if err != nil { @@ -1579,3 +1705,384 @@ func (s *Store) aclRoleDeleteTxn(tx *memdb.Txn, idx uint64, value, index string) } return nil } + +func (s *Store) ACLBindingRuleBatchSet(idx uint64, rules structs.ACLBindingRules) error { + tx := s.db.Txn(true) + defer tx.Abort() + + for _, rule := range rules { + if err := s.aclBindingRuleSetTxn(tx, idx, rule); err != nil { + return err + } + } + + if err := indexUpdateMaxTxn(tx, idx, "acl-binding-rules"); err != nil { + return fmt.Errorf("failed updating index: %s", err) + } + + tx.Commit() + return nil +} + +func (s *Store) ACLBindingRuleSet(idx uint64, rule *structs.ACLBindingRule) error { + tx := s.db.Txn(true) + defer tx.Abort() + + if err := s.aclBindingRuleSetTxn(tx, idx, rule); err != nil { + return err + } + if err := indexUpdateMaxTxn(tx, idx, "acl-binding-rules"); err != nil { + return fmt.Errorf("failed updating index: %s", err) + } + + tx.Commit() + return nil +} + +func (s *Store) aclBindingRuleSetTxn(tx *memdb.Txn, idx uint64, rule *structs.ACLBindingRule) error { + // Check that the ID and AuthMethod are set + if rule.ID == "" { + return ErrMissingACLBindingRuleID + } else if rule.AuthMethod == "" { + return ErrMissingACLBindingRuleAuthMethod + } + + existing, err := tx.First("acl-binding-rules", "id", rule.ID) + if err != nil { + return fmt.Errorf("failed acl binding rule lookup: %v", err) + } + + // Set the indexes + if existing != nil { + rule.CreateIndex = existing.(*structs.ACLBindingRule).CreateIndex + rule.ModifyIndex = idx + } else { + rule.CreateIndex = idx + rule.ModifyIndex = idx + } + + if method, err := tx.First("acl-auth-methods", "id", rule.AuthMethod); err != nil { + return fmt.Errorf("failed acl auth method lookup: %v", err) + } else if method == nil { + return fmt.Errorf("failed inserting acl binding rule: auth method not found") + } + + if err := tx.Insert("acl-binding-rules", rule); err != nil { + return fmt.Errorf("failed inserting acl binding rule: %v", err) + } + return nil +} + +func (s *Store) ACLBindingRuleGetByID(ws memdb.WatchSet, id string) (uint64, *structs.ACLBindingRule, error) { + return s.aclBindingRuleGet(ws, id, "id") +} + +func (s *Store) aclBindingRuleGet(ws memdb.WatchSet, value, index string) (uint64, *structs.ACLBindingRule, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + watchCh, rawRule, err := tx.FirstWatch("acl-binding-rules", index, value) + if err != nil { + return 0, nil, fmt.Errorf("failed acl binding rule lookup: %v", err) + } + ws.Add(watchCh) + + var rule *structs.ACLBindingRule + if rawRule != nil { + rule = rawRule.(*structs.ACLBindingRule) + } + + idx := maxIndexTxn(tx, "acl-binding-rules") + + return idx, rule, nil +} + +func (s *Store) ACLBindingRuleList(ws memdb.WatchSet, methodName string) (uint64, structs.ACLBindingRules, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + var ( + iter memdb.ResultIterator + err error + ) + if methodName != "" { + iter, err = tx.Get("acl-binding-rules", "authmethod", methodName) + } else { + iter, err = tx.Get("acl-binding-rules", "id") + } + if err != nil { + return 0, nil, fmt.Errorf("failed acl binding rule lookup: %v", err) + } + ws.Add(iter.WatchCh()) + + var result structs.ACLBindingRules + for raw := iter.Next(); raw != nil; raw = iter.Next() { + rule := raw.(*structs.ACLBindingRule) + result = append(result, rule) + } + + // Get the table index. + idx := maxIndexTxn(tx, "acl-binding-rules") + + return idx, result, nil +} + +func (s *Store) ACLBindingRuleDeleteByID(idx uint64, id string) error { + return s.aclBindingRuleDelete(idx, id, "id") +} + +func (s *Store) ACLBindingRuleBatchDelete(idx uint64, bindingRuleIDs []string) error { + tx := s.db.Txn(true) + defer tx.Abort() + + for _, bindingRuleID := range bindingRuleIDs { + s.aclBindingRuleDeleteTxn(tx, idx, bindingRuleID, "id") + } + + if err := indexUpdateMaxTxn(tx, idx, "acl-binding-rules"); err != nil { + return fmt.Errorf("failed updating index: %v", err) + } + tx.Commit() + return nil +} + +func (s *Store) aclBindingRuleDelete(idx uint64, value, index string) error { + tx := s.db.Txn(true) + defer tx.Abort() + + if err := s.aclBindingRuleDeleteTxn(tx, idx, value, index); err != nil { + return err + } + if err := indexUpdateMaxTxn(tx, idx, "acl-binding-rules"); err != nil { + return fmt.Errorf("failed updating index: %v", err) + } + + tx.Commit() + return nil +} + +func (s *Store) aclBindingRuleDeleteTxn(tx *memdb.Txn, idx uint64, value, index string) error { + // Look up the existing binding rule + rawRule, err := tx.First("acl-binding-rules", index, value) + if err != nil { + return fmt.Errorf("failed acl binding rule lookup: %v", err) + } + + if rawRule == nil { + return nil + } + + rule := rawRule.(*structs.ACLBindingRule) + + if err := tx.Delete("acl-binding-rules", rule); err != nil { + return fmt.Errorf("failed deleting acl binding rule: %v", err) + } + return nil +} + +func (s *Store) aclBindingRuleDeleteAllForAuthMethodTxn(tx *memdb.Txn, idx uint64, methodName string) error { + // collect them all + iter, err := tx.Get("acl-binding-rules", "authmethod", methodName) + if err != nil { + return fmt.Errorf("failed acl binding rule lookup: %v", err) + } + + var rules structs.ACLBindingRules + for raw := iter.Next(); raw != nil; raw = iter.Next() { + rule := raw.(*structs.ACLBindingRule) + rules = append(rules, rule) + } + + if len(rules) > 0 { + // delete them all + for _, rule := range rules { + if err := tx.Delete("acl-binding-rules", rule); err != nil { + return fmt.Errorf("failed deleting acl binding rule: %v", err) + } + } + + if err := indexUpdateMaxTxn(tx, idx, "acl-binding-rules"); err != nil { + return fmt.Errorf("failed updating index: %v", err) + } + } + + return nil +} + +func (s *Store) ACLAuthMethodBatchSet(idx uint64, methods structs.ACLAuthMethods) error { + tx := s.db.Txn(true) + defer tx.Abort() + + for _, method := range methods { + // this is only used when doing batch insertions for upgrades and replication. Therefore + // we take whatever those said. + if err := s.aclAuthMethodSetTxn(tx, idx, method); err != nil { + return err + } + } + + if err := indexUpdateMaxTxn(tx, idx, "acl-auth-methods"); err != nil { + return fmt.Errorf("failed updating index: %s", err) + } + + tx.Commit() + return nil +} + +func (s *Store) ACLAuthMethodSet(idx uint64, method *structs.ACLAuthMethod) error { + tx := s.db.Txn(true) + defer tx.Abort() + + if err := s.aclAuthMethodSetTxn(tx, idx, method); err != nil { + return err + } + if err := indexUpdateMaxTxn(tx, idx, "acl-auth-methods"); err != nil { + return fmt.Errorf("failed updating index: %s", err) + } + + tx.Commit() + return nil +} + +func (s *Store) aclAuthMethodSetTxn(tx *memdb.Txn, idx uint64, method *structs.ACLAuthMethod) error { + // Check that the Name and Type are set + if method.Name == "" { + return ErrMissingACLAuthMethodName + } else if method.Type == "" { + return ErrMissingACLAuthMethodType + } + + existing, err := tx.First("acl-auth-methods", "id", method.Name) + if err != nil { + return fmt.Errorf("failed acl auth method lookup: %v", err) + } + + // Set the indexes + if existing != nil { + method.CreateIndex = existing.(*structs.ACLAuthMethod).CreateIndex + method.ModifyIndex = idx + } else { + method.CreateIndex = idx + method.ModifyIndex = idx + } + + if err := tx.Insert("acl-auth-methods", method); err != nil { + return fmt.Errorf("failed inserting acl auth method: %v", err) + } + return nil +} + +func (s *Store) ACLAuthMethodGetByName(ws memdb.WatchSet, name string) (uint64, *structs.ACLAuthMethod, error) { + return s.aclAuthMethodGet(ws, name, "id") +} + +func (s *Store) aclAuthMethodGet(ws memdb.WatchSet, value, index string) (uint64, *structs.ACLAuthMethod, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + method, err := s.getAuthMethodWithTxn(tx, ws, value, index) + if err != nil { + return 0, nil, err + } + + idx := maxIndexTxn(tx, "acl-auth-methods") + + return idx, method, nil +} + +func (s *Store) getAuthMethodWithTxn(tx *memdb.Txn, ws memdb.WatchSet, value, index string) (*structs.ACLAuthMethod, error) { + watchCh, rawMethod, err := tx.FirstWatch("acl-auth-methods", index, value) + if err != nil { + return nil, fmt.Errorf("failed acl auth method lookup: %v", err) + } + ws.Add(watchCh) + + if rawMethod != nil { + return rawMethod.(*structs.ACLAuthMethod), nil + } + + return nil, nil +} + +func (s *Store) ACLAuthMethodList(ws memdb.WatchSet) (uint64, structs.ACLAuthMethods, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + iter, err := tx.Get("acl-auth-methods", "id") + if err != nil { + return 0, nil, fmt.Errorf("failed acl auth method lookup: %v", err) + } + ws.Add(iter.WatchCh()) + + var result structs.ACLAuthMethods + for raw := iter.Next(); raw != nil; raw = iter.Next() { + method := raw.(*structs.ACLAuthMethod) + result = append(result, method) + } + + // Get the table index. + idx := maxIndexTxn(tx, "acl-auth-methods") + + return idx, result, nil +} + +func (s *Store) ACLAuthMethodDeleteByName(idx uint64, name string) error { + return s.aclAuthMethodDelete(idx, name, "id") +} + +func (s *Store) ACLAuthMethodBatchDelete(idx uint64, names []string) error { + tx := s.db.Txn(true) + defer tx.Abort() + + for _, name := range names { + s.aclAuthMethodDeleteTxn(tx, idx, name, "id") + } + + if err := indexUpdateMaxTxn(tx, idx, "acl-auth-methods"); err != nil { + return fmt.Errorf("failed updating index: %v", err) + } + tx.Commit() + return nil +} + +func (s *Store) aclAuthMethodDelete(idx uint64, value, index string) error { + tx := s.db.Txn(true) + defer tx.Abort() + + if err := s.aclAuthMethodDeleteTxn(tx, idx, value, index); err != nil { + return err + } + if err := indexUpdateMaxTxn(tx, idx, "acl-auth-methods"); err != nil { + return fmt.Errorf("failed updating index: %v", err) + } + + tx.Commit() + return nil +} + +func (s *Store) aclAuthMethodDeleteTxn(tx *memdb.Txn, idx uint64, value, index string) error { + // Look up the existing method + rawMethod, err := tx.First("acl-auth-methods", index, value) + if err != nil { + return fmt.Errorf("failed acl auth method lookup: %v", err) + } + + if rawMethod == nil { + return nil + } + + method := rawMethod.(*structs.ACLAuthMethod) + + if err := s.aclBindingRuleDeleteAllForAuthMethodTxn(tx, idx, method.Name); err != nil { + return err + } + + if err := s.aclTokenDeleteAllForAuthMethodTxn(tx, idx, method.Name); err != nil { + return err + } + + if err := tx.Delete("acl-auth-methods", method); err != nil { + return fmt.Errorf("failed deleting acl auth method: %v", err) + } + return nil +} diff --git a/agent/consul/state/acl_test.go b/agent/consul/state/acl_test.go index b561e1e54c..7dddbaccb7 100644 --- a/agent/consul/state/acl_test.go +++ b/agent/consul/state/acl_test.go @@ -1,6 +1,7 @@ package state import ( + "fmt" "math/rand" "strconv" "testing" @@ -53,6 +54,17 @@ func testACLStateStore(t *testing.T) *Store { return s } +func setupExtraAuthMethods(t *testing.T, s *Store) { + methods := structs.ACLAuthMethods{ + &structs.ACLAuthMethod{ + Name: "test", + Type: "testing", + Description: "test", + }, + } + require.NoError(t, s.ACLAuthMethodBatchSet(2, methods)) +} + func setupExtraPolicies(t *testing.T, s *Store) { policies := structs.ACLPolicies{ &structs.ACLPolicy{ @@ -205,7 +217,7 @@ func TestStateStore_ACLBootstrap(t *testing.T) { require.Equal(t, uint64(3), index) // Make sure the ACLs are in an expected state. - _, tokens, err := s.ACLTokenList(nil, true, true, "", "") + _, tokens, err := s.ACLTokenList(nil, true, true, "", "", "") require.NoError(t, err) require.Len(t, tokens, 1) compareTokens(t, token1, tokens[0]) @@ -219,7 +231,7 @@ func TestStateStore_ACLBootstrap(t *testing.T) { err = s.ACLBootstrap(32, index, token2.Clone(), false) require.NoError(t, err) - _, tokens, err = s.ACLTokenList(nil, true, true, "", "") + _, tokens, err = s.ACLTokenList(nil, true, true, "", "", "") require.NoError(t, err) require.Len(t, tokens, 2) } @@ -447,6 +459,19 @@ func TestStateStore_ACLToken_SetGet(t *testing.T) { require.Error(t, err) }) + t.Run("Unresolvable AuthMethod", func(t *testing.T) { + t.Parallel() + s := testACLTokensStateStore(t) + token := &structs.ACLToken{ + AccessorID: "daf37c07-d04d-4fd5-9678-a8206a57d61a", + SecretID: "39171632-6f34-4411-827f-9416403687f4", + AuthMethod: "test", + } + + err := s.ACLTokenSet(2, token, false) + require.Error(t, err) + }) + t.Run("New", func(t *testing.T) { t.Parallel() s := testACLTokensStateStore(t) @@ -543,6 +568,37 @@ func TestStateStore_ACLToken_SetGet(t *testing.T) { require.Len(t, rtoken.ServiceIdentities, 1) require.Equal(t, "db", rtoken.ServiceIdentities[0].ServiceName) }) + + t.Run("New with auth method", func(t *testing.T) { + t.Parallel() + s := testACLTokensStateStore(t) + setupExtraAuthMethods(t, s) + + token := &structs.ACLToken{ + AccessorID: "daf37c07-d04d-4fd5-9678-a8206a57d61a", + SecretID: "39171632-6f34-4411-827f-9416403687f4", + AuthMethod: "test", + Roles: []structs.ACLTokenRoleLink{ + structs.ACLTokenRoleLink{ + ID: testRoleID_A, + }, + }, + } + + require.NoError(t, s.ACLTokenSet(2, token.Clone(), false)) + + idx, rtoken, err := s.ACLTokenGetByAccessor(nil, "daf37c07-d04d-4fd5-9678-a8206a57d61a") + require.NoError(t, err) + require.Equal(t, uint64(2), idx) + compareTokens(t, token, rtoken) + require.Equal(t, uint64(2), rtoken.CreateIndex) + require.Equal(t, uint64(2), rtoken.ModifyIndex) + require.Equal(t, "test", rtoken.AuthMethod) + require.Len(t, rtoken.Policies, 0) + require.Len(t, rtoken.ServiceIdentities, 0) + require.Len(t, rtoken.Roles, 1) + require.Equal(t, "node-read-role", rtoken.Roles[0].Name) + }) } func TestStateStore_ACLTokens_UpsertBatchRead(t *testing.T) { @@ -828,6 +884,7 @@ func TestStateStore_ACLTokens_ListUpgradeable(t *testing.T) { func TestStateStore_ACLToken_List(t *testing.T) { t.Parallel() s := testACLTokensStateStore(t) + setupExtraAuthMethods(t, s) tokens := structs.ACLTokens{ // the local token @@ -893,118 +950,167 @@ func TestStateStore_ACLToken_List(t *testing.T) { }, Local: true, }, + // the method specific token + &structs.ACLToken{ + AccessorID: "74277ae1-6a9b-4035-b444-2370fe6a2cb5", + SecretID: "ab8ac834-0d35-4cb7-83c3-168203f986cd", + AuthMethod: "test", + }, + // the method specific token and local + &structs.ACLToken{ + AccessorID: "211f0360-ef53-41d3-9d4d-db84396eb6c0", + SecretID: "087a0eb4-366f-4190-ab4c-a4aa3d2562aa", + AuthMethod: "test", + Local: true, + }, } require.NoError(t, s.ACLTokenBatchSet(2, tokens, false)) type testCase struct { - name string - local bool - global bool - policy string - role string - accessors []string + name string + local bool + global bool + policy string + role string + methodName string + accessors []string } cases := []testCase{ { - name: "Global", - local: false, - global: true, - policy: "", - role: "", + name: "Global", + local: false, + global: true, + policy: "", + role: "", + methodName: "", accessors: []string{ structs.ACLTokenAnonymousID, "47eea4da-bda1-48a6-901c-3e36d2d9262f", // policy + global "54866514-3cf2-4fec-8a8a-710583831834", // mgmt + global + "74277ae1-6a9b-4035-b444-2370fe6a2cb5", // authMethod + global "a7715fde-8954-4c92-afbc-d84c6ecdc582", // role + global }, }, { - name: "Local", - local: true, - global: false, - policy: "", - role: "", + name: "Local", + local: true, + global: false, + policy: "", + role: "", + methodName: "", accessors: []string{ + "211f0360-ef53-41d3-9d4d-db84396eb6c0", // authMethod + local "4915fc9d-3726-4171-b588-6c271f45eecd", // policy + local "cadb4f13-f62a-49ab-ab3f-5a7e01b925d9", // role + local "f1093997-b6c7-496d-bfb8-6b1b1895641b", // mgmt + local }, }, { - name: "Policy", - local: true, - global: true, - policy: testPolicyID_A, - role: "", + name: "Policy", + local: true, + global: true, + policy: testPolicyID_A, + role: "", + methodName: "", accessors: []string{ "47eea4da-bda1-48a6-901c-3e36d2d9262f", // policy + global "4915fc9d-3726-4171-b588-6c271f45eecd", // policy + local }, }, { - name: "Policy - Local", - local: true, - global: false, - policy: testPolicyID_A, - role: "", + name: "Policy - Local", + local: true, + global: false, + policy: testPolicyID_A, + role: "", + methodName: "", accessors: []string{ "4915fc9d-3726-4171-b588-6c271f45eecd", // policy + local }, }, { - name: "Policy - Global", - local: false, - global: true, - policy: testPolicyID_A, - role: "", + name: "Policy - Global", + local: false, + global: true, + policy: testPolicyID_A, + role: "", + methodName: "", accessors: []string{ "47eea4da-bda1-48a6-901c-3e36d2d9262f", // policy + global }, }, { - name: "Role", - local: true, - global: true, - policy: "", - role: testRoleID_A, + name: "Role", + local: true, + global: true, + policy: "", + role: testRoleID_A, + methodName: "", accessors: []string{ "a7715fde-8954-4c92-afbc-d84c6ecdc582", // role + global "cadb4f13-f62a-49ab-ab3f-5a7e01b925d9", // role + local }, }, { - name: "Role - Local", - local: true, - global: false, - policy: "", - role: testRoleID_A, + name: "Role - Local", + local: true, + global: false, + policy: "", + role: testRoleID_A, + methodName: "", accessors: []string{ "cadb4f13-f62a-49ab-ab3f-5a7e01b925d9", // role + local }, }, { - name: "Role - Global", - local: false, - global: true, - policy: "", - role: testRoleID_A, + name: "Role - Global", + local: false, + global: true, + policy: "", + role: testRoleID_A, + methodName: "", accessors: []string{ "a7715fde-8954-4c92-afbc-d84c6ecdc582", // role + global }, }, { - name: "All", - local: true, - global: true, - policy: "", - role: "", + name: "AuthMethod - Local", + local: true, + global: false, + policy: "", + role: "", + methodName: "test", + accessors: []string{ + "211f0360-ef53-41d3-9d4d-db84396eb6c0", // authMethod + local + }, + }, + { + name: "AuthMethod - Global", + local: false, + global: true, + policy: "", + role: "", + methodName: "test", + accessors: []string{ + "74277ae1-6a9b-4035-b444-2370fe6a2cb5", // authMethod + global + }, + }, + { + name: "All", + local: true, + global: true, + policy: "", + role: "", + methodName: "", accessors: []string{ structs.ACLTokenAnonymousID, + "211f0360-ef53-41d3-9d4d-db84396eb6c0", // authMethod + local "47eea4da-bda1-48a6-901c-3e36d2d9262f", // policy + global "4915fc9d-3726-4171-b588-6c271f45eecd", // policy + local "54866514-3cf2-4fec-8a8a-710583831834", // mgmt + global + "74277ae1-6a9b-4035-b444-2370fe6a2cb5", // authMethod + global "a7715fde-8954-4c92-afbc-d84c6ecdc582", // role + global "cadb4f13-f62a-49ab-ab3f-5a7e01b925d9", // role + local "f1093997-b6c7-496d-bfb8-6b1b1895641b", // mgmt + local @@ -1012,16 +1118,23 @@ func TestStateStore_ACLToken_List(t *testing.T) { }, } - t.Run("can't filter on both", func(t *testing.T) { - _, _, err := s.ACLTokenList(nil, false, false, testPolicyID_A, testRoleID_A) - require.Error(t, err) - }) + for _, tc := range []struct{ policy, role, methodName string }{ + {testPolicyID_A, testRoleID_A, "test"}, + {"", testRoleID_A, "test"}, + {testPolicyID_A, "", "test"}, + {testPolicyID_A, testRoleID_A, ""}, + } { + t.Run(fmt.Sprintf("can't filter on more than one: %s/%s/%s", tc.policy, tc.role, tc.methodName), func(t *testing.T) { + _, _, err := s.ACLTokenList(nil, false, false, tc.policy, tc.role, tc.methodName) + require.Error(t, err) + }) + } for _, tc := range cases { tc := tc // capture range variable t.Run(tc.name, func(t *testing.T) { t.Parallel() - _, tokens, err := s.ACLTokenList(nil, tc.local, tc.global, tc.policy, tc.role) + _, tokens, err := s.ACLTokenList(nil, tc.local, tc.global, tc.policy, tc.role, tc.methodName) require.NoError(t, err) require.Len(t, tokens, len(tc.accessors)) tokens.Sort() @@ -1082,7 +1195,7 @@ func TestStateStore_ACLToken_FixupPolicyLinks(t *testing.T) { require.Equal(t, "node-read-renamed", retrieved.Policies[0].Name) // list tokens without stale links - _, tokens, err := s.ACLTokenList(nil, true, true, "", "") + _, tokens, err := s.ACLTokenList(nil, true, true, "", "", "") require.NoError(t, err) found := false @@ -1126,7 +1239,7 @@ func TestStateStore_ACLToken_FixupPolicyLinks(t *testing.T) { require.Len(t, retrieved.Policies, 0) // list tokens without stale links - _, tokens, err = s.ACLTokenList(nil, true, true, "", "") + _, tokens, err = s.ACLTokenList(nil, true, true, "", "", "") require.NoError(t, err) found = false @@ -1211,7 +1324,7 @@ func TestStateStore_ACLToken_FixupRoleLinks(t *testing.T) { require.Equal(t, "node-read-role-renamed", retrieved.Roles[0].Name) // list tokens without stale links - _, tokens, err := s.ACLTokenList(nil, true, true, "", "") + _, tokens, err := s.ACLTokenList(nil, true, true, "", "", "") require.NoError(t, err) found := false @@ -1255,7 +1368,7 @@ func TestStateStore_ACLToken_FixupRoleLinks(t *testing.T) { require.Len(t, retrieved.Roles, 0) // list tokens without stale links - _, tokens, err = s.ACLTokenList(nil, true, true, "", "") + _, tokens, err = s.ACLTokenList(nil, true, true, "", "", "") require.NoError(t, err) found = false @@ -2504,6 +2617,730 @@ func TestStateStore_ACLRole_Delete(t *testing.T) { }) } +func TestStateStore_ACLAuthMethod_SetGet(t *testing.T) { + t.Parallel() + + // The state store only validates key pieces of data, so we only have to + // care about filling in Name+Type. + + t.Run("Missing Name", func(t *testing.T) { + t.Parallel() + s := testACLStateStore(t) + + method := structs.ACLAuthMethod{ + Name: "", + Type: "testing", + Description: "test", + } + + require.Error(t, s.ACLAuthMethodSet(3, &method)) + }) + + t.Run("Missing Type", func(t *testing.T) { + t.Parallel() + s := testACLStateStore(t) + + method := structs.ACLAuthMethod{ + Name: "test", + Type: "", + Description: "test", + } + + require.Error(t, s.ACLAuthMethodSet(3, &method)) + }) + + t.Run("New", func(t *testing.T) { + t.Parallel() + s := testACLStateStore(t) + + method := structs.ACLAuthMethod{ + Name: "test", + Type: "testing", + Description: "test", + } + + require.NoError(t, s.ACLAuthMethodSet(3, &method)) + + idx, rmethod, err := s.ACLAuthMethodGetByName(nil, "test") + require.NoError(t, err) + require.Equal(t, uint64(3), idx) + require.NotNil(t, rmethod) + require.Equal(t, "test", rmethod.Name) + require.Equal(t, "testing", rmethod.Type) + require.Equal(t, "test", rmethod.Description) + require.Equal(t, uint64(3), rmethod.CreateIndex) + require.Equal(t, uint64(3), rmethod.ModifyIndex) + }) + + t.Run("Update", func(t *testing.T) { + t.Parallel() + s := testACLStateStore(t) + + // Create the initial method + method := structs.ACLAuthMethod{ + Name: "test", + Type: "testing", + Description: "test", + } + + require.NoError(t, s.ACLAuthMethodSet(2, &method)) + + // Now make sure we can update it + update := structs.ACLAuthMethod{ + Name: "test", + Type: "testing", + Description: "modified", + Config: map[string]interface{}{ + "Host": "https://localhost:8443", + }, + } + + require.NoError(t, s.ACLAuthMethodSet(3, &update)) + + idx, rmethod, err := s.ACLAuthMethodGetByName(nil, "test") + require.NoError(t, err) + require.Equal(t, uint64(3), idx) + require.NotNil(t, rmethod) + require.Equal(t, "test", rmethod.Name) + require.Equal(t, "testing", rmethod.Type) + require.Equal(t, "modified", rmethod.Description) + require.Equal(t, update.Config, rmethod.Config) + require.Equal(t, uint64(2), rmethod.CreateIndex) + require.Equal(t, uint64(3), rmethod.ModifyIndex) + }) +} + +func TestStateStore_ACLAuthMethods_UpsertBatchRead(t *testing.T) { + t.Parallel() + + t.Run("Normal", func(t *testing.T) { + t.Parallel() + s := testACLStateStore(t) + + methods := structs.ACLAuthMethods{ + &structs.ACLAuthMethod{ + Name: "test-1", + Type: "testing", + Description: "test-1", + }, + &structs.ACLAuthMethod{ + Name: "test-2", + Type: "testing", + Description: "test-1", + }, + } + + require.NoError(t, s.ACLAuthMethodBatchSet(2, methods)) + + idx, rmethods, err := s.ACLAuthMethodList(nil) + require.NoError(t, err) + require.Equal(t, uint64(2), idx) + require.Len(t, rmethods, 2) + rmethods.Sort() + require.ElementsMatch(t, methods, rmethods) + require.Equal(t, uint64(2), rmethods[0].CreateIndex) + require.Equal(t, uint64(2), rmethods[0].ModifyIndex) + require.Equal(t, uint64(2), rmethods[1].CreateIndex) + require.Equal(t, uint64(2), rmethods[1].ModifyIndex) + }) + + t.Run("Update", func(t *testing.T) { + t.Parallel() + s := testACLStateStore(t) + + // Seed initial data. + methods := structs.ACLAuthMethods{ + &structs.ACLAuthMethod{ + Name: "test-1", + Type: "testing", + Description: "test-1", + }, + &structs.ACLAuthMethod{ + Name: "test-2", + Type: "testing", + Description: "test-2", + }, + } + + require.NoError(t, s.ACLAuthMethodBatchSet(2, methods)) + + // Update two methods at the same time. + updates := structs.ACLAuthMethods{ + &structs.ACLAuthMethod{ + Name: "test-1", + Type: "testing", + Description: "test-1 modified", + Config: map[string]interface{}{ + "Host": "https://localhost:8443", + }, + }, + &structs.ACLAuthMethod{ + Name: "test-2", + Type: "testing", + Description: "test-2 modified", + Config: map[string]interface{}{ + "Host": "https://localhost:8444", + }, + }, + } + + require.NoError(t, s.ACLAuthMethodBatchSet(3, updates)) + + idx, rmethods, err := s.ACLAuthMethodList(nil) + require.NoError(t, err) + require.Equal(t, uint64(3), idx) + require.Len(t, rmethods, 2) + rmethods.Sort() + require.ElementsMatch(t, updates, rmethods) + require.Equal(t, uint64(2), rmethods[0].CreateIndex) + require.Equal(t, uint64(3), rmethods[0].ModifyIndex) + require.Equal(t, uint64(2), rmethods[1].CreateIndex) + require.Equal(t, uint64(3), rmethods[1].ModifyIndex) + }) +} + +func TestStateStore_ACLAuthMethod_List(t *testing.T) { + t.Parallel() + s := testACLStateStore(t) + + methods := structs.ACLAuthMethods{ + &structs.ACLAuthMethod{ + Name: "test-1", + Type: "testing", + Description: "test-1", + }, + &structs.ACLAuthMethod{ + Name: "test-2", + Type: "testing", + Description: "test-2", + }, + } + + require.NoError(t, s.ACLAuthMethodBatchSet(2, methods)) + + _, rmethods, err := s.ACLAuthMethodList(nil) + require.NoError(t, err) + + require.Len(t, rmethods, 2) + rmethods.Sort() + + require.Equal(t, "test-1", rmethods[0].Name) + require.Equal(t, "testing", rmethods[0].Type) + require.Equal(t, "test-1", rmethods[0].Description) + require.Equal(t, uint64(2), rmethods[0].CreateIndex) + require.Equal(t, uint64(2), rmethods[0].ModifyIndex) + + require.Equal(t, "test-2", rmethods[1].Name) + require.Equal(t, "testing", rmethods[1].Type) + require.Equal(t, "test-2", rmethods[1].Description) + require.Equal(t, uint64(2), rmethods[1].CreateIndex) + require.Equal(t, uint64(2), rmethods[1].ModifyIndex) +} + +func TestStateStore_ACLAuthMethod_Delete(t *testing.T) { + t.Parallel() + + t.Run("Name", func(t *testing.T) { + t.Parallel() + s := testACLStateStore(t) + + method := structs.ACLAuthMethod{ + Name: "test", + Type: "testing", + Description: "test", + } + + require.NoError(t, s.ACLAuthMethodSet(2, &method)) + + _, rmethod, err := s.ACLAuthMethodGetByName(nil, "test") + require.NoError(t, err) + require.NotNil(t, rmethod) + + require.NoError(t, s.ACLAuthMethodDeleteByName(3, "test")) + require.NoError(t, err) + + _, rmethod, err = s.ACLAuthMethodGetByName(nil, "test") + require.NoError(t, err) + require.Nil(t, rmethod) + }) + + t.Run("Multiple", func(t *testing.T) { + t.Parallel() + s := testACLStateStore(t) + + methods := structs.ACLAuthMethods{ + &structs.ACLAuthMethod{ + Name: "test-1", + Type: "testing", + Description: "test-1", + }, + &structs.ACLAuthMethod{ + Name: "test-2", + Type: "testing", + Description: "test-2", + }, + } + + require.NoError(t, s.ACLAuthMethodBatchSet(2, methods)) + + _, rmethod, err := s.ACLAuthMethodGetByName(nil, "test-1") + require.NoError(t, err) + require.NotNil(t, rmethod) + _, rmethod, err = s.ACLAuthMethodGetByName(nil, "test-2") + require.NoError(t, err) + require.NotNil(t, rmethod) + + require.NoError(t, s.ACLAuthMethodBatchDelete(3, []string{"test-1", "test-2"})) + + _, rmethod, err = s.ACLAuthMethodGetByName(nil, "test-1") + require.NoError(t, err) + require.Nil(t, rmethod) + _, rmethod, err = s.ACLAuthMethodGetByName(nil, "test-2") + require.NoError(t, err) + require.Nil(t, rmethod) + }) + + t.Run("Not Found", func(t *testing.T) { + t.Parallel() + s := testACLStateStore(t) + + // deletion of non-existant methods is not an error + require.NoError(t, s.ACLAuthMethodDeleteByName(3, "not-found")) + }) +} + +// Deleting an auth method atomically deletes all rules and tokens as well. +func TestStateStore_ACLAuthMethod_Delete_RuleAndTokenCascade(t *testing.T) { + t.Parallel() + + s := testACLStateStore(t) + + methods := structs.ACLAuthMethods{ + &structs.ACLAuthMethod{ + Name: "test-1", + Type: "testing", + Description: "test-1", + }, + &structs.ACLAuthMethod{ + Name: "test-2", + Type: "testing", + Description: "test-2", + }, + } + require.NoError(t, s.ACLAuthMethodBatchSet(2, methods)) + + const ( + method1_rule1 = "dff6f8a3-0115-4b22-8661-04a497ebb23c" + method1_rule2 = "69e2d304-703d-4889-bd94-4a720c061fc3" + method2_rule1 = "997ee45c-d6ba-4da1-a98e-aaa012e7d1e2" + method2_rule2 = "9ebae132-f1f1-4b72-b1d9-a4313ac22075" + ) + + rules := structs.ACLBindingRules{ + &structs.ACLBindingRule{ + ID: method1_rule1, + AuthMethod: "test-1", + Description: "test-m1-r1", + }, + &structs.ACLBindingRule{ + ID: method1_rule2, + AuthMethod: "test-1", + Description: "test-m1-r2", + }, + &structs.ACLBindingRule{ + ID: method2_rule1, + AuthMethod: "test-2", + Description: "test-m2-r1", + }, + &structs.ACLBindingRule{ + ID: method2_rule2, + AuthMethod: "test-2", + Description: "test-m2-r2", + }, + } + require.NoError(t, s.ACLBindingRuleBatchSet(3, rules)) + + const ( // accessors + method1_tok1 = "6d020c5d-c4fd-4348-ba79-beac37ed0b9c" + method1_tok2 = "169160dc-34ab-45c6-aba7-ff65e9ace9cb" + method2_tok1 = "8e14628e-7dde-4573-aca1-6386c0f2095d" + method2_tok2 = "291e5af9-c68e-4dd3-8824-b2bdfdcc89e6" + ) + + tokens := structs.ACLTokens{ + &structs.ACLToken{ + AccessorID: method1_tok1, + SecretID: "7a1950c6-79dc-441c-acd2-e22cd3db0240", + Description: "test-m1-t1", + AuthMethod: "test-1", + }, + &structs.ACLToken{ + AccessorID: method1_tok2, + SecretID: "442cee4c-353f-4957-adbb-33db2f9e267f", + Description: "test-m1-t2", + AuthMethod: "test-1", + }, + &structs.ACLToken{ + AccessorID: method2_tok1, + SecretID: "d9399b7d-6c34-46bd-a675-c1352fadb6fd", + Description: "test-m2-t1", + AuthMethod: "test-2", + }, + &structs.ACLToken{ + AccessorID: method2_tok2, + SecretID: "3b72fc27-9230-42ab-a1e8-02cb489ab177", + Description: "test-m2-t2", + AuthMethod: "test-2", + }, + } + require.NoError(t, s.ACLTokenBatchSet(4, tokens, false)) + + // Delete one method. + require.NoError(t, s.ACLAuthMethodDeleteByName(4, "test-1")) + + // Make sure the method is gone. + _, rmethod, err := s.ACLAuthMethodGetByName(nil, "test-1") + require.NoError(t, err) + require.Nil(t, rmethod) + + // Make sure the rules and tokens are gone. + for _, ruleID := range []string{method1_rule1, method1_rule2} { + _, rrule, err := s.ACLBindingRuleGetByID(nil, ruleID) + require.NoError(t, err) + require.Nil(t, rrule) + } + for _, tokID := range []string{method1_tok1, method1_tok2} { + _, tok, err := s.ACLTokenGetByAccessor(nil, tokID) + require.NoError(t, err) + require.Nil(t, tok) + } + + // Make sure the rules and tokens for the untouched method are still there. + for _, ruleID := range []string{method2_rule1, method2_rule2} { + _, rrule, err := s.ACLBindingRuleGetByID(nil, ruleID) + require.NoError(t, err) + require.NotNil(t, rrule) + } + for _, tokID := range []string{method2_tok1, method2_tok2} { + _, tok, err := s.ACLTokenGetByAccessor(nil, tokID) + require.NoError(t, err) + require.NotNil(t, tok) + } +} + +func TestStateStore_ACLBindingRule_SetGet(t *testing.T) { + t.Parallel() + + // The state store only validates key pieces of data, so we only have to + // care about filling in ID+AuthMethod. + + t.Run("Missing ID", func(t *testing.T) { + t.Parallel() + s := testACLStateStore(t) + setupExtraAuthMethods(t, s) + + rule := structs.ACLBindingRule{ + ID: "", + AuthMethod: "test", + Description: "test", + } + + require.Error(t, s.ACLBindingRuleSet(3, &rule)) + }) + + t.Run("Missing AuthMethod", func(t *testing.T) { + t.Parallel() + s := testACLStateStore(t) + setupExtraAuthMethods(t, s) + + rule := structs.ACLBindingRule{ + ID: "9669b2d7-455c-4d70-b0ac-457fd7969a2e", + AuthMethod: "", + Description: "test", + } + + require.Error(t, s.ACLBindingRuleSet(3, &rule)) + }) + + t.Run("Unknown AuthMethod", func(t *testing.T) { + t.Parallel() + s := testACLStateStore(t) + setupExtraAuthMethods(t, s) + + rule := structs.ACLBindingRule{ + ID: "9669b2d7-455c-4d70-b0ac-457fd7969a2e", + AuthMethod: "unknown", + Description: "test", + } + + require.Error(t, s.ACLBindingRuleSet(3, &rule)) + }) + + t.Run("New", func(t *testing.T) { + t.Parallel() + s := testACLStateStore(t) + setupExtraAuthMethods(t, s) + + rule := structs.ACLBindingRule{ + ID: "9669b2d7-455c-4d70-b0ac-457fd7969a2e", + AuthMethod: "test", + Description: "test", + } + + require.NoError(t, s.ACLBindingRuleSet(3, &rule)) + + idx, rrule, err := s.ACLBindingRuleGetByID(nil, rule.ID) + require.NoError(t, err) + require.Equal(t, uint64(3), idx) + require.NotNil(t, rrule) + require.Equal(t, rule.ID, rrule.ID) + require.Equal(t, "test", rrule.AuthMethod) + require.Equal(t, "test", rrule.Description) + require.Equal(t, uint64(3), rrule.CreateIndex) + require.Equal(t, uint64(3), rrule.ModifyIndex) + }) + + t.Run("Update", func(t *testing.T) { + t.Parallel() + s := testACLStateStore(t) + setupExtraAuthMethods(t, s) + + // Create the initial rule + rule := structs.ACLBindingRule{ + ID: "9669b2d7-455c-4d70-b0ac-457fd7969a2e", + AuthMethod: "test", + Description: "test", + } + + require.NoError(t, s.ACLBindingRuleSet(2, &rule)) + + // Now make sure we can update it + update := structs.ACLBindingRule{ + ID: "9669b2d7-455c-4d70-b0ac-457fd7969a2e", + AuthMethod: "test", + Description: "modified", + BindType: structs.BindingRuleBindTypeService, + BindName: "web", + } + + require.NoError(t, s.ACLBindingRuleSet(3, &update)) + + idx, rrule, err := s.ACLBindingRuleGetByID(nil, rule.ID) + require.NoError(t, err) + require.Equal(t, uint64(3), idx) + require.NotNil(t, rrule) + require.Equal(t, rule.ID, rrule.ID) + require.Equal(t, "test", rrule.AuthMethod) + require.Equal(t, "modified", rrule.Description) + require.Equal(t, structs.BindingRuleBindTypeService, rrule.BindType) + require.Equal(t, "web", rrule.BindName) + require.Equal(t, uint64(2), rrule.CreateIndex) + require.Equal(t, uint64(3), rrule.ModifyIndex) + }) +} + +func TestStateStore_ACLBindingRules_UpsertBatchRead(t *testing.T) { + t.Parallel() + + t.Run("Normal", func(t *testing.T) { + t.Parallel() + s := testACLStateStore(t) + setupExtraAuthMethods(t, s) + + rules := structs.ACLBindingRules{ + &structs.ACLBindingRule{ + ID: "9669b2d7-455c-4d70-b0ac-457fd7969a2e", + AuthMethod: "test", + Description: "test-1", + }, + &structs.ACLBindingRule{ + ID: "3ebcc27b-f8ba-4611-b385-79a065dfb983", + AuthMethod: "test", + Description: "test-2", + }, + } + + require.NoError(t, s.ACLBindingRuleBatchSet(2, rules)) + + idx, rrules, err := s.ACLBindingRuleList(nil, "test") + require.NoError(t, err) + require.Equal(t, uint64(2), idx) + require.Len(t, rrules, 2) + rrules.Sort() + require.ElementsMatch(t, rules, rrules) + require.Equal(t, uint64(2), rrules[0].CreateIndex) + require.Equal(t, uint64(2), rrules[0].ModifyIndex) + require.Equal(t, uint64(2), rrules[1].CreateIndex) + require.Equal(t, uint64(2), rrules[1].ModifyIndex) + }) + + t.Run("Update", func(t *testing.T) { + t.Parallel() + s := testACLStateStore(t) + setupExtraAuthMethods(t, s) + + // Seed initial data. + rules := structs.ACLBindingRules{ + &structs.ACLBindingRule{ + ID: "9669b2d7-455c-4d70-b0ac-457fd7969a2e", + AuthMethod: "test", + Description: "test-1", + }, + &structs.ACLBindingRule{ + ID: "3ebcc27b-f8ba-4611-b385-79a065dfb983", + AuthMethod: "test", + Description: "test-2", + }, + } + + require.NoError(t, s.ACLBindingRuleBatchSet(2, rules)) + + // Update two rules at the same time. + updates := structs.ACLBindingRules{ + &structs.ACLBindingRule{ + ID: "9669b2d7-455c-4d70-b0ac-457fd7969a2e", + AuthMethod: "test", + Description: "test-1 modified", + BindType: structs.BindingRuleBindTypeService, + BindName: "web-1", + }, + &structs.ACLBindingRule{ + ID: "3ebcc27b-f8ba-4611-b385-79a065dfb983", + AuthMethod: "test", + Description: "test-2 modified", + BindType: structs.BindingRuleBindTypeService, + BindName: "web-2", + }, + } + + require.NoError(t, s.ACLBindingRuleBatchSet(3, updates)) + + idx, rrules, err := s.ACLBindingRuleList(nil, "test") + require.NoError(t, err) + require.Equal(t, uint64(3), idx) + require.Len(t, rrules, 2) + rrules.Sort() + require.ElementsMatch(t, updates, rrules) + require.Equal(t, uint64(2), rrules[0].CreateIndex) + require.Equal(t, uint64(3), rrules[0].ModifyIndex) + require.Equal(t, uint64(2), rrules[1].CreateIndex) + require.Equal(t, uint64(3), rrules[1].ModifyIndex) + }) +} + +func TestStateStore_ACLBindingRule_List(t *testing.T) { + t.Parallel() + s := testACLStateStore(t) + setupExtraAuthMethods(t, s) + + rules := structs.ACLBindingRules{ + &structs.ACLBindingRule{ + ID: "3ebcc27b-f8ba-4611-b385-79a065dfb983", + AuthMethod: "test", + Description: "test-1", + }, + &structs.ACLBindingRule{ + ID: "9669b2d7-455c-4d70-b0ac-457fd7969a2e", + AuthMethod: "test", + Description: "test-2", + }, + } + + require.NoError(t, s.ACLBindingRuleBatchSet(2, rules)) + + _, rrules, err := s.ACLBindingRuleList(nil, "") + require.NoError(t, err) + + require.Len(t, rrules, 2) + rrules.Sort() + + require.Equal(t, "3ebcc27b-f8ba-4611-b385-79a065dfb983", rrules[0].ID) + require.Equal(t, "test", rrules[0].AuthMethod) + require.Equal(t, "test-1", rrules[0].Description) + require.Equal(t, uint64(2), rrules[0].CreateIndex) + require.Equal(t, uint64(2), rrules[0].ModifyIndex) + + require.Equal(t, "9669b2d7-455c-4d70-b0ac-457fd7969a2e", rrules[1].ID) + require.Equal(t, "test", rrules[1].AuthMethod) + require.Equal(t, "test-2", rrules[1].Description) + require.Equal(t, uint64(2), rrules[1].CreateIndex) + require.Equal(t, uint64(2), rrules[1].ModifyIndex) +} + +func TestStateStore_ACLBindingRule_Delete(t *testing.T) { + t.Parallel() + + t.Run("Name", func(t *testing.T) { + t.Parallel() + s := testACLStateStore(t) + setupExtraAuthMethods(t, s) + + rule := structs.ACLBindingRule{ + ID: "9669b2d7-455c-4d70-b0ac-457fd7969a2e", + AuthMethod: "test", + Description: "test", + } + + require.NoError(t, s.ACLBindingRuleSet(2, &rule)) + + _, rrule, err := s.ACLBindingRuleGetByID(nil, rule.ID) + require.NoError(t, err) + require.NotNil(t, rrule) + + require.NoError(t, s.ACLBindingRuleDeleteByID(3, rule.ID)) + require.NoError(t, err) + + _, rrule, err = s.ACLBindingRuleGetByID(nil, rule.ID) + require.NoError(t, err) + require.Nil(t, rrule) + }) + + t.Run("Multiple", func(t *testing.T) { + t.Parallel() + s := testACLStateStore(t) + setupExtraAuthMethods(t, s) + + rules := structs.ACLBindingRules{ + &structs.ACLBindingRule{ + ID: "3ebcc27b-f8ba-4611-b385-79a065dfb983", + AuthMethod: "test", + Description: "test-1", + }, + &structs.ACLBindingRule{ + ID: "9669b2d7-455c-4d70-b0ac-457fd7969a2e", + AuthMethod: "test", + Description: "test-2", + }, + } + + require.NoError(t, s.ACLBindingRuleBatchSet(2, rules)) + + _, rrule, err := s.ACLBindingRuleGetByID(nil, rules[0].ID) + require.NoError(t, err) + require.NotNil(t, rrule) + _, rrule, err = s.ACLBindingRuleGetByID(nil, rules[1].ID) + require.NoError(t, err) + require.NotNil(t, rrule) + + require.NoError(t, s.ACLBindingRuleBatchDelete(3, []string{rules[0].ID, rules[1].ID})) + + _, rrule, err = s.ACLBindingRuleGetByID(nil, rules[0].ID) + require.NoError(t, err) + require.Nil(t, rrule) + _, rrule, err = s.ACLBindingRuleGetByID(nil, rules[1].ID) + require.NoError(t, err) + require.Nil(t, rrule) + }) + + t.Run("Not Found", func(t *testing.T) { + t.Parallel() + s := testACLStateStore(t) + + // deletion of non-existant rules is not an error + require.NoError(t, s.ACLBindingRuleDeleteByID(3, "ed3ce1b8-3a16-4e2f-b82e-f92e3b92410d")) + }) +} + func TestStateStore_ACLTokens_Snapshot_Restore(t *testing.T) { s := testStateStore(t) @@ -2651,7 +3488,7 @@ func TestStateStore_ACLTokens_Snapshot_Restore(t *testing.T) { require.NoError(t, s.ACLRoleBatchSet(2, roles)) // Read the restored ACLs back out and verify that they match. - idx, res, err := s.ACLTokenList(nil, true, true, "", "") + idx, res, err := s.ACLTokenList(nil, true, true, "", "", "") require.NoError(t, err) require.Equal(t, uint64(4), idx) require.ElementsMatch(t, tokens, res) @@ -2991,3 +3828,120 @@ func TestStateStore_ACLRoles_Snapshot_Restore(t *testing.T) { require.Equal(t, uint64(2), s.maxIndex("acl-roles")) }() } + +func TestStateStore_ACLAuthMethods_Snapshot_Restore(t *testing.T) { + s := testACLStateStore(t) + + methods := structs.ACLAuthMethods{ + &structs.ACLAuthMethod{ + Name: "test-1", + Type: "testing", + Description: "test-1", + RaftIndex: structs.RaftIndex{CreateIndex: 1, ModifyIndex: 2}, + }, + &structs.ACLAuthMethod{ + Name: "test-2", + Type: "testing", + Description: "test-2", + RaftIndex: structs.RaftIndex{CreateIndex: 1, ModifyIndex: 2}, + }, + } + + require.NoError(t, s.ACLAuthMethodBatchSet(2, methods)) + + // Snapshot the ACLs. + snap := s.Snapshot() + defer snap.Close() + + // Alter the real state store. + require.NoError(t, s.ACLAuthMethodDeleteByName(3, "test-1")) + + // Verify the snapshot. + require.Equal(t, uint64(2), snap.LastIndex()) + + iter, err := snap.ACLAuthMethods() + require.NoError(t, err) + + var dump structs.ACLAuthMethods + for method := iter.Next(); method != nil; method = iter.Next() { + dump = append(dump, method.(*structs.ACLAuthMethod)) + } + require.ElementsMatch(t, dump, methods) + + // Restore the values into a new state store. + func() { + s := testStateStore(t) + restore := s.Restore() + for _, method := range dump { + require.NoError(t, restore.ACLAuthMethod(method)) + } + restore.Commit() + + // Read the restored methods back out and verify that they match. + idx, res, err := s.ACLAuthMethodList(nil) + require.NoError(t, err) + require.Equal(t, uint64(2), idx) + require.ElementsMatch(t, methods, res) + require.Equal(t, uint64(2), s.maxIndex("acl-auth-methods")) + }() +} + +func TestStateStore_ACLBindingRules_Snapshot_Restore(t *testing.T) { + s := testACLStateStore(t) + setupExtraAuthMethods(t, s) + + rules := structs.ACLBindingRules{ + &structs.ACLBindingRule{ + ID: "9669b2d7-455c-4d70-b0ac-457fd7969a2e", + AuthMethod: "test", + Description: "test-1", + RaftIndex: structs.RaftIndex{CreateIndex: 1, ModifyIndex: 2}, + }, + &structs.ACLBindingRule{ + ID: "3ebcc27b-f8ba-4611-b385-79a065dfb983", + AuthMethod: "test", + Description: "test-2", + RaftIndex: structs.RaftIndex{CreateIndex: 1, ModifyIndex: 2}, + }, + } + + require.NoError(t, s.ACLBindingRuleBatchSet(2, rules)) + + // Snapshot the ACLs. + snap := s.Snapshot() + defer snap.Close() + + // Alter the real state store. + require.NoError(t, s.ACLBindingRuleDeleteByID(3, rules[0].ID)) + + // Verify the snapshot. + require.Equal(t, uint64(2), snap.LastIndex()) + + iter, err := snap.ACLBindingRules() + require.NoError(t, err) + + var dump structs.ACLBindingRules + for rule := iter.Next(); rule != nil; rule = iter.Next() { + dump = append(dump, rule.(*structs.ACLBindingRule)) + } + require.ElementsMatch(t, dump, rules) + + // Restore the values into a new state store. + func() { + s := testStateStore(t) + setupExtraAuthMethods(t, s) + + restore := s.Restore() + for _, rule := range dump { + require.NoError(t, restore.ACLBindingRule(rule)) + } + restore.Commit() + + // Read the restored rules back out and verify that they match. + idx, res, err := s.ACLBindingRuleList(nil, "") + require.NoError(t, err) + require.Equal(t, uint64(2), idx) + require.ElementsMatch(t, rules, res) + require.Equal(t, uint64(2), s.maxIndex("acl-binding-rules")) + }() +} diff --git a/agent/consul/state/state_store.go b/agent/consul/state/state_store.go index 4dcc74ddde..f38bc42da8 100644 --- a/agent/consul/state/state_store.go +++ b/agent/consul/state/state_store.go @@ -37,17 +37,29 @@ var ( // policy with an empty Name. ErrMissingACLPolicyName = errors.New("Missing ACL Policy Name") - // ErrMissingACLRoleID is returned when an role set is called on + // ErrMissingACLRoleID is returned when a role set is called on // a role with an empty ID. ErrMissingACLRoleID = errors.New("Missing ACL Role ID") - // ErrMissingACLRoleName is returned when an role set is called on + // ErrMissingACLRoleName is returned when a role set is called on // a role with an empty Name. ErrMissingACLRoleName = errors.New("Missing ACL Role Name") - // ErrInvalidACLRoleName is returned when an role set is called on - // a role with an invalid Name. - ErrInvalidACLRoleName = errors.New("Invalid ACL Role Name") + // ErrMissingACLBindingRuleID is returned when a binding rule set + // is called on a binding rule with an empty ID. + ErrMissingACLBindingRuleID = errors.New("Missing ACL Binding Rule ID") + + // ErrMissingACLBindingRuleAuthMethod is returned when a binding rule set + // is called on a binding rule with an empty AuthMethod. + ErrMissingACLBindingRuleAuthMethod = errors.New("Missing ACL Binding Rule Auth Method") + + // ErrMissingACLAuthMethodName is returned when an auth method set is + // called on an auth method with an empty Name. + ErrMissingACLAuthMethodName = errors.New("Missing ACL Auth Method Name") + + // ErrMissingACLAuthMethodType is returned when an auth method set is + // called on an auth method with an empty Type. + ErrMissingACLAuthMethodType = errors.New("Missing ACL Auth Method Type") // ErrMissingQueryID is returned when a Query set is called on // a Query with an empty ID. diff --git a/agent/consul/util.go b/agent/consul/util.go index cb0134240a..19ff2cc2b2 100644 --- a/agent/consul/util.go +++ b/agent/consul/util.go @@ -6,10 +6,13 @@ import ( "net" "runtime" "strconv" + "strings" "github.com/hashicorp/consul/agent/metadata" "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/go-version" + "github.com/hashicorp/hil" + "github.com/hashicorp/hil/ast" "github.com/hashicorp/serf/serf" ) @@ -322,3 +325,42 @@ func ServersGetACLMode(members []serf.Member, leader string, datacenter string) return } + +// InterpolateHIL processes the string as if it were HIL and interpolates only +// the provided string->string map as possible variables. +func InterpolateHIL(s string, vars map[string]string) (string, error) { + if strings.Index(s, "${") == -1 { + // Skip going to the trouble of parsing something that has no HIL. + return s, nil + } + + tree, err := hil.Parse(s) + if err != nil { + return "", err + } + + vm := make(map[string]ast.Variable) + for k, v := range vars { + vm[k] = ast.Variable{ + Type: ast.TypeString, + Value: v, + } + } + + config := &hil.EvalConfig{ + GlobalScope: &ast.BasicScope{ + VarMap: vm, + }, + } + + result, err := hil.Eval(tree, config) + if err != nil { + return "", err + } + + if result.Type != hil.TypeString { + return "", fmt.Errorf("generated unexpected hil type: %s", result.Type) + } + + return result.Value.(string), nil +} diff --git a/agent/consul/util_test.go b/agent/consul/util_test.go index b0c6e04d30..654673b628 100644 --- a/agent/consul/util_test.go +++ b/agent/consul/util_test.go @@ -9,6 +9,7 @@ import ( "github.com/hashicorp/go-version" "github.com/hashicorp/serf/serf" + "github.com/stretchr/testify/require" ) func TestGetPrivateIP(t *testing.T) { @@ -403,3 +404,133 @@ func TestServersMeetMinimumVersion(t *testing.T) { } } } + +func TestInterpolateHIL(t *testing.T) { + for _, test := range []struct { + name string + in string + vars map[string]string + exp string + ok bool + }{ + // valid HIL + { + "empty", + "", + map[string]string{}, + "", + true, + }, + { + "no vars", + "nothing", + map[string]string{}, + "nothing", + true, + }, + { + "just var", + "${item}", + map[string]string{"item": "value"}, + "value", + true, + }, + { + "var in middle", + "before ${item}after", + map[string]string{"item": "value"}, + "before valueafter", + true, + }, + { + "two vars", + "before ${item}after ${more}", + map[string]string{"item": "value", "more": "xyz"}, + "before valueafter xyz", + true, + }, + { + "missing map val", + "${item}", + map[string]string{"item": ""}, + "", + true, + }, + // "weird" HIL, but not technically invalid + { + "just end", + "}", + map[string]string{}, + "}", + true, + }, + { + "var without start", + " item }", + map[string]string{"item": "value"}, + " item }", + true, + }, + { + "two vars missing second start", + "before ${ item }after more }", + map[string]string{"item": "value", "more": "xyz"}, + "before valueafter more }", + true, + }, + // invalid HIL + { + "just start", + "${", + map[string]string{}, + "", + false, + }, + { + "backwards", + "}${", + map[string]string{}, + "", + false, + }, + { + "no varname", + "${}", + map[string]string{}, + "", + false, + }, + { + "missing map key", + "${item}", + map[string]string{}, + "", + false, + }, + { + "var without end", + "${ item ", + map[string]string{"item": "value"}, + "", + false, + }, + { + "two vars missing first end", + "before ${ item after ${ more }", + map[string]string{"item": "value", "more": "xyz"}, + "", + false, + }, + } { + t.Run(test.name, func(t *testing.T) { + out, err := InterpolateHIL(test.in, test.vars) + if test.ok { + require.NoError(t, err) + require.Equal(t, test.exp, out) + } else { + require.NotNil(t, err) + require.Equal(t, out, "") + } + }) + } +} diff --git a/agent/http_oss.go b/agent/http_oss.go index 6a0d5917bc..a4584a5a49 100644 --- a/agent/http_oss.go +++ b/agent/http_oss.go @@ -10,6 +10,8 @@ func init() { registerEndpoint("/v1/acl/info/", []string{"GET"}, (*HTTPServer).ACLGet) registerEndpoint("/v1/acl/clone/", []string{"PUT"}, (*HTTPServer).ACLClone) registerEndpoint("/v1/acl/list", []string{"GET"}, (*HTTPServer).ACLList) + registerEndpoint("/v1/acl/login", []string{"POST"}, (*HTTPServer).ACLLogin) + registerEndpoint("/v1/acl/logout", []string{"POST"}, (*HTTPServer).ACLLogout) registerEndpoint("/v1/acl/replication", []string{"GET"}, (*HTTPServer).ACLReplicationStatus) registerEndpoint("/v1/acl/policies", []string{"GET"}, (*HTTPServer).ACLPolicyList) registerEndpoint("/v1/acl/policy", []string{"PUT"}, (*HTTPServer).ACLPolicyCreate) @@ -18,6 +20,12 @@ func init() { registerEndpoint("/v1/acl/role", []string{"PUT"}, (*HTTPServer).ACLRoleCreate) registerEndpoint("/v1/acl/role/name/", []string{"GET"}, (*HTTPServer).ACLRoleReadByName) registerEndpoint("/v1/acl/role/", []string{"GET", "PUT", "DELETE"}, (*HTTPServer).ACLRoleCRUD) + registerEndpoint("/v1/acl/binding-rules", []string{"GET"}, (*HTTPServer).ACLBindingRuleList) + registerEndpoint("/v1/acl/binding-rule", []string{"PUT"}, (*HTTPServer).ACLBindingRuleCreate) + registerEndpoint("/v1/acl/binding-rule/", []string{"GET", "PUT", "DELETE"}, (*HTTPServer).ACLBindingRuleCRUD) + registerEndpoint("/v1/acl/auth-methods", []string{"GET"}, (*HTTPServer).ACLAuthMethodList) + registerEndpoint("/v1/acl/auth-method", []string{"PUT"}, (*HTTPServer).ACLAuthMethodCreate) + registerEndpoint("/v1/acl/auth-method/", []string{"GET", "PUT", "DELETE"}, (*HTTPServer).ACLAuthMethodCRUD) registerEndpoint("/v1/acl/rules/translate", []string{"POST"}, (*HTTPServer).ACLRulesTranslate) registerEndpoint("/v1/acl/rules/translate/", []string{"GET"}, (*HTTPServer).ACLRulesTranslateLegacyToken) registerEndpoint("/v1/acl/tokens", []string{"GET"}, (*HTTPServer).ACLTokenList) diff --git a/agent/structs/acl.go b/agent/structs/acl.go index 0b46b21e31..0fc63a12cd 100644 --- a/agent/structs/acl.go +++ b/agent/structs/acl.go @@ -186,9 +186,12 @@ func (s *ACLServiceIdentity) SyntheticPolicy() *ACLPolicy { rules := fmt.Sprintf(aclPolicyTemplateServiceIdentity, s.ServiceName, s.ServiceName) hasher := fnv.New128a() + hashID := fmt.Sprintf("%x", hasher.Sum([]byte(rules))) + policy := &ACLPolicy{} - policy.ID = fmt.Sprintf("%x", hasher.Sum([]byte(rules))) - policy.Name = fmt.Sprintf("synthetic-policy-%s", policy.ID) + policy.ID = hashID + policy.Name = fmt.Sprintf("synthetic-policy-%s", hashID) + policy.Description = "synthetic policy" policy.Rules = rules policy.Syntax = acl.SyntaxCurrent policy.Datacenters = s.Datacenters @@ -234,6 +237,9 @@ type ACLToken struct { // to the ACL datacenter and replicated to others. Local bool + // AuthMethod is the name of the auth method used to create this token. + AuthMethod string `json:",omitempty"` + // ExpirationTime represents the point after which a token should be // considered revoked and is eligible for destruction. The zero value // represents NO expiration. @@ -309,7 +315,11 @@ func (t *ACLToken) PolicyIDs() []string { } func (t *ACLToken) RoleIDs() []string { - var ids []string + if len(t.Roles) == 0 { + return nil + } + + ids := make([]string, 0, len(t.Roles)) for _, link := range t.Roles { ids = append(ids, link.ID) } @@ -345,7 +355,8 @@ func (t *ACLToken) UsesNonLegacyFields() bool { len(t.Roles) > 0 || t.Type == "" || t.HasExpirationTime() || - t.ExpirationTTL != 0 + t.ExpirationTTL != 0 || + t.AuthMethod != "" } func (t *ACLToken) EmbeddedPolicy() *ACLPolicy { @@ -428,7 +439,7 @@ func (t *ACLToken) SetHash(force bool) []byte { func (t *ACLToken) EstimateSize() int { // 41 = 16 (RaftIndex) + 8 (Hash) + 8 (ExpirationTime) + 8 (CreateTime) + 1 (Local) - size := 41 + len(t.AccessorID) + len(t.SecretID) + len(t.Description) + len(t.Type) + len(t.Rules) + size := 41 + len(t.AccessorID) + len(t.SecretID) + len(t.Description) + len(t.Type) + len(t.Rules) + len(t.AuthMethod) for _, link := range t.Policies { size += len(link.ID) + len(link.Name) } @@ -451,6 +462,7 @@ type ACLTokenListStub struct { Roles []ACLTokenRoleLink `json:",omitempty"` ServiceIdentities []*ACLServiceIdentity `json:",omitempty"` Local bool + AuthMethod string `json:",omitempty"` ExpirationTime *time.Time `json:",omitempty"` CreateTime time.Time `json:",omitempty"` Hash []byte @@ -469,6 +481,7 @@ func (token *ACLToken) Stub() *ACLTokenListStub { Roles: token.Roles, ServiceIdentities: token.ServiceIdentities, Local: token.Local, + AuthMethod: token.AuthMethod, ExpirationTime: token.ExpirationTime, CreateTime: token.CreateTime, Hash: token.Hash, @@ -722,8 +735,6 @@ type ACLRole struct { ID string // Name is the unique name to reference the role by. - // - // Validated with structs.isValidRoleName() Name string // Description is a human readable description (Optional) @@ -819,6 +830,136 @@ func (r *ACLRole) EstimateSize() int { return size } +const ( + // BindingRuleBindTypeService is the binding rule bind type that + // assigns a Service Identity to the token that is created using the value + // of the computed BindName as the ServiceName like: + // + // &ACLToken{ + // ...other fields... + // ServiceIdentities: []*ACLServiceIdentity{ + // &ACLServiceIdentity{ + // ServiceName: "", + // }, + // }, + // } + BindingRuleBindTypeService = "service" + + // BindingRuleBindTypeRole is the binding rule bind type that only allows + // the binding rule to function if a role with the given name (BindName) + // exists at login-time. If it does the token that is created is directly + // linked to that role like: + // + // &ACLToken{ + // ...other fields... + // Roles: []ACLTokenRoleLink{ + // { Name: "" } + // } + // } + // + // If it does not exist at login-time the rule is ignored. + BindingRuleBindTypeRole = "role" +) + +type ACLBindingRule struct { + // ID is the internal UUID associated with the binding rule + ID string + + // Description is a human readable description (Optional) + Description string + + // AuthMethod is the name of the auth method for which this rule applies. + AuthMethod string + + // Selector is an expression that matches against verified identity + // attributes returned from the auth method during login. + Selector string + + // BindType adjusts how this binding rule is applied at login time. The + // valid values are: + // + // - BindingRuleBindTypeService = "service" + // - BindingRuleBindTypeRole = "role" + BindType string + + // BindName is the target of the binding. Can be lightly templated using + // HIL ${foo} syntax from available field names. How it is used depends + // upon the BindType. + BindName string + + // Embedded Raft Metadata + RaftIndex `hash:"ignore"` +} + +func (r *ACLBindingRule) Clone() *ACLBindingRule { + r2 := *r + return &r2 +} + +type ACLBindingRules []*ACLBindingRule + +func (rules ACLBindingRules) Sort() { + sort.Slice(rules, func(i, j int) bool { + return rules[i].ID < rules[j].ID + }) +} + +type ACLAuthMethodListStub struct { + Name string + Description string + Type string + CreateIndex uint64 + ModifyIndex uint64 +} + +func (p *ACLAuthMethod) Stub() *ACLAuthMethodListStub { + return &ACLAuthMethodListStub{ + Name: p.Name, + Description: p.Description, + Type: p.Type, + CreateIndex: p.CreateIndex, + ModifyIndex: p.ModifyIndex, + } +} + +type ACLAuthMethods []*ACLAuthMethod +type ACLAuthMethodListStubs []*ACLAuthMethodListStub + +func (methods ACLAuthMethods) Sort() { + sort.Slice(methods, func(i, j int) bool { + return methods[i].Name < methods[j].Name + }) +} + +func (methods ACLAuthMethodListStubs) Sort() { + sort.Slice(methods, func(i, j int) bool { + return methods[i].Name < methods[j].Name + }) +} + +type ACLAuthMethod struct { + // Name is a unique identifier for this specific auth method. + // + // Immutable once set and only settable during create. + Name string + + // Type is the type of the auth method this is. + // + // Immutable once set and only settable during create. + Type string + + // Description is just an optional bunch of explanatory text. + Description string + + // Configuration is arbitrary configuration for the auth method. This + // should only contain primitive values and containers (such as lists and + // maps). + Config map[string]interface{} + + // Embedded Raft Metadata + RaftIndex `hash:"ignore"` +} + type ACLReplicationType string const ( @@ -898,6 +1039,7 @@ type ACLTokenListRequest struct { IncludeGlobal bool // Whether global tokens should be included Policy string // Policy filter Role string // Role filter + AuthMethod string // Auth Method filter Datacenter string // The datacenter to perform the request within QueryOptions } @@ -1068,7 +1210,7 @@ func cloneStringSlice(s []string) []string { // ACLRoleSetRequest is used at the RPC layer for creation and update requests type ACLRoleSetRequest struct { - Role ACLRole // The policy to upsert + Role ACLRole // The role to upsert Datacenter string // The datacenter to perform the request within WriteRequest } @@ -1154,3 +1296,161 @@ type ACLRoleBatchSetRequest struct { type ACLRoleBatchDeleteRequest struct { RoleIDs []string } + +// ACLBindingRuleSetRequest is used at the RPC layer for creation and update requests +type ACLBindingRuleSetRequest struct { + BindingRule ACLBindingRule // The rule to upsert + Datacenter string // The datacenter to perform the request within + WriteRequest +} + +func (r *ACLBindingRuleSetRequest) RequestDatacenter() string { + return r.Datacenter +} + +// ACLBindingRuleDeleteRequest is used at the RPC layer deletion requests +type ACLBindingRuleDeleteRequest struct { + BindingRuleID string // id of the rule to delete + Datacenter string // The datacenter to perform the request within + WriteRequest +} + +func (r *ACLBindingRuleDeleteRequest) RequestDatacenter() string { + return r.Datacenter +} + +// ACLBindingRuleGetRequest is used at the RPC layer to perform rule read operations +type ACLBindingRuleGetRequest struct { + BindingRuleID string // id used for the rule lookup + Datacenter string // The datacenter to perform the request within + QueryOptions +} + +func (r *ACLBindingRuleGetRequest) RequestDatacenter() string { + return r.Datacenter +} + +// ACLBindingRuleListRequest is used at the RPC layer to request a listing of rules +type ACLBindingRuleListRequest struct { + AuthMethod string // optional filter + Datacenter string // The datacenter to perform the request within + QueryOptions +} + +func (r *ACLBindingRuleListRequest) RequestDatacenter() string { + return r.Datacenter +} + +type ACLBindingRuleListResponse struct { + BindingRules ACLBindingRules + QueryMeta +} + +// ACLBindingRuleResponse returns a single binding + metadata +type ACLBindingRuleResponse struct { + BindingRule *ACLBindingRule + QueryMeta +} + +// ACLBindingRuleBatchSetRequest is used at the Raft layer for batching +// multiple rule creations and updates +type ACLBindingRuleBatchSetRequest struct { + BindingRules ACLBindingRules +} + +// ACLBindingRuleBatchDeleteRequest is used at the Raft layer for batching +// multiple rule deletions +type ACLBindingRuleBatchDeleteRequest struct { + BindingRuleIDs []string +} + +// ACLAuthMethodSetRequest is used at the RPC layer for creation and update requests +type ACLAuthMethodSetRequest struct { + AuthMethod ACLAuthMethod // The auth method to upsert + Datacenter string // The datacenter to perform the request within + WriteRequest +} + +func (r *ACLAuthMethodSetRequest) RequestDatacenter() string { + return r.Datacenter +} + +// ACLAuthMethodDeleteRequest is used at the RPC layer deletion requests +type ACLAuthMethodDeleteRequest struct { + AuthMethodName string // name of the auth method to delete + Datacenter string // The datacenter to perform the request within + WriteRequest +} + +func (r *ACLAuthMethodDeleteRequest) RequestDatacenter() string { + return r.Datacenter +} + +// ACLAuthMethodGetRequest is used at the RPC layer to perform rule read operations +type ACLAuthMethodGetRequest struct { + AuthMethodName string // name used for the auth method lookup + Datacenter string // The datacenter to perform the request within + QueryOptions +} + +func (r *ACLAuthMethodGetRequest) RequestDatacenter() string { + return r.Datacenter +} + +// ACLAuthMethodListRequest is used at the RPC layer to request a listing of auth methods +type ACLAuthMethodListRequest struct { + Datacenter string // The datacenter to perform the request within + QueryOptions +} + +func (r *ACLAuthMethodListRequest) RequestDatacenter() string { + return r.Datacenter +} + +type ACLAuthMethodListResponse struct { + AuthMethods ACLAuthMethodListStubs + QueryMeta +} + +// ACLAuthMethodResponse returns a single auth method + metadata +type ACLAuthMethodResponse struct { + AuthMethod *ACLAuthMethod + QueryMeta +} + +// ACLAuthMethodBatchSetRequest is used at the Raft layer for batching +// multiple auth method creations and updates +type ACLAuthMethodBatchSetRequest struct { + AuthMethods ACLAuthMethods +} + +// ACLAuthMethodBatchDeleteRequest is used at the Raft layer for batching +// multiple auth method deletions +type ACLAuthMethodBatchDeleteRequest struct { + AuthMethodNames []string +} + +type ACLLoginParams struct { + AuthMethod string + BearerToken string + Meta map[string]string `json:",omitempty"` +} + +type ACLLoginRequest struct { + Auth *ACLLoginParams + Datacenter string // The datacenter to perform the request within + WriteRequest +} + +func (r *ACLLoginRequest) RequestDatacenter() string { + return r.Datacenter +} + +type ACLLogoutRequest struct { + Datacenter string // The datacenter to perform the request within + WriteRequest +} + +func (r *ACLLogoutRequest) RequestDatacenter() string { + return r.Datacenter +} diff --git a/agent/structs/acl_cache.go b/agent/structs/acl_cache.go index 8a4f494194..1494727070 100644 --- a/agent/structs/acl_cache.go +++ b/agent/structs/acl_cache.go @@ -60,7 +60,6 @@ func (e *AuthorizerCacheEntry) Age() time.Duration { return time.Since(e.CacheTime) } -// RoleCacheEntry is the payload for by by-id and by-name caches. type RoleCacheEntry struct { Role *ACLRole CacheTime time.Time @@ -173,7 +172,7 @@ func (c *ACLCaches) GetAuthorizer(id string) *AuthorizerCacheEntry { return nil } -// GetRoleByID fetches a role from the cache by id and returns it +// GetRole fetches a role from the cache by id and returns it func (c *ACLCaches) GetRole(roleID string) *RoleCacheEntry { if c == nil || c.roles == nil { return nil @@ -228,9 +227,11 @@ func (c *ACLCaches) PutAuthorizerWithTTL(id string, authorizer acl.Authorizer, t } func (c *ACLCaches) PutRole(roleID string, role *ACLRole) { - if c != nil && c.roles != nil { - c.roles.Add(roleID, &RoleCacheEntry{Role: role, CacheTime: time.Now()}) + if c == nil || c.roles == nil { + return } + + c.roles.Add(roleID, &RoleCacheEntry{Role: role, CacheTime: time.Now()}) } func (c *ACLCaches) RemoveIdentity(id string) { @@ -246,7 +247,7 @@ func (c *ACLCaches) RemovePolicy(policyID string) { } func (c *ACLCaches) RemoveRole(roleID string) { - if c != nil && c.roles != nil && roleID != "" { + if c != nil && c.roles != nil { c.roles.Remove(roleID) } } diff --git a/agent/structs/acl_cache_test.go b/agent/structs/acl_cache_test.go index dbbf717c88..337d1860f3 100644 --- a/agent/structs/acl_cache_test.go +++ b/agent/structs/acl_cache_test.go @@ -113,6 +113,7 @@ func TestStructs_ACLCaches(t *testing.T) { require.NotNil(t, cache) cache.PutRole("foo", &ACLRole{}) + entry := cache.GetRole("foo") require.NotNil(t, entry) require.NotNil(t, entry.Role) diff --git a/agent/structs/acl_test.go b/agent/structs/acl_test.go index 0d69c9886d..a7860a49d6 100644 --- a/agent/structs/acl_test.go +++ b/agent/structs/acl_test.go @@ -188,12 +188,13 @@ node_prefix "" { expect := &ACLPolicy{ Syntax: acl.SyntaxCurrent, Datacenters: test.datacenters, + Description: "synthetic policy", Rules: test.expectRules, } got := svcid.SyntheticPolicy() require.NotEmpty(t, got.ID) - require.Equal(t, got.Name, "synthetic-policy-"+got.ID) + require.True(t, strings.HasPrefix(got.Name, "synthetic-policy-")) // strip irrelevant fields before equality got.ID = "" got.Name = "" diff --git a/agent/structs/structs.go b/agent/structs/structs.go index 2ad0a4e97c..56a19fc4e7 100644 --- a/agent/structs/structs.go +++ b/agent/structs/structs.go @@ -33,31 +33,35 @@ type RaftIndex struct { // These are serialized between Consul servers and stored in Consul snapshots, // so entries must only ever be added. const ( - RegisterRequestType MessageType = 0 - DeregisterRequestType = 1 - KVSRequestType = 2 - SessionRequestType = 3 - ACLRequestType = 4 // DEPRECATED (ACL-Legacy-Compat) - TombstoneRequestType = 5 - CoordinateBatchUpdateType = 6 - PreparedQueryRequestType = 7 - TxnRequestType = 8 - AutopilotRequestType = 9 - AreaRequestType = 10 - ACLBootstrapRequestType = 11 - IntentionRequestType = 12 - ConnectCARequestType = 13 - ConnectCAProviderStateType = 14 - ConnectCAConfigType = 15 // FSM snapshots only. - IndexRequestType = 16 // FSM snapshots only. - ACLTokenSetRequestType = 17 - ACLTokenDeleteRequestType = 18 - ACLPolicySetRequestType = 19 - ACLPolicyDeleteRequestType = 20 - ConnectCALeafRequestType = 21 - ConfigEntryRequestType = 22 - ACLRoleSetRequestType = 23 - ACLRoleDeleteRequestType = 24 + RegisterRequestType MessageType = 0 + DeregisterRequestType = 1 + KVSRequestType = 2 + SessionRequestType = 3 + ACLRequestType = 4 // DEPRECATED (ACL-Legacy-Compat) + TombstoneRequestType = 5 + CoordinateBatchUpdateType = 6 + PreparedQueryRequestType = 7 + TxnRequestType = 8 + AutopilotRequestType = 9 + AreaRequestType = 10 + ACLBootstrapRequestType = 11 + IntentionRequestType = 12 + ConnectCARequestType = 13 + ConnectCAProviderStateType = 14 + ConnectCAConfigType = 15 // FSM snapshots only. + IndexRequestType = 16 // FSM snapshots only. + ACLTokenSetRequestType = 17 + ACLTokenDeleteRequestType = 18 + ACLPolicySetRequestType = 19 + ACLPolicyDeleteRequestType = 20 + ConnectCALeafRequestType = 21 + ConfigEntryRequestType = 22 + ACLRoleSetRequestType = 23 + ACLRoleDeleteRequestType = 24 + ACLBindingRuleSetRequestType = 25 + ACLBindingRuleDeleteRequestType = 26 + ACLAuthMethodSetRequestType = 27 + ACLAuthMethodDeleteRequestType = 28 ) const ( diff --git a/api/acl.go b/api/acl.go index 2713d0ddc9..3327f667c3 100644 --- a/api/acl.go +++ b/api/acl.go @@ -6,6 +6,8 @@ import ( "io/ioutil" "net/url" "time" + + "github.com/mitchellh/mapstructure" ) const ( @@ -132,6 +134,96 @@ type ACLRole struct { ModifyIndex uint64 } +// BindingRuleBindType is the type of binding rule mechanism used. +type BindingRuleBindType string + +const ( + // BindingRuleBindTypeService binds to a service identity with the given name. + BindingRuleBindTypeService BindingRuleBindType = "service" + + // BindingRuleBindTypeRole binds to pre-existing roles with the given name. + BindingRuleBindTypeRole BindingRuleBindType = "role" +) + +type ACLBindingRule struct { + ID string + Description string + AuthMethod string + Selector string + BindType BindingRuleBindType + BindName string + + CreateIndex uint64 + ModifyIndex uint64 +} + +type ACLAuthMethod struct { + Name string + Type string + Description string + + // Configuration is arbitrary configuration for the auth method. This + // should only contain primitive values and containers (such as lists and + // maps). + Config map[string]interface{} + + CreateIndex uint64 + ModifyIndex uint64 +} + +type ACLAuthMethodListEntry struct { + Name string + Type string + Description string + CreateIndex uint64 + ModifyIndex uint64 +} + +// ParseKubernetesAuthMethodConfig takes a raw config map and returns a parsed +// KubernetesAuthMethodConfig. +func ParseKubernetesAuthMethodConfig(raw map[string]interface{}) (*KubernetesAuthMethodConfig, error) { + var config KubernetesAuthMethodConfig + decodeConf := &mapstructure.DecoderConfig{ + Result: &config, + WeaklyTypedInput: true, + } + + decoder, err := mapstructure.NewDecoder(decodeConf) + if err != nil { + return nil, err + } + + if err := decoder.Decode(raw); err != nil { + return nil, fmt.Errorf("error decoding config: %s", err) + } + + return &config, nil +} + +// KubernetesAuthMethodConfig is the config for the built-in Consul auth method +// for Kubernetes. +type KubernetesAuthMethodConfig struct { + Host string `json:",omitempty"` + CACert string `json:",omitempty"` + ServiceAccountJWT string `json:",omitempty"` +} + +// RenderToConfig converts this into a map[string]interface{} suitable for use +// in the ACLAuthMethod.Config field. +func (c *KubernetesAuthMethodConfig) RenderToConfig() map[string]interface{} { + return map[string]interface{}{ + "Host": c.Host, + "CACert": c.CACert, + "ServiceAccountJWT": c.ServiceAccountJWT, + } +} + +type ACLLoginParams struct { + AuthMethod string + BearerToken string + Meta map[string]string `json:",omitempty"` +} + // ACL can be used to query the ACL endpoints type ACL struct { c *Client @@ -498,7 +590,7 @@ func (a *ACL) PolicyCreate(policy *ACLPolicy, q *WriteOptions) (*ACLPolicy, *Wri // existing policy ID func (a *ACL) PolicyUpdate(policy *ACLPolicy, q *WriteOptions) (*ACLPolicy, *WriteMeta, error) { if policy.ID == "" { - return nil, nil, fmt.Errorf("Must specify an ID in Policy Creation") + return nil, nil, fmt.Errorf("Must specify an ID in Policy Update") } r := a.c.newRequest("PUT", "/v1/acl/policy/"+policy.ID) @@ -654,7 +746,7 @@ func (a *ACL) RoleCreate(role *ACLRole, q *WriteOptions) (*ACLRole, *WriteMeta, // existing role ID func (a *ACL) RoleUpdate(role *ACLRole, q *WriteOptions) (*ACLRole, *WriteMeta, error) { if role.ID == "" { - return nil, nil, fmt.Errorf("Must specify an ID in Role Creation") + return nil, nil, fmt.Errorf("Must specify an ID in Role Update") } r := a.c.newRequest("PUT", "/v1/acl/role/"+role.ID) @@ -763,3 +855,271 @@ func (a *ACL) RoleList(q *QueryOptions) ([]*ACLRole, *QueryMeta, error) { } return entries, qm, nil } + +// AuthMethodCreate will create a new auth method. +func (a *ACL) AuthMethodCreate(method *ACLAuthMethod, q *WriteOptions) (*ACLAuthMethod, *WriteMeta, error) { + if method.Name == "" { + return nil, nil, fmt.Errorf("Must specify a Name in Auth Method Creation") + } + + r := a.c.newRequest("PUT", "/v1/acl/auth-method") + r.setWriteOptions(q) + r.obj = method + rtt, resp, err := requireOK(a.c.doRequest(r)) + if err != nil { + return nil, nil, err + } + defer resp.Body.Close() + + wm := &WriteMeta{RequestTime: rtt} + var out ACLAuthMethod + if err := decodeBody(resp, &out); err != nil { + return nil, nil, err + } + + return &out, wm, nil +} + +// AuthMethodUpdate updates an auth method. +func (a *ACL) AuthMethodUpdate(method *ACLAuthMethod, q *WriteOptions) (*ACLAuthMethod, *WriteMeta, error) { + if method.Name == "" { + return nil, nil, fmt.Errorf("Must specify a Name in Auth Method Update") + } + + r := a.c.newRequest("PUT", "/v1/acl/auth-method/"+url.QueryEscape(method.Name)) + r.setWriteOptions(q) + r.obj = method + rtt, resp, err := requireOK(a.c.doRequest(r)) + if err != nil { + return nil, nil, err + } + defer resp.Body.Close() + + wm := &WriteMeta{RequestTime: rtt} + var out ACLAuthMethod + if err := decodeBody(resp, &out); err != nil { + return nil, nil, err + } + + return &out, wm, nil +} + +// AuthMethodDelete deletes an auth method given its Name. +func (a *ACL) AuthMethodDelete(methodName string, q *WriteOptions) (*WriteMeta, error) { + if methodName == "" { + return nil, fmt.Errorf("Must specify a Name in Auth Method Delete") + } + + r := a.c.newRequest("DELETE", "/v1/acl/auth-method/"+url.QueryEscape(methodName)) + r.setWriteOptions(q) + rtt, resp, err := requireOK(a.c.doRequest(r)) + if err != nil { + return nil, err + } + resp.Body.Close() + + wm := &WriteMeta{RequestTime: rtt} + return wm, nil +} + +// AuthMethodRead retrieves the auth method. Returns nil if not found. +func (a *ACL) AuthMethodRead(methodName string, q *QueryOptions) (*ACLAuthMethod, *QueryMeta, error) { + if methodName == "" { + return nil, nil, fmt.Errorf("Must specify a Name in Auth Method Read") + } + + r := a.c.newRequest("GET", "/v1/acl/auth-method/"+url.QueryEscape(methodName)) + r.setQueryOptions(q) + found, rtt, resp, err := requireNotFoundOrOK(a.c.doRequest(r)) + if err != nil { + return nil, nil, err + } + defer resp.Body.Close() + + qm := &QueryMeta{} + parseQueryMeta(resp, qm) + qm.RequestTime = rtt + + if !found { + return nil, qm, nil + } + + var out ACLAuthMethod + if err := decodeBody(resp, &out); err != nil { + return nil, nil, err + } + + return &out, qm, nil +} + +// AuthMethodList retrieves a listing of all auth methods. The listing does not +// include some metadata for the auth method as those should be retrieved by +// subsequent calls to AuthMethodRead. +func (a *ACL) AuthMethodList(q *QueryOptions) ([]*ACLAuthMethodListEntry, *QueryMeta, error) { + r := a.c.newRequest("GET", "/v1/acl/auth-methods") + r.setQueryOptions(q) + rtt, resp, err := requireOK(a.c.doRequest(r)) + if err != nil { + return nil, nil, err + } + defer resp.Body.Close() + + qm := &QueryMeta{} + parseQueryMeta(resp, qm) + qm.RequestTime = rtt + + var entries []*ACLAuthMethodListEntry + if err := decodeBody(resp, &entries); err != nil { + return nil, nil, err + } + return entries, qm, nil +} + +// BindingRuleCreate will create a new binding rule. It is not allowed for the +// binding rule parameter's ID field to be set as this will be generated by +// Consul while processing the request. +func (a *ACL) BindingRuleCreate(rule *ACLBindingRule, q *WriteOptions) (*ACLBindingRule, *WriteMeta, error) { + if rule.ID != "" { + return nil, nil, fmt.Errorf("Cannot specify an ID in Binding Rule Creation") + } + + r := a.c.newRequest("PUT", "/v1/acl/binding-rule") + r.setWriteOptions(q) + r.obj = rule + rtt, resp, err := requireOK(a.c.doRequest(r)) + if err != nil { + return nil, nil, err + } + defer resp.Body.Close() + + wm := &WriteMeta{RequestTime: rtt} + var out ACLBindingRule + if err := decodeBody(resp, &out); err != nil { + return nil, nil, err + } + + return &out, wm, nil +} + +// BindingRuleUpdate updates a binding rule. The ID field of the role binding +// rule parameter must be set to an existing binding rule ID. +func (a *ACL) BindingRuleUpdate(rule *ACLBindingRule, q *WriteOptions) (*ACLBindingRule, *WriteMeta, error) { + if rule.ID == "" { + return nil, nil, fmt.Errorf("Must specify an ID in Binding Rule Update") + } + + r := a.c.newRequest("PUT", "/v1/acl/binding-rule/"+rule.ID) + r.setWriteOptions(q) + r.obj = rule + rtt, resp, err := requireOK(a.c.doRequest(r)) + if err != nil { + return nil, nil, err + } + defer resp.Body.Close() + + wm := &WriteMeta{RequestTime: rtt} + var out ACLBindingRule + if err := decodeBody(resp, &out); err != nil { + return nil, nil, err + } + + return &out, wm, nil +} + +// BindingRuleDelete deletes a binding rule given its ID. +func (a *ACL) BindingRuleDelete(bindingRuleID string, q *WriteOptions) (*WriteMeta, error) { + r := a.c.newRequest("DELETE", "/v1/acl/binding-rule/"+bindingRuleID) + r.setWriteOptions(q) + rtt, resp, err := requireOK(a.c.doRequest(r)) + if err != nil { + return nil, err + } + resp.Body.Close() + + wm := &WriteMeta{RequestTime: rtt} + return wm, nil +} + +// BindingRuleRead retrieves the binding rule details. Returns nil if not found. +func (a *ACL) BindingRuleRead(bindingRuleID string, q *QueryOptions) (*ACLBindingRule, *QueryMeta, error) { + r := a.c.newRequest("GET", "/v1/acl/binding-rule/"+bindingRuleID) + r.setQueryOptions(q) + found, rtt, resp, err := requireNotFoundOrOK(a.c.doRequest(r)) + if err != nil { + return nil, nil, err + } + defer resp.Body.Close() + + qm := &QueryMeta{} + parseQueryMeta(resp, qm) + qm.RequestTime = rtt + + if !found { + return nil, qm, nil + } + + var out ACLBindingRule + if err := decodeBody(resp, &out); err != nil { + return nil, nil, err + } + + return &out, qm, nil +} + +// BindingRuleList retrieves a listing of all binding rules. +func (a *ACL) BindingRuleList(methodName string, q *QueryOptions) ([]*ACLBindingRule, *QueryMeta, error) { + r := a.c.newRequest("GET", "/v1/acl/binding-rules") + if methodName != "" { + r.params.Set("authmethod", methodName) + } + r.setQueryOptions(q) + rtt, resp, err := requireOK(a.c.doRequest(r)) + if err != nil { + return nil, nil, err + } + defer resp.Body.Close() + + qm := &QueryMeta{} + parseQueryMeta(resp, qm) + qm.RequestTime = rtt + + var entries []*ACLBindingRule + if err := decodeBody(resp, &entries); err != nil { + return nil, nil, err + } + return entries, qm, nil +} + +// Login is used to exchange auth method credentials for a newly-minted Consul Token. +func (a *ACL) Login(auth *ACLLoginParams, q *WriteOptions) (*ACLToken, *WriteMeta, error) { + r := a.c.newRequest("POST", "/v1/acl/login") + r.setWriteOptions(q) + r.obj = auth + + rtt, resp, err := requireOK(a.c.doRequest(r)) + if err != nil { + return nil, nil, err + } + defer resp.Body.Close() + + wm := &WriteMeta{RequestTime: rtt} + var out ACLToken + if err := decodeBody(resp, &out); err != nil { + return nil, nil, err + } + return &out, wm, nil +} + +// Logout is used to destroy a Consul Token created via Login(). +func (a *ACL) Logout(q *WriteOptions) (*WriteMeta, error) { + r := a.c.newRequest("POST", "/v1/acl/logout") + r.setWriteOptions(q) + rtt, resp, err := requireOK(a.c.doRequest(r)) + if err != nil { + return nil, err + } + resp.Body.Close() + + wm := &WriteMeta{RequestTime: rtt} + return wm, nil +} diff --git a/api/api.go b/api/api.go index e8370d0441..4b17ff6cda 100644 --- a/api/api.go +++ b/api/api.go @@ -30,6 +30,10 @@ const ( // the HTTP token. HTTPTokenEnvName = "CONSUL_HTTP_TOKEN" + // HTTPTokenFileEnvName defines an environment variable name which sets + // the HTTP token file. + HTTPTokenFileEnvName = "CONSUL_HTTP_TOKEN_FILE" + // HTTPAuthEnvName defines an environment variable name which sets // the HTTP authentication header. HTTPAuthEnvName = "CONSUL_HTTP_AUTH" @@ -280,6 +284,10 @@ type Config struct { // which overrides the agent's default token. Token string + // TokenFile is a file containing the current token to use for this client. + // If provided it is read once at startup and never again. + TokenFile string + TLSConfig TLSConfig } @@ -343,6 +351,10 @@ func defaultConfig(transportFn func() *http.Transport) *Config { config.Address = addr } + if tokenFile := os.Getenv(HTTPTokenFileEnvName); tokenFile != "" { + config.TokenFile = tokenFile + } + if token := os.Getenv(HTTPTokenEnvName); token != "" { config.Token = token } @@ -449,6 +461,7 @@ func (c *Config) GenerateEnv() []string { env = append(env, fmt.Sprintf("%s=%s", HTTPAddrEnvName, c.Address), fmt.Sprintf("%s=%s", HTTPTokenEnvName, c.Token), + fmt.Sprintf("%s=%s", HTTPTokenFileEnvName, c.TokenFile), fmt.Sprintf("%s=%t", HTTPSSLEnvName, c.Scheme == "https"), fmt.Sprintf("%s=%s", HTTPCAFile, c.TLSConfig.CAFile), fmt.Sprintf("%s=%s", HTTPCAPath, c.TLSConfig.CAPath), @@ -541,6 +554,19 @@ func NewClient(config *Config) (*Client, error) { config.Address = parts[1] } + // If the TokenFile is set, always use that, even if a Token is configured. + // This is because when TokenFile is set it is read into the Token field. + // We want any derived clients to have to re-read the token file. + if config.TokenFile != "" { + data, err := ioutil.ReadFile(config.TokenFile) + if err != nil { + return nil, fmt.Errorf("Error loading token file: %s", err) + } + + if token := strings.TrimSpace(string(data)); token != "" { + config.Token = token + } + } if config.Token == "" { config.Token = defConfig.Token } @@ -820,6 +846,8 @@ func (c *Client) write(endpoint string, in, out interface{}, q *WriteOptions) (* } // parseQueryMeta is used to help parse query meta-data +// +// TODO(rb): bug? the error from this function is never handled func parseQueryMeta(resp *http.Response, q *QueryMeta) error { header := resp.Header diff --git a/api/api_test.go b/api/api_test.go index eca799e022..7934ed87b1 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -875,9 +875,10 @@ func TestAPI_GenerateEnv(t *testing.T) { t.Parallel() c := &Config{ - Address: "127.0.0.1:8500", - Token: "test", - Scheme: "http", + Address: "127.0.0.1:8500", + Token: "test", + TokenFile: "test.file", + Scheme: "http", TLSConfig: TLSConfig{ CAFile: "", CAPath: "", @@ -891,6 +892,7 @@ func TestAPI_GenerateEnv(t *testing.T) { expected := []string{ "CONSUL_HTTP_ADDR=127.0.0.1:8500", "CONSUL_HTTP_TOKEN=test", + "CONSUL_HTTP_TOKEN_FILE=test.file", "CONSUL_HTTP_SSL=false", "CONSUL_CACERT=", "CONSUL_CAPATH=", @@ -908,9 +910,10 @@ func TestAPI_GenerateEnvHTTPS(t *testing.T) { t.Parallel() c := &Config{ - Address: "127.0.0.1:8500", - Token: "test", - Scheme: "https", + Address: "127.0.0.1:8500", + Token: "test", + TokenFile: "test.file", + Scheme: "https", TLSConfig: TLSConfig{ CAFile: "/var/consul/ca.crt", CAPath: "/var/consul/ca.dir", @@ -928,6 +931,7 @@ func TestAPI_GenerateEnvHTTPS(t *testing.T) { expected := []string{ "CONSUL_HTTP_ADDR=127.0.0.1:8500", "CONSUL_HTTP_TOKEN=test", + "CONSUL_HTTP_TOKEN_FILE=test.file", "CONSUL_HTTP_SSL=true", "CONSUL_CACERT=/var/consul/ca.crt", "CONSUL_CAPATH=/var/consul/ca.dir", diff --git a/command/acl/acl_helpers.go b/command/acl/acl_helpers.go index 928f8beb55..57c51ce14b 100644 --- a/command/acl/acl_helpers.go +++ b/command/acl/acl_helpers.go @@ -1,6 +1,7 @@ package acl import ( + "encoding/json" "fmt" "strings" @@ -23,20 +24,26 @@ func PrintToken(token *api.ACLToken, ui cli.Ui, showMeta bool) { ui.Info(fmt.Sprintf("Create Index: %d", token.CreateIndex)) ui.Info(fmt.Sprintf("Modify Index: %d", token.ModifyIndex)) } - ui.Info(fmt.Sprintf("Policies:")) - for _, policy := range token.Policies { - ui.Info(fmt.Sprintf(" %s - %s", policy.ID, policy.Name)) + if len(token.Policies) > 0 { + ui.Info(fmt.Sprintf("Policies:")) + for _, policy := range token.Policies { + ui.Info(fmt.Sprintf(" %s - %s", policy.ID, policy.Name)) + } } - ui.Info(fmt.Sprintf("Roles:")) - for _, role := range token.Roles { - ui.Info(fmt.Sprintf(" %s - %s", role.ID, role.Name)) + if len(token.Roles) > 0 { + ui.Info(fmt.Sprintf("Roles:")) + for _, role := range token.Roles { + ui.Info(fmt.Sprintf(" %s - %s", role.ID, role.Name)) + } } - ui.Info(fmt.Sprintf("Service Identities:")) - for _, svcid := range token.ServiceIdentities { - if len(svcid.Datacenters) > 0 { - ui.Info(fmt.Sprintf(" %s (Datacenters: %s)", svcid.ServiceName, strings.Join(svcid.Datacenters, ", "))) - } else { - ui.Info(fmt.Sprintf(" %s (Datacenters: all)", svcid.ServiceName)) + if len(token.ServiceIdentities) > 0 { + ui.Info(fmt.Sprintf("Service Identities:")) + for _, svcid := range token.ServiceIdentities { + if len(svcid.Datacenters) > 0 { + ui.Info(fmt.Sprintf(" %s (Datacenters: %s)", svcid.ServiceName, strings.Join(svcid.Datacenters, ", "))) + } else { + ui.Info(fmt.Sprintf(" %s (Datacenters: all)", svcid.ServiceName)) + } } } if token.Rules != "" { @@ -59,20 +66,26 @@ func PrintTokenListEntry(token *api.ACLTokenListEntry, ui cli.Ui, showMeta bool) ui.Info(fmt.Sprintf("Create Index: %d", token.CreateIndex)) ui.Info(fmt.Sprintf("Modify Index: %d", token.ModifyIndex)) } - ui.Info(fmt.Sprintf("Policies:")) - for _, policy := range token.Policies { - ui.Info(fmt.Sprintf(" %s - %s", policy.ID, policy.Name)) + if len(token.Policies) > 0 { + ui.Info(fmt.Sprintf("Policies:")) + for _, policy := range token.Policies { + ui.Info(fmt.Sprintf(" %s - %s", policy.ID, policy.Name)) + } } - ui.Info(fmt.Sprintf("Roles:")) - for _, role := range token.Roles { - ui.Info(fmt.Sprintf(" %s - %s", role.ID, role.Name)) + if len(token.Roles) > 0 { + ui.Info(fmt.Sprintf("Roles:")) + for _, role := range token.Roles { + ui.Info(fmt.Sprintf(" %s - %s", role.ID, role.Name)) + } } - ui.Info(fmt.Sprintf("Service Identities:")) - for _, svcid := range token.ServiceIdentities { - if len(svcid.Datacenters) > 0 { - ui.Info(fmt.Sprintf(" %s (Datacenters: %s)", svcid.ServiceName, strings.Join(svcid.Datacenters, ", "))) - } else { - ui.Info(fmt.Sprintf(" %s (Datacenters: all)", svcid.ServiceName)) + if len(token.ServiceIdentities) > 0 { + ui.Info(fmt.Sprintf("Service Identities:")) + for _, svcid := range token.ServiceIdentities { + if len(svcid.Datacenters) > 0 { + ui.Info(fmt.Sprintf(" %s (Datacenters: %s)", svcid.ServiceName, strings.Join(svcid.Datacenters, ", "))) + } else { + ui.Info(fmt.Sprintf(" %s (Datacenters: all)", svcid.ServiceName)) + } } } } @@ -112,16 +125,20 @@ func PrintRole(role *api.ACLRole, ui cli.Ui, showMeta bool) { ui.Info(fmt.Sprintf("Create Index: %d", role.CreateIndex)) ui.Info(fmt.Sprintf("Modify Index: %d", role.ModifyIndex)) } - ui.Info(fmt.Sprintf("Policies:")) - for _, policy := range role.Policies { - ui.Info(fmt.Sprintf(" %s - %s", policy.ID, policy.Name)) + if len(role.Policies) > 0 { + ui.Info(fmt.Sprintf("Policies:")) + for _, policy := range role.Policies { + ui.Info(fmt.Sprintf(" %s - %s", policy.ID, policy.Name)) + } } - ui.Info(fmt.Sprintf("Service Identities:")) - for _, svcid := range role.ServiceIdentities { - if len(svcid.Datacenters) > 0 { - ui.Info(fmt.Sprintf(" %s (Datacenters: %s)", svcid.ServiceName, strings.Join(svcid.Datacenters, ", "))) - } else { - ui.Info(fmt.Sprintf(" %s (Datacenters: all)", svcid.ServiceName)) + if len(role.ServiceIdentities) > 0 { + ui.Info(fmt.Sprintf("Service Identities:")) + for _, svcid := range role.ServiceIdentities { + if len(svcid.Datacenters) > 0 { + ui.Info(fmt.Sprintf(" %s (Datacenters: %s)", svcid.ServiceName, strings.Join(svcid.Datacenters, ", "))) + } else { + ui.Info(fmt.Sprintf(" %s (Datacenters: all)", svcid.ServiceName)) + } } } } @@ -135,18 +152,74 @@ func PrintRoleListEntry(role *api.ACLRole, ui cli.Ui, showMeta bool) { ui.Info(fmt.Sprintf(" Create Index: %d", role.CreateIndex)) ui.Info(fmt.Sprintf(" Modify Index: %d", role.ModifyIndex)) } - ui.Info(fmt.Sprintf(" Policies:")) - for _, policy := range role.Policies { - ui.Info(fmt.Sprintf(" %s - %s", policy.ID, policy.Name)) - } - ui.Info(fmt.Sprintf(" Service Identities:")) - for _, svcid := range role.ServiceIdentities { - if len(svcid.Datacenters) > 0 { - ui.Info(fmt.Sprintf(" %s (Datacenters: %s)", svcid.ServiceName, strings.Join(svcid.Datacenters, ", "))) - } else { - ui.Info(fmt.Sprintf(" %s (Datacenters: all)", svcid.ServiceName)) + if len(role.Policies) > 0 { + ui.Info(fmt.Sprintf(" Policies:")) + for _, policy := range role.Policies { + ui.Info(fmt.Sprintf(" %s - %s", policy.ID, policy.Name)) } } + if len(role.ServiceIdentities) > 0 { + ui.Info(fmt.Sprintf(" Service Identities:")) + for _, svcid := range role.ServiceIdentities { + if len(svcid.Datacenters) > 0 { + ui.Info(fmt.Sprintf(" %s (Datacenters: %s)", svcid.ServiceName, strings.Join(svcid.Datacenters, ", "))) + } else { + ui.Info(fmt.Sprintf(" %s (Datacenters: all)", svcid.ServiceName)) + } + } + } +} + +func PrintAuthMethod(method *api.ACLAuthMethod, ui cli.Ui, showMeta bool) { + ui.Info(fmt.Sprintf("Name: %s", method.Name)) + ui.Info(fmt.Sprintf("Type: %s", method.Type)) + ui.Info(fmt.Sprintf("Description: %s", method.Description)) + if showMeta { + ui.Info(fmt.Sprintf("Create Index: %d", method.CreateIndex)) + ui.Info(fmt.Sprintf("Modify Index: %d", method.ModifyIndex)) + } + ui.Info(fmt.Sprintf("Config:")) + output, err := json.MarshalIndent(method.Config, "", " ") + if err != nil { + ui.Error(fmt.Sprintf("Error formatting auth method configuration: %s", err)) + } + ui.Output(string(output)) +} + +func PrintAuthMethodListEntry(method *api.ACLAuthMethodListEntry, ui cli.Ui, showMeta bool) { + ui.Info(fmt.Sprintf("%s:", method.Name)) + ui.Info(fmt.Sprintf(" Type: %s", method.Type)) + ui.Info(fmt.Sprintf(" Description: %s", method.Description)) + if showMeta { + ui.Info(fmt.Sprintf(" Create Index: %d", method.CreateIndex)) + ui.Info(fmt.Sprintf(" Modify Index: %d", method.ModifyIndex)) + } +} + +func PrintBindingRule(rule *api.ACLBindingRule, ui cli.Ui, showMeta bool) { + ui.Info(fmt.Sprintf("ID: %s", rule.ID)) + ui.Info(fmt.Sprintf("AuthMethod: %s", rule.AuthMethod)) + ui.Info(fmt.Sprintf("Description: %s", rule.Description)) + ui.Info(fmt.Sprintf("BindType: %s", rule.BindType)) + ui.Info(fmt.Sprintf("BindName: %s", rule.BindName)) + ui.Info(fmt.Sprintf("Selector: %s", rule.Selector)) + if showMeta { + ui.Info(fmt.Sprintf("Create Index: %d", rule.CreateIndex)) + ui.Info(fmt.Sprintf("Modify Index: %d", rule.ModifyIndex)) + } +} + +func PrintBindingRuleListEntry(rule *api.ACLBindingRule, ui cli.Ui, showMeta bool) { + ui.Info(fmt.Sprintf("%s:", rule.ID)) + ui.Info(fmt.Sprintf(" AuthMethod: %s", rule.AuthMethod)) + ui.Info(fmt.Sprintf(" Description: %s", rule.Description)) + ui.Info(fmt.Sprintf(" BindType: %s", rule.BindType)) + ui.Info(fmt.Sprintf(" BindName: %s", rule.BindName)) + ui.Info(fmt.Sprintf(" Selector: %s", rule.Selector)) + if showMeta { + ui.Info(fmt.Sprintf(" Create Index: %d", rule.CreateIndex)) + ui.Info(fmt.Sprintf(" Modify Index: %d", rule.ModifyIndex)) + } } func GetTokenIDFromPartial(client *api.Client, partialID string) (string, error) { @@ -309,6 +382,34 @@ func GetRoleIDByName(client *api.Client, name string) (string, error) { return "", fmt.Errorf("No such role with name %s", name) } +func GetBindingRuleIDFromPartial(client *api.Client, partialID string) (string, error) { + // the full UUID string was given + if len(partialID) == 36 { + return partialID, nil + } + + rules, _, err := client.ACL().BindingRuleList("", nil) + if err != nil { + return "", err + } + + ruleID := "" + for _, rule := range rules { + if strings.HasPrefix(rule.ID, partialID) { + if ruleID != "" { + return "", fmt.Errorf("Partial rule ID is not unique") + } + ruleID = rule.ID + } + } + + if ruleID == "" { + return "", fmt.Errorf("No such rule ID with prefix: %s", partialID) + } + + return ruleID, nil +} + func ExtractServiceIdentities(serviceIdents []string) ([]*api.ACLServiceIdentity, error) { var out []*api.ACLServiceIdentity for _, svcidRaw := range serviceIdents { @@ -329,3 +430,27 @@ func ExtractServiceIdentities(serviceIdents []string) ([]*api.ACLServiceIdentity } return out, nil } + +// TestKubernetesJWT_A is a valid service account jwt extracted from a minikube setup. +// +// { +// "iss": "kubernetes/serviceaccount", +// "kubernetes.io/serviceaccount/namespace": "default", +// "kubernetes.io/serviceaccount/secret.name": "admin-token-qlz42", +// "kubernetes.io/serviceaccount/service-account.name": "admin", +// "kubernetes.io/serviceaccount/service-account.uid": "738bc251-6532-11e9-b67f-48e6c8b8ecb5", +// "sub": "system:serviceaccount:default:admin" +// } +const TestKubernetesJWT_A = "eyJhbGciOiJSUzI1NiIsImtpZCI6IiJ9.eyJpc3MiOiJrdWJlcm5ldGVzL3NlcnZpY2VhY2NvdW50Iiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9uYW1lc3BhY2UiOiJkZWZhdWx0Iiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9zZWNyZXQubmFtZSI6ImFkbWluLXRva2VuLXFsejQyIiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9zZXJ2aWNlLWFjY291bnQubmFtZSI6ImFkbWluIiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9zZXJ2aWNlLWFjY291bnQudWlkIjoiNzM4YmMyNTEtNjUzMi0xMWU5LWI2N2YtNDhlNmM4YjhlY2I1Iiwic3ViIjoic3lzdGVtOnNlcnZpY2VhY2NvdW50OmRlZmF1bHQ6YWRtaW4ifQ.ixMlnWrAG7NVuTTKu8cdcYfM7gweS3jlKaEsIBNGOVEjPE7rtXtgMkAwjQTdYR08_0QBjkgzy5fQC5ZNyglSwONJ-bPaXGvhoH1cTnRi1dz9H_63CfqOCvQP1sbdkMeRxNTGVAyWZT76rXoCUIfHP4LY2I8aab0KN9FTIcgZRF0XPTtT70UwGIrSmRpxW38zjiy2ymWL01cc5VWGhJqVysmWmYk3wNp0h5N57H_MOrz4apQR4pKaamzskzjLxO55gpbmZFC76qWuUdexAR7DT2fpbHLOw90atN_NlLMY-VrXyW3-Ei5EhYaVreMB9PSpKwkrA4jULITohV-sxpa1LA" + +// TestKubernetesJWT_B is a valid service account jwt extracted from a minikube setup. +// +// { +// "iss": "kubernetes/serviceaccount", +// "kubernetes.io/serviceaccount/namespace": "default", +// "kubernetes.io/serviceaccount/secret.name": "demo-token-kmb9n", +// "kubernetes.io/serviceaccount/service-account.name": "demo", +// "kubernetes.io/serviceaccount/service-account.uid": "76091af4-4b56-11e9-ac4b-708b11801cbe", +// "sub": "system:serviceaccount:default:demo" +// } +const TestKubernetesJWT_B = "eyJhbGciOiJSUzI1NiIsImtpZCI6IiJ9.eyJpc3MiOiJrdWJlcm5ldGVzL3NlcnZpY2VhY2NvdW50Iiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9uYW1lc3BhY2UiOiJkZWZhdWx0Iiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9zZWNyZXQubmFtZSI6ImRlbW8tdG9rZW4ta21iOW4iLCJrdWJlcm5ldGVzLmlvL3NlcnZpY2VhY2NvdW50L3NlcnZpY2UtYWNjb3VudC5uYW1lIjoiZGVtbyIsImt1YmVybmV0ZXMuaW8vc2VydmljZWFjY291bnQvc2VydmljZS1hY2NvdW50LnVpZCI6Ijc2MDkxYWY0LTRiNTYtMTFlOS1hYzRiLTcwOGIxMTgwMWNiZSIsInN1YiI6InN5c3RlbTpzZXJ2aWNlYWNjb3VudDpkZWZhdWx0OmRlbW8ifQ.ZiAHjijBAOsKdum0Aix6lgtkLkGo9_Tu87dWQ5Zfwnn3r2FejEWDAnftTft1MqqnMzivZ9Wyyki5ZjQRmTAtnMPJuHC-iivqY4Wh4S6QWCJ1SivBv5tMZR79t5t8mE7R1-OHwst46spru1pps9wt9jsA04d3LpV0eeKYgdPTVaQKklxTm397kIMUugA6yINIBQ3Rh8eQqBgNwEmL4iqyYubzHLVkGkoP9MJikFI05vfRiHtYr-piXz6JFDzXMQj9rW6xtMmrBSn79ChbyvC5nz-Nj2rJPnHsb_0rDUbmXY5PpnMhBpdSH-CbZ4j8jsiib6DtaGJhVZeEQ1GjsFAZwQ" diff --git a/command/acl/authmethod/authmethod.go b/command/acl/authmethod/authmethod.go new file mode 100644 index 0000000000..d8be7857ab --- /dev/null +++ b/command/acl/authmethod/authmethod.go @@ -0,0 +1,64 @@ +package authmethod + +import ( + "github.com/hashicorp/consul/command/flags" + "github.com/mitchellh/cli" +) + +func New() *cmd { + return &cmd{} +} + +type cmd struct{} + +func (c *cmd) Run(args []string) int { + return cli.RunResultHelp +} + +func (c *cmd) Synopsis() string { + return synopsis +} + +func (c *cmd) Help() string { + return flags.Usage(help, nil) +} + +const synopsis = "Manage Consul's ACL Auth Methods" +const help = ` +Usage: consul acl auth-method [options] [args] + + This command has subcommands for managing Consul's ACL Auth Methods. + Here are some simple examples, and more detailed examples are available in + the subcommands or the documentation. + + Create a new auth method: + + $ consul acl auth-method create -type "kubernetes" \ + -name "my-k8s" \ + -description "This is an example kube auth method" \ + -kubernetes-host "https://apiserver.example.com:8443" \ + -kubernetes-ca-file /path/to/kube.ca.crt \ + -kubernetes-service-account-jwt "JWT_CONTENTS" + + List all auth methods: + + $ consul acl auth-method list + + Update all editable fields of the auth method: + + $ consul acl auth-method update -name "my-k8s" \ + -description "new description" \ + -kubernetes-host "https://new-apiserver.example.com:8443" \ + -kubernetes-ca-file /path/to/new-kube.ca.crt \ + -kubernetes-service-account-jwt "NEW_JWT_CONTENTS" + + Read an auth method: + + $ consul acl auth-method read -name my-k8s + + Delete an auth method: + + $ consul acl auth-method delete -name my-k8s + + For more examples, ask for subcommand help or view the documentation. +` diff --git a/command/acl/authmethod/create/authmethod_create.go b/command/acl/authmethod/create/authmethod_create.go new file mode 100644 index 0000000000..46a55882b6 --- /dev/null +++ b/command/acl/authmethod/create/authmethod_create.go @@ -0,0 +1,186 @@ +package authmethodcreate + +import ( + "flag" + "fmt" + "io" + + "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/command/acl" + "github.com/hashicorp/consul/command/flags" + "github.com/hashicorp/consul/command/helpers" + "github.com/mitchellh/cli" +) + +func New(ui cli.Ui) *cmd { + c := &cmd{UI: ui} + c.init() + return c +} + +type cmd struct { + UI cli.Ui + flags *flag.FlagSet + http *flags.HTTPFlags + help string + + authMethodType string + name string + description string + + k8sHost string + k8sCACert string + k8sServiceAccountJWT string + + showMeta bool + + testStdin io.Reader +} + +func (c *cmd) init() { + c.flags = flag.NewFlagSet("", flag.ContinueOnError) + + c.flags.BoolVar( + &c.showMeta, + "meta", + false, + "Indicates that auth method metadata such "+ + "as the content hash and raft indices should be shown for each entry.", + ) + + c.flags.StringVar( + &c.authMethodType, + "type", + "", + "The new auth method's type. This flag is required.", + ) + c.flags.StringVar( + &c.name, + "name", + "", + "The new auth method's name. This flag is required.", + ) + c.flags.StringVar( + &c.description, + "description", + "", + "A description of the auth method.", + ) + + c.flags.StringVar( + &c.k8sHost, + "kubernetes-host", + "", + "Address of the Kubernetes API server. This flag is required for type=kubernetes.", + ) + c.flags.StringVar( + &c.k8sCACert, + "kubernetes-ca-cert", + "", + "PEM encoded CA cert for use by the TLS client used to talk with the "+ + "Kubernetes API. May be prefixed with '@' to indicate that the "+ + "value is a file path to load the cert from. "+ + "This flag is required for type=kubernetes.", + ) + c.flags.StringVar( + &c.k8sServiceAccountJWT, + "kubernetes-service-account-jwt", + "", + "A kubernetes service account JWT used to access the TokenReview API to "+ + "validate other JWTs during login. "+ + "This flag is required for type=kubernetes.", + ) + + c.http = &flags.HTTPFlags{} + flags.Merge(c.flags, c.http.ClientFlags()) + flags.Merge(c.flags, c.http.ServerFlags()) + c.help = flags.Usage(help, c.flags) +} + +func (c *cmd) Run(args []string) int { + if err := c.flags.Parse(args); err != nil { + return 1 + } + + if c.authMethodType == "" { + c.UI.Error(fmt.Sprintf("Missing required '-type' flag")) + c.UI.Error(c.Help()) + return 1 + } else if c.name == "" { + c.UI.Error(fmt.Sprintf("Missing required '-name' flag")) + c.UI.Error(c.Help()) + return 1 + } + + client, err := c.http.APIClient() + if err != nil { + c.UI.Error(fmt.Sprintf("Error connecting to Consul agent: %s", err)) + return 1 + } + + newAuthMethod := &api.ACLAuthMethod{ + Type: c.authMethodType, + Name: c.name, + Description: c.description, + } + + if c.authMethodType == "kubernetes" { + if c.k8sHost == "" { + c.UI.Error(fmt.Sprintf("Missing required '-kubernetes-host' flag")) + return 1 + } else if c.k8sCACert == "" { + c.UI.Error(fmt.Sprintf("Missing required '-kubernetes-ca-cert' flag")) + return 1 + } else if c.k8sServiceAccountJWT == "" { + c.UI.Error(fmt.Sprintf("Missing required '-kubernetes-service-account-jwt' flag")) + return 1 + } + + c.k8sCACert, err = helpers.LoadDataSource(c.k8sCACert, c.testStdin) + if err != nil { + c.UI.Error(fmt.Sprintf("Invalid '-kubernetes-ca-cert' value: %v", err)) + return 1 + } else if c.k8sCACert == "" { + c.UI.Error(fmt.Sprintf("Kubernetes CA Cert is empty")) + return 1 + } + + newAuthMethod.Config = map[string]interface{}{ + "Host": c.k8sHost, + "CACert": c.k8sCACert, + "ServiceAccountJWT": c.k8sServiceAccountJWT, + } + } + + method, _, err := client.ACL().AuthMethodCreate(newAuthMethod, nil) + if err != nil { + c.UI.Error(fmt.Sprintf("Failed to create new auth method: %v", err)) + return 1 + } + + acl.PrintAuthMethod(method, c.UI, c.showMeta) + return 0 +} + +func (c *cmd) Synopsis() string { + return synopsis +} + +func (c *cmd) Help() string { + return flags.Usage(c.help, nil) +} + +const synopsis = "Create an ACL Auth Method" + +const help = ` +Usage: consul acl auth-method create -name NAME -type TYPE [options] + + Create a new auth method: + + $ consul acl auth-method create -type "kubernetes" \ + -name "my-k8s" \ + -description "This is an example kube method" \ + -kubernetes-host "https://apiserver.example.com:8443" \ + -kubernetes-ca-file /path/to/kube.ca.crt \ + -kubernetes-service-account-jwt "JWT_CONTENTS" +` diff --git a/command/acl/authmethod/create/authmethod_create_test.go b/command/acl/authmethod/create/authmethod_create_test.go new file mode 100644 index 0000000000..a5bb222dd1 --- /dev/null +++ b/command/acl/authmethod/create/authmethod_create_test.go @@ -0,0 +1,226 @@ +package authmethodcreate + +import ( + "io/ioutil" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/hashicorp/consul/agent" + "github.com/hashicorp/consul/agent/connect" + "github.com/hashicorp/consul/command/acl" + "github.com/hashicorp/consul/logger" + "github.com/hashicorp/consul/sdk/testutil" + "github.com/hashicorp/consul/testrpc" + "github.com/mitchellh/cli" + "github.com/stretchr/testify/require" + + // activate testing auth method + _ "github.com/hashicorp/consul/agent/consul/authmethod/testauth" +) + +func TestAuthMethodCreateCommand_noTabs(t *testing.T) { + t.Parallel() + + if strings.ContainsRune(New(cli.NewMockUi()).Help(), '\t') { + t.Fatal("help has tabs") + } +} + +func TestAuthMethodCreateCommand(t *testing.T) { + t.Parallel() + + testDir := testutil.TempDir(t, "acl") + defer os.RemoveAll(testDir) + + a := agent.NewTestAgent(t, t.Name(), ` + primary_datacenter = "dc1" + acl { + enabled = true + tokens { + master = "root" + } + }`) + + a.Agent.LogWriter = logger.NewLogWriter(512) + + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + t.Run("type required", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Missing required '-type' flag") + }) + + t.Run("name required", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-type=testing", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Missing required '-name' flag") + }) + + t.Run("invalid type", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-type=invalid", + "-name=my-method", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Invalid Auth Method: Type should be one of") + }) + + t.Run("create testing", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-type=testing", + "-name=test", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + }) +} + +func TestAuthMethodCreateCommand_k8s(t *testing.T) { + t.Parallel() + + testDir := testutil.TempDir(t, "acl") + defer os.RemoveAll(testDir) + + a := agent.NewTestAgent(t, t.Name(), ` + primary_datacenter = "dc1" + acl { + enabled = true + tokens { + master = "root" + } + }`) + + a.Agent.LogWriter = logger.NewLogWriter(512) + + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + t.Run("k8s host required", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-type=kubernetes", + "-name=k8s", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Missing required '-kubernetes-host' flag") + }) + + t.Run("k8s ca cert required", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-type=kubernetes", + "-name=k8s", + "-kubernetes-host=https://foo.internal:8443", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Missing required '-kubernetes-ca-cert' flag") + }) + + ca := connect.TestCA(t, nil) + + t.Run("k8s jwt required", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-type=kubernetes", + "-name=k8s", + "-kubernetes-host=https://foo.internal:8443", + "-kubernetes-ca-cert", ca.RootCert, + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Missing required '-kubernetes-service-account-jwt' flag") + }) + + t.Run("create k8s", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-type=kubernetes", + "-name=k8s", + "-kubernetes-host", "https://foo.internal:8443", + "-kubernetes-ca-cert", ca.RootCert, + "-kubernetes-service-account-jwt", acl.TestKubernetesJWT_A, + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + }) + + caFile := filepath.Join(testDir, "ca.crt") + require.NoError(t, ioutil.WriteFile(caFile, []byte(ca.RootCert), 0600)) + + t.Run("create k8s with cert file", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-type=kubernetes", + "-name=k8s", + "-kubernetes-host", "https://foo.internal:8443", + "-kubernetes-ca-cert", "@" + caFile, + "-kubernetes-service-account-jwt", acl.TestKubernetesJWT_A, + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + }) +} diff --git a/command/acl/authmethod/delete/authmethod_delete.go b/command/acl/authmethod/delete/authmethod_delete.go new file mode 100644 index 0000000000..d8c341c989 --- /dev/null +++ b/command/acl/authmethod/delete/authmethod_delete.go @@ -0,0 +1,82 @@ +package authmethoddelete + +import ( + "flag" + "fmt" + + "github.com/hashicorp/consul/command/flags" + "github.com/mitchellh/cli" +) + +func New(ui cli.Ui) *cmd { + c := &cmd{UI: ui} + c.init() + return c +} + +type cmd struct { + UI cli.Ui + flags *flag.FlagSet + http *flags.HTTPFlags + help string + + name string +} + +func (c *cmd) init() { + c.flags = flag.NewFlagSet("", flag.ContinueOnError) + + c.flags.StringVar( + &c.name, + "name", + "", + "The name of the auth method to delete.", + ) + + c.http = &flags.HTTPFlags{} + flags.Merge(c.flags, c.http.ClientFlags()) + flags.Merge(c.flags, c.http.ServerFlags()) + c.help = flags.Usage(help, c.flags) +} + +func (c *cmd) Run(args []string) int { + if err := c.flags.Parse(args); err != nil { + return 1 + } + + if c.name == "" { + c.UI.Error(fmt.Sprintf("Must specify the -name parameter")) + return 1 + } + + client, err := c.http.APIClient() + if err != nil { + c.UI.Error(fmt.Sprintf("Error connecting to Consul agent: %s", err)) + return 1 + } + + if _, err := client.ACL().AuthMethodDelete(c.name, nil); err != nil { + c.UI.Error(fmt.Sprintf("Error deleting auth method %q: %v", c.name, err)) + return 1 + } + + c.UI.Info(fmt.Sprintf("Auth method %q deleted successfully", c.name)) + return 0 +} + +func (c *cmd) Synopsis() string { + return synopsis +} + +func (c *cmd) Help() string { + return flags.Usage(c.help, nil) +} + +const synopsis = "Delete an ACL Auth Method" +const help = ` +Usage: consul acl auth-method delete -name NAME [options] + + Delete an auth method: + + $ consul acl auth-method delete -name "my-auth-method" +` diff --git a/command/acl/authmethod/delete/authmethod_delete_test.go b/command/acl/authmethod/delete/authmethod_delete_test.go new file mode 100644 index 0000000000..5d0638727c --- /dev/null +++ b/command/acl/authmethod/delete/authmethod_delete_test.go @@ -0,0 +1,131 @@ +package authmethoddelete + +import ( + "fmt" + "os" + "strings" + "testing" + + "github.com/hashicorp/consul/agent" + "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/logger" + "github.com/hashicorp/consul/sdk/testutil" + "github.com/hashicorp/consul/testrpc" + "github.com/hashicorp/go-uuid" + "github.com/mitchellh/cli" + "github.com/stretchr/testify/require" + + // activate testing auth method + _ "github.com/hashicorp/consul/agent/consul/authmethod/testauth" +) + +func TestAuthMethodDeleteCommand_noTabs(t *testing.T) { + t.Parallel() + + if strings.ContainsRune(New(cli.NewMockUi()).Help(), '\t') { + t.Fatal("help has tabs") + } +} + +func TestAuthMethodDeleteCommand(t *testing.T) { + t.Parallel() + + testDir := testutil.TempDir(t, "acl") + defer os.RemoveAll(testDir) + + a := agent.NewTestAgent(t, t.Name(), ` + primary_datacenter = "dc1" + acl { + enabled = true + tokens { + master = "root" + } + }`) + + a.Agent.LogWriter = logger.NewLogWriter(512) + + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + client := a.Client() + + t.Run("name required", func(t *testing.T) { + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + } + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Must specify the -name parameter") + }) + + t.Run("delete notfound", func(t *testing.T) { + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-name=notfound", + } + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + + output := ui.OutputWriter.String() + require.Contains(t, output, fmt.Sprintf("deleted successfully")) + require.Contains(t, output, "notfound") + }) + + createAuthMethod := func(t *testing.T) string { + id, err := uuid.GenerateUUID() + require.NoError(t, err) + + methodName := "test-" + id + + _, _, err = client.ACL().AuthMethodCreate( + &api.ACLAuthMethod{ + Name: methodName, + Type: "testing", + Description: "test", + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + + return methodName + } + + t.Run("delete works", func(t *testing.T) { + name := createAuthMethod(t) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-name=" + name, + } + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + + output := ui.OutputWriter.String() + require.Contains(t, output, fmt.Sprintf("deleted successfully")) + require.Contains(t, output, name) + + method, _, err := client.ACL().AuthMethodRead( + name, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.Nil(t, method) + }) +} diff --git a/command/acl/authmethod/list/authmethod_list.go b/command/acl/authmethod/list/authmethod_list.go new file mode 100644 index 0000000000..837d5f9ce8 --- /dev/null +++ b/command/acl/authmethod/list/authmethod_list.go @@ -0,0 +1,83 @@ +package authmethodlist + +import ( + "flag" + "fmt" + + "github.com/hashicorp/consul/command/acl" + "github.com/hashicorp/consul/command/flags" + "github.com/mitchellh/cli" +) + +func New(ui cli.Ui) *cmd { + c := &cmd{UI: ui} + c.init() + return c +} + +type cmd struct { + UI cli.Ui + flags *flag.FlagSet + http *flags.HTTPFlags + help string + + showMeta bool +} + +func (c *cmd) init() { + c.flags = flag.NewFlagSet("", flag.ContinueOnError) + + c.flags.BoolVar( + &c.showMeta, + "meta", + false, + "Indicates that auth method metadata such "+ + "as the content hash and raft indices should be shown for each entry.", + ) + + c.http = &flags.HTTPFlags{} + flags.Merge(c.flags, c.http.ClientFlags()) + flags.Merge(c.flags, c.http.ServerFlags()) + c.help = flags.Usage(help, c.flags) +} + +func (c *cmd) Run(args []string) int { + if err := c.flags.Parse(args); err != nil { + return 1 + } + + client, err := c.http.APIClient() + if err != nil { + c.UI.Error(fmt.Sprintf("Error connecting to Consul agent: %s", err)) + return 1 + } + + methods, _, err := client.ACL().AuthMethodList(nil) + if err != nil { + c.UI.Error(fmt.Sprintf("Failed to retrieve the auth method list: %v", err)) + return 1 + } + + for _, method := range methods { + acl.PrintAuthMethodListEntry(method, c.UI, c.showMeta) + } + + return 0 +} + +func (c *cmd) Synopsis() string { + return synopsis +} + +func (c *cmd) Help() string { + return flags.Usage(c.help, nil) +} + +const synopsis = "Lists ACL Auth Methods" +const help = ` +Usage: consul acl auth-method list [options] + + List all auth methods: + + $ consul acl auth-method list +` diff --git a/command/acl/authmethod/list/authmethod_list_test.go b/command/acl/authmethod/list/authmethod_list_test.go new file mode 100644 index 0000000000..a8a650393e --- /dev/null +++ b/command/acl/authmethod/list/authmethod_list_test.go @@ -0,0 +1,109 @@ +package authmethodlist + +import ( + "os" + "strings" + "testing" + + "github.com/hashicorp/consul/agent" + "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/logger" + "github.com/hashicorp/consul/sdk/testutil" + "github.com/hashicorp/consul/testrpc" + "github.com/hashicorp/go-uuid" + "github.com/mitchellh/cli" + "github.com/stretchr/testify/require" + + // activate testing auth method + _ "github.com/hashicorp/consul/agent/consul/authmethod/testauth" +) + +func TestAuthMethodListCommand_noTabs(t *testing.T) { + t.Parallel() + + if strings.ContainsRune(New(cli.NewMockUi()).Help(), '\t') { + t.Fatal("help has tabs") + } +} + +func TestAuthMethodListCommand(t *testing.T) { + t.Parallel() + + testDir := testutil.TempDir(t, "acl") + defer os.RemoveAll(testDir) + + a := agent.NewTestAgent(t, t.Name(), ` + primary_datacenter = "dc1" + acl { + enabled = true + tokens { + master = "root" + } + }`) + + a.Agent.LogWriter = logger.NewLogWriter(512) + + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + t.Run("found none", func(t *testing.T) { + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + } + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + require.Empty(t, ui.OutputWriter.String()) + }) + + client := a.Client() + + createAuthMethod := func(t *testing.T) string { + id, err := uuid.GenerateUUID() + require.NoError(t, err) + + methodName := "test-" + id + + _, _, err = client.ACL().AuthMethodCreate( + &api.ACLAuthMethod{ + Name: methodName, + Type: "testing", + Description: "test", + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + + return methodName + } + + var methodNames []string + for i := 0; i < 5; i++ { + methodName := createAuthMethod(t) + methodNames = append(methodNames, methodName) + } + + t.Run("found some", func(t *testing.T) { + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + } + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + output := ui.OutputWriter.String() + + for _, methodName := range methodNames { + require.Contains(t, output, methodName) + } + }) +} diff --git a/command/acl/authmethod/read/authmethod_read.go b/command/acl/authmethod/read/authmethod_read.go new file mode 100644 index 0000000000..1a98bbf64d --- /dev/null +++ b/command/acl/authmethod/read/authmethod_read.go @@ -0,0 +1,96 @@ +package authmethodread + +import ( + "flag" + "fmt" + + "github.com/hashicorp/consul/command/acl" + "github.com/hashicorp/consul/command/flags" + "github.com/mitchellh/cli" +) + +func New(ui cli.Ui) *cmd { + c := &cmd{UI: ui} + c.init() + return c +} + +type cmd struct { + UI cli.Ui + flags *flag.FlagSet + http *flags.HTTPFlags + help string + + name string + + showMeta bool +} + +func (c *cmd) init() { + c.flags = flag.NewFlagSet("", flag.ContinueOnError) + + c.flags.BoolVar( + &c.showMeta, + "meta", + false, + "Indicates that auth method metadata such "+ + "as the content hash and raft indices should be shown for each entry.", + ) + + c.flags.StringVar( + &c.name, + "name", + "", + "The name of the auth method to read.", + ) + + c.http = &flags.HTTPFlags{} + flags.Merge(c.flags, c.http.ClientFlags()) + flags.Merge(c.flags, c.http.ServerFlags()) + c.help = flags.Usage(help, c.flags) +} + +func (c *cmd) Run(args []string) int { + if err := c.flags.Parse(args); err != nil { + return 1 + } + + if c.name == "" { + c.UI.Error(fmt.Sprintf("Must specify the -name parameter")) + return 1 + } + + client, err := c.http.APIClient() + if err != nil { + c.UI.Error(fmt.Sprintf("Error connecting to Consul agent: %s", err)) + return 1 + } + + method, _, err := client.ACL().AuthMethodRead(c.name, nil) + if err != nil { + c.UI.Error(fmt.Sprintf("Error reading auth method %q: %v", c.name, err)) + return 1 + } else if method == nil { + c.UI.Error(fmt.Sprintf("Auth method not found with name %q", c.name)) + return 1 + } + acl.PrintAuthMethod(method, c.UI, c.showMeta) + return 0 +} + +func (c *cmd) Synopsis() string { + return synopsis +} + +func (c *cmd) Help() string { + return flags.Usage(c.help, nil) +} + +const synopsis = "Read an ACL Auth Method" +const help = ` +Usage: consul acl auth-method read -name NAME [options] + + Read an auth method: + + $ consul acl auth-method read -name my-auth-method +` diff --git a/command/acl/authmethod/read/authmethod_read_test.go b/command/acl/authmethod/read/authmethod_read_test.go new file mode 100644 index 0000000000..72b78e8005 --- /dev/null +++ b/command/acl/authmethod/read/authmethod_read_test.go @@ -0,0 +1,118 @@ +package authmethodread + +import ( + "os" + "strings" + "testing" + + "github.com/hashicorp/consul/agent" + "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/logger" + "github.com/hashicorp/consul/sdk/testutil" + "github.com/hashicorp/consul/testrpc" + "github.com/hashicorp/go-uuid" + "github.com/mitchellh/cli" + "github.com/stretchr/testify/require" + + // activate testing auth method + _ "github.com/hashicorp/consul/agent/consul/authmethod/testauth" +) + +func TestAuthMethodReadCommand_noTabs(t *testing.T) { + t.Parallel() + + if strings.ContainsRune(New(cli.NewMockUi()).Help(), '\t') { + t.Fatal("help has tabs") + } +} + +func TestAuthMethodReadCommand(t *testing.T) { + t.Parallel() + + testDir := testutil.TempDir(t, "acl") + defer os.RemoveAll(testDir) + + a := agent.NewTestAgent(t, t.Name(), ` + primary_datacenter = "dc1" + acl { + enabled = true + tokens { + master = "root" + } + }`) + + a.Agent.LogWriter = logger.NewLogWriter(512) + + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + client := a.Client() + + t.Run("name required", func(t *testing.T) { + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + } + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Must specify the -name parameter") + }) + + t.Run("not found", func(t *testing.T) { + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-name=notfound", + } + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Auth method not found with name") + }) + + createAuthMethod := func(t *testing.T) string { + id, err := uuid.GenerateUUID() + require.NoError(t, err) + + methodName := "test-" + id + + _, _, err = client.ACL().AuthMethodCreate( + &api.ACLAuthMethod{ + Name: methodName, + Type: "testing", + Description: "test", + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + + return methodName + } + + t.Run("read by name", func(t *testing.T) { + name := createAuthMethod(t) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-name=" + name, + } + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + + output := ui.OutputWriter.String() + require.Contains(t, output, name) + }) +} diff --git a/command/acl/authmethod/update/authmethod_update.go b/command/acl/authmethod/update/authmethod_update.go new file mode 100644 index 0000000000..6f77235f51 --- /dev/null +++ b/command/acl/authmethod/update/authmethod_update.go @@ -0,0 +1,220 @@ +package authmethodupdate + +import ( + "flag" + "fmt" + "io" + + "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/command/acl" + "github.com/hashicorp/consul/command/flags" + "github.com/hashicorp/consul/command/helpers" + "github.com/mitchellh/cli" +) + +func New(ui cli.Ui) *cmd { + c := &cmd{UI: ui} + c.init() + return c +} + +type cmd struct { + UI cli.Ui + flags *flag.FlagSet + http *flags.HTTPFlags + help string + + name string + + description string + + k8sHost string + k8sCACert string + k8sServiceAccountJWT string + + noMerge bool + showMeta bool + + testStdin io.Reader +} + +func (c *cmd) init() { + c.flags = flag.NewFlagSet("", flag.ContinueOnError) + + c.flags.BoolVar( + &c.showMeta, + "meta", + false, + "Indicates that auth method metadata such "+ + "as the content hash and raft indices should be shown for each entry.", + ) + + c.flags.StringVar( + &c.name, + "name", + "", + "The auth method name.", + ) + + c.flags.StringVar( + &c.description, + "description", + "", + "A description of the auth method.", + ) + + c.flags.StringVar( + &c.k8sHost, + "kubernetes-host", + "", + "Address of the Kubernetes API server. This flag is required for type=kubernetes.", + ) + c.flags.StringVar( + &c.k8sCACert, + "kubernetes-ca-cert", + "", + "PEM encoded CA cert for use by the TLS client used to talk with the "+ + "Kubernetes API. May be prefixed with '@' to indicate that the "+ + "value is a file path to load the cert from. "+ + "This flag is required for type=kubernetes.", + ) + c.flags.StringVar( + &c.k8sServiceAccountJWT, + "kubernetes-service-account-jwt", + "", + "A kubernetes service account JWT used to access the TokenReview API to "+ + "validate other JWTs during login. "+ + "This flag is required for type=kubernetes.", + ) + + c.flags.BoolVar(&c.noMerge, "no-merge", false, "Do not merge the current auth method "+ + "information with what is provided to the command. Instead overwrite all fields "+ + "with the exception of the name which is immutable.") + + c.http = &flags.HTTPFlags{} + flags.Merge(c.flags, c.http.ClientFlags()) + flags.Merge(c.flags, c.http.ServerFlags()) + c.help = flags.Usage(help, c.flags) +} + +func (c *cmd) Run(args []string) int { + if err := c.flags.Parse(args); err != nil { + return 1 + } + + if c.name == "" { + c.UI.Error(fmt.Sprintf("Cannot update an auth method without specifying the -name parameter")) + return 1 + } + + client, err := c.http.APIClient() + if err != nil { + c.UI.Error(fmt.Sprintf("Error connecting to Consul agent: %s", err)) + return 1 + } + + // Regardless of merge, we need to fetch the prior immutable fields first. + currentAuthMethod, _, err := client.ACL().AuthMethodRead(c.name, nil) + if err != nil { + c.UI.Error(fmt.Sprintf("Error when retrieving current auth method: %v", err)) + return 1 + } else if currentAuthMethod == nil { + c.UI.Error(fmt.Sprintf("Auth method not found with name %q", c.name)) + return 1 + } + + if c.k8sCACert != "" { + c.k8sCACert, err = helpers.LoadDataSource(c.k8sCACert, c.testStdin) + if err != nil { + c.UI.Error(fmt.Sprintf("Invalid '-kubernetes-ca-cert' value: %v", err)) + return 1 + } else if c.k8sCACert == "" { + c.UI.Error(fmt.Sprintf("Kubernetes CA Cert is empty")) + return 1 + } + } + + var method *api.ACLAuthMethod + if c.noMerge { + method = &api.ACLAuthMethod{ + Name: currentAuthMethod.Name, + Type: currentAuthMethod.Type, + Description: c.description, + } + + if currentAuthMethod.Type == "kubernetes" { + if c.k8sHost == "" { + c.UI.Error(fmt.Sprintf("Missing required '-kubernetes-host' flag")) + return 1 + } else if c.k8sCACert == "" { + c.UI.Error(fmt.Sprintf("Missing required '-kubernetes-ca-cert' flag")) + return 1 + } else if c.k8sServiceAccountJWT == "" { + c.UI.Error(fmt.Sprintf("Missing required '-kubernetes-service-account-jwt' flag")) + return 1 + } + + method.Config = map[string]interface{}{ + "Host": c.k8sHost, + "CACert": c.k8sCACert, + "ServiceAccountJWT": c.k8sServiceAccountJWT, + } + } + } else { + methodCopy := *currentAuthMethod + method = &methodCopy + + if c.description != "" { + method.Description = c.description + } + if method.Config == nil { + method.Config = make(map[string]interface{}) + } + if currentAuthMethod.Type == "kubernetes" { + if c.k8sHost != "" { + method.Config["Host"] = c.k8sHost + } + if c.k8sCACert != "" { + method.Config["CACert"] = c.k8sCACert + } + if c.k8sServiceAccountJWT != "" { + method.Config["ServiceAccountJWT"] = c.k8sServiceAccountJWT + } + } + } + + method, _, err = client.ACL().AuthMethodUpdate(method, nil) + if err != nil { + c.UI.Error(fmt.Sprintf("Error updating auth method %q: %v", c.name, err)) + return 1 + } + + c.UI.Info(fmt.Sprintf("Auth method updated successfully")) + acl.PrintAuthMethod(method, c.UI, c.showMeta) + return 0 +} + +func (c *cmd) Synopsis() string { + return synopsis +} + +func (c *cmd) Help() string { + return flags.Usage(c.help, nil) +} + +const synopsis = "Update an ACL Auth Method" +const help = ` +Usage: consul acl auth-method update -name NAME [options] + + Updates an auth method. By default it will merge the auth method + information with its current state so that you do not have to provide all + parameters. This behavior can be disabled by passing -no-merge. + + Update all editable fields of the auth method: + + $ consul acl auth-method update -name "my-k8s" \ + -description "new description" \ + -kubernetes-host "https://new-apiserver.example.com:8443" \ + -kubernetes-ca-file /path/to/new-kube.ca.crt \ + -kubernetes-service-account-jwt "NEW_JWT_CONTENTS" +` diff --git a/command/acl/authmethod/update/authmethod_update_test.go b/command/acl/authmethod/update/authmethod_update_test.go new file mode 100644 index 0000000000..ba5d92758e --- /dev/null +++ b/command/acl/authmethod/update/authmethod_update_test.go @@ -0,0 +1,647 @@ +package authmethodupdate + +import ( + "io/ioutil" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/hashicorp/consul/agent" + "github.com/hashicorp/consul/agent/connect" + "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/command/acl" + "github.com/hashicorp/consul/logger" + "github.com/hashicorp/consul/sdk/testutil" + "github.com/hashicorp/consul/testrpc" + "github.com/hashicorp/go-uuid" + "github.com/mitchellh/cli" + "github.com/stretchr/testify/require" + + // activate testing auth method + _ "github.com/hashicorp/consul/agent/consul/authmethod/testauth" +) + +func TestAuthMethodUpdateCommand_noTabs(t *testing.T) { + t.Parallel() + + if strings.ContainsRune(New(cli.NewMockUi()).Help(), '\t') { + t.Fatal("help has tabs") + } +} + +func TestAuthMethodUpdateCommand(t *testing.T) { + t.Parallel() + + testDir := testutil.TempDir(t, "acl") + defer os.RemoveAll(testDir) + + a := agent.NewTestAgent(t, t.Name(), ` + primary_datacenter = "dc1" + acl { + enabled = true + tokens { + master = "root" + } + }`) + + a.Agent.LogWriter = logger.NewLogWriter(512) + + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + client := a.Client() + + t.Run("update without name", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Cannot update an auth method without specifying the -name parameter") + }) + + t.Run("update nonexistent method", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-name=test", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Auth method not found with name") + }) + + createAuthMethod := func(t *testing.T) string { + id, err := uuid.GenerateUUID() + require.NoError(t, err) + + methodName := "test-" + id + + _, _, err = client.ACL().AuthMethodCreate( + &api.ACLAuthMethod{ + Name: methodName, + Type: "testing", + Description: "test", + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + + return methodName + } + + t.Run("update all fields", func(t *testing.T) { + name := createAuthMethod(t) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-name=" + name, + "-description", "updated description", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + + method, _, err := client.ACL().AuthMethodRead( + name, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, method) + require.Equal(t, "updated description", method.Description) + }) +} + +func TestAuthMethodUpdateCommand_noMerge(t *testing.T) { + t.Parallel() + + testDir := testutil.TempDir(t, "acl") + defer os.RemoveAll(testDir) + + a := agent.NewTestAgent(t, t.Name(), ` + primary_datacenter = "dc1" + acl { + enabled = true + tokens { + master = "root" + } + }`) + + a.Agent.LogWriter = logger.NewLogWriter(512) + + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + client := a.Client() + + t.Run("update without name", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-no-merge", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Cannot update an auth method without specifying the -name parameter") + }) + + t.Run("update nonexistent method", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-no-merge", + "-name=test", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Auth method not found with name") + }) + + createAuthMethod := func(t *testing.T) string { + id, err := uuid.GenerateUUID() + require.NoError(t, err) + + methodName := "test-" + id + + _, _, err = client.ACL().AuthMethodCreate( + &api.ACLAuthMethod{ + Name: methodName, + Type: "testing", + Description: "test", + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + + return methodName + } + + t.Run("update all fields", func(t *testing.T) { + name := createAuthMethod(t) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-no-merge", + "-name=" + name, + "-description", "updated description", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String()) + require.Empty(t, ui.ErrorWriter.String()) + + method, _, err := client.ACL().AuthMethodRead( + name, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, method) + require.Equal(t, "updated description", method.Description) + }) +} + +func TestAuthMethodUpdateCommand_k8s(t *testing.T) { + t.Parallel() + + testDir := testutil.TempDir(t, "acl") + defer os.RemoveAll(testDir) + + a := agent.NewTestAgent(t, t.Name(), ` + primary_datacenter = "dc1" + acl { + enabled = true + tokens { + master = "root" + } + }`) + + a.Agent.LogWriter = logger.NewLogWriter(512) + + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + client := a.Client() + + ca := connect.TestCA(t, nil) + ca2 := connect.TestCA(t, nil) + + createAuthMethod := func(t *testing.T) string { + id, err := uuid.GenerateUUID() + require.NoError(t, err) + + methodName := "k8s-" + id + + _, _, err = client.ACL().AuthMethodCreate( + &api.ACLAuthMethod{ + Name: methodName, + Type: "kubernetes", + Description: "test", + Config: map[string]interface{}{ + "Host": "https://foo.internal:8443", + "CACert": ca.RootCert, + "ServiceAccountJWT": acl.TestKubernetesJWT_A, + }, + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + + return methodName + } + + t.Run("update all fields", func(t *testing.T) { + name := createAuthMethod(t) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-name=" + name, + "-description", "updated description", + "-kubernetes-host", "https://foo-new.internal:8443", + "-kubernetes-ca-cert", ca2.RootCert, + "-kubernetes-service-account-jwt", acl.TestKubernetesJWT_B, + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + + method, _, err := client.ACL().AuthMethodRead( + name, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, method) + require.Equal(t, "updated description", method.Description) + + config, err := api.ParseKubernetesAuthMethodConfig(method.Config) + require.NoError(t, err) + require.Equal(t, "https://foo-new.internal:8443", config.Host) + require.Equal(t, ca2.RootCert, config.CACert) + require.Equal(t, acl.TestKubernetesJWT_B, config.ServiceAccountJWT) + }) + + ca2File := filepath.Join(testDir, "ca2.crt") + require.NoError(t, ioutil.WriteFile(ca2File, []byte(ca2.RootCert), 0600)) + + t.Run("update all fields with cert file", func(t *testing.T) { + name := createAuthMethod(t) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-name=" + name, + "-description", "updated description", + "-kubernetes-host", "https://foo-new.internal:8443", + "-kubernetes-ca-cert", "@" + ca2File, + "-kubernetes-service-account-jwt", acl.TestKubernetesJWT_B, + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + + method, _, err := client.ACL().AuthMethodRead( + name, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, method) + require.Equal(t, "updated description", method.Description) + + config, err := api.ParseKubernetesAuthMethodConfig(method.Config) + require.NoError(t, err) + + require.Equal(t, "https://foo-new.internal:8443", config.Host) + require.Equal(t, ca2.RootCert, config.CACert) + require.Equal(t, acl.TestKubernetesJWT_B, config.ServiceAccountJWT) + }) + + t.Run("update all fields but k8s host", func(t *testing.T) { + name := createAuthMethod(t) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-name=" + name, + "-description", "updated description", + "-kubernetes-ca-cert", ca2.RootCert, + "-kubernetes-service-account-jwt", acl.TestKubernetesJWT_B, + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + + method, _, err := client.ACL().AuthMethodRead( + name, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, method) + require.Equal(t, "updated description", method.Description) + + config, err := api.ParseKubernetesAuthMethodConfig(method.Config) + require.NoError(t, err) + + require.Equal(t, "https://foo.internal:8443", config.Host) + require.Equal(t, ca2.RootCert, config.CACert) + require.Equal(t, acl.TestKubernetesJWT_B, config.ServiceAccountJWT) + }) + + t.Run("update all fields but k8s ca cert", func(t *testing.T) { + name := createAuthMethod(t) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-name=" + name, + "-description", "updated description", + "-kubernetes-host", "https://foo-new.internal:8443", + "-kubernetes-service-account-jwt", acl.TestKubernetesJWT_B, + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + + method, _, err := client.ACL().AuthMethodRead( + name, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, method) + require.Equal(t, "updated description", method.Description) + + config, err := api.ParseKubernetesAuthMethodConfig(method.Config) + require.NoError(t, err) + + require.Equal(t, "https://foo-new.internal:8443", config.Host) + require.Equal(t, ca.RootCert, config.CACert) + require.Equal(t, acl.TestKubernetesJWT_B, config.ServiceAccountJWT) + }) + + t.Run("update all fields but k8s jwt", func(t *testing.T) { + name := createAuthMethod(t) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-name=" + name, + "-description", "updated description", + "-kubernetes-host", "https://foo-new.internal:8443", + "-kubernetes-ca-cert", ca2.RootCert, + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + + method, _, err := client.ACL().AuthMethodRead( + name, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, method) + require.Equal(t, "updated description", method.Description) + + config, err := api.ParseKubernetesAuthMethodConfig(method.Config) + require.NoError(t, err) + + require.Equal(t, "https://foo-new.internal:8443", config.Host) + require.Equal(t, ca2.RootCert, config.CACert) + require.Equal(t, acl.TestKubernetesJWT_A, config.ServiceAccountJWT) + }) +} + +func TestAuthMethodUpdateCommand_k8s_noMerge(t *testing.T) { + t.Parallel() + + testDir := testutil.TempDir(t, "acl") + defer os.RemoveAll(testDir) + + a := agent.NewTestAgent(t, t.Name(), ` + primary_datacenter = "dc1" + acl { + enabled = true + tokens { + master = "root" + } + }`) + + a.Agent.LogWriter = logger.NewLogWriter(512) + + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + client := a.Client() + + ca := connect.TestCA(t, nil) + ca2 := connect.TestCA(t, nil) + + createAuthMethod := func(t *testing.T) string { + id, err := uuid.GenerateUUID() + require.NoError(t, err) + + methodName := "k8s-" + id + + _, _, err = client.ACL().AuthMethodCreate( + &api.ACLAuthMethod{ + Name: methodName, + Type: "kubernetes", + Description: "test", + Config: map[string]interface{}{ + "Host": "https://foo.internal:8443", + "CACert": ca.RootCert, + "ServiceAccountJWT": acl.TestKubernetesJWT_A, + }, + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + + return methodName + } + + t.Run("update missing k8s host", func(t *testing.T) { + name := createAuthMethod(t) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-no-merge", + "-name=" + name, + "-description", "updated description", + "-kubernetes-ca-cert", ca2.RootCert, + "-kubernetes-service-account-jwt", acl.TestKubernetesJWT_B, + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Missing required '-kubernetes-host' flag") + }) + + t.Run("update missing k8s ca cert", func(t *testing.T) { + name := createAuthMethod(t) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-no-merge", + "-name=" + name, + "-description", "updated description", + "-kubernetes-host", "https://foo-new.internal:8443", + "-kubernetes-service-account-jwt", acl.TestKubernetesJWT_B, + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Missing required '-kubernetes-ca-cert' flag") + }) + + t.Run("update missing k8s jwt", func(t *testing.T) { + name := createAuthMethod(t) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-no-merge", + "-name=" + name, + "-description", "updated description", + "-kubernetes-host", "https://foo-new.internal:8443", + "-kubernetes-ca-cert", ca2.RootCert, + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Missing required '-kubernetes-service-account-jwt' flag") + }) + + t.Run("update all fields", func(t *testing.T) { + name := createAuthMethod(t) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-no-merge", + "-name=" + name, + "-description", "updated description", + "-kubernetes-host", "https://foo-new.internal:8443", + "-kubernetes-ca-cert", ca2.RootCert, + "-kubernetes-service-account-jwt", acl.TestKubernetesJWT_B, + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + + method, _, err := client.ACL().AuthMethodRead( + name, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, method) + require.Equal(t, "updated description", method.Description) + + config, err := api.ParseKubernetesAuthMethodConfig(method.Config) + require.NoError(t, err) + + require.Equal(t, "https://foo-new.internal:8443", config.Host) + require.Equal(t, ca2.RootCert, config.CACert) + require.Equal(t, acl.TestKubernetesJWT_B, config.ServiceAccountJWT) + }) + + ca2File := filepath.Join(testDir, "ca2.crt") + require.NoError(t, ioutil.WriteFile(ca2File, []byte(ca2.RootCert), 0600)) + + t.Run("update all fields with cert file", func(t *testing.T) { + name := createAuthMethod(t) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-no-merge", + "-name=" + name, + "-description", "updated description", + "-kubernetes-host", "https://foo-new.internal:8443", + "-kubernetes-ca-cert", "@" + ca2File, + "-kubernetes-service-account-jwt", acl.TestKubernetesJWT_B, + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + + method, _, err := client.ACL().AuthMethodRead( + name, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, method) + require.Equal(t, "updated description", method.Description) + + config, err := api.ParseKubernetesAuthMethodConfig(method.Config) + require.NoError(t, err) + + require.Equal(t, "https://foo-new.internal:8443", config.Host) + require.Equal(t, ca2.RootCert, config.CACert) + require.Equal(t, acl.TestKubernetesJWT_B, config.ServiceAccountJWT) + }) +} diff --git a/command/acl/bindingrule/bindingrule.go b/command/acl/bindingrule/bindingrule.go new file mode 100644 index 0000000000..2b94139463 --- /dev/null +++ b/command/acl/bindingrule/bindingrule.go @@ -0,0 +1,60 @@ +package bindingrule + +import ( + "github.com/hashicorp/consul/command/flags" + "github.com/mitchellh/cli" +) + +func New() *cmd { + return &cmd{} +} + +type cmd struct{} + +func (c *cmd) Run(args []string) int { + return cli.RunResultHelp +} + +func (c *cmd) Synopsis() string { + return synopsis +} + +func (c *cmd) Help() string { + return flags.Usage(help, nil) +} + +const synopsis = "Manage Consul's ACL Binding Rules" +const help = ` +Usage: consul acl binding-rule [options] [args] + + This command has subcommands for managing Consul's ACL Binding Rules. Here + are some simple examples, and more detailed examples are available in the + subcommands or the documentation. + + Create a new binding rule: + + $ consul acl binding-rule create \ + -method=minikube \ + -bind-type=service \ + -bind-name='k8s-${serviceaccount.name}' \ + -selector='serviceaccount.namespace==default and serviceaccount.name==web' + + List all binding rules: + + $ consul acl binding-rule list + + Update a binding rule: + + $ consul acl binding-rule update -id=43cb72df-9c6f-4315-ac8a-01a9d98155ef \ + -bind-name='k8s-${serviceaccount.name}' + + Read a binding rule: + + $ consul acl binding-rule read -id fdabbcb5-9de5-4b1a-961f-77214ae88cba + + Delete a binding rule: + + $ consul acl binding-rule delete -id b6b856da-5193-4e78-845a-7d61ca8371ba + + For more examples, ask for subcommand help or view the documentation. +` diff --git a/command/acl/bindingrule/create/bindingrule_create.go b/command/acl/bindingrule/create/bindingrule_create.go new file mode 100644 index 0000000000..01bcfcbe72 --- /dev/null +++ b/command/acl/bindingrule/create/bindingrule_create.go @@ -0,0 +1,148 @@ +package bindingrulecreate + +import ( + "flag" + "fmt" + + "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/command/acl" + "github.com/hashicorp/consul/command/flags" + "github.com/mitchellh/cli" +) + +func New(ui cli.Ui) *cmd { + c := &cmd{UI: ui} + c.init() + return c +} + +type cmd struct { + UI cli.Ui + flags *flag.FlagSet + http *flags.HTTPFlags + help string + + authMethodName string + description string + selector string + bindType string + bindName string + + showMeta bool +} + +func (c *cmd) init() { + c.flags = flag.NewFlagSet("", flag.ContinueOnError) + + c.flags.BoolVar( + &c.showMeta, + "meta", + false, + "Indicates that binding rule metadata such "+ + "as the content hash and raft indices should be shown for each entry.", + ) + + c.flags.StringVar( + &c.authMethodName, + "method", + "", + "The auth method's name for which this binding rule applies. "+ + "This flag is required.", + ) + c.flags.StringVar( + &c.description, + "description", + "", + "A description of the binding rule.", + ) + c.flags.StringVar( + &c.selector, + "selector", + "", + "Selector is an expression that matches against verified identity "+ + "attributes returned from the auth method during login.", + ) + c.flags.StringVar( + &c.bindType, + "bind-type", + string(api.BindingRuleBindTypeService), + "Type of binding to perform (\"service\" or \"role\").", + ) + c.flags.StringVar( + &c.bindName, + "bind-name", + "", + "Name to bind on match. Can use ${var} interpolation. "+ + "This flag is required.", + ) + + c.http = &flags.HTTPFlags{} + flags.Merge(c.flags, c.http.ClientFlags()) + flags.Merge(c.flags, c.http.ServerFlags()) + c.help = flags.Usage(help, c.flags) +} + +func (c *cmd) Run(args []string) int { + if err := c.flags.Parse(args); err != nil { + return 1 + } + + if c.authMethodName == "" { + c.UI.Error(fmt.Sprintf("Missing required '-method' flag")) + c.UI.Error(c.Help()) + return 1 + } else if c.bindType == "" { + c.UI.Error(fmt.Sprintf("Missing required '-bind-type' flag")) + c.UI.Error(c.Help()) + return 1 + } else if c.bindName == "" { + c.UI.Error(fmt.Sprintf("Missing required '-bind-name' flag")) + c.UI.Error(c.Help()) + return 1 + } + + newRule := &api.ACLBindingRule{ + Description: c.description, + AuthMethod: c.authMethodName, + BindType: api.BindingRuleBindType(c.bindType), + BindName: c.bindName, + Selector: c.selector, + } + + client, err := c.http.APIClient() + if err != nil { + c.UI.Error(fmt.Sprintf("Error connecting to Consul agent: %s", err)) + return 1 + } + + rule, _, err := client.ACL().BindingRuleCreate(newRule, nil) + if err != nil { + c.UI.Error(fmt.Sprintf("Failed to create new binding rule: %v", err)) + return 1 + } + + acl.PrintBindingRule(rule, c.UI, c.showMeta) + return 0 +} + +func (c *cmd) Synopsis() string { + return synopsis +} + +func (c *cmd) Help() string { + return flags.Usage(c.help, nil) +} + +const synopsis = "Create an ACL Binding Rule" + +const help = ` +Usage: consul acl binding-rule create [options] + + Create a new binding rule: + + $ consul acl binding-rule create \ + -method=minikube \ + -bind-type=service \ + -bind-name='k8s-${serviceaccount.name}' \ + -selector='serviceaccount.namespace==default and serviceaccount.name==web' +` diff --git a/command/acl/bindingrule/create/bindingrule_create_test.go b/command/acl/bindingrule/create/bindingrule_create_test.go new file mode 100644 index 0000000000..0e8b510963 --- /dev/null +++ b/command/acl/bindingrule/create/bindingrule_create_test.go @@ -0,0 +1,178 @@ +package bindingrulecreate + +import ( + "os" + "strings" + "testing" + + "github.com/hashicorp/consul/agent" + "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/logger" + "github.com/hashicorp/consul/sdk/testutil" + "github.com/hashicorp/consul/testrpc" + "github.com/mitchellh/cli" + "github.com/stretchr/testify/require" + + // activate testing auth method + _ "github.com/hashicorp/consul/agent/consul/authmethod/testauth" +) + +func TestBindingRuleCreateCommand_noTabs(t *testing.T) { + t.Parallel() + + if strings.ContainsRune(New(cli.NewMockUi()).Help(), '\t') { + t.Fatal("help has tabs") + } +} + +func TestBindingRuleCreateCommand(t *testing.T) { + t.Parallel() + + testDir := testutil.TempDir(t, "acl") + defer os.RemoveAll(testDir) + + a := agent.NewTestAgent(t, t.Name(), ` + primary_datacenter = "dc1" + acl { + enabled = true + tokens { + master = "root" + } + }`) + + a.Agent.LogWriter = logger.NewLogWriter(512) + + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + client := a.Client() + + // create an auth method in advance + { + _, _, err := client.ACL().AuthMethodCreate( + &api.ACLAuthMethod{ + Name: "test", + Type: "testing", + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + } + + t.Run("method is required", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Missing required '-method' flag") + }) + + t.Run("bind type required", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-method=test", + "-bind-type=", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Missing required '-bind-type' flag") + }) + + t.Run("bind name required", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-method=test", + "-bind-type=service", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Missing required '-bind-name' flag") + }) + + t.Run("must use roughly valid selector", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-method=test", + "-bind-type=service", + "-bind-name=demo", + "-selector", "foo", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Selector is invalid") + }) + + t.Run("create it with no selector", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-method=test", + "-bind-type=service", + "-bind-name=demo", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + }) + + t.Run("create it with a match selector", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-method=test", + "-bind-type=service", + "-bind-name=demo", + "-selector", "serviceaccount.namespace==default and serviceaccount.name==vault", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + }) + + t.Run("create it with type role", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-method=test", + "-bind-type=role", + "-bind-name=demo", + "-selector", "serviceaccount.namespace==default and serviceaccount.name==vault", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + }) +} diff --git a/command/acl/bindingrule/delete/bindingrule_delete.go b/command/acl/bindingrule/delete/bindingrule_delete.go new file mode 100644 index 0000000000..7956e1e3aa --- /dev/null +++ b/command/acl/bindingrule/delete/bindingrule_delete.go @@ -0,0 +1,97 @@ +package bindingruledelete + +import ( + "flag" + "fmt" + + "github.com/hashicorp/consul/command/acl" + "github.com/hashicorp/consul/command/flags" + "github.com/mitchellh/cli" +) + +func New(ui cli.Ui) *cmd { + c := &cmd{UI: ui} + c.init() + return c +} + +type cmd struct { + UI cli.Ui + flags *flag.FlagSet + http *flags.HTTPFlags + help string + + ruleID string +} + +func (c *cmd) init() { + c.flags = flag.NewFlagSet("", flag.ContinueOnError) + + c.flags.StringVar( + &c.ruleID, + "id", + "", + "The ID of the binding rule to delete. "+ + "It may be specified as a unique ID prefix but will error if the prefix "+ + "matches multiple binding rule IDs", + ) + + c.http = &flags.HTTPFlags{} + flags.Merge(c.flags, c.http.ClientFlags()) + flags.Merge(c.flags, c.http.ServerFlags()) + c.help = flags.Usage(help, c.flags) +} + +func (c *cmd) Run(args []string) int { + if err := c.flags.Parse(args); err != nil { + return 1 + } + + if c.ruleID == "" { + c.UI.Error(fmt.Sprintf("Must specify the -id parameter")) + return 1 + } + + client, err := c.http.APIClient() + if err != nil { + c.UI.Error(fmt.Sprintf("Error connecting to Consul agent: %s", err)) + return 1 + } + + ruleID, err := acl.GetBindingRuleIDFromPartial(client, c.ruleID) + if err != nil { + c.UI.Error(fmt.Sprintf("Error determining binding rule ID: %v", err)) + return 1 + } + + if _, err := client.ACL().BindingRuleDelete(ruleID, nil); err != nil { + c.UI.Error(fmt.Sprintf("Error deleting binding rule %q: %v", ruleID, err)) + return 1 + } + + c.UI.Info(fmt.Sprintf("Binding rule %q deleted successfully", ruleID)) + return 0 +} + +func (c *cmd) Synopsis() string { + return synopsis +} + +func (c *cmd) Help() string { + return flags.Usage(c.help, nil) +} + +const synopsis = "Delete an ACL Binding Rule" +const help = ` +Usage: consul acl binding-rule delete -id ID [options] + + Deletes an ACL binding rule by providing the ID or a unique ID prefix. + + Delete by prefix: + + $ consul acl binding-rule delete -id b6b85 + + Delete by full ID: + + $ consul acl binding-rule delete -id b6b856da-5193-4e78-845a-7d61ca8371ba +` diff --git a/command/acl/bindingrule/delete/bindingrule_delete_test.go b/command/acl/bindingrule/delete/bindingrule_delete_test.go new file mode 100644 index 0000000000..497f26b21c --- /dev/null +++ b/command/acl/bindingrule/delete/bindingrule_delete_test.go @@ -0,0 +1,187 @@ +package bindingruledelete + +import ( + "fmt" + "os" + "strings" + "testing" + + "github.com/hashicorp/consul/agent" + "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/logger" + "github.com/hashicorp/consul/sdk/testutil" + "github.com/hashicorp/consul/testrpc" + "github.com/mitchellh/cli" + "github.com/stretchr/testify/require" + + // activate testing auth method + _ "github.com/hashicorp/consul/agent/consul/authmethod/testauth" +) + +func TestBindingRuleDeleteCommand_noTabs(t *testing.T) { + t.Parallel() + + if strings.ContainsRune(New(cli.NewMockUi()).Help(), '\t') { + t.Fatal("help has tabs") + } +} + +func TestBindingRuleDeleteCommand(t *testing.T) { + t.Parallel() + + testDir := testutil.TempDir(t, "acl") + defer os.RemoveAll(testDir) + + a := agent.NewTestAgent(t, t.Name(), ` + primary_datacenter = "dc1" + acl { + enabled = true + tokens { + master = "root" + } + }`) + + a.Agent.LogWriter = logger.NewLogWriter(512) + + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + client := a.Client() + + // create an auth method in advance + { + _, _, err := client.ACL().AuthMethodCreate( + &api.ACLAuthMethod{ + Name: "test", + Type: "testing", + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + } + + createRule := func(t *testing.T) string { + rule, _, err := client.ACL().BindingRuleCreate( + &api.ACLBindingRule{ + AuthMethod: "test", + Description: "test rule", + BindType: api.BindingRuleBindTypeService, + BindName: "test-${serviceaccount.name}", + Selector: "serviceaccount.namespace==default", + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + return rule.ID + } + + createDupe := func(t *testing.T) string { + for { + // Check for 1-char duplicates. + rules, _, err := client.ACL().BindingRuleList( + "test", + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + + m := make(map[byte]struct{}) + for _, rule := range rules { + c := rule.ID[0] + + if _, ok := m[c]; ok { + return string(c) + } + m[c] = struct{}{} + } + + _ = createRule(t) + } + } + + t.Run("id required", func(t *testing.T) { + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + } + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Must specify the -id parameter") + }) + + t.Run("delete works", func(t *testing.T) { + id := createRule(t) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-id", id, + } + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + + output := ui.OutputWriter.String() + require.Contains(t, output, fmt.Sprintf("deleted successfully")) + require.Contains(t, output, id) + + rule, _, err := client.ACL().BindingRuleRead( + id, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.Nil(t, rule) + }) + + t.Run("delete works via prefixes", func(t *testing.T) { + id := createRule(t) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-id", id[0:5], + } + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + + output := ui.OutputWriter.String() + require.Contains(t, output, fmt.Sprintf("deleted successfully")) + require.Contains(t, output, id) + + rule, _, err := client.ACL().BindingRuleRead( + id, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.Nil(t, rule) + }) + + t.Run("delete fails when prefix matches more than one rule", func(t *testing.T) { + prefix := createDupe(t) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-id=" + prefix, + } + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Error determining binding rule ID") + }) +} diff --git a/command/acl/bindingrule/list/bindingrule_list.go b/command/acl/bindingrule/list/bindingrule_list.go new file mode 100644 index 0000000000..1150ac42c2 --- /dev/null +++ b/command/acl/bindingrule/list/bindingrule_list.go @@ -0,0 +1,98 @@ +package bindingrulelist + +import ( + "flag" + "fmt" + + "github.com/hashicorp/consul/command/acl" + "github.com/hashicorp/consul/command/flags" + "github.com/mitchellh/cli" +) + +func New(ui cli.Ui) *cmd { + c := &cmd{UI: ui} + c.init() + return c +} + +type cmd struct { + UI cli.Ui + flags *flag.FlagSet + http *flags.HTTPFlags + help string + + authMethodName string + + showMeta bool +} + +func (c *cmd) init() { + c.flags = flag.NewFlagSet("", flag.ContinueOnError) + + c.flags.BoolVar( + &c.showMeta, + "meta", + false, + "Indicates that binding rule metadata such "+ + "as the content hash and raft indices should be shown for each entry.", + ) + + c.flags.StringVar( + &c.authMethodName, + "method", + "", + "Only show rules linked to the auth method with the given name.", + ) + + c.http = &flags.HTTPFlags{} + flags.Merge(c.flags, c.http.ClientFlags()) + flags.Merge(c.flags, c.http.ServerFlags()) + c.help = flags.Usage(help, c.flags) +} + +func (c *cmd) Run(args []string) int { + if err := c.flags.Parse(args); err != nil { + return 1 + } + + client, err := c.http.APIClient() + if err != nil { + c.UI.Error(fmt.Sprintf("Error connecting to Consul agent: %s", err)) + return 1 + } + + rules, _, err := client.ACL().BindingRuleList(c.authMethodName, nil) + if err != nil { + c.UI.Error(fmt.Sprintf("Failed to retrieve the binding rule list: %v", err)) + return 1 + } + + for _, rule := range rules { + acl.PrintBindingRuleListEntry(rule, c.UI, c.showMeta) + } + + return 0 +} + +func (c *cmd) Synopsis() string { + return synopsis +} + +func (c *cmd) Help() string { + return flags.Usage(c.help, nil) +} + +const synopsis = "Lists ACL Binding Rules" +const help = ` +Usage: consul acl binding-rule list [options] + + Lists all the ACL binding rules. + + Show all: + + $ consul acl binding-rule list + + Show all for a specific auth method: + + $ consul acl binding-rule list -method="my-method" +` diff --git a/command/acl/bindingrule/list/bindingrule_list_test.go b/command/acl/bindingrule/list/bindingrule_list_test.go new file mode 100644 index 0000000000..2d935857e3 --- /dev/null +++ b/command/acl/bindingrule/list/bindingrule_list_test.go @@ -0,0 +1,167 @@ +package bindingrulelist + +import ( + "fmt" + "os" + "strings" + "testing" + + "github.com/hashicorp/consul/agent" + "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/logger" + "github.com/hashicorp/consul/sdk/testutil" + "github.com/hashicorp/consul/testrpc" + "github.com/mitchellh/cli" + "github.com/stretchr/testify/require" + + // activate testing auth method + _ "github.com/hashicorp/consul/agent/consul/authmethod/testauth" +) + +func TestBindingRuleListCommand_noTabs(t *testing.T) { + t.Parallel() + + if strings.ContainsRune(New(cli.NewMockUi()).Help(), '\t') { + t.Fatal("help has tabs") + } +} + +func TestBindingRuleListCommand(t *testing.T) { + t.Parallel() + + testDir := testutil.TempDir(t, "acl") + defer os.RemoveAll(testDir) + + a := agent.NewTestAgent(t, t.Name(), ` + primary_datacenter = "dc1" + acl { + enabled = true + tokens { + master = "root" + } + }`) + + a.Agent.LogWriter = logger.NewLogWriter(512) + + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + client := a.Client() + + { + _, _, err := client.ACL().AuthMethodCreate( + &api.ACLAuthMethod{ + Name: "test-1", + Type: "testing", + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + + _, _, err = client.ACL().AuthMethodCreate( + &api.ACLAuthMethod{ + Name: "test-2", + Type: "testing", + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + } + + createRule := func(t *testing.T, methodName, description string) string { + rule, _, err := client.ACL().BindingRuleCreate( + &api.ACLBindingRule{ + AuthMethod: methodName, + Description: description, + BindType: api.BindingRuleBindTypeService, + BindName: "test-${serviceaccount.name}", + Selector: "serviceaccount.namespace==default", + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + return rule.ID + } + + var ruleIDs []string + for i := 0; i < 10; i++ { + name := fmt.Sprintf("test-rule-%d", i) + + var methodName string + if i%2 == 0 { + methodName = "test-1" + } else { + methodName = "test-2" + } + + id := createRule(t, methodName, name) + + ruleIDs = append(ruleIDs, id) + } + + t.Run("normal", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + output := ui.OutputWriter.String() + + for i, v := range ruleIDs { + require.Contains(t, output, fmt.Sprintf("test-rule-%d", i)) + require.Contains(t, output, v) + } + }) + + t.Run("filter by method 1", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-method=test-1", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + output := ui.OutputWriter.String() + + for i, v := range ruleIDs { + if i%2 == 0 { + require.Contains(t, output, fmt.Sprintf("test-rule-%d", i)) + require.Contains(t, output, v) + } + } + }) + + t.Run("filter by method 2", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-method=test-2", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + output := ui.OutputWriter.String() + + for i, v := range ruleIDs { + if i%2 == 1 { + require.Contains(t, output, fmt.Sprintf("test-rule-%d", i)) + require.Contains(t, output, v) + } + } + }) +} diff --git a/command/acl/bindingrule/read/bindingrule_read.go b/command/acl/bindingrule/read/bindingrule_read.go new file mode 100644 index 0000000000..677a950cf2 --- /dev/null +++ b/command/acl/bindingrule/read/bindingrule_read.go @@ -0,0 +1,108 @@ +package bindingruleread + +import ( + "flag" + "fmt" + + "github.com/hashicorp/consul/command/acl" + "github.com/hashicorp/consul/command/flags" + "github.com/mitchellh/cli" +) + +func New(ui cli.Ui) *cmd { + c := &cmd{UI: ui} + c.init() + return c +} + +type cmd struct { + UI cli.Ui + flags *flag.FlagSet + http *flags.HTTPFlags + help string + + ruleID string + + showMeta bool +} + +func (c *cmd) init() { + c.flags = flag.NewFlagSet("", flag.ContinueOnError) + + c.flags.BoolVar( + &c.showMeta, + "meta", + false, + "Indicates that binding rule metadata such "+ + "as the content hash and raft indices should be shown for each entry.", + ) + + c.flags.StringVar( + &c.ruleID, + "id", + "", + "The ID of the binding rule to read. "+ + "It may be specified as a unique ID prefix but will error if the prefix "+ + "matches multiple binding rule IDs", + ) + + c.http = &flags.HTTPFlags{} + flags.Merge(c.flags, c.http.ClientFlags()) + flags.Merge(c.flags, c.http.ServerFlags()) + c.help = flags.Usage(help, c.flags) +} + +func (c *cmd) Run(args []string) int { + if err := c.flags.Parse(args); err != nil { + return 1 + } + + if c.ruleID == "" { + c.UI.Error(fmt.Sprintf("Must specify the -id parameter.")) + return 1 + } + + client, err := c.http.APIClient() + if err != nil { + c.UI.Error(fmt.Sprintf("Error connecting to Consul agent: %s", err)) + return 1 + } + + ruleID, err := acl.GetBindingRuleIDFromPartial(client, c.ruleID) + if err != nil { + c.UI.Error(fmt.Sprintf("Error determining binding rule ID: %v", err)) + return 1 + } + + rule, _, err := client.ACL().BindingRuleRead(ruleID, nil) + if err != nil { + c.UI.Error(fmt.Sprintf("Error reading binding rule %q: %v", ruleID, err)) + return 1 + } else if rule == nil { + c.UI.Error(fmt.Sprintf("Binding rule not found with ID %q", ruleID)) + return 1 + } + + acl.PrintBindingRule(rule, c.UI, c.showMeta) + return 0 +} + +func (c *cmd) Synopsis() string { + return synopsis +} + +func (c *cmd) Help() string { + return flags.Usage(c.help, nil) +} + +const synopsis = "Read an ACL Binding Rule" +const help = ` +Usage: consul acl binding-rule read -id ID [options] + + This command will retrieve and print out the details of a single binding + rule. + + Read a binding rule: + + $ consul acl binding-rule read -id fdabbcb5-9de5-4b1a-961f-77214ae88cba +` diff --git a/command/acl/bindingrule/read/bindingrule_read_test.go b/command/acl/bindingrule/read/bindingrule_read_test.go new file mode 100644 index 0000000000..205e29e2fa --- /dev/null +++ b/command/acl/bindingrule/read/bindingrule_read_test.go @@ -0,0 +1,152 @@ +package bindingruleread + +import ( + "fmt" + "os" + "strings" + "testing" + + "github.com/hashicorp/consul/agent" + "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/logger" + "github.com/hashicorp/consul/sdk/testutil" + "github.com/hashicorp/consul/testrpc" + "github.com/hashicorp/go-uuid" + "github.com/mitchellh/cli" + "github.com/stretchr/testify/require" + + // activate testing auth method + _ "github.com/hashicorp/consul/agent/consul/authmethod/testauth" +) + +func TestBindingRuleReadCommand_noTabs(t *testing.T) { + t.Parallel() + + if strings.ContainsRune(New(cli.NewMockUi()).Help(), '\t') { + t.Fatal("help has tabs") + } +} + +func TestBindingRuleReadCommand(t *testing.T) { + t.Parallel() + + testDir := testutil.TempDir(t, "acl") + defer os.RemoveAll(testDir) + + a := agent.NewTestAgent(t, t.Name(), ` + primary_datacenter = "dc1" + acl { + enabled = true + tokens { + master = "root" + } + }`) + + a.Agent.LogWriter = logger.NewLogWriter(512) + + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + client := a.Client() + + // create an auth method in advance + { + _, _, err := client.ACL().AuthMethodCreate( + &api.ACLAuthMethod{ + Name: "test", + Type: "testing", + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + } + + createRule := func(t *testing.T) string { + rule, _, err := client.ACL().BindingRuleCreate( + &api.ACLBindingRule{ + AuthMethod: "test", + Description: "test rule", + BindType: api.BindingRuleBindTypeService, + BindName: "test-${serviceaccount.name}", + Selector: "serviceaccount.namespace==default", + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + return rule.ID + } + + t.Run("id required", func(t *testing.T) { + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + } + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Must specify the -id parameter") + }) + + t.Run("read by id not found", func(t *testing.T) { + fakeID, err := uuid.GenerateUUID() + require.NoError(t, err) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-id=" + fakeID, + } + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Binding rule not found with ID") + }) + + t.Run("read by id", func(t *testing.T) { + id := createRule(t) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-id=" + id, + } + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + + output := ui.OutputWriter.String() + require.Contains(t, output, fmt.Sprintf("test rule")) + require.Contains(t, output, id) + }) + + t.Run("read by id prefix", func(t *testing.T) { + id := createRule(t) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-id=" + id[0:5], + } + + code := cmd.Run(args) + require.Equal(t, code, 0) + require.Empty(t, ui.ErrorWriter.String()) + + output := ui.OutputWriter.String() + require.Contains(t, output, fmt.Sprintf("test rule")) + require.Contains(t, output, id) + }) +} diff --git a/command/acl/bindingrule/update/bindingrule_update.go b/command/acl/bindingrule/update/bindingrule_update.go new file mode 100644 index 0000000000..0f6d23f659 --- /dev/null +++ b/command/acl/bindingrule/update/bindingrule_update.go @@ -0,0 +1,212 @@ +package bindingruleupdate + +import ( + "flag" + "fmt" + + "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/command/acl" + "github.com/hashicorp/consul/command/flags" + "github.com/mitchellh/cli" +) + +func New(ui cli.Ui) *cmd { + c := &cmd{UI: ui} + c.init() + return c +} + +type cmd struct { + UI cli.Ui + flags *flag.FlagSet + http *flags.HTTPFlags + help string + + ruleID string + + description string + selector string + bindType string + bindName string + + noMerge bool + showMeta bool +} + +func (c *cmd) init() { + c.flags = flag.NewFlagSet("", flag.ContinueOnError) + c.flags.BoolVar( + &c.showMeta, + "meta", + false, + "Indicates that binding rule metadata such "+ + "as the content hash and raft indices should be shown for each entry.", + ) + + c.flags.StringVar( + &c.ruleID, + "id", + "", + "The ID of the binding rule to update. "+ + "It may be specified as a unique ID prefix but will error if the prefix "+ + "matches multiple binding rule IDs", + ) + + c.flags.StringVar( + &c.description, + "description", + "", + "A description of the binding rule.", + ) + c.flags.StringVar( + &c.selector, + "selector", + "", + "Selector is an expression that matches against verified identity "+ + "attributes returned from the auth method during login.", + ) + c.flags.StringVar( + &c.bindType, + "bind-type", + string(api.BindingRuleBindTypeService), + "Type of binding to perform (\"service\" or \"role\").", + ) + c.flags.StringVar( + &c.bindName, + "bind-name", + "", + "Name to bind on match. Can use ${var} interpolation. "+ + "This flag is required.", + ) + + c.flags.BoolVar( + &c.noMerge, + "no-merge", + false, + "Do not merge the current binding rule "+ + "information with what is provided to the command. Instead overwrite all fields "+ + "with the exception of the binding rule ID which is immutable.", + ) + + c.http = &flags.HTTPFlags{} + flags.Merge(c.flags, c.http.ClientFlags()) + flags.Merge(c.flags, c.http.ServerFlags()) + c.help = flags.Usage(help, c.flags) +} + +func (c *cmd) Run(args []string) int { + if err := c.flags.Parse(args); err != nil { + return 1 + } + + if c.ruleID == "" { + c.UI.Error(fmt.Sprintf("Cannot update a binding rule without specifying the -id parameter")) + return 1 + } + + client, err := c.http.APIClient() + if err != nil { + c.UI.Error(fmt.Sprintf("Error connecting to Consul agent: %s", err)) + return 1 + } + + ruleID, err := acl.GetBindingRuleIDFromPartial(client, c.ruleID) + if err != nil { + c.UI.Error(fmt.Sprintf("Error determining binding rule ID: %v", err)) + return 1 + } + + // Read the current binding rule in both cases so we can fail better if not found. + currentRule, _, err := client.ACL().BindingRuleRead(ruleID, nil) + if err != nil { + c.UI.Error(fmt.Sprintf("Error when retrieving current binding rule: %v", err)) + return 1 + } else if currentRule == nil { + c.UI.Error(fmt.Sprintf("Binding rule not found with ID %q", ruleID)) + return 1 + } + + var rule *api.ACLBindingRule + if c.noMerge { + if c.bindType == "" { + c.UI.Error(fmt.Sprintf("Missing required '-bind-type' flag")) + c.UI.Error(c.Help()) + return 1 + } else if c.bindName == "" { + c.UI.Error(fmt.Sprintf("Missing required '-bind-name' flag")) + c.UI.Error(c.Help()) + return 1 + } + + rule = &api.ACLBindingRule{ + ID: ruleID, + AuthMethod: currentRule.AuthMethod, // immutable + Description: c.description, + BindType: api.BindingRuleBindType(c.bindType), + BindName: c.bindName, + Selector: c.selector, + } + + } else { + rule = currentRule + + if c.description != "" { + rule.Description = c.description + } + if c.bindType != "" { + rule.BindType = api.BindingRuleBindType(c.bindType) + } + if c.bindName != "" { + rule.BindName = c.bindName + } + if isFlagSet(c.flags, "selector") { + rule.Selector = c.selector // empty is valid + } + } + + rule, _, err = client.ACL().BindingRuleUpdate(rule, nil) + if err != nil { + c.UI.Error(fmt.Sprintf("Error updating binding rule %q: %v", ruleID, err)) + return 1 + } + + c.UI.Info(fmt.Sprintf("Binding rule updated successfully")) + acl.PrintBindingRule(rule, c.UI, c.showMeta) + return 0 +} + +func (c *cmd) Synopsis() string { + return synopsis +} + +func (c *cmd) Help() string { + return flags.Usage(c.help, nil) +} + +func isFlagSet(flags *flag.FlagSet, name string) bool { + found := false + flags.Visit(func(f *flag.Flag) { + if f.Name == name { + found = true + } + }) + return found +} + +const synopsis = "Update an ACL Binding Rule" +const help = ` +Usage: consul acl binding-rule update -id ID [options] + + Updates a binding rule. By default it will merge the binding rule + information with its current state so that you do not have to provide all + parameters. This behavior can be disabled by passing -no-merge. + + Update all editable fields of the binding rule: + + $ consul acl binding-rule update \ + -id=43cb72df-9c6f-4315-ac8a-01a9d98155ef \ + -description="new description" \ + -bind-type=role \ + -bind-name='k8s-${serviceaccount.name}' \ + -selector='serviceaccount.namespace==default and serviceaccount.name==web' +` diff --git a/command/acl/bindingrule/update/bindingrule_update_test.go b/command/acl/bindingrule/update/bindingrule_update_test.go new file mode 100644 index 0000000000..82a6e1fbb4 --- /dev/null +++ b/command/acl/bindingrule/update/bindingrule_update_test.go @@ -0,0 +1,768 @@ +package bindingruleupdate + +import ( + "os" + "strings" + "testing" + + "github.com/hashicorp/consul/agent" + "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/logger" + "github.com/hashicorp/consul/sdk/testutil" + "github.com/hashicorp/consul/testrpc" + uuid "github.com/hashicorp/go-uuid" + "github.com/mitchellh/cli" + "github.com/stretchr/testify/require" + + // activate testing auth method + _ "github.com/hashicorp/consul/agent/consul/authmethod/testauth" +) + +func TestBindingRuleUpdateCommand_noTabs(t *testing.T) { + t.Parallel() + + if strings.ContainsRune(New(cli.NewMockUi()).Help(), '\t') { + t.Fatal("help has tabs") + } +} + +func TestBindingRuleUpdateCommand(t *testing.T) { + t.Parallel() + + testDir := testutil.TempDir(t, "acl") + defer os.RemoveAll(testDir) + + a := agent.NewTestAgent(t, t.Name(), ` + primary_datacenter = "dc1" + acl { + enabled = true + tokens { + master = "root" + } + }`) + + a.Agent.LogWriter = logger.NewLogWriter(512) + + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + client := a.Client() + + // create an auth method in advance + { + _, _, err := client.ACL().AuthMethodCreate( + &api.ACLAuthMethod{ + Name: "test", + Type: "testing", + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + } + + deleteRules := func(t *testing.T) { + rules, _, err := client.ACL().BindingRuleList( + "test", + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + + for _, rule := range rules { + _, err := client.ACL().BindingRuleDelete( + rule.ID, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + } + } + + t.Run("rule id required", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Cannot update a binding rule without specifying the -id parameter") + }) + + t.Run("rule id partial matches nothing", func(t *testing.T) { + fakeID, err := uuid.GenerateUUID() + require.NoError(t, err) + + ui := cli.NewMockUi() + cmd := New(ui) + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-id=" + fakeID[0:5], + "-token=root", + "-description=test rule edited", + } + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Error determining binding rule ID") + }) + + t.Run("rule id exact match doesn't exist", func(t *testing.T) { + fakeID, err := uuid.GenerateUUID() + require.NoError(t, err) + + ui := cli.NewMockUi() + cmd := New(ui) + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-id=" + fakeID, + "-token=root", + "-description=test rule edited", + } + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Binding rule not found with ID") + }) + + createRule := func(t *testing.T) string { + rule, _, err := client.ACL().BindingRuleCreate( + &api.ACLBindingRule{ + AuthMethod: "test", + Description: "test rule", + BindType: api.BindingRuleBindTypeService, + BindName: "test-${serviceaccount.name}", + Selector: "serviceaccount.namespace==default", + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + return rule.ID + } + + createDupe := func(t *testing.T) string { + for { + // Check for 1-char duplicates. + rules, _, err := client.ACL().BindingRuleList( + "test", + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + + m := make(map[byte]struct{}) + for _, rule := range rules { + c := rule.ID[0] + + if _, ok := m[c]; ok { + return string(c) + } + m[c] = struct{}{} + } + + _ = createRule(t) + } + } + + t.Run("rule id partial matches multiple", func(t *testing.T) { + prefix := createDupe(t) + + ui := cli.NewMockUi() + cmd := New(ui) + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-id=" + prefix, + "-token=root", + "-description=test rule edited", + } + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Error determining binding rule ID") + }) + + t.Run("must use roughly valid selector", func(t *testing.T) { + id := createRule(t) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-id", id, + "-selector", "foo", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Selector is invalid") + }) + + t.Run("update all fields", func(t *testing.T) { + id := createRule(t) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-id", id, + "-description=test rule edited", + "-bind-type", "role", + "-bind-name=role-updated", + "-selector=serviceaccount.namespace==alt and serviceaccount.name==demo", + } + + code := cmd.Run(args) + require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String()) + require.Empty(t, ui.ErrorWriter.String()) + + rule, _, err := client.ACL().BindingRuleRead( + id, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, rule) + + require.Equal(t, "test rule edited", rule.Description) + require.Equal(t, "role-updated", rule.BindName) + require.Equal(t, api.BindingRuleBindTypeRole, rule.BindType) + require.Equal(t, "serviceaccount.namespace==alt and serviceaccount.name==demo", rule.Selector) + }) + + t.Run("update all fields - partial", func(t *testing.T) { + deleteRules(t) // reset since we created a bunch that might be dupes + + id := createRule(t) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-id", id[0:5], + "-description=test rule edited", + "-bind-type", "role", + "-bind-name=role-updated", + "-selector=serviceaccount.namespace==alt and serviceaccount.name==demo", + } + + code := cmd.Run(args) + require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String()) + require.Empty(t, ui.ErrorWriter.String()) + + rule, _, err := client.ACL().BindingRuleRead( + id, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, rule) + + require.Equal(t, "test rule edited", rule.Description) + require.Equal(t, "role-updated", rule.BindName) + require.Equal(t, api.BindingRuleBindTypeRole, rule.BindType) + require.Equal(t, "serviceaccount.namespace==alt and serviceaccount.name==demo", rule.Selector) + }) + + t.Run("update all fields but description", func(t *testing.T) { + id := createRule(t) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-id", id, + "-bind-type", "role", + "-bind-name=role-updated", + "-selector=serviceaccount.namespace==alt and serviceaccount.name==demo", + } + + code := cmd.Run(args) + require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String()) + require.Empty(t, ui.ErrorWriter.String()) + + rule, _, err := client.ACL().BindingRuleRead( + id, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, rule) + + require.Equal(t, "test rule", rule.Description) + require.Equal(t, api.BindingRuleBindTypeRole, rule.BindType) + require.Equal(t, "role-updated", rule.BindName) + require.Equal(t, "serviceaccount.namespace==alt and serviceaccount.name==demo", rule.Selector) + }) + + t.Run("update all fields but bind name", func(t *testing.T) { + id := createRule(t) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-id", id, + "-description=test rule edited", + "-bind-type", "role", + "-selector=serviceaccount.namespace==alt and serviceaccount.name==demo", + } + + code := cmd.Run(args) + require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String()) + require.Empty(t, ui.ErrorWriter.String()) + + rule, _, err := client.ACL().BindingRuleRead( + id, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, rule) + + require.Equal(t, "test rule edited", rule.Description) + require.Equal(t, api.BindingRuleBindTypeRole, rule.BindType) + require.Equal(t, "test-${serviceaccount.name}", rule.BindName) + require.Equal(t, "serviceaccount.namespace==alt and serviceaccount.name==demo", rule.Selector) + }) + + t.Run("update all fields but must exist", func(t *testing.T) { + id := createRule(t) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-id", id, + "-description=test rule edited", + "-bind-name=role-updated", + "-selector=serviceaccount.namespace==alt and serviceaccount.name==demo", + } + + code := cmd.Run(args) + require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String()) + require.Empty(t, ui.ErrorWriter.String()) + + rule, _, err := client.ACL().BindingRuleRead( + id, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, rule) + + require.Equal(t, "test rule edited", rule.Description) + require.Equal(t, api.BindingRuleBindTypeService, rule.BindType) + require.Equal(t, "role-updated", rule.BindName) + require.Equal(t, "serviceaccount.namespace==alt and serviceaccount.name==demo", rule.Selector) + }) + + t.Run("update all fields but selector", func(t *testing.T) { + id := createRule(t) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-id", id, + "-description=test rule edited", + "-bind-type", "role", + "-bind-name=role-updated", + } + + code := cmd.Run(args) + require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String()) + require.Empty(t, ui.ErrorWriter.String()) + + rule, _, err := client.ACL().BindingRuleRead( + id, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, rule) + + require.Equal(t, "test rule edited", rule.Description) + require.Equal(t, api.BindingRuleBindTypeRole, rule.BindType) + require.Equal(t, "role-updated", rule.BindName) + require.Equal(t, "serviceaccount.namespace==default", rule.Selector) + }) + + t.Run("update all fields clear selector", func(t *testing.T) { + id := createRule(t) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-id", id, + "-description=test rule edited", + "-bind-type", "role", + "-bind-name=role-updated", + "-selector=", + } + + code := cmd.Run(args) + require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String()) + require.Empty(t, ui.ErrorWriter.String()) + + rule, _, err := client.ACL().BindingRuleRead( + id, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, rule) + + require.Equal(t, "test rule edited", rule.Description) + require.Equal(t, api.BindingRuleBindTypeRole, rule.BindType) + require.Equal(t, "role-updated", rule.BindName) + require.Empty(t, rule.Selector) + }) +} + +func TestBindingRuleUpdateCommand_noMerge(t *testing.T) { + t.Parallel() + + testDir := testutil.TempDir(t, "acl") + defer os.RemoveAll(testDir) + + a := agent.NewTestAgent(t, t.Name(), ` + primary_datacenter = "dc1" + acl { + enabled = true + tokens { + master = "root" + } + }`) + + a.Agent.LogWriter = logger.NewLogWriter(512) + + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + client := a.Client() + + // create an auth method in advance + { + _, _, err := client.ACL().AuthMethodCreate( + &api.ACLAuthMethod{ + Name: "test", + Type: "testing", + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + } + + deleteRules := func(t *testing.T) { + rules, _, err := client.ACL().BindingRuleList( + "test", + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + + for _, rule := range rules { + _, err := client.ACL().BindingRuleDelete( + rule.ID, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + } + } + + t.Run("rule id required", func(t *testing.T) { + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-no-merge", + } + + ui := cli.NewMockUi() + cmd := New(ui) + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Cannot update a binding rule without specifying the -id parameter") + }) + + t.Run("rule id partial matches nothing", func(t *testing.T) { + fakeID, err := uuid.GenerateUUID() + require.NoError(t, err) + + ui := cli.NewMockUi() + cmd := New(ui) + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-id=" + fakeID[0:5], + "-token=root", + "-no-merge", + "-description=test rule edited", + } + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Error determining binding rule ID") + }) + + t.Run("rule id exact match doesn't exist", func(t *testing.T) { + fakeID, err := uuid.GenerateUUID() + require.NoError(t, err) + + ui := cli.NewMockUi() + cmd := New(ui) + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-id=" + fakeID, + "-token=root", + "-no-merge", + "-description=test rule edited", + } + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Binding rule not found with ID") + }) + + createRule := func(t *testing.T) string { + rule, _, err := client.ACL().BindingRuleCreate( + &api.ACLBindingRule{ + AuthMethod: "test", + Description: "test rule", + BindType: api.BindingRuleBindTypeRole, + BindName: "test-${serviceaccount.name}", + Selector: "serviceaccount.namespace==default", + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + return rule.ID + } + + createDupe := func(t *testing.T) string { + for { + // Check for 1-char duplicates. + rules, _, err := client.ACL().BindingRuleList( + "test", + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + + m := make(map[byte]struct{}) + for _, rule := range rules { + c := rule.ID[0] + + if _, ok := m[c]; ok { + return string(c) + } + m[c] = struct{}{} + } + + _ = createRule(t) + } + } + + t.Run("rule id partial matches multiple", func(t *testing.T) { + prefix := createDupe(t) + + ui := cli.NewMockUi() + cmd := New(ui) + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-id=" + prefix, + "-token=root", + "-no-merge", + "-description=test rule edited", + } + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Error determining binding rule ID") + }) + + t.Run("must use roughly valid selector", func(t *testing.T) { + id := createRule(t) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-no-merge", + "-id", id, + "-description=test rule edited", + "-bind-type", "service", + "-bind-name=role-updated", + "-selector", "foo", + } + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Selector is invalid") + }) + + t.Run("update all fields", func(t *testing.T) { + id := createRule(t) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-no-merge", + "-id", id, + "-description=test rule edited", + "-bind-type", "service", + "-bind-name=role-updated", + "-selector=serviceaccount.namespace==alt and serviceaccount.name==demo", + } + + code := cmd.Run(args) + require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String()) + require.Empty(t, ui.ErrorWriter.String()) + + rule, _, err := client.ACL().BindingRuleRead( + id, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, rule) + + require.Equal(t, "test rule edited", rule.Description) + require.Equal(t, api.BindingRuleBindTypeService, rule.BindType) + require.Equal(t, "role-updated", rule.BindName) + require.Equal(t, "serviceaccount.namespace==alt and serviceaccount.name==demo", rule.Selector) + }) + + t.Run("update all fields - partial", func(t *testing.T) { + deleteRules(t) // reset since we created a bunch that might be dupes + + id := createRule(t) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-no-merge", + "-id", id[0:5], + "-description=test rule edited", + "-bind-type", "service", + "-bind-name=role-updated", + "-selector=serviceaccount.namespace==alt and serviceaccount.name==demo", + } + + code := cmd.Run(args) + require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String()) + require.Empty(t, ui.ErrorWriter.String()) + + rule, _, err := client.ACL().BindingRuleRead( + id, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, rule) + + require.Equal(t, "test rule edited", rule.Description) + require.Equal(t, api.BindingRuleBindTypeService, rule.BindType) + require.Equal(t, "role-updated", rule.BindName) + require.Equal(t, "serviceaccount.namespace==alt and serviceaccount.name==demo", rule.Selector) + }) + + t.Run("update all fields but description", func(t *testing.T) { + id := createRule(t) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-no-merge", + "-id", id, + "-bind-type", "service", + "-bind-name=role-updated", + "-selector=serviceaccount.namespace==alt and serviceaccount.name==demo", + } + + code := cmd.Run(args) + require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String()) + require.Empty(t, ui.ErrorWriter.String()) + + rule, _, err := client.ACL().BindingRuleRead( + id, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, rule) + + require.Empty(t, rule.Description) + require.Equal(t, api.BindingRuleBindTypeService, rule.BindType) + require.Equal(t, "role-updated", rule.BindName) + require.Equal(t, "serviceaccount.namespace==alt and serviceaccount.name==demo", rule.Selector) + }) + + t.Run("missing bind name", func(t *testing.T) { + id := createRule(t) + + ui := cli.NewMockUi() + cmd := New(ui) + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-no-merge", + "-id=" + id, + "-description=test rule edited", + "-bind-type", "role", + "-selector=serviceaccount.namespace==alt and serviceaccount.name==demo", + } + + code := cmd.Run(args) + require.Equal(t, code, 1) + require.Contains(t, ui.ErrorWriter.String(), "Missing required '-bind-name' flag") + }) + + t.Run("update all fields but selector", func(t *testing.T) { + id := createRule(t) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-no-merge", + "-id", id, + "-description=test rule edited", + "-bind-type", "service", + "-bind-name=role-updated", + } + + code := cmd.Run(args) + require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String()) + require.Empty(t, ui.ErrorWriter.String()) + + rule, _, err := client.ACL().BindingRuleRead( + id, + &api.QueryOptions{Token: "root"}, + ) + require.NoError(t, err) + require.NotNil(t, rule) + + require.Equal(t, "test rule edited", rule.Description) + require.Equal(t, api.BindingRuleBindTypeService, rule.BindType) + require.Equal(t, "role-updated", rule.BindName) + require.Empty(t, rule.Selector) + }) +} diff --git a/command/acl/role/delete/role_delete.go b/command/acl/role/delete/role_delete.go index 543133ae1b..5e1b17ad4b 100644 --- a/command/acl/role/delete/role_delete.go +++ b/command/acl/role/delete/role_delete.go @@ -21,7 +21,8 @@ type cmd struct { http *flags.HTTPFlags help string - roleID string + roleID string + roleName string } func (c *cmd) init() { @@ -29,6 +30,7 @@ func (c *cmd) init() { c.flags.StringVar(&c.roleID, "id", "", "The ID of the role to delete. "+ "It may be specified as a unique ID prefix but will error if the prefix "+ "matches multiple role IDs") + c.flags.StringVar(&c.roleName, "name", "", "The name of the role to delete.") c.http = &flags.HTTPFlags{} flags.Merge(c.flags, c.http.ClientFlags()) flags.Merge(c.flags, c.http.ServerFlags()) @@ -40,8 +42,8 @@ func (c *cmd) Run(args []string) int { return 1 } - if c.roleID == "" { - c.UI.Error(fmt.Sprintf("Must specify the -id parameter")) + if c.roleID == "" && c.roleName == "" { + c.UI.Error(fmt.Sprintf("Must specify the -id or -name parameters")) return 1 } @@ -51,7 +53,12 @@ func (c *cmd) Run(args []string) int { return 1 } - roleID, err := acl.GetRoleIDFromPartial(client, c.roleID) + var roleID string + if c.roleID != "" { + roleID, err = acl.GetRoleIDFromPartial(client, c.roleID) + } else { + roleID, err = acl.GetRoleIDByName(client, c.roleName) + } if err != nil { c.UI.Error(fmt.Sprintf("Error determining role ID: %v", err)) return 1 diff --git a/command/acl/role/delete/role_delete_test.go b/command/acl/role/delete/role_delete_test.go index e6523941f6..25f2faf0af 100644 --- a/command/acl/role/delete/role_delete_test.go +++ b/command/acl/role/delete/role_delete_test.go @@ -45,7 +45,7 @@ func TestRoleDeleteCommand(t *testing.T) { client := a.Client() - t.Run("id required", func(t *testing.T) { + t.Run("id or name required", func(t *testing.T) { ui := cli.NewMockUi() cmd := New(ui) @@ -56,7 +56,7 @@ func TestRoleDeleteCommand(t *testing.T) { code := cmd.Run(args) require.Equal(t, code, 1) - require.Contains(t, ui.ErrorWriter.String(), "Must specify the -id parameter") + require.Contains(t, ui.ErrorWriter.String(), "Must specify the -id or -name parameters") }) t.Run("delete works", func(t *testing.T) { diff --git a/command/commands_oss.go b/command/commands_oss.go index f92fcf8bba..89c3a2d408 100644 --- a/command/commands_oss.go +++ b/command/commands_oss.go @@ -3,6 +3,18 @@ package command import ( "github.com/hashicorp/consul/command/acl" aclagent "github.com/hashicorp/consul/command/acl/agenttokens" + aclam "github.com/hashicorp/consul/command/acl/authmethod" + aclamcreate "github.com/hashicorp/consul/command/acl/authmethod/create" + aclamdelete "github.com/hashicorp/consul/command/acl/authmethod/delete" + aclamlist "github.com/hashicorp/consul/command/acl/authmethod/list" + aclamread "github.com/hashicorp/consul/command/acl/authmethod/read" + aclamupdate "github.com/hashicorp/consul/command/acl/authmethod/update" + aclbr "github.com/hashicorp/consul/command/acl/bindingrule" + aclbrcreate "github.com/hashicorp/consul/command/acl/bindingrule/create" + aclbrdelete "github.com/hashicorp/consul/command/acl/bindingrule/delete" + aclbrlist "github.com/hashicorp/consul/command/acl/bindingrule/list" + aclbrread "github.com/hashicorp/consul/command/acl/bindingrule/read" + aclbrupdate "github.com/hashicorp/consul/command/acl/bindingrule/update" aclbootstrap "github.com/hashicorp/consul/command/acl/bootstrap" aclpolicy "github.com/hashicorp/consul/command/acl/policy" aclpcreate "github.com/hashicorp/consul/command/acl/policy/create" @@ -57,6 +69,8 @@ import ( kvput "github.com/hashicorp/consul/command/kv/put" "github.com/hashicorp/consul/command/leave" "github.com/hashicorp/consul/command/lock" + login "github.com/hashicorp/consul/command/login" + logout "github.com/hashicorp/consul/command/logout" "github.com/hashicorp/consul/command/maint" "github.com/hashicorp/consul/command/members" "github.com/hashicorp/consul/command/monitor" @@ -118,6 +132,18 @@ func init() { Register("acl role read", func(ui cli.Ui) (cli.Command, error) { return aclrread.New(ui), nil }) Register("acl role update", func(ui cli.Ui) (cli.Command, error) { return aclrupdate.New(ui), nil }) Register("acl role delete", func(ui cli.Ui) (cli.Command, error) { return aclrdelete.New(ui), nil }) + Register("acl auth-method", func(cli.Ui) (cli.Command, error) { return aclam.New(), nil }) + Register("acl auth-method create", func(ui cli.Ui) (cli.Command, error) { return aclamcreate.New(ui), nil }) + Register("acl auth-method list", func(ui cli.Ui) (cli.Command, error) { return aclamlist.New(ui), nil }) + Register("acl auth-method read", func(ui cli.Ui) (cli.Command, error) { return aclamread.New(ui), nil }) + Register("acl auth-method update", func(ui cli.Ui) (cli.Command, error) { return aclamupdate.New(ui), nil }) + Register("acl auth-method delete", func(ui cli.Ui) (cli.Command, error) { return aclamdelete.New(ui), nil }) + Register("acl binding-rule", func(cli.Ui) (cli.Command, error) { return aclbr.New(), nil }) + Register("acl binding-rule create", func(ui cli.Ui) (cli.Command, error) { return aclbrcreate.New(ui), nil }) + Register("acl binding-rule list", func(ui cli.Ui) (cli.Command, error) { return aclbrlist.New(ui), nil }) + Register("acl binding-rule read", func(ui cli.Ui) (cli.Command, error) { return aclbrread.New(ui), nil }) + Register("acl binding-rule update", func(ui cli.Ui) (cli.Command, error) { return aclbrupdate.New(ui), nil }) + Register("acl binding-rule delete", func(ui cli.Ui) (cli.Command, error) { return aclbrdelete.New(ui), nil }) Register("agent", func(ui cli.Ui) (cli.Command, error) { return agent.New(ui, rev, ver, verPre, verHuman, make(chan struct{})), nil }) @@ -153,6 +179,8 @@ func init() { Register("kv put", func(ui cli.Ui) (cli.Command, error) { return kvput.New(ui), nil }) Register("leave", func(ui cli.Ui) (cli.Command, error) { return leave.New(ui), nil }) Register("lock", func(ui cli.Ui) (cli.Command, error) { return lock.New(ui), nil }) + Register("login", func(ui cli.Ui) (cli.Command, error) { return login.New(ui), nil }) + Register("logout", func(ui cli.Ui) (cli.Command, error) { return logout.New(ui), nil }) Register("maint", func(ui cli.Ui) (cli.Command, error) { return maint.New(ui), nil }) Register("members", func(ui cli.Ui) (cli.Command, error) { return members.New(ui), nil }) Register("monitor", func(ui cli.Ui) (cli.Command, error) { return monitor.New(ui, MakeShutdownCh()), nil }) diff --git a/command/connect/envoy/envoy.go b/command/connect/envoy/envoy.go index 3500b69476..5aa9bea182 100644 --- a/command/connect/envoy/envoy.go +++ b/command/connect/envoy/envoy.go @@ -104,7 +104,7 @@ func (c *cmd) Run(args []string) int { // enabled. c.grpcAddr = "localhost:8502" } - if c.http.Token() == "" { + if c.http.Token() == "" && c.http.TokenFile() == "" { // Extra check needed since CONSUL_HTTP_TOKEN has not been consulted yet but // calling SetToken with empty will force that to override the if proxyToken := os.Getenv(proxyAgent.EnvProxyToken); proxyToken != "" { diff --git a/command/connect/proxy/proxy.go b/command/connect/proxy/proxy.go index a60f2217e7..4c99ab5730 100644 --- a/command/connect/proxy/proxy.go +++ b/command/connect/proxy/proxy.go @@ -129,7 +129,7 @@ func (c *cmd) Run(args []string) int { if c.sidecarFor == "" { c.sidecarFor = os.Getenv(proxyAgent.EnvSidecarFor) } - if c.http.Token() == "" { + if c.http.Token() == "" && c.http.TokenFile() == "" { c.http.SetToken(os.Getenv(proxyAgent.EnvProxyToken)) } diff --git a/command/flags/http.go b/command/flags/http.go index 7d02f6ab3b..e2688fab8c 100644 --- a/command/flags/http.go +++ b/command/flags/http.go @@ -2,6 +2,8 @@ package flags import ( "flag" + "io/ioutil" + "strings" "github.com/hashicorp/consul/api" ) @@ -10,6 +12,7 @@ type HTTPFlags struct { // client api flags address StringValue token StringValue + tokenFile StringValue caFile StringValue caPath StringValue certFile StringValue @@ -33,6 +36,10 @@ func (f *HTTPFlags) ClientFlags() *flag.FlagSet { "ACL token to use in the request. This can also be specified via the "+ "CONSUL_HTTP_TOKEN environment variable. If unspecified, the query will "+ "default to the token of the Consul agent at the HTTP address.") + fs.Var(&f.tokenFile, "token-file", + "File containing the ACL token to use in the request instead of one specified "+ + "via the -token argument or CONSUL_HTTP_TOKEN environment variable. "+ + "This can also be specified via the CONSUL_HTTP_TOKEN_FILE environment variable.") fs.Var(&f.caFile, "ca-file", "Path to a CA file to use for TLS when communicating with Consul. This "+ "can also be specified via the CONSUL_CACERT environment variable.") @@ -88,6 +95,28 @@ func (f *HTTPFlags) SetToken(v string) error { return f.token.Set(v) } +func (f *HTTPFlags) TokenFile() string { + return f.tokenFile.String() +} + +func (f *HTTPFlags) SetTokenFile(v string) error { + return f.tokenFile.Set(v) +} + +func (f *HTTPFlags) ReadTokenFile() (string, error) { + tokenFile := f.tokenFile.String() + if tokenFile == "" { + return "", nil + } + + data, err := ioutil.ReadFile(tokenFile) + if err != nil { + return "", err + } + + return strings.TrimSpace(string(data)), nil +} + func (f *HTTPFlags) APIClient() (*api.Client, error) { c := api.DefaultConfig() @@ -99,6 +128,7 @@ func (f *HTTPFlags) APIClient() (*api.Client, error) { func (f *HTTPFlags) MergeOntoConfig(c *api.Config) { f.address.Merge(&c.Address) f.token.Merge(&c.Token) + f.tokenFile.Merge(&c.TokenFile) f.caFile.Merge(&c.TLSConfig.CAFile) f.caPath.Merge(&c.TLSConfig.CAPath) f.certFile.Merge(&c.TLSConfig.CertFile) diff --git a/command/login/login.go b/command/login/login.go new file mode 100644 index 0000000000..ada268cac8 --- /dev/null +++ b/command/login/login.go @@ -0,0 +1,148 @@ +package login + +import ( + "flag" + "fmt" + "io/ioutil" + "strings" + + "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/command/flags" + "github.com/hashicorp/consul/lib/file" + "github.com/mitchellh/cli" +) + +func New(ui cli.Ui) *cmd { + c := &cmd{UI: ui} + c.init() + return c +} + +type cmd struct { + UI cli.Ui + flags *flag.FlagSet + http *flags.HTTPFlags + help string + + shutdownCh <-chan struct{} + + bearerToken string + + // flags + authMethodName string + bearerTokenFile string + tokenSinkFile string + meta map[string]string +} + +func (c *cmd) init() { + c.flags = flag.NewFlagSet("", flag.ContinueOnError) + + c.flags.StringVar(&c.authMethodName, "method", "", + "Name of the auth method to login to.") + + c.flags.StringVar(&c.bearerTokenFile, "bearer-token-file", "", + "Path to a file containing a secret bearer token to use with this auth method.") + + c.flags.StringVar(&c.tokenSinkFile, "token-sink-file", "", + "The most recent token's SecretID is kept up to date in this file.") + + c.flags.Var((*flags.FlagMapValue)(&c.meta), "meta", + "Metadata to set on the token, formatted as key=value. This flag "+ + "may be specified multiple times to set multiple meta fields.") + + c.http = &flags.HTTPFlags{} + flags.Merge(c.flags, c.http.ClientFlags()) + flags.Merge(c.flags, c.http.ServerFlags()) + c.help = flags.Usage(help, c.flags) +} + +func (c *cmd) Run(args []string) int { + if err := c.flags.Parse(args); err != nil { + return 1 + } + if len(c.flags.Args()) > 0 { + c.UI.Error(fmt.Sprintf("Should have no non-flag arguments.")) + return 1 + } + + if c.authMethodName == "" { + c.UI.Error(fmt.Sprintf("Missing required '-method' flag")) + return 1 + } + if c.tokenSinkFile == "" { + c.UI.Error(fmt.Sprintf("Missing required '-token-sink-file' flag")) + return 1 + } + + if c.bearerTokenFile == "" { + c.UI.Error(fmt.Sprintf("Missing required '-bearer-token-file' flag")) + return 1 + } + + data, err := ioutil.ReadFile(c.bearerTokenFile) + if err != nil { + c.UI.Error(err.Error()) + return 1 + } + c.bearerToken = strings.TrimSpace(string(data)) + + if c.bearerToken == "" { + c.UI.Error(fmt.Sprintf("No bearer token found in %s", c.bearerTokenFile)) + return 1 + } + + // Ensure that we don't try to use a token when performing a login + // operation. + c.http.SetToken("") + c.http.SetTokenFile("") + + client, err := c.http.APIClient() + if err != nil { + c.UI.Error(fmt.Sprintf("Error connecting to Consul agent: %s", err)) + return 1 + } + + // Do the login. + req := &api.ACLLoginParams{ + AuthMethod: c.authMethodName, + BearerToken: c.bearerToken, + Meta: c.meta, + } + tok, _, err := client.ACL().Login(req, nil) + if err != nil { + c.UI.Error(fmt.Sprintf("Error logging in: %s", err)) + return 1 + } + + if err := c.writeToSink(tok); err != nil { + c.UI.Error(fmt.Sprintf("Error writing token to file sink: %s", err)) + return 1 + } + + return 0 +} + +func (c *cmd) writeToSink(tok *api.ACLToken) error { + payload := []byte(tok.SecretID) + return file.WriteAtomicWithPerms(c.tokenSinkFile, payload, 0600) +} + +func (c *cmd) Synopsis() string { + return synopsis +} + +func (c *cmd) Help() string { + return flags.Usage(c.help, nil) +} + +const synopsis = "Login to Consul using an Auth Method" + +const help = ` +Usage: consul login [options] + + The login command will exchange the provided third party credentials with the + requested auth method for a newly minted Consul ACL Token. The companion + command 'consul logout' should be used to destroy any tokens created this way + to avoid a resource leak. +` diff --git a/command/login/login_test.go b/command/login/login_test.go new file mode 100644 index 0000000000..c2988d8626 --- /dev/null +++ b/command/login/login_test.go @@ -0,0 +1,321 @@ +package login + +import ( + "io/ioutil" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/hashicorp/consul/agent" + "github.com/hashicorp/consul/agent/consul/authmethod/kubeauth" + "github.com/hashicorp/consul/agent/consul/authmethod/testauth" + "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/command/acl" + "github.com/hashicorp/consul/logger" + "github.com/hashicorp/consul/sdk/testutil" + "github.com/hashicorp/consul/testrpc" + "github.com/mitchellh/cli" + "github.com/stretchr/testify/require" +) + +func TestLoginCommand_noTabs(t *testing.T) { + t.Parallel() + + if strings.ContainsRune(New(cli.NewMockUi()).Help(), '\t') { + t.Fatal("help has tabs") + } +} + +func TestLoginCommand(t *testing.T) { + t.Parallel() + + testDir := testutil.TempDir(t, "acl") + defer os.RemoveAll(testDir) + + a := agent.NewTestAgent(t, t.Name(), ` + primary_datacenter = "dc1" + acl { + enabled = true + tokens { + master = "root" + } + }`) + + a.Agent.LogWriter = logger.NewLogWriter(512) + + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + client := a.Client() + + t.Run("method is required", func(t *testing.T) { + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + } + + code := cmd.Run(args) + require.Equal(t, code, 1, "err: %s", ui.ErrorWriter.String()) + require.Contains(t, ui.ErrorWriter.String(), "Missing required '-method' flag") + }) + + tokenSinkFile := filepath.Join(testDir, "test.token") + + t.Run("token-sink-file is required", func(t *testing.T) { + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-method=test", + } + + code := cmd.Run(args) + require.Equal(t, code, 1, "err: %s", ui.ErrorWriter.String()) + require.Contains(t, ui.ErrorWriter.String(), "Missing required '-token-sink-file' flag") + }) + + t.Run("bearer-token-file is required", func(t *testing.T) { + defer os.Remove(tokenSinkFile) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-method=test", + "-token-sink-file", tokenSinkFile, + } + + code := cmd.Run(args) + require.Equal(t, code, 1, "err: %s", ui.ErrorWriter.String()) + require.Contains(t, ui.ErrorWriter.String(), "Missing required '-bearer-token-file' flag") + }) + + bearerTokenFile := filepath.Join(testDir, "bearer.token") + + t.Run("bearer-token-file is empty", func(t *testing.T) { + defer os.Remove(tokenSinkFile) + + require.NoError(t, ioutil.WriteFile(bearerTokenFile, []byte(""), 0600)) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-method=test", + "-token-sink-file", tokenSinkFile, + "-bearer-token-file", bearerTokenFile, + } + + code := cmd.Run(args) + require.Equal(t, code, 1, "err: %s", ui.ErrorWriter.String()) + require.Contains(t, ui.ErrorWriter.String(), "No bearer token found in") + }) + + require.NoError(t, ioutil.WriteFile(bearerTokenFile, []byte("demo-token"), 0600)) + + t.Run("try login with no method configured", func(t *testing.T) { + defer os.Remove(tokenSinkFile) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-method=test", + "-token-sink-file", tokenSinkFile, + "-bearer-token-file", bearerTokenFile, + } + + code := cmd.Run(args) + require.Equal(t, code, 1, "err: %s", ui.ErrorWriter.String()) + require.Contains(t, ui.ErrorWriter.String(), "403 (ACL not found)") + }) + + testSessionID := testauth.StartSession() + defer testauth.ResetSession(testSessionID) + + testauth.InstallSessionToken( + testSessionID, + "demo-token", + "default", "demo", "76091af4-4b56-11e9-ac4b-708b11801cbe", + ) + + { + _, _, err := client.ACL().AuthMethodCreate( + &api.ACLAuthMethod{ + Name: "test", + Type: "testing", + Config: map[string]interface{}{ + "SessionID": testSessionID, + }, + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + } + + t.Run("try login with method configured but no binding rules", func(t *testing.T) { + defer os.Remove(tokenSinkFile) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-method=test", + "-token-sink-file", tokenSinkFile, + "-bearer-token-file", bearerTokenFile, + } + + code := cmd.Run(args) + require.Equal(t, 1, code, "err: %s", ui.ErrorWriter.String()) + require.Contains(t, ui.ErrorWriter.String(), "403 (Permission denied)") + }) + + { + _, _, err := client.ACL().BindingRuleCreate(&api.ACLBindingRule{ + AuthMethod: "test", + BindType: api.BindingRuleBindTypeService, + BindName: "${serviceaccount.name}", + Selector: "serviceaccount.namespace==default", + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + } + + t.Run("try login with method configured and binding rules", func(t *testing.T) { + defer os.Remove(tokenSinkFile) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-method=test", + "-token-sink-file", tokenSinkFile, + "-bearer-token-file", bearerTokenFile, + } + + code := cmd.Run(args) + require.Equal(t, 0, code, "err: %s", ui.ErrorWriter.String()) + require.Empty(t, ui.ErrorWriter.String()) + require.Empty(t, ui.OutputWriter.String()) + + raw, err := ioutil.ReadFile(tokenSinkFile) + require.NoError(t, err) + + token := strings.TrimSpace(string(raw)) + require.Len(t, token, 36, "must be a valid uid: %s", token) + }) +} + +func TestLoginCommand_k8s(t *testing.T) { + t.Parallel() + + testDir := testutil.TempDir(t, "acl") + defer os.RemoveAll(testDir) + + a := agent.NewTestAgent(t, t.Name(), ` + primary_datacenter = "dc1" + acl { + enabled = true + tokens { + master = "root" + } + }`) + + a.Agent.LogWriter = logger.NewLogWriter(512) + + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + client := a.Client() + + tokenSinkFile := filepath.Join(testDir, "test.token") + bearerTokenFile := filepath.Join(testDir, "bearer.token") + + // the "B" jwt will be the one being reviewed + require.NoError(t, ioutil.WriteFile(bearerTokenFile, []byte(acl.TestKubernetesJWT_B), 0600)) + + // spin up a fake api server + testSrv := kubeauth.StartTestAPIServer(t) + defer testSrv.Stop() + + testSrv.AuthorizeJWT(acl.TestKubernetesJWT_A) + testSrv.SetAllowedServiceAccount( + "default", + "demo", + "76091af4-4b56-11e9-ac4b-708b11801cbe", + "", + acl.TestKubernetesJWT_B, + ) + + { + _, _, err := client.ACL().AuthMethodCreate( + &api.ACLAuthMethod{ + Name: "k8s", + Type: "kubernetes", + Config: map[string]interface{}{ + "Host": testSrv.Addr(), + "CACert": testSrv.CACert(), + // the "A" jwt will be the one with token review privs + "ServiceAccountJWT": acl.TestKubernetesJWT_A, + }, + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + } + + { + _, _, err := client.ACL().BindingRuleCreate(&api.ACLBindingRule{ + AuthMethod: "k8s", + BindType: api.BindingRuleBindTypeService, + BindName: "${serviceaccount.name}", + Selector: "serviceaccount.namespace==default", + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + } + + t.Run("try login with method configured and binding rules", func(t *testing.T) { + defer os.Remove(tokenSinkFile) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=root", + "-method=k8s", + "-token-sink-file", tokenSinkFile, + "-bearer-token-file", bearerTokenFile, + } + + code := cmd.Run(args) + require.Equal(t, 0, code, "err: %s", ui.ErrorWriter.String()) + require.Empty(t, ui.ErrorWriter.String()) + require.Empty(t, ui.OutputWriter.String()) + + raw, err := ioutil.ReadFile(tokenSinkFile) + require.NoError(t, err) + + token := strings.TrimSpace(string(raw)) + require.Len(t, token, 36, "must be a valid uid: %s", token) + }) +} diff --git a/command/logout/logout.go b/command/logout/logout.go new file mode 100644 index 0000000000..eca9c416be --- /dev/null +++ b/command/logout/logout.go @@ -0,0 +1,70 @@ +package logout + +import ( + "flag" + "fmt" + + "github.com/hashicorp/consul/command/flags" + "github.com/mitchellh/cli" +) + +func New(ui cli.Ui) *cmd { + c := &cmd{UI: ui} + c.init() + return c +} + +type cmd struct { + UI cli.Ui + flags *flag.FlagSet + http *flags.HTTPFlags + help string +} + +func (c *cmd) init() { + c.flags = flag.NewFlagSet("", flag.ContinueOnError) + c.http = &flags.HTTPFlags{} + flags.Merge(c.flags, c.http.ClientFlags()) + flags.Merge(c.flags, c.http.ServerFlags()) + c.help = flags.Usage(help, c.flags) +} + +func (c *cmd) Run(args []string) int { + if err := c.flags.Parse(args); err != nil { + return 1 + } + if len(c.flags.Args()) > 0 { + c.UI.Error(fmt.Sprintf("Should have no non-flag arguments.")) + return 1 + } + + client, err := c.http.APIClient() + if err != nil { + c.UI.Error(fmt.Sprintf("Error connecting to Consul agent: %s", err)) + return 1 + } + + if _, err := client.ACL().Logout(nil); err != nil { + c.UI.Error(fmt.Sprintf("Error destroying token: %v", err)) + return 1 + } + + return 0 +} + +func (c *cmd) Synopsis() string { + return synopsis +} + +func (c *cmd) Help() string { + return flags.Usage(c.help, nil) +} + +const synopsis = "Destroy a Consul Token created with Login" + +const help = ` +Usage: consul logout [options] + + The logout command will destroy the provided token if it was created from + 'consul login'. +` diff --git a/command/logout/logout_test.go b/command/logout/logout_test.go new file mode 100644 index 0000000000..5596297b9c --- /dev/null +++ b/command/logout/logout_test.go @@ -0,0 +1,299 @@ +package logout + +import ( + "os" + "strings" + "testing" + + "github.com/hashicorp/consul/agent" + "github.com/hashicorp/consul/agent/consul/authmethod/kubeauth" + "github.com/hashicorp/consul/agent/consul/authmethod/testauth" + "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/command/acl" + "github.com/hashicorp/consul/logger" + "github.com/hashicorp/consul/sdk/testutil" + "github.com/hashicorp/consul/testrpc" + "github.com/hashicorp/go-uuid" + "github.com/mitchellh/cli" + "github.com/stretchr/testify/require" +) + +func TestLogout_noTabs(t *testing.T) { + t.Parallel() + + if strings.ContainsRune(New(cli.NewMockUi()).Help(), '\t') { + t.Fatal("help has tabs") + } +} + +func TestLogoutCommand(t *testing.T) { + t.Parallel() + + testDir := testutil.TempDir(t, "acl") + defer os.RemoveAll(testDir) + + a := agent.NewTestAgent(t, t.Name(), ` + primary_datacenter = "dc1" + acl { + enabled = true + tokens { + master = "root" + } + }`) + + a.Agent.LogWriter = logger.NewLogWriter(512) + + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + client := a.Client() + + t.Run("no token specified", func(t *testing.T) { + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + } + + code := cmd.Run(args) + require.Equal(t, code, 1, "err: %s", ui.ErrorWriter.String()) + require.Contains(t, ui.ErrorWriter.String(), "403 (ACL not found)") + }) + + t.Run("logout of deleted token", func(t *testing.T) { + fakeID, err := uuid.GenerateUUID() + require.NoError(t, err) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=" + fakeID, + } + + code := cmd.Run(args) + require.Equal(t, code, 1, "err: %s", ui.ErrorWriter.String()) + require.Contains(t, ui.ErrorWriter.String(), "403 (ACL not found)") + }) + + plainToken, _, err := client.ACL().TokenCreate( + &api.ACLToken{Description: "test"}, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + + t.Run("logout of ordinary token", func(t *testing.T) { + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=" + plainToken.SecretID, + } + + code := cmd.Run(args) + require.Equal(t, code, 1, "err: %s", ui.ErrorWriter.String()) + require.Contains(t, ui.ErrorWriter.String(), "403 (Permission denied)") + }) + + testSessionID := testauth.StartSession() + defer testauth.ResetSession(testSessionID) + + testauth.InstallSessionToken( + testSessionID, + "demo-token", + "default", "demo", "76091af4-4b56-11e9-ac4b-708b11801cbe", + ) + + { + _, _, err := client.ACL().AuthMethodCreate( + &api.ACLAuthMethod{ + Name: "test", + Type: "testing", + Config: map[string]interface{}{ + "SessionID": testSessionID, + }, + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + } + { + _, _, err := client.ACL().BindingRuleCreate(&api.ACLBindingRule{ + AuthMethod: "test", + BindType: api.BindingRuleBindTypeService, + BindName: "${serviceaccount.name}", + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + } + + var loginTokenSecret string + { + tok, _, err := client.ACL().Login(&api.ACLLoginParams{ + AuthMethod: "test", + BearerToken: "demo-token", + }, nil) + require.NoError(t, err) + + loginTokenSecret = tok.SecretID + } + + t.Run("logout of login token", func(t *testing.T) { + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=" + loginTokenSecret, + } + + code := cmd.Run(args) + require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String()) + require.Empty(t, ui.ErrorWriter.String()) + }) +} + +func TestLogoutCommand_k8s(t *testing.T) { + t.Parallel() + + testDir := testutil.TempDir(t, "acl") + defer os.RemoveAll(testDir) + + a := agent.NewTestAgent(t, t.Name(), ` + primary_datacenter = "dc1" + acl { + enabled = true + tokens { + master = "root" + } + }`) + + a.Agent.LogWriter = logger.NewLogWriter(512) + + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + client := a.Client() + + t.Run("no token specified", func(t *testing.T) { + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + } + + code := cmd.Run(args) + require.Equal(t, code, 1, "err: %s", ui.ErrorWriter.String()) + require.Contains(t, ui.ErrorWriter.String(), "403 (ACL not found)") + }) + + t.Run("logout of deleted token", func(t *testing.T) { + fakeID, err := uuid.GenerateUUID() + require.NoError(t, err) + + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=" + fakeID, + } + + code := cmd.Run(args) + require.Equal(t, code, 1, "err: %s", ui.ErrorWriter.String()) + require.Contains(t, ui.ErrorWriter.String(), "403 (ACL not found)") + }) + + plainToken, _, err := client.ACL().TokenCreate( + &api.ACLToken{Description: "test"}, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + + t.Run("logout of ordinary token", func(t *testing.T) { + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=" + plainToken.SecretID, + } + + code := cmd.Run(args) + require.Equal(t, code, 1, "err: %s", ui.ErrorWriter.String()) + require.Contains(t, ui.ErrorWriter.String(), "403 (Permission denied)") + }) + + // go to the trouble of creating a login token + // require.NoError(t, ioutil.WriteFile(bearerTokenFile, []byte(acl.TestKubernetesJWT_B), 0600)) + + // spin up a fake api server + testSrv := kubeauth.StartTestAPIServer(t) + defer testSrv.Stop() + + testSrv.AuthorizeJWT(acl.TestKubernetesJWT_A) + testSrv.SetAllowedServiceAccount( + "default", + "demo", + "76091af4-4b56-11e9-ac4b-708b11801cbe", + "", + acl.TestKubernetesJWT_B, + ) + + { + _, _, err := client.ACL().AuthMethodCreate( + &api.ACLAuthMethod{ + Name: "k8s", + Type: "kubernetes", + Config: map[string]interface{}{ + "Host": testSrv.Addr(), + "CACert": testSrv.CACert(), + // the "A" jwt will be the one with token review privs + "ServiceAccountJWT": acl.TestKubernetesJWT_A, + }, + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + } + { + _, _, err := client.ACL().BindingRuleCreate(&api.ACLBindingRule{ + AuthMethod: "k8s", + BindType: api.BindingRuleBindTypeService, + BindName: "${serviceaccount.name}", + }, + &api.WriteOptions{Token: "root"}, + ) + require.NoError(t, err) + } + + var loginTokenSecret string + { + tok, _, err := client.ACL().Login(&api.ACLLoginParams{ + AuthMethod: "k8s", + BearerToken: acl.TestKubernetesJWT_B, + }, nil) + require.NoError(t, err) + + loginTokenSecret = tok.SecretID + } + + t.Run("logout of login token", func(t *testing.T) { + ui := cli.NewMockUi() + cmd := New(ui) + + args := []string{ + "-http-addr=" + a.HTTPAddr(), + "-token=" + loginTokenSecret, + } + + code := cmd.Run(args) + require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String()) + require.Empty(t, ui.ErrorWriter.String()) + }) +} diff --git a/command/watch/watch.go b/command/watch/watch.go index fd44e81fbb..92178a1f06 100644 --- a/command/watch/watch.go +++ b/command/watch/watch.go @@ -87,6 +87,14 @@ func (c *cmd) Run(args []string) int { return 1 } + token := c.http.Token() + if tokenFromFile, err := c.http.ReadTokenFile(); err != nil { + c.UI.Error(fmt.Sprintf("Error loading token file: %s", err)) + return 1 + } else if tokenFromFile != "" { + token = tokenFromFile + } + // Compile the watch parameters params := make(map[string]interface{}) if c.watchType != "" { @@ -95,8 +103,8 @@ func (c *cmd) Run(args []string) int { if c.http.Datacenter() != "" { params["datacenter"] = c.http.Datacenter() } - if c.http.Token() != "" { - params["token"] = c.http.Token() + if token != "" { + params["token"] = token } if c.key != "" { params["key"] = c.key diff --git a/go.mod b/go.mod index 4d87ae9f00..642f2d47d7 100644 --- a/go.mod +++ b/go.mod @@ -123,7 +123,9 @@ require ( gopkg.in/asn1-ber.v1 v1.0.0-20181015200546-f715ec2f112d // indirect gopkg.in/mgo.v2 v2.0.0-20160818020120-3f83fa500528 // indirect gopkg.in/ory-am/dockertest.v3 v3.3.4 // indirect + gopkg.in/square/go-jose.v2 v2.3.1 gotest.tools v2.2.0+incompatible // indirect - k8s.io/api v0.0.0-20190118113203-912cbe2bfef3 // indirect - k8s.io/apimachinery v0.0.0-20180904193909-def12e63c512 // indirect + k8s.io/api v0.0.0-20190325185214-7544f9db76f6 + k8s.io/apimachinery v0.0.0-20190223001710-c182ff3b9841 + k8s.io/client-go v8.0.0+incompatible ) diff --git a/go.sum b/go.sum index 8bba6865b0..649673f1f2 100644 --- a/go.sum +++ b/go.sum @@ -383,6 +383,8 @@ gopkg.in/mgo.v2 v2.0.0-20160818020120-3f83fa500528 h1:/saqWwm73dLmuzbNhe92F0QsZ/ gopkg.in/mgo.v2 v2.0.0-20160818020120-3f83fa500528/go.mod h1:yeKp02qBN3iKW1OzL3MGk2IdtZzaj7SFntXj72NppTA= gopkg.in/ory-am/dockertest.v3 v3.3.4 h1:oen8RiwxVNxtQ1pRoV4e4jqh6UjNsOuIZ1NXns6jdcw= gopkg.in/ory-am/dockertest.v3 v3.3.4/go.mod h1:s9mmoLkaGeAh97qygnNj4xWkiN7e1SKekYC6CovU+ek= +gopkg.in/square/go-jose.v2 v2.3.1 h1:SK5KegNXmKmqE342YYN2qPHEnUYeoMiXXl1poUlI+o4= +gopkg.in/square/go-jose.v2 v2.3.1/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= gopkg.in/yaml.v2 v2.2.1 h1:mUhvW9EsL+naU5Q3cakzfE91YhliOondGd6ZrsDBHQE= @@ -391,10 +393,10 @@ gotest.tools v2.2.0+incompatible h1:VsBPFP1AI068pPrMxtb/S8Zkgf9xEmTLJjfM+P5UIEo= gotest.tools v2.2.0+incompatible/go.mod h1:DsYFclhRJ6vuDpmuTbkuFWG+y2sxOXAzmJt81HFBacw= honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= k8s.io/api v0.0.0-20180806132203-61b11ee65332/go.mod h1:iuAfoD4hCxJ8Onx9kaTIt30j7jUFS00AXQi6QMi99vA= -k8s.io/api v0.0.0-20190118113203-912cbe2bfef3 h1:lV0+KGoNkvZOt4zGT4H83hQrzWMt/US/LSz4z4+BQS4= -k8s.io/api v0.0.0-20190118113203-912cbe2bfef3/go.mod h1:iuAfoD4hCxJ8Onx9kaTIt30j7jUFS00AXQi6QMi99vA= +k8s.io/api v0.0.0-20190325185214-7544f9db76f6 h1:9MWtbqhwTyDvF4cS1qAhxDb9Mi8taXiAu+5nEacl7gY= +k8s.io/api v0.0.0-20190325185214-7544f9db76f6/go.mod h1:iuAfoD4hCxJ8Onx9kaTIt30j7jUFS00AXQi6QMi99vA= k8s.io/apimachinery v0.0.0-20180821005732-488889b0007f/go.mod h1:ccL7Eh7zubPUSh9A3USN90/OzHNSVN6zxzde07TDCL0= -k8s.io/apimachinery v0.0.0-20180904193909-def12e63c512 h1:/Z1m/6oEN6hE2SzWP4BHW2yATeUrBRr+1GxNf1Ny58Y= -k8s.io/apimachinery v0.0.0-20180904193909-def12e63c512/go.mod h1:ccL7Eh7zubPUSh9A3USN90/OzHNSVN6zxzde07TDCL0= +k8s.io/apimachinery v0.0.0-20190223001710-c182ff3b9841 h1:Q4RZrHNtlC/mSdC1sTrcZ5RchC/9vxLVj57pWiCBKv4= +k8s.io/apimachinery v0.0.0-20190223001710-c182ff3b9841/go.mod h1:ccL7Eh7zubPUSh9A3USN90/OzHNSVN6zxzde07TDCL0= k8s.io/client-go v8.0.0+incompatible h1:tTI4hRmb1DRMl4fG6Vclfdi6nTM82oIrTT7HfitmxC4= k8s.io/client-go v8.0.0+incompatible/go.mod h1:7vJpHMYJwNQCWgzmNV+VYUl1zCObLyodBc8nIyt8L5s= diff --git a/vendor/golang.org/x/crypto/pbkdf2/pbkdf2.go b/vendor/golang.org/x/crypto/pbkdf2/pbkdf2.go new file mode 100644 index 0000000000..593f653008 --- /dev/null +++ b/vendor/golang.org/x/crypto/pbkdf2/pbkdf2.go @@ -0,0 +1,77 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +/* +Package pbkdf2 implements the key derivation function PBKDF2 as defined in RFC +2898 / PKCS #5 v2.0. + +A key derivation function is useful when encrypting data based on a password +or any other not-fully-random data. It uses a pseudorandom function to derive +a secure encryption key based on the password. + +While v2.0 of the standard defines only one pseudorandom function to use, +HMAC-SHA1, the drafted v2.1 specification allows use of all five FIPS Approved +Hash Functions SHA-1, SHA-224, SHA-256, SHA-384 and SHA-512 for HMAC. To +choose, you can pass the `New` functions from the different SHA packages to +pbkdf2.Key. +*/ +package pbkdf2 // import "golang.org/x/crypto/pbkdf2" + +import ( + "crypto/hmac" + "hash" +) + +// Key derives a key from the password, salt and iteration count, returning a +// []byte of length keylen that can be used as cryptographic key. The key is +// derived based on the method described as PBKDF2 with the HMAC variant using +// the supplied hash function. +// +// For example, to use a HMAC-SHA-1 based PBKDF2 key derivation function, you +// can get a derived key for e.g. AES-256 (which needs a 32-byte key) by +// doing: +// +// dk := pbkdf2.Key([]byte("some password"), salt, 4096, 32, sha1.New) +// +// Remember to get a good random salt. At least 8 bytes is recommended by the +// RFC. +// +// Using a higher iteration count will increase the cost of an exhaustive +// search but will also make derivation proportionally slower. +func Key(password, salt []byte, iter, keyLen int, h func() hash.Hash) []byte { + prf := hmac.New(h, password) + hashLen := prf.Size() + numBlocks := (keyLen + hashLen - 1) / hashLen + + var buf [4]byte + dk := make([]byte, 0, numBlocks*hashLen) + U := make([]byte, hashLen) + for block := 1; block <= numBlocks; block++ { + // N.B.: || means concatenation, ^ means XOR + // for each block T_i = U_1 ^ U_2 ^ ... ^ U_iter + // U_1 = PRF(password, salt || uint(i)) + prf.Reset() + prf.Write(salt) + buf[0] = byte(block >> 24) + buf[1] = byte(block >> 16) + buf[2] = byte(block >> 8) + buf[3] = byte(block) + prf.Write(buf[:4]) + dk = prf.Sum(dk) + T := dk[len(dk)-hashLen:] + copy(U, T) + + // U_n = PRF(password, U_(n-1)) + for n := 2; n <= iter; n++ { + prf.Reset() + prf.Write(U) + U = U[:0] + U = prf.Sum(U) + for x := range U { + T[x] ^= U[x] + } + } + } + return dk[:keyLen] +} diff --git a/vendor/gopkg.in/square/go-jose.v2/.gitcookies.sh.enc b/vendor/gopkg.in/square/go-jose.v2/.gitcookies.sh.enc new file mode 100644 index 0000000000..730e569b06 --- /dev/null +++ b/vendor/gopkg.in/square/go-jose.v2/.gitcookies.sh.enc @@ -0,0 +1 @@ +'|Ê&{tÄU|gGê(ìCy=+¨œòcû:u:/pœ#~žü["±4¤!­nÙAªDK<ŠufÿhÅa¿Â:ºü¸¡´B/£Ø¤¹¤ò_hÎÛSãT*wÌx¼¯¹-ç|àÀÓƒÑÄäóÌ㣗A$$â6£ÁâG)8nÏpûÆË¡3ÌšœoïÏvŽB–3¿­]xÝ“Ó2l§G•|qRÞ¯ ö2 5R–Ó×Ç$´ñ½Yè¡ÞÝ™l‘Ë«yAI"ÛŒ˜®íû¹¼kÄ|Kåþ[9ÆâÒå=°úÿŸñ|@S•3 ó#æx?¾V„,¾‚SÆÝõœwPíogÒ6&V6 ©D.dBŠ 7 \ No newline at end of file diff --git a/vendor/gopkg.in/square/go-jose.v2/.gitignore b/vendor/gopkg.in/square/go-jose.v2/.gitignore new file mode 100644 index 0000000000..5b4d73b681 --- /dev/null +++ b/vendor/gopkg.in/square/go-jose.v2/.gitignore @@ -0,0 +1,7 @@ +*~ +.*.swp +*.out +*.test +*.pem +*.cov +jose-util/jose-util diff --git a/vendor/gopkg.in/square/go-jose.v2/.travis.yml b/vendor/gopkg.in/square/go-jose.v2/.travis.yml new file mode 100644 index 0000000000..fc501ca9b7 --- /dev/null +++ b/vendor/gopkg.in/square/go-jose.v2/.travis.yml @@ -0,0 +1,46 @@ +language: go + +sudo: false + +matrix: + fast_finish: true + allow_failures: + - go: tip + +go: +- '1.7.x' +- '1.8.x' +- '1.9.x' +- '1.10.x' +- '1.11.x' + +go_import_path: gopkg.in/square/go-jose.v2 + +before_script: +- export PATH=$HOME/.local/bin:$PATH + +before_install: +# Install encrypted gitcookies to get around bandwidth-limits +# that is causing Travis-CI builds to fail. For more info, see +# https://github.com/golang/go/issues/12933 +- openssl aes-256-cbc -K $encrypted_1528c3c2cafd_key -iv $encrypted_1528c3c2cafd_iv -in .gitcookies.sh.enc -out .gitcookies.sh -d || true +- bash .gitcookies.sh || true +- go get github.com/wadey/gocovmerge +- go get github.com/mattn/goveralls +- go get github.com/stretchr/testify/assert +- go get golang.org/x/tools/cmd/cover || true +- go get code.google.com/p/go.tools/cmd/cover || true +- pip install cram --user + +script: +- go test . -v -covermode=count -coverprofile=profile.cov +- go test ./cipher -v -covermode=count -coverprofile=cipher/profile.cov +- go test ./jwt -v -covermode=count -coverprofile=jwt/profile.cov +- go test ./json -v # no coverage for forked encoding/json package +- cd jose-util && go build && PATH=$PWD:$PATH cram -v jose-util.t +- cd .. + +after_success: +- gocovmerge *.cov */*.cov > merged.coverprofile +- $HOME/gopath/bin/goveralls -coverprofile merged.coverprofile -service=travis-ci + diff --git a/vendor/gopkg.in/square/go-jose.v2/BUG-BOUNTY.md b/vendor/gopkg.in/square/go-jose.v2/BUG-BOUNTY.md new file mode 100644 index 0000000000..3305db0f65 --- /dev/null +++ b/vendor/gopkg.in/square/go-jose.v2/BUG-BOUNTY.md @@ -0,0 +1,10 @@ +Serious about security +====================== + +Square recognizes the important contributions the security research community +can make. We therefore encourage reporting security issues with the code +contained in this repository. + +If you believe you have discovered a security vulnerability, please follow the +guidelines at . + diff --git a/vendor/gopkg.in/square/go-jose.v2/CONTRIBUTING.md b/vendor/gopkg.in/square/go-jose.v2/CONTRIBUTING.md new file mode 100644 index 0000000000..61b183651c --- /dev/null +++ b/vendor/gopkg.in/square/go-jose.v2/CONTRIBUTING.md @@ -0,0 +1,14 @@ +# Contributing + +If you would like to contribute code to go-jose you can do so through GitHub by +forking the repository and sending a pull request. + +When submitting code, please make every effort to follow existing conventions +and style in order to keep the code as readable as possible. Please also make +sure all tests pass by running `go test`, and format your code with `go fmt`. +We also recommend using `golint` and `errcheck`. + +Before your code can be accepted into the project you must also sign the +[Individual Contributor License Agreement][1]. + + [1]: https://spreadsheets.google.com/spreadsheet/viewform?formkey=dDViT2xzUHAwRkI3X3k5Z0lQM091OGc6MQ&ndplr=1 diff --git a/vendor/gopkg.in/square/go-jose.v2/LICENSE b/vendor/gopkg.in/square/go-jose.v2/LICENSE new file mode 100644 index 0000000000..d645695673 --- /dev/null +++ b/vendor/gopkg.in/square/go-jose.v2/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/vendor/gopkg.in/square/go-jose.v2/README.md b/vendor/gopkg.in/square/go-jose.v2/README.md new file mode 100644 index 0000000000..1791bfa8f6 --- /dev/null +++ b/vendor/gopkg.in/square/go-jose.v2/README.md @@ -0,0 +1,118 @@ +# Go JOSE + +[![godoc](http://img.shields.io/badge/godoc-version_1-blue.svg?style=flat)](https://godoc.org/gopkg.in/square/go-jose.v1) +[![godoc](http://img.shields.io/badge/godoc-version_2-blue.svg?style=flat)](https://godoc.org/gopkg.in/square/go-jose.v2) +[![license](http://img.shields.io/badge/license-apache_2.0-blue.svg?style=flat)](https://raw.githubusercontent.com/square/go-jose/master/LICENSE) +[![build](https://travis-ci.org/square/go-jose.svg?branch=v2)](https://travis-ci.org/square/go-jose) +[![coverage](https://coveralls.io/repos/github/square/go-jose/badge.svg?branch=v2)](https://coveralls.io/r/square/go-jose) + +Package jose aims to provide an implementation of the Javascript Object Signing +and Encryption set of standards. This includes support for JSON Web Encryption, +JSON Web Signature, and JSON Web Token standards. + +**Disclaimer**: This library contains encryption software that is subject to +the U.S. Export Administration Regulations. You may not export, re-export, +transfer or download this code or any part of it in violation of any United +States law, directive or regulation. In particular this software may not be +exported or re-exported in any form or on any media to Iran, North Sudan, +Syria, Cuba, or North Korea, or to denied persons or entities mentioned on any +US maintained blocked list. + +## Overview + +The implementation follows the +[JSON Web Encryption](http://dx.doi.org/10.17487/RFC7516) (RFC 7516), +[JSON Web Signature](http://dx.doi.org/10.17487/RFC7515) (RFC 7515), and +[JSON Web Token](http://dx.doi.org/10.17487/RFC7519) (RFC 7519). +Tables of supported algorithms are shown below. The library supports both +the compact and full serialization formats, and has optional support for +multiple recipients. It also comes with a small command-line utility +([`jose-util`](https://github.com/square/go-jose/tree/v2/jose-util)) +for dealing with JOSE messages in a shell. + +**Note**: We use a forked version of the `encoding/json` package from the Go +standard library which uses case-sensitive matching for member names (instead +of [case-insensitive matching](https://www.ietf.org/mail-archive/web/json/current/msg03763.html)). +This is to avoid differences in interpretation of messages between go-jose and +libraries in other languages. + +### Versions + +We use [gopkg.in](https://gopkg.in) for versioning. + +[Version 2](https://gopkg.in/square/go-jose.v2) +([branch](https://github.com/square/go-jose/tree/v2), +[doc](https://godoc.org/gopkg.in/square/go-jose.v2)) is the current version: + + import "gopkg.in/square/go-jose.v2" + +The old `v1` branch ([go-jose.v1](https://gopkg.in/square/go-jose.v1)) will +still receive backported bug fixes and security fixes, but otherwise +development is frozen. All new feature development takes place on the `v2` +branch. Version 2 also contains additional sub-packages such as the +[jwt](https://godoc.org/gopkg.in/square/go-jose.v2/jwt) implementation +contributed by [@shaxbee](https://github.com/shaxbee). + +### Supported algorithms + +See below for a table of supported algorithms. Algorithm identifiers match +the names in the [JSON Web Algorithms](http://dx.doi.org/10.17487/RFC7518) +standard where possible. The Godoc reference has a list of constants. + + Key encryption | Algorithm identifier(s) + :------------------------- | :------------------------------ + RSA-PKCS#1v1.5 | RSA1_5 + RSA-OAEP | RSA-OAEP, RSA-OAEP-256 + AES key wrap | A128KW, A192KW, A256KW + AES-GCM key wrap | A128GCMKW, A192GCMKW, A256GCMKW + ECDH-ES + AES key wrap | ECDH-ES+A128KW, ECDH-ES+A192KW, ECDH-ES+A256KW + ECDH-ES (direct) | ECDH-ES1 + Direct encryption | dir1 + +1. Not supported in multi-recipient mode + + Signing / MAC | Algorithm identifier(s) + :------------------------- | :------------------------------ + RSASSA-PKCS#1v1.5 | RS256, RS384, RS512 + RSASSA-PSS | PS256, PS384, PS512 + HMAC | HS256, HS384, HS512 + ECDSA | ES256, ES384, ES512 + Ed25519 | EdDSA2 + +2. Only available in version 2 of the package + + Content encryption | Algorithm identifier(s) + :------------------------- | :------------------------------ + AES-CBC+HMAC | A128CBC-HS256, A192CBC-HS384, A256CBC-HS512 + AES-GCM | A128GCM, A192GCM, A256GCM + + Compression | Algorithm identifiers(s) + :------------------------- | ------------------------------- + DEFLATE (RFC 1951) | DEF + +### Supported key types + +See below for a table of supported key types. These are understood by the +library, and can be passed to corresponding functions such as `NewEncrypter` or +`NewSigner`. Each of these keys can also be wrapped in a JWK if desired, which +allows attaching a key id. + + Algorithm(s) | Corresponding types + :------------------------- | ------------------------------- + RSA | *[rsa.PublicKey](http://golang.org/pkg/crypto/rsa/#PublicKey), *[rsa.PrivateKey](http://golang.org/pkg/crypto/rsa/#PrivateKey) + ECDH, ECDSA | *[ecdsa.PublicKey](http://golang.org/pkg/crypto/ecdsa/#PublicKey), *[ecdsa.PrivateKey](http://golang.org/pkg/crypto/ecdsa/#PrivateKey) + EdDSA1 | [ed25519.PublicKey](https://godoc.org/golang.org/x/crypto/ed25519#PublicKey), [ed25519.PrivateKey](https://godoc.org/golang.org/x/crypto/ed25519#PrivateKey) + AES, HMAC | []byte + +1. Only available in version 2 of the package + +## Examples + +[![godoc](http://img.shields.io/badge/godoc-version_1-blue.svg?style=flat)](https://godoc.org/gopkg.in/square/go-jose.v1) +[![godoc](http://img.shields.io/badge/godoc-version_2-blue.svg?style=flat)](https://godoc.org/gopkg.in/square/go-jose.v2) + +Examples can be found in the Godoc +reference for this package. The +[`jose-util`](https://github.com/square/go-jose/tree/v2/jose-util) +subdirectory also contains a small command-line utility which might be useful +as an example. diff --git a/vendor/gopkg.in/square/go-jose.v2/asymmetric.go b/vendor/gopkg.in/square/go-jose.v2/asymmetric.go new file mode 100644 index 0000000000..67935561bc --- /dev/null +++ b/vendor/gopkg.in/square/go-jose.v2/asymmetric.go @@ -0,0 +1,592 @@ +/*- + * Copyright 2014 Square Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package jose + +import ( + "crypto" + "crypto/aes" + "crypto/ecdsa" + "crypto/rand" + "crypto/rsa" + "crypto/sha1" + "crypto/sha256" + "errors" + "fmt" + "math/big" + + "golang.org/x/crypto/ed25519" + "gopkg.in/square/go-jose.v2/cipher" + "gopkg.in/square/go-jose.v2/json" +) + +// A generic RSA-based encrypter/verifier +type rsaEncrypterVerifier struct { + publicKey *rsa.PublicKey +} + +// A generic RSA-based decrypter/signer +type rsaDecrypterSigner struct { + privateKey *rsa.PrivateKey +} + +// A generic EC-based encrypter/verifier +type ecEncrypterVerifier struct { + publicKey *ecdsa.PublicKey +} + +type edEncrypterVerifier struct { + publicKey ed25519.PublicKey +} + +// A key generator for ECDH-ES +type ecKeyGenerator struct { + size int + algID string + publicKey *ecdsa.PublicKey +} + +// A generic EC-based decrypter/signer +type ecDecrypterSigner struct { + privateKey *ecdsa.PrivateKey +} + +type edDecrypterSigner struct { + privateKey ed25519.PrivateKey +} + +// newRSARecipient creates recipientKeyInfo based on the given key. +func newRSARecipient(keyAlg KeyAlgorithm, publicKey *rsa.PublicKey) (recipientKeyInfo, error) { + // Verify that key management algorithm is supported by this encrypter + switch keyAlg { + case RSA1_5, RSA_OAEP, RSA_OAEP_256: + default: + return recipientKeyInfo{}, ErrUnsupportedAlgorithm + } + + if publicKey == nil { + return recipientKeyInfo{}, errors.New("invalid public key") + } + + return recipientKeyInfo{ + keyAlg: keyAlg, + keyEncrypter: &rsaEncrypterVerifier{ + publicKey: publicKey, + }, + }, nil +} + +// newRSASigner creates a recipientSigInfo based on the given key. +func newRSASigner(sigAlg SignatureAlgorithm, privateKey *rsa.PrivateKey) (recipientSigInfo, error) { + // Verify that key management algorithm is supported by this encrypter + switch sigAlg { + case RS256, RS384, RS512, PS256, PS384, PS512: + default: + return recipientSigInfo{}, ErrUnsupportedAlgorithm + } + + if privateKey == nil { + return recipientSigInfo{}, errors.New("invalid private key") + } + + return recipientSigInfo{ + sigAlg: sigAlg, + publicKey: staticPublicKey(&JSONWebKey{ + Key: privateKey.Public(), + }), + signer: &rsaDecrypterSigner{ + privateKey: privateKey, + }, + }, nil +} + +func newEd25519Signer(sigAlg SignatureAlgorithm, privateKey ed25519.PrivateKey) (recipientSigInfo, error) { + if sigAlg != EdDSA { + return recipientSigInfo{}, ErrUnsupportedAlgorithm + } + + if privateKey == nil { + return recipientSigInfo{}, errors.New("invalid private key") + } + return recipientSigInfo{ + sigAlg: sigAlg, + publicKey: staticPublicKey(&JSONWebKey{ + Key: privateKey.Public(), + }), + signer: &edDecrypterSigner{ + privateKey: privateKey, + }, + }, nil +} + +// newECDHRecipient creates recipientKeyInfo based on the given key. +func newECDHRecipient(keyAlg KeyAlgorithm, publicKey *ecdsa.PublicKey) (recipientKeyInfo, error) { + // Verify that key management algorithm is supported by this encrypter + switch keyAlg { + case ECDH_ES, ECDH_ES_A128KW, ECDH_ES_A192KW, ECDH_ES_A256KW: + default: + return recipientKeyInfo{}, ErrUnsupportedAlgorithm + } + + if publicKey == nil || !publicKey.Curve.IsOnCurve(publicKey.X, publicKey.Y) { + return recipientKeyInfo{}, errors.New("invalid public key") + } + + return recipientKeyInfo{ + keyAlg: keyAlg, + keyEncrypter: &ecEncrypterVerifier{ + publicKey: publicKey, + }, + }, nil +} + +// newECDSASigner creates a recipientSigInfo based on the given key. +func newECDSASigner(sigAlg SignatureAlgorithm, privateKey *ecdsa.PrivateKey) (recipientSigInfo, error) { + // Verify that key management algorithm is supported by this encrypter + switch sigAlg { + case ES256, ES384, ES512: + default: + return recipientSigInfo{}, ErrUnsupportedAlgorithm + } + + if privateKey == nil { + return recipientSigInfo{}, errors.New("invalid private key") + } + + return recipientSigInfo{ + sigAlg: sigAlg, + publicKey: staticPublicKey(&JSONWebKey{ + Key: privateKey.Public(), + }), + signer: &ecDecrypterSigner{ + privateKey: privateKey, + }, + }, nil +} + +// Encrypt the given payload and update the object. +func (ctx rsaEncrypterVerifier) encryptKey(cek []byte, alg KeyAlgorithm) (recipientInfo, error) { + encryptedKey, err := ctx.encrypt(cek, alg) + if err != nil { + return recipientInfo{}, err + } + + return recipientInfo{ + encryptedKey: encryptedKey, + header: &rawHeader{}, + }, nil +} + +// Encrypt the given payload. Based on the key encryption algorithm, +// this will either use RSA-PKCS1v1.5 or RSA-OAEP (with SHA-1 or SHA-256). +func (ctx rsaEncrypterVerifier) encrypt(cek []byte, alg KeyAlgorithm) ([]byte, error) { + switch alg { + case RSA1_5: + return rsa.EncryptPKCS1v15(RandReader, ctx.publicKey, cek) + case RSA_OAEP: + return rsa.EncryptOAEP(sha1.New(), RandReader, ctx.publicKey, cek, []byte{}) + case RSA_OAEP_256: + return rsa.EncryptOAEP(sha256.New(), RandReader, ctx.publicKey, cek, []byte{}) + } + + return nil, ErrUnsupportedAlgorithm +} + +// Decrypt the given payload and return the content encryption key. +func (ctx rsaDecrypterSigner) decryptKey(headers rawHeader, recipient *recipientInfo, generator keyGenerator) ([]byte, error) { + return ctx.decrypt(recipient.encryptedKey, headers.getAlgorithm(), generator) +} + +// Decrypt the given payload. Based on the key encryption algorithm, +// this will either use RSA-PKCS1v1.5 or RSA-OAEP (with SHA-1 or SHA-256). +func (ctx rsaDecrypterSigner) decrypt(jek []byte, alg KeyAlgorithm, generator keyGenerator) ([]byte, error) { + // Note: The random reader on decrypt operations is only used for blinding, + // so stubbing is meanlingless (hence the direct use of rand.Reader). + switch alg { + case RSA1_5: + defer func() { + // DecryptPKCS1v15SessionKey sometimes panics on an invalid payload + // because of an index out of bounds error, which we want to ignore. + // This has been fixed in Go 1.3.1 (released 2014/08/13), the recover() + // only exists for preventing crashes with unpatched versions. + // See: https://groups.google.com/forum/#!topic/golang-dev/7ihX6Y6kx9k + // See: https://code.google.com/p/go/source/detail?r=58ee390ff31602edb66af41ed10901ec95904d33 + _ = recover() + }() + + // Perform some input validation. + keyBytes := ctx.privateKey.PublicKey.N.BitLen() / 8 + if keyBytes != len(jek) { + // Input size is incorrect, the encrypted payload should always match + // the size of the public modulus (e.g. using a 2048 bit key will + // produce 256 bytes of output). Reject this since it's invalid input. + return nil, ErrCryptoFailure + } + + cek, _, err := generator.genKey() + if err != nil { + return nil, ErrCryptoFailure + } + + // When decrypting an RSA-PKCS1v1.5 payload, we must take precautions to + // prevent chosen-ciphertext attacks as described in RFC 3218, "Preventing + // the Million Message Attack on Cryptographic Message Syntax". We are + // therefore deliberately ignoring errors here. + _ = rsa.DecryptPKCS1v15SessionKey(rand.Reader, ctx.privateKey, jek, cek) + + return cek, nil + case RSA_OAEP: + // Use rand.Reader for RSA blinding + return rsa.DecryptOAEP(sha1.New(), rand.Reader, ctx.privateKey, jek, []byte{}) + case RSA_OAEP_256: + // Use rand.Reader for RSA blinding + return rsa.DecryptOAEP(sha256.New(), rand.Reader, ctx.privateKey, jek, []byte{}) + } + + return nil, ErrUnsupportedAlgorithm +} + +// Sign the given payload +func (ctx rsaDecrypterSigner) signPayload(payload []byte, alg SignatureAlgorithm) (Signature, error) { + var hash crypto.Hash + + switch alg { + case RS256, PS256: + hash = crypto.SHA256 + case RS384, PS384: + hash = crypto.SHA384 + case RS512, PS512: + hash = crypto.SHA512 + default: + return Signature{}, ErrUnsupportedAlgorithm + } + + hasher := hash.New() + + // According to documentation, Write() on hash never fails + _, _ = hasher.Write(payload) + hashed := hasher.Sum(nil) + + var out []byte + var err error + + switch alg { + case RS256, RS384, RS512: + out, err = rsa.SignPKCS1v15(RandReader, ctx.privateKey, hash, hashed) + case PS256, PS384, PS512: + out, err = rsa.SignPSS(RandReader, ctx.privateKey, hash, hashed, &rsa.PSSOptions{ + SaltLength: rsa.PSSSaltLengthAuto, + }) + } + + if err != nil { + return Signature{}, err + } + + return Signature{ + Signature: out, + protected: &rawHeader{}, + }, nil +} + +// Verify the given payload +func (ctx rsaEncrypterVerifier) verifyPayload(payload []byte, signature []byte, alg SignatureAlgorithm) error { + var hash crypto.Hash + + switch alg { + case RS256, PS256: + hash = crypto.SHA256 + case RS384, PS384: + hash = crypto.SHA384 + case RS512, PS512: + hash = crypto.SHA512 + default: + return ErrUnsupportedAlgorithm + } + + hasher := hash.New() + + // According to documentation, Write() on hash never fails + _, _ = hasher.Write(payload) + hashed := hasher.Sum(nil) + + switch alg { + case RS256, RS384, RS512: + return rsa.VerifyPKCS1v15(ctx.publicKey, hash, hashed, signature) + case PS256, PS384, PS512: + return rsa.VerifyPSS(ctx.publicKey, hash, hashed, signature, nil) + } + + return ErrUnsupportedAlgorithm +} + +// Encrypt the given payload and update the object. +func (ctx ecEncrypterVerifier) encryptKey(cek []byte, alg KeyAlgorithm) (recipientInfo, error) { + switch alg { + case ECDH_ES: + // ECDH-ES mode doesn't wrap a key, the shared secret is used directly as the key. + return recipientInfo{ + header: &rawHeader{}, + }, nil + case ECDH_ES_A128KW, ECDH_ES_A192KW, ECDH_ES_A256KW: + default: + return recipientInfo{}, ErrUnsupportedAlgorithm + } + + generator := ecKeyGenerator{ + algID: string(alg), + publicKey: ctx.publicKey, + } + + switch alg { + case ECDH_ES_A128KW: + generator.size = 16 + case ECDH_ES_A192KW: + generator.size = 24 + case ECDH_ES_A256KW: + generator.size = 32 + } + + kek, header, err := generator.genKey() + if err != nil { + return recipientInfo{}, err + } + + block, err := aes.NewCipher(kek) + if err != nil { + return recipientInfo{}, err + } + + jek, err := josecipher.KeyWrap(block, cek) + if err != nil { + return recipientInfo{}, err + } + + return recipientInfo{ + encryptedKey: jek, + header: &header, + }, nil +} + +// Get key size for EC key generator +func (ctx ecKeyGenerator) keySize() int { + return ctx.size +} + +// Get a content encryption key for ECDH-ES +func (ctx ecKeyGenerator) genKey() ([]byte, rawHeader, error) { + priv, err := ecdsa.GenerateKey(ctx.publicKey.Curve, RandReader) + if err != nil { + return nil, rawHeader{}, err + } + + out := josecipher.DeriveECDHES(ctx.algID, []byte{}, []byte{}, priv, ctx.publicKey, ctx.size) + + b, err := json.Marshal(&JSONWebKey{ + Key: &priv.PublicKey, + }) + if err != nil { + return nil, nil, err + } + + headers := rawHeader{ + headerEPK: makeRawMessage(b), + } + + return out, headers, nil +} + +// Decrypt the given payload and return the content encryption key. +func (ctx ecDecrypterSigner) decryptKey(headers rawHeader, recipient *recipientInfo, generator keyGenerator) ([]byte, error) { + epk, err := headers.getEPK() + if err != nil { + return nil, errors.New("square/go-jose: invalid epk header") + } + if epk == nil { + return nil, errors.New("square/go-jose: missing epk header") + } + + publicKey, ok := epk.Key.(*ecdsa.PublicKey) + if publicKey == nil || !ok { + return nil, errors.New("square/go-jose: invalid epk header") + } + + if !ctx.privateKey.Curve.IsOnCurve(publicKey.X, publicKey.Y) { + return nil, errors.New("square/go-jose: invalid public key in epk header") + } + + apuData, err := headers.getAPU() + if err != nil { + return nil, errors.New("square/go-jose: invalid apu header") + } + apvData, err := headers.getAPV() + if err != nil { + return nil, errors.New("square/go-jose: invalid apv header") + } + + deriveKey := func(algID string, size int) []byte { + return josecipher.DeriveECDHES(algID, apuData.bytes(), apvData.bytes(), ctx.privateKey, publicKey, size) + } + + var keySize int + + algorithm := headers.getAlgorithm() + switch algorithm { + case ECDH_ES: + // ECDH-ES uses direct key agreement, no key unwrapping necessary. + return deriveKey(string(headers.getEncryption()), generator.keySize()), nil + case ECDH_ES_A128KW: + keySize = 16 + case ECDH_ES_A192KW: + keySize = 24 + case ECDH_ES_A256KW: + keySize = 32 + default: + return nil, ErrUnsupportedAlgorithm + } + + key := deriveKey(string(algorithm), keySize) + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + + return josecipher.KeyUnwrap(block, recipient.encryptedKey) +} + +func (ctx edDecrypterSigner) signPayload(payload []byte, alg SignatureAlgorithm) (Signature, error) { + if alg != EdDSA { + return Signature{}, ErrUnsupportedAlgorithm + } + + sig, err := ctx.privateKey.Sign(RandReader, payload, crypto.Hash(0)) + if err != nil { + return Signature{}, err + } + + return Signature{ + Signature: sig, + protected: &rawHeader{}, + }, nil +} + +func (ctx edEncrypterVerifier) verifyPayload(payload []byte, signature []byte, alg SignatureAlgorithm) error { + if alg != EdDSA { + return ErrUnsupportedAlgorithm + } + ok := ed25519.Verify(ctx.publicKey, payload, signature) + if !ok { + return errors.New("square/go-jose: ed25519 signature failed to verify") + } + return nil +} + +// Sign the given payload +func (ctx ecDecrypterSigner) signPayload(payload []byte, alg SignatureAlgorithm) (Signature, error) { + var expectedBitSize int + var hash crypto.Hash + + switch alg { + case ES256: + expectedBitSize = 256 + hash = crypto.SHA256 + case ES384: + expectedBitSize = 384 + hash = crypto.SHA384 + case ES512: + expectedBitSize = 521 + hash = crypto.SHA512 + } + + curveBits := ctx.privateKey.Curve.Params().BitSize + if expectedBitSize != curveBits { + return Signature{}, fmt.Errorf("square/go-jose: expected %d bit key, got %d bits instead", expectedBitSize, curveBits) + } + + hasher := hash.New() + + // According to documentation, Write() on hash never fails + _, _ = hasher.Write(payload) + hashed := hasher.Sum(nil) + + r, s, err := ecdsa.Sign(RandReader, ctx.privateKey, hashed) + if err != nil { + return Signature{}, err + } + + keyBytes := curveBits / 8 + if curveBits%8 > 0 { + keyBytes++ + } + + // We serialize the outputs (r and s) into big-endian byte arrays and pad + // them with zeros on the left to make sure the sizes work out. Both arrays + // must be keyBytes long, and the output must be 2*keyBytes long. + rBytes := r.Bytes() + rBytesPadded := make([]byte, keyBytes) + copy(rBytesPadded[keyBytes-len(rBytes):], rBytes) + + sBytes := s.Bytes() + sBytesPadded := make([]byte, keyBytes) + copy(sBytesPadded[keyBytes-len(sBytes):], sBytes) + + out := append(rBytesPadded, sBytesPadded...) + + return Signature{ + Signature: out, + protected: &rawHeader{}, + }, nil +} + +// Verify the given payload +func (ctx ecEncrypterVerifier) verifyPayload(payload []byte, signature []byte, alg SignatureAlgorithm) error { + var keySize int + var hash crypto.Hash + + switch alg { + case ES256: + keySize = 32 + hash = crypto.SHA256 + case ES384: + keySize = 48 + hash = crypto.SHA384 + case ES512: + keySize = 66 + hash = crypto.SHA512 + default: + return ErrUnsupportedAlgorithm + } + + if len(signature) != 2*keySize { + return fmt.Errorf("square/go-jose: invalid signature size, have %d bytes, wanted %d", len(signature), 2*keySize) + } + + hasher := hash.New() + + // According to documentation, Write() on hash never fails + _, _ = hasher.Write(payload) + hashed := hasher.Sum(nil) + + r := big.NewInt(0).SetBytes(signature[:keySize]) + s := big.NewInt(0).SetBytes(signature[keySize:]) + + match := ecdsa.Verify(ctx.publicKey, hashed, r, s) + if !match { + return errors.New("square/go-jose: ecdsa signature failed to verify") + } + + return nil +} diff --git a/vendor/gopkg.in/square/go-jose.v2/cipher/cbc_hmac.go b/vendor/gopkg.in/square/go-jose.v2/cipher/cbc_hmac.go new file mode 100644 index 0000000000..126b85ce25 --- /dev/null +++ b/vendor/gopkg.in/square/go-jose.v2/cipher/cbc_hmac.go @@ -0,0 +1,196 @@ +/*- + * Copyright 2014 Square Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package josecipher + +import ( + "bytes" + "crypto/cipher" + "crypto/hmac" + "crypto/sha256" + "crypto/sha512" + "crypto/subtle" + "encoding/binary" + "errors" + "hash" +) + +const ( + nonceBytes = 16 +) + +// NewCBCHMAC instantiates a new AEAD based on CBC+HMAC. +func NewCBCHMAC(key []byte, newBlockCipher func([]byte) (cipher.Block, error)) (cipher.AEAD, error) { + keySize := len(key) / 2 + integrityKey := key[:keySize] + encryptionKey := key[keySize:] + + blockCipher, err := newBlockCipher(encryptionKey) + if err != nil { + return nil, err + } + + var hash func() hash.Hash + switch keySize { + case 16: + hash = sha256.New + case 24: + hash = sha512.New384 + case 32: + hash = sha512.New + } + + return &cbcAEAD{ + hash: hash, + blockCipher: blockCipher, + authtagBytes: keySize, + integrityKey: integrityKey, + }, nil +} + +// An AEAD based on CBC+HMAC +type cbcAEAD struct { + hash func() hash.Hash + authtagBytes int + integrityKey []byte + blockCipher cipher.Block +} + +func (ctx *cbcAEAD) NonceSize() int { + return nonceBytes +} + +func (ctx *cbcAEAD) Overhead() int { + // Maximum overhead is block size (for padding) plus auth tag length, where + // the length of the auth tag is equivalent to the key size. + return ctx.blockCipher.BlockSize() + ctx.authtagBytes +} + +// Seal encrypts and authenticates the plaintext. +func (ctx *cbcAEAD) Seal(dst, nonce, plaintext, data []byte) []byte { + // Output buffer -- must take care not to mangle plaintext input. + ciphertext := make([]byte, uint64(len(plaintext))+uint64(ctx.Overhead()))[:len(plaintext)] + copy(ciphertext, plaintext) + ciphertext = padBuffer(ciphertext, ctx.blockCipher.BlockSize()) + + cbc := cipher.NewCBCEncrypter(ctx.blockCipher, nonce) + + cbc.CryptBlocks(ciphertext, ciphertext) + authtag := ctx.computeAuthTag(data, nonce, ciphertext) + + ret, out := resize(dst, uint64(len(dst))+uint64(len(ciphertext))+uint64(len(authtag))) + copy(out, ciphertext) + copy(out[len(ciphertext):], authtag) + + return ret +} + +// Open decrypts and authenticates the ciphertext. +func (ctx *cbcAEAD) Open(dst, nonce, ciphertext, data []byte) ([]byte, error) { + if len(ciphertext) < ctx.authtagBytes { + return nil, errors.New("square/go-jose: invalid ciphertext (too short)") + } + + offset := len(ciphertext) - ctx.authtagBytes + expectedTag := ctx.computeAuthTag(data, nonce, ciphertext[:offset]) + match := subtle.ConstantTimeCompare(expectedTag, ciphertext[offset:]) + if match != 1 { + return nil, errors.New("square/go-jose: invalid ciphertext (auth tag mismatch)") + } + + cbc := cipher.NewCBCDecrypter(ctx.blockCipher, nonce) + + // Make copy of ciphertext buffer, don't want to modify in place + buffer := append([]byte{}, []byte(ciphertext[:offset])...) + + if len(buffer)%ctx.blockCipher.BlockSize() > 0 { + return nil, errors.New("square/go-jose: invalid ciphertext (invalid length)") + } + + cbc.CryptBlocks(buffer, buffer) + + // Remove padding + plaintext, err := unpadBuffer(buffer, ctx.blockCipher.BlockSize()) + if err != nil { + return nil, err + } + + ret, out := resize(dst, uint64(len(dst))+uint64(len(plaintext))) + copy(out, plaintext) + + return ret, nil +} + +// Compute an authentication tag +func (ctx *cbcAEAD) computeAuthTag(aad, nonce, ciphertext []byte) []byte { + buffer := make([]byte, uint64(len(aad))+uint64(len(nonce))+uint64(len(ciphertext))+8) + n := 0 + n += copy(buffer, aad) + n += copy(buffer[n:], nonce) + n += copy(buffer[n:], ciphertext) + binary.BigEndian.PutUint64(buffer[n:], uint64(len(aad))*8) + + // According to documentation, Write() on hash.Hash never fails. + hmac := hmac.New(ctx.hash, ctx.integrityKey) + _, _ = hmac.Write(buffer) + + return hmac.Sum(nil)[:ctx.authtagBytes] +} + +// resize ensures the the given slice has a capacity of at least n bytes. +// If the capacity of the slice is less than n, a new slice is allocated +// and the existing data will be copied. +func resize(in []byte, n uint64) (head, tail []byte) { + if uint64(cap(in)) >= n { + head = in[:n] + } else { + head = make([]byte, n) + copy(head, in) + } + + tail = head[len(in):] + return +} + +// Apply padding +func padBuffer(buffer []byte, blockSize int) []byte { + missing := blockSize - (len(buffer) % blockSize) + ret, out := resize(buffer, uint64(len(buffer))+uint64(missing)) + padding := bytes.Repeat([]byte{byte(missing)}, missing) + copy(out, padding) + return ret +} + +// Remove padding +func unpadBuffer(buffer []byte, blockSize int) ([]byte, error) { + if len(buffer)%blockSize != 0 { + return nil, errors.New("square/go-jose: invalid padding") + } + + last := buffer[len(buffer)-1] + count := int(last) + + if count == 0 || count > blockSize || count > len(buffer) { + return nil, errors.New("square/go-jose: invalid padding") + } + + padding := bytes.Repeat([]byte{last}, count) + if !bytes.HasSuffix(buffer, padding) { + return nil, errors.New("square/go-jose: invalid padding") + } + + return buffer[:len(buffer)-count], nil +} diff --git a/vendor/gopkg.in/square/go-jose.v2/cipher/concat_kdf.go b/vendor/gopkg.in/square/go-jose.v2/cipher/concat_kdf.go new file mode 100644 index 0000000000..f62c3bdba5 --- /dev/null +++ b/vendor/gopkg.in/square/go-jose.v2/cipher/concat_kdf.go @@ -0,0 +1,75 @@ +/*- + * Copyright 2014 Square Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package josecipher + +import ( + "crypto" + "encoding/binary" + "hash" + "io" +) + +type concatKDF struct { + z, info []byte + i uint32 + cache []byte + hasher hash.Hash +} + +// NewConcatKDF builds a KDF reader based on the given inputs. +func NewConcatKDF(hash crypto.Hash, z, algID, ptyUInfo, ptyVInfo, supPubInfo, supPrivInfo []byte) io.Reader { + buffer := make([]byte, uint64(len(algID))+uint64(len(ptyUInfo))+uint64(len(ptyVInfo))+uint64(len(supPubInfo))+uint64(len(supPrivInfo))) + n := 0 + n += copy(buffer, algID) + n += copy(buffer[n:], ptyUInfo) + n += copy(buffer[n:], ptyVInfo) + n += copy(buffer[n:], supPubInfo) + copy(buffer[n:], supPrivInfo) + + hasher := hash.New() + + return &concatKDF{ + z: z, + info: buffer, + hasher: hasher, + cache: []byte{}, + i: 1, + } +} + +func (ctx *concatKDF) Read(out []byte) (int, error) { + copied := copy(out, ctx.cache) + ctx.cache = ctx.cache[copied:] + + for copied < len(out) { + ctx.hasher.Reset() + + // Write on a hash.Hash never fails + _ = binary.Write(ctx.hasher, binary.BigEndian, ctx.i) + _, _ = ctx.hasher.Write(ctx.z) + _, _ = ctx.hasher.Write(ctx.info) + + hash := ctx.hasher.Sum(nil) + chunkCopied := copy(out[copied:], hash) + copied += chunkCopied + ctx.cache = hash[chunkCopied:] + + ctx.i++ + } + + return copied, nil +} diff --git a/vendor/gopkg.in/square/go-jose.v2/cipher/ecdh_es.go b/vendor/gopkg.in/square/go-jose.v2/cipher/ecdh_es.go new file mode 100644 index 0000000000..c128e327f3 --- /dev/null +++ b/vendor/gopkg.in/square/go-jose.v2/cipher/ecdh_es.go @@ -0,0 +1,62 @@ +/*- + * Copyright 2014 Square Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package josecipher + +import ( + "crypto" + "crypto/ecdsa" + "encoding/binary" +) + +// DeriveECDHES derives a shared encryption key using ECDH/ConcatKDF as described in JWE/JWA. +// It is an error to call this function with a private/public key that are not on the same +// curve. Callers must ensure that the keys are valid before calling this function. Output +// size may be at most 1<<16 bytes (64 KiB). +func DeriveECDHES(alg string, apuData, apvData []byte, priv *ecdsa.PrivateKey, pub *ecdsa.PublicKey, size int) []byte { + if size > 1<<16 { + panic("ECDH-ES output size too large, must be less than or equal to 1<<16") + } + + // algId, partyUInfo, partyVInfo inputs must be prefixed with the length + algID := lengthPrefixed([]byte(alg)) + ptyUInfo := lengthPrefixed(apuData) + ptyVInfo := lengthPrefixed(apvData) + + // suppPubInfo is the encoded length of the output size in bits + supPubInfo := make([]byte, 4) + binary.BigEndian.PutUint32(supPubInfo, uint32(size)*8) + + if !priv.PublicKey.Curve.IsOnCurve(pub.X, pub.Y) { + panic("public key not on same curve as private key") + } + + z, _ := priv.PublicKey.Curve.ScalarMult(pub.X, pub.Y, priv.D.Bytes()) + reader := NewConcatKDF(crypto.SHA256, z.Bytes(), algID, ptyUInfo, ptyVInfo, supPubInfo, []byte{}) + + key := make([]byte, size) + + // Read on the KDF will never fail + _, _ = reader.Read(key) + return key +} + +func lengthPrefixed(data []byte) []byte { + out := make([]byte, len(data)+4) + binary.BigEndian.PutUint32(out, uint32(len(data))) + copy(out[4:], data) + return out +} diff --git a/vendor/gopkg.in/square/go-jose.v2/cipher/key_wrap.go b/vendor/gopkg.in/square/go-jose.v2/cipher/key_wrap.go new file mode 100644 index 0000000000..1d36d50151 --- /dev/null +++ b/vendor/gopkg.in/square/go-jose.v2/cipher/key_wrap.go @@ -0,0 +1,109 @@ +/*- + * Copyright 2014 Square Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package josecipher + +import ( + "crypto/cipher" + "crypto/subtle" + "encoding/binary" + "errors" +) + +var defaultIV = []byte{0xA6, 0xA6, 0xA6, 0xA6, 0xA6, 0xA6, 0xA6, 0xA6} + +// KeyWrap implements NIST key wrapping; it wraps a content encryption key (cek) with the given block cipher. +func KeyWrap(block cipher.Block, cek []byte) ([]byte, error) { + if len(cek)%8 != 0 { + return nil, errors.New("square/go-jose: key wrap input must be 8 byte blocks") + } + + n := len(cek) / 8 + r := make([][]byte, n) + + for i := range r { + r[i] = make([]byte, 8) + copy(r[i], cek[i*8:]) + } + + buffer := make([]byte, 16) + tBytes := make([]byte, 8) + copy(buffer, defaultIV) + + for t := 0; t < 6*n; t++ { + copy(buffer[8:], r[t%n]) + + block.Encrypt(buffer, buffer) + + binary.BigEndian.PutUint64(tBytes, uint64(t+1)) + + for i := 0; i < 8; i++ { + buffer[i] = buffer[i] ^ tBytes[i] + } + copy(r[t%n], buffer[8:]) + } + + out := make([]byte, (n+1)*8) + copy(out, buffer[:8]) + for i := range r { + copy(out[(i+1)*8:], r[i]) + } + + return out, nil +} + +// KeyUnwrap implements NIST key unwrapping; it unwraps a content encryption key (cek) with the given block cipher. +func KeyUnwrap(block cipher.Block, ciphertext []byte) ([]byte, error) { + if len(ciphertext)%8 != 0 { + return nil, errors.New("square/go-jose: key wrap input must be 8 byte blocks") + } + + n := (len(ciphertext) / 8) - 1 + r := make([][]byte, n) + + for i := range r { + r[i] = make([]byte, 8) + copy(r[i], ciphertext[(i+1)*8:]) + } + + buffer := make([]byte, 16) + tBytes := make([]byte, 8) + copy(buffer[:8], ciphertext[:8]) + + for t := 6*n - 1; t >= 0; t-- { + binary.BigEndian.PutUint64(tBytes, uint64(t+1)) + + for i := 0; i < 8; i++ { + buffer[i] = buffer[i] ^ tBytes[i] + } + copy(buffer[8:], r[t%n]) + + block.Decrypt(buffer, buffer) + + copy(r[t%n], buffer[8:]) + } + + if subtle.ConstantTimeCompare(buffer[:8], defaultIV) == 0 { + return nil, errors.New("square/go-jose: failed to unwrap key") + } + + out := make([]byte, n*8) + for i := range r { + copy(out[i*8:], r[i]) + } + + return out, nil +} diff --git a/vendor/gopkg.in/square/go-jose.v2/crypter.go b/vendor/gopkg.in/square/go-jose.v2/crypter.go new file mode 100644 index 0000000000..c45c71206b --- /dev/null +++ b/vendor/gopkg.in/square/go-jose.v2/crypter.go @@ -0,0 +1,535 @@ +/*- + * Copyright 2014 Square Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package jose + +import ( + "crypto/ecdsa" + "crypto/rsa" + "errors" + "fmt" + "reflect" + + "gopkg.in/square/go-jose.v2/json" +) + +// Encrypter represents an encrypter which produces an encrypted JWE object. +type Encrypter interface { + Encrypt(plaintext []byte) (*JSONWebEncryption, error) + EncryptWithAuthData(plaintext []byte, aad []byte) (*JSONWebEncryption, error) + Options() EncrypterOptions +} + +// A generic content cipher +type contentCipher interface { + keySize() int + encrypt(cek []byte, aad, plaintext []byte) (*aeadParts, error) + decrypt(cek []byte, aad []byte, parts *aeadParts) ([]byte, error) +} + +// A key generator (for generating/getting a CEK) +type keyGenerator interface { + keySize() int + genKey() ([]byte, rawHeader, error) +} + +// A generic key encrypter +type keyEncrypter interface { + encryptKey(cek []byte, alg KeyAlgorithm) (recipientInfo, error) // Encrypt a key +} + +// A generic key decrypter +type keyDecrypter interface { + decryptKey(headers rawHeader, recipient *recipientInfo, generator keyGenerator) ([]byte, error) // Decrypt a key +} + +// A generic encrypter based on the given key encrypter and content cipher. +type genericEncrypter struct { + contentAlg ContentEncryption + compressionAlg CompressionAlgorithm + cipher contentCipher + recipients []recipientKeyInfo + keyGenerator keyGenerator + extraHeaders map[HeaderKey]interface{} +} + +type recipientKeyInfo struct { + keyID string + keyAlg KeyAlgorithm + keyEncrypter keyEncrypter +} + +// EncrypterOptions represents options that can be set on new encrypters. +type EncrypterOptions struct { + Compression CompressionAlgorithm + + // Optional map of additional keys to be inserted into the protected header + // of a JWS object. Some specifications which make use of JWS like to insert + // additional values here. All values must be JSON-serializable. + ExtraHeaders map[HeaderKey]interface{} +} + +// WithHeader adds an arbitrary value to the ExtraHeaders map, initializing it +// if necessary. It returns itself and so can be used in a fluent style. +func (eo *EncrypterOptions) WithHeader(k HeaderKey, v interface{}) *EncrypterOptions { + if eo.ExtraHeaders == nil { + eo.ExtraHeaders = map[HeaderKey]interface{}{} + } + eo.ExtraHeaders[k] = v + return eo +} + +// WithContentType adds a content type ("cty") header and returns the updated +// EncrypterOptions. +func (eo *EncrypterOptions) WithContentType(contentType ContentType) *EncrypterOptions { + return eo.WithHeader(HeaderContentType, contentType) +} + +// WithType adds a type ("typ") header and returns the updated EncrypterOptions. +func (eo *EncrypterOptions) WithType(typ ContentType) *EncrypterOptions { + return eo.WithHeader(HeaderType, typ) +} + +// Recipient represents an algorithm/key to encrypt messages to. +// +// PBES2Count and PBES2Salt correspond with the "p2c" and "p2s" headers used +// on the password-based encryption algorithms PBES2-HS256+A128KW, +// PBES2-HS384+A192KW, and PBES2-HS512+A256KW. If they are not provided a safe +// default of 100000 will be used for the count and a 128-bit random salt will +// be generated. +type Recipient struct { + Algorithm KeyAlgorithm + Key interface{} + KeyID string + PBES2Count int + PBES2Salt []byte +} + +// NewEncrypter creates an appropriate encrypter based on the key type +func NewEncrypter(enc ContentEncryption, rcpt Recipient, opts *EncrypterOptions) (Encrypter, error) { + encrypter := &genericEncrypter{ + contentAlg: enc, + recipients: []recipientKeyInfo{}, + cipher: getContentCipher(enc), + } + if opts != nil { + encrypter.compressionAlg = opts.Compression + encrypter.extraHeaders = opts.ExtraHeaders + } + + if encrypter.cipher == nil { + return nil, ErrUnsupportedAlgorithm + } + + var keyID string + var rawKey interface{} + switch encryptionKey := rcpt.Key.(type) { + case JSONWebKey: + keyID, rawKey = encryptionKey.KeyID, encryptionKey.Key + case *JSONWebKey: + keyID, rawKey = encryptionKey.KeyID, encryptionKey.Key + default: + rawKey = encryptionKey + } + + switch rcpt.Algorithm { + case DIRECT: + // Direct encryption mode must be treated differently + if reflect.TypeOf(rawKey) != reflect.TypeOf([]byte{}) { + return nil, ErrUnsupportedKeyType + } + if encrypter.cipher.keySize() != len(rawKey.([]byte)) { + return nil, ErrInvalidKeySize + } + encrypter.keyGenerator = staticKeyGenerator{ + key: rawKey.([]byte), + } + recipientInfo, _ := newSymmetricRecipient(rcpt.Algorithm, rawKey.([]byte)) + recipientInfo.keyID = keyID + if rcpt.KeyID != "" { + recipientInfo.keyID = rcpt.KeyID + } + encrypter.recipients = []recipientKeyInfo{recipientInfo} + return encrypter, nil + case ECDH_ES: + // ECDH-ES (w/o key wrapping) is similar to DIRECT mode + typeOf := reflect.TypeOf(rawKey) + if typeOf != reflect.TypeOf(&ecdsa.PublicKey{}) { + return nil, ErrUnsupportedKeyType + } + encrypter.keyGenerator = ecKeyGenerator{ + size: encrypter.cipher.keySize(), + algID: string(enc), + publicKey: rawKey.(*ecdsa.PublicKey), + } + recipientInfo, _ := newECDHRecipient(rcpt.Algorithm, rawKey.(*ecdsa.PublicKey)) + recipientInfo.keyID = keyID + if rcpt.KeyID != "" { + recipientInfo.keyID = rcpt.KeyID + } + encrypter.recipients = []recipientKeyInfo{recipientInfo} + return encrypter, nil + default: + // Can just add a standard recipient + encrypter.keyGenerator = randomKeyGenerator{ + size: encrypter.cipher.keySize(), + } + err := encrypter.addRecipient(rcpt) + return encrypter, err + } +} + +// NewMultiEncrypter creates a multi-encrypter based on the given parameters +func NewMultiEncrypter(enc ContentEncryption, rcpts []Recipient, opts *EncrypterOptions) (Encrypter, error) { + cipher := getContentCipher(enc) + + if cipher == nil { + return nil, ErrUnsupportedAlgorithm + } + if rcpts == nil || len(rcpts) == 0 { + return nil, fmt.Errorf("square/go-jose: recipients is nil or empty") + } + + encrypter := &genericEncrypter{ + contentAlg: enc, + recipients: []recipientKeyInfo{}, + cipher: cipher, + keyGenerator: randomKeyGenerator{ + size: cipher.keySize(), + }, + } + + if opts != nil { + encrypter.compressionAlg = opts.Compression + } + + for _, recipient := range rcpts { + err := encrypter.addRecipient(recipient) + if err != nil { + return nil, err + } + } + + return encrypter, nil +} + +func (ctx *genericEncrypter) addRecipient(recipient Recipient) (err error) { + var recipientInfo recipientKeyInfo + + switch recipient.Algorithm { + case DIRECT, ECDH_ES: + return fmt.Errorf("square/go-jose: key algorithm '%s' not supported in multi-recipient mode", recipient.Algorithm) + } + + recipientInfo, err = makeJWERecipient(recipient.Algorithm, recipient.Key) + if recipient.KeyID != "" { + recipientInfo.keyID = recipient.KeyID + } + + switch recipient.Algorithm { + case PBES2_HS256_A128KW, PBES2_HS384_A192KW, PBES2_HS512_A256KW: + if sr, ok := recipientInfo.keyEncrypter.(*symmetricKeyCipher); ok { + sr.p2c = recipient.PBES2Count + sr.p2s = recipient.PBES2Salt + } + } + + if err == nil { + ctx.recipients = append(ctx.recipients, recipientInfo) + } + return err +} + +func makeJWERecipient(alg KeyAlgorithm, encryptionKey interface{}) (recipientKeyInfo, error) { + switch encryptionKey := encryptionKey.(type) { + case *rsa.PublicKey: + return newRSARecipient(alg, encryptionKey) + case *ecdsa.PublicKey: + return newECDHRecipient(alg, encryptionKey) + case []byte: + return newSymmetricRecipient(alg, encryptionKey) + case string: + return newSymmetricRecipient(alg, []byte(encryptionKey)) + case *JSONWebKey: + recipient, err := makeJWERecipient(alg, encryptionKey.Key) + recipient.keyID = encryptionKey.KeyID + return recipient, err + default: + return recipientKeyInfo{}, ErrUnsupportedKeyType + } +} + +// newDecrypter creates an appropriate decrypter based on the key type +func newDecrypter(decryptionKey interface{}) (keyDecrypter, error) { + switch decryptionKey := decryptionKey.(type) { + case *rsa.PrivateKey: + return &rsaDecrypterSigner{ + privateKey: decryptionKey, + }, nil + case *ecdsa.PrivateKey: + return &ecDecrypterSigner{ + privateKey: decryptionKey, + }, nil + case []byte: + return &symmetricKeyCipher{ + key: decryptionKey, + }, nil + case string: + return &symmetricKeyCipher{ + key: []byte(decryptionKey), + }, nil + case JSONWebKey: + return newDecrypter(decryptionKey.Key) + case *JSONWebKey: + return newDecrypter(decryptionKey.Key) + default: + return nil, ErrUnsupportedKeyType + } +} + +// Implementation of encrypt method producing a JWE object. +func (ctx *genericEncrypter) Encrypt(plaintext []byte) (*JSONWebEncryption, error) { + return ctx.EncryptWithAuthData(plaintext, nil) +} + +// Implementation of encrypt method producing a JWE object. +func (ctx *genericEncrypter) EncryptWithAuthData(plaintext, aad []byte) (*JSONWebEncryption, error) { + obj := &JSONWebEncryption{} + obj.aad = aad + + obj.protected = &rawHeader{} + err := obj.protected.set(headerEncryption, ctx.contentAlg) + if err != nil { + return nil, err + } + + obj.recipients = make([]recipientInfo, len(ctx.recipients)) + + if len(ctx.recipients) == 0 { + return nil, fmt.Errorf("square/go-jose: no recipients to encrypt to") + } + + cek, headers, err := ctx.keyGenerator.genKey() + if err != nil { + return nil, err + } + + obj.protected.merge(&headers) + + for i, info := range ctx.recipients { + recipient, err := info.keyEncrypter.encryptKey(cek, info.keyAlg) + if err != nil { + return nil, err + } + + err = recipient.header.set(headerAlgorithm, info.keyAlg) + if err != nil { + return nil, err + } + + if info.keyID != "" { + err = recipient.header.set(headerKeyID, info.keyID) + if err != nil { + return nil, err + } + } + obj.recipients[i] = recipient + } + + if len(ctx.recipients) == 1 { + // Move per-recipient headers into main protected header if there's + // only a single recipient. + obj.protected.merge(obj.recipients[0].header) + obj.recipients[0].header = nil + } + + if ctx.compressionAlg != NONE { + plaintext, err = compress(ctx.compressionAlg, plaintext) + if err != nil { + return nil, err + } + + err = obj.protected.set(headerCompression, ctx.compressionAlg) + if err != nil { + return nil, err + } + } + + for k, v := range ctx.extraHeaders { + b, err := json.Marshal(v) + if err != nil { + return nil, err + } + (*obj.protected)[k] = makeRawMessage(b) + } + + authData := obj.computeAuthData() + parts, err := ctx.cipher.encrypt(cek, authData, plaintext) + if err != nil { + return nil, err + } + + obj.iv = parts.iv + obj.ciphertext = parts.ciphertext + obj.tag = parts.tag + + return obj, nil +} + +func (ctx *genericEncrypter) Options() EncrypterOptions { + return EncrypterOptions{ + Compression: ctx.compressionAlg, + ExtraHeaders: ctx.extraHeaders, + } +} + +// Decrypt and validate the object and return the plaintext. Note that this +// function does not support multi-recipient, if you desire multi-recipient +// decryption use DecryptMulti instead. +func (obj JSONWebEncryption) Decrypt(decryptionKey interface{}) ([]byte, error) { + headers := obj.mergedHeaders(nil) + + if len(obj.recipients) > 1 { + return nil, errors.New("square/go-jose: too many recipients in payload; expecting only one") + } + + critical, err := headers.getCritical() + if err != nil { + return nil, fmt.Errorf("square/go-jose: invalid crit header") + } + + if len(critical) > 0 { + return nil, fmt.Errorf("square/go-jose: unsupported crit header") + } + + decrypter, err := newDecrypter(decryptionKey) + if err != nil { + return nil, err + } + + cipher := getContentCipher(headers.getEncryption()) + if cipher == nil { + return nil, fmt.Errorf("square/go-jose: unsupported enc value '%s'", string(headers.getEncryption())) + } + + generator := randomKeyGenerator{ + size: cipher.keySize(), + } + + parts := &aeadParts{ + iv: obj.iv, + ciphertext: obj.ciphertext, + tag: obj.tag, + } + + authData := obj.computeAuthData() + + var plaintext []byte + recipient := obj.recipients[0] + recipientHeaders := obj.mergedHeaders(&recipient) + + cek, err := decrypter.decryptKey(recipientHeaders, &recipient, generator) + if err == nil { + // Found a valid CEK -- let's try to decrypt. + plaintext, err = cipher.decrypt(cek, authData, parts) + } + + if plaintext == nil { + return nil, ErrCryptoFailure + } + + // The "zip" header parameter may only be present in the protected header. + if comp := obj.protected.getCompression(); comp != "" { + plaintext, err = decompress(comp, plaintext) + } + + return plaintext, err +} + +// DecryptMulti decrypts and validates the object and returns the plaintexts, +// with support for multiple recipients. It returns the index of the recipient +// for which the decryption was successful, the merged headers for that recipient, +// and the plaintext. +func (obj JSONWebEncryption) DecryptMulti(decryptionKey interface{}) (int, Header, []byte, error) { + globalHeaders := obj.mergedHeaders(nil) + + critical, err := globalHeaders.getCritical() + if err != nil { + return -1, Header{}, nil, fmt.Errorf("square/go-jose: invalid crit header") + } + + if len(critical) > 0 { + return -1, Header{}, nil, fmt.Errorf("square/go-jose: unsupported crit header") + } + + decrypter, err := newDecrypter(decryptionKey) + if err != nil { + return -1, Header{}, nil, err + } + + encryption := globalHeaders.getEncryption() + cipher := getContentCipher(encryption) + if cipher == nil { + return -1, Header{}, nil, fmt.Errorf("square/go-jose: unsupported enc value '%s'", string(encryption)) + } + + generator := randomKeyGenerator{ + size: cipher.keySize(), + } + + parts := &aeadParts{ + iv: obj.iv, + ciphertext: obj.ciphertext, + tag: obj.tag, + } + + authData := obj.computeAuthData() + + index := -1 + var plaintext []byte + var headers rawHeader + + for i, recipient := range obj.recipients { + recipientHeaders := obj.mergedHeaders(&recipient) + + cek, err := decrypter.decryptKey(recipientHeaders, &recipient, generator) + if err == nil { + // Found a valid CEK -- let's try to decrypt. + plaintext, err = cipher.decrypt(cek, authData, parts) + if err == nil { + index = i + headers = recipientHeaders + break + } + } + } + + if plaintext == nil || err != nil { + return -1, Header{}, nil, ErrCryptoFailure + } + + // The "zip" header parameter may only be present in the protected header. + if comp := obj.protected.getCompression(); comp != "" { + plaintext, err = decompress(comp, plaintext) + } + + sanitized, err := headers.sanitized() + if err != nil { + return -1, Header{}, nil, fmt.Errorf("square/go-jose: failed to sanitize header: %v", err) + } + + return index, sanitized, plaintext, err +} diff --git a/vendor/gopkg.in/square/go-jose.v2/doc.go b/vendor/gopkg.in/square/go-jose.v2/doc.go new file mode 100644 index 0000000000..dd1387f3f0 --- /dev/null +++ b/vendor/gopkg.in/square/go-jose.v2/doc.go @@ -0,0 +1,27 @@ +/*- + * Copyright 2014 Square Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + +Package jose aims to provide an implementation of the Javascript Object Signing +and Encryption set of standards. It implements encryption and signing based on +the JSON Web Encryption and JSON Web Signature standards, with optional JSON +Web Token support available in a sub-package. The library supports both the +compact and full serialization formats, and has optional support for multiple +recipients. + +*/ +package jose diff --git a/vendor/gopkg.in/square/go-jose.v2/encoding.go b/vendor/gopkg.in/square/go-jose.v2/encoding.go new file mode 100644 index 0000000000..b9687c647d --- /dev/null +++ b/vendor/gopkg.in/square/go-jose.v2/encoding.go @@ -0,0 +1,179 @@ +/*- + * Copyright 2014 Square Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package jose + +import ( + "bytes" + "compress/flate" + "encoding/base64" + "encoding/binary" + "io" + "math/big" + "regexp" + + "gopkg.in/square/go-jose.v2/json" +) + +var stripWhitespaceRegex = regexp.MustCompile("\\s") + +// Helper function to serialize known-good objects. +// Precondition: value is not a nil pointer. +func mustSerializeJSON(value interface{}) []byte { + out, err := json.Marshal(value) + if err != nil { + panic(err) + } + // We never want to serialize the top-level value "null," since it's not a + // valid JOSE message. But if a caller passes in a nil pointer to this method, + // MarshalJSON will happily serialize it as the top-level value "null". If + // that value is then embedded in another operation, for instance by being + // base64-encoded and fed as input to a signing algorithm + // (https://github.com/square/go-jose/issues/22), the result will be + // incorrect. Because this method is intended for known-good objects, and a nil + // pointer is not a known-good object, we are free to panic in this case. + // Note: It's not possible to directly check whether the data pointed at by an + // interface is a nil pointer, so we do this hacky workaround. + // https://groups.google.com/forum/#!topic/golang-nuts/wnH302gBa4I + if string(out) == "null" { + panic("Tried to serialize a nil pointer.") + } + return out +} + +// Strip all newlines and whitespace +func stripWhitespace(data string) string { + return stripWhitespaceRegex.ReplaceAllString(data, "") +} + +// Perform compression based on algorithm +func compress(algorithm CompressionAlgorithm, input []byte) ([]byte, error) { + switch algorithm { + case DEFLATE: + return deflate(input) + default: + return nil, ErrUnsupportedAlgorithm + } +} + +// Perform decompression based on algorithm +func decompress(algorithm CompressionAlgorithm, input []byte) ([]byte, error) { + switch algorithm { + case DEFLATE: + return inflate(input) + default: + return nil, ErrUnsupportedAlgorithm + } +} + +// Compress with DEFLATE +func deflate(input []byte) ([]byte, error) { + output := new(bytes.Buffer) + + // Writing to byte buffer, err is always nil + writer, _ := flate.NewWriter(output, 1) + _, _ = io.Copy(writer, bytes.NewBuffer(input)) + + err := writer.Close() + return output.Bytes(), err +} + +// Decompress with DEFLATE +func inflate(input []byte) ([]byte, error) { + output := new(bytes.Buffer) + reader := flate.NewReader(bytes.NewBuffer(input)) + + _, err := io.Copy(output, reader) + if err != nil { + return nil, err + } + + err = reader.Close() + return output.Bytes(), err +} + +// byteBuffer represents a slice of bytes that can be serialized to url-safe base64. +type byteBuffer struct { + data []byte +} + +func newBuffer(data []byte) *byteBuffer { + if data == nil { + return nil + } + return &byteBuffer{ + data: data, + } +} + +func newFixedSizeBuffer(data []byte, length int) *byteBuffer { + if len(data) > length { + panic("square/go-jose: invalid call to newFixedSizeBuffer (len(data) > length)") + } + pad := make([]byte, length-len(data)) + return newBuffer(append(pad, data...)) +} + +func newBufferFromInt(num uint64) *byteBuffer { + data := make([]byte, 8) + binary.BigEndian.PutUint64(data, num) + return newBuffer(bytes.TrimLeft(data, "\x00")) +} + +func (b *byteBuffer) MarshalJSON() ([]byte, error) { + return json.Marshal(b.base64()) +} + +func (b *byteBuffer) UnmarshalJSON(data []byte) error { + var encoded string + err := json.Unmarshal(data, &encoded) + if err != nil { + return err + } + + if encoded == "" { + return nil + } + + decoded, err := base64.RawURLEncoding.DecodeString(encoded) + if err != nil { + return err + } + + *b = *newBuffer(decoded) + + return nil +} + +func (b *byteBuffer) base64() string { + return base64.RawURLEncoding.EncodeToString(b.data) +} + +func (b *byteBuffer) bytes() []byte { + // Handling nil here allows us to transparently handle nil slices when serializing. + if b == nil { + return nil + } + return b.data +} + +func (b byteBuffer) bigInt() *big.Int { + return new(big.Int).SetBytes(b.data) +} + +func (b byteBuffer) toInt() int { + return int(b.bigInt().Int64()) +} diff --git a/vendor/gopkg.in/square/go-jose.v2/json/LICENSE b/vendor/gopkg.in/square/go-jose.v2/json/LICENSE new file mode 100644 index 0000000000..7448756763 --- /dev/null +++ b/vendor/gopkg.in/square/go-jose.v2/json/LICENSE @@ -0,0 +1,27 @@ +Copyright (c) 2012 The Go Authors. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/gopkg.in/square/go-jose.v2/json/README.md b/vendor/gopkg.in/square/go-jose.v2/json/README.md new file mode 100644 index 0000000000..86de5e5581 --- /dev/null +++ b/vendor/gopkg.in/square/go-jose.v2/json/README.md @@ -0,0 +1,13 @@ +# Safe JSON + +This repository contains a fork of the `encoding/json` package from Go 1.6. + +The following changes were made: + +* Object deserialization uses case-sensitive member name matching instead of + [case-insensitive matching](https://www.ietf.org/mail-archive/web/json/current/msg03763.html). + This is to avoid differences in the interpretation of JOSE messages between + go-jose and libraries written in other languages. +* When deserializing a JSON object, we check for duplicate keys and reject the + input whenever we detect a duplicate. Rather than trying to work with malformed + data, we prefer to reject it right away. diff --git a/vendor/gopkg.in/square/go-jose.v2/json/decode.go b/vendor/gopkg.in/square/go-jose.v2/json/decode.go new file mode 100644 index 0000000000..37457e5a83 --- /dev/null +++ b/vendor/gopkg.in/square/go-jose.v2/json/decode.go @@ -0,0 +1,1183 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Represents JSON data structure using native Go types: booleans, floats, +// strings, arrays, and maps. + +package json + +import ( + "bytes" + "encoding" + "encoding/base64" + "errors" + "fmt" + "reflect" + "runtime" + "strconv" + "unicode" + "unicode/utf16" + "unicode/utf8" +) + +// Unmarshal parses the JSON-encoded data and stores the result +// in the value pointed to by v. +// +// Unmarshal uses the inverse of the encodings that +// Marshal uses, allocating maps, slices, and pointers as necessary, +// with the following additional rules: +// +// To unmarshal JSON into a pointer, Unmarshal first handles the case of +// the JSON being the JSON literal null. In that case, Unmarshal sets +// the pointer to nil. Otherwise, Unmarshal unmarshals the JSON into +// the value pointed at by the pointer. If the pointer is nil, Unmarshal +// allocates a new value for it to point to. +// +// To unmarshal JSON into a struct, Unmarshal matches incoming object +// keys to the keys used by Marshal (either the struct field name or its tag), +// preferring an exact match but also accepting a case-insensitive match. +// Unmarshal will only set exported fields of the struct. +// +// To unmarshal JSON into an interface value, +// Unmarshal stores one of these in the interface value: +// +// bool, for JSON booleans +// float64, for JSON numbers +// string, for JSON strings +// []interface{}, for JSON arrays +// map[string]interface{}, for JSON objects +// nil for JSON null +// +// To unmarshal a JSON array into a slice, Unmarshal resets the slice length +// to zero and then appends each element to the slice. +// As a special case, to unmarshal an empty JSON array into a slice, +// Unmarshal replaces the slice with a new empty slice. +// +// To unmarshal a JSON array into a Go array, Unmarshal decodes +// JSON array elements into corresponding Go array elements. +// If the Go array is smaller than the JSON array, +// the additional JSON array elements are discarded. +// If the JSON array is smaller than the Go array, +// the additional Go array elements are set to zero values. +// +// To unmarshal a JSON object into a string-keyed map, Unmarshal first +// establishes a map to use, If the map is nil, Unmarshal allocates a new map. +// Otherwise Unmarshal reuses the existing map, keeping existing entries. +// Unmarshal then stores key-value pairs from the JSON object into the map. +// +// If a JSON value is not appropriate for a given target type, +// or if a JSON number overflows the target type, Unmarshal +// skips that field and completes the unmarshaling as best it can. +// If no more serious errors are encountered, Unmarshal returns +// an UnmarshalTypeError describing the earliest such error. +// +// The JSON null value unmarshals into an interface, map, pointer, or slice +// by setting that Go value to nil. Because null is often used in JSON to mean +// ``not present,'' unmarshaling a JSON null into any other Go type has no effect +// on the value and produces no error. +// +// When unmarshaling quoted strings, invalid UTF-8 or +// invalid UTF-16 surrogate pairs are not treated as an error. +// Instead, they are replaced by the Unicode replacement +// character U+FFFD. +// +func Unmarshal(data []byte, v interface{}) error { + // Check for well-formedness. + // Avoids filling out half a data structure + // before discovering a JSON syntax error. + var d decodeState + err := checkValid(data, &d.scan) + if err != nil { + return err + } + + d.init(data) + return d.unmarshal(v) +} + +// Unmarshaler is the interface implemented by objects +// that can unmarshal a JSON description of themselves. +// The input can be assumed to be a valid encoding of +// a JSON value. UnmarshalJSON must copy the JSON data +// if it wishes to retain the data after returning. +type Unmarshaler interface { + UnmarshalJSON([]byte) error +} + +// An UnmarshalTypeError describes a JSON value that was +// not appropriate for a value of a specific Go type. +type UnmarshalTypeError struct { + Value string // description of JSON value - "bool", "array", "number -5" + Type reflect.Type // type of Go value it could not be assigned to + Offset int64 // error occurred after reading Offset bytes +} + +func (e *UnmarshalTypeError) Error() string { + return "json: cannot unmarshal " + e.Value + " into Go value of type " + e.Type.String() +} + +// An UnmarshalFieldError describes a JSON object key that +// led to an unexported (and therefore unwritable) struct field. +// (No longer used; kept for compatibility.) +type UnmarshalFieldError struct { + Key string + Type reflect.Type + Field reflect.StructField +} + +func (e *UnmarshalFieldError) Error() string { + return "json: cannot unmarshal object key " + strconv.Quote(e.Key) + " into unexported field " + e.Field.Name + " of type " + e.Type.String() +} + +// An InvalidUnmarshalError describes an invalid argument passed to Unmarshal. +// (The argument to Unmarshal must be a non-nil pointer.) +type InvalidUnmarshalError struct { + Type reflect.Type +} + +func (e *InvalidUnmarshalError) Error() string { + if e.Type == nil { + return "json: Unmarshal(nil)" + } + + if e.Type.Kind() != reflect.Ptr { + return "json: Unmarshal(non-pointer " + e.Type.String() + ")" + } + return "json: Unmarshal(nil " + e.Type.String() + ")" +} + +func (d *decodeState) unmarshal(v interface{}) (err error) { + defer func() { + if r := recover(); r != nil { + if _, ok := r.(runtime.Error); ok { + panic(r) + } + err = r.(error) + } + }() + + rv := reflect.ValueOf(v) + if rv.Kind() != reflect.Ptr || rv.IsNil() { + return &InvalidUnmarshalError{reflect.TypeOf(v)} + } + + d.scan.reset() + // We decode rv not rv.Elem because the Unmarshaler interface + // test must be applied at the top level of the value. + d.value(rv) + return d.savedError +} + +// A Number represents a JSON number literal. +type Number string + +// String returns the literal text of the number. +func (n Number) String() string { return string(n) } + +// Float64 returns the number as a float64. +func (n Number) Float64() (float64, error) { + return strconv.ParseFloat(string(n), 64) +} + +// Int64 returns the number as an int64. +func (n Number) Int64() (int64, error) { + return strconv.ParseInt(string(n), 10, 64) +} + +// isValidNumber reports whether s is a valid JSON number literal. +func isValidNumber(s string) bool { + // This function implements the JSON numbers grammar. + // See https://tools.ietf.org/html/rfc7159#section-6 + // and http://json.org/number.gif + + if s == "" { + return false + } + + // Optional - + if s[0] == '-' { + s = s[1:] + if s == "" { + return false + } + } + + // Digits + switch { + default: + return false + + case s[0] == '0': + s = s[1:] + + case '1' <= s[0] && s[0] <= '9': + s = s[1:] + for len(s) > 0 && '0' <= s[0] && s[0] <= '9' { + s = s[1:] + } + } + + // . followed by 1 or more digits. + if len(s) >= 2 && s[0] == '.' && '0' <= s[1] && s[1] <= '9' { + s = s[2:] + for len(s) > 0 && '0' <= s[0] && s[0] <= '9' { + s = s[1:] + } + } + + // e or E followed by an optional - or + and + // 1 or more digits. + if len(s) >= 2 && (s[0] == 'e' || s[0] == 'E') { + s = s[1:] + if s[0] == '+' || s[0] == '-' { + s = s[1:] + if s == "" { + return false + } + } + for len(s) > 0 && '0' <= s[0] && s[0] <= '9' { + s = s[1:] + } + } + + // Make sure we are at the end. + return s == "" +} + +// decodeState represents the state while decoding a JSON value. +type decodeState struct { + data []byte + off int // read offset in data + scan scanner + nextscan scanner // for calls to nextValue + savedError error + useNumber bool +} + +// errPhase is used for errors that should not happen unless +// there is a bug in the JSON decoder or something is editing +// the data slice while the decoder executes. +var errPhase = errors.New("JSON decoder out of sync - data changing underfoot?") + +func (d *decodeState) init(data []byte) *decodeState { + d.data = data + d.off = 0 + d.savedError = nil + return d +} + +// error aborts the decoding by panicking with err. +func (d *decodeState) error(err error) { + panic(err) +} + +// saveError saves the first err it is called with, +// for reporting at the end of the unmarshal. +func (d *decodeState) saveError(err error) { + if d.savedError == nil { + d.savedError = err + } +} + +// next cuts off and returns the next full JSON value in d.data[d.off:]. +// The next value is known to be an object or array, not a literal. +func (d *decodeState) next() []byte { + c := d.data[d.off] + item, rest, err := nextValue(d.data[d.off:], &d.nextscan) + if err != nil { + d.error(err) + } + d.off = len(d.data) - len(rest) + + // Our scanner has seen the opening brace/bracket + // and thinks we're still in the middle of the object. + // invent a closing brace/bracket to get it out. + if c == '{' { + d.scan.step(&d.scan, '}') + } else { + d.scan.step(&d.scan, ']') + } + + return item +} + +// scanWhile processes bytes in d.data[d.off:] until it +// receives a scan code not equal to op. +// It updates d.off and returns the new scan code. +func (d *decodeState) scanWhile(op int) int { + var newOp int + for { + if d.off >= len(d.data) { + newOp = d.scan.eof() + d.off = len(d.data) + 1 // mark processed EOF with len+1 + } else { + c := d.data[d.off] + d.off++ + newOp = d.scan.step(&d.scan, c) + } + if newOp != op { + break + } + } + return newOp +} + +// value decodes a JSON value from d.data[d.off:] into the value. +// it updates d.off to point past the decoded value. +func (d *decodeState) value(v reflect.Value) { + if !v.IsValid() { + _, rest, err := nextValue(d.data[d.off:], &d.nextscan) + if err != nil { + d.error(err) + } + d.off = len(d.data) - len(rest) + + // d.scan thinks we're still at the beginning of the item. + // Feed in an empty string - the shortest, simplest value - + // so that it knows we got to the end of the value. + if d.scan.redo { + // rewind. + d.scan.redo = false + d.scan.step = stateBeginValue + } + d.scan.step(&d.scan, '"') + d.scan.step(&d.scan, '"') + + n := len(d.scan.parseState) + if n > 0 && d.scan.parseState[n-1] == parseObjectKey { + // d.scan thinks we just read an object key; finish the object + d.scan.step(&d.scan, ':') + d.scan.step(&d.scan, '"') + d.scan.step(&d.scan, '"') + d.scan.step(&d.scan, '}') + } + + return + } + + switch op := d.scanWhile(scanSkipSpace); op { + default: + d.error(errPhase) + + case scanBeginArray: + d.array(v) + + case scanBeginObject: + d.object(v) + + case scanBeginLiteral: + d.literal(v) + } +} + +type unquotedValue struct{} + +// valueQuoted is like value but decodes a +// quoted string literal or literal null into an interface value. +// If it finds anything other than a quoted string literal or null, +// valueQuoted returns unquotedValue{}. +func (d *decodeState) valueQuoted() interface{} { + switch op := d.scanWhile(scanSkipSpace); op { + default: + d.error(errPhase) + + case scanBeginArray: + d.array(reflect.Value{}) + + case scanBeginObject: + d.object(reflect.Value{}) + + case scanBeginLiteral: + switch v := d.literalInterface().(type) { + case nil, string: + return v + } + } + return unquotedValue{} +} + +// indirect walks down v allocating pointers as needed, +// until it gets to a non-pointer. +// if it encounters an Unmarshaler, indirect stops and returns that. +// if decodingNull is true, indirect stops at the last pointer so it can be set to nil. +func (d *decodeState) indirect(v reflect.Value, decodingNull bool) (Unmarshaler, encoding.TextUnmarshaler, reflect.Value) { + // If v is a named type and is addressable, + // start with its address, so that if the type has pointer methods, + // we find them. + if v.Kind() != reflect.Ptr && v.Type().Name() != "" && v.CanAddr() { + v = v.Addr() + } + for { + // Load value from interface, but only if the result will be + // usefully addressable. + if v.Kind() == reflect.Interface && !v.IsNil() { + e := v.Elem() + if e.Kind() == reflect.Ptr && !e.IsNil() && (!decodingNull || e.Elem().Kind() == reflect.Ptr) { + v = e + continue + } + } + + if v.Kind() != reflect.Ptr { + break + } + + if v.Elem().Kind() != reflect.Ptr && decodingNull && v.CanSet() { + break + } + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + if v.Type().NumMethod() > 0 { + if u, ok := v.Interface().(Unmarshaler); ok { + return u, nil, reflect.Value{} + } + if u, ok := v.Interface().(encoding.TextUnmarshaler); ok { + return nil, u, reflect.Value{} + } + } + v = v.Elem() + } + return nil, nil, v +} + +// array consumes an array from d.data[d.off-1:], decoding into the value v. +// the first byte of the array ('[') has been read already. +func (d *decodeState) array(v reflect.Value) { + // Check for unmarshaler. + u, ut, pv := d.indirect(v, false) + if u != nil { + d.off-- + err := u.UnmarshalJSON(d.next()) + if err != nil { + d.error(err) + } + return + } + if ut != nil { + d.saveError(&UnmarshalTypeError{"array", v.Type(), int64(d.off)}) + d.off-- + d.next() + return + } + + v = pv + + // Check type of target. + switch v.Kind() { + case reflect.Interface: + if v.NumMethod() == 0 { + // Decoding into nil interface? Switch to non-reflect code. + v.Set(reflect.ValueOf(d.arrayInterface())) + return + } + // Otherwise it's invalid. + fallthrough + default: + d.saveError(&UnmarshalTypeError{"array", v.Type(), int64(d.off)}) + d.off-- + d.next() + return + case reflect.Array: + case reflect.Slice: + break + } + + i := 0 + for { + // Look ahead for ] - can only happen on first iteration. + op := d.scanWhile(scanSkipSpace) + if op == scanEndArray { + break + } + + // Back up so d.value can have the byte we just read. + d.off-- + d.scan.undo(op) + + // Get element of array, growing if necessary. + if v.Kind() == reflect.Slice { + // Grow slice if necessary + if i >= v.Cap() { + newcap := v.Cap() + v.Cap()/2 + if newcap < 4 { + newcap = 4 + } + newv := reflect.MakeSlice(v.Type(), v.Len(), newcap) + reflect.Copy(newv, v) + v.Set(newv) + } + if i >= v.Len() { + v.SetLen(i + 1) + } + } + + if i < v.Len() { + // Decode into element. + d.value(v.Index(i)) + } else { + // Ran out of fixed array: skip. + d.value(reflect.Value{}) + } + i++ + + // Next token must be , or ]. + op = d.scanWhile(scanSkipSpace) + if op == scanEndArray { + break + } + if op != scanArrayValue { + d.error(errPhase) + } + } + + if i < v.Len() { + if v.Kind() == reflect.Array { + // Array. Zero the rest. + z := reflect.Zero(v.Type().Elem()) + for ; i < v.Len(); i++ { + v.Index(i).Set(z) + } + } else { + v.SetLen(i) + } + } + if i == 0 && v.Kind() == reflect.Slice { + v.Set(reflect.MakeSlice(v.Type(), 0, 0)) + } +} + +var nullLiteral = []byte("null") + +// object consumes an object from d.data[d.off-1:], decoding into the value v. +// the first byte ('{') of the object has been read already. +func (d *decodeState) object(v reflect.Value) { + // Check for unmarshaler. + u, ut, pv := d.indirect(v, false) + if u != nil { + d.off-- + err := u.UnmarshalJSON(d.next()) + if err != nil { + d.error(err) + } + return + } + if ut != nil { + d.saveError(&UnmarshalTypeError{"object", v.Type(), int64(d.off)}) + d.off-- + d.next() // skip over { } in input + return + } + v = pv + + // Decoding into nil interface? Switch to non-reflect code. + if v.Kind() == reflect.Interface && v.NumMethod() == 0 { + v.Set(reflect.ValueOf(d.objectInterface())) + return + } + + // Check type of target: struct or map[string]T + switch v.Kind() { + case reflect.Map: + // map must have string kind + t := v.Type() + if t.Key().Kind() != reflect.String { + d.saveError(&UnmarshalTypeError{"object", v.Type(), int64(d.off)}) + d.off-- + d.next() // skip over { } in input + return + } + if v.IsNil() { + v.Set(reflect.MakeMap(t)) + } + case reflect.Struct: + + default: + d.saveError(&UnmarshalTypeError{"object", v.Type(), int64(d.off)}) + d.off-- + d.next() // skip over { } in input + return + } + + var mapElem reflect.Value + keys := map[string]bool{} + + for { + // Read opening " of string key or closing }. + op := d.scanWhile(scanSkipSpace) + if op == scanEndObject { + // closing } - can only happen on first iteration. + break + } + if op != scanBeginLiteral { + d.error(errPhase) + } + + // Read key. + start := d.off - 1 + op = d.scanWhile(scanContinue) + item := d.data[start : d.off-1] + key, ok := unquote(item) + if !ok { + d.error(errPhase) + } + + // Check for duplicate keys. + _, ok = keys[key] + if !ok { + keys[key] = true + } else { + d.error(fmt.Errorf("json: duplicate key '%s' in object", key)) + } + + // Figure out field corresponding to key. + var subv reflect.Value + destring := false // whether the value is wrapped in a string to be decoded first + + if v.Kind() == reflect.Map { + elemType := v.Type().Elem() + if !mapElem.IsValid() { + mapElem = reflect.New(elemType).Elem() + } else { + mapElem.Set(reflect.Zero(elemType)) + } + subv = mapElem + } else { + var f *field + fields := cachedTypeFields(v.Type()) + for i := range fields { + ff := &fields[i] + if bytes.Equal(ff.nameBytes, []byte(key)) { + f = ff + break + } + } + if f != nil { + subv = v + destring = f.quoted + for _, i := range f.index { + if subv.Kind() == reflect.Ptr { + if subv.IsNil() { + subv.Set(reflect.New(subv.Type().Elem())) + } + subv = subv.Elem() + } + subv = subv.Field(i) + } + } + } + + // Read : before value. + if op == scanSkipSpace { + op = d.scanWhile(scanSkipSpace) + } + if op != scanObjectKey { + d.error(errPhase) + } + + // Read value. + if destring { + switch qv := d.valueQuoted().(type) { + case nil: + d.literalStore(nullLiteral, subv, false) + case string: + d.literalStore([]byte(qv), subv, true) + default: + d.saveError(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal unquoted value into %v", subv.Type())) + } + } else { + d.value(subv) + } + + // Write value back to map; + // if using struct, subv points into struct already. + if v.Kind() == reflect.Map { + kv := reflect.ValueOf(key).Convert(v.Type().Key()) + v.SetMapIndex(kv, subv) + } + + // Next token must be , or }. + op = d.scanWhile(scanSkipSpace) + if op == scanEndObject { + break + } + if op != scanObjectValue { + d.error(errPhase) + } + } +} + +// literal consumes a literal from d.data[d.off-1:], decoding into the value v. +// The first byte of the literal has been read already +// (that's how the caller knows it's a literal). +func (d *decodeState) literal(v reflect.Value) { + // All bytes inside literal return scanContinue op code. + start := d.off - 1 + op := d.scanWhile(scanContinue) + + // Scan read one byte too far; back up. + d.off-- + d.scan.undo(op) + + d.literalStore(d.data[start:d.off], v, false) +} + +// convertNumber converts the number literal s to a float64 or a Number +// depending on the setting of d.useNumber. +func (d *decodeState) convertNumber(s string) (interface{}, error) { + if d.useNumber { + return Number(s), nil + } + f, err := strconv.ParseFloat(s, 64) + if err != nil { + return nil, &UnmarshalTypeError{"number " + s, reflect.TypeOf(0.0), int64(d.off)} + } + return f, nil +} + +var numberType = reflect.TypeOf(Number("")) + +// literalStore decodes a literal stored in item into v. +// +// fromQuoted indicates whether this literal came from unwrapping a +// string from the ",string" struct tag option. this is used only to +// produce more helpful error messages. +func (d *decodeState) literalStore(item []byte, v reflect.Value, fromQuoted bool) { + // Check for unmarshaler. + if len(item) == 0 { + //Empty string given + d.saveError(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type())) + return + } + wantptr := item[0] == 'n' // null + u, ut, pv := d.indirect(v, wantptr) + if u != nil { + err := u.UnmarshalJSON(item) + if err != nil { + d.error(err) + } + return + } + if ut != nil { + if item[0] != '"' { + if fromQuoted { + d.saveError(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type())) + } else { + d.saveError(&UnmarshalTypeError{"string", v.Type(), int64(d.off)}) + } + return + } + s, ok := unquoteBytes(item) + if !ok { + if fromQuoted { + d.error(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type())) + } else { + d.error(errPhase) + } + } + err := ut.UnmarshalText(s) + if err != nil { + d.error(err) + } + return + } + + v = pv + + switch c := item[0]; c { + case 'n': // null + switch v.Kind() { + case reflect.Interface, reflect.Ptr, reflect.Map, reflect.Slice: + v.Set(reflect.Zero(v.Type())) + // otherwise, ignore null for primitives/string + } + case 't', 'f': // true, false + value := c == 't' + switch v.Kind() { + default: + if fromQuoted { + d.saveError(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type())) + } else { + d.saveError(&UnmarshalTypeError{"bool", v.Type(), int64(d.off)}) + } + case reflect.Bool: + v.SetBool(value) + case reflect.Interface: + if v.NumMethod() == 0 { + v.Set(reflect.ValueOf(value)) + } else { + d.saveError(&UnmarshalTypeError{"bool", v.Type(), int64(d.off)}) + } + } + + case '"': // string + s, ok := unquoteBytes(item) + if !ok { + if fromQuoted { + d.error(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type())) + } else { + d.error(errPhase) + } + } + switch v.Kind() { + default: + d.saveError(&UnmarshalTypeError{"string", v.Type(), int64(d.off)}) + case reflect.Slice: + if v.Type().Elem().Kind() != reflect.Uint8 { + d.saveError(&UnmarshalTypeError{"string", v.Type(), int64(d.off)}) + break + } + b := make([]byte, base64.StdEncoding.DecodedLen(len(s))) + n, err := base64.StdEncoding.Decode(b, s) + if err != nil { + d.saveError(err) + break + } + v.SetBytes(b[:n]) + case reflect.String: + v.SetString(string(s)) + case reflect.Interface: + if v.NumMethod() == 0 { + v.Set(reflect.ValueOf(string(s))) + } else { + d.saveError(&UnmarshalTypeError{"string", v.Type(), int64(d.off)}) + } + } + + default: // number + if c != '-' && (c < '0' || c > '9') { + if fromQuoted { + d.error(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type())) + } else { + d.error(errPhase) + } + } + s := string(item) + switch v.Kind() { + default: + if v.Kind() == reflect.String && v.Type() == numberType { + v.SetString(s) + if !isValidNumber(s) { + d.error(fmt.Errorf("json: invalid number literal, trying to unmarshal %q into Number", item)) + } + break + } + if fromQuoted { + d.error(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type())) + } else { + d.error(&UnmarshalTypeError{"number", v.Type(), int64(d.off)}) + } + case reflect.Interface: + n, err := d.convertNumber(s) + if err != nil { + d.saveError(err) + break + } + if v.NumMethod() != 0 { + d.saveError(&UnmarshalTypeError{"number", v.Type(), int64(d.off)}) + break + } + v.Set(reflect.ValueOf(n)) + + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + n, err := strconv.ParseInt(s, 10, 64) + if err != nil || v.OverflowInt(n) { + d.saveError(&UnmarshalTypeError{"number " + s, v.Type(), int64(d.off)}) + break + } + v.SetInt(n) + + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + n, err := strconv.ParseUint(s, 10, 64) + if err != nil || v.OverflowUint(n) { + d.saveError(&UnmarshalTypeError{"number " + s, v.Type(), int64(d.off)}) + break + } + v.SetUint(n) + + case reflect.Float32, reflect.Float64: + n, err := strconv.ParseFloat(s, v.Type().Bits()) + if err != nil || v.OverflowFloat(n) { + d.saveError(&UnmarshalTypeError{"number " + s, v.Type(), int64(d.off)}) + break + } + v.SetFloat(n) + } + } +} + +// The xxxInterface routines build up a value to be stored +// in an empty interface. They are not strictly necessary, +// but they avoid the weight of reflection in this common case. + +// valueInterface is like value but returns interface{} +func (d *decodeState) valueInterface() interface{} { + switch d.scanWhile(scanSkipSpace) { + default: + d.error(errPhase) + panic("unreachable") + case scanBeginArray: + return d.arrayInterface() + case scanBeginObject: + return d.objectInterface() + case scanBeginLiteral: + return d.literalInterface() + } +} + +// arrayInterface is like array but returns []interface{}. +func (d *decodeState) arrayInterface() []interface{} { + var v = make([]interface{}, 0) + for { + // Look ahead for ] - can only happen on first iteration. + op := d.scanWhile(scanSkipSpace) + if op == scanEndArray { + break + } + + // Back up so d.value can have the byte we just read. + d.off-- + d.scan.undo(op) + + v = append(v, d.valueInterface()) + + // Next token must be , or ]. + op = d.scanWhile(scanSkipSpace) + if op == scanEndArray { + break + } + if op != scanArrayValue { + d.error(errPhase) + } + } + return v +} + +// objectInterface is like object but returns map[string]interface{}. +func (d *decodeState) objectInterface() map[string]interface{} { + m := make(map[string]interface{}) + keys := map[string]bool{} + + for { + // Read opening " of string key or closing }. + op := d.scanWhile(scanSkipSpace) + if op == scanEndObject { + // closing } - can only happen on first iteration. + break + } + if op != scanBeginLiteral { + d.error(errPhase) + } + + // Read string key. + start := d.off - 1 + op = d.scanWhile(scanContinue) + item := d.data[start : d.off-1] + key, ok := unquote(item) + if !ok { + d.error(errPhase) + } + + // Check for duplicate keys. + _, ok = keys[key] + if !ok { + keys[key] = true + } else { + d.error(fmt.Errorf("json: duplicate key '%s' in object", key)) + } + + // Read : before value. + if op == scanSkipSpace { + op = d.scanWhile(scanSkipSpace) + } + if op != scanObjectKey { + d.error(errPhase) + } + + // Read value. + m[key] = d.valueInterface() + + // Next token must be , or }. + op = d.scanWhile(scanSkipSpace) + if op == scanEndObject { + break + } + if op != scanObjectValue { + d.error(errPhase) + } + } + return m +} + +// literalInterface is like literal but returns an interface value. +func (d *decodeState) literalInterface() interface{} { + // All bytes inside literal return scanContinue op code. + start := d.off - 1 + op := d.scanWhile(scanContinue) + + // Scan read one byte too far; back up. + d.off-- + d.scan.undo(op) + item := d.data[start:d.off] + + switch c := item[0]; c { + case 'n': // null + return nil + + case 't', 'f': // true, false + return c == 't' + + case '"': // string + s, ok := unquote(item) + if !ok { + d.error(errPhase) + } + return s + + default: // number + if c != '-' && (c < '0' || c > '9') { + d.error(errPhase) + } + n, err := d.convertNumber(string(item)) + if err != nil { + d.saveError(err) + } + return n + } +} + +// getu4 decodes \uXXXX from the beginning of s, returning the hex value, +// or it returns -1. +func getu4(s []byte) rune { + if len(s) < 6 || s[0] != '\\' || s[1] != 'u' { + return -1 + } + r, err := strconv.ParseUint(string(s[2:6]), 16, 64) + if err != nil { + return -1 + } + return rune(r) +} + +// unquote converts a quoted JSON string literal s into an actual string t. +// The rules are different than for Go, so cannot use strconv.Unquote. +func unquote(s []byte) (t string, ok bool) { + s, ok = unquoteBytes(s) + t = string(s) + return +} + +func unquoteBytes(s []byte) (t []byte, ok bool) { + if len(s) < 2 || s[0] != '"' || s[len(s)-1] != '"' { + return + } + s = s[1 : len(s)-1] + + // Check for unusual characters. If there are none, + // then no unquoting is needed, so return a slice of the + // original bytes. + r := 0 + for r < len(s) { + c := s[r] + if c == '\\' || c == '"' || c < ' ' { + break + } + if c < utf8.RuneSelf { + r++ + continue + } + rr, size := utf8.DecodeRune(s[r:]) + if rr == utf8.RuneError && size == 1 { + break + } + r += size + } + if r == len(s) { + return s, true + } + + b := make([]byte, len(s)+2*utf8.UTFMax) + w := copy(b, s[0:r]) + for r < len(s) { + // Out of room? Can only happen if s is full of + // malformed UTF-8 and we're replacing each + // byte with RuneError. + if w >= len(b)-2*utf8.UTFMax { + nb := make([]byte, (len(b)+utf8.UTFMax)*2) + copy(nb, b[0:w]) + b = nb + } + switch c := s[r]; { + case c == '\\': + r++ + if r >= len(s) { + return + } + switch s[r] { + default: + return + case '"', '\\', '/', '\'': + b[w] = s[r] + r++ + w++ + case 'b': + b[w] = '\b' + r++ + w++ + case 'f': + b[w] = '\f' + r++ + w++ + case 'n': + b[w] = '\n' + r++ + w++ + case 'r': + b[w] = '\r' + r++ + w++ + case 't': + b[w] = '\t' + r++ + w++ + case 'u': + r-- + rr := getu4(s[r:]) + if rr < 0 { + return + } + r += 6 + if utf16.IsSurrogate(rr) { + rr1 := getu4(s[r:]) + if dec := utf16.DecodeRune(rr, rr1); dec != unicode.ReplacementChar { + // A valid pair; consume. + r += 6 + w += utf8.EncodeRune(b[w:], dec) + break + } + // Invalid surrogate; fall back to replacement rune. + rr = unicode.ReplacementChar + } + w += utf8.EncodeRune(b[w:], rr) + } + + // Quote, control characters are invalid. + case c == '"', c < ' ': + return + + // ASCII + case c < utf8.RuneSelf: + b[w] = c + r++ + w++ + + // Coerce to well-formed UTF-8. + default: + rr, size := utf8.DecodeRune(s[r:]) + r += size + w += utf8.EncodeRune(b[w:], rr) + } + } + return b[0:w], true +} diff --git a/vendor/gopkg.in/square/go-jose.v2/json/encode.go b/vendor/gopkg.in/square/go-jose.v2/json/encode.go new file mode 100644 index 0000000000..1dae8bb7cd --- /dev/null +++ b/vendor/gopkg.in/square/go-jose.v2/json/encode.go @@ -0,0 +1,1197 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package json implements encoding and decoding of JSON objects as defined in +// RFC 4627. The mapping between JSON objects and Go values is described +// in the documentation for the Marshal and Unmarshal functions. +// +// See "JSON and Go" for an introduction to this package: +// https://golang.org/doc/articles/json_and_go.html +package json + +import ( + "bytes" + "encoding" + "encoding/base64" + "fmt" + "math" + "reflect" + "runtime" + "sort" + "strconv" + "strings" + "sync" + "unicode" + "unicode/utf8" +) + +// Marshal returns the JSON encoding of v. +// +// Marshal traverses the value v recursively. +// If an encountered value implements the Marshaler interface +// and is not a nil pointer, Marshal calls its MarshalJSON method +// to produce JSON. If no MarshalJSON method is present but the +// value implements encoding.TextMarshaler instead, Marshal calls +// its MarshalText method. +// The nil pointer exception is not strictly necessary +// but mimics a similar, necessary exception in the behavior of +// UnmarshalJSON. +// +// Otherwise, Marshal uses the following type-dependent default encodings: +// +// Boolean values encode as JSON booleans. +// +// Floating point, integer, and Number values encode as JSON numbers. +// +// String values encode as JSON strings coerced to valid UTF-8, +// replacing invalid bytes with the Unicode replacement rune. +// The angle brackets "<" and ">" are escaped to "\u003c" and "\u003e" +// to keep some browsers from misinterpreting JSON output as HTML. +// Ampersand "&" is also escaped to "\u0026" for the same reason. +// +// Array and slice values encode as JSON arrays, except that +// []byte encodes as a base64-encoded string, and a nil slice +// encodes as the null JSON object. +// +// Struct values encode as JSON objects. Each exported struct field +// becomes a member of the object unless +// - the field's tag is "-", or +// - the field is empty and its tag specifies the "omitempty" option. +// The empty values are false, 0, any +// nil pointer or interface value, and any array, slice, map, or string of +// length zero. The object's default key string is the struct field name +// but can be specified in the struct field's tag value. The "json" key in +// the struct field's tag value is the key name, followed by an optional comma +// and options. Examples: +// +// // Field is ignored by this package. +// Field int `json:"-"` +// +// // Field appears in JSON as key "myName". +// Field int `json:"myName"` +// +// // Field appears in JSON as key "myName" and +// // the field is omitted from the object if its value is empty, +// // as defined above. +// Field int `json:"myName,omitempty"` +// +// // Field appears in JSON as key "Field" (the default), but +// // the field is skipped if empty. +// // Note the leading comma. +// Field int `json:",omitempty"` +// +// The "string" option signals that a field is stored as JSON inside a +// JSON-encoded string. It applies only to fields of string, floating point, +// integer, or boolean types. This extra level of encoding is sometimes used +// when communicating with JavaScript programs: +// +// Int64String int64 `json:",string"` +// +// The key name will be used if it's a non-empty string consisting of +// only Unicode letters, digits, dollar signs, percent signs, hyphens, +// underscores and slashes. +// +// Anonymous struct fields are usually marshaled as if their inner exported fields +// were fields in the outer struct, subject to the usual Go visibility rules amended +// as described in the next paragraph. +// An anonymous struct field with a name given in its JSON tag is treated as +// having that name, rather than being anonymous. +// An anonymous struct field of interface type is treated the same as having +// that type as its name, rather than being anonymous. +// +// The Go visibility rules for struct fields are amended for JSON when +// deciding which field to marshal or unmarshal. If there are +// multiple fields at the same level, and that level is the least +// nested (and would therefore be the nesting level selected by the +// usual Go rules), the following extra rules apply: +// +// 1) Of those fields, if any are JSON-tagged, only tagged fields are considered, +// even if there are multiple untagged fields that would otherwise conflict. +// 2) If there is exactly one field (tagged or not according to the first rule), that is selected. +// 3) Otherwise there are multiple fields, and all are ignored; no error occurs. +// +// Handling of anonymous struct fields is new in Go 1.1. +// Prior to Go 1.1, anonymous struct fields were ignored. To force ignoring of +// an anonymous struct field in both current and earlier versions, give the field +// a JSON tag of "-". +// +// Map values encode as JSON objects. +// The map's key type must be string; the map keys are used as JSON object +// keys, subject to the UTF-8 coercion described for string values above. +// +// Pointer values encode as the value pointed to. +// A nil pointer encodes as the null JSON object. +// +// Interface values encode as the value contained in the interface. +// A nil interface value encodes as the null JSON object. +// +// Channel, complex, and function values cannot be encoded in JSON. +// Attempting to encode such a value causes Marshal to return +// an UnsupportedTypeError. +// +// JSON cannot represent cyclic data structures and Marshal does not +// handle them. Passing cyclic structures to Marshal will result in +// an infinite recursion. +// +func Marshal(v interface{}) ([]byte, error) { + e := &encodeState{} + err := e.marshal(v) + if err != nil { + return nil, err + } + return e.Bytes(), nil +} + +// MarshalIndent is like Marshal but applies Indent to format the output. +func MarshalIndent(v interface{}, prefix, indent string) ([]byte, error) { + b, err := Marshal(v) + if err != nil { + return nil, err + } + var buf bytes.Buffer + err = Indent(&buf, b, prefix, indent) + if err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +// HTMLEscape appends to dst the JSON-encoded src with <, >, &, U+2028 and U+2029 +// characters inside string literals changed to \u003c, \u003e, \u0026, \u2028, \u2029 +// so that the JSON will be safe to embed inside HTML