mirror of https://github.com/status-im/consul.git
acl: adding support for kubernetes auth provider login (#5600)
* auth providers * binding rules * auth provider for kubernetes * login/logout
This commit is contained in:
parent
cc1aa3f973
commit
e47d7eeddb
|
@ -376,6 +376,7 @@ func (s *HTTPServer) ACLTokenList(resp http.ResponseWriter, req *http.Request) (
|
|||
|
||||
args.Policy = req.URL.Query().Get("policy")
|
||||
args.Role = req.URL.Query().Get("role")
|
||||
args.AuthMethod = req.URL.Query().Get("authmethod")
|
||||
|
||||
var out structs.ACLTokenListResponse
|
||||
defer setMeta(resp, &out.QueryMeta)
|
||||
|
@ -701,3 +702,329 @@ func (s *HTTPServer) ACLRoleDelete(resp http.ResponseWriter, req *http.Request,
|
|||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (s *HTTPServer) ACLBindingRuleList(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
if s.checkACLDisabled(resp, req) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var args structs.ACLBindingRuleListRequest
|
||||
if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if args.Datacenter == "" {
|
||||
args.Datacenter = s.agent.config.Datacenter
|
||||
}
|
||||
|
||||
args.AuthMethod = req.URL.Query().Get("authmethod")
|
||||
|
||||
var out structs.ACLBindingRuleListResponse
|
||||
defer setMeta(resp, &out.QueryMeta)
|
||||
if err := s.agent.RPC("ACL.BindingRuleList", &args, &out); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// make sure we return an array and not nil
|
||||
if out.BindingRules == nil {
|
||||
out.BindingRules = make(structs.ACLBindingRules, 0)
|
||||
}
|
||||
|
||||
return out.BindingRules, nil
|
||||
}
|
||||
|
||||
func (s *HTTPServer) ACLBindingRuleCRUD(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
if s.checkACLDisabled(resp, req) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var fn func(resp http.ResponseWriter, req *http.Request, bindingRuleID string) (interface{}, error)
|
||||
|
||||
switch req.Method {
|
||||
case "GET":
|
||||
fn = s.ACLBindingRuleRead
|
||||
|
||||
case "PUT":
|
||||
fn = s.ACLBindingRuleWrite
|
||||
|
||||
case "DELETE":
|
||||
fn = s.ACLBindingRuleDelete
|
||||
|
||||
default:
|
||||
return nil, MethodNotAllowedError{req.Method, []string{"GET", "PUT", "DELETE"}}
|
||||
}
|
||||
|
||||
bindingRuleID := strings.TrimPrefix(req.URL.Path, "/v1/acl/binding-rule/")
|
||||
if bindingRuleID == "" && req.Method != "PUT" {
|
||||
return nil, BadRequestError{Reason: "Missing binding rule ID"}
|
||||
}
|
||||
|
||||
return fn(resp, req, bindingRuleID)
|
||||
}
|
||||
|
||||
func (s *HTTPServer) ACLBindingRuleRead(resp http.ResponseWriter, req *http.Request, bindingRuleID string) (interface{}, error) {
|
||||
args := structs.ACLBindingRuleGetRequest{
|
||||
Datacenter: s.agent.config.Datacenter,
|
||||
BindingRuleID: bindingRuleID,
|
||||
}
|
||||
if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if args.Datacenter == "" {
|
||||
args.Datacenter = s.agent.config.Datacenter
|
||||
}
|
||||
|
||||
var out structs.ACLBindingRuleResponse
|
||||
defer setMeta(resp, &out.QueryMeta)
|
||||
if err := s.agent.RPC("ACL.BindingRuleRead", &args, &out); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if out.BindingRule == nil {
|
||||
resp.WriteHeader(http.StatusNotFound)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return out.BindingRule, nil
|
||||
}
|
||||
|
||||
func (s *HTTPServer) ACLBindingRuleCreate(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
if s.checkACLDisabled(resp, req) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return s.ACLBindingRuleWrite(resp, req, "")
|
||||
}
|
||||
|
||||
func (s *HTTPServer) ACLBindingRuleWrite(resp http.ResponseWriter, req *http.Request, bindingRuleID string) (interface{}, error) {
|
||||
args := structs.ACLBindingRuleSetRequest{
|
||||
Datacenter: s.agent.config.Datacenter,
|
||||
}
|
||||
s.parseToken(req, &args.Token)
|
||||
|
||||
if err := decodeBody(req, &args.BindingRule, fixTimeAndHashFields); err != nil {
|
||||
return nil, BadRequestError{Reason: fmt.Sprintf("BindingRule decoding failed: %v", err)}
|
||||
}
|
||||
|
||||
if args.BindingRule.ID != "" && args.BindingRule.ID != bindingRuleID {
|
||||
return nil, BadRequestError{Reason: "BindingRule ID in URL and payload do not match"}
|
||||
} else if args.BindingRule.ID == "" {
|
||||
args.BindingRule.ID = bindingRuleID
|
||||
}
|
||||
|
||||
var out structs.ACLBindingRule
|
||||
if err := s.agent.RPC("ACL.BindingRuleSet", args, &out); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &out, nil
|
||||
}
|
||||
|
||||
func (s *HTTPServer) ACLBindingRuleDelete(resp http.ResponseWriter, req *http.Request, bindingRuleID string) (interface{}, error) {
|
||||
args := structs.ACLBindingRuleDeleteRequest{
|
||||
Datacenter: s.agent.config.Datacenter,
|
||||
BindingRuleID: bindingRuleID,
|
||||
}
|
||||
s.parseToken(req, &args.Token)
|
||||
|
||||
var ignored bool
|
||||
if err := s.agent.RPC("ACL.BindingRuleDelete", args, &ignored); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (s *HTTPServer) ACLAuthMethodList(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
if s.checkACLDisabled(resp, req) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var args structs.ACLAuthMethodListRequest
|
||||
if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if args.Datacenter == "" {
|
||||
args.Datacenter = s.agent.config.Datacenter
|
||||
}
|
||||
|
||||
var out structs.ACLAuthMethodListResponse
|
||||
defer setMeta(resp, &out.QueryMeta)
|
||||
if err := s.agent.RPC("ACL.AuthMethodList", &args, &out); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// make sure we return an array and not nil
|
||||
if out.AuthMethods == nil {
|
||||
out.AuthMethods = make(structs.ACLAuthMethodListStubs, 0)
|
||||
}
|
||||
|
||||
return out.AuthMethods, nil
|
||||
}
|
||||
|
||||
func (s *HTTPServer) ACLAuthMethodCRUD(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
if s.checkACLDisabled(resp, req) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var fn func(resp http.ResponseWriter, req *http.Request, methodName string) (interface{}, error)
|
||||
|
||||
switch req.Method {
|
||||
case "GET":
|
||||
fn = s.ACLAuthMethodRead
|
||||
|
||||
case "PUT":
|
||||
fn = s.ACLAuthMethodWrite
|
||||
|
||||
case "DELETE":
|
||||
fn = s.ACLAuthMethodDelete
|
||||
|
||||
default:
|
||||
return nil, MethodNotAllowedError{req.Method, []string{"GET", "PUT", "DELETE"}}
|
||||
}
|
||||
|
||||
methodName := strings.TrimPrefix(req.URL.Path, "/v1/acl/auth-method/")
|
||||
if methodName == "" && req.Method != "PUT" {
|
||||
return nil, BadRequestError{Reason: "Missing auth method name"}
|
||||
}
|
||||
|
||||
return fn(resp, req, methodName)
|
||||
}
|
||||
|
||||
func (s *HTTPServer) ACLAuthMethodRead(resp http.ResponseWriter, req *http.Request, methodName string) (interface{}, error) {
|
||||
args := structs.ACLAuthMethodGetRequest{
|
||||
Datacenter: s.agent.config.Datacenter,
|
||||
AuthMethodName: methodName,
|
||||
}
|
||||
if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if args.Datacenter == "" {
|
||||
args.Datacenter = s.agent.config.Datacenter
|
||||
}
|
||||
|
||||
var out structs.ACLAuthMethodResponse
|
||||
defer setMeta(resp, &out.QueryMeta)
|
||||
if err := s.agent.RPC("ACL.AuthMethodRead", &args, &out); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if out.AuthMethod == nil {
|
||||
resp.WriteHeader(http.StatusNotFound)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
fixupAuthMethodConfig(out.AuthMethod)
|
||||
return out.AuthMethod, nil
|
||||
}
|
||||
|
||||
func (s *HTTPServer) ACLAuthMethodCreate(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
if s.checkACLDisabled(resp, req) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return s.ACLAuthMethodWrite(resp, req, "")
|
||||
}
|
||||
|
||||
func (s *HTTPServer) ACLAuthMethodWrite(resp http.ResponseWriter, req *http.Request, methodName string) (interface{}, error) {
|
||||
args := structs.ACLAuthMethodSetRequest{
|
||||
Datacenter: s.agent.config.Datacenter,
|
||||
}
|
||||
s.parseToken(req, &args.Token)
|
||||
|
||||
if err := decodeBody(req, &args.AuthMethod, fixTimeAndHashFields); err != nil {
|
||||
return nil, BadRequestError{Reason: fmt.Sprintf("AuthMethod decoding failed: %v", err)}
|
||||
}
|
||||
|
||||
if methodName != "" {
|
||||
if args.AuthMethod.Name != "" && args.AuthMethod.Name != methodName {
|
||||
return nil, BadRequestError{Reason: "AuthMethod Name in URL and payload do not match"}
|
||||
} else if args.AuthMethod.Name == "" {
|
||||
args.AuthMethod.Name = methodName
|
||||
}
|
||||
}
|
||||
|
||||
var out structs.ACLAuthMethod
|
||||
if err := s.agent.RPC("ACL.AuthMethodSet", args, &out); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fixupAuthMethodConfig(&out)
|
||||
return &out, nil
|
||||
}
|
||||
|
||||
func (s *HTTPServer) ACLAuthMethodDelete(resp http.ResponseWriter, req *http.Request, methodName string) (interface{}, error) {
|
||||
args := structs.ACLAuthMethodDeleteRequest{
|
||||
Datacenter: s.agent.config.Datacenter,
|
||||
AuthMethodName: methodName,
|
||||
}
|
||||
s.parseToken(req, &args.Token)
|
||||
|
||||
var ignored bool
|
||||
if err := s.agent.RPC("ACL.AuthMethodDelete", args, &ignored); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (s *HTTPServer) ACLLogin(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
if s.checkACLDisabled(resp, req) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
args := &structs.ACLLoginRequest{
|
||||
Datacenter: s.agent.config.Datacenter,
|
||||
}
|
||||
s.parseDC(req, &args.Datacenter)
|
||||
|
||||
if err := decodeBody(req, &args.Auth, nil); err != nil {
|
||||
return nil, BadRequestError{Reason: fmt.Sprintf("Failed to decode request body:: %v", err)}
|
||||
}
|
||||
|
||||
var out structs.ACLToken
|
||||
if err := s.agent.RPC("ACL.Login", args, &out); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &out, nil
|
||||
}
|
||||
|
||||
func (s *HTTPServer) ACLLogout(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
if s.checkACLDisabled(resp, req) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
args := structs.ACLLogoutRequest{
|
||||
Datacenter: s.agent.config.Datacenter,
|
||||
}
|
||||
s.parseDC(req, &args.Datacenter)
|
||||
s.parseToken(req, &args.Token)
|
||||
|
||||
if args.Token == "" {
|
||||
return nil, acl.ErrNotFound
|
||||
}
|
||||
|
||||
var ignored bool
|
||||
if err := s.agent.RPC("ACL.Logout", &args, &ignored); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// A hack to fix up the config types inside of the map[string]interface{}
|
||||
// so that they get formatted correctly during json.Marshal. Without this,
|
||||
// string values that get converted to []uint8 end up getting output back
|
||||
// to the user in base64-encoded form.
|
||||
func fixupAuthMethodConfig(method *structs.ACLAuthMethod) {
|
||||
for k, v := range method.Config {
|
||||
if raw, ok := v.([]uint8); ok {
|
||||
strVal := structs.Uint8ToString(raw)
|
||||
method.Config[k] = strVal
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,6 +8,8 @@ import (
|
|||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/consul/acl"
|
||||
"github.com/hashicorp/consul/agent/consul/authmethod/testauth"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/hashicorp/consul/testrpc"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
@ -40,6 +42,17 @@ func TestACL_Disabled_Response(t *testing.T) {
|
|||
{"ACLTokenCreate", a.srv.ACLTokenCreate},
|
||||
{"ACLTokenSelf", a.srv.ACLTokenSelf},
|
||||
{"ACLTokenCRUD", a.srv.ACLTokenCRUD},
|
||||
{"ACLRoleList", a.srv.ACLRoleList},
|
||||
{"ACLRoleCreate", a.srv.ACLRoleCreate},
|
||||
{"ACLRoleCRUD", a.srv.ACLRoleCRUD},
|
||||
{"ACLBindingRuleList", a.srv.ACLBindingRuleList},
|
||||
{"ACLBindingRuleCreate", a.srv.ACLBindingRuleCreate},
|
||||
{"ACLBindingRuleCRUD", a.srv.ACLBindingRuleCRUD},
|
||||
{"ACLAuthMethodList", a.srv.ACLAuthMethodList},
|
||||
{"ACLAuthMethodCreate", a.srv.ACLAuthMethodCreate},
|
||||
{"ACLAuthMethodCRUD", a.srv.ACLAuthMethodCRUD},
|
||||
{"ACLLogin", a.srv.ACLLogin},
|
||||
{"ACLLogout", a.srv.ACLLogout},
|
||||
}
|
||||
testrpc.WaitForLeader(t, a.RPC, "dc1")
|
||||
for _, tt := range tests {
|
||||
|
@ -119,6 +132,7 @@ func TestACL_HTTP(t *testing.T) {
|
|||
|
||||
idMap := make(map[string]string)
|
||||
policyMap := make(map[string]*structs.ACLPolicy)
|
||||
roleMap := make(map[string]*structs.ACLRole)
|
||||
tokenMap := make(map[string]*structs.ACLToken)
|
||||
|
||||
// This is all done as a subtest for a couple reasons
|
||||
|
@ -220,7 +234,7 @@ func TestACL_HTTP(t *testing.T) {
|
|||
policyMap[policy.ID] = policy
|
||||
})
|
||||
|
||||
t.Run("Update Name ID Mistmatch", func(t *testing.T) {
|
||||
t.Run("Update Name ID Mismatch", func(t *testing.T) {
|
||||
policyInput := &structs.ACLPolicy{
|
||||
ID: "ac7560be-7f11-4d6d-bfcf-15633c2090fd",
|
||||
Name: "read-all-nodes",
|
||||
|
@ -355,6 +369,222 @@ func TestACL_HTTP(t *testing.T) {
|
|||
})
|
||||
})
|
||||
|
||||
t.Run("Role", func(t *testing.T) {
|
||||
t.Run("Create", func(t *testing.T) {
|
||||
roleInput := &structs.ACLRole{
|
||||
Name: "test",
|
||||
Description: "test",
|
||||
Policies: []structs.ACLRolePolicyLink{
|
||||
structs.ACLRolePolicyLink{
|
||||
ID: idMap["policy-test"],
|
||||
Name: policyMap[idMap["policy-test"]].Name,
|
||||
},
|
||||
structs.ACLRolePolicyLink{
|
||||
ID: idMap["policy-read-all-nodes"],
|
||||
Name: policyMap[idMap["policy-read-all-nodes"]].Name,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
req, _ := http.NewRequest("PUT", "/v1/acl/role?token=root", jsonBody(roleInput))
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := a.srv.ACLRoleCreate(resp, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
role, ok := obj.(*structs.ACLRole)
|
||||
require.True(t, ok)
|
||||
|
||||
// 36 = length of the string form of uuids
|
||||
require.Len(t, role.ID, 36)
|
||||
require.Equal(t, roleInput.Name, role.Name)
|
||||
require.Equal(t, roleInput.Description, role.Description)
|
||||
require.Equal(t, roleInput.Policies, role.Policies)
|
||||
require.True(t, role.CreateIndex > 0)
|
||||
require.Equal(t, role.CreateIndex, role.ModifyIndex)
|
||||
require.NotNil(t, role.Hash)
|
||||
require.NotEqual(t, role.Hash, []byte{})
|
||||
|
||||
idMap["role-test"] = role.ID
|
||||
roleMap[role.ID] = role
|
||||
})
|
||||
|
||||
t.Run("Name Chars", func(t *testing.T) {
|
||||
roleInput := &structs.ACLRole{
|
||||
Name: "service-id-web",
|
||||
ServiceIdentities: []*structs.ACLServiceIdentity{
|
||||
&structs.ACLServiceIdentity{
|
||||
ServiceName: "web",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
req, _ := http.NewRequest("PUT", "/v1/acl/role?token=root", jsonBody(roleInput))
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := a.srv.ACLRoleCreate(resp, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
role, ok := obj.(*structs.ACLRole)
|
||||
require.True(t, ok)
|
||||
|
||||
// 36 = length of the string form of uuids
|
||||
require.Len(t, role.ID, 36)
|
||||
require.Equal(t, roleInput.Name, role.Name)
|
||||
require.Equal(t, roleInput.Description, role.Description)
|
||||
require.Equal(t, roleInput.ServiceIdentities, role.ServiceIdentities)
|
||||
require.True(t, role.CreateIndex > 0)
|
||||
require.Equal(t, role.CreateIndex, role.ModifyIndex)
|
||||
require.NotNil(t, role.Hash)
|
||||
require.NotEqual(t, role.Hash, []byte{})
|
||||
|
||||
idMap["role-service-id-web"] = role.ID
|
||||
roleMap[role.ID] = role
|
||||
})
|
||||
|
||||
t.Run("Update Name ID Mismatch", func(t *testing.T) {
|
||||
roleInput := &structs.ACLRole{
|
||||
ID: "ac7560be-7f11-4d6d-bfcf-15633c2090fd",
|
||||
Name: "test",
|
||||
Description: "test",
|
||||
ServiceIdentities: []*structs.ACLServiceIdentity{
|
||||
&structs.ACLServiceIdentity{
|
||||
ServiceName: "db",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
req, _ := http.NewRequest("PUT", "/v1/acl/role/"+idMap["role-test"]+"?token=root", jsonBody(roleInput))
|
||||
resp := httptest.NewRecorder()
|
||||
_, err := a.srv.ACLRoleCRUD(resp, req)
|
||||
require.Error(t, err)
|
||||
_, ok := err.(BadRequestError)
|
||||
require.True(t, ok)
|
||||
})
|
||||
|
||||
t.Run("Role CRUD Missing ID in URL", func(t *testing.T) {
|
||||
req, _ := http.NewRequest("GET", "/v1/acl/role/?token=root", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
_, err := a.srv.ACLRoleCRUD(resp, req)
|
||||
require.Error(t, err)
|
||||
_, ok := err.(BadRequestError)
|
||||
require.True(t, ok)
|
||||
})
|
||||
|
||||
t.Run("Update", func(t *testing.T) {
|
||||
roleInput := &structs.ACLRole{
|
||||
Name: "test",
|
||||
Description: "test",
|
||||
ServiceIdentities: []*structs.ACLServiceIdentity{
|
||||
&structs.ACLServiceIdentity{
|
||||
ServiceName: "web-indexer",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
req, _ := http.NewRequest("PUT", "/v1/acl/role/"+idMap["role-test"]+"?token=root", jsonBody(roleInput))
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := a.srv.ACLRoleCRUD(resp, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
role, ok := obj.(*structs.ACLRole)
|
||||
require.True(t, ok)
|
||||
|
||||
// 36 = length of the string form of uuids
|
||||
require.Len(t, role.ID, 36)
|
||||
require.Equal(t, roleInput.Name, role.Name)
|
||||
require.Equal(t, roleInput.Description, role.Description)
|
||||
require.Equal(t, roleInput.Policies, role.Policies)
|
||||
require.Equal(t, roleInput.ServiceIdentities, role.ServiceIdentities)
|
||||
require.True(t, role.CreateIndex > 0)
|
||||
require.True(t, role.CreateIndex < role.ModifyIndex)
|
||||
require.NotNil(t, role.Hash)
|
||||
require.NotEqual(t, role.Hash, []byte{})
|
||||
|
||||
idMap["role-test"] = role.ID
|
||||
roleMap[role.ID] = role
|
||||
})
|
||||
|
||||
t.Run("ID Supplied", func(t *testing.T) {
|
||||
roleInput := &structs.ACLRole{
|
||||
ID: "12123d01-37f1-47e6-b55b-32328652bd38",
|
||||
Name: "with-id",
|
||||
Description: "test",
|
||||
ServiceIdentities: []*structs.ACLServiceIdentity{
|
||||
&structs.ACLServiceIdentity{
|
||||
ServiceName: "foobar",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
req, _ := http.NewRequest("PUT", "/v1/acl/role?token=root", jsonBody(roleInput))
|
||||
resp := httptest.NewRecorder()
|
||||
_, err := a.srv.ACLRoleCreate(resp, req)
|
||||
require.Error(t, err)
|
||||
_, ok := err.(BadRequestError)
|
||||
require.True(t, ok)
|
||||
})
|
||||
|
||||
t.Run("Invalid payload", func(t *testing.T) {
|
||||
body := bytes.NewBuffer(nil)
|
||||
body.Write([]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9})
|
||||
|
||||
req, _ := http.NewRequest("PUT", "/v1/acl/role?token=root", body)
|
||||
resp := httptest.NewRecorder()
|
||||
_, err := a.srv.ACLRoleCreate(resp, req)
|
||||
require.Error(t, err)
|
||||
_, ok := err.(BadRequestError)
|
||||
require.True(t, ok)
|
||||
})
|
||||
|
||||
t.Run("Delete", func(t *testing.T) {
|
||||
req, _ := http.NewRequest("DELETE", "/v1/acl/role/"+idMap["role-service-id-web"]+"?token=root", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
_, err := a.srv.ACLRoleCRUD(resp, req)
|
||||
require.NoError(t, err)
|
||||
delete(roleMap, idMap["role-service-id-web"])
|
||||
delete(idMap, "role-service-id-web")
|
||||
})
|
||||
|
||||
t.Run("List", func(t *testing.T) {
|
||||
req, _ := http.NewRequest("GET", "/v1/acl/roles?token=root", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
raw, err := a.srv.ACLRoleList(resp, req)
|
||||
require.NoError(t, err)
|
||||
roles, ok := raw.(structs.ACLRoles)
|
||||
require.True(t, ok)
|
||||
|
||||
// 1 we just created
|
||||
require.Len(t, roles, 1)
|
||||
|
||||
for roleID, expected := range roleMap {
|
||||
found := false
|
||||
for _, actual := range roles {
|
||||
if actual.ID == roleID {
|
||||
require.Equal(t, expected.Name, actual.Name)
|
||||
require.Equal(t, expected.Policies, actual.Policies)
|
||||
require.Equal(t, expected.ServiceIdentities, actual.ServiceIdentities)
|
||||
require.Equal(t, expected.Hash, actual.Hash)
|
||||
require.Equal(t, expected.CreateIndex, actual.CreateIndex)
|
||||
require.Equal(t, expected.ModifyIndex, actual.ModifyIndex)
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
require.True(t, found)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Read", func(t *testing.T) {
|
||||
req, _ := http.NewRequest("GET", "/v1/acl/role/"+idMap["role-test"]+"?token=root", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
raw, err := a.srv.ACLRoleCRUD(resp, req)
|
||||
require.NoError(t, err)
|
||||
role, ok := raw.(*structs.ACLRole)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, roleMap[idMap["role-test"]], role)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("Token", func(t *testing.T) {
|
||||
t.Run("Create", func(t *testing.T) {
|
||||
tokenInput := &structs.ACLToken{
|
||||
|
@ -594,3 +824,504 @@ func TestACL_HTTP(t *testing.T) {
|
|||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestACL_LoginProcedure_HTTP(t *testing.T) {
|
||||
// This tests AuthMethods, BindingRules, Login, and Logout.
|
||||
t.Parallel()
|
||||
a := NewTestAgent(t, t.Name(), TestACLConfig())
|
||||
defer a.Shutdown()
|
||||
|
||||
testrpc.WaitForLeader(t, a.RPC, "dc1")
|
||||
|
||||
idMap := make(map[string]string)
|
||||
methodMap := make(map[string]*structs.ACLAuthMethod)
|
||||
ruleMap := make(map[string]*structs.ACLBindingRule)
|
||||
tokenMap := make(map[string]*structs.ACLToken)
|
||||
|
||||
testSessionID := testauth.StartSession()
|
||||
defer testauth.ResetSession(testSessionID)
|
||||
|
||||
// This is all done as a subtest for a couple reasons
|
||||
// 1. It uses only 1 test agent and these are
|
||||
// somewhat expensive to bring up and tear down often
|
||||
// 2. Instead of having to bring up a new agent and prime
|
||||
// the ACL system with some data before running the test
|
||||
// we can intelligently order these tests so we can still
|
||||
// test everything with less actual operations and do
|
||||
// so in a manner that is less prone to being flaky
|
||||
// 3. While this test will be large it should
|
||||
t.Run("AuthMethod", func(t *testing.T) {
|
||||
t.Run("Create", func(t *testing.T) {
|
||||
methodInput := &structs.ACLAuthMethod{
|
||||
Name: "test",
|
||||
Type: "testing",
|
||||
Description: "test",
|
||||
Config: map[string]interface{}{
|
||||
"SessionID": testSessionID,
|
||||
},
|
||||
}
|
||||
|
||||
req, _ := http.NewRequest("PUT", "/v1/acl/auth-method?token=root", jsonBody(methodInput))
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := a.srv.ACLAuthMethodCreate(resp, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
method, ok := obj.(*structs.ACLAuthMethod)
|
||||
require.True(t, ok)
|
||||
|
||||
require.Equal(t, methodInput.Name, method.Name)
|
||||
require.Equal(t, methodInput.Type, method.Type)
|
||||
require.Equal(t, methodInput.Description, method.Description)
|
||||
require.Equal(t, methodInput.Config, method.Config)
|
||||
require.True(t, method.CreateIndex > 0)
|
||||
require.Equal(t, method.CreateIndex, method.ModifyIndex)
|
||||
|
||||
methodMap[method.Name] = method
|
||||
})
|
||||
|
||||
t.Run("Create other", func(t *testing.T) {
|
||||
methodInput := &structs.ACLAuthMethod{
|
||||
Name: "other",
|
||||
Type: "testing",
|
||||
Description: "test",
|
||||
Config: map[string]interface{}{
|
||||
"SessionID": testSessionID,
|
||||
},
|
||||
}
|
||||
|
||||
req, _ := http.NewRequest("PUT", "/v1/acl/auth-method?token=root", jsonBody(methodInput))
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := a.srv.ACLAuthMethodCreate(resp, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
method, ok := obj.(*structs.ACLAuthMethod)
|
||||
require.True(t, ok)
|
||||
|
||||
require.Equal(t, methodInput.Name, method.Name)
|
||||
require.Equal(t, methodInput.Type, method.Type)
|
||||
require.Equal(t, methodInput.Description, method.Description)
|
||||
require.Equal(t, methodInput.Config, method.Config)
|
||||
require.True(t, method.CreateIndex > 0)
|
||||
require.Equal(t, method.CreateIndex, method.ModifyIndex)
|
||||
|
||||
methodMap[method.Name] = method
|
||||
})
|
||||
|
||||
t.Run("Update Name URL Mismatch", func(t *testing.T) {
|
||||
methodInput := &structs.ACLAuthMethod{
|
||||
Name: "test",
|
||||
Type: "testing",
|
||||
Description: "test",
|
||||
Config: map[string]interface{}{
|
||||
"SessionID": testSessionID,
|
||||
},
|
||||
}
|
||||
|
||||
req, _ := http.NewRequest("PUT", "/v1/acl/auth-method/not-test?token=root", jsonBody(methodInput))
|
||||
resp := httptest.NewRecorder()
|
||||
_, err := a.srv.ACLAuthMethodCRUD(resp, req)
|
||||
require.Error(t, err)
|
||||
_, ok := err.(BadRequestError)
|
||||
require.True(t, ok)
|
||||
})
|
||||
|
||||
t.Run("Update", func(t *testing.T) {
|
||||
methodInput := &structs.ACLAuthMethod{
|
||||
Name: "test",
|
||||
Type: "testing",
|
||||
Description: "updated description",
|
||||
Config: map[string]interface{}{
|
||||
"SessionID": testSessionID,
|
||||
},
|
||||
}
|
||||
|
||||
req, _ := http.NewRequest("PUT", "/v1/acl/auth-method/test?token=root", jsonBody(methodInput))
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := a.srv.ACLAuthMethodCRUD(resp, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
method, ok := obj.(*structs.ACLAuthMethod)
|
||||
require.True(t, ok)
|
||||
|
||||
require.Equal(t, methodInput.Name, method.Name)
|
||||
require.Equal(t, methodInput.Type, method.Type)
|
||||
require.Equal(t, methodInput.Description, method.Description)
|
||||
require.Equal(t, methodInput.Config, method.Config)
|
||||
require.True(t, method.CreateIndex > 0)
|
||||
require.True(t, method.CreateIndex < method.ModifyIndex)
|
||||
|
||||
methodMap[method.Name] = method
|
||||
})
|
||||
|
||||
t.Run("Invalid payload", func(t *testing.T) {
|
||||
body := bytes.NewBuffer(nil)
|
||||
body.Write([]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9})
|
||||
|
||||
req, _ := http.NewRequest("PUT", "/v1/acl/auth-method?token=root", body)
|
||||
resp := httptest.NewRecorder()
|
||||
_, err := a.srv.ACLAuthMethodCreate(resp, req)
|
||||
require.Error(t, err)
|
||||
_, ok := err.(BadRequestError)
|
||||
require.True(t, ok)
|
||||
})
|
||||
|
||||
t.Run("List", func(t *testing.T) {
|
||||
req, _ := http.NewRequest("GET", "/v1/acl/auth-methods?token=root", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
raw, err := a.srv.ACLAuthMethodList(resp, req)
|
||||
require.NoError(t, err)
|
||||
methods, ok := raw.(structs.ACLAuthMethodListStubs)
|
||||
require.True(t, ok)
|
||||
|
||||
// 2 we just created
|
||||
require.Len(t, methods, 2)
|
||||
|
||||
for methodName, expected := range methodMap {
|
||||
found := false
|
||||
for _, actual := range methods {
|
||||
if actual.Name == methodName {
|
||||
require.Equal(t, expected.Name, actual.Name)
|
||||
require.Equal(t, expected.Type, actual.Type)
|
||||
require.Equal(t, expected.Description, actual.Description)
|
||||
require.Equal(t, expected.CreateIndex, actual.CreateIndex)
|
||||
require.Equal(t, expected.ModifyIndex, actual.ModifyIndex)
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
require.True(t, found)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Delete", func(t *testing.T) {
|
||||
req, _ := http.NewRequest("DELETE", "/v1/acl/auth-method/other?token=root", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
_, err := a.srv.ACLAuthMethodCRUD(resp, req)
|
||||
require.NoError(t, err)
|
||||
delete(methodMap, "other")
|
||||
})
|
||||
|
||||
t.Run("Read", func(t *testing.T) {
|
||||
req, _ := http.NewRequest("GET", "/v1/acl/auth-method/test?token=root", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
raw, err := a.srv.ACLAuthMethodCRUD(resp, req)
|
||||
require.NoError(t, err)
|
||||
method, ok := raw.(*structs.ACLAuthMethod)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, methodMap["test"], method)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("BindingRule", func(t *testing.T) {
|
||||
t.Run("Create", func(t *testing.T) {
|
||||
ruleInput := &structs.ACLBindingRule{
|
||||
Description: "test",
|
||||
AuthMethod: "test",
|
||||
Selector: "serviceaccount.namespace==default",
|
||||
BindType: structs.BindingRuleBindTypeService,
|
||||
BindName: "web",
|
||||
}
|
||||
|
||||
req, _ := http.NewRequest("PUT", "/v1/acl/binding-rule?token=root", jsonBody(ruleInput))
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := a.srv.ACLBindingRuleCreate(resp, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
rule, ok := obj.(*structs.ACLBindingRule)
|
||||
require.True(t, ok)
|
||||
|
||||
// 36 = length of the string form of uuids
|
||||
require.Len(t, rule.ID, 36)
|
||||
require.Equal(t, ruleInput.Description, rule.Description)
|
||||
require.Equal(t, ruleInput.AuthMethod, rule.AuthMethod)
|
||||
require.Equal(t, ruleInput.Selector, rule.Selector)
|
||||
require.Equal(t, ruleInput.BindType, rule.BindType)
|
||||
require.Equal(t, ruleInput.BindName, rule.BindName)
|
||||
require.True(t, rule.CreateIndex > 0)
|
||||
require.Equal(t, rule.CreateIndex, rule.ModifyIndex)
|
||||
|
||||
idMap["rule-test"] = rule.ID
|
||||
ruleMap[rule.ID] = rule
|
||||
})
|
||||
|
||||
t.Run("Create other", func(t *testing.T) {
|
||||
ruleInput := &structs.ACLBindingRule{
|
||||
Description: "other",
|
||||
AuthMethod: "test",
|
||||
Selector: "serviceaccount.namespace==default",
|
||||
BindType: structs.BindingRuleBindTypeRole,
|
||||
BindName: "fancy-role",
|
||||
}
|
||||
|
||||
req, _ := http.NewRequest("PUT", "/v1/acl/binding-rule?token=root", jsonBody(ruleInput))
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := a.srv.ACLBindingRuleCreate(resp, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
rule, ok := obj.(*structs.ACLBindingRule)
|
||||
require.True(t, ok)
|
||||
|
||||
// 36 = length of the string form of uuids
|
||||
require.Len(t, rule.ID, 36)
|
||||
require.Equal(t, ruleInput.Description, rule.Description)
|
||||
require.Equal(t, ruleInput.AuthMethod, rule.AuthMethod)
|
||||
require.Equal(t, ruleInput.Selector, rule.Selector)
|
||||
require.Equal(t, ruleInput.BindType, rule.BindType)
|
||||
require.Equal(t, ruleInput.BindName, rule.BindName)
|
||||
require.True(t, rule.CreateIndex > 0)
|
||||
require.Equal(t, rule.CreateIndex, rule.ModifyIndex)
|
||||
|
||||
idMap["rule-other"] = rule.ID
|
||||
ruleMap[rule.ID] = rule
|
||||
})
|
||||
|
||||
t.Run("BindingRule CRUD Missing ID in URL", func(t *testing.T) {
|
||||
req, _ := http.NewRequest("GET", "/v1/acl/binding-rule/?token=root", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
_, err := a.srv.ACLBindingRuleCRUD(resp, req)
|
||||
require.Error(t, err)
|
||||
_, ok := err.(BadRequestError)
|
||||
require.True(t, ok)
|
||||
})
|
||||
|
||||
t.Run("Update", func(t *testing.T) {
|
||||
ruleInput := &structs.ACLBindingRule{
|
||||
Description: "updated",
|
||||
AuthMethod: "test",
|
||||
Selector: "serviceaccount.namespace==default",
|
||||
BindType: structs.BindingRuleBindTypeService,
|
||||
BindName: "${serviceaccount.name}",
|
||||
}
|
||||
|
||||
req, _ := http.NewRequest("PUT", "/v1/acl/binding-rule/"+idMap["rule-test"]+"?token=root", jsonBody(ruleInput))
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := a.srv.ACLBindingRuleCRUD(resp, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
rule, ok := obj.(*structs.ACLBindingRule)
|
||||
require.True(t, ok)
|
||||
|
||||
// 36 = length of the string form of uuids
|
||||
require.Len(t, rule.ID, 36)
|
||||
require.Equal(t, ruleInput.Description, rule.Description)
|
||||
require.Equal(t, ruleInput.AuthMethod, rule.AuthMethod)
|
||||
require.Equal(t, ruleInput.Selector, rule.Selector)
|
||||
require.Equal(t, ruleInput.BindType, rule.BindType)
|
||||
require.Equal(t, ruleInput.BindName, rule.BindName)
|
||||
require.True(t, rule.CreateIndex > 0)
|
||||
require.True(t, rule.CreateIndex < rule.ModifyIndex)
|
||||
|
||||
idMap["rule-test"] = rule.ID
|
||||
ruleMap[rule.ID] = rule
|
||||
})
|
||||
|
||||
t.Run("ID Supplied", func(t *testing.T) {
|
||||
ruleInput := &structs.ACLBindingRule{
|
||||
ID: "12123d01-37f1-47e6-b55b-32328652bd38",
|
||||
Description: "with-id",
|
||||
AuthMethod: "test",
|
||||
Selector: "serviceaccount.namespace==default",
|
||||
BindType: structs.BindingRuleBindTypeService,
|
||||
BindName: "vault",
|
||||
}
|
||||
|
||||
req, _ := http.NewRequest("PUT", "/v1/acl/binding-rule?token=root", jsonBody(ruleInput))
|
||||
resp := httptest.NewRecorder()
|
||||
_, err := a.srv.ACLBindingRuleCreate(resp, req)
|
||||
require.Error(t, err)
|
||||
_, ok := err.(BadRequestError)
|
||||
require.True(t, ok)
|
||||
})
|
||||
|
||||
t.Run("Invalid payload", func(t *testing.T) {
|
||||
body := bytes.NewBuffer(nil)
|
||||
body.Write([]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9})
|
||||
|
||||
req, _ := http.NewRequest("PUT", "/v1/acl/binding-rule?token=root", body)
|
||||
resp := httptest.NewRecorder()
|
||||
_, err := a.srv.ACLBindingRuleCreate(resp, req)
|
||||
require.Error(t, err)
|
||||
_, ok := err.(BadRequestError)
|
||||
require.True(t, ok)
|
||||
})
|
||||
|
||||
t.Run("List", func(t *testing.T) {
|
||||
req, _ := http.NewRequest("GET", "/v1/acl/binding-rules?token=root", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
raw, err := a.srv.ACLBindingRuleList(resp, req)
|
||||
require.NoError(t, err)
|
||||
rules, ok := raw.(structs.ACLBindingRules)
|
||||
require.True(t, ok)
|
||||
|
||||
// 2 we just created
|
||||
require.Len(t, rules, 2)
|
||||
|
||||
for ruleID, expected := range ruleMap {
|
||||
found := false
|
||||
for _, actual := range rules {
|
||||
if actual.ID == ruleID {
|
||||
require.Equal(t, expected.Description, actual.Description)
|
||||
require.Equal(t, expected.AuthMethod, actual.AuthMethod)
|
||||
require.Equal(t, expected.Selector, actual.Selector)
|
||||
require.Equal(t, expected.BindType, actual.BindType)
|
||||
require.Equal(t, expected.BindName, actual.BindName)
|
||||
require.Equal(t, expected.CreateIndex, actual.CreateIndex)
|
||||
require.Equal(t, expected.ModifyIndex, actual.ModifyIndex)
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
require.True(t, found)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Delete", func(t *testing.T) {
|
||||
req, _ := http.NewRequest("DELETE", "/v1/acl/binding-rule/"+idMap["rule-other"]+"?token=root", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
_, err := a.srv.ACLBindingRuleCRUD(resp, req)
|
||||
require.NoError(t, err)
|
||||
delete(ruleMap, idMap["rule-other"])
|
||||
delete(idMap, "rule-other")
|
||||
})
|
||||
|
||||
t.Run("Read", func(t *testing.T) {
|
||||
req, _ := http.NewRequest("GET", "/v1/acl/binding-rule/"+idMap["rule-test"]+"?token=root", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
raw, err := a.srv.ACLBindingRuleCRUD(resp, req)
|
||||
require.NoError(t, err)
|
||||
rule, ok := raw.(*structs.ACLBindingRule)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, ruleMap[idMap["rule-test"]], rule)
|
||||
})
|
||||
})
|
||||
|
||||
testauth.InstallSessionToken(testSessionID, "token1", "default", "demo1", "abc123")
|
||||
testauth.InstallSessionToken(testSessionID, "token2", "default", "demo2", "def456")
|
||||
|
||||
t.Run("Login", func(t *testing.T) {
|
||||
t.Run("Create Token 1", func(t *testing.T) {
|
||||
loginInput := &structs.ACLLoginParams{
|
||||
AuthMethod: "test",
|
||||
BearerToken: "token1",
|
||||
Meta: map[string]string{"foo": "bar"},
|
||||
}
|
||||
|
||||
req, _ := http.NewRequest("POST", "/v1/acl/login?token=root", jsonBody(loginInput))
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := a.srv.ACLLogin(resp, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
token, ok := obj.(*structs.ACLToken)
|
||||
require.True(t, ok)
|
||||
|
||||
// 36 = length of the string form of uuids
|
||||
require.Len(t, token.AccessorID, 36)
|
||||
require.Len(t, token.SecretID, 36)
|
||||
require.Equal(t, `token created via login: {"foo":"bar"}`, token.Description)
|
||||
require.True(t, token.Local)
|
||||
require.Len(t, token.Policies, 0)
|
||||
require.Len(t, token.Roles, 0)
|
||||
require.Len(t, token.ServiceIdentities, 1)
|
||||
require.Equal(t, "demo1", token.ServiceIdentities[0].ServiceName)
|
||||
require.Len(t, token.ServiceIdentities[0].Datacenters, 0)
|
||||
require.True(t, token.CreateIndex > 0)
|
||||
require.Equal(t, token.CreateIndex, token.ModifyIndex)
|
||||
require.NotNil(t, token.Hash)
|
||||
require.NotEqual(t, token.Hash, []byte{})
|
||||
|
||||
idMap["token-test-1"] = token.AccessorID
|
||||
tokenMap[token.AccessorID] = token
|
||||
})
|
||||
t.Run("Create Token 2", func(t *testing.T) {
|
||||
loginInput := &structs.ACLLoginParams{
|
||||
AuthMethod: "test",
|
||||
BearerToken: "token2",
|
||||
Meta: map[string]string{"blah": "woot"},
|
||||
}
|
||||
|
||||
req, _ := http.NewRequest("POST", "/v1/acl/login?token=root", jsonBody(loginInput))
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := a.srv.ACLLogin(resp, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
token, ok := obj.(*structs.ACLToken)
|
||||
require.True(t, ok)
|
||||
|
||||
// 36 = length of the string form of uuids
|
||||
require.Len(t, token.AccessorID, 36)
|
||||
require.Len(t, token.SecretID, 36)
|
||||
require.Equal(t, `token created via login: {"blah":"woot"}`, token.Description)
|
||||
require.True(t, token.Local)
|
||||
require.Len(t, token.Policies, 0)
|
||||
require.Len(t, token.Roles, 0)
|
||||
require.Len(t, token.ServiceIdentities, 1)
|
||||
require.Equal(t, "demo2", token.ServiceIdentities[0].ServiceName)
|
||||
require.Len(t, token.ServiceIdentities[0].Datacenters, 0)
|
||||
require.True(t, token.CreateIndex > 0)
|
||||
require.Equal(t, token.CreateIndex, token.ModifyIndex)
|
||||
require.NotNil(t, token.Hash)
|
||||
require.NotEqual(t, token.Hash, []byte{})
|
||||
|
||||
idMap["token-test-2"] = token.AccessorID
|
||||
tokenMap[token.AccessorID] = token
|
||||
})
|
||||
|
||||
t.Run("List Tokens by (incorrect) Method", func(t *testing.T) {
|
||||
req, _ := http.NewRequest("GET", "/v1/acl/tokens?token=root&authmethod=other", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
raw, err := a.srv.ACLTokenList(resp, req)
|
||||
require.NoError(t, err)
|
||||
tokens, ok := raw.(structs.ACLTokenListStubs)
|
||||
require.True(t, ok)
|
||||
require.Len(t, tokens, 0)
|
||||
})
|
||||
|
||||
t.Run("List Tokens by (correct) Method", func(t *testing.T) {
|
||||
req, _ := http.NewRequest("GET", "/v1/acl/tokens?token=root&authmethod=test", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
raw, err := a.srv.ACLTokenList(resp, req)
|
||||
require.NoError(t, err)
|
||||
tokens, ok := raw.(structs.ACLTokenListStubs)
|
||||
require.True(t, ok)
|
||||
require.Len(t, tokens, 2)
|
||||
|
||||
for tokenID, expected := range tokenMap {
|
||||
found := false
|
||||
for _, actual := range tokens {
|
||||
if actual.AccessorID == tokenID {
|
||||
require.Equal(t, expected.Description, actual.Description)
|
||||
require.Equal(t, expected.Policies, actual.Policies)
|
||||
require.Equal(t, expected.Roles, actual.Roles)
|
||||
require.Equal(t, expected.ServiceIdentities, actual.ServiceIdentities)
|
||||
require.Equal(t, expected.Local, actual.Local)
|
||||
require.Equal(t, expected.CreateTime, actual.CreateTime)
|
||||
require.Equal(t, expected.Hash, actual.Hash)
|
||||
require.Equal(t, expected.CreateIndex, actual.CreateIndex)
|
||||
require.Equal(t, expected.ModifyIndex, actual.ModifyIndex)
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, found)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Logout", func(t *testing.T) {
|
||||
tok := tokenMap[idMap["token-test-1"]]
|
||||
req, _ := http.NewRequest("POST", "/v1/acl/logout?token="+tok.SecretID, nil)
|
||||
resp := httptest.NewRecorder()
|
||||
_, err := a.srv.ACLLogout(resp, req)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("Token is gone after Logout", func(t *testing.T) {
|
||||
req, _ := http.NewRequest("GET", "/v1/acl/token/"+idMap["token-test-1"]+"?token=root", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
_, err := a.srv.ACLTokenCRUD(resp, req)
|
||||
require.Error(t, err)
|
||||
require.True(t, acl.IsErrNotFound(err), err.Error())
|
||||
})
|
||||
})
|
||||
}
|
||||
|
|
|
@ -447,25 +447,8 @@ func (r *ACLResolver) fetchAndCachePoliciesForIdentity(identity structs.ACLIdent
|
|||
return out, nil
|
||||
}
|
||||
|
||||
if acl.IsErrNotFound(err) {
|
||||
// make sure to indicate that this identity is no longer valid within
|
||||
// the cache
|
||||
r.cache.PutIdentity(identity.SecretToken(), nil)
|
||||
|
||||
// Do not touch the policy cache. Getting a top level ACL not found error
|
||||
// only indicates that the secret token used in the request
|
||||
// no longer exists
|
||||
return nil, &policyOrRoleTokenError{acl.ErrNotFound, identity.SecretToken()}
|
||||
}
|
||||
|
||||
if acl.IsErrPermissionDenied(err) {
|
||||
// invalidate our ID cache so that identity resolution will take place
|
||||
// again in the future
|
||||
r.cache.RemoveIdentity(identity.SecretToken())
|
||||
|
||||
// Do not remove from the policy cache for permission denied
|
||||
// what this does indicate is that our view of the token is out of date
|
||||
return nil, &policyOrRoleTokenError{acl.ErrPermissionDenied, identity.SecretToken()}
|
||||
if handledErr := r.maybeHandleIdentityErrorDuringFetch(identity, err); handledErr != nil {
|
||||
return nil, handledErr
|
||||
}
|
||||
|
||||
// other RPC error - use cache if available
|
||||
|
@ -519,25 +502,8 @@ func (r *ACLResolver) fetchAndCacheRolesForIdentity(identity structs.ACLIdentity
|
|||
return out, nil
|
||||
}
|
||||
|
||||
if acl.IsErrNotFound(err) {
|
||||
// make sure to indicate that this identity is no longer valid within
|
||||
// the cache
|
||||
r.cache.PutIdentity(identity.SecretToken(), nil)
|
||||
|
||||
// Do not touch the cache. Getting a top level ACL not found error
|
||||
// only indicates that the secret token used in the request
|
||||
// no longer exists
|
||||
return nil, &policyOrRoleTokenError{acl.ErrNotFound, identity.SecretToken()}
|
||||
}
|
||||
|
||||
if acl.IsErrPermissionDenied(err) {
|
||||
// invalidate our ID cache so that identity resolution will take place
|
||||
// again in the future
|
||||
r.cache.RemoveIdentity(identity.SecretToken())
|
||||
|
||||
// Do not remove from the cache for permission denied
|
||||
// what this does indicate is that our view of the token is out of date
|
||||
return nil, &policyOrRoleTokenError{acl.ErrPermissionDenied, identity.SecretToken()}
|
||||
if handledErr := r.maybeHandleIdentityErrorDuringFetch(identity, err); handledErr != nil {
|
||||
return nil, handledErr
|
||||
}
|
||||
|
||||
// other RPC error - use cache if available
|
||||
|
@ -557,12 +523,39 @@ func (r *ACLResolver) fetchAndCacheRolesForIdentity(identity structs.ACLIdentity
|
|||
insufficientCache = true
|
||||
}
|
||||
}
|
||||
|
||||
if insufficientCache {
|
||||
return nil, ACLRemoteError{Err: err}
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *ACLResolver) maybeHandleIdentityErrorDuringFetch(identity structs.ACLIdentity, err error) error {
|
||||
if acl.IsErrNotFound(err) {
|
||||
// make sure to indicate that this identity is no longer valid within
|
||||
// the cache
|
||||
r.cache.PutIdentity(identity.SecretToken(), nil)
|
||||
|
||||
// Do not touch the cache. Getting a top level ACL not found error
|
||||
// only indicates that the secret token used in the request
|
||||
// no longer exists
|
||||
return &policyOrRoleTokenError{acl.ErrNotFound, identity.SecretToken()}
|
||||
}
|
||||
|
||||
if acl.IsErrPermissionDenied(err) {
|
||||
// invalidate our ID cache so that identity resolution will take place
|
||||
// again in the future
|
||||
r.cache.RemoveIdentity(identity.SecretToken())
|
||||
|
||||
// Do not remove from the cache for permission denied
|
||||
// what this does indicate is that our view of the token is out of date
|
||||
return &policyOrRoleTokenError{acl.ErrPermissionDenied, identity.SecretToken()}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *ACLResolver) filterPoliciesByScope(policies structs.ACLPolicies) structs.ACLPolicies {
|
||||
var out structs.ACLPolicies
|
||||
for _, policy := range policies {
|
||||
|
|
|
@ -0,0 +1,169 @@
|
|||
package consul
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/consul/agent/consul/authmethod"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/hashicorp/go-bexpr"
|
||||
|
||||
// register this as a builtin auth method
|
||||
_ "github.com/hashicorp/consul/agent/consul/authmethod/kubeauth"
|
||||
)
|
||||
|
||||
type authMethodValidatorEntry struct {
|
||||
Validator authmethod.Validator
|
||||
ModifyIndex uint64 // the raft index when this last changed
|
||||
}
|
||||
|
||||
// loadAuthMethodValidator returns an authmethod.Validator for the given auth
|
||||
// method configuration. If the cache is up to date as-of the provided index
|
||||
// then the cached version is returned, otherwise a new validator is created
|
||||
// and cached.
|
||||
func (s *Server) loadAuthMethodValidator(idx uint64, method *structs.ACLAuthMethod) (authmethod.Validator, error) {
|
||||
if prevIdx, v, ok := s.getCachedAuthMethodValidator(method.Name); ok && idx <= prevIdx {
|
||||
return v, nil
|
||||
}
|
||||
|
||||
v, err := authmethod.NewValidator(method)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("auth method validator for %q could not be initialized: %v", method.Name, err)
|
||||
}
|
||||
|
||||
v = s.getOrReplaceAuthMethodValidator(method.Name, idx, v)
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// getCachedAuthMethodValidator returns an AuthMethodValidator for
|
||||
// the given name exclusively from the cache. If one is not found in the cache
|
||||
// nil is returned.
|
||||
func (s *Server) getCachedAuthMethodValidator(name string) (uint64, authmethod.Validator, bool) {
|
||||
s.aclAuthMethodValidatorLock.RLock()
|
||||
defer s.aclAuthMethodValidatorLock.RUnlock()
|
||||
|
||||
if s.aclAuthMethodValidators != nil {
|
||||
v, ok := s.aclAuthMethodValidators[name]
|
||||
if ok {
|
||||
return v.ModifyIndex, v.Validator, true
|
||||
}
|
||||
}
|
||||
return 0, nil, false
|
||||
}
|
||||
|
||||
// getOrReplaceAuthMethodValidator updates the cached validator with the
|
||||
// provided one UNLESS it has been updated by another goroutine in which case
|
||||
// the updated one is returned.
|
||||
func (s *Server) getOrReplaceAuthMethodValidator(name string, idx uint64, v authmethod.Validator) authmethod.Validator {
|
||||
s.aclAuthMethodValidatorLock.Lock()
|
||||
defer s.aclAuthMethodValidatorLock.Unlock()
|
||||
|
||||
if s.aclAuthMethodValidators == nil {
|
||||
s.aclAuthMethodValidators = make(map[string]*authMethodValidatorEntry)
|
||||
}
|
||||
|
||||
prev, ok := s.aclAuthMethodValidators[name]
|
||||
if ok {
|
||||
if prev.ModifyIndex >= idx {
|
||||
return prev.Validator
|
||||
}
|
||||
}
|
||||
|
||||
s.logger.Printf("[DEBUG] acl: updating cached auth method validator for %q", name)
|
||||
|
||||
s.aclAuthMethodValidators[name] = &authMethodValidatorEntry{
|
||||
Validator: v,
|
||||
ModifyIndex: idx,
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// purgeAuthMethodValidators resets the cache of validators.
|
||||
func (s *Server) purgeAuthMethodValidators() {
|
||||
s.aclAuthMethodValidatorLock.Lock()
|
||||
s.aclAuthMethodValidators = make(map[string]*authMethodValidatorEntry)
|
||||
s.aclAuthMethodValidatorLock.Unlock()
|
||||
}
|
||||
|
||||
// evaluateRoleBindings evaluates all current binding rules associated with the
|
||||
// given auth method against the verified data returned from the authentication
|
||||
// process.
|
||||
//
|
||||
// A list of role links and service identities are returned.
|
||||
func (s *Server) evaluateRoleBindings(
|
||||
validator authmethod.Validator,
|
||||
verifiedFields map[string]string,
|
||||
) ([]*structs.ACLServiceIdentity, []structs.ACLTokenRoleLink, error) {
|
||||
// Only fetch rules that are relevant for this method.
|
||||
_, rules, err := s.fsm.State().ACLBindingRuleList(nil, validator.Name())
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
} else if len(rules) == 0 {
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
// Convert the fields into something suitable for go-bexpr.
|
||||
selectableVars := validator.MakeFieldMapSelectable(verifiedFields)
|
||||
|
||||
// Find all binding rules that match the provided fields.
|
||||
var matchingRules []*structs.ACLBindingRule
|
||||
for _, rule := range rules {
|
||||
if doesBindingRuleMatch(rule, selectableVars) {
|
||||
matchingRules = append(matchingRules, rule)
|
||||
}
|
||||
}
|
||||
if len(matchingRules) == 0 {
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
// For all matching rules compute the attributes of a token.
|
||||
var (
|
||||
roleLinks []structs.ACLTokenRoleLink
|
||||
serviceIdentities []*structs.ACLServiceIdentity
|
||||
)
|
||||
for _, rule := range matchingRules {
|
||||
bindName, valid, err := computeBindingRuleBindName(rule.BindType, rule.BindName, verifiedFields)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("cannot compute %q bind name for bind target: %v", rule.BindType, err)
|
||||
} else if !valid {
|
||||
return nil, nil, fmt.Errorf("computed %q bind name for bind target is invalid: %q", rule.BindType, bindName)
|
||||
}
|
||||
|
||||
switch rule.BindType {
|
||||
case structs.BindingRuleBindTypeService:
|
||||
serviceIdentities = append(serviceIdentities, &structs.ACLServiceIdentity{
|
||||
ServiceName: bindName,
|
||||
})
|
||||
|
||||
case structs.BindingRuleBindTypeRole:
|
||||
roleLinks = append(roleLinks, structs.ACLTokenRoleLink{
|
||||
Name: bindName,
|
||||
})
|
||||
|
||||
default:
|
||||
// skip unknown bind type; don't grant privileges
|
||||
}
|
||||
}
|
||||
|
||||
return serviceIdentities, roleLinks, nil
|
||||
}
|
||||
|
||||
// doesBindingRuleMatch checks that a single binding rule matches the provided
|
||||
// vars.
|
||||
func doesBindingRuleMatch(rule *structs.ACLBindingRule, selectableVars interface{}) bool {
|
||||
if rule.Selector == "" {
|
||||
return true // catch-all
|
||||
}
|
||||
|
||||
eval, err := bexpr.CreateEvaluatorForType(rule.Selector, nil, selectableVars)
|
||||
if err != nil {
|
||||
return false // fails to match if selector is invalid
|
||||
}
|
||||
|
||||
result, err := eval.Evaluate(selectableVars)
|
||||
if err != nil {
|
||||
return false // fails to match if evaluation fails
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
|
@ -0,0 +1,48 @@
|
|||
package consul
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDoesBindingRuleMatch(t *testing.T) {
|
||||
type matchable struct {
|
||||
A string `bexpr:"a"`
|
||||
C string `bexpr:"c"`
|
||||
}
|
||||
|
||||
for _, test := range []struct {
|
||||
name string
|
||||
selector string
|
||||
details interface{}
|
||||
ok bool
|
||||
}{
|
||||
{"no fields",
|
||||
"a==b", nil, false},
|
||||
{"1 term ok",
|
||||
"a==b", &matchable{A: "b"}, true},
|
||||
{"1 term no field",
|
||||
"a==b", &matchable{C: "d"}, false},
|
||||
{"1 term wrong value",
|
||||
"a==b", &matchable{A: "z"}, false},
|
||||
{"2 terms ok",
|
||||
"a==b and c==d", &matchable{A: "b", C: "d"}, true},
|
||||
{"2 terms one missing field",
|
||||
"a==b and c==d", &matchable{A: "b"}, false},
|
||||
{"2 terms one wrong value",
|
||||
"a==b and c==d", &matchable{A: "z", C: "d"}, false},
|
||||
///////////////////////////////
|
||||
{"no fields (no selectors)",
|
||||
"", nil, true},
|
||||
{"1 term ok (no selectors)",
|
||||
"", &matchable{A: "b"}, true},
|
||||
} {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
rule := structs.ACLBindingRule{Selector: test.selector}
|
||||
ok := doesBindingRuleMatch(&rule, test.details)
|
||||
require.Equal(t, test.ok, ok)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,6 +1,8 @@
|
|||
package consul
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
|
@ -10,9 +12,11 @@ import (
|
|||
|
||||
metrics "github.com/armon/go-metrics"
|
||||
"github.com/hashicorp/consul/acl"
|
||||
"github.com/hashicorp/consul/agent/consul/authmethod"
|
||||
"github.com/hashicorp/consul/agent/consul/state"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/hashicorp/consul/lib"
|
||||
"github.com/hashicorp/go-bexpr"
|
||||
memdb "github.com/hashicorp/go-memdb"
|
||||
uuid "github.com/hashicorp/go-uuid"
|
||||
)
|
||||
|
@ -29,6 +33,7 @@ var (
|
|||
validServiceIdentityName = regexp.MustCompile(`^[a-z0-9]([a-z0-9\-_]*[a-z0-9])?$`)
|
||||
serviceIdentityNameMaxLength = 256
|
||||
validRoleName = regexp.MustCompile(`^[A-Za-z0-9\-_]{1,256}$`)
|
||||
validAuthMethod = regexp.MustCompile(`^[A-Za-z0-9\-_]{1,128}$`)
|
||||
)
|
||||
|
||||
// ACL endpoint is used to manipulate ACLs
|
||||
|
@ -273,6 +278,10 @@ func (a *ACL) TokenClone(args *structs.ACLTokenSetRequest, reply *structs.ACLTok
|
|||
return a.srv.forwardDC("ACL.TokenClone", a.srv.config.ACLDatacenter, args, reply)
|
||||
}
|
||||
|
||||
if token.AuthMethod != "" {
|
||||
return fmt.Errorf("Cannot clone a token created from an auth method")
|
||||
}
|
||||
|
||||
if token.Rules != "" {
|
||||
return fmt.Errorf("Cannot clone a legacy ACL with this endpoint")
|
||||
}
|
||||
|
@ -324,7 +333,7 @@ func (a *ACL) TokenSet(args *structs.ACLTokenSetRequest, reply *structs.ACLToken
|
|||
return a.tokenSetInternal(args, reply, false)
|
||||
}
|
||||
|
||||
func (a *ACL) tokenSetInternal(args *structs.ACLTokenSetRequest, reply *structs.ACLToken, upgrade bool) error {
|
||||
func (a *ACL) tokenSetInternal(args *structs.ACLTokenSetRequest, reply *structs.ACLToken, fromLogin bool) error {
|
||||
token := &args.ACLToken
|
||||
|
||||
if !a.srv.LocalTokensEnabled() {
|
||||
|
@ -354,6 +363,19 @@ func (a *ACL) tokenSetInternal(args *structs.ACLTokenSetRequest, reply *structs.
|
|||
|
||||
token.CreateTime = time.Now()
|
||||
|
||||
if fromLogin {
|
||||
if token.AuthMethod == "" {
|
||||
return fmt.Errorf("AuthMethod field is required during Login")
|
||||
}
|
||||
if !token.Local {
|
||||
return fmt.Errorf("Cannot create Global token via Login")
|
||||
}
|
||||
} else {
|
||||
if token.AuthMethod != "" {
|
||||
return fmt.Errorf("AuthMethod field is disallowed outside of Login")
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure an ExpirationTTL is valid if provided.
|
||||
if token.ExpirationTTL != 0 {
|
||||
if token.ExpirationTTL < 0 {
|
||||
|
@ -418,6 +440,12 @@ func (a *ACL) tokenSetInternal(args *structs.ACLTokenSetRequest, reply *structs.
|
|||
return fmt.Errorf("cannot toggle local mode of %s", token.AccessorID)
|
||||
}
|
||||
|
||||
if token.AuthMethod == "" {
|
||||
token.AuthMethod = existing.AuthMethod
|
||||
} else if token.AuthMethod != existing.AuthMethod {
|
||||
return fmt.Errorf("Cannot change AuthMethod of %s", token.AccessorID)
|
||||
}
|
||||
|
||||
if token.ExpirationTTL != 0 {
|
||||
return fmt.Errorf("Cannot change expiration time of %s", token.AccessorID)
|
||||
}
|
||||
|
@ -430,11 +458,7 @@ func (a *ACL) tokenSetInternal(args *structs.ACLTokenSetRequest, reply *structs.
|
|||
return fmt.Errorf("Cannot change expiration time of %s", token.AccessorID)
|
||||
}
|
||||
|
||||
if upgrade {
|
||||
token.CreateTime = time.Now()
|
||||
} else {
|
||||
token.CreateTime = existing.CreateTime
|
||||
}
|
||||
token.CreateTime = existing.CreateTime
|
||||
}
|
||||
|
||||
policyIDs := make(map[string]struct{})
|
||||
|
@ -467,7 +491,7 @@ func (a *ACL) tokenSetInternal(args *structs.ACLTokenSetRequest, reply *structs.
|
|||
roleIDs := make(map[string]struct{})
|
||||
var roles []structs.ACLTokenRoleLink
|
||||
|
||||
// Validate all the role names and convert them to role IDs
|
||||
// Validate all the role names and convert them to role IDs.
|
||||
for _, link := range token.Roles {
|
||||
if link.ID == "" {
|
||||
_, role, err := state.ACLRoleGetByName(nil, link.Name)
|
||||
|
@ -502,6 +526,7 @@ func (a *ACL) tokenSetInternal(args *structs.ACLTokenSetRequest, reply *structs.
|
|||
return fmt.Errorf("Service identity %q has an invalid name. Only alphanumeric characters, '-' and '_' are allowed", svcid.ServiceName)
|
||||
}
|
||||
}
|
||||
token.ServiceIdentities = dedupeServiceIdentities(token.ServiceIdentities)
|
||||
|
||||
if token.Rules != "" {
|
||||
return fmt.Errorf("Rules cannot be specified for this token")
|
||||
|
@ -540,6 +565,51 @@ func (a *ACL) tokenSetInternal(args *structs.ACLTokenSetRequest, reply *structs.
|
|||
return nil
|
||||
}
|
||||
|
||||
func validateBindingRuleBindName(bindType, bindName string, availableFields []string) (bool, error) {
|
||||
if bindType == "" || bindName == "" {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
fakeVarMap := make(map[string]string)
|
||||
for _, v := range availableFields {
|
||||
fakeVarMap[v] = "fake"
|
||||
}
|
||||
|
||||
_, valid, err := computeBindingRuleBindName(bindType, bindName, fakeVarMap)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return valid, nil
|
||||
}
|
||||
|
||||
// computeBindingRuleBindName processes the HIL for the provided bind type+name
|
||||
// using the verified fields.
|
||||
//
|
||||
// - If the HIL is invalid ("", false, AN_ERROR) is returned.
|
||||
// - If the computed name is not valid for the type ("INVALID_NAME", false, nil) is returned.
|
||||
// - If the computed name is valid for the type ("VALID_NAME", true, nil) is returned.
|
||||
func computeBindingRuleBindName(bindType, bindName string, verifiedFields map[string]string) (string, bool, error) {
|
||||
bindName, err := InterpolateHIL(bindName, verifiedFields)
|
||||
if err != nil {
|
||||
return "", false, err
|
||||
}
|
||||
|
||||
valid := false
|
||||
|
||||
switch bindType {
|
||||
case structs.BindingRuleBindTypeService:
|
||||
valid = isValidServiceIdentityName(bindName)
|
||||
|
||||
case structs.BindingRuleBindTypeRole:
|
||||
valid = validRoleName.MatchString(bindName)
|
||||
|
||||
default:
|
||||
return "", false, fmt.Errorf("unknown binding rule bind type: %s", bindType)
|
||||
}
|
||||
|
||||
return bindName, valid, nil
|
||||
}
|
||||
|
||||
// isValidServiceIdentityName returns true if the provided name can be used as
|
||||
// an ACLServiceIdentity ServiceName. This is more restrictive than standard
|
||||
// catalog registration, which basically takes the view that "everything is
|
||||
|
@ -652,7 +722,7 @@ func (a *ACL) TokenList(args *structs.ACLTokenListRequest, reply *structs.ACLTok
|
|||
|
||||
return a.srv.blockingQuery(&args.QueryOptions, &reply.QueryMeta,
|
||||
func(ws memdb.WatchSet, state *state.Store) error {
|
||||
index, tokens, err := state.ACLTokenList(ws, args.IncludeLocal, args.IncludeGlobal, args.Policy, args.Role)
|
||||
index, tokens, err := state.ACLTokenList(ws, args.IncludeLocal, args.IncludeGlobal, args.Policy, args.Role, args.AuthMethod)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -1252,17 +1322,11 @@ func (a *ACL) RoleSet(args *structs.ACLRoleSetRequest, reply *structs.ACLRole) e
|
|||
if svcid.ServiceName == "" {
|
||||
return fmt.Errorf("Service identity is missing the service name field on this role")
|
||||
}
|
||||
// TODO(rb): ugh if a local token gets a role that has a service
|
||||
// identity that has datacenters set, we won't be anble to enforce this
|
||||
// next blob here. This makes me lean more towards nuking ServiceIdentity.Datacenters again
|
||||
//
|
||||
// if token.Local && len(svcid.Datacenters) > 0 {
|
||||
// return fmt.Errorf("Service identity %q cannot specify a list of datacenters on a local token", svcid.ServiceName)
|
||||
// }
|
||||
if !isValidServiceIdentityName(svcid.ServiceName) {
|
||||
return fmt.Errorf("Service identity %q has an invalid name. Only alphanumeric characters, '-' and '_' are allowed", svcid.ServiceName)
|
||||
}
|
||||
}
|
||||
role.ServiceIdentities = dedupeServiceIdentities(role.ServiceIdentities)
|
||||
|
||||
// calculate the hash for this role
|
||||
role.SetHash(true)
|
||||
|
@ -1412,3 +1476,577 @@ func (a *ACL) RoleResolve(args *structs.ACLRoleBatchGetRequest, reply *structs.A
|
|||
|
||||
return nil
|
||||
}
|
||||
|
||||
var errAuthMethodsRequireTokenReplication = errors.New("Token replication is required for auth methods to function")
|
||||
|
||||
func (a *ACL) BindingRuleRead(args *structs.ACLBindingRuleGetRequest, reply *structs.ACLBindingRuleResponse) error {
|
||||
if err := a.aclPreCheck(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !a.srv.LocalTokensEnabled() {
|
||||
return errAuthMethodsRequireTokenReplication
|
||||
}
|
||||
|
||||
if done, err := a.srv.forward("ACL.BindingRuleRead", args, args, reply); done {
|
||||
return err
|
||||
}
|
||||
|
||||
if rule, err := a.srv.ResolveToken(args.Token); err != nil {
|
||||
return err
|
||||
} else if rule == nil || !rule.ACLRead() {
|
||||
return acl.ErrPermissionDenied
|
||||
}
|
||||
|
||||
return a.srv.blockingQuery(&args.QueryOptions, &reply.QueryMeta,
|
||||
func(ws memdb.WatchSet, state *state.Store) error {
|
||||
index, rule, err := state.ACLBindingRuleGetByID(ws, args.BindingRuleID)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
reply.Index, reply.BindingRule = index, rule
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (a *ACL) BindingRuleSet(args *structs.ACLBindingRuleSetRequest, reply *structs.ACLBindingRule) error {
|
||||
if err := a.aclPreCheck(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !a.srv.LocalTokensEnabled() {
|
||||
return errAuthMethodsRequireTokenReplication
|
||||
}
|
||||
|
||||
if done, err := a.srv.forward("ACL.BindingRuleSet", args, args, reply); done {
|
||||
return err
|
||||
}
|
||||
|
||||
defer metrics.MeasureSince([]string{"acl", "bindingrule", "upsert"}, time.Now())
|
||||
|
||||
// Verify token is permitted to modify ACLs
|
||||
if rule, err := a.srv.ResolveToken(args.Token); err != nil {
|
||||
return err
|
||||
} else if rule == nil || !rule.ACLWrite() {
|
||||
return acl.ErrPermissionDenied
|
||||
}
|
||||
|
||||
rule := &args.BindingRule
|
||||
state := a.srv.fsm.State()
|
||||
|
||||
if rule.ID == "" {
|
||||
// with no binding rule ID one will be generated
|
||||
var err error
|
||||
|
||||
rule.ID, err = lib.GenerateUUID(a.srv.checkBindingRuleUUID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if _, err := uuid.ParseUUID(rule.ID); err != nil {
|
||||
return fmt.Errorf("Binding Rule ID invalid UUID")
|
||||
}
|
||||
|
||||
// Verify the role exists
|
||||
_, existing, err := state.ACLBindingRuleGetByID(nil, rule.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("acl binding rule lookup failed: %v", err)
|
||||
} else if existing == nil {
|
||||
return fmt.Errorf("cannot find binding rule %s", rule.ID)
|
||||
}
|
||||
|
||||
if rule.AuthMethod == "" {
|
||||
rule.AuthMethod = existing.AuthMethod
|
||||
} else if existing.AuthMethod != rule.AuthMethod {
|
||||
return fmt.Errorf("the AuthMethod field of an Binding Rule is immutable")
|
||||
}
|
||||
}
|
||||
|
||||
if rule.AuthMethod == "" {
|
||||
return fmt.Errorf("Invalid Binding Rule: no AuthMethod is set")
|
||||
}
|
||||
|
||||
methodIdx, method, err := state.ACLAuthMethodGetByName(nil, rule.AuthMethod)
|
||||
if err != nil {
|
||||
return fmt.Errorf("acl auth method lookup failed: %v", err)
|
||||
} else if method == nil {
|
||||
return fmt.Errorf("cannot find auth method with name %q", rule.AuthMethod)
|
||||
}
|
||||
validator, err := a.srv.loadAuthMethodValidator(methodIdx, method)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if rule.Selector != "" {
|
||||
selectableVars := validator.MakeFieldMapSelectable(map[string]string{})
|
||||
_, err := bexpr.CreateEvaluatorForType(rule.Selector, nil, selectableVars)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid Binding Rule: Selector is invalid: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if rule.BindType == "" {
|
||||
return fmt.Errorf("Invalid Binding Rule: no BindType is set")
|
||||
}
|
||||
|
||||
if rule.BindName == "" {
|
||||
return fmt.Errorf("Invalid Binding Rule: no BindName is set")
|
||||
}
|
||||
|
||||
switch rule.BindType {
|
||||
case structs.BindingRuleBindTypeService:
|
||||
case structs.BindingRuleBindTypeRole:
|
||||
default:
|
||||
return fmt.Errorf("Invalid Binding Rule: unknown BindType %q", rule.BindType)
|
||||
}
|
||||
|
||||
if valid, err := validateBindingRuleBindName(rule.BindType, rule.BindName, validator.AvailableFields()); err != nil {
|
||||
return fmt.Errorf("Invalid Binding Rule: invalid BindName: %v", err)
|
||||
} else if !valid {
|
||||
return fmt.Errorf("Invalid Binding Rule: invalid BindName")
|
||||
}
|
||||
|
||||
req := &structs.ACLBindingRuleBatchSetRequest{
|
||||
BindingRules: structs.ACLBindingRules{rule},
|
||||
}
|
||||
|
||||
resp, err := a.srv.raftApply(structs.ACLBindingRuleSetRequestType, req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to apply binding rule upsert request: %v", err)
|
||||
}
|
||||
|
||||
if respErr, ok := resp.(error); ok {
|
||||
return respErr
|
||||
}
|
||||
|
||||
if _, rule, err := a.srv.fsm.State().ACLBindingRuleGetByID(nil, rule.ID); err == nil && rule != nil {
|
||||
*reply = *rule
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *ACL) BindingRuleDelete(args *structs.ACLBindingRuleDeleteRequest, reply *bool) error {
|
||||
if err := a.aclPreCheck(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !a.srv.LocalTokensEnabled() {
|
||||
return errAuthMethodsRequireTokenReplication
|
||||
}
|
||||
|
||||
if done, err := a.srv.forward("ACL.BindingRuleDelete", args, args, reply); done {
|
||||
return err
|
||||
}
|
||||
|
||||
defer metrics.MeasureSince([]string{"acl", "bindingrule", "delete"}, time.Now())
|
||||
|
||||
// Verify token is permitted to modify ACLs
|
||||
if rule, err := a.srv.ResolveToken(args.Token); err != nil {
|
||||
return err
|
||||
} else if rule == nil || !rule.ACLWrite() {
|
||||
return acl.ErrPermissionDenied
|
||||
}
|
||||
|
||||
_, rule, err := a.srv.fsm.State().ACLBindingRuleGetByID(nil, args.BindingRuleID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if rule == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
req := structs.ACLBindingRuleBatchDeleteRequest{
|
||||
BindingRuleIDs: []string{args.BindingRuleID},
|
||||
}
|
||||
|
||||
resp, err := a.srv.raftApply(structs.ACLBindingRuleDeleteRequestType, &req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to apply binding rule delete request: %v", err)
|
||||
}
|
||||
|
||||
if respErr, ok := resp.(error); ok {
|
||||
return respErr
|
||||
}
|
||||
|
||||
*reply = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *ACL) BindingRuleList(args *structs.ACLBindingRuleListRequest, reply *structs.ACLBindingRuleListResponse) error {
|
||||
if err := a.aclPreCheck(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !a.srv.LocalTokensEnabled() {
|
||||
return errAuthMethodsRequireTokenReplication
|
||||
}
|
||||
|
||||
if done, err := a.srv.forward("ACL.BindingRuleList", args, args, reply); done {
|
||||
return err
|
||||
}
|
||||
|
||||
if rule, err := a.srv.ResolveToken(args.Token); err != nil {
|
||||
return err
|
||||
} else if rule == nil || !rule.ACLRead() {
|
||||
return acl.ErrPermissionDenied
|
||||
}
|
||||
|
||||
return a.srv.blockingQuery(&args.QueryOptions, &reply.QueryMeta,
|
||||
func(ws memdb.WatchSet, state *state.Store) error {
|
||||
index, rules, err := state.ACLBindingRuleList(ws, args.AuthMethod)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
reply.Index, reply.BindingRules = index, rules
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (a *ACL) AuthMethodRead(args *structs.ACLAuthMethodGetRequest, reply *structs.ACLAuthMethodResponse) error {
|
||||
if err := a.aclPreCheck(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !a.srv.LocalTokensEnabled() {
|
||||
return errAuthMethodsRequireTokenReplication
|
||||
}
|
||||
|
||||
if done, err := a.srv.forward("ACL.AuthMethodRead", args, args, reply); done {
|
||||
return err
|
||||
}
|
||||
|
||||
if rule, err := a.srv.ResolveToken(args.Token); err != nil {
|
||||
return err
|
||||
} else if rule == nil || !rule.ACLRead() {
|
||||
return acl.ErrPermissionDenied
|
||||
}
|
||||
|
||||
return a.srv.blockingQuery(&args.QueryOptions, &reply.QueryMeta,
|
||||
func(ws memdb.WatchSet, state *state.Store) error {
|
||||
index, method, err := state.ACLAuthMethodGetByName(ws, args.AuthMethodName)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
reply.Index, reply.AuthMethod = index, method
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (a *ACL) AuthMethodSet(args *structs.ACLAuthMethodSetRequest, reply *structs.ACLAuthMethod) error {
|
||||
if err := a.aclPreCheck(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !a.srv.LocalTokensEnabled() {
|
||||
return errAuthMethodsRequireTokenReplication
|
||||
}
|
||||
|
||||
if done, err := a.srv.forward("ACL.AuthMethodSet", args, args, reply); done {
|
||||
return err
|
||||
}
|
||||
|
||||
defer metrics.MeasureSince([]string{"acl", "authmethod", "upsert"}, time.Now())
|
||||
|
||||
// Verify token is permitted to modify ACLs
|
||||
if rule, err := a.srv.ResolveToken(args.Token); err != nil {
|
||||
return err
|
||||
} else if rule == nil || !rule.ACLWrite() {
|
||||
return acl.ErrPermissionDenied
|
||||
}
|
||||
|
||||
method := &args.AuthMethod
|
||||
state := a.srv.fsm.State()
|
||||
|
||||
// ensure a name is set
|
||||
if method.Name == "" {
|
||||
return fmt.Errorf("Invalid Auth Method: no Name is set")
|
||||
}
|
||||
if !validAuthMethod.MatchString(method.Name) {
|
||||
return fmt.Errorf("Invalid Auth Method: invalid Name. Only alphanumeric characters, '-' and '_' are allowed")
|
||||
}
|
||||
|
||||
// Check to see if the method exists first.
|
||||
_, existing, err := state.ACLAuthMethodGetByName(nil, method.Name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("acl auth method lookup failed: %v", err)
|
||||
}
|
||||
|
||||
if existing != nil {
|
||||
if method.Type == "" {
|
||||
method.Type = existing.Type
|
||||
} else if existing.Type != method.Type {
|
||||
return fmt.Errorf("the Type field of an Auth Method is immutable")
|
||||
}
|
||||
}
|
||||
|
||||
if !authmethod.IsRegisteredType(method.Type) {
|
||||
return fmt.Errorf("Invalid Auth Method: Type should be one of: %v", authmethod.Types())
|
||||
}
|
||||
|
||||
// Instantiate a validator but do not cache it yet. This will validate the
|
||||
// configuration.
|
||||
if _, err := authmethod.NewValidator(method); err != nil {
|
||||
return fmt.Errorf("Invalid Auth Method: %v", err)
|
||||
}
|
||||
|
||||
req := &structs.ACLAuthMethodBatchSetRequest{
|
||||
AuthMethods: structs.ACLAuthMethods{method},
|
||||
}
|
||||
|
||||
resp, err := a.srv.raftApply(structs.ACLAuthMethodSetRequestType, req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to apply auth method upsert request: %v", err)
|
||||
}
|
||||
|
||||
if respErr, ok := resp.(error); ok {
|
||||
return respErr
|
||||
}
|
||||
|
||||
if _, method, err := a.srv.fsm.State().ACLAuthMethodGetByName(nil, method.Name); err == nil && method != nil {
|
||||
*reply = *method
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *ACL) AuthMethodDelete(args *structs.ACLAuthMethodDeleteRequest, reply *bool) error {
|
||||
if err := a.aclPreCheck(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !a.srv.LocalTokensEnabled() {
|
||||
return errAuthMethodsRequireTokenReplication
|
||||
}
|
||||
|
||||
if done, err := a.srv.forward("ACL.AuthMethodDelete", args, args, reply); done {
|
||||
return err
|
||||
}
|
||||
|
||||
defer metrics.MeasureSince([]string{"acl", "authmethod", "delete"}, time.Now())
|
||||
|
||||
// Verify token is permitted to modify ACLs
|
||||
if rule, err := a.srv.ResolveToken(args.Token); err != nil {
|
||||
return err
|
||||
} else if rule == nil || !rule.ACLWrite() {
|
||||
return acl.ErrPermissionDenied
|
||||
}
|
||||
|
||||
_, method, err := a.srv.fsm.State().ACLAuthMethodGetByName(nil, args.AuthMethodName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if method == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
req := structs.ACLAuthMethodBatchDeleteRequest{
|
||||
AuthMethodNames: []string{args.AuthMethodName},
|
||||
}
|
||||
|
||||
resp, err := a.srv.raftApply(structs.ACLAuthMethodDeleteRequestType, &req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to apply auth method delete request: %v", err)
|
||||
}
|
||||
|
||||
if respErr, ok := resp.(error); ok {
|
||||
return respErr
|
||||
}
|
||||
|
||||
*reply = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *ACL) AuthMethodList(args *structs.ACLAuthMethodListRequest, reply *structs.ACLAuthMethodListResponse) error {
|
||||
if err := a.aclPreCheck(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !a.srv.LocalTokensEnabled() {
|
||||
return errAuthMethodsRequireTokenReplication
|
||||
}
|
||||
|
||||
if done, err := a.srv.forward("ACL.AuthMethodList", args, args, reply); done {
|
||||
return err
|
||||
}
|
||||
|
||||
if rule, err := a.srv.ResolveToken(args.Token); err != nil {
|
||||
return err
|
||||
} else if rule == nil || !rule.ACLRead() {
|
||||
return acl.ErrPermissionDenied
|
||||
}
|
||||
|
||||
return a.srv.blockingQuery(&args.QueryOptions, &reply.QueryMeta,
|
||||
func(ws memdb.WatchSet, state *state.Store) error {
|
||||
index, methods, err := state.ACLAuthMethodList(ws)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var stubs structs.ACLAuthMethodListStubs
|
||||
for _, method := range methods {
|
||||
stubs = append(stubs, method.Stub())
|
||||
}
|
||||
|
||||
reply.Index, reply.AuthMethods = index, stubs
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (a *ACL) Login(args *structs.ACLLoginRequest, reply *structs.ACLToken) error {
|
||||
if err := a.aclPreCheck(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !a.srv.LocalTokensEnabled() {
|
||||
return errAuthMethodsRequireTokenReplication
|
||||
}
|
||||
|
||||
if args.Token != "" { // This shouldn't happen.
|
||||
return errors.New("do not provide a token when logging in")
|
||||
}
|
||||
|
||||
if done, err := a.srv.forward("ACL.Login", args, args, reply); done {
|
||||
return err
|
||||
}
|
||||
|
||||
defer metrics.MeasureSince([]string{"acl", "login"}, time.Now())
|
||||
|
||||
auth := args.Auth
|
||||
|
||||
// 1. take args.Data.AuthMethod to get an AuthMethod Validator
|
||||
idx, method, err := a.srv.fsm.State().ACLAuthMethodGetByName(nil, auth.AuthMethod)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if method == nil {
|
||||
return acl.ErrNotFound
|
||||
}
|
||||
|
||||
validator, err := a.srv.loadAuthMethodValidator(idx, method)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 2. Send args.Data.BearerToken to method validator and get back a fields map
|
||||
verifiedFields, err := validator.ValidateLogin(auth.BearerToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 3. send map through role bindings
|
||||
serviceIdentities, roleLinks, err := a.srv.evaluateRoleBindings(validator, verifiedFields)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(serviceIdentities) == 0 && len(roleLinks) == 0 {
|
||||
return acl.ErrPermissionDenied
|
||||
}
|
||||
|
||||
description := "token created via login"
|
||||
loginMeta, err := encodeLoginMeta(auth.Meta)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if loginMeta != "" {
|
||||
description += ": " + loginMeta
|
||||
}
|
||||
|
||||
// 4. create token
|
||||
createReq := structs.ACLTokenSetRequest{
|
||||
Datacenter: args.Datacenter,
|
||||
ACLToken: structs.ACLToken{
|
||||
Description: description,
|
||||
Local: true,
|
||||
AuthMethod: auth.AuthMethod,
|
||||
ServiceIdentities: serviceIdentities,
|
||||
Roles: roleLinks,
|
||||
},
|
||||
WriteRequest: args.WriteRequest,
|
||||
}
|
||||
|
||||
// 5. return token information like a TokenCreate would
|
||||
return a.tokenSetInternal(&createReq, reply, true)
|
||||
}
|
||||
|
||||
func encodeLoginMeta(meta map[string]string) (string, error) {
|
||||
if len(meta) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
d, err := json.Marshal(meta)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(d), nil
|
||||
}
|
||||
|
||||
func (a *ACL) Logout(args *structs.ACLLogoutRequest, reply *bool) error {
|
||||
if err := a.aclPreCheck(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !a.srv.LocalTokensEnabled() {
|
||||
return errAuthMethodsRequireTokenReplication
|
||||
}
|
||||
|
||||
if args.Token == "" {
|
||||
return acl.ErrNotFound
|
||||
}
|
||||
|
||||
if done, err := a.srv.forward("ACL.Logout", args, args, reply); done {
|
||||
return err
|
||||
}
|
||||
|
||||
defer metrics.MeasureSince([]string{"acl", "logout"}, time.Now())
|
||||
|
||||
_, token, err := a.srv.fsm.State().ACLTokenGetBySecret(nil, args.Token)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
} else if token == nil {
|
||||
return acl.ErrNotFound
|
||||
|
||||
} else if token.AuthMethod == "" {
|
||||
// Can't "logout" of a token that wasn't a result of login.
|
||||
return acl.ErrPermissionDenied
|
||||
|
||||
} else if !a.srv.InACLDatacenter() && !token.Local {
|
||||
// global token writes must be forwarded to the primary DC
|
||||
args.Datacenter = a.srv.config.ACLDatacenter
|
||||
return a.srv.forwardDC("ACL.Logout", a.srv.config.ACLDatacenter, args, reply)
|
||||
}
|
||||
|
||||
// No need to check expiration time because it's being deleted.
|
||||
|
||||
req := &structs.ACLTokenBatchDeleteRequest{
|
||||
TokenIDs: []string{token.AccessorID},
|
||||
}
|
||||
|
||||
resp, err := a.srv.raftApply(structs.ACLTokenDeleteRequestType, req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to apply token delete request: %v", err)
|
||||
}
|
||||
|
||||
// Purge the identity from the cache to prevent using the previous definition of the identity
|
||||
if token != nil {
|
||||
a.srv.acls.cache.RemoveIdentity(token.SecretID)
|
||||
}
|
||||
|
||||
if respErr, ok := resp.(error); ok {
|
||||
return respErr
|
||||
}
|
||||
|
||||
*reply = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -255,7 +255,7 @@ func (a *ACL) List(args *structs.DCSpecificRequest,
|
|||
return a.srv.blockingQuery(&args.QueryOptions,
|
||||
&reply.QueryMeta,
|
||||
func(ws memdb.WatchSet, state *state.Store) error {
|
||||
index, tokens, err := state.ACLTokenList(ws, false, true, "", "")
|
||||
index, tokens, err := state.ACLTokenList(ws, false, true, "", "", "")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -138,7 +138,7 @@ func reconcileLegacyACLs(local, remote structs.ACLs, lastRemoteIndex uint64) str
|
|||
|
||||
// FetchLocalACLs returns the ACLs in the local state store.
|
||||
func (s *Server) fetchLocalLegacyACLs() (structs.ACLs, error) {
|
||||
_, local, err := s.fsm.State().ACLTokenList(nil, false, true, "", "")
|
||||
_, local, err := s.fsm.State().ACLTokenList(nil, false, true, "", "", "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -396,11 +396,11 @@ func TestACLReplication_LegacyTokens(t *testing.T) {
|
|||
}
|
||||
|
||||
checkSame := func() error {
|
||||
index, remote, err := s1.fsm.State().ACLTokenList(nil, true, true, "", "")
|
||||
index, remote, err := s1.fsm.State().ACLTokenList(nil, true, true, "", "", "")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, local, err := s2.fsm.State().ACLTokenList(nil, true, true, "", "")
|
||||
_, local, err := s2.fsm.State().ACLTokenList(nil, true, true, "", "", "")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -351,9 +351,9 @@ func TestACLReplication_Tokens(t *testing.T) {
|
|||
|
||||
checkSame := func(t *retry.R) {
|
||||
// only account for global tokens - local tokens shouldn't be replicated
|
||||
index, remote, err := s1.fsm.State().ACLTokenList(nil, false, true, "", "")
|
||||
index, remote, err := s1.fsm.State().ACLTokenList(nil, false, true, "", "", "")
|
||||
require.NoError(t, err)
|
||||
_, local, err := s2.fsm.State().ACLTokenList(nil, false, true, "", "")
|
||||
_, local, err := s2.fsm.State().ACLTokenList(nil, false, true, "", "", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, local, len(remote))
|
||||
|
@ -444,7 +444,7 @@ func TestACLReplication_Tokens(t *testing.T) {
|
|||
})
|
||||
|
||||
// verify dc2 local tokens didn't get blown away
|
||||
_, local, err := s2.fsm.State().ACLTokenList(nil, true, false, "", "")
|
||||
_, local, err := s2.fsm.State().ACLTokenList(nil, true, false, "", "", "")
|
||||
require.NoError(t, err)
|
||||
require.Len(t, local, 50)
|
||||
|
||||
|
@ -779,10 +779,10 @@ func TestACLReplication_AllTypes(t *testing.T) {
|
|||
|
||||
checkSameTokens := func(t *retry.R) {
|
||||
// only account for global tokens - local tokens shouldn't be replicated
|
||||
index, remote, err := s1.fsm.State().ACLTokenList(nil, false, true, "", "")
|
||||
index, remote, err := s1.fsm.State().ACLTokenList(nil, false, true, "", "", "")
|
||||
require.NoError(t, err)
|
||||
// Query for all of them, so that we can prove that no globals snuck in.
|
||||
_, local, err := s2.fsm.State().ACLTokenList(nil, true, true, "", "")
|
||||
_, local, err := s2.fsm.State().ACLTokenList(nil, true, true, "", "", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, remote, len(local))
|
||||
|
|
|
@ -34,7 +34,7 @@ func (r *aclTokenReplicator) FetchRemote(srv *Server, lastRemoteIndex uint64) (i
|
|||
func (r *aclTokenReplicator) FetchLocal(srv *Server) (int, uint64, error) {
|
||||
r.local = nil
|
||||
|
||||
idx, local, err := srv.fsm.State().ACLTokenList(nil, false, true, "", "")
|
||||
idx, local, err := srv.fsm.State().ACLTokenList(nil, false, true, "", "", "")
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
|
|
|
@ -73,6 +73,17 @@ func (s *Server) checkRoleUUID(id string) (bool, error) {
|
|||
return !structs.ACLIDReserved(id), nil
|
||||
}
|
||||
|
||||
func (s *Server) checkBindingRuleUUID(id string) (bool, error) {
|
||||
state := s.fsm.State()
|
||||
if _, rule, err := state.ACLBindingRuleGetByID(nil, id); err != nil {
|
||||
return false, err
|
||||
} else if rule != nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return !structs.ACLIDReserved(id), nil
|
||||
}
|
||||
|
||||
func (s *Server) updateACLAdvertisement() {
|
||||
// One thing to note is that once in new ACL mode the server will
|
||||
// never transition to legacy ACL mode. This is not currently a
|
||||
|
|
|
@ -2,7 +2,6 @@ package consul
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
@ -12,6 +11,7 @@ import (
|
|||
|
||||
"github.com/hashicorp/consul/acl"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/hashicorp/consul/sdk/testutil"
|
||||
"github.com/hashicorp/consul/sdk/testutil/retry"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
@ -139,6 +139,26 @@ func testIdentityForToken(token string) (bool, structs.ACLIdentity, error) {
|
|||
},
|
||||
},
|
||||
}, nil
|
||||
case "found-synthetic-policy-1":
|
||||
return true, &structs.ACLToken{
|
||||
AccessorID: "f6c5a5fb-4da4-422b-9abf-2c942813fc71",
|
||||
SecretID: "55cb7d69-2bea-42c3-a68f-2a1443d2abbc",
|
||||
ServiceIdentities: []*structs.ACLServiceIdentity{
|
||||
&structs.ACLServiceIdentity{
|
||||
ServiceName: "service1",
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
case "found-synthetic-policy-2":
|
||||
return true, &structs.ACLToken{
|
||||
AccessorID: "7c87dfad-be37-446e-8305-299585677cb5",
|
||||
SecretID: "dfca9676-ac80-453a-837b-4c0cf923473c",
|
||||
ServiceIdentities: []*structs.ACLServiceIdentity{
|
||||
&structs.ACLServiceIdentity{
|
||||
ServiceName: "service2",
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
case "acl-ro":
|
||||
return true, &structs.ACLToken{
|
||||
AccessorID: "435a75af-1763-4980-89f4-f0951dda53b4",
|
||||
|
@ -430,57 +450,87 @@ type ACLResolverTestDelegate struct {
|
|||
roleCached bool
|
||||
}
|
||||
|
||||
func (d *ACLResolverTestDelegate) Reset() {
|
||||
d.tokenCached = false
|
||||
d.policyCached = false
|
||||
d.roleCached = false
|
||||
}
|
||||
|
||||
var errRPC = fmt.Errorf("Induced RPC Error")
|
||||
|
||||
func (d *ACLResolverTestDelegate) defaultTokenReadFn(errAfterCached error) func(*structs.ACLTokenGetRequest, *structs.ACLTokenResponse) error {
|
||||
return func(args *structs.ACLTokenGetRequest, reply *structs.ACLTokenResponse) error {
|
||||
if !d.tokenCached {
|
||||
_, token, _ := testIdentityForToken(args.TokenID)
|
||||
reply.Token = token.(*structs.ACLToken)
|
||||
|
||||
err := d.plainTokenReadFn(args, reply)
|
||||
d.tokenCached = true
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
return errAfterCached
|
||||
}
|
||||
}
|
||||
|
||||
func (d *ACLResolverTestDelegate) plainTokenReadFn(args *structs.ACLTokenGetRequest, reply *structs.ACLTokenResponse) error {
|
||||
_, token, err := testIdentityForToken(args.TokenID)
|
||||
if token != nil {
|
||||
reply.Token = token.(*structs.ACLToken)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (d *ACLResolverTestDelegate) defaultPolicyResolveFn(errAfterCached error) func(*structs.ACLPolicyBatchGetRequest, *structs.ACLPolicyBatchResponse) error {
|
||||
return func(args *structs.ACLPolicyBatchGetRequest, reply *structs.ACLPolicyBatchResponse) error {
|
||||
if !d.policyCached {
|
||||
for _, policyID := range args.PolicyIDs {
|
||||
_, policy, _ := testPolicyForID(policyID)
|
||||
if policy != nil {
|
||||
reply.Policies = append(reply.Policies, policy)
|
||||
}
|
||||
}
|
||||
|
||||
err := d.plainPolicyResolveFn(args, reply)
|
||||
d.policyCached = true
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
|
||||
return errAfterCached
|
||||
}
|
||||
}
|
||||
|
||||
func (d *ACLResolverTestDelegate) plainPolicyResolveFn(args *structs.ACLPolicyBatchGetRequest, reply *structs.ACLPolicyBatchResponse) error {
|
||||
// TODO: if we were being super correct about it, we'd verify the token first
|
||||
// TODO: and possibly return a not-found or permission-denied here
|
||||
|
||||
for _, policyID := range args.PolicyIDs {
|
||||
_, policy, _ := testPolicyForID(policyID)
|
||||
if policy != nil {
|
||||
reply.Policies = append(reply.Policies, policy)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *ACLResolverTestDelegate) defaultRoleResolveFn(errAfterCached error) func(*structs.ACLRoleBatchGetRequest, *structs.ACLRoleBatchResponse) error {
|
||||
return func(args *structs.ACLRoleBatchGetRequest, reply *structs.ACLRoleBatchResponse) error {
|
||||
if !d.roleCached {
|
||||
for _, roleID := range args.RoleIDs {
|
||||
_, role, _ := testRoleForID(roleID)
|
||||
if role != nil {
|
||||
reply.Roles = append(reply.Roles, role)
|
||||
}
|
||||
}
|
||||
|
||||
err := d.plainRoleResolveFn(args, reply)
|
||||
d.roleCached = true
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
|
||||
return errAfterCached
|
||||
}
|
||||
}
|
||||
|
||||
// plainRoleResolveFn tries to follow the normal logic of ACL.RoleResolve using
|
||||
// the test fixtures.
|
||||
func (d *ACLResolverTestDelegate) plainRoleResolveFn(args *structs.ACLRoleBatchGetRequest, reply *structs.ACLRoleBatchResponse) error {
|
||||
// TODO: if we were being super correct about it, we'd verify the token first
|
||||
// TODO: and possibly return a not-found or permission-denied here
|
||||
|
||||
for _, roleID := range args.RoleIDs {
|
||||
_, role, _ := testRoleForID(roleID)
|
||||
if role != nil {
|
||||
reply.Roles = append(reply.Roles, role)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *ACLResolverTestDelegate) ACLsEnabled() bool {
|
||||
return d.enabled
|
||||
}
|
||||
|
@ -549,7 +599,7 @@ func newTestACLResolver(t *testing.T, delegate ACLResolverDelegate, cb func(*ACL
|
|||
config.ACLDownPolicy = "extend-cache"
|
||||
rconf := &ACLResolverConfig{
|
||||
Config: config,
|
||||
Logger: log.New(os.Stdout, t.Name()+" - ", log.LstdFlags|log.Lmicroseconds),
|
||||
Logger: testutil.TestLoggerWithName(t, t.Name()),
|
||||
CacheConfig: &structs.ACLCachesConfig{
|
||||
Identities: 4,
|
||||
Policies: 4,
|
||||
|
@ -1058,7 +1108,7 @@ func TestACLResolver_DownPolicy(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
require.NotNil(t, authz2)
|
||||
// testing pointer equality - these will be the same object because it is cached.
|
||||
require.True(t, authz == authz2)
|
||||
require.True(t, authz == authz2, "\n[1]={%+v} != \n[2]={%+v}", authz, authz2)
|
||||
require.True(t, authz2.NodeWrite("foo", nil))
|
||||
})
|
||||
|
||||
|
@ -1445,6 +1495,23 @@ func TestACLResolver_Client(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestACLResolver_Client_TokensPoliciesAndRoles(t *testing.T) {
|
||||
t.Parallel()
|
||||
delegate := &ACLResolverTestDelegate{
|
||||
enabled: true,
|
||||
datacenter: "dc1",
|
||||
legacy: false,
|
||||
localTokens: false,
|
||||
localPolicies: false,
|
||||
localRoles: false,
|
||||
}
|
||||
delegate.tokenReadFn = delegate.plainTokenReadFn
|
||||
delegate.policyResolveFn = delegate.plainPolicyResolveFn
|
||||
delegate.roleResolveFn = delegate.plainRoleResolveFn
|
||||
|
||||
testACLResolver_variousTokens(t, delegate)
|
||||
}
|
||||
|
||||
func TestACLResolver_LocalTokensPoliciesAndRoles(t *testing.T) {
|
||||
t.Parallel()
|
||||
delegate := &ACLResolverTestDelegate{
|
||||
|
@ -1470,31 +1537,43 @@ func TestACLResolver_LocalPoliciesAndRoles(t *testing.T) {
|
|||
localTokens: false,
|
||||
localPolicies: true,
|
||||
localRoles: true,
|
||||
tokenReadFn: func(args *structs.ACLTokenGetRequest, reply *structs.ACLTokenResponse) error {
|
||||
_, token, err := testIdentityForToken(args.TokenID)
|
||||
|
||||
if token != nil {
|
||||
reply.Token = token.(*structs.ACLToken)
|
||||
}
|
||||
return err
|
||||
},
|
||||
}
|
||||
delegate.tokenReadFn = delegate.plainTokenReadFn
|
||||
|
||||
testACLResolver_variousTokens(t, delegate)
|
||||
}
|
||||
|
||||
func testACLResolver_variousTokens(t *testing.T, delegate *ACLResolverTestDelegate) {
|
||||
t.Helper()
|
||||
r := newTestACLResolver(t, delegate, nil)
|
||||
r := newTestACLResolver(t, delegate, func(config *ACLResolverConfig) {
|
||||
config.Config.ACLTokenTTL = 600 * time.Second
|
||||
config.Config.ACLPolicyTTL = 30 * time.Millisecond
|
||||
config.Config.ACLRoleTTL = 30 * time.Millisecond
|
||||
config.Config.ACLDownPolicy = "extend-cache"
|
||||
})
|
||||
reset := func() {
|
||||
// prevent subtest bleedover
|
||||
r.cache.Purge()
|
||||
delegate.Reset()
|
||||
}
|
||||
|
||||
t.Run("Missing Identity", func(t *testing.T) {
|
||||
runTwiceAndReset := func(name string, f func(t *testing.T)) {
|
||||
t.Helper()
|
||||
defer reset() // reset the stateful resolve AND blow away the cache
|
||||
|
||||
t.Run(name+" (no-cache)", f)
|
||||
delegate.Reset() // allow the stateful resolve functions to reset
|
||||
t.Run(name+" (cached)", f)
|
||||
}
|
||||
|
||||
runTwiceAndReset("Missing Identity", func(t *testing.T) {
|
||||
authz, err := r.ResolveToken("doesn't exist")
|
||||
require.Nil(t, authz)
|
||||
require.Error(t, err)
|
||||
require.True(t, acl.IsErrNotFound(err))
|
||||
})
|
||||
|
||||
t.Run("Missing Policy", func(t *testing.T) {
|
||||
runTwiceAndReset("Missing Policy", func(t *testing.T) {
|
||||
authz, err := r.ResolveToken("missing-policy")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, authz)
|
||||
|
@ -1502,7 +1581,7 @@ func testACLResolver_variousTokens(t *testing.T, delegate *ACLResolverTestDelega
|
|||
require.False(t, authz.NodeWrite("foo", nil))
|
||||
})
|
||||
|
||||
t.Run("Missing Role", func(t *testing.T) {
|
||||
runTwiceAndReset("Missing Role", func(t *testing.T) {
|
||||
authz, err := r.ResolveToken("missing-role")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, authz)
|
||||
|
@ -1510,7 +1589,7 @@ func testACLResolver_variousTokens(t *testing.T, delegate *ACLResolverTestDelega
|
|||
require.False(t, authz.NodeWrite("foo", nil))
|
||||
})
|
||||
|
||||
t.Run("Missing Policy on Role", func(t *testing.T) {
|
||||
runTwiceAndReset("Missing Policy on Role", func(t *testing.T) {
|
||||
authz, err := r.ResolveToken("missing-policy-on-role")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, authz)
|
||||
|
@ -1518,7 +1597,7 @@ func testACLResolver_variousTokens(t *testing.T, delegate *ACLResolverTestDelega
|
|||
require.False(t, authz.NodeWrite("foo", nil))
|
||||
})
|
||||
|
||||
t.Run("Normal with Policy", func(t *testing.T) {
|
||||
runTwiceAndReset("Normal with Policy", func(t *testing.T) {
|
||||
authz, err := r.ResolveToken("found")
|
||||
require.NotNil(t, authz)
|
||||
require.NoError(t, err)
|
||||
|
@ -1526,7 +1605,7 @@ func testACLResolver_variousTokens(t *testing.T, delegate *ACLResolverTestDelega
|
|||
require.True(t, authz.NodeWrite("foo", nil))
|
||||
})
|
||||
|
||||
t.Run("Normal with Role", func(t *testing.T) {
|
||||
runTwiceAndReset("Normal with Role", func(t *testing.T) {
|
||||
authz, err := r.ResolveToken("found-role")
|
||||
require.NotNil(t, authz)
|
||||
require.NoError(t, err)
|
||||
|
@ -1534,7 +1613,7 @@ func testACLResolver_variousTokens(t *testing.T, delegate *ACLResolverTestDelega
|
|||
require.True(t, authz.NodeWrite("foo", nil))
|
||||
})
|
||||
|
||||
t.Run("Normal with Policy and Role", func(t *testing.T) {
|
||||
runTwiceAndReset("Normal with Policy and Role", func(t *testing.T) {
|
||||
authz, err := r.ResolveToken("found-policy-and-role")
|
||||
require.NotNil(t, authz)
|
||||
require.NoError(t, err)
|
||||
|
@ -1543,7 +1622,41 @@ func testACLResolver_variousTokens(t *testing.T, delegate *ACLResolverTestDelega
|
|||
require.True(t, authz.ServiceRead("bar"))
|
||||
})
|
||||
|
||||
t.Run("Anonymous", func(t *testing.T) {
|
||||
runTwiceAndReset("Synthetic Policies Independently Cache", func(t *testing.T) {
|
||||
// We resolve both of these tokens in the same cache session
|
||||
// to verify that the keys for caching synthetic policies don't bleed
|
||||
// over between each other.
|
||||
{
|
||||
authz, err := r.ResolveToken("found-synthetic-policy-1")
|
||||
require.NotNil(t, authz)
|
||||
require.NoError(t, err)
|
||||
// spot check some random perms
|
||||
require.False(t, authz.ACLRead())
|
||||
require.False(t, authz.NodeWrite("foo", nil))
|
||||
// ensure we didn't bleed over to the other synthetic policy
|
||||
require.False(t, authz.ServiceWrite("service2", nil))
|
||||
// check our own synthetic policy
|
||||
require.True(t, authz.ServiceWrite("service1", nil))
|
||||
require.True(t, authz.ServiceRead("literally-anything"))
|
||||
require.True(t, authz.NodeRead("any-node"))
|
||||
}
|
||||
{
|
||||
authz, err := r.ResolveToken("found-synthetic-policy-2")
|
||||
require.NotNil(t, authz)
|
||||
require.NoError(t, err)
|
||||
// spot check some random perms
|
||||
require.False(t, authz.ACLRead())
|
||||
require.False(t, authz.NodeWrite("foo", nil))
|
||||
// ensure we didn't bleed over to the other synthetic policy
|
||||
require.False(t, authz.ServiceWrite("service1", nil))
|
||||
// check our own synthetic policy
|
||||
require.True(t, authz.ServiceWrite("service2", nil))
|
||||
require.True(t, authz.ServiceRead("literally-anything"))
|
||||
require.True(t, authz.NodeRead("any-node"))
|
||||
}
|
||||
})
|
||||
|
||||
runTwiceAndReset("Anonymous", func(t *testing.T) {
|
||||
authz, err := r.ResolveToken("")
|
||||
require.NotNil(t, authz)
|
||||
require.NoError(t, err)
|
||||
|
@ -1551,7 +1664,7 @@ func testACLResolver_variousTokens(t *testing.T, delegate *ACLResolverTestDelega
|
|||
require.True(t, authz.NodeWrite("foo", nil))
|
||||
})
|
||||
|
||||
t.Run("legacy-management", func(t *testing.T) {
|
||||
runTwiceAndReset("legacy-management", func(t *testing.T) {
|
||||
authz, err := r.ResolveToken("legacy-management")
|
||||
require.NotNil(t, authz)
|
||||
require.NoError(t, err)
|
||||
|
@ -1559,7 +1672,7 @@ func testACLResolver_variousTokens(t *testing.T, delegate *ACLResolverTestDelega
|
|||
require.True(t, authz.KeyRead("foo"))
|
||||
})
|
||||
|
||||
t.Run("legacy-client", func(t *testing.T) {
|
||||
runTwiceAndReset("legacy-client", func(t *testing.T) {
|
||||
authz, err := r.ResolveToken("legacy-client")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, authz)
|
||||
|
|
|
@ -51,7 +51,7 @@ func testACLTokenReap_Primary(t *testing.T, local, global bool) {
|
|||
codec := rpcClient(t, s1)
|
||||
defer codec.Close()
|
||||
|
||||
acl := ACL{s1}
|
||||
acl := ACL{srv: s1}
|
||||
|
||||
masterTokenAccessorID, err := retrieveTestTokenAccessorForSecret(codec, "root", "dc1", "root")
|
||||
require.NoError(t, err)
|
||||
|
|
|
@ -0,0 +1,112 @@
|
|||
package authmethod
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"sync"
|
||||
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
)
|
||||
|
||||
type ValidatorFactory func(method *structs.ACLAuthMethod) (Validator, error)
|
||||
|
||||
type Validator interface {
|
||||
// Name returns the name of the auth method backing this validator.
|
||||
Name() string
|
||||
|
||||
// ValidateLogin takes raw user-provided auth method metadata and ensures
|
||||
// it is sane, provably correct, and currently valid. Relevant identifying
|
||||
// data is extracted and returned for immediate use by the role binding
|
||||
// process.
|
||||
//
|
||||
// Depending upon the method, it may make sense to use these calls to
|
||||
// continue to extend the life of the underlying token.
|
||||
//
|
||||
// Returns auth method specific metadata suitable for the Role Binding
|
||||
// process.
|
||||
ValidateLogin(loginToken string) (map[string]string, error)
|
||||
|
||||
// AvailableFields returns a slice of all fields that are returned as a
|
||||
// result of ValidateLogin. These are valid fields for use in any
|
||||
// BindingRule tied to this auth method.
|
||||
AvailableFields() []string
|
||||
|
||||
// MakeFieldMapSelectable converts a field map as returned by ValidateLogin
|
||||
// into a structure suitable for selection with a binding rule.
|
||||
MakeFieldMapSelectable(fieldMap map[string]string) interface{}
|
||||
}
|
||||
|
||||
var (
|
||||
typesMu sync.RWMutex
|
||||
types = make(map[string]ValidatorFactory)
|
||||
)
|
||||
|
||||
// Register makes an auth method with the given type available for use. If
|
||||
// Register is called twice with the same name or if validator is nil, it
|
||||
// panics.
|
||||
func Register(name string, factory ValidatorFactory) {
|
||||
typesMu.Lock()
|
||||
defer typesMu.Unlock()
|
||||
if factory == nil {
|
||||
panic("authmethod: Register factory is nil for type " + name)
|
||||
}
|
||||
if _, dup := types[name]; dup {
|
||||
panic("authmethod: Register called twice for type " + name)
|
||||
}
|
||||
types[name] = factory
|
||||
}
|
||||
|
||||
func IsRegisteredType(typeName string) bool {
|
||||
typesMu.RLock()
|
||||
_, ok := types[typeName]
|
||||
typesMu.RUnlock()
|
||||
return ok
|
||||
}
|
||||
|
||||
// NewValidator instantiates a new Validator for the given auth method
|
||||
// configuration. If no auth method is registered with the provided type an
|
||||
// error is returned.
|
||||
func NewValidator(method *structs.ACLAuthMethod) (Validator, error) {
|
||||
typesMu.RLock()
|
||||
factory, ok := types[method.Type]
|
||||
typesMu.RUnlock()
|
||||
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no auth method registered with type: %s", method.Type)
|
||||
}
|
||||
|
||||
return factory(method)
|
||||
}
|
||||
|
||||
// Types returns a sorted list of the names of the registered types.
|
||||
func Types() []string {
|
||||
typesMu.RLock()
|
||||
defer typesMu.RUnlock()
|
||||
var list []string
|
||||
for name := range types {
|
||||
list = append(list, name)
|
||||
}
|
||||
sort.Strings(list)
|
||||
return list
|
||||
}
|
||||
|
||||
// ParseConfig parses the config block for a auth method.
|
||||
func ParseConfig(rawConfig map[string]interface{}, out interface{}) error {
|
||||
decodeConf := &mapstructure.DecoderConfig{
|
||||
Result: out,
|
||||
WeaklyTypedInput: true,
|
||||
ErrorUnused: true,
|
||||
}
|
||||
|
||||
decoder, err := mapstructure.NewDecoder(decodeConf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := decoder.Decode(rawConfig); err != nil {
|
||||
return fmt.Errorf("error decoding config: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,202 @@
|
|||
package kubeauth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/consul/agent/consul/authmethod"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
cleanhttp "github.com/hashicorp/go-cleanhttp"
|
||||
"gopkg.in/square/go-jose.v2/jwt"
|
||||
authv1 "k8s.io/api/authentication/v1"
|
||||
client_metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
||||
k8s "k8s.io/client-go/kubernetes"
|
||||
client_authv1 "k8s.io/client-go/kubernetes/typed/authentication/v1"
|
||||
client_corev1 "k8s.io/client-go/kubernetes/typed/core/v1"
|
||||
client_rest "k8s.io/client-go/rest"
|
||||
cert "k8s.io/client-go/util/cert"
|
||||
)
|
||||
|
||||
func init() {
|
||||
// register this as an available auth method type
|
||||
authmethod.Register("kubernetes", func(method *structs.ACLAuthMethod) (authmethod.Validator, error) {
|
||||
v, err := NewValidator(method)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return v, nil
|
||||
})
|
||||
}
|
||||
|
||||
const (
|
||||
serviceAccountNamespaceField = "serviceaccount.namespace"
|
||||
serviceAccountNameField = "serviceaccount.name"
|
||||
serviceAccountUIDField = "serviceaccount.uid"
|
||||
|
||||
serviceAccountServiceNameAnnotation = "consul.hashicorp.com/service-name"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
// Host must be a host string, a host:port pair, or a URL to the base of
|
||||
// the Kubernetes API server.
|
||||
Host string `json:",omitempty"`
|
||||
|
||||
// PEM encoded CA cert for use by the TLS client used to talk with the
|
||||
// Kubernetes API. Every line must end with a newline: \n
|
||||
CACert string `json:",omitempty"`
|
||||
|
||||
// A service account JWT used to access the TokenReview API to validate
|
||||
// other JWTs during login. It also must be able to read ServiceAccount
|
||||
// annotations.
|
||||
ServiceAccountJWT string `json:",omitempty"`
|
||||
}
|
||||
|
||||
// Validator is the wrapper around the relevant portions of the Kubernetes API
|
||||
// that also conforms to the authmethod.Validator interface.
|
||||
type Validator struct {
|
||||
name string
|
||||
config *Config
|
||||
saGetter client_corev1.ServiceAccountsGetter
|
||||
trGetter client_authv1.TokenReviewsGetter
|
||||
}
|
||||
|
||||
func NewValidator(method *structs.ACLAuthMethod) (*Validator, error) {
|
||||
if method.Type != "kubernetes" {
|
||||
return nil, fmt.Errorf("%q is not a kubernetes auth method", method.Name)
|
||||
}
|
||||
|
||||
var config Config
|
||||
if err := authmethod.ParseConfig(method.Config, &config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if config.Host == "" {
|
||||
return nil, fmt.Errorf("Config.Host is required")
|
||||
}
|
||||
|
||||
if config.CACert == "" {
|
||||
return nil, fmt.Errorf("Config.CACert is required")
|
||||
}
|
||||
if _, err := cert.ParseCertsPEM([]byte(config.CACert)); err != nil {
|
||||
return nil, fmt.Errorf("error parsing kubernetes ca cert: %v", err)
|
||||
}
|
||||
|
||||
// This is the bearer token we give the apiserver to use the API.
|
||||
if config.ServiceAccountJWT == "" {
|
||||
return nil, fmt.Errorf("Config.ServiceAccountJWT is required")
|
||||
}
|
||||
if _, err := jwt.ParseSigned(config.ServiceAccountJWT); err != nil {
|
||||
return nil, fmt.Errorf("Config.ServiceAccountJWT is not a valid JWT: %v", err)
|
||||
}
|
||||
|
||||
transport := cleanhttp.DefaultTransport()
|
||||
client, err := k8s.NewForConfig(&client_rest.Config{
|
||||
Host: config.Host,
|
||||
BearerToken: config.ServiceAccountJWT,
|
||||
Dial: transport.DialContext,
|
||||
TLSClientConfig: client_rest.TLSClientConfig{
|
||||
CAData: []byte(config.CACert),
|
||||
},
|
||||
ContentConfig: client_rest.ContentConfig{
|
||||
ContentType: "application/json",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Validator{
|
||||
name: method.Name,
|
||||
config: &config,
|
||||
saGetter: client.CoreV1(),
|
||||
trGetter: client.AuthenticationV1(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (v *Validator) Name() string { return v.name }
|
||||
|
||||
func (v *Validator) ValidateLogin(loginToken string) (map[string]string, error) {
|
||||
if _, err := jwt.ParseSigned(loginToken); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse and validate JWT: %v", err)
|
||||
}
|
||||
|
||||
// Check TokenReview for the bulk of the work.
|
||||
trResp, err := v.trGetter.TokenReviews().Create(&authv1.TokenReview{
|
||||
Spec: authv1.TokenReviewSpec{
|
||||
Token: loginToken,
|
||||
},
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if trResp.Status.Error != "" {
|
||||
return nil, fmt.Errorf("lookup failed: %s", trResp.Status.Error)
|
||||
}
|
||||
|
||||
if !trResp.Status.Authenticated {
|
||||
return nil, errors.New("lookup failed: service account jwt not valid")
|
||||
}
|
||||
|
||||
// The username is of format: system:serviceaccount:(NAMESPACE):(SERVICEACCOUNT)
|
||||
parts := strings.Split(trResp.Status.User.Username, ":")
|
||||
if len(parts) != 4 {
|
||||
return nil, errors.New("lookup failed: unexpected username format")
|
||||
}
|
||||
|
||||
// Validate the user that comes back from token review is a service account
|
||||
if parts[0] != "system" || parts[1] != "serviceaccount" {
|
||||
return nil, errors.New("lookup failed: username returned is not a service account")
|
||||
}
|
||||
|
||||
var (
|
||||
saNamespace = parts[2]
|
||||
saName = parts[3]
|
||||
saUID = string(trResp.Status.User.UID)
|
||||
)
|
||||
|
||||
// Check to see if there is an override name on the ServiceAccount object.
|
||||
sa, err := v.saGetter.ServiceAccounts(saNamespace).Get(saName, client_metav1.GetOptions{})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("annotation lookup failed: %v", err)
|
||||
}
|
||||
|
||||
annotations := sa.GetObjectMeta().GetAnnotations()
|
||||
if serviceNameOverride, ok := annotations[serviceAccountServiceNameAnnotation]; ok {
|
||||
saName = serviceNameOverride
|
||||
}
|
||||
|
||||
return map[string]string{
|
||||
serviceAccountNamespaceField: saNamespace,
|
||||
serviceAccountNameField: saName,
|
||||
serviceAccountUIDField: saUID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *Validator) AvailableFields() []string {
|
||||
return []string{
|
||||
serviceAccountNamespaceField,
|
||||
serviceAccountNameField,
|
||||
serviceAccountUIDField,
|
||||
}
|
||||
}
|
||||
|
||||
func (v *Validator) MakeFieldMapSelectable(fieldMap map[string]string) interface{} {
|
||||
return &k8sFieldDetails{
|
||||
ServiceAccount: k8sFieldDetailsServiceAccount{
|
||||
Namespace: fieldMap[serviceAccountNamespaceField],
|
||||
Name: fieldMap[serviceAccountNameField],
|
||||
UID: fieldMap[serviceAccountUIDField],
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type k8sFieldDetails struct {
|
||||
ServiceAccount k8sFieldDetailsServiceAccount `bexpr:"serviceaccount"`
|
||||
}
|
||||
|
||||
type k8sFieldDetailsServiceAccount struct {
|
||||
Namespace string `bexpr:"namespace"`
|
||||
Name string `bexpr:"name"`
|
||||
UID string `bexpr:"uid"`
|
||||
}
|
|
@ -0,0 +1,144 @@
|
|||
package kubeauth
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/consul/agent/connect"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestValidateLogin(t *testing.T) {
|
||||
testSrv := StartTestAPIServer(t)
|
||||
defer testSrv.Stop()
|
||||
|
||||
testSrv.AuthorizeJWT(goodJWT_A)
|
||||
testSrv.SetAllowedServiceAccount(
|
||||
"default",
|
||||
"demo",
|
||||
"76091af4-4b56-11e9-ac4b-708b11801cbe",
|
||||
"",
|
||||
goodJWT_B,
|
||||
)
|
||||
|
||||
method := &structs.ACLAuthMethod{
|
||||
Name: "test-k8s",
|
||||
Description: "k8s test",
|
||||
Type: "kubernetes",
|
||||
Config: map[string]interface{}{
|
||||
"Host": testSrv.Addr(),
|
||||
"CACert": testSrv.CACert(),
|
||||
"ServiceAccountJWT": goodJWT_A,
|
||||
},
|
||||
}
|
||||
validator, err := NewValidator(method)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("invalid bearer token", func(t *testing.T) {
|
||||
_, err := validator.ValidateLogin("invalid")
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("valid bearer token", func(t *testing.T) {
|
||||
fields, err := validator.ValidateLogin(goodJWT_B)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, map[string]string{
|
||||
"serviceaccount.namespace": "default",
|
||||
"serviceaccount.name": "demo",
|
||||
"serviceaccount.uid": "76091af4-4b56-11e9-ac4b-708b11801cbe",
|
||||
}, fields)
|
||||
})
|
||||
|
||||
// annotate the account
|
||||
testSrv.SetAllowedServiceAccount(
|
||||
"default",
|
||||
"demo",
|
||||
"76091af4-4b56-11e9-ac4b-708b11801cbe",
|
||||
"alternate-name",
|
||||
goodJWT_B,
|
||||
)
|
||||
|
||||
t.Run("valid bearer token with annotation", func(t *testing.T) {
|
||||
fields, err := validator.ValidateLogin(goodJWT_B)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, map[string]string{
|
||||
"serviceaccount.namespace": "default",
|
||||
"serviceaccount.name": "alternate-name",
|
||||
"serviceaccount.uid": "76091af4-4b56-11e9-ac4b-708b11801cbe",
|
||||
}, fields)
|
||||
})
|
||||
}
|
||||
|
||||
func TestNewValidator(t *testing.T) {
|
||||
ca := connect.TestCA(t, nil)
|
||||
|
||||
type AM = *structs.ACLAuthMethod
|
||||
|
||||
makeAuthMethod := func(f func(method AM)) *structs.ACLAuthMethod {
|
||||
method := &structs.ACLAuthMethod{
|
||||
Name: "test-k8s",
|
||||
Description: "k8s test",
|
||||
Type: "kubernetes",
|
||||
Config: map[string]interface{}{
|
||||
"Host": "https://abc:8443",
|
||||
"CACert": ca.RootCert,
|
||||
"ServiceAccountJWT": goodJWT_A,
|
||||
},
|
||||
}
|
||||
if f != nil {
|
||||
f(method)
|
||||
}
|
||||
return method
|
||||
}
|
||||
|
||||
for _, test := range []struct {
|
||||
name string
|
||||
method *structs.ACLAuthMethod
|
||||
ok bool
|
||||
}{
|
||||
// bad
|
||||
{"wrong type", makeAuthMethod(func(method AM) {
|
||||
method.Type = "invalid"
|
||||
}), false},
|
||||
{"extra config", makeAuthMethod(func(method AM) {
|
||||
method.Config["extra"] = "config"
|
||||
}), false},
|
||||
{"wrong type of config", makeAuthMethod(func(method AM) {
|
||||
method.Config["Host"] = []int{12345}
|
||||
}), false},
|
||||
{"missing host", makeAuthMethod(func(method AM) {
|
||||
delete(method.Config, "Host")
|
||||
}), false},
|
||||
{"missing ca cert", makeAuthMethod(func(method AM) {
|
||||
delete(method.Config, "CACert")
|
||||
}), false},
|
||||
{"invalid ca cert", makeAuthMethod(func(method AM) {
|
||||
method.Config["CACert"] = "invalid"
|
||||
}), false},
|
||||
{"invalid jwt", makeAuthMethod(func(method AM) {
|
||||
method.Config["ServiceAccountJWT"] = "invalid"
|
||||
}), false},
|
||||
{"garbage host", makeAuthMethod(func(method AM) {
|
||||
method.Config["Host"] = "://:12345"
|
||||
}), false},
|
||||
// good
|
||||
{"normal", makeAuthMethod(nil), true},
|
||||
} {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
v, err := NewValidator(test.method)
|
||||
if test.ok {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, v)
|
||||
} else {
|
||||
require.NotNil(t, err)
|
||||
require.Nil(t, v)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 'default/admin'
|
||||
const goodJWT_A = "eyJhbGciOiJSUzI1NiIsImtpZCI6IiJ9.eyJpc3MiOiJrdWJlcm5ldGVzL3NlcnZpY2VhY2NvdW50Iiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9uYW1lc3BhY2UiOiJkZWZhdWx0Iiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9zZWNyZXQubmFtZSI6ImFkbWluLXRva2VuLXFsejQyIiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9zZXJ2aWNlLWFjY291bnQubmFtZSI6ImFkbWluIiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9zZXJ2aWNlLWFjY291bnQudWlkIjoiNzM4YmMyNTEtNjUzMi0xMWU5LWI2N2YtNDhlNmM4YjhlY2I1Iiwic3ViIjoic3lzdGVtOnNlcnZpY2VhY2NvdW50OmRlZmF1bHQ6YWRtaW4ifQ.ixMlnWrAG7NVuTTKu8cdcYfM7gweS3jlKaEsIBNGOVEjPE7rtXtgMkAwjQTdYR08_0QBjkgzy5fQC5ZNyglSwONJ-bPaXGvhoH1cTnRi1dz9H_63CfqOCvQP1sbdkMeRxNTGVAyWZT76rXoCUIfHP4LY2I8aab0KN9FTIcgZRF0XPTtT70UwGIrSmRpxW38zjiy2ymWL01cc5VWGhJqVysmWmYk3wNp0h5N57H_MOrz4apQR4pKaamzskzjLxO55gpbmZFC76qWuUdexAR7DT2fpbHLOw90atN_NlLMY-VrXyW3-Ei5EhYaVreMB9PSpKwkrA4jULITohV-sxpa1LA"
|
||||
|
||||
// 'default/demo'
|
||||
const goodJWT_B = "eyJhbGciOiJSUzI1NiIsImtpZCI6IiJ9.eyJpc3MiOiJrdWJlcm5ldGVzL3NlcnZpY2VhY2NvdW50Iiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9uYW1lc3BhY2UiOiJkZWZhdWx0Iiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9zZWNyZXQubmFtZSI6ImRlbW8tdG9rZW4ta21iOW4iLCJrdWJlcm5ldGVzLmlvL3NlcnZpY2VhY2NvdW50L3NlcnZpY2UtYWNjb3VudC5uYW1lIjoiZGVtbyIsImt1YmVybmV0ZXMuaW8vc2VydmljZWFjY291bnQvc2VydmljZS1hY2NvdW50LnVpZCI6Ijc2MDkxYWY0LTRiNTYtMTFlOS1hYzRiLTcwOGIxMTgwMWNiZSIsInN1YiI6InN5c3RlbTpzZXJ2aWNlYWNjb3VudDpkZWZhdWx0OmRlbW8ifQ.ZiAHjijBAOsKdum0Aix6lgtkLkGo9_Tu87dWQ5Zfwnn3r2FejEWDAnftTft1MqqnMzivZ9Wyyki5ZjQRmTAtnMPJuHC-iivqY4Wh4S6QWCJ1SivBv5tMZR79t5t8mE7R1-OHwst46spru1pps9wt9jsA04d3LpV0eeKYgdPTVaQKklxTm397kIMUugA6yINIBQ3Rh8eQqBgNwEmL4iqyYubzHLVkGkoP9MJikFI05vfRiHtYr-piXz6JFDzXMQj9rW6xtMmrBSn79ChbyvC5nz-Nj2rJPnHsb_0rDUbmXY5PpnMhBpdSH-CbZ4j8jsiib6DtaGJhVZeEQ1GjsFAZwQ"
|
|
@ -0,0 +1,532 @@
|
|||
package kubeauth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
authv1 "k8s.io/api/authentication/v1"
|
||||
corev1 "k8s.io/api/core/v1"
|
||||
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
||||
"k8s.io/apimachinery/pkg/types"
|
||||
)
|
||||
|
||||
// TestAPIServer is a way to mock the Kubernetes API server as it is used by
|
||||
// the consul kubernetes auth method.
|
||||
//
|
||||
// - POST /apis/authentication.k8s.io/v1/tokenreviews
|
||||
// - GET /api/v1/namespaces/<NAMESPACE>/serviceaccounts/<NAME>
|
||||
//
|
||||
type TestAPIServer struct {
|
||||
t *testing.T
|
||||
srv *httptest.Server
|
||||
caCert string
|
||||
|
||||
mu sync.Mutex
|
||||
authorizedJWT string // token review and sa read
|
||||
allowedServiceAccountJWT string // general service account
|
||||
replyStatus *authv1.TokenReview // general service account
|
||||
replyRead *corev1.ServiceAccount // general service account
|
||||
}
|
||||
|
||||
// StartTestAPIServer creates a disposable TestAPIServer and binds it to a
|
||||
// random free port.
|
||||
func StartTestAPIServer(t *testing.T) *TestAPIServer {
|
||||
s := &TestAPIServer{t: t}
|
||||
|
||||
s.srv = httptest.NewTLSServer(s)
|
||||
|
||||
bs := s.srv.TLS.Certificates[0].Certificate[0]
|
||||
|
||||
var buf bytes.Buffer
|
||||
require.NoError(t, pem.Encode(&buf, &pem.Block{Type: "CERTIFICATE", Bytes: bs}))
|
||||
s.caCert = buf.String()
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// AuthorizeJWT whitelists the given JWT as able to use the API server.
|
||||
func (s *TestAPIServer) AuthorizeJWT(jwt string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.authorizedJWT = jwt
|
||||
}
|
||||
|
||||
// SetAllowedServiceAccount configures the singular known Service Account
|
||||
// installed in this API server. If any of namespace/name/uid/jwt are empty
|
||||
// it removes anything previously configured.
|
||||
//
|
||||
// It is up to the caller to ensure that the provided JWT matches the other
|
||||
// data.
|
||||
func (s *TestAPIServer) SetAllowedServiceAccount(
|
||||
namespace, name, uid, overrideAnnotation, jwt string,
|
||||
) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if namespace == "" || name == "" || uid == "" || jwt == "" {
|
||||
s.allowedServiceAccountJWT = ""
|
||||
s.replyStatus = nil
|
||||
s.replyRead = nil
|
||||
return
|
||||
}
|
||||
|
||||
s.allowedServiceAccountJWT = jwt
|
||||
s.replyRead = createReadServiceAccountFound(namespace, name, uid, overrideAnnotation, jwt)
|
||||
s.replyStatus = createTokenReviewFound(namespace, name, uid, jwt)
|
||||
}
|
||||
|
||||
// Stop stops the running TestAPIServer.
|
||||
func (s *TestAPIServer) Stop() {
|
||||
s.srv.Close()
|
||||
}
|
||||
|
||||
// Addr returns the current base URL for the running webserver.
|
||||
func (s *TestAPIServer) Addr() string { return s.srv.URL }
|
||||
|
||||
// CACert returns the pem-encoded CA certificate used by the HTTPS server.
|
||||
func (s *TestAPIServer) CACert() string { return s.caCert }
|
||||
|
||||
var readServiceAccountPathRE = regexp.MustCompile("^/api/v1/namespaces/([^/]+)/serviceaccounts/([^/]+)$")
|
||||
|
||||
func (s *TestAPIServer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
w.Header().Set("content-type", "application/json")
|
||||
|
||||
if req.URL.Path == "/apis/authentication.k8s.io/v1/tokenreviews" {
|
||||
s.handleTokenReview(w, req)
|
||||
return
|
||||
}
|
||||
|
||||
if m := readServiceAccountPathRE.FindStringSubmatch(req.URL.Path); m != nil {
|
||||
namespace, err := url.QueryUnescape(m[1])
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
name, err := url.QueryUnescape(m[2])
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
s.handleReadServiceAccount(namespace, name, w, req)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
|
||||
func writeJSON(w http.ResponseWriter, out interface{}) error {
|
||||
enc := json.NewEncoder(w)
|
||||
return enc.Encode(out)
|
||||
}
|
||||
|
||||
func (s *TestAPIServer) handleTokenReview(w http.ResponseWriter, req *http.Request) {
|
||||
if req.Method != "POST" {
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
if auth, anon := s.isAuthenticated(req); !auth {
|
||||
var out interface{}
|
||||
if anon {
|
||||
out = createTokenReviewForbidden_NoAuthz()
|
||||
} else {
|
||||
out = createTokenReviewForbidden("default", "fake-account")
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
if err := writeJSON(w, out); err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if req.Body == nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
defer req.Body.Close()
|
||||
|
||||
b, err := ioutil.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
var trReq authv1.TokenReview
|
||||
if err := json.Unmarshal(b, &trReq); err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
reviewingJWT := trReq.Spec.Token
|
||||
|
||||
var out interface{}
|
||||
if s.replyStatus == nil || reviewingJWT != s.allowedServiceAccountJWT {
|
||||
out = createTokenReviewNotFound(reviewingJWT)
|
||||
} else {
|
||||
out = s.replyStatus
|
||||
}
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
|
||||
if err := writeJSON(w, out); err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (s *TestAPIServer) handleReadServiceAccount(
|
||||
namespace, name string,
|
||||
w http.ResponseWriter,
|
||||
req *http.Request,
|
||||
) {
|
||||
if req.Method != "GET" {
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
var out interface{}
|
||||
if auth, anon := s.isAuthenticated(req); !auth {
|
||||
if anon {
|
||||
out = createReadServiceAccountForbidden_NoAuthz()
|
||||
} else {
|
||||
out = createReadServiceAccountForbidden(namespace, name)
|
||||
}
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
} else if s.replyRead == nil {
|
||||
out = createReadServiceAccountNotFound(namespace, name)
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
} else if s.replyRead.Namespace != namespace || s.replyRead.Name != name {
|
||||
out = createReadServiceAccountNotFound(namespace, name)
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
} else {
|
||||
out = s.replyRead
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
if err := writeJSON(w, out); err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (s *TestAPIServer) isAuthenticated(req *http.Request) (auth, anonymous bool) {
|
||||
authz := req.Header.Get("Authorization")
|
||||
if !strings.HasPrefix(authz, "Bearer ") {
|
||||
return false, true
|
||||
}
|
||||
jwt := strings.TrimPrefix(authz, "Bearer ")
|
||||
|
||||
return s.authorizedJWT == jwt, false
|
||||
}
|
||||
|
||||
func createTokenReviewForbidden_NoAuthz() *metav1.Status {
|
||||
/*
|
||||
STATUS: 403
|
||||
{
|
||||
"kind": "Status",
|
||||
"apiVersion": "v1",
|
||||
"metadata": {},
|
||||
"status": "Failure",
|
||||
"message": "tokenreviews.authentication.k8s.io is forbidden: User \"system:anonymous\" cannot create resource \"tokenreviews\" in API group \"authentication.k8s.io\" at the cluster scope",
|
||||
"reason": "Forbidden",
|
||||
"details": {
|
||||
"group": "authentication.k8s.io",
|
||||
"kind": "tokenreviews"
|
||||
},
|
||||
"code": 403
|
||||
}
|
||||
*/
|
||||
return createStatus(
|
||||
metav1.StatusFailure,
|
||||
"tokenreviews.authentication.k8s.io is forbidden: User \"system:anonymous\" cannot create resource \"tokenreviews\" in API group \"authentication.k8s.io\" in the cluster scope",
|
||||
metav1.StatusReasonForbidden,
|
||||
&metav1.StatusDetails{
|
||||
Group: "authentication.k8s.io",
|
||||
Kind: "tokenreviews",
|
||||
},
|
||||
403,
|
||||
)
|
||||
}
|
||||
|
||||
func createTokenReviewForbidden(namespace, name string) *metav1.Status {
|
||||
/*
|
||||
STATUS: 403
|
||||
{
|
||||
"kind": "Status",
|
||||
"apiVersion": "v1",
|
||||
"metadata": {},
|
||||
"status": "Failure",
|
||||
"message": "tokenreviews.authentication.k8s.io is forbidden: User \"system:serviceaccount:default:admin\" cannot create resource \"tokenreviews\" in API group \"authentication.k8s.io\" at the cluster scope",
|
||||
"reason": "Forbidden",
|
||||
"details": {
|
||||
"group": "authentication.k8s.io",
|
||||
"kind": "tokenreviews"
|
||||
},
|
||||
"code": 403
|
||||
}
|
||||
*/
|
||||
return createStatus(
|
||||
metav1.StatusFailure,
|
||||
"tokenreviews.authentication.k8s.io is forbidden: User \"system:serviceaccount:"+namespace+":"+name+"\" cannot create resource \"tokenreviews\" in API group \"authentication.k8s.io\" in the cluster scope",
|
||||
metav1.StatusReasonForbidden,
|
||||
&metav1.StatusDetails{
|
||||
Group: "authentication.k8s.io",
|
||||
Kind: "tokenreviews",
|
||||
},
|
||||
403,
|
||||
)
|
||||
}
|
||||
|
||||
func createTokenReviewNotFound(jwt string) *authv1.TokenReview {
|
||||
/*
|
||||
STATUS: 201
|
||||
{
|
||||
"kind": "TokenReview",
|
||||
"apiVersion": "authentication.k8s.io/v1",
|
||||
"metadata": {
|
||||
"creationTimestamp": null
|
||||
},
|
||||
"spec": {
|
||||
"token": "eyJhbGciOiJSUzI1NiIsImtpZCI6IiJ9.eyJpc3MiOiJrdWJlcm5ldGVzL3NlcnZpY2VhY2NvdW50Iiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9uYW1lc3BhY2UiOiJkZWZhdWx0Iiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9zZWNyZXQubmFtZSI6ImZha2UtdG9rZW4tano2YnYiLCJrdWJlcm5ldGVzLmlvL3NlcnZpY2VhY2NvdW50L3NlcnZpY2UtYWNjb3VudC5uYW1lIjoiZmFrZSIsImt1YmVybmV0ZXMuaW8vc2VydmljZWFjY291bnQvc2VydmljZS1hY2NvdW50LnVpZCI6IjgxYTY1Mjg2LTU3YzEtMTFlOS1iYzJhLTQ4ZTZjOGI4ZWNiNSIsInN1YiI6InN5c3RlbTpzZXJ2aWNlYWNjb3VudDpkZWZhdWx0OmZha2UifQ.DqjUXe34SzCP4NCwbhqV9EuksfzmTSLhJzkE_URyufeGJDn-Gw0_JS-_KmxZSdAO0XXNzB1tJNM1NCVW-V6YbThnPUw5WY4V2J6U1W72c2dzNBx_ipBxGBZ632ZnpViIRu6tL2guT36lWa8YnMDF_OY8sHhl_3kJ6MRxNxY41vAuf45mohi3gri46Kpzc3pf1g6PJ-0oogvUsZ2nBFv1mIdciGBV0zejMKc5Bnxur1L-hEQ9EgZrJ7o0yQRCWYgam_yo_M38EsB8b-suTzQJMA-pRgApOb9dHIV6YAE_b3g_pGkJjrPYzV4IJC1CiPfdz1SAjm7e0ARXtZmaoPltjQ"
|
||||
},
|
||||
"status": {
|
||||
"user": {},
|
||||
"error": "[invalid bearer token, Token has been invalidated]"
|
||||
}
|
||||
}
|
||||
*/
|
||||
return &authv1.TokenReview{
|
||||
TypeMeta: metav1.TypeMeta{
|
||||
Kind: "TokenReview",
|
||||
APIVersion: "authentication.k8s.io/v1",
|
||||
},
|
||||
ObjectMeta: metav1.ObjectMeta{},
|
||||
Spec: authv1.TokenReviewSpec{
|
||||
Token: jwt,
|
||||
},
|
||||
Status: authv1.TokenReviewStatus{
|
||||
User: authv1.UserInfo{},
|
||||
Error: "[invalid bearer token, Token has been invalidated]",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func createTokenReviewFound(namespace, name, uid, jwt string) *authv1.TokenReview {
|
||||
/*
|
||||
STATUS: 201
|
||||
{
|
||||
"kind": "TokenReview",
|
||||
"apiVersion": "authentication.k8s.io/v1",
|
||||
"metadata": {
|
||||
"creationTimestamp": null
|
||||
},
|
||||
"spec": {
|
||||
"token": "eyJhbGciOiJSUzI1NiIsImtpZCI6IiJ9.eyJpc3MiOiJrdWJlcm5ldGVzL3NlcnZpY2VhY2NvdW50Iiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9uYW1lc3BhY2UiOiJkZWZhdWx0Iiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9zZWNyZXQubmFtZSI6ImRlbW8tdG9rZW4tbTljdm4iLCJrdWJlcm5ldGVzLmlvL3NlcnZpY2VhY2NvdW50L3NlcnZpY2UtYWNjb3VudC5uYW1lIjoiZGVtbyIsImt1YmVybmV0ZXMuaW8vc2VydmljZWFjY291bnQvc2VydmljZS1hY2NvdW50LnVpZCI6IjlmZjUxZmY0LTU1N2UtMTFlOS05Njg3LTQ4ZTZjOGI4ZWNiNSIsInN1YiI6InN5c3RlbTpzZXJ2aWNlYWNjb3VudDpkZWZhdWx0OmRlbW8ifQ.UJEphtrN261gy9WCl4ZKjm2PRDLDkc3Xg9VcDGfzyroOqFQ6sog5dVAb9voc5Nc0-H5b1yGwxDViEMucwKvZpA5pi7VEx_OskK-KTWXSmafM0Xg_AvzpU9Ed5TSRno-OhXaAraxdjXoC4myh1ay2DMeHUusJg_ibqcYJrWx-6MO1bH_ObORtAKhoST_8fzkqNAlZmsQ87FinQvYN5mzDXYukl-eeRdBgQUBkWvEb-Ju6cc0-QE4sUQ4IH_fs0fUyX_xc0om0SZGWLP909FTz4V8LxV8kr6L7irxROiS1jn3Fvyc9ur1PamVf3JOPPrOyfmKbaGRiWJM32b3buQw7cg"
|
||||
},
|
||||
"status": {
|
||||
"authenticated": true,
|
||||
"user": {
|
||||
"username": "system:serviceaccount:default:demo",
|
||||
"uid": "9ff51ff4-557e-11e9-9687-48e6c8b8ecb5",
|
||||
"groups": [
|
||||
"system:serviceaccounts",
|
||||
"system:serviceaccounts:default",
|
||||
"system:authenticated"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
*/
|
||||
return &authv1.TokenReview{
|
||||
TypeMeta: metav1.TypeMeta{
|
||||
Kind: "TokenReview",
|
||||
APIVersion: "authentication.k8s.io/v1",
|
||||
},
|
||||
ObjectMeta: metav1.ObjectMeta{},
|
||||
Spec: authv1.TokenReviewSpec{
|
||||
Token: jwt,
|
||||
},
|
||||
Status: authv1.TokenReviewStatus{
|
||||
Authenticated: true,
|
||||
User: authv1.UserInfo{
|
||||
Username: "system:serviceaccount:" + namespace + ":" + name,
|
||||
UID: uid,
|
||||
Groups: []string{
|
||||
"system:serviceaccounts",
|
||||
"system:serviceaccounts:default",
|
||||
"system:authenticated",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func createReadServiceAccountForbidden(namespace, name string) *metav1.Status {
|
||||
/*
|
||||
STATUS: 403
|
||||
{
|
||||
"kind": "Status",
|
||||
"apiVersion": "v1",
|
||||
"metadata": {},
|
||||
"status": "Failure",
|
||||
"message": "serviceaccounts \"demo\" is forbidden: User \"system:serviceaccount:default:admin\" cannot get resource \"serviceaccounts\" in API group \"\" in the namespace \"default\"",
|
||||
"reason": "Forbidden",
|
||||
"details": {
|
||||
"name": "demo",
|
||||
"kind": "serviceaccounts"
|
||||
},
|
||||
"code": 403
|
||||
}
|
||||
*/
|
||||
return createStatus(
|
||||
metav1.StatusFailure,
|
||||
"serviceaccounts \""+name+"\" is forbidden: User \"system:serviceaccount:"+namespace+":"+name+"\" cannot get resource \"serviceaccounts\" in API group \"\" in the namespace \""+namespace+"\"",
|
||||
metav1.StatusReasonForbidden,
|
||||
&metav1.StatusDetails{
|
||||
Kind: "serviceaccounts",
|
||||
Name: name,
|
||||
},
|
||||
403,
|
||||
)
|
||||
}
|
||||
|
||||
func createReadServiceAccountForbidden_NoAuthz() *metav1.Status {
|
||||
// missing bearer token header 403
|
||||
/*
|
||||
{
|
||||
"kind": "Status",
|
||||
"apiVersion": "v1",
|
||||
"metadata": {},
|
||||
"status": "Failure",
|
||||
"message": "serviceaccounts \"demo\" is forbidden: User \"system:anonymous\" cannot get resource \"serviceaccounts\" in API group \"\" in the namespace \"default\"",
|
||||
"reason": "Forbidden",
|
||||
"details": {
|
||||
"name": "demo",
|
||||
"kind": "serviceaccounts"
|
||||
},
|
||||
"code": 403
|
||||
}
|
||||
*/
|
||||
return createStatus(
|
||||
metav1.StatusFailure,
|
||||
"serviceaccounts \"PLACEHOLDER\" is forbidden: User \"system:anonymous\" cannot get resource \"serviceaccounts\" in API group \"\" in the namespace \"default\"",
|
||||
metav1.StatusReasonForbidden,
|
||||
&metav1.StatusDetails{
|
||||
Kind: "serviceaccounts",
|
||||
Name: "PLACEHOLDER",
|
||||
},
|
||||
403,
|
||||
)
|
||||
}
|
||||
|
||||
func createReadServiceAccountNotFound(namespace, name string) *metav1.Status {
|
||||
/*
|
||||
STATUS: 404
|
||||
{
|
||||
"kind": "Status",
|
||||
"apiVersion": "v1",
|
||||
"metadata": {},
|
||||
"status": "Failure",
|
||||
"message": "serviceaccounts \"demo\" not found",
|
||||
"reason": "NotFound",
|
||||
"details": {
|
||||
"name": "demo",
|
||||
"kind": "serviceaccounts"
|
||||
},
|
||||
"code": 404
|
||||
}
|
||||
*/
|
||||
return createStatus(
|
||||
metav1.StatusFailure,
|
||||
"serviceaccounts \""+name+"\" not found",
|
||||
metav1.StatusReasonNotFound,
|
||||
&metav1.StatusDetails{
|
||||
Kind: "serviceaccounts",
|
||||
Name: name,
|
||||
},
|
||||
404,
|
||||
)
|
||||
}
|
||||
|
||||
func createReadServiceAccountFound(namespace, name, uid, overrideAnnotation, jwt string) *corev1.ServiceAccount {
|
||||
/*
|
||||
STATUS: 200
|
||||
{
|
||||
"kind": "ServiceAccount",
|
||||
"apiVersion": "v1",
|
||||
"metadata": {
|
||||
"name": "demo",
|
||||
"namespace": "default",
|
||||
"selfLink": "/api/v1/namespaces/default/serviceaccounts/demo",
|
||||
"uid": "9ff51ff4-557e-11e9-9687-48e6c8b8ecb5",
|
||||
"resourceVersion": "2101",
|
||||
"creationTimestamp": "2019-04-02T19:36:34Z",
|
||||
"annotations": {
|
||||
"consul.hashicorp.com/service-name": "actual",
|
||||
"kubectl.kubernetes.io/last-applied-configuration": "{\"apiVersion\":\"v1\",\"kind\":\"ServiceAccount\",\"metadata\":{\"annotations\":{\"consul.hashicorp.com/service-name\":\"actual\"},\"name\":\"demo\",\"namespace\":\"default\"}}\n"
|
||||
}
|
||||
},
|
||||
"secrets": [
|
||||
{
|
||||
"name": "demo-token-m9cvn"
|
||||
}
|
||||
]
|
||||
}
|
||||
*/
|
||||
sa := &corev1.ServiceAccount{
|
||||
TypeMeta: metav1.TypeMeta{
|
||||
Kind: "ServiceAccount",
|
||||
APIVersion: "v1",
|
||||
},
|
||||
ObjectMeta: metav1.ObjectMeta{
|
||||
Name: name,
|
||||
Namespace: namespace,
|
||||
SelfLink: "/api/v1/namespaces/" + namespace + "/serviceaccounts/" + name,
|
||||
UID: types.UID(uid),
|
||||
ResourceVersion: "123",
|
||||
CreationTimestamp: metav1.Time{Time: time.Now()},
|
||||
},
|
||||
Secrets: []corev1.ObjectReference{
|
||||
corev1.ObjectReference{
|
||||
Name: name + "-token-m9cvn",
|
||||
},
|
||||
},
|
||||
}
|
||||
if overrideAnnotation != "" {
|
||||
sa.ObjectMeta.Annotations = map[string]string{
|
||||
"consul.hashicorp.com/service-name": overrideAnnotation,
|
||||
}
|
||||
}
|
||||
|
||||
return sa
|
||||
}
|
||||
|
||||
func createStatus(status, message string, reason metav1.StatusReason, details *metav1.StatusDetails, code int32) *metav1.Status {
|
||||
return &metav1.Status{
|
||||
TypeMeta: metav1.TypeMeta{
|
||||
Kind: "Status",
|
||||
APIVersion: "v1",
|
||||
},
|
||||
ListMeta: metav1.ListMeta{},
|
||||
Status: status,
|
||||
Message: message,
|
||||
Reason: reason,
|
||||
Details: details,
|
||||
Code: code,
|
||||
}
|
||||
}
|
|
@ -0,0 +1,166 @@
|
|||
package testauth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/hashicorp/consul/acl"
|
||||
"github.com/hashicorp/consul/agent/consul/authmethod"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/hashicorp/go-uuid"
|
||||
)
|
||||
|
||||
func init() {
|
||||
authmethod.Register("testing", newValidator)
|
||||
}
|
||||
|
||||
var (
|
||||
tokenDatabaseMu sync.Mutex
|
||||
tokenDatabase map[string]map[string]map[string]string // session => token => fieldmap
|
||||
)
|
||||
|
||||
func StartSession() string {
|
||||
sessionID, err := uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return sessionID
|
||||
}
|
||||
|
||||
func ResetSession(sessionID string) {
|
||||
tokenDatabaseMu.Lock()
|
||||
defer tokenDatabaseMu.Unlock()
|
||||
if tokenDatabase != nil {
|
||||
delete(tokenDatabase, sessionID)
|
||||
}
|
||||
}
|
||||
|
||||
func InstallSessionToken(sessionID string, token string, namespace, name, uid string) {
|
||||
fields := map[string]string{
|
||||
serviceAccountNamespaceField: namespace,
|
||||
serviceAccountNameField: name,
|
||||
serviceAccountUIDField: uid,
|
||||
}
|
||||
|
||||
tokenDatabaseMu.Lock()
|
||||
defer tokenDatabaseMu.Unlock()
|
||||
if tokenDatabase == nil {
|
||||
tokenDatabase = make(map[string]map[string]map[string]string)
|
||||
}
|
||||
sdb, ok := tokenDatabase[sessionID]
|
||||
if !ok {
|
||||
sdb = make(map[string]map[string]string)
|
||||
tokenDatabase[sessionID] = sdb
|
||||
}
|
||||
sdb[token] = fields
|
||||
}
|
||||
|
||||
func GetSessionToken(sessionID string, token string) (map[string]string, bool) {
|
||||
tokenDatabaseMu.Lock()
|
||||
defer tokenDatabaseMu.Unlock()
|
||||
if tokenDatabase == nil {
|
||||
return nil, false
|
||||
}
|
||||
sdb, ok := tokenDatabase[sessionID]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
fields, ok := sdb[token]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
fmCopy := make(map[string]string)
|
||||
for k, v := range fields {
|
||||
fmCopy[k] = v
|
||||
}
|
||||
|
||||
return fmCopy, true
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
SessionID string // unique identifier for this set of tokens in the database
|
||||
}
|
||||
|
||||
func newValidator(method *structs.ACLAuthMethod) (authmethod.Validator, error) {
|
||||
if method.Type != "testing" {
|
||||
return nil, fmt.Errorf("%q is not a testing auth method", method.Name)
|
||||
}
|
||||
|
||||
var config Config
|
||||
if err := authmethod.ParseConfig(method.Config, &config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if config.SessionID == "" {
|
||||
// If you don't explicitly create one, we create a random one but you
|
||||
// won't have access to it. Useful if you are testing everything EXCEPT
|
||||
// ValidateToken().
|
||||
config.SessionID = StartSession()
|
||||
}
|
||||
|
||||
return &Validator{
|
||||
name: method.Name,
|
||||
config: &config,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type Validator struct {
|
||||
name string
|
||||
config *Config
|
||||
}
|
||||
|
||||
func (v *Validator) Name() string { return v.name }
|
||||
|
||||
// ValidateLogin takes raw user-provided auth method metadata and ensures it is
|
||||
// sane, provably correct, and currently valid. Relevant identifying data is
|
||||
// extracted and returned for immediate use by the role binding process.
|
||||
//
|
||||
// Depending upon the method, it may make sense to use these calls to continue
|
||||
// to extend the life of the underlying token.
|
||||
//
|
||||
// Returns auth method specific metadata suitable for the Role Binding process.
|
||||
func (v *Validator) ValidateLogin(loginToken string) (map[string]string, error) {
|
||||
fields, valid := GetSessionToken(v.config.SessionID, loginToken)
|
||||
if !valid {
|
||||
return nil, acl.ErrNotFound
|
||||
}
|
||||
|
||||
return fields, nil
|
||||
}
|
||||
|
||||
func (v *Validator) AvailableFields() []string { return availableFields }
|
||||
|
||||
const (
|
||||
serviceAccountNamespaceField = "serviceaccount.namespace"
|
||||
serviceAccountNameField = "serviceaccount.name"
|
||||
serviceAccountUIDField = "serviceaccount.uid"
|
||||
)
|
||||
|
||||
var availableFields = []string{
|
||||
serviceAccountNamespaceField,
|
||||
serviceAccountNameField,
|
||||
serviceAccountUIDField,
|
||||
}
|
||||
|
||||
// MakeFieldMapSelectable converts a field map as returned by ValidateLogin
|
||||
// into a structure suitable for selection with a binding rule.
|
||||
func (v *Validator) MakeFieldMapSelectable(fieldMap map[string]string) interface{} {
|
||||
return &selectableVars{
|
||||
ServiceAccount: selectableServiceAccount{
|
||||
Namespace: fieldMap[serviceAccountNamespaceField],
|
||||
Name: fieldMap[serviceAccountNameField],
|
||||
UID: fieldMap[serviceAccountUIDField],
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type selectableVars struct {
|
||||
ServiceAccount selectableServiceAccount `bexpr:"serviceaccount"`
|
||||
}
|
||||
|
||||
type selectableServiceAccount struct {
|
||||
Namespace string `bexpr:"namespace"`
|
||||
Name string `bexpr:"name"`
|
||||
UID string `bexpr:"uid"`
|
||||
}
|
|
@ -32,6 +32,10 @@ func init() {
|
|||
registerCommand(structs.ConfigEntryRequestType, (*FSM).applyConfigEntryOperation)
|
||||
registerCommand(structs.ACLRoleSetRequestType, (*FSM).applyACLRoleSetOperation)
|
||||
registerCommand(structs.ACLRoleDeleteRequestType, (*FSM).applyACLRoleDeleteOperation)
|
||||
registerCommand(structs.ACLBindingRuleSetRequestType, (*FSM).applyACLBindingRuleSetOperation)
|
||||
registerCommand(structs.ACLBindingRuleDeleteRequestType, (*FSM).applyACLBindingRuleDeleteOperation)
|
||||
registerCommand(structs.ACLAuthMethodSetRequestType, (*FSM).applyACLAuthMethodSetOperation)
|
||||
registerCommand(structs.ACLAuthMethodDeleteRequestType, (*FSM).applyACLAuthMethodDeleteOperation)
|
||||
}
|
||||
|
||||
func (c *FSM) applyRegister(buf []byte, index uint64) interface{} {
|
||||
|
@ -476,3 +480,47 @@ func (c *FSM) applyACLRoleDeleteOperation(buf []byte, index uint64) interface{}
|
|||
|
||||
return c.state.ACLRoleBatchDelete(index, req.RoleIDs)
|
||||
}
|
||||
|
||||
func (c *FSM) applyACLBindingRuleSetOperation(buf []byte, index uint64) interface{} {
|
||||
var req structs.ACLBindingRuleBatchSetRequest
|
||||
if err := structs.Decode(buf, &req); err != nil {
|
||||
panic(fmt.Errorf("failed to decode request: %v", err))
|
||||
}
|
||||
defer metrics.MeasureSinceWithLabels([]string{"fsm", "acl", "bindingrule"}, time.Now(),
|
||||
[]metrics.Label{{Name: "op", Value: "upsert"}})
|
||||
|
||||
return c.state.ACLBindingRuleBatchSet(index, req.BindingRules)
|
||||
}
|
||||
|
||||
func (c *FSM) applyACLBindingRuleDeleteOperation(buf []byte, index uint64) interface{} {
|
||||
var req structs.ACLBindingRuleBatchDeleteRequest
|
||||
if err := structs.Decode(buf, &req); err != nil {
|
||||
panic(fmt.Errorf("failed to decode request: %v", err))
|
||||
}
|
||||
defer metrics.MeasureSinceWithLabels([]string{"fsm", "acl", "bindingrule"}, time.Now(),
|
||||
[]metrics.Label{{Name: "op", Value: "delete"}})
|
||||
|
||||
return c.state.ACLBindingRuleBatchDelete(index, req.BindingRuleIDs)
|
||||
}
|
||||
|
||||
func (c *FSM) applyACLAuthMethodSetOperation(buf []byte, index uint64) interface{} {
|
||||
var req structs.ACLAuthMethodBatchSetRequest
|
||||
if err := structs.Decode(buf, &req); err != nil {
|
||||
panic(fmt.Errorf("failed to decode request: %v", err))
|
||||
}
|
||||
defer metrics.MeasureSinceWithLabels([]string{"fsm", "acl", "authmethod"}, time.Now(),
|
||||
[]metrics.Label{{Name: "op", Value: "upsert"}})
|
||||
|
||||
return c.state.ACLAuthMethodBatchSet(index, req.AuthMethods)
|
||||
}
|
||||
|
||||
func (c *FSM) applyACLAuthMethodDeleteOperation(buf []byte, index uint64) interface{} {
|
||||
var req structs.ACLAuthMethodBatchDeleteRequest
|
||||
if err := structs.Decode(buf, &req); err != nil {
|
||||
panic(fmt.Errorf("failed to decode request: %v", err))
|
||||
}
|
||||
defer metrics.MeasureSinceWithLabels([]string{"fsm", "acl", "authmethod"}, time.Now(),
|
||||
[]metrics.Label{{Name: "op", Value: "delete"}})
|
||||
|
||||
return c.state.ACLAuthMethodBatchDelete(index, req.AuthMethodNames)
|
||||
}
|
||||
|
|
|
@ -29,6 +29,8 @@ func init() {
|
|||
registerRestorer(structs.ACLPolicySetRequestType, restorePolicy)
|
||||
registerRestorer(structs.ConfigEntryRequestType, restoreConfigEntry)
|
||||
registerRestorer(structs.ACLRoleSetRequestType, restoreRole)
|
||||
registerRestorer(structs.ACLBindingRuleSetRequestType, restoreBindingRule)
|
||||
registerRestorer(structs.ACLAuthMethodSetRequestType, restoreAuthMethod)
|
||||
}
|
||||
|
||||
func persistOSS(s *snapshot, sink raft.SnapshotSink, encoder *codec.Encoder) error {
|
||||
|
@ -218,6 +220,34 @@ func (s *snapshot) persistACLs(sink raft.SnapshotSink,
|
|||
}
|
||||
}
|
||||
|
||||
rules, err := s.state.ACLBindingRules()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for rule := rules.Next(); rule != nil; rule = rules.Next() {
|
||||
if _, err := sink.Write([]byte{byte(structs.ACLBindingRuleSetRequestType)}); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := encoder.Encode(rule.(*structs.ACLBindingRule)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
methods, err := s.state.ACLAuthMethods()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for method := methods.Next(); method != nil; method = rules.Next() {
|
||||
if _, err := sink.Write([]byte{byte(structs.ACLAuthMethodSetRequestType)}); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := encoder.Encode(method.(*structs.ACLAuthMethod)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -626,3 +656,19 @@ func restoreRole(header *snapshotHeader, restore *state.Restore, decoder *codec.
|
|||
}
|
||||
return restore.ACLRole(&req)
|
||||
}
|
||||
|
||||
func restoreBindingRule(header *snapshotHeader, restore *state.Restore, decoder *codec.Decoder) error {
|
||||
var req structs.ACLBindingRule
|
||||
if err := decoder.Decode(&req); err != nil {
|
||||
return err
|
||||
}
|
||||
return restore.ACLBindingRule(&req)
|
||||
}
|
||||
|
||||
func restoreAuthMethod(header *snapshotHeader, restore *state.Restore, decoder *codec.Decoder) error {
|
||||
var req structs.ACLAuthMethod
|
||||
if err := decoder.Decode(&req); err != nil {
|
||||
return err
|
||||
}
|
||||
return restore.ACLAuthMethod(&req)
|
||||
}
|
||||
|
|
|
@ -86,7 +86,7 @@ func TestFSM_SnapshotRestore_OSS(t *testing.T) {
|
|||
})
|
||||
session := &structs.Session{ID: generateUUID(), Node: "foo"}
|
||||
fsm.state.SessionCreate(9, session)
|
||||
policy := structs.ACLPolicy{
|
||||
policy := &structs.ACLPolicy{
|
||||
ID: structs.ACLPolicyGlobalManagementID,
|
||||
Name: "global-management",
|
||||
Description: "Builtin Policy that grants unlimited access",
|
||||
|
@ -94,7 +94,20 @@ func TestFSM_SnapshotRestore_OSS(t *testing.T) {
|
|||
Syntax: acl.SyntaxCurrent,
|
||||
}
|
||||
policy.SetHash(true)
|
||||
require.NoError(fsm.state.ACLPolicySet(1, &policy))
|
||||
require.NoError(fsm.state.ACLPolicySet(1, policy))
|
||||
|
||||
role := &structs.ACLRole{
|
||||
ID: "86dedd19-8fae-4594-8294-4e6948a81f9a",
|
||||
Name: "some-role",
|
||||
Description: "test snapshot role",
|
||||
ServiceIdentities: []*structs.ACLServiceIdentity{
|
||||
&structs.ACLServiceIdentity{
|
||||
ServiceName: "example",
|
||||
},
|
||||
},
|
||||
}
|
||||
role.SetHash(true)
|
||||
require.NoError(fsm.state.ACLRoleSet(1, role))
|
||||
|
||||
token := &structs.ACLToken{
|
||||
AccessorID: "30fca056-9fbb-4455-b94a-bf0e2bc575d6",
|
||||
|
@ -112,6 +125,26 @@ func TestFSM_SnapshotRestore_OSS(t *testing.T) {
|
|||
}
|
||||
require.NoError(fsm.state.ACLBootstrap(10, 0, token, false))
|
||||
|
||||
method := &structs.ACLAuthMethod{
|
||||
Name: "some-method",
|
||||
Type: "testing",
|
||||
Description: "test snapshot auth method",
|
||||
Config: map[string]interface{}{
|
||||
"SessionID": "952ebfa8-2a42-46f0-bcd3-fd98a842000e",
|
||||
},
|
||||
}
|
||||
require.NoError(fsm.state.ACLAuthMethodSet(1, method))
|
||||
|
||||
bindingRule := &structs.ACLBindingRule{
|
||||
ID: "85184c52-5997-4a84-9817-5945f2632a17",
|
||||
Description: "test snapshot binding rule",
|
||||
AuthMethod: "some-method",
|
||||
Selector: "serviceaccount.namespace==default",
|
||||
BindType: structs.BindingRuleBindTypeService,
|
||||
BindName: "${serviceaccount.name}",
|
||||
}
|
||||
require.NoError(fsm.state.ACLBindingRuleSet(1, bindingRule))
|
||||
|
||||
fsm.state.KVSSet(11, &structs.DirEntry{
|
||||
Key: "/remove",
|
||||
Value: []byte("foo"),
|
||||
|
@ -314,21 +347,40 @@ func TestFSM_SnapshotRestore_OSS(t *testing.T) {
|
|||
t.Fatalf("bad index: %d", idx)
|
||||
}
|
||||
|
||||
// Verify ACL Token is restored
|
||||
_, a, err := fsm2.state.ACLTokenGetByAccessor(nil, token.AccessorID)
|
||||
// Verify ACL Binding Rule is restored
|
||||
_, bindingRule2, err := fsm2.state.ACLBindingRuleGetByID(nil, bindingRule.ID)
|
||||
require.NoError(err)
|
||||
require.Equal(token.AccessorID, a.AccessorID)
|
||||
require.Equal(token.ModifyIndex, a.ModifyIndex)
|
||||
require.Equal(bindingRule, bindingRule2)
|
||||
|
||||
// Verify ACL Auth Method is restored
|
||||
_, method2, err := fsm2.state.ACLAuthMethodGetByName(nil, method.Name)
|
||||
require.NoError(err)
|
||||
require.Equal(method, method2)
|
||||
|
||||
// Verify ACL Token is restored
|
||||
_, token2, err := fsm2.state.ACLTokenGetByAccessor(nil, token.AccessorID)
|
||||
require.NoError(err)
|
||||
{
|
||||
// time.Time is tricky to compare generically when it takes a ser/deserialization round trip.
|
||||
require.True(token.CreateTime.Equal(token2.CreateTime))
|
||||
token2.CreateTime = token.CreateTime
|
||||
}
|
||||
require.Equal(token, token2)
|
||||
|
||||
// Verify the acl-token-bootstrap index was restored
|
||||
canBootstrap, index, err := fsm2.state.CanBootstrapACLToken()
|
||||
require.False(canBootstrap)
|
||||
require.True(index > 0)
|
||||
|
||||
// Verify ACL Role is restored
|
||||
_, role2, err := fsm2.state.ACLRoleGetByID(nil, role.ID)
|
||||
require.NoError(err)
|
||||
require.Equal(role, role2)
|
||||
|
||||
// Verify ACL Policy is restored
|
||||
_, policy2, err := fsm2.state.ACLPolicyGetByID(nil, structs.ACLPolicyGlobalManagementID)
|
||||
require.NoError(err)
|
||||
require.Equal(policy.Name, policy2.Name)
|
||||
require.Equal(policy, policy2)
|
||||
|
||||
// Verify tombstones are restored
|
||||
func() {
|
||||
|
|
|
@ -427,6 +427,10 @@ func (s *Server) initializeACLs(upgrade bool) error {
|
|||
// leader.
|
||||
s.acls.cache.Purge()
|
||||
|
||||
// Purge the auth method validators since they could've changed while we
|
||||
// were not leader.
|
||||
s.purgeAuthMethodValidators()
|
||||
|
||||
// Remove any token affected by CVE-2019-8336
|
||||
if !s.InACLDatacenter() {
|
||||
_, token, err := s.fsm.State().ACLTokenGetBySecret(nil, redactedToken)
|
||||
|
|
|
@ -115,6 +115,9 @@ type Server struct {
|
|||
aclTokenReapLock sync.RWMutex
|
||||
aclTokenReapEnabled bool
|
||||
|
||||
aclAuthMethodValidators map[string]*authMethodValidatorEntry
|
||||
aclAuthMethodValidatorLock sync.RWMutex
|
||||
|
||||
// DEPRECATED (ACL-Legacy-Compat) - only needed while we support both
|
||||
// useNewACLs is used to determine whether we can use new ACLs or not
|
||||
useNewACLs int32
|
||||
|
|
|
@ -240,12 +240,20 @@ func tokensTableSchema() *memdb.TableSchema {
|
|||
Indexer: &TokenPoliciesIndex{},
|
||||
},
|
||||
"roles": &memdb.IndexSchema{
|
||||
Name: "roles",
|
||||
// Need to allow missing for the anonymous token
|
||||
Name: "roles",
|
||||
AllowMissing: true,
|
||||
Unique: false,
|
||||
Indexer: &TokenRolesIndex{},
|
||||
},
|
||||
"authmethod": &memdb.IndexSchema{
|
||||
Name: "authmethod",
|
||||
AllowMissing: true,
|
||||
Unique: false,
|
||||
Indexer: &memdb.StringFieldIndex{
|
||||
Field: "AuthMethod",
|
||||
Lowercase: false,
|
||||
},
|
||||
},
|
||||
"local": &memdb.IndexSchema{
|
||||
Name: "local",
|
||||
AllowMissing: false,
|
||||
|
@ -349,10 +357,54 @@ func rolesTableSchema() *memdb.TableSchema {
|
|||
}
|
||||
}
|
||||
|
||||
func bindingRulesTableSchema() *memdb.TableSchema {
|
||||
return &memdb.TableSchema{
|
||||
Name: "acl-binding-rules",
|
||||
Indexes: map[string]*memdb.IndexSchema{
|
||||
"id": &memdb.IndexSchema{
|
||||
Name: "id",
|
||||
AllowMissing: false,
|
||||
Unique: true,
|
||||
Indexer: &memdb.UUIDFieldIndex{
|
||||
Field: "ID",
|
||||
},
|
||||
},
|
||||
"authmethod": &memdb.IndexSchema{
|
||||
Name: "authmethod",
|
||||
AllowMissing: false,
|
||||
Unique: false,
|
||||
Indexer: &memdb.StringFieldIndex{
|
||||
Field: "AuthMethod",
|
||||
Lowercase: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func authMethodsTableSchema() *memdb.TableSchema {
|
||||
return &memdb.TableSchema{
|
||||
Name: "acl-auth-methods",
|
||||
Indexes: map[string]*memdb.IndexSchema{
|
||||
"id": &memdb.IndexSchema{
|
||||
Name: "id",
|
||||
AllowMissing: false,
|
||||
Unique: true,
|
||||
Indexer: &memdb.StringFieldIndex{
|
||||
Field: "Name",
|
||||
Lowercase: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func init() {
|
||||
registerSchema(tokensTableSchema)
|
||||
registerSchema(policiesTableSchema)
|
||||
registerSchema(rolesTableSchema)
|
||||
registerSchema(bindingRulesTableSchema)
|
||||
registerSchema(authMethodsTableSchema)
|
||||
}
|
||||
|
||||
// ACLTokens is used when saving a snapshot
|
||||
|
@ -417,6 +469,46 @@ func (s *Restore) ACLRole(role *structs.ACLRole) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// ACLBindingRules is used when saving a snapshot
|
||||
func (s *Snapshot) ACLBindingRules() (memdb.ResultIterator, error) {
|
||||
iter, err := s.tx.Get("acl-binding-rules", "id")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return iter, nil
|
||||
}
|
||||
|
||||
func (s *Restore) ACLBindingRule(rule *structs.ACLBindingRule) error {
|
||||
if err := s.tx.Insert("acl-binding-rules", rule); err != nil {
|
||||
return fmt.Errorf("failed restoring acl binding rule: %s", err)
|
||||
}
|
||||
|
||||
if err := indexUpdateMaxTxn(s.tx, rule.ModifyIndex, "acl-binding-rules"); err != nil {
|
||||
return fmt.Errorf("failed updating index: %s", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ACLAuthMethods is used when saving a snapshot
|
||||
func (s *Snapshot) ACLAuthMethods() (memdb.ResultIterator, error) {
|
||||
iter, err := s.tx.Get("acl-auth-methods", "id")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return iter, nil
|
||||
}
|
||||
|
||||
func (s *Restore) ACLAuthMethod(method *structs.ACLAuthMethod) error {
|
||||
if err := s.tx.Insert("acl-auth-methods", method); err != nil {
|
||||
return fmt.Errorf("failed restoring acl auth method: %s", err)
|
||||
}
|
||||
|
||||
if err := indexUpdateMaxTxn(s.tx, method.ModifyIndex, "acl-auth-methods"); err != nil {
|
||||
return fmt.Errorf("failed updating index: %s", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ACLBootstrap is used to perform a one-time ACL bootstrap operation on a
|
||||
// cluster to get the first management token.
|
||||
func (s *Store) ACLBootstrap(idx, resetIndex uint64, token *structs.ACLToken, legacy bool) error {
|
||||
|
@ -789,6 +881,15 @@ func (s *Store) aclTokenSetTxn(tx *memdb.Txn, idx uint64, token *structs.ACLToke
|
|||
return err
|
||||
}
|
||||
|
||||
if token.AuthMethod != "" {
|
||||
method, err := s.getAuthMethodWithTxn(tx, nil, token.AuthMethod, "id")
|
||||
if err != nil {
|
||||
return err
|
||||
} else if method == nil {
|
||||
return fmt.Errorf("No such auth method with Name: %s", token.AuthMethod)
|
||||
}
|
||||
}
|
||||
|
||||
for _, svcid := range token.ServiceIdentities {
|
||||
if svcid.ServiceName == "" {
|
||||
return fmt.Errorf("Encountered a Token with an empty service identity name in the state store")
|
||||
|
@ -890,7 +991,7 @@ func (s *Store) aclTokenGetTxn(tx *memdb.Txn, ws memdb.WatchSet, value, index st
|
|||
}
|
||||
|
||||
// ACLTokenList is used to list out all of the ACLs in the state store.
|
||||
func (s *Store) ACLTokenList(ws memdb.WatchSet, local, global bool, policy, role string) (uint64, structs.ACLTokens, error) {
|
||||
func (s *Store) ACLTokenList(ws memdb.WatchSet, local, global bool, policy, role, methodName string) (uint64, structs.ACLTokens, error) {
|
||||
tx := s.db.Txn(false)
|
||||
defer tx.Abort()
|
||||
|
||||
|
@ -901,57 +1002,53 @@ func (s *Store) ACLTokenList(ws memdb.WatchSet, local, global bool, policy, role
|
|||
// to false but for defaulted structs (zero values for both) we want it to list out
|
||||
// all tokens so our checks just ensure that global == local
|
||||
|
||||
if policy != "" && role != "" {
|
||||
return 0, nil, fmt.Errorf("cannot filter by role and policy at the same time")
|
||||
}
|
||||
needLocalityFilter := false
|
||||
if policy == "" && role == "" && methodName == "" {
|
||||
if global == local {
|
||||
iter, err = tx.Get("acl-tokens", "id")
|
||||
} else if global {
|
||||
iter, err = tx.Get("acl-tokens", "local", false)
|
||||
} else {
|
||||
iter, err = tx.Get("acl-tokens", "local", true)
|
||||
}
|
||||
|
||||
if policy != "" {
|
||||
} else if policy != "" && role == "" && methodName == "" {
|
||||
iter, err = tx.Get("acl-tokens", "policies", policy)
|
||||
if err == nil && global != local {
|
||||
iter = memdb.NewFilterIterator(iter, func(raw interface{}) bool {
|
||||
token, ok := raw.(*structs.ACLToken)
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
needLocalityFilter = true
|
||||
|
||||
if global && !token.Local {
|
||||
return false
|
||||
} else if local && token.Local {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
}
|
||||
} else if role != "" {
|
||||
} else if policy == "" && role != "" && methodName == "" {
|
||||
iter, err = tx.Get("acl-tokens", "roles", role)
|
||||
if err == nil && global != local {
|
||||
iter = memdb.NewFilterIterator(iter, func(raw interface{}) bool {
|
||||
token, ok := raw.(*structs.ACLToken)
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
needLocalityFilter = true
|
||||
|
||||
if global && !token.Local {
|
||||
return false
|
||||
} else if local && token.Local {
|
||||
return false
|
||||
}
|
||||
} else if policy == "" && role == "" && methodName != "" {
|
||||
iter, err = tx.Get("acl-tokens", "authmethod", methodName)
|
||||
needLocalityFilter = true
|
||||
|
||||
return true
|
||||
})
|
||||
}
|
||||
} else if global == local {
|
||||
iter, err = tx.Get("acl-tokens", "id")
|
||||
} else if global {
|
||||
iter, err = tx.Get("acl-tokens", "local", false)
|
||||
} else {
|
||||
iter, err = tx.Get("acl-tokens", "local", true)
|
||||
return 0, nil, fmt.Errorf("can only filter by one of policy, role, or methodName at a time")
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("failed acl token lookup: %v", err)
|
||||
}
|
||||
|
||||
if needLocalityFilter && global != local {
|
||||
iter = memdb.NewFilterIterator(iter, func(raw interface{}) bool {
|
||||
token, ok := raw.(*structs.ACLToken)
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
|
||||
if global && !token.Local {
|
||||
return false
|
||||
} else if local && token.Local {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
ws.Add(iter.WatchCh())
|
||||
|
||||
var result structs.ACLTokens
|
||||
|
@ -1114,6 +1211,35 @@ func (s *Store) aclTokenDeleteTxn(tx *memdb.Txn, idx uint64, value, index string
|
|||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) aclTokenDeleteAllForAuthMethodTxn(tx *memdb.Txn, idx uint64, methodName string) error {
|
||||
// collect them all
|
||||
iter, err := tx.Get("acl-tokens", "authmethod", methodName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed acl token lookup: %v", err)
|
||||
}
|
||||
|
||||
var tokens structs.ACLTokens
|
||||
for raw := iter.Next(); raw != nil; raw = iter.Next() {
|
||||
token := raw.(*structs.ACLToken)
|
||||
tokens = append(tokens, token)
|
||||
}
|
||||
|
||||
if len(tokens) > 0 {
|
||||
// delete them all
|
||||
for _, token := range tokens {
|
||||
if err := tx.Delete("acl-tokens", token); err != nil {
|
||||
return fmt.Errorf("failed deleting acl token: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := indexUpdateMaxTxn(tx, idx, "acl-tokens"); err != nil {
|
||||
return fmt.Errorf("failed updating index: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) ACLPolicyBatchSet(idx uint64, policies structs.ACLPolicies) error {
|
||||
tx := s.db.Txn(true)
|
||||
defer tx.Abort()
|
||||
|
@ -1437,7 +1563,7 @@ func (s *Store) ACLRoleBatchGet(ws memdb.WatchSet, ids []string) (uint64, struct
|
|||
tx := s.db.Txn(false)
|
||||
defer tx.Abort()
|
||||
|
||||
roles := make(structs.ACLRoles, 0)
|
||||
roles := make(structs.ACLRoles, 0, len(ids))
|
||||
for _, rid := range ids {
|
||||
role, err := s.getRoleWithTxn(tx, ws, rid, "id")
|
||||
if err != nil {
|
||||
|
@ -1579,3 +1705,384 @@ func (s *Store) aclRoleDeleteTxn(tx *memdb.Txn, idx uint64, value, index string)
|
|||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) ACLBindingRuleBatchSet(idx uint64, rules structs.ACLBindingRules) error {
|
||||
tx := s.db.Txn(true)
|
||||
defer tx.Abort()
|
||||
|
||||
for _, rule := range rules {
|
||||
if err := s.aclBindingRuleSetTxn(tx, idx, rule); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := indexUpdateMaxTxn(tx, idx, "acl-binding-rules"); err != nil {
|
||||
return fmt.Errorf("failed updating index: %s", err)
|
||||
}
|
||||
|
||||
tx.Commit()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) ACLBindingRuleSet(idx uint64, rule *structs.ACLBindingRule) error {
|
||||
tx := s.db.Txn(true)
|
||||
defer tx.Abort()
|
||||
|
||||
if err := s.aclBindingRuleSetTxn(tx, idx, rule); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := indexUpdateMaxTxn(tx, idx, "acl-binding-rules"); err != nil {
|
||||
return fmt.Errorf("failed updating index: %s", err)
|
||||
}
|
||||
|
||||
tx.Commit()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) aclBindingRuleSetTxn(tx *memdb.Txn, idx uint64, rule *structs.ACLBindingRule) error {
|
||||
// Check that the ID and AuthMethod are set
|
||||
if rule.ID == "" {
|
||||
return ErrMissingACLBindingRuleID
|
||||
} else if rule.AuthMethod == "" {
|
||||
return ErrMissingACLBindingRuleAuthMethod
|
||||
}
|
||||
|
||||
existing, err := tx.First("acl-binding-rules", "id", rule.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed acl binding rule lookup: %v", err)
|
||||
}
|
||||
|
||||
// Set the indexes
|
||||
if existing != nil {
|
||||
rule.CreateIndex = existing.(*structs.ACLBindingRule).CreateIndex
|
||||
rule.ModifyIndex = idx
|
||||
} else {
|
||||
rule.CreateIndex = idx
|
||||
rule.ModifyIndex = idx
|
||||
}
|
||||
|
||||
if method, err := tx.First("acl-auth-methods", "id", rule.AuthMethod); err != nil {
|
||||
return fmt.Errorf("failed acl auth method lookup: %v", err)
|
||||
} else if method == nil {
|
||||
return fmt.Errorf("failed inserting acl binding rule: auth method not found")
|
||||
}
|
||||
|
||||
if err := tx.Insert("acl-binding-rules", rule); err != nil {
|
||||
return fmt.Errorf("failed inserting acl binding rule: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) ACLBindingRuleGetByID(ws memdb.WatchSet, id string) (uint64, *structs.ACLBindingRule, error) {
|
||||
return s.aclBindingRuleGet(ws, id, "id")
|
||||
}
|
||||
|
||||
func (s *Store) aclBindingRuleGet(ws memdb.WatchSet, value, index string) (uint64, *structs.ACLBindingRule, error) {
|
||||
tx := s.db.Txn(false)
|
||||
defer tx.Abort()
|
||||
|
||||
watchCh, rawRule, err := tx.FirstWatch("acl-binding-rules", index, value)
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("failed acl binding rule lookup: %v", err)
|
||||
}
|
||||
ws.Add(watchCh)
|
||||
|
||||
var rule *structs.ACLBindingRule
|
||||
if rawRule != nil {
|
||||
rule = rawRule.(*structs.ACLBindingRule)
|
||||
}
|
||||
|
||||
idx := maxIndexTxn(tx, "acl-binding-rules")
|
||||
|
||||
return idx, rule, nil
|
||||
}
|
||||
|
||||
func (s *Store) ACLBindingRuleList(ws memdb.WatchSet, methodName string) (uint64, structs.ACLBindingRules, error) {
|
||||
tx := s.db.Txn(false)
|
||||
defer tx.Abort()
|
||||
|
||||
var (
|
||||
iter memdb.ResultIterator
|
||||
err error
|
||||
)
|
||||
if methodName != "" {
|
||||
iter, err = tx.Get("acl-binding-rules", "authmethod", methodName)
|
||||
} else {
|
||||
iter, err = tx.Get("acl-binding-rules", "id")
|
||||
}
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("failed acl binding rule lookup: %v", err)
|
||||
}
|
||||
ws.Add(iter.WatchCh())
|
||||
|
||||
var result structs.ACLBindingRules
|
||||
for raw := iter.Next(); raw != nil; raw = iter.Next() {
|
||||
rule := raw.(*structs.ACLBindingRule)
|
||||
result = append(result, rule)
|
||||
}
|
||||
|
||||
// Get the table index.
|
||||
idx := maxIndexTxn(tx, "acl-binding-rules")
|
||||
|
||||
return idx, result, nil
|
||||
}
|
||||
|
||||
func (s *Store) ACLBindingRuleDeleteByID(idx uint64, id string) error {
|
||||
return s.aclBindingRuleDelete(idx, id, "id")
|
||||
}
|
||||
|
||||
func (s *Store) ACLBindingRuleBatchDelete(idx uint64, bindingRuleIDs []string) error {
|
||||
tx := s.db.Txn(true)
|
||||
defer tx.Abort()
|
||||
|
||||
for _, bindingRuleID := range bindingRuleIDs {
|
||||
s.aclBindingRuleDeleteTxn(tx, idx, bindingRuleID, "id")
|
||||
}
|
||||
|
||||
if err := indexUpdateMaxTxn(tx, idx, "acl-binding-rules"); err != nil {
|
||||
return fmt.Errorf("failed updating index: %v", err)
|
||||
}
|
||||
tx.Commit()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) aclBindingRuleDelete(idx uint64, value, index string) error {
|
||||
tx := s.db.Txn(true)
|
||||
defer tx.Abort()
|
||||
|
||||
if err := s.aclBindingRuleDeleteTxn(tx, idx, value, index); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := indexUpdateMaxTxn(tx, idx, "acl-binding-rules"); err != nil {
|
||||
return fmt.Errorf("failed updating index: %v", err)
|
||||
}
|
||||
|
||||
tx.Commit()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) aclBindingRuleDeleteTxn(tx *memdb.Txn, idx uint64, value, index string) error {
|
||||
// Look up the existing binding rule
|
||||
rawRule, err := tx.First("acl-binding-rules", index, value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed acl binding rule lookup: %v", err)
|
||||
}
|
||||
|
||||
if rawRule == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
rule := rawRule.(*structs.ACLBindingRule)
|
||||
|
||||
if err := tx.Delete("acl-binding-rules", rule); err != nil {
|
||||
return fmt.Errorf("failed deleting acl binding rule: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) aclBindingRuleDeleteAllForAuthMethodTxn(tx *memdb.Txn, idx uint64, methodName string) error {
|
||||
// collect them all
|
||||
iter, err := tx.Get("acl-binding-rules", "authmethod", methodName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed acl binding rule lookup: %v", err)
|
||||
}
|
||||
|
||||
var rules structs.ACLBindingRules
|
||||
for raw := iter.Next(); raw != nil; raw = iter.Next() {
|
||||
rule := raw.(*structs.ACLBindingRule)
|
||||
rules = append(rules, rule)
|
||||
}
|
||||
|
||||
if len(rules) > 0 {
|
||||
// delete them all
|
||||
for _, rule := range rules {
|
||||
if err := tx.Delete("acl-binding-rules", rule); err != nil {
|
||||
return fmt.Errorf("failed deleting acl binding rule: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := indexUpdateMaxTxn(tx, idx, "acl-binding-rules"); err != nil {
|
||||
return fmt.Errorf("failed updating index: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) ACLAuthMethodBatchSet(idx uint64, methods structs.ACLAuthMethods) error {
|
||||
tx := s.db.Txn(true)
|
||||
defer tx.Abort()
|
||||
|
||||
for _, method := range methods {
|
||||
// this is only used when doing batch insertions for upgrades and replication. Therefore
|
||||
// we take whatever those said.
|
||||
if err := s.aclAuthMethodSetTxn(tx, idx, method); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := indexUpdateMaxTxn(tx, idx, "acl-auth-methods"); err != nil {
|
||||
return fmt.Errorf("failed updating index: %s", err)
|
||||
}
|
||||
|
||||
tx.Commit()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) ACLAuthMethodSet(idx uint64, method *structs.ACLAuthMethod) error {
|
||||
tx := s.db.Txn(true)
|
||||
defer tx.Abort()
|
||||
|
||||
if err := s.aclAuthMethodSetTxn(tx, idx, method); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := indexUpdateMaxTxn(tx, idx, "acl-auth-methods"); err != nil {
|
||||
return fmt.Errorf("failed updating index: %s", err)
|
||||
}
|
||||
|
||||
tx.Commit()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) aclAuthMethodSetTxn(tx *memdb.Txn, idx uint64, method *structs.ACLAuthMethod) error {
|
||||
// Check that the Name and Type are set
|
||||
if method.Name == "" {
|
||||
return ErrMissingACLAuthMethodName
|
||||
} else if method.Type == "" {
|
||||
return ErrMissingACLAuthMethodType
|
||||
}
|
||||
|
||||
existing, err := tx.First("acl-auth-methods", "id", method.Name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed acl auth method lookup: %v", err)
|
||||
}
|
||||
|
||||
// Set the indexes
|
||||
if existing != nil {
|
||||
method.CreateIndex = existing.(*structs.ACLAuthMethod).CreateIndex
|
||||
method.ModifyIndex = idx
|
||||
} else {
|
||||
method.CreateIndex = idx
|
||||
method.ModifyIndex = idx
|
||||
}
|
||||
|
||||
if err := tx.Insert("acl-auth-methods", method); err != nil {
|
||||
return fmt.Errorf("failed inserting acl auth method: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) ACLAuthMethodGetByName(ws memdb.WatchSet, name string) (uint64, *structs.ACLAuthMethod, error) {
|
||||
return s.aclAuthMethodGet(ws, name, "id")
|
||||
}
|
||||
|
||||
func (s *Store) aclAuthMethodGet(ws memdb.WatchSet, value, index string) (uint64, *structs.ACLAuthMethod, error) {
|
||||
tx := s.db.Txn(false)
|
||||
defer tx.Abort()
|
||||
|
||||
method, err := s.getAuthMethodWithTxn(tx, ws, value, index)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
|
||||
idx := maxIndexTxn(tx, "acl-auth-methods")
|
||||
|
||||
return idx, method, nil
|
||||
}
|
||||
|
||||
func (s *Store) getAuthMethodWithTxn(tx *memdb.Txn, ws memdb.WatchSet, value, index string) (*structs.ACLAuthMethod, error) {
|
||||
watchCh, rawMethod, err := tx.FirstWatch("acl-auth-methods", index, value)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed acl auth method lookup: %v", err)
|
||||
}
|
||||
ws.Add(watchCh)
|
||||
|
||||
if rawMethod != nil {
|
||||
return rawMethod.(*structs.ACLAuthMethod), nil
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *Store) ACLAuthMethodList(ws memdb.WatchSet) (uint64, structs.ACLAuthMethods, error) {
|
||||
tx := s.db.Txn(false)
|
||||
defer tx.Abort()
|
||||
|
||||
iter, err := tx.Get("acl-auth-methods", "id")
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("failed acl auth method lookup: %v", err)
|
||||
}
|
||||
ws.Add(iter.WatchCh())
|
||||
|
||||
var result structs.ACLAuthMethods
|
||||
for raw := iter.Next(); raw != nil; raw = iter.Next() {
|
||||
method := raw.(*structs.ACLAuthMethod)
|
||||
result = append(result, method)
|
||||
}
|
||||
|
||||
// Get the table index.
|
||||
idx := maxIndexTxn(tx, "acl-auth-methods")
|
||||
|
||||
return idx, result, nil
|
||||
}
|
||||
|
||||
func (s *Store) ACLAuthMethodDeleteByName(idx uint64, name string) error {
|
||||
return s.aclAuthMethodDelete(idx, name, "id")
|
||||
}
|
||||
|
||||
func (s *Store) ACLAuthMethodBatchDelete(idx uint64, names []string) error {
|
||||
tx := s.db.Txn(true)
|
||||
defer tx.Abort()
|
||||
|
||||
for _, name := range names {
|
||||
s.aclAuthMethodDeleteTxn(tx, idx, name, "id")
|
||||
}
|
||||
|
||||
if err := indexUpdateMaxTxn(tx, idx, "acl-auth-methods"); err != nil {
|
||||
return fmt.Errorf("failed updating index: %v", err)
|
||||
}
|
||||
tx.Commit()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) aclAuthMethodDelete(idx uint64, value, index string) error {
|
||||
tx := s.db.Txn(true)
|
||||
defer tx.Abort()
|
||||
|
||||
if err := s.aclAuthMethodDeleteTxn(tx, idx, value, index); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := indexUpdateMaxTxn(tx, idx, "acl-auth-methods"); err != nil {
|
||||
return fmt.Errorf("failed updating index: %v", err)
|
||||
}
|
||||
|
||||
tx.Commit()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) aclAuthMethodDeleteTxn(tx *memdb.Txn, idx uint64, value, index string) error {
|
||||
// Look up the existing method
|
||||
rawMethod, err := tx.First("acl-auth-methods", index, value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed acl auth method lookup: %v", err)
|
||||
}
|
||||
|
||||
if rawMethod == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
method := rawMethod.(*structs.ACLAuthMethod)
|
||||
|
||||
if err := s.aclBindingRuleDeleteAllForAuthMethodTxn(tx, idx, method.Name); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.aclTokenDeleteAllForAuthMethodTxn(tx, idx, method.Name); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := tx.Delete("acl-auth-methods", method); err != nil {
|
||||
return fmt.Errorf("failed deleting acl auth method: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -37,17 +37,29 @@ var (
|
|||
// policy with an empty Name.
|
||||
ErrMissingACLPolicyName = errors.New("Missing ACL Policy Name")
|
||||
|
||||
// ErrMissingACLRoleID is returned when an role set is called on
|
||||
// ErrMissingACLRoleID is returned when a role set is called on
|
||||
// a role with an empty ID.
|
||||
ErrMissingACLRoleID = errors.New("Missing ACL Role ID")
|
||||
|
||||
// ErrMissingACLRoleName is returned when an role set is called on
|
||||
// ErrMissingACLRoleName is returned when a role set is called on
|
||||
// a role with an empty Name.
|
||||
ErrMissingACLRoleName = errors.New("Missing ACL Role Name")
|
||||
|
||||
// ErrInvalidACLRoleName is returned when an role set is called on
|
||||
// a role with an invalid Name.
|
||||
ErrInvalidACLRoleName = errors.New("Invalid ACL Role Name")
|
||||
// ErrMissingACLBindingRuleID is returned when a binding rule set
|
||||
// is called on a binding rule with an empty ID.
|
||||
ErrMissingACLBindingRuleID = errors.New("Missing ACL Binding Rule ID")
|
||||
|
||||
// ErrMissingACLBindingRuleAuthMethod is returned when a binding rule set
|
||||
// is called on a binding rule with an empty AuthMethod.
|
||||
ErrMissingACLBindingRuleAuthMethod = errors.New("Missing ACL Binding Rule Auth Method")
|
||||
|
||||
// ErrMissingACLAuthMethodName is returned when an auth method set is
|
||||
// called on an auth method with an empty Name.
|
||||
ErrMissingACLAuthMethodName = errors.New("Missing ACL Auth Method Name")
|
||||
|
||||
// ErrMissingACLAuthMethodType is returned when an auth method set is
|
||||
// called on an auth method with an empty Type.
|
||||
ErrMissingACLAuthMethodType = errors.New("Missing ACL Auth Method Type")
|
||||
|
||||
// ErrMissingQueryID is returned when a Query set is called on
|
||||
// a Query with an empty ID.
|
||||
|
|
|
@ -6,10 +6,13 @@ import (
|
|||
"net"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/consul/agent/metadata"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/hashicorp/go-version"
|
||||
"github.com/hashicorp/hil"
|
||||
"github.com/hashicorp/hil/ast"
|
||||
"github.com/hashicorp/serf/serf"
|
||||
)
|
||||
|
||||
|
@ -322,3 +325,42 @@ func ServersGetACLMode(members []serf.Member, leader string, datacenter string)
|
|||
|
||||
return
|
||||
}
|
||||
|
||||
// InterpolateHIL processes the string as if it were HIL and interpolates only
|
||||
// the provided string->string map as possible variables.
|
||||
func InterpolateHIL(s string, vars map[string]string) (string, error) {
|
||||
if strings.Index(s, "${") == -1 {
|
||||
// Skip going to the trouble of parsing something that has no HIL.
|
||||
return s, nil
|
||||
}
|
||||
|
||||
tree, err := hil.Parse(s)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
vm := make(map[string]ast.Variable)
|
||||
for k, v := range vars {
|
||||
vm[k] = ast.Variable{
|
||||
Type: ast.TypeString,
|
||||
Value: v,
|
||||
}
|
||||
}
|
||||
|
||||
config := &hil.EvalConfig{
|
||||
GlobalScope: &ast.BasicScope{
|
||||
VarMap: vm,
|
||||
},
|
||||
}
|
||||
|
||||
result, err := hil.Eval(tree, config)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if result.Type != hil.TypeString {
|
||||
return "", fmt.Errorf("generated unexpected hil type: %s", result.Type)
|
||||
}
|
||||
|
||||
return result.Value.(string), nil
|
||||
}
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
|
||||
"github.com/hashicorp/go-version"
|
||||
"github.com/hashicorp/serf/serf"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGetPrivateIP(t *testing.T) {
|
||||
|
@ -403,3 +404,133 @@ func TestServersMeetMinimumVersion(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestInterpolateHIL(t *testing.T) {
|
||||
for _, test := range []struct {
|
||||
name string
|
||||
in string
|
||||
vars map[string]string
|
||||
exp string
|
||||
ok bool
|
||||
}{
|
||||
// valid HIL
|
||||
{
|
||||
"empty",
|
||||
"",
|
||||
map[string]string{},
|
||||
"",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"no vars",
|
||||
"nothing",
|
||||
map[string]string{},
|
||||
"nothing",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"just var",
|
||||
"${item}",
|
||||
map[string]string{"item": "value"},
|
||||
"value",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"var in middle",
|
||||
"before ${item}after",
|
||||
map[string]string{"item": "value"},
|
||||
"before valueafter",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"two vars",
|
||||
"before ${item}after ${more}",
|
||||
map[string]string{"item": "value", "more": "xyz"},
|
||||
"before valueafter xyz",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"missing map val",
|
||||
"${item}",
|
||||
map[string]string{"item": ""},
|
||||
"",
|
||||
true,
|
||||
},
|
||||
// "weird" HIL, but not technically invalid
|
||||
{
|
||||
"just end",
|
||||
"}",
|
||||
map[string]string{},
|
||||
"}",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"var without start",
|
||||
" item }",
|
||||
map[string]string{"item": "value"},
|
||||
" item }",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"two vars missing second start",
|
||||
"before ${ item }after more }",
|
||||
map[string]string{"item": "value", "more": "xyz"},
|
||||
"before valueafter more }",
|
||||
true,
|
||||
},
|
||||
// invalid HIL
|
||||
{
|
||||
"just start",
|
||||
"${",
|
||||
map[string]string{},
|
||||
"",
|
||||
false,
|
||||
},
|
||||
{
|
||||
"backwards",
|
||||
"}${",
|
||||
map[string]string{},
|
||||
"",
|
||||
false,
|
||||
},
|
||||
{
|
||||
"no varname",
|
||||
"${}",
|
||||
map[string]string{},
|
||||
"",
|
||||
false,
|
||||
},
|
||||
{
|
||||
"missing map key",
|
||||
"${item}",
|
||||
map[string]string{},
|
||||
"",
|
||||
false,
|
||||
},
|
||||
{
|
||||
"var without end",
|
||||
"${ item ",
|
||||
map[string]string{"item": "value"},
|
||||
"",
|
||||
false,
|
||||
},
|
||||
{
|
||||
"two vars missing first end",
|
||||
"before ${ item after ${ more }",
|
||||
map[string]string{"item": "value", "more": "xyz"},
|
||||
"",
|
||||
false,
|
||||
},
|
||||
} {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
out, err := InterpolateHIL(test.in, test.vars)
|
||||
if test.ok {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, test.exp, out)
|
||||
} else {
|
||||
require.NotNil(t, err)
|
||||
require.Equal(t, out, "")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,6 +10,8 @@ func init() {
|
|||
registerEndpoint("/v1/acl/info/", []string{"GET"}, (*HTTPServer).ACLGet)
|
||||
registerEndpoint("/v1/acl/clone/", []string{"PUT"}, (*HTTPServer).ACLClone)
|
||||
registerEndpoint("/v1/acl/list", []string{"GET"}, (*HTTPServer).ACLList)
|
||||
registerEndpoint("/v1/acl/login", []string{"POST"}, (*HTTPServer).ACLLogin)
|
||||
registerEndpoint("/v1/acl/logout", []string{"POST"}, (*HTTPServer).ACLLogout)
|
||||
registerEndpoint("/v1/acl/replication", []string{"GET"}, (*HTTPServer).ACLReplicationStatus)
|
||||
registerEndpoint("/v1/acl/policies", []string{"GET"}, (*HTTPServer).ACLPolicyList)
|
||||
registerEndpoint("/v1/acl/policy", []string{"PUT"}, (*HTTPServer).ACLPolicyCreate)
|
||||
|
@ -18,6 +20,12 @@ func init() {
|
|||
registerEndpoint("/v1/acl/role", []string{"PUT"}, (*HTTPServer).ACLRoleCreate)
|
||||
registerEndpoint("/v1/acl/role/name/", []string{"GET"}, (*HTTPServer).ACLRoleReadByName)
|
||||
registerEndpoint("/v1/acl/role/", []string{"GET", "PUT", "DELETE"}, (*HTTPServer).ACLRoleCRUD)
|
||||
registerEndpoint("/v1/acl/binding-rules", []string{"GET"}, (*HTTPServer).ACLBindingRuleList)
|
||||
registerEndpoint("/v1/acl/binding-rule", []string{"PUT"}, (*HTTPServer).ACLBindingRuleCreate)
|
||||
registerEndpoint("/v1/acl/binding-rule/", []string{"GET", "PUT", "DELETE"}, (*HTTPServer).ACLBindingRuleCRUD)
|
||||
registerEndpoint("/v1/acl/auth-methods", []string{"GET"}, (*HTTPServer).ACLAuthMethodList)
|
||||
registerEndpoint("/v1/acl/auth-method", []string{"PUT"}, (*HTTPServer).ACLAuthMethodCreate)
|
||||
registerEndpoint("/v1/acl/auth-method/", []string{"GET", "PUT", "DELETE"}, (*HTTPServer).ACLAuthMethodCRUD)
|
||||
registerEndpoint("/v1/acl/rules/translate", []string{"POST"}, (*HTTPServer).ACLRulesTranslate)
|
||||
registerEndpoint("/v1/acl/rules/translate/", []string{"GET"}, (*HTTPServer).ACLRulesTranslateLegacyToken)
|
||||
registerEndpoint("/v1/acl/tokens", []string{"GET"}, (*HTTPServer).ACLTokenList)
|
||||
|
|
|
@ -186,9 +186,12 @@ func (s *ACLServiceIdentity) SyntheticPolicy() *ACLPolicy {
|
|||
rules := fmt.Sprintf(aclPolicyTemplateServiceIdentity, s.ServiceName, s.ServiceName)
|
||||
|
||||
hasher := fnv.New128a()
|
||||
hashID := fmt.Sprintf("%x", hasher.Sum([]byte(rules)))
|
||||
|
||||
policy := &ACLPolicy{}
|
||||
policy.ID = fmt.Sprintf("%x", hasher.Sum([]byte(rules)))
|
||||
policy.Name = fmt.Sprintf("synthetic-policy-%s", policy.ID)
|
||||
policy.ID = hashID
|
||||
policy.Name = fmt.Sprintf("synthetic-policy-%s", hashID)
|
||||
policy.Description = "synthetic policy"
|
||||
policy.Rules = rules
|
||||
policy.Syntax = acl.SyntaxCurrent
|
||||
policy.Datacenters = s.Datacenters
|
||||
|
@ -234,6 +237,9 @@ type ACLToken struct {
|
|||
// to the ACL datacenter and replicated to others.
|
||||
Local bool
|
||||
|
||||
// AuthMethod is the name of the auth method used to create this token.
|
||||
AuthMethod string `json:",omitempty"`
|
||||
|
||||
// ExpirationTime represents the point after which a token should be
|
||||
// considered revoked and is eligible for destruction. The zero value
|
||||
// represents NO expiration.
|
||||
|
@ -309,7 +315,11 @@ func (t *ACLToken) PolicyIDs() []string {
|
|||
}
|
||||
|
||||
func (t *ACLToken) RoleIDs() []string {
|
||||
var ids []string
|
||||
if len(t.Roles) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
ids := make([]string, 0, len(t.Roles))
|
||||
for _, link := range t.Roles {
|
||||
ids = append(ids, link.ID)
|
||||
}
|
||||
|
@ -345,7 +355,8 @@ func (t *ACLToken) UsesNonLegacyFields() bool {
|
|||
len(t.Roles) > 0 ||
|
||||
t.Type == "" ||
|
||||
t.HasExpirationTime() ||
|
||||
t.ExpirationTTL != 0
|
||||
t.ExpirationTTL != 0 ||
|
||||
t.AuthMethod != ""
|
||||
}
|
||||
|
||||
func (t *ACLToken) EmbeddedPolicy() *ACLPolicy {
|
||||
|
@ -428,7 +439,7 @@ func (t *ACLToken) SetHash(force bool) []byte {
|
|||
|
||||
func (t *ACLToken) EstimateSize() int {
|
||||
// 41 = 16 (RaftIndex) + 8 (Hash) + 8 (ExpirationTime) + 8 (CreateTime) + 1 (Local)
|
||||
size := 41 + len(t.AccessorID) + len(t.SecretID) + len(t.Description) + len(t.Type) + len(t.Rules)
|
||||
size := 41 + len(t.AccessorID) + len(t.SecretID) + len(t.Description) + len(t.Type) + len(t.Rules) + len(t.AuthMethod)
|
||||
for _, link := range t.Policies {
|
||||
size += len(link.ID) + len(link.Name)
|
||||
}
|
||||
|
@ -451,6 +462,7 @@ type ACLTokenListStub struct {
|
|||
Roles []ACLTokenRoleLink `json:",omitempty"`
|
||||
ServiceIdentities []*ACLServiceIdentity `json:",omitempty"`
|
||||
Local bool
|
||||
AuthMethod string `json:",omitempty"`
|
||||
ExpirationTime *time.Time `json:",omitempty"`
|
||||
CreateTime time.Time `json:",omitempty"`
|
||||
Hash []byte
|
||||
|
@ -469,6 +481,7 @@ func (token *ACLToken) Stub() *ACLTokenListStub {
|
|||
Roles: token.Roles,
|
||||
ServiceIdentities: token.ServiceIdentities,
|
||||
Local: token.Local,
|
||||
AuthMethod: token.AuthMethod,
|
||||
ExpirationTime: token.ExpirationTime,
|
||||
CreateTime: token.CreateTime,
|
||||
Hash: token.Hash,
|
||||
|
@ -722,8 +735,6 @@ type ACLRole struct {
|
|||
ID string
|
||||
|
||||
// Name is the unique name to reference the role by.
|
||||
//
|
||||
// Validated with structs.isValidRoleName()
|
||||
Name string
|
||||
|
||||
// Description is a human readable description (Optional)
|
||||
|
@ -819,6 +830,136 @@ func (r *ACLRole) EstimateSize() int {
|
|||
return size
|
||||
}
|
||||
|
||||
const (
|
||||
// BindingRuleBindTypeService is the binding rule bind type that
|
||||
// assigns a Service Identity to the token that is created using the value
|
||||
// of the computed BindName as the ServiceName like:
|
||||
//
|
||||
// &ACLToken{
|
||||
// ...other fields...
|
||||
// ServiceIdentities: []*ACLServiceIdentity{
|
||||
// &ACLServiceIdentity{
|
||||
// ServiceName: "<computed BindName>",
|
||||
// },
|
||||
// },
|
||||
// }
|
||||
BindingRuleBindTypeService = "service"
|
||||
|
||||
// BindingRuleBindTypeRole is the binding rule bind type that only allows
|
||||
// the binding rule to function if a role with the given name (BindName)
|
||||
// exists at login-time. If it does the token that is created is directly
|
||||
// linked to that role like:
|
||||
//
|
||||
// &ACLToken{
|
||||
// ...other fields...
|
||||
// Roles: []ACLTokenRoleLink{
|
||||
// { Name: "<computed BindName>" }
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// If it does not exist at login-time the rule is ignored.
|
||||
BindingRuleBindTypeRole = "role"
|
||||
)
|
||||
|
||||
type ACLBindingRule struct {
|
||||
// ID is the internal UUID associated with the binding rule
|
||||
ID string
|
||||
|
||||
// Description is a human readable description (Optional)
|
||||
Description string
|
||||
|
||||
// AuthMethod is the name of the auth method for which this rule applies.
|
||||
AuthMethod string
|
||||
|
||||
// Selector is an expression that matches against verified identity
|
||||
// attributes returned from the auth method during login.
|
||||
Selector string
|
||||
|
||||
// BindType adjusts how this binding rule is applied at login time. The
|
||||
// valid values are:
|
||||
//
|
||||
// - BindingRuleBindTypeService = "service"
|
||||
// - BindingRuleBindTypeRole = "role"
|
||||
BindType string
|
||||
|
||||
// BindName is the target of the binding. Can be lightly templated using
|
||||
// HIL ${foo} syntax from available field names. How it is used depends
|
||||
// upon the BindType.
|
||||
BindName string
|
||||
|
||||
// Embedded Raft Metadata
|
||||
RaftIndex `hash:"ignore"`
|
||||
}
|
||||
|
||||
func (r *ACLBindingRule) Clone() *ACLBindingRule {
|
||||
r2 := *r
|
||||
return &r2
|
||||
}
|
||||
|
||||
type ACLBindingRules []*ACLBindingRule
|
||||
|
||||
func (rules ACLBindingRules) Sort() {
|
||||
sort.Slice(rules, func(i, j int) bool {
|
||||
return rules[i].ID < rules[j].ID
|
||||
})
|
||||
}
|
||||
|
||||
type ACLAuthMethodListStub struct {
|
||||
Name string
|
||||
Description string
|
||||
Type string
|
||||
CreateIndex uint64
|
||||
ModifyIndex uint64
|
||||
}
|
||||
|
||||
func (p *ACLAuthMethod) Stub() *ACLAuthMethodListStub {
|
||||
return &ACLAuthMethodListStub{
|
||||
Name: p.Name,
|
||||
Description: p.Description,
|
||||
Type: p.Type,
|
||||
CreateIndex: p.CreateIndex,
|
||||
ModifyIndex: p.ModifyIndex,
|
||||
}
|
||||
}
|
||||
|
||||
type ACLAuthMethods []*ACLAuthMethod
|
||||
type ACLAuthMethodListStubs []*ACLAuthMethodListStub
|
||||
|
||||
func (methods ACLAuthMethods) Sort() {
|
||||
sort.Slice(methods, func(i, j int) bool {
|
||||
return methods[i].Name < methods[j].Name
|
||||
})
|
||||
}
|
||||
|
||||
func (methods ACLAuthMethodListStubs) Sort() {
|
||||
sort.Slice(methods, func(i, j int) bool {
|
||||
return methods[i].Name < methods[j].Name
|
||||
})
|
||||
}
|
||||
|
||||
type ACLAuthMethod struct {
|
||||
// Name is a unique identifier for this specific auth method.
|
||||
//
|
||||
// Immutable once set and only settable during create.
|
||||
Name string
|
||||
|
||||
// Type is the type of the auth method this is.
|
||||
//
|
||||
// Immutable once set and only settable during create.
|
||||
Type string
|
||||
|
||||
// Description is just an optional bunch of explanatory text.
|
||||
Description string
|
||||
|
||||
// Configuration is arbitrary configuration for the auth method. This
|
||||
// should only contain primitive values and containers (such as lists and
|
||||
// maps).
|
||||
Config map[string]interface{}
|
||||
|
||||
// Embedded Raft Metadata
|
||||
RaftIndex `hash:"ignore"`
|
||||
}
|
||||
|
||||
type ACLReplicationType string
|
||||
|
||||
const (
|
||||
|
@ -898,6 +1039,7 @@ type ACLTokenListRequest struct {
|
|||
IncludeGlobal bool // Whether global tokens should be included
|
||||
Policy string // Policy filter
|
||||
Role string // Role filter
|
||||
AuthMethod string // Auth Method filter
|
||||
Datacenter string // The datacenter to perform the request within
|
||||
QueryOptions
|
||||
}
|
||||
|
@ -1068,7 +1210,7 @@ func cloneStringSlice(s []string) []string {
|
|||
|
||||
// ACLRoleSetRequest is used at the RPC layer for creation and update requests
|
||||
type ACLRoleSetRequest struct {
|
||||
Role ACLRole // The policy to upsert
|
||||
Role ACLRole // The role to upsert
|
||||
Datacenter string // The datacenter to perform the request within
|
||||
WriteRequest
|
||||
}
|
||||
|
@ -1154,3 +1296,161 @@ type ACLRoleBatchSetRequest struct {
|
|||
type ACLRoleBatchDeleteRequest struct {
|
||||
RoleIDs []string
|
||||
}
|
||||
|
||||
// ACLBindingRuleSetRequest is used at the RPC layer for creation and update requests
|
||||
type ACLBindingRuleSetRequest struct {
|
||||
BindingRule ACLBindingRule // The rule to upsert
|
||||
Datacenter string // The datacenter to perform the request within
|
||||
WriteRequest
|
||||
}
|
||||
|
||||
func (r *ACLBindingRuleSetRequest) RequestDatacenter() string {
|
||||
return r.Datacenter
|
||||
}
|
||||
|
||||
// ACLBindingRuleDeleteRequest is used at the RPC layer deletion requests
|
||||
type ACLBindingRuleDeleteRequest struct {
|
||||
BindingRuleID string // id of the rule to delete
|
||||
Datacenter string // The datacenter to perform the request within
|
||||
WriteRequest
|
||||
}
|
||||
|
||||
func (r *ACLBindingRuleDeleteRequest) RequestDatacenter() string {
|
||||
return r.Datacenter
|
||||
}
|
||||
|
||||
// ACLBindingRuleGetRequest is used at the RPC layer to perform rule read operations
|
||||
type ACLBindingRuleGetRequest struct {
|
||||
BindingRuleID string // id used for the rule lookup
|
||||
Datacenter string // The datacenter to perform the request within
|
||||
QueryOptions
|
||||
}
|
||||
|
||||
func (r *ACLBindingRuleGetRequest) RequestDatacenter() string {
|
||||
return r.Datacenter
|
||||
}
|
||||
|
||||
// ACLBindingRuleListRequest is used at the RPC layer to request a listing of rules
|
||||
type ACLBindingRuleListRequest struct {
|
||||
AuthMethod string // optional filter
|
||||
Datacenter string // The datacenter to perform the request within
|
||||
QueryOptions
|
||||
}
|
||||
|
||||
func (r *ACLBindingRuleListRequest) RequestDatacenter() string {
|
||||
return r.Datacenter
|
||||
}
|
||||
|
||||
type ACLBindingRuleListResponse struct {
|
||||
BindingRules ACLBindingRules
|
||||
QueryMeta
|
||||
}
|
||||
|
||||
// ACLBindingRuleResponse returns a single binding + metadata
|
||||
type ACLBindingRuleResponse struct {
|
||||
BindingRule *ACLBindingRule
|
||||
QueryMeta
|
||||
}
|
||||
|
||||
// ACLBindingRuleBatchSetRequest is used at the Raft layer for batching
|
||||
// multiple rule creations and updates
|
||||
type ACLBindingRuleBatchSetRequest struct {
|
||||
BindingRules ACLBindingRules
|
||||
}
|
||||
|
||||
// ACLBindingRuleBatchDeleteRequest is used at the Raft layer for batching
|
||||
// multiple rule deletions
|
||||
type ACLBindingRuleBatchDeleteRequest struct {
|
||||
BindingRuleIDs []string
|
||||
}
|
||||
|
||||
// ACLAuthMethodSetRequest is used at the RPC layer for creation and update requests
|
||||
type ACLAuthMethodSetRequest struct {
|
||||
AuthMethod ACLAuthMethod // The auth method to upsert
|
||||
Datacenter string // The datacenter to perform the request within
|
||||
WriteRequest
|
||||
}
|
||||
|
||||
func (r *ACLAuthMethodSetRequest) RequestDatacenter() string {
|
||||
return r.Datacenter
|
||||
}
|
||||
|
||||
// ACLAuthMethodDeleteRequest is used at the RPC layer deletion requests
|
||||
type ACLAuthMethodDeleteRequest struct {
|
||||
AuthMethodName string // name of the auth method to delete
|
||||
Datacenter string // The datacenter to perform the request within
|
||||
WriteRequest
|
||||
}
|
||||
|
||||
func (r *ACLAuthMethodDeleteRequest) RequestDatacenter() string {
|
||||
return r.Datacenter
|
||||
}
|
||||
|
||||
// ACLAuthMethodGetRequest is used at the RPC layer to perform rule read operations
|
||||
type ACLAuthMethodGetRequest struct {
|
||||
AuthMethodName string // name used for the auth method lookup
|
||||
Datacenter string // The datacenter to perform the request within
|
||||
QueryOptions
|
||||
}
|
||||
|
||||
func (r *ACLAuthMethodGetRequest) RequestDatacenter() string {
|
||||
return r.Datacenter
|
||||
}
|
||||
|
||||
// ACLAuthMethodListRequest is used at the RPC layer to request a listing of auth methods
|
||||
type ACLAuthMethodListRequest struct {
|
||||
Datacenter string // The datacenter to perform the request within
|
||||
QueryOptions
|
||||
}
|
||||
|
||||
func (r *ACLAuthMethodListRequest) RequestDatacenter() string {
|
||||
return r.Datacenter
|
||||
}
|
||||
|
||||
type ACLAuthMethodListResponse struct {
|
||||
AuthMethods ACLAuthMethodListStubs
|
||||
QueryMeta
|
||||
}
|
||||
|
||||
// ACLAuthMethodResponse returns a single auth method + metadata
|
||||
type ACLAuthMethodResponse struct {
|
||||
AuthMethod *ACLAuthMethod
|
||||
QueryMeta
|
||||
}
|
||||
|
||||
// ACLAuthMethodBatchSetRequest is used at the Raft layer for batching
|
||||
// multiple auth method creations and updates
|
||||
type ACLAuthMethodBatchSetRequest struct {
|
||||
AuthMethods ACLAuthMethods
|
||||
}
|
||||
|
||||
// ACLAuthMethodBatchDeleteRequest is used at the Raft layer for batching
|
||||
// multiple auth method deletions
|
||||
type ACLAuthMethodBatchDeleteRequest struct {
|
||||
AuthMethodNames []string
|
||||
}
|
||||
|
||||
type ACLLoginParams struct {
|
||||
AuthMethod string
|
||||
BearerToken string
|
||||
Meta map[string]string `json:",omitempty"`
|
||||
}
|
||||
|
||||
type ACLLoginRequest struct {
|
||||
Auth *ACLLoginParams
|
||||
Datacenter string // The datacenter to perform the request within
|
||||
WriteRequest
|
||||
}
|
||||
|
||||
func (r *ACLLoginRequest) RequestDatacenter() string {
|
||||
return r.Datacenter
|
||||
}
|
||||
|
||||
type ACLLogoutRequest struct {
|
||||
Datacenter string // The datacenter to perform the request within
|
||||
WriteRequest
|
||||
}
|
||||
|
||||
func (r *ACLLogoutRequest) RequestDatacenter() string {
|
||||
return r.Datacenter
|
||||
}
|
||||
|
|
|
@ -60,7 +60,6 @@ func (e *AuthorizerCacheEntry) Age() time.Duration {
|
|||
return time.Since(e.CacheTime)
|
||||
}
|
||||
|
||||
// RoleCacheEntry is the payload for by by-id and by-name caches.
|
||||
type RoleCacheEntry struct {
|
||||
Role *ACLRole
|
||||
CacheTime time.Time
|
||||
|
@ -173,7 +172,7 @@ func (c *ACLCaches) GetAuthorizer(id string) *AuthorizerCacheEntry {
|
|||
return nil
|
||||
}
|
||||
|
||||
// GetRoleByID fetches a role from the cache by id and returns it
|
||||
// GetRole fetches a role from the cache by id and returns it
|
||||
func (c *ACLCaches) GetRole(roleID string) *RoleCacheEntry {
|
||||
if c == nil || c.roles == nil {
|
||||
return nil
|
||||
|
@ -228,9 +227,11 @@ func (c *ACLCaches) PutAuthorizerWithTTL(id string, authorizer acl.Authorizer, t
|
|||
}
|
||||
|
||||
func (c *ACLCaches) PutRole(roleID string, role *ACLRole) {
|
||||
if c != nil && c.roles != nil {
|
||||
c.roles.Add(roleID, &RoleCacheEntry{Role: role, CacheTime: time.Now()})
|
||||
if c == nil || c.roles == nil {
|
||||
return
|
||||
}
|
||||
|
||||
c.roles.Add(roleID, &RoleCacheEntry{Role: role, CacheTime: time.Now()})
|
||||
}
|
||||
|
||||
func (c *ACLCaches) RemoveIdentity(id string) {
|
||||
|
@ -246,7 +247,7 @@ func (c *ACLCaches) RemovePolicy(policyID string) {
|
|||
}
|
||||
|
||||
func (c *ACLCaches) RemoveRole(roleID string) {
|
||||
if c != nil && c.roles != nil && roleID != "" {
|
||||
if c != nil && c.roles != nil {
|
||||
c.roles.Remove(roleID)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -113,6 +113,7 @@ func TestStructs_ACLCaches(t *testing.T) {
|
|||
require.NotNil(t, cache)
|
||||
|
||||
cache.PutRole("foo", &ACLRole{})
|
||||
|
||||
entry := cache.GetRole("foo")
|
||||
require.NotNil(t, entry)
|
||||
require.NotNil(t, entry.Role)
|
||||
|
|
|
@ -188,12 +188,13 @@ node_prefix "" {
|
|||
expect := &ACLPolicy{
|
||||
Syntax: acl.SyntaxCurrent,
|
||||
Datacenters: test.datacenters,
|
||||
Description: "synthetic policy",
|
||||
Rules: test.expectRules,
|
||||
}
|
||||
|
||||
got := svcid.SyntheticPolicy()
|
||||
require.NotEmpty(t, got.ID)
|
||||
require.Equal(t, got.Name, "synthetic-policy-"+got.ID)
|
||||
require.True(t, strings.HasPrefix(got.Name, "synthetic-policy-"))
|
||||
// strip irrelevant fields before equality
|
||||
got.ID = ""
|
||||
got.Name = ""
|
||||
|
|
|
@ -33,31 +33,35 @@ type RaftIndex struct {
|
|||
// These are serialized between Consul servers and stored in Consul snapshots,
|
||||
// so entries must only ever be added.
|
||||
const (
|
||||
RegisterRequestType MessageType = 0
|
||||
DeregisterRequestType = 1
|
||||
KVSRequestType = 2
|
||||
SessionRequestType = 3
|
||||
ACLRequestType = 4 // DEPRECATED (ACL-Legacy-Compat)
|
||||
TombstoneRequestType = 5
|
||||
CoordinateBatchUpdateType = 6
|
||||
PreparedQueryRequestType = 7
|
||||
TxnRequestType = 8
|
||||
AutopilotRequestType = 9
|
||||
AreaRequestType = 10
|
||||
ACLBootstrapRequestType = 11
|
||||
IntentionRequestType = 12
|
||||
ConnectCARequestType = 13
|
||||
ConnectCAProviderStateType = 14
|
||||
ConnectCAConfigType = 15 // FSM snapshots only.
|
||||
IndexRequestType = 16 // FSM snapshots only.
|
||||
ACLTokenSetRequestType = 17
|
||||
ACLTokenDeleteRequestType = 18
|
||||
ACLPolicySetRequestType = 19
|
||||
ACLPolicyDeleteRequestType = 20
|
||||
ConnectCALeafRequestType = 21
|
||||
ConfigEntryRequestType = 22
|
||||
ACLRoleSetRequestType = 23
|
||||
ACLRoleDeleteRequestType = 24
|
||||
RegisterRequestType MessageType = 0
|
||||
DeregisterRequestType = 1
|
||||
KVSRequestType = 2
|
||||
SessionRequestType = 3
|
||||
ACLRequestType = 4 // DEPRECATED (ACL-Legacy-Compat)
|
||||
TombstoneRequestType = 5
|
||||
CoordinateBatchUpdateType = 6
|
||||
PreparedQueryRequestType = 7
|
||||
TxnRequestType = 8
|
||||
AutopilotRequestType = 9
|
||||
AreaRequestType = 10
|
||||
ACLBootstrapRequestType = 11
|
||||
IntentionRequestType = 12
|
||||
ConnectCARequestType = 13
|
||||
ConnectCAProviderStateType = 14
|
||||
ConnectCAConfigType = 15 // FSM snapshots only.
|
||||
IndexRequestType = 16 // FSM snapshots only.
|
||||
ACLTokenSetRequestType = 17
|
||||
ACLTokenDeleteRequestType = 18
|
||||
ACLPolicySetRequestType = 19
|
||||
ACLPolicyDeleteRequestType = 20
|
||||
ConnectCALeafRequestType = 21
|
||||
ConfigEntryRequestType = 22
|
||||
ACLRoleSetRequestType = 23
|
||||
ACLRoleDeleteRequestType = 24
|
||||
ACLBindingRuleSetRequestType = 25
|
||||
ACLBindingRuleDeleteRequestType = 26
|
||||
ACLAuthMethodSetRequestType = 27
|
||||
ACLAuthMethodDeleteRequestType = 28
|
||||
)
|
||||
|
||||
const (
|
||||
|
|
364
api/acl.go
364
api/acl.go
|
@ -6,6 +6,8 @@ import (
|
|||
"io/ioutil"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/mitchellh/mapstructure"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -132,6 +134,96 @@ type ACLRole struct {
|
|||
ModifyIndex uint64
|
||||
}
|
||||
|
||||
// BindingRuleBindType is the type of binding rule mechanism used.
|
||||
type BindingRuleBindType string
|
||||
|
||||
const (
|
||||
// BindingRuleBindTypeService binds to a service identity with the given name.
|
||||
BindingRuleBindTypeService BindingRuleBindType = "service"
|
||||
|
||||
// BindingRuleBindTypeRole binds to pre-existing roles with the given name.
|
||||
BindingRuleBindTypeRole BindingRuleBindType = "role"
|
||||
)
|
||||
|
||||
type ACLBindingRule struct {
|
||||
ID string
|
||||
Description string
|
||||
AuthMethod string
|
||||
Selector string
|
||||
BindType BindingRuleBindType
|
||||
BindName string
|
||||
|
||||
CreateIndex uint64
|
||||
ModifyIndex uint64
|
||||
}
|
||||
|
||||
type ACLAuthMethod struct {
|
||||
Name string
|
||||
Type string
|
||||
Description string
|
||||
|
||||
// Configuration is arbitrary configuration for the auth method. This
|
||||
// should only contain primitive values and containers (such as lists and
|
||||
// maps).
|
||||
Config map[string]interface{}
|
||||
|
||||
CreateIndex uint64
|
||||
ModifyIndex uint64
|
||||
}
|
||||
|
||||
type ACLAuthMethodListEntry struct {
|
||||
Name string
|
||||
Type string
|
||||
Description string
|
||||
CreateIndex uint64
|
||||
ModifyIndex uint64
|
||||
}
|
||||
|
||||
// ParseKubernetesAuthMethodConfig takes a raw config map and returns a parsed
|
||||
// KubernetesAuthMethodConfig.
|
||||
func ParseKubernetesAuthMethodConfig(raw map[string]interface{}) (*KubernetesAuthMethodConfig, error) {
|
||||
var config KubernetesAuthMethodConfig
|
||||
decodeConf := &mapstructure.DecoderConfig{
|
||||
Result: &config,
|
||||
WeaklyTypedInput: true,
|
||||
}
|
||||
|
||||
decoder, err := mapstructure.NewDecoder(decodeConf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := decoder.Decode(raw); err != nil {
|
||||
return nil, fmt.Errorf("error decoding config: %s", err)
|
||||
}
|
||||
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
// KubernetesAuthMethodConfig is the config for the built-in Consul auth method
|
||||
// for Kubernetes.
|
||||
type KubernetesAuthMethodConfig struct {
|
||||
Host string `json:",omitempty"`
|
||||
CACert string `json:",omitempty"`
|
||||
ServiceAccountJWT string `json:",omitempty"`
|
||||
}
|
||||
|
||||
// RenderToConfig converts this into a map[string]interface{} suitable for use
|
||||
// in the ACLAuthMethod.Config field.
|
||||
func (c *KubernetesAuthMethodConfig) RenderToConfig() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"Host": c.Host,
|
||||
"CACert": c.CACert,
|
||||
"ServiceAccountJWT": c.ServiceAccountJWT,
|
||||
}
|
||||
}
|
||||
|
||||
type ACLLoginParams struct {
|
||||
AuthMethod string
|
||||
BearerToken string
|
||||
Meta map[string]string `json:",omitempty"`
|
||||
}
|
||||
|
||||
// ACL can be used to query the ACL endpoints
|
||||
type ACL struct {
|
||||
c *Client
|
||||
|
@ -498,7 +590,7 @@ func (a *ACL) PolicyCreate(policy *ACLPolicy, q *WriteOptions) (*ACLPolicy, *Wri
|
|||
// existing policy ID
|
||||
func (a *ACL) PolicyUpdate(policy *ACLPolicy, q *WriteOptions) (*ACLPolicy, *WriteMeta, error) {
|
||||
if policy.ID == "" {
|
||||
return nil, nil, fmt.Errorf("Must specify an ID in Policy Creation")
|
||||
return nil, nil, fmt.Errorf("Must specify an ID in Policy Update")
|
||||
}
|
||||
|
||||
r := a.c.newRequest("PUT", "/v1/acl/policy/"+policy.ID)
|
||||
|
@ -654,7 +746,7 @@ func (a *ACL) RoleCreate(role *ACLRole, q *WriteOptions) (*ACLRole, *WriteMeta,
|
|||
// existing role ID
|
||||
func (a *ACL) RoleUpdate(role *ACLRole, q *WriteOptions) (*ACLRole, *WriteMeta, error) {
|
||||
if role.ID == "" {
|
||||
return nil, nil, fmt.Errorf("Must specify an ID in Role Creation")
|
||||
return nil, nil, fmt.Errorf("Must specify an ID in Role Update")
|
||||
}
|
||||
|
||||
r := a.c.newRequest("PUT", "/v1/acl/role/"+role.ID)
|
||||
|
@ -763,3 +855,271 @@ func (a *ACL) RoleList(q *QueryOptions) ([]*ACLRole, *QueryMeta, error) {
|
|||
}
|
||||
return entries, qm, nil
|
||||
}
|
||||
|
||||
// AuthMethodCreate will create a new auth method.
|
||||
func (a *ACL) AuthMethodCreate(method *ACLAuthMethod, q *WriteOptions) (*ACLAuthMethod, *WriteMeta, error) {
|
||||
if method.Name == "" {
|
||||
return nil, nil, fmt.Errorf("Must specify a Name in Auth Method Creation")
|
||||
}
|
||||
|
||||
r := a.c.newRequest("PUT", "/v1/acl/auth-method")
|
||||
r.setWriteOptions(q)
|
||||
r.obj = method
|
||||
rtt, resp, err := requireOK(a.c.doRequest(r))
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
wm := &WriteMeta{RequestTime: rtt}
|
||||
var out ACLAuthMethod
|
||||
if err := decodeBody(resp, &out); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return &out, wm, nil
|
||||
}
|
||||
|
||||
// AuthMethodUpdate updates an auth method.
|
||||
func (a *ACL) AuthMethodUpdate(method *ACLAuthMethod, q *WriteOptions) (*ACLAuthMethod, *WriteMeta, error) {
|
||||
if method.Name == "" {
|
||||
return nil, nil, fmt.Errorf("Must specify a Name in Auth Method Update")
|
||||
}
|
||||
|
||||
r := a.c.newRequest("PUT", "/v1/acl/auth-method/"+url.QueryEscape(method.Name))
|
||||
r.setWriteOptions(q)
|
||||
r.obj = method
|
||||
rtt, resp, err := requireOK(a.c.doRequest(r))
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
wm := &WriteMeta{RequestTime: rtt}
|
||||
var out ACLAuthMethod
|
||||
if err := decodeBody(resp, &out); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return &out, wm, nil
|
||||
}
|
||||
|
||||
// AuthMethodDelete deletes an auth method given its Name.
|
||||
func (a *ACL) AuthMethodDelete(methodName string, q *WriteOptions) (*WriteMeta, error) {
|
||||
if methodName == "" {
|
||||
return nil, fmt.Errorf("Must specify a Name in Auth Method Delete")
|
||||
}
|
||||
|
||||
r := a.c.newRequest("DELETE", "/v1/acl/auth-method/"+url.QueryEscape(methodName))
|
||||
r.setWriteOptions(q)
|
||||
rtt, resp, err := requireOK(a.c.doRequest(r))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
wm := &WriteMeta{RequestTime: rtt}
|
||||
return wm, nil
|
||||
}
|
||||
|
||||
// AuthMethodRead retrieves the auth method. Returns nil if not found.
|
||||
func (a *ACL) AuthMethodRead(methodName string, q *QueryOptions) (*ACLAuthMethod, *QueryMeta, error) {
|
||||
if methodName == "" {
|
||||
return nil, nil, fmt.Errorf("Must specify a Name in Auth Method Read")
|
||||
}
|
||||
|
||||
r := a.c.newRequest("GET", "/v1/acl/auth-method/"+url.QueryEscape(methodName))
|
||||
r.setQueryOptions(q)
|
||||
found, rtt, resp, err := requireNotFoundOrOK(a.c.doRequest(r))
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
qm := &QueryMeta{}
|
||||
parseQueryMeta(resp, qm)
|
||||
qm.RequestTime = rtt
|
||||
|
||||
if !found {
|
||||
return nil, qm, nil
|
||||
}
|
||||
|
||||
var out ACLAuthMethod
|
||||
if err := decodeBody(resp, &out); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return &out, qm, nil
|
||||
}
|
||||
|
||||
// AuthMethodList retrieves a listing of all auth methods. The listing does not
|
||||
// include some metadata for the auth method as those should be retrieved by
|
||||
// subsequent calls to AuthMethodRead.
|
||||
func (a *ACL) AuthMethodList(q *QueryOptions) ([]*ACLAuthMethodListEntry, *QueryMeta, error) {
|
||||
r := a.c.newRequest("GET", "/v1/acl/auth-methods")
|
||||
r.setQueryOptions(q)
|
||||
rtt, resp, err := requireOK(a.c.doRequest(r))
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
qm := &QueryMeta{}
|
||||
parseQueryMeta(resp, qm)
|
||||
qm.RequestTime = rtt
|
||||
|
||||
var entries []*ACLAuthMethodListEntry
|
||||
if err := decodeBody(resp, &entries); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return entries, qm, nil
|
||||
}
|
||||
|
||||
// BindingRuleCreate will create a new binding rule. It is not allowed for the
|
||||
// binding rule parameter's ID field to be set as this will be generated by
|
||||
// Consul while processing the request.
|
||||
func (a *ACL) BindingRuleCreate(rule *ACLBindingRule, q *WriteOptions) (*ACLBindingRule, *WriteMeta, error) {
|
||||
if rule.ID != "" {
|
||||
return nil, nil, fmt.Errorf("Cannot specify an ID in Binding Rule Creation")
|
||||
}
|
||||
|
||||
r := a.c.newRequest("PUT", "/v1/acl/binding-rule")
|
||||
r.setWriteOptions(q)
|
||||
r.obj = rule
|
||||
rtt, resp, err := requireOK(a.c.doRequest(r))
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
wm := &WriteMeta{RequestTime: rtt}
|
||||
var out ACLBindingRule
|
||||
if err := decodeBody(resp, &out); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return &out, wm, nil
|
||||
}
|
||||
|
||||
// BindingRuleUpdate updates a binding rule. The ID field of the role binding
|
||||
// rule parameter must be set to an existing binding rule ID.
|
||||
func (a *ACL) BindingRuleUpdate(rule *ACLBindingRule, q *WriteOptions) (*ACLBindingRule, *WriteMeta, error) {
|
||||
if rule.ID == "" {
|
||||
return nil, nil, fmt.Errorf("Must specify an ID in Binding Rule Update")
|
||||
}
|
||||
|
||||
r := a.c.newRequest("PUT", "/v1/acl/binding-rule/"+rule.ID)
|
||||
r.setWriteOptions(q)
|
||||
r.obj = rule
|
||||
rtt, resp, err := requireOK(a.c.doRequest(r))
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
wm := &WriteMeta{RequestTime: rtt}
|
||||
var out ACLBindingRule
|
||||
if err := decodeBody(resp, &out); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return &out, wm, nil
|
||||
}
|
||||
|
||||
// BindingRuleDelete deletes a binding rule given its ID.
|
||||
func (a *ACL) BindingRuleDelete(bindingRuleID string, q *WriteOptions) (*WriteMeta, error) {
|
||||
r := a.c.newRequest("DELETE", "/v1/acl/binding-rule/"+bindingRuleID)
|
||||
r.setWriteOptions(q)
|
||||
rtt, resp, err := requireOK(a.c.doRequest(r))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
wm := &WriteMeta{RequestTime: rtt}
|
||||
return wm, nil
|
||||
}
|
||||
|
||||
// BindingRuleRead retrieves the binding rule details. Returns nil if not found.
|
||||
func (a *ACL) BindingRuleRead(bindingRuleID string, q *QueryOptions) (*ACLBindingRule, *QueryMeta, error) {
|
||||
r := a.c.newRequest("GET", "/v1/acl/binding-rule/"+bindingRuleID)
|
||||
r.setQueryOptions(q)
|
||||
found, rtt, resp, err := requireNotFoundOrOK(a.c.doRequest(r))
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
qm := &QueryMeta{}
|
||||
parseQueryMeta(resp, qm)
|
||||
qm.RequestTime = rtt
|
||||
|
||||
if !found {
|
||||
return nil, qm, nil
|
||||
}
|
||||
|
||||
var out ACLBindingRule
|
||||
if err := decodeBody(resp, &out); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return &out, qm, nil
|
||||
}
|
||||
|
||||
// BindingRuleList retrieves a listing of all binding rules.
|
||||
func (a *ACL) BindingRuleList(methodName string, q *QueryOptions) ([]*ACLBindingRule, *QueryMeta, error) {
|
||||
r := a.c.newRequest("GET", "/v1/acl/binding-rules")
|
||||
if methodName != "" {
|
||||
r.params.Set("authmethod", methodName)
|
||||
}
|
||||
r.setQueryOptions(q)
|
||||
rtt, resp, err := requireOK(a.c.doRequest(r))
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
qm := &QueryMeta{}
|
||||
parseQueryMeta(resp, qm)
|
||||
qm.RequestTime = rtt
|
||||
|
||||
var entries []*ACLBindingRule
|
||||
if err := decodeBody(resp, &entries); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return entries, qm, nil
|
||||
}
|
||||
|
||||
// Login is used to exchange auth method credentials for a newly-minted Consul Token.
|
||||
func (a *ACL) Login(auth *ACLLoginParams, q *WriteOptions) (*ACLToken, *WriteMeta, error) {
|
||||
r := a.c.newRequest("POST", "/v1/acl/login")
|
||||
r.setWriteOptions(q)
|
||||
r.obj = auth
|
||||
|
||||
rtt, resp, err := requireOK(a.c.doRequest(r))
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
wm := &WriteMeta{RequestTime: rtt}
|
||||
var out ACLToken
|
||||
if err := decodeBody(resp, &out); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return &out, wm, nil
|
||||
}
|
||||
|
||||
// Logout is used to destroy a Consul Token created via Login().
|
||||
func (a *ACL) Logout(q *WriteOptions) (*WriteMeta, error) {
|
||||
r := a.c.newRequest("POST", "/v1/acl/logout")
|
||||
r.setWriteOptions(q)
|
||||
rtt, resp, err := requireOK(a.c.doRequest(r))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
wm := &WriteMeta{RequestTime: rtt}
|
||||
return wm, nil
|
||||
}
|
||||
|
|
28
api/api.go
28
api/api.go
|
@ -30,6 +30,10 @@ const (
|
|||
// the HTTP token.
|
||||
HTTPTokenEnvName = "CONSUL_HTTP_TOKEN"
|
||||
|
||||
// HTTPTokenFileEnvName defines an environment variable name which sets
|
||||
// the HTTP token file.
|
||||
HTTPTokenFileEnvName = "CONSUL_HTTP_TOKEN_FILE"
|
||||
|
||||
// HTTPAuthEnvName defines an environment variable name which sets
|
||||
// the HTTP authentication header.
|
||||
HTTPAuthEnvName = "CONSUL_HTTP_AUTH"
|
||||
|
@ -280,6 +284,10 @@ type Config struct {
|
|||
// which overrides the agent's default token.
|
||||
Token string
|
||||
|
||||
// TokenFile is a file containing the current token to use for this client.
|
||||
// If provided it is read once at startup and never again.
|
||||
TokenFile string
|
||||
|
||||
TLSConfig TLSConfig
|
||||
}
|
||||
|
||||
|
@ -343,6 +351,10 @@ func defaultConfig(transportFn func() *http.Transport) *Config {
|
|||
config.Address = addr
|
||||
}
|
||||
|
||||
if tokenFile := os.Getenv(HTTPTokenFileEnvName); tokenFile != "" {
|
||||
config.TokenFile = tokenFile
|
||||
}
|
||||
|
||||
if token := os.Getenv(HTTPTokenEnvName); token != "" {
|
||||
config.Token = token
|
||||
}
|
||||
|
@ -449,6 +461,7 @@ func (c *Config) GenerateEnv() []string {
|
|||
env = append(env,
|
||||
fmt.Sprintf("%s=%s", HTTPAddrEnvName, c.Address),
|
||||
fmt.Sprintf("%s=%s", HTTPTokenEnvName, c.Token),
|
||||
fmt.Sprintf("%s=%s", HTTPTokenFileEnvName, c.TokenFile),
|
||||
fmt.Sprintf("%s=%t", HTTPSSLEnvName, c.Scheme == "https"),
|
||||
fmt.Sprintf("%s=%s", HTTPCAFile, c.TLSConfig.CAFile),
|
||||
fmt.Sprintf("%s=%s", HTTPCAPath, c.TLSConfig.CAPath),
|
||||
|
@ -541,6 +554,19 @@ func NewClient(config *Config) (*Client, error) {
|
|||
config.Address = parts[1]
|
||||
}
|
||||
|
||||
// If the TokenFile is set, always use that, even if a Token is configured.
|
||||
// This is because when TokenFile is set it is read into the Token field.
|
||||
// We want any derived clients to have to re-read the token file.
|
||||
if config.TokenFile != "" {
|
||||
data, err := ioutil.ReadFile(config.TokenFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error loading token file: %s", err)
|
||||
}
|
||||
|
||||
if token := strings.TrimSpace(string(data)); token != "" {
|
||||
config.Token = token
|
||||
}
|
||||
}
|
||||
if config.Token == "" {
|
||||
config.Token = defConfig.Token
|
||||
}
|
||||
|
@ -820,6 +846,8 @@ func (c *Client) write(endpoint string, in, out interface{}, q *WriteOptions) (*
|
|||
}
|
||||
|
||||
// parseQueryMeta is used to help parse query meta-data
|
||||
//
|
||||
// TODO(rb): bug? the error from this function is never handled
|
||||
func parseQueryMeta(resp *http.Response, q *QueryMeta) error {
|
||||
header := resp.Header
|
||||
|
||||
|
|
|
@ -875,9 +875,10 @@ func TestAPI_GenerateEnv(t *testing.T) {
|
|||
t.Parallel()
|
||||
|
||||
c := &Config{
|
||||
Address: "127.0.0.1:8500",
|
||||
Token: "test",
|
||||
Scheme: "http",
|
||||
Address: "127.0.0.1:8500",
|
||||
Token: "test",
|
||||
TokenFile: "test.file",
|
||||
Scheme: "http",
|
||||
TLSConfig: TLSConfig{
|
||||
CAFile: "",
|
||||
CAPath: "",
|
||||
|
@ -891,6 +892,7 @@ func TestAPI_GenerateEnv(t *testing.T) {
|
|||
expected := []string{
|
||||
"CONSUL_HTTP_ADDR=127.0.0.1:8500",
|
||||
"CONSUL_HTTP_TOKEN=test",
|
||||
"CONSUL_HTTP_TOKEN_FILE=test.file",
|
||||
"CONSUL_HTTP_SSL=false",
|
||||
"CONSUL_CACERT=",
|
||||
"CONSUL_CAPATH=",
|
||||
|
@ -908,9 +910,10 @@ func TestAPI_GenerateEnvHTTPS(t *testing.T) {
|
|||
t.Parallel()
|
||||
|
||||
c := &Config{
|
||||
Address: "127.0.0.1:8500",
|
||||
Token: "test",
|
||||
Scheme: "https",
|
||||
Address: "127.0.0.1:8500",
|
||||
Token: "test",
|
||||
TokenFile: "test.file",
|
||||
Scheme: "https",
|
||||
TLSConfig: TLSConfig{
|
||||
CAFile: "/var/consul/ca.crt",
|
||||
CAPath: "/var/consul/ca.dir",
|
||||
|
@ -928,6 +931,7 @@ func TestAPI_GenerateEnvHTTPS(t *testing.T) {
|
|||
expected := []string{
|
||||
"CONSUL_HTTP_ADDR=127.0.0.1:8500",
|
||||
"CONSUL_HTTP_TOKEN=test",
|
||||
"CONSUL_HTTP_TOKEN_FILE=test.file",
|
||||
"CONSUL_HTTP_SSL=true",
|
||||
"CONSUL_CACERT=/var/consul/ca.crt",
|
||||
"CONSUL_CAPATH=/var/consul/ca.dir",
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package acl
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
|
@ -23,20 +24,26 @@ func PrintToken(token *api.ACLToken, ui cli.Ui, showMeta bool) {
|
|||
ui.Info(fmt.Sprintf("Create Index: %d", token.CreateIndex))
|
||||
ui.Info(fmt.Sprintf("Modify Index: %d", token.ModifyIndex))
|
||||
}
|
||||
ui.Info(fmt.Sprintf("Policies:"))
|
||||
for _, policy := range token.Policies {
|
||||
ui.Info(fmt.Sprintf(" %s - %s", policy.ID, policy.Name))
|
||||
if len(token.Policies) > 0 {
|
||||
ui.Info(fmt.Sprintf("Policies:"))
|
||||
for _, policy := range token.Policies {
|
||||
ui.Info(fmt.Sprintf(" %s - %s", policy.ID, policy.Name))
|
||||
}
|
||||
}
|
||||
ui.Info(fmt.Sprintf("Roles:"))
|
||||
for _, role := range token.Roles {
|
||||
ui.Info(fmt.Sprintf(" %s - %s", role.ID, role.Name))
|
||||
if len(token.Roles) > 0 {
|
||||
ui.Info(fmt.Sprintf("Roles:"))
|
||||
for _, role := range token.Roles {
|
||||
ui.Info(fmt.Sprintf(" %s - %s", role.ID, role.Name))
|
||||
}
|
||||
}
|
||||
ui.Info(fmt.Sprintf("Service Identities:"))
|
||||
for _, svcid := range token.ServiceIdentities {
|
||||
if len(svcid.Datacenters) > 0 {
|
||||
ui.Info(fmt.Sprintf(" %s (Datacenters: %s)", svcid.ServiceName, strings.Join(svcid.Datacenters, ", ")))
|
||||
} else {
|
||||
ui.Info(fmt.Sprintf(" %s (Datacenters: all)", svcid.ServiceName))
|
||||
if len(token.ServiceIdentities) > 0 {
|
||||
ui.Info(fmt.Sprintf("Service Identities:"))
|
||||
for _, svcid := range token.ServiceIdentities {
|
||||
if len(svcid.Datacenters) > 0 {
|
||||
ui.Info(fmt.Sprintf(" %s (Datacenters: %s)", svcid.ServiceName, strings.Join(svcid.Datacenters, ", ")))
|
||||
} else {
|
||||
ui.Info(fmt.Sprintf(" %s (Datacenters: all)", svcid.ServiceName))
|
||||
}
|
||||
}
|
||||
}
|
||||
if token.Rules != "" {
|
||||
|
@ -59,20 +66,26 @@ func PrintTokenListEntry(token *api.ACLTokenListEntry, ui cli.Ui, showMeta bool)
|
|||
ui.Info(fmt.Sprintf("Create Index: %d", token.CreateIndex))
|
||||
ui.Info(fmt.Sprintf("Modify Index: %d", token.ModifyIndex))
|
||||
}
|
||||
ui.Info(fmt.Sprintf("Policies:"))
|
||||
for _, policy := range token.Policies {
|
||||
ui.Info(fmt.Sprintf(" %s - %s", policy.ID, policy.Name))
|
||||
if len(token.Policies) > 0 {
|
||||
ui.Info(fmt.Sprintf("Policies:"))
|
||||
for _, policy := range token.Policies {
|
||||
ui.Info(fmt.Sprintf(" %s - %s", policy.ID, policy.Name))
|
||||
}
|
||||
}
|
||||
ui.Info(fmt.Sprintf("Roles:"))
|
||||
for _, role := range token.Roles {
|
||||
ui.Info(fmt.Sprintf(" %s - %s", role.ID, role.Name))
|
||||
if len(token.Roles) > 0 {
|
||||
ui.Info(fmt.Sprintf("Roles:"))
|
||||
for _, role := range token.Roles {
|
||||
ui.Info(fmt.Sprintf(" %s - %s", role.ID, role.Name))
|
||||
}
|
||||
}
|
||||
ui.Info(fmt.Sprintf("Service Identities:"))
|
||||
for _, svcid := range token.ServiceIdentities {
|
||||
if len(svcid.Datacenters) > 0 {
|
||||
ui.Info(fmt.Sprintf(" %s (Datacenters: %s)", svcid.ServiceName, strings.Join(svcid.Datacenters, ", ")))
|
||||
} else {
|
||||
ui.Info(fmt.Sprintf(" %s (Datacenters: all)", svcid.ServiceName))
|
||||
if len(token.ServiceIdentities) > 0 {
|
||||
ui.Info(fmt.Sprintf("Service Identities:"))
|
||||
for _, svcid := range token.ServiceIdentities {
|
||||
if len(svcid.Datacenters) > 0 {
|
||||
ui.Info(fmt.Sprintf(" %s (Datacenters: %s)", svcid.ServiceName, strings.Join(svcid.Datacenters, ", ")))
|
||||
} else {
|
||||
ui.Info(fmt.Sprintf(" %s (Datacenters: all)", svcid.ServiceName))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -112,16 +125,20 @@ func PrintRole(role *api.ACLRole, ui cli.Ui, showMeta bool) {
|
|||
ui.Info(fmt.Sprintf("Create Index: %d", role.CreateIndex))
|
||||
ui.Info(fmt.Sprintf("Modify Index: %d", role.ModifyIndex))
|
||||
}
|
||||
ui.Info(fmt.Sprintf("Policies:"))
|
||||
for _, policy := range role.Policies {
|
||||
ui.Info(fmt.Sprintf(" %s - %s", policy.ID, policy.Name))
|
||||
if len(role.Policies) > 0 {
|
||||
ui.Info(fmt.Sprintf("Policies:"))
|
||||
for _, policy := range role.Policies {
|
||||
ui.Info(fmt.Sprintf(" %s - %s", policy.ID, policy.Name))
|
||||
}
|
||||
}
|
||||
ui.Info(fmt.Sprintf("Service Identities:"))
|
||||
for _, svcid := range role.ServiceIdentities {
|
||||
if len(svcid.Datacenters) > 0 {
|
||||
ui.Info(fmt.Sprintf(" %s (Datacenters: %s)", svcid.ServiceName, strings.Join(svcid.Datacenters, ", ")))
|
||||
} else {
|
||||
ui.Info(fmt.Sprintf(" %s (Datacenters: all)", svcid.ServiceName))
|
||||
if len(role.ServiceIdentities) > 0 {
|
||||
ui.Info(fmt.Sprintf("Service Identities:"))
|
||||
for _, svcid := range role.ServiceIdentities {
|
||||
if len(svcid.Datacenters) > 0 {
|
||||
ui.Info(fmt.Sprintf(" %s (Datacenters: %s)", svcid.ServiceName, strings.Join(svcid.Datacenters, ", ")))
|
||||
} else {
|
||||
ui.Info(fmt.Sprintf(" %s (Datacenters: all)", svcid.ServiceName))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -135,18 +152,74 @@ func PrintRoleListEntry(role *api.ACLRole, ui cli.Ui, showMeta bool) {
|
|||
ui.Info(fmt.Sprintf(" Create Index: %d", role.CreateIndex))
|
||||
ui.Info(fmt.Sprintf(" Modify Index: %d", role.ModifyIndex))
|
||||
}
|
||||
ui.Info(fmt.Sprintf(" Policies:"))
|
||||
for _, policy := range role.Policies {
|
||||
ui.Info(fmt.Sprintf(" %s - %s", policy.ID, policy.Name))
|
||||
}
|
||||
ui.Info(fmt.Sprintf(" Service Identities:"))
|
||||
for _, svcid := range role.ServiceIdentities {
|
||||
if len(svcid.Datacenters) > 0 {
|
||||
ui.Info(fmt.Sprintf(" %s (Datacenters: %s)", svcid.ServiceName, strings.Join(svcid.Datacenters, ", ")))
|
||||
} else {
|
||||
ui.Info(fmt.Sprintf(" %s (Datacenters: all)", svcid.ServiceName))
|
||||
if len(role.Policies) > 0 {
|
||||
ui.Info(fmt.Sprintf(" Policies:"))
|
||||
for _, policy := range role.Policies {
|
||||
ui.Info(fmt.Sprintf(" %s - %s", policy.ID, policy.Name))
|
||||
}
|
||||
}
|
||||
if len(role.ServiceIdentities) > 0 {
|
||||
ui.Info(fmt.Sprintf(" Service Identities:"))
|
||||
for _, svcid := range role.ServiceIdentities {
|
||||
if len(svcid.Datacenters) > 0 {
|
||||
ui.Info(fmt.Sprintf(" %s (Datacenters: %s)", svcid.ServiceName, strings.Join(svcid.Datacenters, ", ")))
|
||||
} else {
|
||||
ui.Info(fmt.Sprintf(" %s (Datacenters: all)", svcid.ServiceName))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func PrintAuthMethod(method *api.ACLAuthMethod, ui cli.Ui, showMeta bool) {
|
||||
ui.Info(fmt.Sprintf("Name: %s", method.Name))
|
||||
ui.Info(fmt.Sprintf("Type: %s", method.Type))
|
||||
ui.Info(fmt.Sprintf("Description: %s", method.Description))
|
||||
if showMeta {
|
||||
ui.Info(fmt.Sprintf("Create Index: %d", method.CreateIndex))
|
||||
ui.Info(fmt.Sprintf("Modify Index: %d", method.ModifyIndex))
|
||||
}
|
||||
ui.Info(fmt.Sprintf("Config:"))
|
||||
output, err := json.MarshalIndent(method.Config, "", " ")
|
||||
if err != nil {
|
||||
ui.Error(fmt.Sprintf("Error formatting auth method configuration: %s", err))
|
||||
}
|
||||
ui.Output(string(output))
|
||||
}
|
||||
|
||||
func PrintAuthMethodListEntry(method *api.ACLAuthMethodListEntry, ui cli.Ui, showMeta bool) {
|
||||
ui.Info(fmt.Sprintf("%s:", method.Name))
|
||||
ui.Info(fmt.Sprintf(" Type: %s", method.Type))
|
||||
ui.Info(fmt.Sprintf(" Description: %s", method.Description))
|
||||
if showMeta {
|
||||
ui.Info(fmt.Sprintf(" Create Index: %d", method.CreateIndex))
|
||||
ui.Info(fmt.Sprintf(" Modify Index: %d", method.ModifyIndex))
|
||||
}
|
||||
}
|
||||
|
||||
func PrintBindingRule(rule *api.ACLBindingRule, ui cli.Ui, showMeta bool) {
|
||||
ui.Info(fmt.Sprintf("ID: %s", rule.ID))
|
||||
ui.Info(fmt.Sprintf("AuthMethod: %s", rule.AuthMethod))
|
||||
ui.Info(fmt.Sprintf("Description: %s", rule.Description))
|
||||
ui.Info(fmt.Sprintf("BindType: %s", rule.BindType))
|
||||
ui.Info(fmt.Sprintf("BindName: %s", rule.BindName))
|
||||
ui.Info(fmt.Sprintf("Selector: %s", rule.Selector))
|
||||
if showMeta {
|
||||
ui.Info(fmt.Sprintf("Create Index: %d", rule.CreateIndex))
|
||||
ui.Info(fmt.Sprintf("Modify Index: %d", rule.ModifyIndex))
|
||||
}
|
||||
}
|
||||
|
||||
func PrintBindingRuleListEntry(rule *api.ACLBindingRule, ui cli.Ui, showMeta bool) {
|
||||
ui.Info(fmt.Sprintf("%s:", rule.ID))
|
||||
ui.Info(fmt.Sprintf(" AuthMethod: %s", rule.AuthMethod))
|
||||
ui.Info(fmt.Sprintf(" Description: %s", rule.Description))
|
||||
ui.Info(fmt.Sprintf(" BindType: %s", rule.BindType))
|
||||
ui.Info(fmt.Sprintf(" BindName: %s", rule.BindName))
|
||||
ui.Info(fmt.Sprintf(" Selector: %s", rule.Selector))
|
||||
if showMeta {
|
||||
ui.Info(fmt.Sprintf(" Create Index: %d", rule.CreateIndex))
|
||||
ui.Info(fmt.Sprintf(" Modify Index: %d", rule.ModifyIndex))
|
||||
}
|
||||
}
|
||||
|
||||
func GetTokenIDFromPartial(client *api.Client, partialID string) (string, error) {
|
||||
|
@ -309,6 +382,34 @@ func GetRoleIDByName(client *api.Client, name string) (string, error) {
|
|||
return "", fmt.Errorf("No such role with name %s", name)
|
||||
}
|
||||
|
||||
func GetBindingRuleIDFromPartial(client *api.Client, partialID string) (string, error) {
|
||||
// the full UUID string was given
|
||||
if len(partialID) == 36 {
|
||||
return partialID, nil
|
||||
}
|
||||
|
||||
rules, _, err := client.ACL().BindingRuleList("", nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
ruleID := ""
|
||||
for _, rule := range rules {
|
||||
if strings.HasPrefix(rule.ID, partialID) {
|
||||
if ruleID != "" {
|
||||
return "", fmt.Errorf("Partial rule ID is not unique")
|
||||
}
|
||||
ruleID = rule.ID
|
||||
}
|
||||
}
|
||||
|
||||
if ruleID == "" {
|
||||
return "", fmt.Errorf("No such rule ID with prefix: %s", partialID)
|
||||
}
|
||||
|
||||
return ruleID, nil
|
||||
}
|
||||
|
||||
func ExtractServiceIdentities(serviceIdents []string) ([]*api.ACLServiceIdentity, error) {
|
||||
var out []*api.ACLServiceIdentity
|
||||
for _, svcidRaw := range serviceIdents {
|
||||
|
@ -329,3 +430,27 @@ func ExtractServiceIdentities(serviceIdents []string) ([]*api.ACLServiceIdentity
|
|||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// TestKubernetesJWT_A is a valid service account jwt extracted from a minikube setup.
|
||||
//
|
||||
// {
|
||||
// "iss": "kubernetes/serviceaccount",
|
||||
// "kubernetes.io/serviceaccount/namespace": "default",
|
||||
// "kubernetes.io/serviceaccount/secret.name": "admin-token-qlz42",
|
||||
// "kubernetes.io/serviceaccount/service-account.name": "admin",
|
||||
// "kubernetes.io/serviceaccount/service-account.uid": "738bc251-6532-11e9-b67f-48e6c8b8ecb5",
|
||||
// "sub": "system:serviceaccount:default:admin"
|
||||
// }
|
||||
const TestKubernetesJWT_A = "eyJhbGciOiJSUzI1NiIsImtpZCI6IiJ9.eyJpc3MiOiJrdWJlcm5ldGVzL3NlcnZpY2VhY2NvdW50Iiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9uYW1lc3BhY2UiOiJkZWZhdWx0Iiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9zZWNyZXQubmFtZSI6ImFkbWluLXRva2VuLXFsejQyIiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9zZXJ2aWNlLWFjY291bnQubmFtZSI6ImFkbWluIiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9zZXJ2aWNlLWFjY291bnQudWlkIjoiNzM4YmMyNTEtNjUzMi0xMWU5LWI2N2YtNDhlNmM4YjhlY2I1Iiwic3ViIjoic3lzdGVtOnNlcnZpY2VhY2NvdW50OmRlZmF1bHQ6YWRtaW4ifQ.ixMlnWrAG7NVuTTKu8cdcYfM7gweS3jlKaEsIBNGOVEjPE7rtXtgMkAwjQTdYR08_0QBjkgzy5fQC5ZNyglSwONJ-bPaXGvhoH1cTnRi1dz9H_63CfqOCvQP1sbdkMeRxNTGVAyWZT76rXoCUIfHP4LY2I8aab0KN9FTIcgZRF0XPTtT70UwGIrSmRpxW38zjiy2ymWL01cc5VWGhJqVysmWmYk3wNp0h5N57H_MOrz4apQR4pKaamzskzjLxO55gpbmZFC76qWuUdexAR7DT2fpbHLOw90atN_NlLMY-VrXyW3-Ei5EhYaVreMB9PSpKwkrA4jULITohV-sxpa1LA"
|
||||
|
||||
// TestKubernetesJWT_B is a valid service account jwt extracted from a minikube setup.
|
||||
//
|
||||
// {
|
||||
// "iss": "kubernetes/serviceaccount",
|
||||
// "kubernetes.io/serviceaccount/namespace": "default",
|
||||
// "kubernetes.io/serviceaccount/secret.name": "demo-token-kmb9n",
|
||||
// "kubernetes.io/serviceaccount/service-account.name": "demo",
|
||||
// "kubernetes.io/serviceaccount/service-account.uid": "76091af4-4b56-11e9-ac4b-708b11801cbe",
|
||||
// "sub": "system:serviceaccount:default:demo"
|
||||
// }
|
||||
const TestKubernetesJWT_B = "eyJhbGciOiJSUzI1NiIsImtpZCI6IiJ9.eyJpc3MiOiJrdWJlcm5ldGVzL3NlcnZpY2VhY2NvdW50Iiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9uYW1lc3BhY2UiOiJkZWZhdWx0Iiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9zZWNyZXQubmFtZSI6ImRlbW8tdG9rZW4ta21iOW4iLCJrdWJlcm5ldGVzLmlvL3NlcnZpY2VhY2NvdW50L3NlcnZpY2UtYWNjb3VudC5uYW1lIjoiZGVtbyIsImt1YmVybmV0ZXMuaW8vc2VydmljZWFjY291bnQvc2VydmljZS1hY2NvdW50LnVpZCI6Ijc2MDkxYWY0LTRiNTYtMTFlOS1hYzRiLTcwOGIxMTgwMWNiZSIsInN1YiI6InN5c3RlbTpzZXJ2aWNlYWNjb3VudDpkZWZhdWx0OmRlbW8ifQ.ZiAHjijBAOsKdum0Aix6lgtkLkGo9_Tu87dWQ5Zfwnn3r2FejEWDAnftTft1MqqnMzivZ9Wyyki5ZjQRmTAtnMPJuHC-iivqY4Wh4S6QWCJ1SivBv5tMZR79t5t8mE7R1-OHwst46spru1pps9wt9jsA04d3LpV0eeKYgdPTVaQKklxTm397kIMUugA6yINIBQ3Rh8eQqBgNwEmL4iqyYubzHLVkGkoP9MJikFI05vfRiHtYr-piXz6JFDzXMQj9rW6xtMmrBSn79ChbyvC5nz-Nj2rJPnHsb_0rDUbmXY5PpnMhBpdSH-CbZ4j8jsiib6DtaGJhVZeEQ1GjsFAZwQ"
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
package authmethod
|
||||
|
||||
import (
|
||||
"github.com/hashicorp/consul/command/flags"
|
||||
"github.com/mitchellh/cli"
|
||||
)
|
||||
|
||||
func New() *cmd {
|
||||
return &cmd{}
|
||||
}
|
||||
|
||||
type cmd struct{}
|
||||
|
||||
func (c *cmd) Run(args []string) int {
|
||||
return cli.RunResultHelp
|
||||
}
|
||||
|
||||
func (c *cmd) Synopsis() string {
|
||||
return synopsis
|
||||
}
|
||||
|
||||
func (c *cmd) Help() string {
|
||||
return flags.Usage(help, nil)
|
||||
}
|
||||
|
||||
const synopsis = "Manage Consul's ACL Auth Methods"
|
||||
const help = `
|
||||
Usage: consul acl auth-method <subcommand> [options] [args]
|
||||
|
||||
This command has subcommands for managing Consul's ACL Auth Methods.
|
||||
Here are some simple examples, and more detailed examples are available in
|
||||
the subcommands or the documentation.
|
||||
|
||||
Create a new auth method:
|
||||
|
||||
$ consul acl auth-method create -type "kubernetes" \
|
||||
-name "my-k8s" \
|
||||
-description "This is an example kube auth method" \
|
||||
-kubernetes-host "https://apiserver.example.com:8443" \
|
||||
-kubernetes-ca-file /path/to/kube.ca.crt \
|
||||
-kubernetes-service-account-jwt "JWT_CONTENTS"
|
||||
|
||||
List all auth methods:
|
||||
|
||||
$ consul acl auth-method list
|
||||
|
||||
Update all editable fields of the auth method:
|
||||
|
||||
$ consul acl auth-method update -name "my-k8s" \
|
||||
-description "new description" \
|
||||
-kubernetes-host "https://new-apiserver.example.com:8443" \
|
||||
-kubernetes-ca-file /path/to/new-kube.ca.crt \
|
||||
-kubernetes-service-account-jwt "NEW_JWT_CONTENTS"
|
||||
|
||||
Read an auth method:
|
||||
|
||||
$ consul acl auth-method read -name my-k8s
|
||||
|
||||
Delete an auth method:
|
||||
|
||||
$ consul acl auth-method delete -name my-k8s
|
||||
|
||||
For more examples, ask for subcommand help or view the documentation.
|
||||
`
|
|
@ -0,0 +1,186 @@
|
|||
package authmethodcreate
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/hashicorp/consul/api"
|
||||
"github.com/hashicorp/consul/command/acl"
|
||||
"github.com/hashicorp/consul/command/flags"
|
||||
"github.com/hashicorp/consul/command/helpers"
|
||||
"github.com/mitchellh/cli"
|
||||
)
|
||||
|
||||
func New(ui cli.Ui) *cmd {
|
||||
c := &cmd{UI: ui}
|
||||
c.init()
|
||||
return c
|
||||
}
|
||||
|
||||
type cmd struct {
|
||||
UI cli.Ui
|
||||
flags *flag.FlagSet
|
||||
http *flags.HTTPFlags
|
||||
help string
|
||||
|
||||
authMethodType string
|
||||
name string
|
||||
description string
|
||||
|
||||
k8sHost string
|
||||
k8sCACert string
|
||||
k8sServiceAccountJWT string
|
||||
|
||||
showMeta bool
|
||||
|
||||
testStdin io.Reader
|
||||
}
|
||||
|
||||
func (c *cmd) init() {
|
||||
c.flags = flag.NewFlagSet("", flag.ContinueOnError)
|
||||
|
||||
c.flags.BoolVar(
|
||||
&c.showMeta,
|
||||
"meta",
|
||||
false,
|
||||
"Indicates that auth method metadata such "+
|
||||
"as the content hash and raft indices should be shown for each entry.",
|
||||
)
|
||||
|
||||
c.flags.StringVar(
|
||||
&c.authMethodType,
|
||||
"type",
|
||||
"",
|
||||
"The new auth method's type. This flag is required.",
|
||||
)
|
||||
c.flags.StringVar(
|
||||
&c.name,
|
||||
"name",
|
||||
"",
|
||||
"The new auth method's name. This flag is required.",
|
||||
)
|
||||
c.flags.StringVar(
|
||||
&c.description,
|
||||
"description",
|
||||
"",
|
||||
"A description of the auth method.",
|
||||
)
|
||||
|
||||
c.flags.StringVar(
|
||||
&c.k8sHost,
|
||||
"kubernetes-host",
|
||||
"",
|
||||
"Address of the Kubernetes API server. This flag is required for type=kubernetes.",
|
||||
)
|
||||
c.flags.StringVar(
|
||||
&c.k8sCACert,
|
||||
"kubernetes-ca-cert",
|
||||
"",
|
||||
"PEM encoded CA cert for use by the TLS client used to talk with the "+
|
||||
"Kubernetes API. May be prefixed with '@' to indicate that the "+
|
||||
"value is a file path to load the cert from. "+
|
||||
"This flag is required for type=kubernetes.",
|
||||
)
|
||||
c.flags.StringVar(
|
||||
&c.k8sServiceAccountJWT,
|
||||
"kubernetes-service-account-jwt",
|
||||
"",
|
||||
"A kubernetes service account JWT used to access the TokenReview API to "+
|
||||
"validate other JWTs during login. "+
|
||||
"This flag is required for type=kubernetes.",
|
||||
)
|
||||
|
||||
c.http = &flags.HTTPFlags{}
|
||||
flags.Merge(c.flags, c.http.ClientFlags())
|
||||
flags.Merge(c.flags, c.http.ServerFlags())
|
||||
c.help = flags.Usage(help, c.flags)
|
||||
}
|
||||
|
||||
func (c *cmd) Run(args []string) int {
|
||||
if err := c.flags.Parse(args); err != nil {
|
||||
return 1
|
||||
}
|
||||
|
||||
if c.authMethodType == "" {
|
||||
c.UI.Error(fmt.Sprintf("Missing required '-type' flag"))
|
||||
c.UI.Error(c.Help())
|
||||
return 1
|
||||
} else if c.name == "" {
|
||||
c.UI.Error(fmt.Sprintf("Missing required '-name' flag"))
|
||||
c.UI.Error(c.Help())
|
||||
return 1
|
||||
}
|
||||
|
||||
client, err := c.http.APIClient()
|
||||
if err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Error connecting to Consul agent: %s", err))
|
||||
return 1
|
||||
}
|
||||
|
||||
newAuthMethod := &api.ACLAuthMethod{
|
||||
Type: c.authMethodType,
|
||||
Name: c.name,
|
||||
Description: c.description,
|
||||
}
|
||||
|
||||
if c.authMethodType == "kubernetes" {
|
||||
if c.k8sHost == "" {
|
||||
c.UI.Error(fmt.Sprintf("Missing required '-kubernetes-host' flag"))
|
||||
return 1
|
||||
} else if c.k8sCACert == "" {
|
||||
c.UI.Error(fmt.Sprintf("Missing required '-kubernetes-ca-cert' flag"))
|
||||
return 1
|
||||
} else if c.k8sServiceAccountJWT == "" {
|
||||
c.UI.Error(fmt.Sprintf("Missing required '-kubernetes-service-account-jwt' flag"))
|
||||
return 1
|
||||
}
|
||||
|
||||
c.k8sCACert, err = helpers.LoadDataSource(c.k8sCACert, c.testStdin)
|
||||
if err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Invalid '-kubernetes-ca-cert' value: %v", err))
|
||||
return 1
|
||||
} else if c.k8sCACert == "" {
|
||||
c.UI.Error(fmt.Sprintf("Kubernetes CA Cert is empty"))
|
||||
return 1
|
||||
}
|
||||
|
||||
newAuthMethod.Config = map[string]interface{}{
|
||||
"Host": c.k8sHost,
|
||||
"CACert": c.k8sCACert,
|
||||
"ServiceAccountJWT": c.k8sServiceAccountJWT,
|
||||
}
|
||||
}
|
||||
|
||||
method, _, err := client.ACL().AuthMethodCreate(newAuthMethod, nil)
|
||||
if err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Failed to create new auth method: %v", err))
|
||||
return 1
|
||||
}
|
||||
|
||||
acl.PrintAuthMethod(method, c.UI, c.showMeta)
|
||||
return 0
|
||||
}
|
||||
|
||||
func (c *cmd) Synopsis() string {
|
||||
return synopsis
|
||||
}
|
||||
|
||||
func (c *cmd) Help() string {
|
||||
return flags.Usage(c.help, nil)
|
||||
}
|
||||
|
||||
const synopsis = "Create an ACL Auth Method"
|
||||
|
||||
const help = `
|
||||
Usage: consul acl auth-method create -name NAME -type TYPE [options]
|
||||
|
||||
Create a new auth method:
|
||||
|
||||
$ consul acl auth-method create -type "kubernetes" \
|
||||
-name "my-k8s" \
|
||||
-description "This is an example kube method" \
|
||||
-kubernetes-host "https://apiserver.example.com:8443" \
|
||||
-kubernetes-ca-file /path/to/kube.ca.crt \
|
||||
-kubernetes-service-account-jwt "JWT_CONTENTS"
|
||||
`
|
|
@ -0,0 +1,226 @@
|
|||
package authmethodcreate
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/consul/agent"
|
||||
"github.com/hashicorp/consul/agent/connect"
|
||||
"github.com/hashicorp/consul/command/acl"
|
||||
"github.com/hashicorp/consul/logger"
|
||||
"github.com/hashicorp/consul/sdk/testutil"
|
||||
"github.com/hashicorp/consul/testrpc"
|
||||
"github.com/mitchellh/cli"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
// activate testing auth method
|
||||
_ "github.com/hashicorp/consul/agent/consul/authmethod/testauth"
|
||||
)
|
||||
|
||||
func TestAuthMethodCreateCommand_noTabs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if strings.ContainsRune(New(cli.NewMockUi()).Help(), '\t') {
|
||||
t.Fatal("help has tabs")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthMethodCreateCommand(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testDir := testutil.TempDir(t, "acl")
|
||||
defer os.RemoveAll(testDir)
|
||||
|
||||
a := agent.NewTestAgent(t, t.Name(), `
|
||||
primary_datacenter = "dc1"
|
||||
acl {
|
||||
enabled = true
|
||||
tokens {
|
||||
master = "root"
|
||||
}
|
||||
}`)
|
||||
|
||||
a.Agent.LogWriter = logger.NewLogWriter(512)
|
||||
|
||||
defer a.Shutdown()
|
||||
testrpc.WaitForLeader(t, a.RPC, "dc1")
|
||||
|
||||
t.Run("type required", func(t *testing.T) {
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
}
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 1)
|
||||
require.Contains(t, ui.ErrorWriter.String(), "Missing required '-type' flag")
|
||||
})
|
||||
|
||||
t.Run("name required", func(t *testing.T) {
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
"-type=testing",
|
||||
}
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 1)
|
||||
require.Contains(t, ui.ErrorWriter.String(), "Missing required '-name' flag")
|
||||
})
|
||||
|
||||
t.Run("invalid type", func(t *testing.T) {
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
"-type=invalid",
|
||||
"-name=my-method",
|
||||
}
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 1)
|
||||
require.Contains(t, ui.ErrorWriter.String(), "Invalid Auth Method: Type should be one of")
|
||||
})
|
||||
|
||||
t.Run("create testing", func(t *testing.T) {
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
"-type=testing",
|
||||
"-name=test",
|
||||
}
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 0)
|
||||
require.Empty(t, ui.ErrorWriter.String())
|
||||
})
|
||||
}
|
||||
|
||||
func TestAuthMethodCreateCommand_k8s(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testDir := testutil.TempDir(t, "acl")
|
||||
defer os.RemoveAll(testDir)
|
||||
|
||||
a := agent.NewTestAgent(t, t.Name(), `
|
||||
primary_datacenter = "dc1"
|
||||
acl {
|
||||
enabled = true
|
||||
tokens {
|
||||
master = "root"
|
||||
}
|
||||
}`)
|
||||
|
||||
a.Agent.LogWriter = logger.NewLogWriter(512)
|
||||
|
||||
defer a.Shutdown()
|
||||
testrpc.WaitForLeader(t, a.RPC, "dc1")
|
||||
|
||||
t.Run("k8s host required", func(t *testing.T) {
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
"-type=kubernetes",
|
||||
"-name=k8s",
|
||||
}
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 1)
|
||||
require.Contains(t, ui.ErrorWriter.String(), "Missing required '-kubernetes-host' flag")
|
||||
})
|
||||
|
||||
t.Run("k8s ca cert required", func(t *testing.T) {
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
"-type=kubernetes",
|
||||
"-name=k8s",
|
||||
"-kubernetes-host=https://foo.internal:8443",
|
||||
}
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 1)
|
||||
require.Contains(t, ui.ErrorWriter.String(), "Missing required '-kubernetes-ca-cert' flag")
|
||||
})
|
||||
|
||||
ca := connect.TestCA(t, nil)
|
||||
|
||||
t.Run("k8s jwt required", func(t *testing.T) {
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
"-type=kubernetes",
|
||||
"-name=k8s",
|
||||
"-kubernetes-host=https://foo.internal:8443",
|
||||
"-kubernetes-ca-cert", ca.RootCert,
|
||||
}
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 1)
|
||||
require.Contains(t, ui.ErrorWriter.String(), "Missing required '-kubernetes-service-account-jwt' flag")
|
||||
})
|
||||
|
||||
t.Run("create k8s", func(t *testing.T) {
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
"-type=kubernetes",
|
||||
"-name=k8s",
|
||||
"-kubernetes-host", "https://foo.internal:8443",
|
||||
"-kubernetes-ca-cert", ca.RootCert,
|
||||
"-kubernetes-service-account-jwt", acl.TestKubernetesJWT_A,
|
||||
}
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 0)
|
||||
require.Empty(t, ui.ErrorWriter.String())
|
||||
})
|
||||
|
||||
caFile := filepath.Join(testDir, "ca.crt")
|
||||
require.NoError(t, ioutil.WriteFile(caFile, []byte(ca.RootCert), 0600))
|
||||
|
||||
t.Run("create k8s with cert file", func(t *testing.T) {
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
"-type=kubernetes",
|
||||
"-name=k8s",
|
||||
"-kubernetes-host", "https://foo.internal:8443",
|
||||
"-kubernetes-ca-cert", "@" + caFile,
|
||||
"-kubernetes-service-account-jwt", acl.TestKubernetesJWT_A,
|
||||
}
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 0)
|
||||
require.Empty(t, ui.ErrorWriter.String())
|
||||
})
|
||||
}
|
|
@ -0,0 +1,82 @@
|
|||
package authmethoddelete
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/consul/command/flags"
|
||||
"github.com/mitchellh/cli"
|
||||
)
|
||||
|
||||
func New(ui cli.Ui) *cmd {
|
||||
c := &cmd{UI: ui}
|
||||
c.init()
|
||||
return c
|
||||
}
|
||||
|
||||
type cmd struct {
|
||||
UI cli.Ui
|
||||
flags *flag.FlagSet
|
||||
http *flags.HTTPFlags
|
||||
help string
|
||||
|
||||
name string
|
||||
}
|
||||
|
||||
func (c *cmd) init() {
|
||||
c.flags = flag.NewFlagSet("", flag.ContinueOnError)
|
||||
|
||||
c.flags.StringVar(
|
||||
&c.name,
|
||||
"name",
|
||||
"",
|
||||
"The name of the auth method to delete.",
|
||||
)
|
||||
|
||||
c.http = &flags.HTTPFlags{}
|
||||
flags.Merge(c.flags, c.http.ClientFlags())
|
||||
flags.Merge(c.flags, c.http.ServerFlags())
|
||||
c.help = flags.Usage(help, c.flags)
|
||||
}
|
||||
|
||||
func (c *cmd) Run(args []string) int {
|
||||
if err := c.flags.Parse(args); err != nil {
|
||||
return 1
|
||||
}
|
||||
|
||||
if c.name == "" {
|
||||
c.UI.Error(fmt.Sprintf("Must specify the -name parameter"))
|
||||
return 1
|
||||
}
|
||||
|
||||
client, err := c.http.APIClient()
|
||||
if err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Error connecting to Consul agent: %s", err))
|
||||
return 1
|
||||
}
|
||||
|
||||
if _, err := client.ACL().AuthMethodDelete(c.name, nil); err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Error deleting auth method %q: %v", c.name, err))
|
||||
return 1
|
||||
}
|
||||
|
||||
c.UI.Info(fmt.Sprintf("Auth method %q deleted successfully", c.name))
|
||||
return 0
|
||||
}
|
||||
|
||||
func (c *cmd) Synopsis() string {
|
||||
return synopsis
|
||||
}
|
||||
|
||||
func (c *cmd) Help() string {
|
||||
return flags.Usage(c.help, nil)
|
||||
}
|
||||
|
||||
const synopsis = "Delete an ACL Auth Method"
|
||||
const help = `
|
||||
Usage: consul acl auth-method delete -name NAME [options]
|
||||
|
||||
Delete an auth method:
|
||||
|
||||
$ consul acl auth-method delete -name "my-auth-method"
|
||||
`
|
|
@ -0,0 +1,131 @@
|
|||
package authmethoddelete
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/consul/agent"
|
||||
"github.com/hashicorp/consul/api"
|
||||
"github.com/hashicorp/consul/logger"
|
||||
"github.com/hashicorp/consul/sdk/testutil"
|
||||
"github.com/hashicorp/consul/testrpc"
|
||||
"github.com/hashicorp/go-uuid"
|
||||
"github.com/mitchellh/cli"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
// activate testing auth method
|
||||
_ "github.com/hashicorp/consul/agent/consul/authmethod/testauth"
|
||||
)
|
||||
|
||||
func TestAuthMethodDeleteCommand_noTabs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if strings.ContainsRune(New(cli.NewMockUi()).Help(), '\t') {
|
||||
t.Fatal("help has tabs")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthMethodDeleteCommand(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testDir := testutil.TempDir(t, "acl")
|
||||
defer os.RemoveAll(testDir)
|
||||
|
||||
a := agent.NewTestAgent(t, t.Name(), `
|
||||
primary_datacenter = "dc1"
|
||||
acl {
|
||||
enabled = true
|
||||
tokens {
|
||||
master = "root"
|
||||
}
|
||||
}`)
|
||||
|
||||
a.Agent.LogWriter = logger.NewLogWriter(512)
|
||||
|
||||
defer a.Shutdown()
|
||||
testrpc.WaitForLeader(t, a.RPC, "dc1")
|
||||
|
||||
client := a.Client()
|
||||
|
||||
t.Run("name required", func(t *testing.T) {
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
}
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 1)
|
||||
require.Contains(t, ui.ErrorWriter.String(), "Must specify the -name parameter")
|
||||
})
|
||||
|
||||
t.Run("delete notfound", func(t *testing.T) {
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
"-name=notfound",
|
||||
}
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 0)
|
||||
require.Empty(t, ui.ErrorWriter.String())
|
||||
|
||||
output := ui.OutputWriter.String()
|
||||
require.Contains(t, output, fmt.Sprintf("deleted successfully"))
|
||||
require.Contains(t, output, "notfound")
|
||||
})
|
||||
|
||||
createAuthMethod := func(t *testing.T) string {
|
||||
id, err := uuid.GenerateUUID()
|
||||
require.NoError(t, err)
|
||||
|
||||
methodName := "test-" + id
|
||||
|
||||
_, _, err = client.ACL().AuthMethodCreate(
|
||||
&api.ACLAuthMethod{
|
||||
Name: methodName,
|
||||
Type: "testing",
|
||||
Description: "test",
|
||||
},
|
||||
&api.WriteOptions{Token: "root"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
return methodName
|
||||
}
|
||||
|
||||
t.Run("delete works", func(t *testing.T) {
|
||||
name := createAuthMethod(t)
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
"-name=" + name,
|
||||
}
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 0)
|
||||
require.Empty(t, ui.ErrorWriter.String())
|
||||
|
||||
output := ui.OutputWriter.String()
|
||||
require.Contains(t, output, fmt.Sprintf("deleted successfully"))
|
||||
require.Contains(t, output, name)
|
||||
|
||||
method, _, err := client.ACL().AuthMethodRead(
|
||||
name,
|
||||
&api.QueryOptions{Token: "root"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, method)
|
||||
})
|
||||
}
|
|
@ -0,0 +1,83 @@
|
|||
package authmethodlist
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/consul/command/acl"
|
||||
"github.com/hashicorp/consul/command/flags"
|
||||
"github.com/mitchellh/cli"
|
||||
)
|
||||
|
||||
func New(ui cli.Ui) *cmd {
|
||||
c := &cmd{UI: ui}
|
||||
c.init()
|
||||
return c
|
||||
}
|
||||
|
||||
type cmd struct {
|
||||
UI cli.Ui
|
||||
flags *flag.FlagSet
|
||||
http *flags.HTTPFlags
|
||||
help string
|
||||
|
||||
showMeta bool
|
||||
}
|
||||
|
||||
func (c *cmd) init() {
|
||||
c.flags = flag.NewFlagSet("", flag.ContinueOnError)
|
||||
|
||||
c.flags.BoolVar(
|
||||
&c.showMeta,
|
||||
"meta",
|
||||
false,
|
||||
"Indicates that auth method metadata such "+
|
||||
"as the content hash and raft indices should be shown for each entry.",
|
||||
)
|
||||
|
||||
c.http = &flags.HTTPFlags{}
|
||||
flags.Merge(c.flags, c.http.ClientFlags())
|
||||
flags.Merge(c.flags, c.http.ServerFlags())
|
||||
c.help = flags.Usage(help, c.flags)
|
||||
}
|
||||
|
||||
func (c *cmd) Run(args []string) int {
|
||||
if err := c.flags.Parse(args); err != nil {
|
||||
return 1
|
||||
}
|
||||
|
||||
client, err := c.http.APIClient()
|
||||
if err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Error connecting to Consul agent: %s", err))
|
||||
return 1
|
||||
}
|
||||
|
||||
methods, _, err := client.ACL().AuthMethodList(nil)
|
||||
if err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Failed to retrieve the auth method list: %v", err))
|
||||
return 1
|
||||
}
|
||||
|
||||
for _, method := range methods {
|
||||
acl.PrintAuthMethodListEntry(method, c.UI, c.showMeta)
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
func (c *cmd) Synopsis() string {
|
||||
return synopsis
|
||||
}
|
||||
|
||||
func (c *cmd) Help() string {
|
||||
return flags.Usage(c.help, nil)
|
||||
}
|
||||
|
||||
const synopsis = "Lists ACL Auth Methods"
|
||||
const help = `
|
||||
Usage: consul acl auth-method list [options]
|
||||
|
||||
List all auth methods:
|
||||
|
||||
$ consul acl auth-method list
|
||||
`
|
|
@ -0,0 +1,109 @@
|
|||
package authmethodlist
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/consul/agent"
|
||||
"github.com/hashicorp/consul/api"
|
||||
"github.com/hashicorp/consul/logger"
|
||||
"github.com/hashicorp/consul/sdk/testutil"
|
||||
"github.com/hashicorp/consul/testrpc"
|
||||
"github.com/hashicorp/go-uuid"
|
||||
"github.com/mitchellh/cli"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
// activate testing auth method
|
||||
_ "github.com/hashicorp/consul/agent/consul/authmethod/testauth"
|
||||
)
|
||||
|
||||
func TestAuthMethodListCommand_noTabs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if strings.ContainsRune(New(cli.NewMockUi()).Help(), '\t') {
|
||||
t.Fatal("help has tabs")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthMethodListCommand(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testDir := testutil.TempDir(t, "acl")
|
||||
defer os.RemoveAll(testDir)
|
||||
|
||||
a := agent.NewTestAgent(t, t.Name(), `
|
||||
primary_datacenter = "dc1"
|
||||
acl {
|
||||
enabled = true
|
||||
tokens {
|
||||
master = "root"
|
||||
}
|
||||
}`)
|
||||
|
||||
a.Agent.LogWriter = logger.NewLogWriter(512)
|
||||
|
||||
defer a.Shutdown()
|
||||
testrpc.WaitForLeader(t, a.RPC, "dc1")
|
||||
|
||||
t.Run("found none", func(t *testing.T) {
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
}
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 0)
|
||||
require.Empty(t, ui.ErrorWriter.String())
|
||||
require.Empty(t, ui.OutputWriter.String())
|
||||
})
|
||||
|
||||
client := a.Client()
|
||||
|
||||
createAuthMethod := func(t *testing.T) string {
|
||||
id, err := uuid.GenerateUUID()
|
||||
require.NoError(t, err)
|
||||
|
||||
methodName := "test-" + id
|
||||
|
||||
_, _, err = client.ACL().AuthMethodCreate(
|
||||
&api.ACLAuthMethod{
|
||||
Name: methodName,
|
||||
Type: "testing",
|
||||
Description: "test",
|
||||
},
|
||||
&api.WriteOptions{Token: "root"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
return methodName
|
||||
}
|
||||
|
||||
var methodNames []string
|
||||
for i := 0; i < 5; i++ {
|
||||
methodName := createAuthMethod(t)
|
||||
methodNames = append(methodNames, methodName)
|
||||
}
|
||||
|
||||
t.Run("found some", func(t *testing.T) {
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
}
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 0)
|
||||
require.Empty(t, ui.ErrorWriter.String())
|
||||
output := ui.OutputWriter.String()
|
||||
|
||||
for _, methodName := range methodNames {
|
||||
require.Contains(t, output, methodName)
|
||||
}
|
||||
})
|
||||
}
|
|
@ -0,0 +1,96 @@
|
|||
package authmethodread
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/consul/command/acl"
|
||||
"github.com/hashicorp/consul/command/flags"
|
||||
"github.com/mitchellh/cli"
|
||||
)
|
||||
|
||||
func New(ui cli.Ui) *cmd {
|
||||
c := &cmd{UI: ui}
|
||||
c.init()
|
||||
return c
|
||||
}
|
||||
|
||||
type cmd struct {
|
||||
UI cli.Ui
|
||||
flags *flag.FlagSet
|
||||
http *flags.HTTPFlags
|
||||
help string
|
||||
|
||||
name string
|
||||
|
||||
showMeta bool
|
||||
}
|
||||
|
||||
func (c *cmd) init() {
|
||||
c.flags = flag.NewFlagSet("", flag.ContinueOnError)
|
||||
|
||||
c.flags.BoolVar(
|
||||
&c.showMeta,
|
||||
"meta",
|
||||
false,
|
||||
"Indicates that auth method metadata such "+
|
||||
"as the content hash and raft indices should be shown for each entry.",
|
||||
)
|
||||
|
||||
c.flags.StringVar(
|
||||
&c.name,
|
||||
"name",
|
||||
"",
|
||||
"The name of the auth method to read.",
|
||||
)
|
||||
|
||||
c.http = &flags.HTTPFlags{}
|
||||
flags.Merge(c.flags, c.http.ClientFlags())
|
||||
flags.Merge(c.flags, c.http.ServerFlags())
|
||||
c.help = flags.Usage(help, c.flags)
|
||||
}
|
||||
|
||||
func (c *cmd) Run(args []string) int {
|
||||
if err := c.flags.Parse(args); err != nil {
|
||||
return 1
|
||||
}
|
||||
|
||||
if c.name == "" {
|
||||
c.UI.Error(fmt.Sprintf("Must specify the -name parameter"))
|
||||
return 1
|
||||
}
|
||||
|
||||
client, err := c.http.APIClient()
|
||||
if err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Error connecting to Consul agent: %s", err))
|
||||
return 1
|
||||
}
|
||||
|
||||
method, _, err := client.ACL().AuthMethodRead(c.name, nil)
|
||||
if err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Error reading auth method %q: %v", c.name, err))
|
||||
return 1
|
||||
} else if method == nil {
|
||||
c.UI.Error(fmt.Sprintf("Auth method not found with name %q", c.name))
|
||||
return 1
|
||||
}
|
||||
acl.PrintAuthMethod(method, c.UI, c.showMeta)
|
||||
return 0
|
||||
}
|
||||
|
||||
func (c *cmd) Synopsis() string {
|
||||
return synopsis
|
||||
}
|
||||
|
||||
func (c *cmd) Help() string {
|
||||
return flags.Usage(c.help, nil)
|
||||
}
|
||||
|
||||
const synopsis = "Read an ACL Auth Method"
|
||||
const help = `
|
||||
Usage: consul acl auth-method read -name NAME [options]
|
||||
|
||||
Read an auth method:
|
||||
|
||||
$ consul acl auth-method read -name my-auth-method
|
||||
`
|
|
@ -0,0 +1,118 @@
|
|||
package authmethodread
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/consul/agent"
|
||||
"github.com/hashicorp/consul/api"
|
||||
"github.com/hashicorp/consul/logger"
|
||||
"github.com/hashicorp/consul/sdk/testutil"
|
||||
"github.com/hashicorp/consul/testrpc"
|
||||
"github.com/hashicorp/go-uuid"
|
||||
"github.com/mitchellh/cli"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
// activate testing auth method
|
||||
_ "github.com/hashicorp/consul/agent/consul/authmethod/testauth"
|
||||
)
|
||||
|
||||
func TestAuthMethodReadCommand_noTabs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if strings.ContainsRune(New(cli.NewMockUi()).Help(), '\t') {
|
||||
t.Fatal("help has tabs")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthMethodReadCommand(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testDir := testutil.TempDir(t, "acl")
|
||||
defer os.RemoveAll(testDir)
|
||||
|
||||
a := agent.NewTestAgent(t, t.Name(), `
|
||||
primary_datacenter = "dc1"
|
||||
acl {
|
||||
enabled = true
|
||||
tokens {
|
||||
master = "root"
|
||||
}
|
||||
}`)
|
||||
|
||||
a.Agent.LogWriter = logger.NewLogWriter(512)
|
||||
|
||||
defer a.Shutdown()
|
||||
testrpc.WaitForLeader(t, a.RPC, "dc1")
|
||||
|
||||
client := a.Client()
|
||||
|
||||
t.Run("name required", func(t *testing.T) {
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
}
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 1)
|
||||
require.Contains(t, ui.ErrorWriter.String(), "Must specify the -name parameter")
|
||||
})
|
||||
|
||||
t.Run("not found", func(t *testing.T) {
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
"-name=notfound",
|
||||
}
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 1)
|
||||
require.Contains(t, ui.ErrorWriter.String(), "Auth method not found with name")
|
||||
})
|
||||
|
||||
createAuthMethod := func(t *testing.T) string {
|
||||
id, err := uuid.GenerateUUID()
|
||||
require.NoError(t, err)
|
||||
|
||||
methodName := "test-" + id
|
||||
|
||||
_, _, err = client.ACL().AuthMethodCreate(
|
||||
&api.ACLAuthMethod{
|
||||
Name: methodName,
|
||||
Type: "testing",
|
||||
Description: "test",
|
||||
},
|
||||
&api.WriteOptions{Token: "root"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
return methodName
|
||||
}
|
||||
|
||||
t.Run("read by name", func(t *testing.T) {
|
||||
name := createAuthMethod(t)
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
"-name=" + name,
|
||||
}
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 0)
|
||||
require.Empty(t, ui.ErrorWriter.String())
|
||||
|
||||
output := ui.OutputWriter.String()
|
||||
require.Contains(t, output, name)
|
||||
})
|
||||
}
|
|
@ -0,0 +1,220 @@
|
|||
package authmethodupdate
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/hashicorp/consul/api"
|
||||
"github.com/hashicorp/consul/command/acl"
|
||||
"github.com/hashicorp/consul/command/flags"
|
||||
"github.com/hashicorp/consul/command/helpers"
|
||||
"github.com/mitchellh/cli"
|
||||
)
|
||||
|
||||
func New(ui cli.Ui) *cmd {
|
||||
c := &cmd{UI: ui}
|
||||
c.init()
|
||||
return c
|
||||
}
|
||||
|
||||
type cmd struct {
|
||||
UI cli.Ui
|
||||
flags *flag.FlagSet
|
||||
http *flags.HTTPFlags
|
||||
help string
|
||||
|
||||
name string
|
||||
|
||||
description string
|
||||
|
||||
k8sHost string
|
||||
k8sCACert string
|
||||
k8sServiceAccountJWT string
|
||||
|
||||
noMerge bool
|
||||
showMeta bool
|
||||
|
||||
testStdin io.Reader
|
||||
}
|
||||
|
||||
func (c *cmd) init() {
|
||||
c.flags = flag.NewFlagSet("", flag.ContinueOnError)
|
||||
|
||||
c.flags.BoolVar(
|
||||
&c.showMeta,
|
||||
"meta",
|
||||
false,
|
||||
"Indicates that auth method metadata such "+
|
||||
"as the content hash and raft indices should be shown for each entry.",
|
||||
)
|
||||
|
||||
c.flags.StringVar(
|
||||
&c.name,
|
||||
"name",
|
||||
"",
|
||||
"The auth method name.",
|
||||
)
|
||||
|
||||
c.flags.StringVar(
|
||||
&c.description,
|
||||
"description",
|
||||
"",
|
||||
"A description of the auth method.",
|
||||
)
|
||||
|
||||
c.flags.StringVar(
|
||||
&c.k8sHost,
|
||||
"kubernetes-host",
|
||||
"",
|
||||
"Address of the Kubernetes API server. This flag is required for type=kubernetes.",
|
||||
)
|
||||
c.flags.StringVar(
|
||||
&c.k8sCACert,
|
||||
"kubernetes-ca-cert",
|
||||
"",
|
||||
"PEM encoded CA cert for use by the TLS client used to talk with the "+
|
||||
"Kubernetes API. May be prefixed with '@' to indicate that the "+
|
||||
"value is a file path to load the cert from. "+
|
||||
"This flag is required for type=kubernetes.",
|
||||
)
|
||||
c.flags.StringVar(
|
||||
&c.k8sServiceAccountJWT,
|
||||
"kubernetes-service-account-jwt",
|
||||
"",
|
||||
"A kubernetes service account JWT used to access the TokenReview API to "+
|
||||
"validate other JWTs during login. "+
|
||||
"This flag is required for type=kubernetes.",
|
||||
)
|
||||
|
||||
c.flags.BoolVar(&c.noMerge, "no-merge", false, "Do not merge the current auth method "+
|
||||
"information with what is provided to the command. Instead overwrite all fields "+
|
||||
"with the exception of the name which is immutable.")
|
||||
|
||||
c.http = &flags.HTTPFlags{}
|
||||
flags.Merge(c.flags, c.http.ClientFlags())
|
||||
flags.Merge(c.flags, c.http.ServerFlags())
|
||||
c.help = flags.Usage(help, c.flags)
|
||||
}
|
||||
|
||||
func (c *cmd) Run(args []string) int {
|
||||
if err := c.flags.Parse(args); err != nil {
|
||||
return 1
|
||||
}
|
||||
|
||||
if c.name == "" {
|
||||
c.UI.Error(fmt.Sprintf("Cannot update an auth method without specifying the -name parameter"))
|
||||
return 1
|
||||
}
|
||||
|
||||
client, err := c.http.APIClient()
|
||||
if err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Error connecting to Consul agent: %s", err))
|
||||
return 1
|
||||
}
|
||||
|
||||
// Regardless of merge, we need to fetch the prior immutable fields first.
|
||||
currentAuthMethod, _, err := client.ACL().AuthMethodRead(c.name, nil)
|
||||
if err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Error when retrieving current auth method: %v", err))
|
||||
return 1
|
||||
} else if currentAuthMethod == nil {
|
||||
c.UI.Error(fmt.Sprintf("Auth method not found with name %q", c.name))
|
||||
return 1
|
||||
}
|
||||
|
||||
if c.k8sCACert != "" {
|
||||
c.k8sCACert, err = helpers.LoadDataSource(c.k8sCACert, c.testStdin)
|
||||
if err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Invalid '-kubernetes-ca-cert' value: %v", err))
|
||||
return 1
|
||||
} else if c.k8sCACert == "" {
|
||||
c.UI.Error(fmt.Sprintf("Kubernetes CA Cert is empty"))
|
||||
return 1
|
||||
}
|
||||
}
|
||||
|
||||
var method *api.ACLAuthMethod
|
||||
if c.noMerge {
|
||||
method = &api.ACLAuthMethod{
|
||||
Name: currentAuthMethod.Name,
|
||||
Type: currentAuthMethod.Type,
|
||||
Description: c.description,
|
||||
}
|
||||
|
||||
if currentAuthMethod.Type == "kubernetes" {
|
||||
if c.k8sHost == "" {
|
||||
c.UI.Error(fmt.Sprintf("Missing required '-kubernetes-host' flag"))
|
||||
return 1
|
||||
} else if c.k8sCACert == "" {
|
||||
c.UI.Error(fmt.Sprintf("Missing required '-kubernetes-ca-cert' flag"))
|
||||
return 1
|
||||
} else if c.k8sServiceAccountJWT == "" {
|
||||
c.UI.Error(fmt.Sprintf("Missing required '-kubernetes-service-account-jwt' flag"))
|
||||
return 1
|
||||
}
|
||||
|
||||
method.Config = map[string]interface{}{
|
||||
"Host": c.k8sHost,
|
||||
"CACert": c.k8sCACert,
|
||||
"ServiceAccountJWT": c.k8sServiceAccountJWT,
|
||||
}
|
||||
}
|
||||
} else {
|
||||
methodCopy := *currentAuthMethod
|
||||
method = &methodCopy
|
||||
|
||||
if c.description != "" {
|
||||
method.Description = c.description
|
||||
}
|
||||
if method.Config == nil {
|
||||
method.Config = make(map[string]interface{})
|
||||
}
|
||||
if currentAuthMethod.Type == "kubernetes" {
|
||||
if c.k8sHost != "" {
|
||||
method.Config["Host"] = c.k8sHost
|
||||
}
|
||||
if c.k8sCACert != "" {
|
||||
method.Config["CACert"] = c.k8sCACert
|
||||
}
|
||||
if c.k8sServiceAccountJWT != "" {
|
||||
method.Config["ServiceAccountJWT"] = c.k8sServiceAccountJWT
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
method, _, err = client.ACL().AuthMethodUpdate(method, nil)
|
||||
if err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Error updating auth method %q: %v", c.name, err))
|
||||
return 1
|
||||
}
|
||||
|
||||
c.UI.Info(fmt.Sprintf("Auth method updated successfully"))
|
||||
acl.PrintAuthMethod(method, c.UI, c.showMeta)
|
||||
return 0
|
||||
}
|
||||
|
||||
func (c *cmd) Synopsis() string {
|
||||
return synopsis
|
||||
}
|
||||
|
||||
func (c *cmd) Help() string {
|
||||
return flags.Usage(c.help, nil)
|
||||
}
|
||||
|
||||
const synopsis = "Update an ACL Auth Method"
|
||||
const help = `
|
||||
Usage: consul acl auth-method update -name NAME [options]
|
||||
|
||||
Updates an auth method. By default it will merge the auth method
|
||||
information with its current state so that you do not have to provide all
|
||||
parameters. This behavior can be disabled by passing -no-merge.
|
||||
|
||||
Update all editable fields of the auth method:
|
||||
|
||||
$ consul acl auth-method update -name "my-k8s" \
|
||||
-description "new description" \
|
||||
-kubernetes-host "https://new-apiserver.example.com:8443" \
|
||||
-kubernetes-ca-file /path/to/new-kube.ca.crt \
|
||||
-kubernetes-service-account-jwt "NEW_JWT_CONTENTS"
|
||||
`
|
|
@ -0,0 +1,647 @@
|
|||
package authmethodupdate
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/consul/agent"
|
||||
"github.com/hashicorp/consul/agent/connect"
|
||||
"github.com/hashicorp/consul/api"
|
||||
"github.com/hashicorp/consul/command/acl"
|
||||
"github.com/hashicorp/consul/logger"
|
||||
"github.com/hashicorp/consul/sdk/testutil"
|
||||
"github.com/hashicorp/consul/testrpc"
|
||||
"github.com/hashicorp/go-uuid"
|
||||
"github.com/mitchellh/cli"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
// activate testing auth method
|
||||
_ "github.com/hashicorp/consul/agent/consul/authmethod/testauth"
|
||||
)
|
||||
|
||||
func TestAuthMethodUpdateCommand_noTabs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if strings.ContainsRune(New(cli.NewMockUi()).Help(), '\t') {
|
||||
t.Fatal("help has tabs")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthMethodUpdateCommand(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testDir := testutil.TempDir(t, "acl")
|
||||
defer os.RemoveAll(testDir)
|
||||
|
||||
a := agent.NewTestAgent(t, t.Name(), `
|
||||
primary_datacenter = "dc1"
|
||||
acl {
|
||||
enabled = true
|
||||
tokens {
|
||||
master = "root"
|
||||
}
|
||||
}`)
|
||||
|
||||
a.Agent.LogWriter = logger.NewLogWriter(512)
|
||||
|
||||
defer a.Shutdown()
|
||||
testrpc.WaitForLeader(t, a.RPC, "dc1")
|
||||
|
||||
client := a.Client()
|
||||
|
||||
t.Run("update without name", func(t *testing.T) {
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
}
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 1)
|
||||
require.Contains(t, ui.ErrorWriter.String(), "Cannot update an auth method without specifying the -name parameter")
|
||||
})
|
||||
|
||||
t.Run("update nonexistent method", func(t *testing.T) {
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
"-name=test",
|
||||
}
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 1)
|
||||
require.Contains(t, ui.ErrorWriter.String(), "Auth method not found with name")
|
||||
})
|
||||
|
||||
createAuthMethod := func(t *testing.T) string {
|
||||
id, err := uuid.GenerateUUID()
|
||||
require.NoError(t, err)
|
||||
|
||||
methodName := "test-" + id
|
||||
|
||||
_, _, err = client.ACL().AuthMethodCreate(
|
||||
&api.ACLAuthMethod{
|
||||
Name: methodName,
|
||||
Type: "testing",
|
||||
Description: "test",
|
||||
},
|
||||
&api.WriteOptions{Token: "root"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
return methodName
|
||||
}
|
||||
|
||||
t.Run("update all fields", func(t *testing.T) {
|
||||
name := createAuthMethod(t)
|
||||
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
"-name=" + name,
|
||||
"-description", "updated description",
|
||||
}
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 0)
|
||||
require.Empty(t, ui.ErrorWriter.String())
|
||||
|
||||
method, _, err := client.ACL().AuthMethodRead(
|
||||
name,
|
||||
&api.QueryOptions{Token: "root"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, method)
|
||||
require.Equal(t, "updated description", method.Description)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAuthMethodUpdateCommand_noMerge(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testDir := testutil.TempDir(t, "acl")
|
||||
defer os.RemoveAll(testDir)
|
||||
|
||||
a := agent.NewTestAgent(t, t.Name(), `
|
||||
primary_datacenter = "dc1"
|
||||
acl {
|
||||
enabled = true
|
||||
tokens {
|
||||
master = "root"
|
||||
}
|
||||
}`)
|
||||
|
||||
a.Agent.LogWriter = logger.NewLogWriter(512)
|
||||
|
||||
defer a.Shutdown()
|
||||
testrpc.WaitForLeader(t, a.RPC, "dc1")
|
||||
|
||||
client := a.Client()
|
||||
|
||||
t.Run("update without name", func(t *testing.T) {
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
"-no-merge",
|
||||
}
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 1)
|
||||
require.Contains(t, ui.ErrorWriter.String(), "Cannot update an auth method without specifying the -name parameter")
|
||||
})
|
||||
|
||||
t.Run("update nonexistent method", func(t *testing.T) {
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
"-no-merge",
|
||||
"-name=test",
|
||||
}
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 1)
|
||||
require.Contains(t, ui.ErrorWriter.String(), "Auth method not found with name")
|
||||
})
|
||||
|
||||
createAuthMethod := func(t *testing.T) string {
|
||||
id, err := uuid.GenerateUUID()
|
||||
require.NoError(t, err)
|
||||
|
||||
methodName := "test-" + id
|
||||
|
||||
_, _, err = client.ACL().AuthMethodCreate(
|
||||
&api.ACLAuthMethod{
|
||||
Name: methodName,
|
||||
Type: "testing",
|
||||
Description: "test",
|
||||
},
|
||||
&api.WriteOptions{Token: "root"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
return methodName
|
||||
}
|
||||
|
||||
t.Run("update all fields", func(t *testing.T) {
|
||||
name := createAuthMethod(t)
|
||||
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
"-no-merge",
|
||||
"-name=" + name,
|
||||
"-description", "updated description",
|
||||
}
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String())
|
||||
require.Empty(t, ui.ErrorWriter.String())
|
||||
|
||||
method, _, err := client.ACL().AuthMethodRead(
|
||||
name,
|
||||
&api.QueryOptions{Token: "root"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, method)
|
||||
require.Equal(t, "updated description", method.Description)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAuthMethodUpdateCommand_k8s(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testDir := testutil.TempDir(t, "acl")
|
||||
defer os.RemoveAll(testDir)
|
||||
|
||||
a := agent.NewTestAgent(t, t.Name(), `
|
||||
primary_datacenter = "dc1"
|
||||
acl {
|
||||
enabled = true
|
||||
tokens {
|
||||
master = "root"
|
||||
}
|
||||
}`)
|
||||
|
||||
a.Agent.LogWriter = logger.NewLogWriter(512)
|
||||
|
||||
defer a.Shutdown()
|
||||
testrpc.WaitForLeader(t, a.RPC, "dc1")
|
||||
|
||||
client := a.Client()
|
||||
|
||||
ca := connect.TestCA(t, nil)
|
||||
ca2 := connect.TestCA(t, nil)
|
||||
|
||||
createAuthMethod := func(t *testing.T) string {
|
||||
id, err := uuid.GenerateUUID()
|
||||
require.NoError(t, err)
|
||||
|
||||
methodName := "k8s-" + id
|
||||
|
||||
_, _, err = client.ACL().AuthMethodCreate(
|
||||
&api.ACLAuthMethod{
|
||||
Name: methodName,
|
||||
Type: "kubernetes",
|
||||
Description: "test",
|
||||
Config: map[string]interface{}{
|
||||
"Host": "https://foo.internal:8443",
|
||||
"CACert": ca.RootCert,
|
||||
"ServiceAccountJWT": acl.TestKubernetesJWT_A,
|
||||
},
|
||||
},
|
||||
&api.WriteOptions{Token: "root"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
return methodName
|
||||
}
|
||||
|
||||
t.Run("update all fields", func(t *testing.T) {
|
||||
name := createAuthMethod(t)
|
||||
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
"-name=" + name,
|
||||
"-description", "updated description",
|
||||
"-kubernetes-host", "https://foo-new.internal:8443",
|
||||
"-kubernetes-ca-cert", ca2.RootCert,
|
||||
"-kubernetes-service-account-jwt", acl.TestKubernetesJWT_B,
|
||||
}
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 0)
|
||||
require.Empty(t, ui.ErrorWriter.String())
|
||||
|
||||
method, _, err := client.ACL().AuthMethodRead(
|
||||
name,
|
||||
&api.QueryOptions{Token: "root"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, method)
|
||||
require.Equal(t, "updated description", method.Description)
|
||||
|
||||
config, err := api.ParseKubernetesAuthMethodConfig(method.Config)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "https://foo-new.internal:8443", config.Host)
|
||||
require.Equal(t, ca2.RootCert, config.CACert)
|
||||
require.Equal(t, acl.TestKubernetesJWT_B, config.ServiceAccountJWT)
|
||||
})
|
||||
|
||||
ca2File := filepath.Join(testDir, "ca2.crt")
|
||||
require.NoError(t, ioutil.WriteFile(ca2File, []byte(ca2.RootCert), 0600))
|
||||
|
||||
t.Run("update all fields with cert file", func(t *testing.T) {
|
||||
name := createAuthMethod(t)
|
||||
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
"-name=" + name,
|
||||
"-description", "updated description",
|
||||
"-kubernetes-host", "https://foo-new.internal:8443",
|
||||
"-kubernetes-ca-cert", "@" + ca2File,
|
||||
"-kubernetes-service-account-jwt", acl.TestKubernetesJWT_B,
|
||||
}
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 0)
|
||||
require.Empty(t, ui.ErrorWriter.String())
|
||||
|
||||
method, _, err := client.ACL().AuthMethodRead(
|
||||
name,
|
||||
&api.QueryOptions{Token: "root"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, method)
|
||||
require.Equal(t, "updated description", method.Description)
|
||||
|
||||
config, err := api.ParseKubernetesAuthMethodConfig(method.Config)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, "https://foo-new.internal:8443", config.Host)
|
||||
require.Equal(t, ca2.RootCert, config.CACert)
|
||||
require.Equal(t, acl.TestKubernetesJWT_B, config.ServiceAccountJWT)
|
||||
})
|
||||
|
||||
t.Run("update all fields but k8s host", func(t *testing.T) {
|
||||
name := createAuthMethod(t)
|
||||
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
"-name=" + name,
|
||||
"-description", "updated description",
|
||||
"-kubernetes-ca-cert", ca2.RootCert,
|
||||
"-kubernetes-service-account-jwt", acl.TestKubernetesJWT_B,
|
||||
}
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 0)
|
||||
require.Empty(t, ui.ErrorWriter.String())
|
||||
|
||||
method, _, err := client.ACL().AuthMethodRead(
|
||||
name,
|
||||
&api.QueryOptions{Token: "root"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, method)
|
||||
require.Equal(t, "updated description", method.Description)
|
||||
|
||||
config, err := api.ParseKubernetesAuthMethodConfig(method.Config)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, "https://foo.internal:8443", config.Host)
|
||||
require.Equal(t, ca2.RootCert, config.CACert)
|
||||
require.Equal(t, acl.TestKubernetesJWT_B, config.ServiceAccountJWT)
|
||||
})
|
||||
|
||||
t.Run("update all fields but k8s ca cert", func(t *testing.T) {
|
||||
name := createAuthMethod(t)
|
||||
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
"-name=" + name,
|
||||
"-description", "updated description",
|
||||
"-kubernetes-host", "https://foo-new.internal:8443",
|
||||
"-kubernetes-service-account-jwt", acl.TestKubernetesJWT_B,
|
||||
}
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 0)
|
||||
require.Empty(t, ui.ErrorWriter.String())
|
||||
|
||||
method, _, err := client.ACL().AuthMethodRead(
|
||||
name,
|
||||
&api.QueryOptions{Token: "root"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, method)
|
||||
require.Equal(t, "updated description", method.Description)
|
||||
|
||||
config, err := api.ParseKubernetesAuthMethodConfig(method.Config)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, "https://foo-new.internal:8443", config.Host)
|
||||
require.Equal(t, ca.RootCert, config.CACert)
|
||||
require.Equal(t, acl.TestKubernetesJWT_B, config.ServiceAccountJWT)
|
||||
})
|
||||
|
||||
t.Run("update all fields but k8s jwt", func(t *testing.T) {
|
||||
name := createAuthMethod(t)
|
||||
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
"-name=" + name,
|
||||
"-description", "updated description",
|
||||
"-kubernetes-host", "https://foo-new.internal:8443",
|
||||
"-kubernetes-ca-cert", ca2.RootCert,
|
||||
}
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 0)
|
||||
require.Empty(t, ui.ErrorWriter.String())
|
||||
|
||||
method, _, err := client.ACL().AuthMethodRead(
|
||||
name,
|
||||
&api.QueryOptions{Token: "root"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, method)
|
||||
require.Equal(t, "updated description", method.Description)
|
||||
|
||||
config, err := api.ParseKubernetesAuthMethodConfig(method.Config)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, "https://foo-new.internal:8443", config.Host)
|
||||
require.Equal(t, ca2.RootCert, config.CACert)
|
||||
require.Equal(t, acl.TestKubernetesJWT_A, config.ServiceAccountJWT)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAuthMethodUpdateCommand_k8s_noMerge(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testDir := testutil.TempDir(t, "acl")
|
||||
defer os.RemoveAll(testDir)
|
||||
|
||||
a := agent.NewTestAgent(t, t.Name(), `
|
||||
primary_datacenter = "dc1"
|
||||
acl {
|
||||
enabled = true
|
||||
tokens {
|
||||
master = "root"
|
||||
}
|
||||
}`)
|
||||
|
||||
a.Agent.LogWriter = logger.NewLogWriter(512)
|
||||
|
||||
defer a.Shutdown()
|
||||
testrpc.WaitForLeader(t, a.RPC, "dc1")
|
||||
|
||||
client := a.Client()
|
||||
|
||||
ca := connect.TestCA(t, nil)
|
||||
ca2 := connect.TestCA(t, nil)
|
||||
|
||||
createAuthMethod := func(t *testing.T) string {
|
||||
id, err := uuid.GenerateUUID()
|
||||
require.NoError(t, err)
|
||||
|
||||
methodName := "k8s-" + id
|
||||
|
||||
_, _, err = client.ACL().AuthMethodCreate(
|
||||
&api.ACLAuthMethod{
|
||||
Name: methodName,
|
||||
Type: "kubernetes",
|
||||
Description: "test",
|
||||
Config: map[string]interface{}{
|
||||
"Host": "https://foo.internal:8443",
|
||||
"CACert": ca.RootCert,
|
||||
"ServiceAccountJWT": acl.TestKubernetesJWT_A,
|
||||
},
|
||||
},
|
||||
&api.WriteOptions{Token: "root"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
return methodName
|
||||
}
|
||||
|
||||
t.Run("update missing k8s host", func(t *testing.T) {
|
||||
name := createAuthMethod(t)
|
||||
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
"-no-merge",
|
||||
"-name=" + name,
|
||||
"-description", "updated description",
|
||||
"-kubernetes-ca-cert", ca2.RootCert,
|
||||
"-kubernetes-service-account-jwt", acl.TestKubernetesJWT_B,
|
||||
}
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 1)
|
||||
require.Contains(t, ui.ErrorWriter.String(), "Missing required '-kubernetes-host' flag")
|
||||
})
|
||||
|
||||
t.Run("update missing k8s ca cert", func(t *testing.T) {
|
||||
name := createAuthMethod(t)
|
||||
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
"-no-merge",
|
||||
"-name=" + name,
|
||||
"-description", "updated description",
|
||||
"-kubernetes-host", "https://foo-new.internal:8443",
|
||||
"-kubernetes-service-account-jwt", acl.TestKubernetesJWT_B,
|
||||
}
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 1)
|
||||
require.Contains(t, ui.ErrorWriter.String(), "Missing required '-kubernetes-ca-cert' flag")
|
||||
})
|
||||
|
||||
t.Run("update missing k8s jwt", func(t *testing.T) {
|
||||
name := createAuthMethod(t)
|
||||
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
"-no-merge",
|
||||
"-name=" + name,
|
||||
"-description", "updated description",
|
||||
"-kubernetes-host", "https://foo-new.internal:8443",
|
||||
"-kubernetes-ca-cert", ca2.RootCert,
|
||||
}
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 1)
|
||||
require.Contains(t, ui.ErrorWriter.String(), "Missing required '-kubernetes-service-account-jwt' flag")
|
||||
})
|
||||
|
||||
t.Run("update all fields", func(t *testing.T) {
|
||||
name := createAuthMethod(t)
|
||||
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
"-no-merge",
|
||||
"-name=" + name,
|
||||
"-description", "updated description",
|
||||
"-kubernetes-host", "https://foo-new.internal:8443",
|
||||
"-kubernetes-ca-cert", ca2.RootCert,
|
||||
"-kubernetes-service-account-jwt", acl.TestKubernetesJWT_B,
|
||||
}
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 0)
|
||||
require.Empty(t, ui.ErrorWriter.String())
|
||||
|
||||
method, _, err := client.ACL().AuthMethodRead(
|
||||
name,
|
||||
&api.QueryOptions{Token: "root"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, method)
|
||||
require.Equal(t, "updated description", method.Description)
|
||||
|
||||
config, err := api.ParseKubernetesAuthMethodConfig(method.Config)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, "https://foo-new.internal:8443", config.Host)
|
||||
require.Equal(t, ca2.RootCert, config.CACert)
|
||||
require.Equal(t, acl.TestKubernetesJWT_B, config.ServiceAccountJWT)
|
||||
})
|
||||
|
||||
ca2File := filepath.Join(testDir, "ca2.crt")
|
||||
require.NoError(t, ioutil.WriteFile(ca2File, []byte(ca2.RootCert), 0600))
|
||||
|
||||
t.Run("update all fields with cert file", func(t *testing.T) {
|
||||
name := createAuthMethod(t)
|
||||
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
"-no-merge",
|
||||
"-name=" + name,
|
||||
"-description", "updated description",
|
||||
"-kubernetes-host", "https://foo-new.internal:8443",
|
||||
"-kubernetes-ca-cert", "@" + ca2File,
|
||||
"-kubernetes-service-account-jwt", acl.TestKubernetesJWT_B,
|
||||
}
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 0)
|
||||
require.Empty(t, ui.ErrorWriter.String())
|
||||
|
||||
method, _, err := client.ACL().AuthMethodRead(
|
||||
name,
|
||||
&api.QueryOptions{Token: "root"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, method)
|
||||
require.Equal(t, "updated description", method.Description)
|
||||
|
||||
config, err := api.ParseKubernetesAuthMethodConfig(method.Config)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, "https://foo-new.internal:8443", config.Host)
|
||||
require.Equal(t, ca2.RootCert, config.CACert)
|
||||
require.Equal(t, acl.TestKubernetesJWT_B, config.ServiceAccountJWT)
|
||||
})
|
||||
}
|
|
@ -0,0 +1,60 @@
|
|||
package bindingrule
|
||||
|
||||
import (
|
||||
"github.com/hashicorp/consul/command/flags"
|
||||
"github.com/mitchellh/cli"
|
||||
)
|
||||
|
||||
func New() *cmd {
|
||||
return &cmd{}
|
||||
}
|
||||
|
||||
type cmd struct{}
|
||||
|
||||
func (c *cmd) Run(args []string) int {
|
||||
return cli.RunResultHelp
|
||||
}
|
||||
|
||||
func (c *cmd) Synopsis() string {
|
||||
return synopsis
|
||||
}
|
||||
|
||||
func (c *cmd) Help() string {
|
||||
return flags.Usage(help, nil)
|
||||
}
|
||||
|
||||
const synopsis = "Manage Consul's ACL Binding Rules"
|
||||
const help = `
|
||||
Usage: consul acl binding-rule <subcommand> [options] [args]
|
||||
|
||||
This command has subcommands for managing Consul's ACL Binding Rules. Here
|
||||
are some simple examples, and more detailed examples are available in the
|
||||
subcommands or the documentation.
|
||||
|
||||
Create a new binding rule:
|
||||
|
||||
$ consul acl binding-rule create \
|
||||
-method=minikube \
|
||||
-bind-type=service \
|
||||
-bind-name='k8s-${serviceaccount.name}' \
|
||||
-selector='serviceaccount.namespace==default and serviceaccount.name==web'
|
||||
|
||||
List all binding rules:
|
||||
|
||||
$ consul acl binding-rule list
|
||||
|
||||
Update a binding rule:
|
||||
|
||||
$ consul acl binding-rule update -id=43cb72df-9c6f-4315-ac8a-01a9d98155ef \
|
||||
-bind-name='k8s-${serviceaccount.name}'
|
||||
|
||||
Read a binding rule:
|
||||
|
||||
$ consul acl binding-rule read -id fdabbcb5-9de5-4b1a-961f-77214ae88cba
|
||||
|
||||
Delete a binding rule:
|
||||
|
||||
$ consul acl binding-rule delete -id b6b856da-5193-4e78-845a-7d61ca8371ba
|
||||
|
||||
For more examples, ask for subcommand help or view the documentation.
|
||||
`
|
|
@ -0,0 +1,148 @@
|
|||
package bindingrulecreate
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/consul/api"
|
||||
"github.com/hashicorp/consul/command/acl"
|
||||
"github.com/hashicorp/consul/command/flags"
|
||||
"github.com/mitchellh/cli"
|
||||
)
|
||||
|
||||
func New(ui cli.Ui) *cmd {
|
||||
c := &cmd{UI: ui}
|
||||
c.init()
|
||||
return c
|
||||
}
|
||||
|
||||
type cmd struct {
|
||||
UI cli.Ui
|
||||
flags *flag.FlagSet
|
||||
http *flags.HTTPFlags
|
||||
help string
|
||||
|
||||
authMethodName string
|
||||
description string
|
||||
selector string
|
||||
bindType string
|
||||
bindName string
|
||||
|
||||
showMeta bool
|
||||
}
|
||||
|
||||
func (c *cmd) init() {
|
||||
c.flags = flag.NewFlagSet("", flag.ContinueOnError)
|
||||
|
||||
c.flags.BoolVar(
|
||||
&c.showMeta,
|
||||
"meta",
|
||||
false,
|
||||
"Indicates that binding rule metadata such "+
|
||||
"as the content hash and raft indices should be shown for each entry.",
|
||||
)
|
||||
|
||||
c.flags.StringVar(
|
||||
&c.authMethodName,
|
||||
"method",
|
||||
"",
|
||||
"The auth method's name for which this binding rule applies. "+
|
||||
"This flag is required.",
|
||||
)
|
||||
c.flags.StringVar(
|
||||
&c.description,
|
||||
"description",
|
||||
"",
|
||||
"A description of the binding rule.",
|
||||
)
|
||||
c.flags.StringVar(
|
||||
&c.selector,
|
||||
"selector",
|
||||
"",
|
||||
"Selector is an expression that matches against verified identity "+
|
||||
"attributes returned from the auth method during login.",
|
||||
)
|
||||
c.flags.StringVar(
|
||||
&c.bindType,
|
||||
"bind-type",
|
||||
string(api.BindingRuleBindTypeService),
|
||||
"Type of binding to perform (\"service\" or \"role\").",
|
||||
)
|
||||
c.flags.StringVar(
|
||||
&c.bindName,
|
||||
"bind-name",
|
||||
"",
|
||||
"Name to bind on match. Can use ${var} interpolation. "+
|
||||
"This flag is required.",
|
||||
)
|
||||
|
||||
c.http = &flags.HTTPFlags{}
|
||||
flags.Merge(c.flags, c.http.ClientFlags())
|
||||
flags.Merge(c.flags, c.http.ServerFlags())
|
||||
c.help = flags.Usage(help, c.flags)
|
||||
}
|
||||
|
||||
func (c *cmd) Run(args []string) int {
|
||||
if err := c.flags.Parse(args); err != nil {
|
||||
return 1
|
||||
}
|
||||
|
||||
if c.authMethodName == "" {
|
||||
c.UI.Error(fmt.Sprintf("Missing required '-method' flag"))
|
||||
c.UI.Error(c.Help())
|
||||
return 1
|
||||
} else if c.bindType == "" {
|
||||
c.UI.Error(fmt.Sprintf("Missing required '-bind-type' flag"))
|
||||
c.UI.Error(c.Help())
|
||||
return 1
|
||||
} else if c.bindName == "" {
|
||||
c.UI.Error(fmt.Sprintf("Missing required '-bind-name' flag"))
|
||||
c.UI.Error(c.Help())
|
||||
return 1
|
||||
}
|
||||
|
||||
newRule := &api.ACLBindingRule{
|
||||
Description: c.description,
|
||||
AuthMethod: c.authMethodName,
|
||||
BindType: api.BindingRuleBindType(c.bindType),
|
||||
BindName: c.bindName,
|
||||
Selector: c.selector,
|
||||
}
|
||||
|
||||
client, err := c.http.APIClient()
|
||||
if err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Error connecting to Consul agent: %s", err))
|
||||
return 1
|
||||
}
|
||||
|
||||
rule, _, err := client.ACL().BindingRuleCreate(newRule, nil)
|
||||
if err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Failed to create new binding rule: %v", err))
|
||||
return 1
|
||||
}
|
||||
|
||||
acl.PrintBindingRule(rule, c.UI, c.showMeta)
|
||||
return 0
|
||||
}
|
||||
|
||||
func (c *cmd) Synopsis() string {
|
||||
return synopsis
|
||||
}
|
||||
|
||||
func (c *cmd) Help() string {
|
||||
return flags.Usage(c.help, nil)
|
||||
}
|
||||
|
||||
const synopsis = "Create an ACL Binding Rule"
|
||||
|
||||
const help = `
|
||||
Usage: consul acl binding-rule create [options]
|
||||
|
||||
Create a new binding rule:
|
||||
|
||||
$ consul acl binding-rule create \
|
||||
-method=minikube \
|
||||
-bind-type=service \
|
||||
-bind-name='k8s-${serviceaccount.name}' \
|
||||
-selector='serviceaccount.namespace==default and serviceaccount.name==web'
|
||||
`
|
|
@ -0,0 +1,178 @@
|
|||
package bindingrulecreate
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/consul/agent"
|
||||
"github.com/hashicorp/consul/api"
|
||||
"github.com/hashicorp/consul/logger"
|
||||
"github.com/hashicorp/consul/sdk/testutil"
|
||||
"github.com/hashicorp/consul/testrpc"
|
||||
"github.com/mitchellh/cli"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
// activate testing auth method
|
||||
_ "github.com/hashicorp/consul/agent/consul/authmethod/testauth"
|
||||
)
|
||||
|
||||
func TestBindingRuleCreateCommand_noTabs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if strings.ContainsRune(New(cli.NewMockUi()).Help(), '\t') {
|
||||
t.Fatal("help has tabs")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBindingRuleCreateCommand(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testDir := testutil.TempDir(t, "acl")
|
||||
defer os.RemoveAll(testDir)
|
||||
|
||||
a := agent.NewTestAgent(t, t.Name(), `
|
||||
primary_datacenter = "dc1"
|
||||
acl {
|
||||
enabled = true
|
||||
tokens {
|
||||
master = "root"
|
||||
}
|
||||
}`)
|
||||
|
||||
a.Agent.LogWriter = logger.NewLogWriter(512)
|
||||
|
||||
defer a.Shutdown()
|
||||
testrpc.WaitForLeader(t, a.RPC, "dc1")
|
||||
|
||||
client := a.Client()
|
||||
|
||||
// create an auth method in advance
|
||||
{
|
||||
_, _, err := client.ACL().AuthMethodCreate(
|
||||
&api.ACLAuthMethod{
|
||||
Name: "test",
|
||||
Type: "testing",
|
||||
},
|
||||
&api.WriteOptions{Token: "root"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
t.Run("method is required", func(t *testing.T) {
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
}
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 1)
|
||||
require.Contains(t, ui.ErrorWriter.String(), "Missing required '-method' flag")
|
||||
})
|
||||
|
||||
t.Run("bind type required", func(t *testing.T) {
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
"-method=test",
|
||||
"-bind-type=",
|
||||
}
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 1)
|
||||
require.Contains(t, ui.ErrorWriter.String(), "Missing required '-bind-type' flag")
|
||||
})
|
||||
|
||||
t.Run("bind name required", func(t *testing.T) {
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
"-method=test",
|
||||
"-bind-type=service",
|
||||
}
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 1)
|
||||
require.Contains(t, ui.ErrorWriter.String(), "Missing required '-bind-name' flag")
|
||||
})
|
||||
|
||||
t.Run("must use roughly valid selector", func(t *testing.T) {
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
"-method=test",
|
||||
"-bind-type=service",
|
||||
"-bind-name=demo",
|
||||
"-selector", "foo",
|
||||
}
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 1)
|
||||
require.Contains(t, ui.ErrorWriter.String(), "Selector is invalid")
|
||||
})
|
||||
|
||||
t.Run("create it with no selector", func(t *testing.T) {
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
"-method=test",
|
||||
"-bind-type=service",
|
||||
"-bind-name=demo",
|
||||
}
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 0)
|
||||
require.Empty(t, ui.ErrorWriter.String())
|
||||
})
|
||||
|
||||
t.Run("create it with a match selector", func(t *testing.T) {
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
"-method=test",
|
||||
"-bind-type=service",
|
||||
"-bind-name=demo",
|
||||
"-selector", "serviceaccount.namespace==default and serviceaccount.name==vault",
|
||||
}
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 0)
|
||||
require.Empty(t, ui.ErrorWriter.String())
|
||||
})
|
||||
|
||||
t.Run("create it with type role", func(t *testing.T) {
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
"-method=test",
|
||||
"-bind-type=role",
|
||||
"-bind-name=demo",
|
||||
"-selector", "serviceaccount.namespace==default and serviceaccount.name==vault",
|
||||
}
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 0)
|
||||
require.Empty(t, ui.ErrorWriter.String())
|
||||
})
|
||||
}
|
|
@ -0,0 +1,97 @@
|
|||
package bindingruledelete
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/consul/command/acl"
|
||||
"github.com/hashicorp/consul/command/flags"
|
||||
"github.com/mitchellh/cli"
|
||||
)
|
||||
|
||||
func New(ui cli.Ui) *cmd {
|
||||
c := &cmd{UI: ui}
|
||||
c.init()
|
||||
return c
|
||||
}
|
||||
|
||||
type cmd struct {
|
||||
UI cli.Ui
|
||||
flags *flag.FlagSet
|
||||
http *flags.HTTPFlags
|
||||
help string
|
||||
|
||||
ruleID string
|
||||
}
|
||||
|
||||
func (c *cmd) init() {
|
||||
c.flags = flag.NewFlagSet("", flag.ContinueOnError)
|
||||
|
||||
c.flags.StringVar(
|
||||
&c.ruleID,
|
||||
"id",
|
||||
"",
|
||||
"The ID of the binding rule to delete. "+
|
||||
"It may be specified as a unique ID prefix but will error if the prefix "+
|
||||
"matches multiple binding rule IDs",
|
||||
)
|
||||
|
||||
c.http = &flags.HTTPFlags{}
|
||||
flags.Merge(c.flags, c.http.ClientFlags())
|
||||
flags.Merge(c.flags, c.http.ServerFlags())
|
||||
c.help = flags.Usage(help, c.flags)
|
||||
}
|
||||
|
||||
func (c *cmd) Run(args []string) int {
|
||||
if err := c.flags.Parse(args); err != nil {
|
||||
return 1
|
||||
}
|
||||
|
||||
if c.ruleID == "" {
|
||||
c.UI.Error(fmt.Sprintf("Must specify the -id parameter"))
|
||||
return 1
|
||||
}
|
||||
|
||||
client, err := c.http.APIClient()
|
||||
if err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Error connecting to Consul agent: %s", err))
|
||||
return 1
|
||||
}
|
||||
|
||||
ruleID, err := acl.GetBindingRuleIDFromPartial(client, c.ruleID)
|
||||
if err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Error determining binding rule ID: %v", err))
|
||||
return 1
|
||||
}
|
||||
|
||||
if _, err := client.ACL().BindingRuleDelete(ruleID, nil); err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Error deleting binding rule %q: %v", ruleID, err))
|
||||
return 1
|
||||
}
|
||||
|
||||
c.UI.Info(fmt.Sprintf("Binding rule %q deleted successfully", ruleID))
|
||||
return 0
|
||||
}
|
||||
|
||||
func (c *cmd) Synopsis() string {
|
||||
return synopsis
|
||||
}
|
||||
|
||||
func (c *cmd) Help() string {
|
||||
return flags.Usage(c.help, nil)
|
||||
}
|
||||
|
||||
const synopsis = "Delete an ACL Binding Rule"
|
||||
const help = `
|
||||
Usage: consul acl binding-rule delete -id ID [options]
|
||||
|
||||
Deletes an ACL binding rule by providing the ID or a unique ID prefix.
|
||||
|
||||
Delete by prefix:
|
||||
|
||||
$ consul acl binding-rule delete -id b6b85
|
||||
|
||||
Delete by full ID:
|
||||
|
||||
$ consul acl binding-rule delete -id b6b856da-5193-4e78-845a-7d61ca8371ba
|
||||
`
|
|
@ -0,0 +1,187 @@
|
|||
package bindingruledelete
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/consul/agent"
|
||||
"github.com/hashicorp/consul/api"
|
||||
"github.com/hashicorp/consul/logger"
|
||||
"github.com/hashicorp/consul/sdk/testutil"
|
||||
"github.com/hashicorp/consul/testrpc"
|
||||
"github.com/mitchellh/cli"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
// activate testing auth method
|
||||
_ "github.com/hashicorp/consul/agent/consul/authmethod/testauth"
|
||||
)
|
||||
|
||||
func TestBindingRuleDeleteCommand_noTabs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if strings.ContainsRune(New(cli.NewMockUi()).Help(), '\t') {
|
||||
t.Fatal("help has tabs")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBindingRuleDeleteCommand(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testDir := testutil.TempDir(t, "acl")
|
||||
defer os.RemoveAll(testDir)
|
||||
|
||||
a := agent.NewTestAgent(t, t.Name(), `
|
||||
primary_datacenter = "dc1"
|
||||
acl {
|
||||
enabled = true
|
||||
tokens {
|
||||
master = "root"
|
||||
}
|
||||
}`)
|
||||
|
||||
a.Agent.LogWriter = logger.NewLogWriter(512)
|
||||
|
||||
defer a.Shutdown()
|
||||
testrpc.WaitForLeader(t, a.RPC, "dc1")
|
||||
|
||||
client := a.Client()
|
||||
|
||||
// create an auth method in advance
|
||||
{
|
||||
_, _, err := client.ACL().AuthMethodCreate(
|
||||
&api.ACLAuthMethod{
|
||||
Name: "test",
|
||||
Type: "testing",
|
||||
},
|
||||
&api.WriteOptions{Token: "root"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
createRule := func(t *testing.T) string {
|
||||
rule, _, err := client.ACL().BindingRuleCreate(
|
||||
&api.ACLBindingRule{
|
||||
AuthMethod: "test",
|
||||
Description: "test rule",
|
||||
BindType: api.BindingRuleBindTypeService,
|
||||
BindName: "test-${serviceaccount.name}",
|
||||
Selector: "serviceaccount.namespace==default",
|
||||
},
|
||||
&api.WriteOptions{Token: "root"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
return rule.ID
|
||||
}
|
||||
|
||||
createDupe := func(t *testing.T) string {
|
||||
for {
|
||||
// Check for 1-char duplicates.
|
||||
rules, _, err := client.ACL().BindingRuleList(
|
||||
"test",
|
||||
&api.QueryOptions{Token: "root"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
m := make(map[byte]struct{})
|
||||
for _, rule := range rules {
|
||||
c := rule.ID[0]
|
||||
|
||||
if _, ok := m[c]; ok {
|
||||
return string(c)
|
||||
}
|
||||
m[c] = struct{}{}
|
||||
}
|
||||
|
||||
_ = createRule(t)
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("id required", func(t *testing.T) {
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
}
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 1)
|
||||
require.Contains(t, ui.ErrorWriter.String(), "Must specify the -id parameter")
|
||||
})
|
||||
|
||||
t.Run("delete works", func(t *testing.T) {
|
||||
id := createRule(t)
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
"-id", id,
|
||||
}
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 0)
|
||||
require.Empty(t, ui.ErrorWriter.String())
|
||||
|
||||
output := ui.OutputWriter.String()
|
||||
require.Contains(t, output, fmt.Sprintf("deleted successfully"))
|
||||
require.Contains(t, output, id)
|
||||
|
||||
rule, _, err := client.ACL().BindingRuleRead(
|
||||
id,
|
||||
&api.QueryOptions{Token: "root"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, rule)
|
||||
})
|
||||
|
||||
t.Run("delete works via prefixes", func(t *testing.T) {
|
||||
id := createRule(t)
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
"-id", id[0:5],
|
||||
}
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 0)
|
||||
require.Empty(t, ui.ErrorWriter.String())
|
||||
|
||||
output := ui.OutputWriter.String()
|
||||
require.Contains(t, output, fmt.Sprintf("deleted successfully"))
|
||||
require.Contains(t, output, id)
|
||||
|
||||
rule, _, err := client.ACL().BindingRuleRead(
|
||||
id,
|
||||
&api.QueryOptions{Token: "root"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, rule)
|
||||
})
|
||||
|
||||
t.Run("delete fails when prefix matches more than one rule", func(t *testing.T) {
|
||||
prefix := createDupe(t)
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
"-id=" + prefix,
|
||||
}
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 1)
|
||||
require.Contains(t, ui.ErrorWriter.String(), "Error determining binding rule ID")
|
||||
})
|
||||
}
|
|
@ -0,0 +1,98 @@
|
|||
package bindingrulelist
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/consul/command/acl"
|
||||
"github.com/hashicorp/consul/command/flags"
|
||||
"github.com/mitchellh/cli"
|
||||
)
|
||||
|
||||
func New(ui cli.Ui) *cmd {
|
||||
c := &cmd{UI: ui}
|
||||
c.init()
|
||||
return c
|
||||
}
|
||||
|
||||
type cmd struct {
|
||||
UI cli.Ui
|
||||
flags *flag.FlagSet
|
||||
http *flags.HTTPFlags
|
||||
help string
|
||||
|
||||
authMethodName string
|
||||
|
||||
showMeta bool
|
||||
}
|
||||
|
||||
func (c *cmd) init() {
|
||||
c.flags = flag.NewFlagSet("", flag.ContinueOnError)
|
||||
|
||||
c.flags.BoolVar(
|
||||
&c.showMeta,
|
||||
"meta",
|
||||
false,
|
||||
"Indicates that binding rule metadata such "+
|
||||
"as the content hash and raft indices should be shown for each entry.",
|
||||
)
|
||||
|
||||
c.flags.StringVar(
|
||||
&c.authMethodName,
|
||||
"method",
|
||||
"",
|
||||
"Only show rules linked to the auth method with the given name.",
|
||||
)
|
||||
|
||||
c.http = &flags.HTTPFlags{}
|
||||
flags.Merge(c.flags, c.http.ClientFlags())
|
||||
flags.Merge(c.flags, c.http.ServerFlags())
|
||||
c.help = flags.Usage(help, c.flags)
|
||||
}
|
||||
|
||||
func (c *cmd) Run(args []string) int {
|
||||
if err := c.flags.Parse(args); err != nil {
|
||||
return 1
|
||||
}
|
||||
|
||||
client, err := c.http.APIClient()
|
||||
if err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Error connecting to Consul agent: %s", err))
|
||||
return 1
|
||||
}
|
||||
|
||||
rules, _, err := client.ACL().BindingRuleList(c.authMethodName, nil)
|
||||
if err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Failed to retrieve the binding rule list: %v", err))
|
||||
return 1
|
||||
}
|
||||
|
||||
for _, rule := range rules {
|
||||
acl.PrintBindingRuleListEntry(rule, c.UI, c.showMeta)
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
func (c *cmd) Synopsis() string {
|
||||
return synopsis
|
||||
}
|
||||
|
||||
func (c *cmd) Help() string {
|
||||
return flags.Usage(c.help, nil)
|
||||
}
|
||||
|
||||
const synopsis = "Lists ACL Binding Rules"
|
||||
const help = `
|
||||
Usage: consul acl binding-rule list [options]
|
||||
|
||||
Lists all the ACL binding rules.
|
||||
|
||||
Show all:
|
||||
|
||||
$ consul acl binding-rule list
|
||||
|
||||
Show all for a specific auth method:
|
||||
|
||||
$ consul acl binding-rule list -method="my-method"
|
||||
`
|
|
@ -0,0 +1,167 @@
|
|||
package bindingrulelist
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/consul/agent"
|
||||
"github.com/hashicorp/consul/api"
|
||||
"github.com/hashicorp/consul/logger"
|
||||
"github.com/hashicorp/consul/sdk/testutil"
|
||||
"github.com/hashicorp/consul/testrpc"
|
||||
"github.com/mitchellh/cli"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
// activate testing auth method
|
||||
_ "github.com/hashicorp/consul/agent/consul/authmethod/testauth"
|
||||
)
|
||||
|
||||
func TestBindingRuleListCommand_noTabs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if strings.ContainsRune(New(cli.NewMockUi()).Help(), '\t') {
|
||||
t.Fatal("help has tabs")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBindingRuleListCommand(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testDir := testutil.TempDir(t, "acl")
|
||||
defer os.RemoveAll(testDir)
|
||||
|
||||
a := agent.NewTestAgent(t, t.Name(), `
|
||||
primary_datacenter = "dc1"
|
||||
acl {
|
||||
enabled = true
|
||||
tokens {
|
||||
master = "root"
|
||||
}
|
||||
}`)
|
||||
|
||||
a.Agent.LogWriter = logger.NewLogWriter(512)
|
||||
|
||||
defer a.Shutdown()
|
||||
testrpc.WaitForLeader(t, a.RPC, "dc1")
|
||||
|
||||
client := a.Client()
|
||||
|
||||
{
|
||||
_, _, err := client.ACL().AuthMethodCreate(
|
||||
&api.ACLAuthMethod{
|
||||
Name: "test-1",
|
||||
Type: "testing",
|
||||
},
|
||||
&api.WriteOptions{Token: "root"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, err = client.ACL().AuthMethodCreate(
|
||||
&api.ACLAuthMethod{
|
||||
Name: "test-2",
|
||||
Type: "testing",
|
||||
},
|
||||
&api.WriteOptions{Token: "root"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
createRule := func(t *testing.T, methodName, description string) string {
|
||||
rule, _, err := client.ACL().BindingRuleCreate(
|
||||
&api.ACLBindingRule{
|
||||
AuthMethod: methodName,
|
||||
Description: description,
|
||||
BindType: api.BindingRuleBindTypeService,
|
||||
BindName: "test-${serviceaccount.name}",
|
||||
Selector: "serviceaccount.namespace==default",
|
||||
},
|
||||
&api.WriteOptions{Token: "root"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
return rule.ID
|
||||
}
|
||||
|
||||
var ruleIDs []string
|
||||
for i := 0; i < 10; i++ {
|
||||
name := fmt.Sprintf("test-rule-%d", i)
|
||||
|
||||
var methodName string
|
||||
if i%2 == 0 {
|
||||
methodName = "test-1"
|
||||
} else {
|
||||
methodName = "test-2"
|
||||
}
|
||||
|
||||
id := createRule(t, methodName, name)
|
||||
|
||||
ruleIDs = append(ruleIDs, id)
|
||||
}
|
||||
|
||||
t.Run("normal", func(t *testing.T) {
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
}
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 0)
|
||||
require.Empty(t, ui.ErrorWriter.String())
|
||||
output := ui.OutputWriter.String()
|
||||
|
||||
for i, v := range ruleIDs {
|
||||
require.Contains(t, output, fmt.Sprintf("test-rule-%d", i))
|
||||
require.Contains(t, output, v)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("filter by method 1", func(t *testing.T) {
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
"-method=test-1",
|
||||
}
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 0)
|
||||
require.Empty(t, ui.ErrorWriter.String())
|
||||
output := ui.OutputWriter.String()
|
||||
|
||||
for i, v := range ruleIDs {
|
||||
if i%2 == 0 {
|
||||
require.Contains(t, output, fmt.Sprintf("test-rule-%d", i))
|
||||
require.Contains(t, output, v)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("filter by method 2", func(t *testing.T) {
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
"-method=test-2",
|
||||
}
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 0)
|
||||
require.Empty(t, ui.ErrorWriter.String())
|
||||
output := ui.OutputWriter.String()
|
||||
|
||||
for i, v := range ruleIDs {
|
||||
if i%2 == 1 {
|
||||
require.Contains(t, output, fmt.Sprintf("test-rule-%d", i))
|
||||
require.Contains(t, output, v)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
|
@ -0,0 +1,108 @@
|
|||
package bindingruleread
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/consul/command/acl"
|
||||
"github.com/hashicorp/consul/command/flags"
|
||||
"github.com/mitchellh/cli"
|
||||
)
|
||||
|
||||
func New(ui cli.Ui) *cmd {
|
||||
c := &cmd{UI: ui}
|
||||
c.init()
|
||||
return c
|
||||
}
|
||||
|
||||
type cmd struct {
|
||||
UI cli.Ui
|
||||
flags *flag.FlagSet
|
||||
http *flags.HTTPFlags
|
||||
help string
|
||||
|
||||
ruleID string
|
||||
|
||||
showMeta bool
|
||||
}
|
||||
|
||||
func (c *cmd) init() {
|
||||
c.flags = flag.NewFlagSet("", flag.ContinueOnError)
|
||||
|
||||
c.flags.BoolVar(
|
||||
&c.showMeta,
|
||||
"meta",
|
||||
false,
|
||||
"Indicates that binding rule metadata such "+
|
||||
"as the content hash and raft indices should be shown for each entry.",
|
||||
)
|
||||
|
||||
c.flags.StringVar(
|
||||
&c.ruleID,
|
||||
"id",
|
||||
"",
|
||||
"The ID of the binding rule to read. "+
|
||||
"It may be specified as a unique ID prefix but will error if the prefix "+
|
||||
"matches multiple binding rule IDs",
|
||||
)
|
||||
|
||||
c.http = &flags.HTTPFlags{}
|
||||
flags.Merge(c.flags, c.http.ClientFlags())
|
||||
flags.Merge(c.flags, c.http.ServerFlags())
|
||||
c.help = flags.Usage(help, c.flags)
|
||||
}
|
||||
|
||||
func (c *cmd) Run(args []string) int {
|
||||
if err := c.flags.Parse(args); err != nil {
|
||||
return 1
|
||||
}
|
||||
|
||||
if c.ruleID == "" {
|
||||
c.UI.Error(fmt.Sprintf("Must specify the -id parameter."))
|
||||
return 1
|
||||
}
|
||||
|
||||
client, err := c.http.APIClient()
|
||||
if err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Error connecting to Consul agent: %s", err))
|
||||
return 1
|
||||
}
|
||||
|
||||
ruleID, err := acl.GetBindingRuleIDFromPartial(client, c.ruleID)
|
||||
if err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Error determining binding rule ID: %v", err))
|
||||
return 1
|
||||
}
|
||||
|
||||
rule, _, err := client.ACL().BindingRuleRead(ruleID, nil)
|
||||
if err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Error reading binding rule %q: %v", ruleID, err))
|
||||
return 1
|
||||
} else if rule == nil {
|
||||
c.UI.Error(fmt.Sprintf("Binding rule not found with ID %q", ruleID))
|
||||
return 1
|
||||
}
|
||||
|
||||
acl.PrintBindingRule(rule, c.UI, c.showMeta)
|
||||
return 0
|
||||
}
|
||||
|
||||
func (c *cmd) Synopsis() string {
|
||||
return synopsis
|
||||
}
|
||||
|
||||
func (c *cmd) Help() string {
|
||||
return flags.Usage(c.help, nil)
|
||||
}
|
||||
|
||||
const synopsis = "Read an ACL Binding Rule"
|
||||
const help = `
|
||||
Usage: consul acl binding-rule read -id ID [options]
|
||||
|
||||
This command will retrieve and print out the details of a single binding
|
||||
rule.
|
||||
|
||||
Read a binding rule:
|
||||
|
||||
$ consul acl binding-rule read -id fdabbcb5-9de5-4b1a-961f-77214ae88cba
|
||||
`
|
|
@ -0,0 +1,152 @@
|
|||
package bindingruleread
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/consul/agent"
|
||||
"github.com/hashicorp/consul/api"
|
||||
"github.com/hashicorp/consul/logger"
|
||||
"github.com/hashicorp/consul/sdk/testutil"
|
||||
"github.com/hashicorp/consul/testrpc"
|
||||
"github.com/hashicorp/go-uuid"
|
||||
"github.com/mitchellh/cli"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
// activate testing auth method
|
||||
_ "github.com/hashicorp/consul/agent/consul/authmethod/testauth"
|
||||
)
|
||||
|
||||
func TestBindingRuleReadCommand_noTabs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if strings.ContainsRune(New(cli.NewMockUi()).Help(), '\t') {
|
||||
t.Fatal("help has tabs")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBindingRuleReadCommand(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testDir := testutil.TempDir(t, "acl")
|
||||
defer os.RemoveAll(testDir)
|
||||
|
||||
a := agent.NewTestAgent(t, t.Name(), `
|
||||
primary_datacenter = "dc1"
|
||||
acl {
|
||||
enabled = true
|
||||
tokens {
|
||||
master = "root"
|
||||
}
|
||||
}`)
|
||||
|
||||
a.Agent.LogWriter = logger.NewLogWriter(512)
|
||||
|
||||
defer a.Shutdown()
|
||||
testrpc.WaitForLeader(t, a.RPC, "dc1")
|
||||
|
||||
client := a.Client()
|
||||
|
||||
// create an auth method in advance
|
||||
{
|
||||
_, _, err := client.ACL().AuthMethodCreate(
|
||||
&api.ACLAuthMethod{
|
||||
Name: "test",
|
||||
Type: "testing",
|
||||
},
|
||||
&api.WriteOptions{Token: "root"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
createRule := func(t *testing.T) string {
|
||||
rule, _, err := client.ACL().BindingRuleCreate(
|
||||
&api.ACLBindingRule{
|
||||
AuthMethod: "test",
|
||||
Description: "test rule",
|
||||
BindType: api.BindingRuleBindTypeService,
|
||||
BindName: "test-${serviceaccount.name}",
|
||||
Selector: "serviceaccount.namespace==default",
|
||||
},
|
||||
&api.WriteOptions{Token: "root"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
return rule.ID
|
||||
}
|
||||
|
||||
t.Run("id required", func(t *testing.T) {
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
}
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 1)
|
||||
require.Contains(t, ui.ErrorWriter.String(), "Must specify the -id parameter")
|
||||
})
|
||||
|
||||
t.Run("read by id not found", func(t *testing.T) {
|
||||
fakeID, err := uuid.GenerateUUID()
|
||||
require.NoError(t, err)
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
"-id=" + fakeID,
|
||||
}
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 1)
|
||||
require.Contains(t, ui.ErrorWriter.String(), "Binding rule not found with ID")
|
||||
})
|
||||
|
||||
t.Run("read by id", func(t *testing.T) {
|
||||
id := createRule(t)
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
"-id=" + id,
|
||||
}
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 0)
|
||||
require.Empty(t, ui.ErrorWriter.String())
|
||||
|
||||
output := ui.OutputWriter.String()
|
||||
require.Contains(t, output, fmt.Sprintf("test rule"))
|
||||
require.Contains(t, output, id)
|
||||
})
|
||||
|
||||
t.Run("read by id prefix", func(t *testing.T) {
|
||||
id := createRule(t)
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
"-id=" + id[0:5],
|
||||
}
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 0)
|
||||
require.Empty(t, ui.ErrorWriter.String())
|
||||
|
||||
output := ui.OutputWriter.String()
|
||||
require.Contains(t, output, fmt.Sprintf("test rule"))
|
||||
require.Contains(t, output, id)
|
||||
})
|
||||
}
|
|
@ -0,0 +1,212 @@
|
|||
package bindingruleupdate
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/consul/api"
|
||||
"github.com/hashicorp/consul/command/acl"
|
||||
"github.com/hashicorp/consul/command/flags"
|
||||
"github.com/mitchellh/cli"
|
||||
)
|
||||
|
||||
func New(ui cli.Ui) *cmd {
|
||||
c := &cmd{UI: ui}
|
||||
c.init()
|
||||
return c
|
||||
}
|
||||
|
||||
type cmd struct {
|
||||
UI cli.Ui
|
||||
flags *flag.FlagSet
|
||||
http *flags.HTTPFlags
|
||||
help string
|
||||
|
||||
ruleID string
|
||||
|
||||
description string
|
||||
selector string
|
||||
bindType string
|
||||
bindName string
|
||||
|
||||
noMerge bool
|
||||
showMeta bool
|
||||
}
|
||||
|
||||
func (c *cmd) init() {
|
||||
c.flags = flag.NewFlagSet("", flag.ContinueOnError)
|
||||
c.flags.BoolVar(
|
||||
&c.showMeta,
|
||||
"meta",
|
||||
false,
|
||||
"Indicates that binding rule metadata such "+
|
||||
"as the content hash and raft indices should be shown for each entry.",
|
||||
)
|
||||
|
||||
c.flags.StringVar(
|
||||
&c.ruleID,
|
||||
"id",
|
||||
"",
|
||||
"The ID of the binding rule to update. "+
|
||||
"It may be specified as a unique ID prefix but will error if the prefix "+
|
||||
"matches multiple binding rule IDs",
|
||||
)
|
||||
|
||||
c.flags.StringVar(
|
||||
&c.description,
|
||||
"description",
|
||||
"",
|
||||
"A description of the binding rule.",
|
||||
)
|
||||
c.flags.StringVar(
|
||||
&c.selector,
|
||||
"selector",
|
||||
"",
|
||||
"Selector is an expression that matches against verified identity "+
|
||||
"attributes returned from the auth method during login.",
|
||||
)
|
||||
c.flags.StringVar(
|
||||
&c.bindType,
|
||||
"bind-type",
|
||||
string(api.BindingRuleBindTypeService),
|
||||
"Type of binding to perform (\"service\" or \"role\").",
|
||||
)
|
||||
c.flags.StringVar(
|
||||
&c.bindName,
|
||||
"bind-name",
|
||||
"",
|
||||
"Name to bind on match. Can use ${var} interpolation. "+
|
||||
"This flag is required.",
|
||||
)
|
||||
|
||||
c.flags.BoolVar(
|
||||
&c.noMerge,
|
||||
"no-merge",
|
||||
false,
|
||||
"Do not merge the current binding rule "+
|
||||
"information with what is provided to the command. Instead overwrite all fields "+
|
||||
"with the exception of the binding rule ID which is immutable.",
|
||||
)
|
||||
|
||||
c.http = &flags.HTTPFlags{}
|
||||
flags.Merge(c.flags, c.http.ClientFlags())
|
||||
flags.Merge(c.flags, c.http.ServerFlags())
|
||||
c.help = flags.Usage(help, c.flags)
|
||||
}
|
||||
|
||||
func (c *cmd) Run(args []string) int {
|
||||
if err := c.flags.Parse(args); err != nil {
|
||||
return 1
|
||||
}
|
||||
|
||||
if c.ruleID == "" {
|
||||
c.UI.Error(fmt.Sprintf("Cannot update a binding rule without specifying the -id parameter"))
|
||||
return 1
|
||||
}
|
||||
|
||||
client, err := c.http.APIClient()
|
||||
if err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Error connecting to Consul agent: %s", err))
|
||||
return 1
|
||||
}
|
||||
|
||||
ruleID, err := acl.GetBindingRuleIDFromPartial(client, c.ruleID)
|
||||
if err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Error determining binding rule ID: %v", err))
|
||||
return 1
|
||||
}
|
||||
|
||||
// Read the current binding rule in both cases so we can fail better if not found.
|
||||
currentRule, _, err := client.ACL().BindingRuleRead(ruleID, nil)
|
||||
if err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Error when retrieving current binding rule: %v", err))
|
||||
return 1
|
||||
} else if currentRule == nil {
|
||||
c.UI.Error(fmt.Sprintf("Binding rule not found with ID %q", ruleID))
|
||||
return 1
|
||||
}
|
||||
|
||||
var rule *api.ACLBindingRule
|
||||
if c.noMerge {
|
||||
if c.bindType == "" {
|
||||
c.UI.Error(fmt.Sprintf("Missing required '-bind-type' flag"))
|
||||
c.UI.Error(c.Help())
|
||||
return 1
|
||||
} else if c.bindName == "" {
|
||||
c.UI.Error(fmt.Sprintf("Missing required '-bind-name' flag"))
|
||||
c.UI.Error(c.Help())
|
||||
return 1
|
||||
}
|
||||
|
||||
rule = &api.ACLBindingRule{
|
||||
ID: ruleID,
|
||||
AuthMethod: currentRule.AuthMethod, // immutable
|
||||
Description: c.description,
|
||||
BindType: api.BindingRuleBindType(c.bindType),
|
||||
BindName: c.bindName,
|
||||
Selector: c.selector,
|
||||
}
|
||||
|
||||
} else {
|
||||
rule = currentRule
|
||||
|
||||
if c.description != "" {
|
||||
rule.Description = c.description
|
||||
}
|
||||
if c.bindType != "" {
|
||||
rule.BindType = api.BindingRuleBindType(c.bindType)
|
||||
}
|
||||
if c.bindName != "" {
|
||||
rule.BindName = c.bindName
|
||||
}
|
||||
if isFlagSet(c.flags, "selector") {
|
||||
rule.Selector = c.selector // empty is valid
|
||||
}
|
||||
}
|
||||
|
||||
rule, _, err = client.ACL().BindingRuleUpdate(rule, nil)
|
||||
if err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Error updating binding rule %q: %v", ruleID, err))
|
||||
return 1
|
||||
}
|
||||
|
||||
c.UI.Info(fmt.Sprintf("Binding rule updated successfully"))
|
||||
acl.PrintBindingRule(rule, c.UI, c.showMeta)
|
||||
return 0
|
||||
}
|
||||
|
||||
func (c *cmd) Synopsis() string {
|
||||
return synopsis
|
||||
}
|
||||
|
||||
func (c *cmd) Help() string {
|
||||
return flags.Usage(c.help, nil)
|
||||
}
|
||||
|
||||
func isFlagSet(flags *flag.FlagSet, name string) bool {
|
||||
found := false
|
||||
flags.Visit(func(f *flag.Flag) {
|
||||
if f.Name == name {
|
||||
found = true
|
||||
}
|
||||
})
|
||||
return found
|
||||
}
|
||||
|
||||
const synopsis = "Update an ACL Binding Rule"
|
||||
const help = `
|
||||
Usage: consul acl binding-rule update -id ID [options]
|
||||
|
||||
Updates a binding rule. By default it will merge the binding rule
|
||||
information with its current state so that you do not have to provide all
|
||||
parameters. This behavior can be disabled by passing -no-merge.
|
||||
|
||||
Update all editable fields of the binding rule:
|
||||
|
||||
$ consul acl binding-rule update \
|
||||
-id=43cb72df-9c6f-4315-ac8a-01a9d98155ef \
|
||||
-description="new description" \
|
||||
-bind-type=role \
|
||||
-bind-name='k8s-${serviceaccount.name}' \
|
||||
-selector='serviceaccount.namespace==default and serviceaccount.name==web'
|
||||
`
|
|
@ -0,0 +1,768 @@
|
|||
package bindingruleupdate
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/consul/agent"
|
||||
"github.com/hashicorp/consul/api"
|
||||
"github.com/hashicorp/consul/logger"
|
||||
"github.com/hashicorp/consul/sdk/testutil"
|
||||
"github.com/hashicorp/consul/testrpc"
|
||||
uuid "github.com/hashicorp/go-uuid"
|
||||
"github.com/mitchellh/cli"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
// activate testing auth method
|
||||
_ "github.com/hashicorp/consul/agent/consul/authmethod/testauth"
|
||||
)
|
||||
|
||||
func TestBindingRuleUpdateCommand_noTabs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if strings.ContainsRune(New(cli.NewMockUi()).Help(), '\t') {
|
||||
t.Fatal("help has tabs")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBindingRuleUpdateCommand(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testDir := testutil.TempDir(t, "acl")
|
||||
defer os.RemoveAll(testDir)
|
||||
|
||||
a := agent.NewTestAgent(t, t.Name(), `
|
||||
primary_datacenter = "dc1"
|
||||
acl {
|
||||
enabled = true
|
||||
tokens {
|
||||
master = "root"
|
||||
}
|
||||
}`)
|
||||
|
||||
a.Agent.LogWriter = logger.NewLogWriter(512)
|
||||
|
||||
defer a.Shutdown()
|
||||
testrpc.WaitForLeader(t, a.RPC, "dc1")
|
||||
|
||||
client := a.Client()
|
||||
|
||||
// create an auth method in advance
|
||||
{
|
||||
_, _, err := client.ACL().AuthMethodCreate(
|
||||
&api.ACLAuthMethod{
|
||||
Name: "test",
|
||||
Type: "testing",
|
||||
},
|
||||
&api.WriteOptions{Token: "root"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
deleteRules := func(t *testing.T) {
|
||||
rules, _, err := client.ACL().BindingRuleList(
|
||||
"test",
|
||||
&api.QueryOptions{Token: "root"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, rule := range rules {
|
||||
_, err := client.ACL().BindingRuleDelete(
|
||||
rule.ID,
|
||||
&api.WriteOptions{Token: "root"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("rule id required", func(t *testing.T) {
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
}
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 1)
|
||||
require.Contains(t, ui.ErrorWriter.String(), "Cannot update a binding rule without specifying the -id parameter")
|
||||
})
|
||||
|
||||
t.Run("rule id partial matches nothing", func(t *testing.T) {
|
||||
fakeID, err := uuid.GenerateUUID()
|
||||
require.NoError(t, err)
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-id=" + fakeID[0:5],
|
||||
"-token=root",
|
||||
"-description=test rule edited",
|
||||
}
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 1)
|
||||
require.Contains(t, ui.ErrorWriter.String(), "Error determining binding rule ID")
|
||||
})
|
||||
|
||||
t.Run("rule id exact match doesn't exist", func(t *testing.T) {
|
||||
fakeID, err := uuid.GenerateUUID()
|
||||
require.NoError(t, err)
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-id=" + fakeID,
|
||||
"-token=root",
|
||||
"-description=test rule edited",
|
||||
}
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 1)
|
||||
require.Contains(t, ui.ErrorWriter.String(), "Binding rule not found with ID")
|
||||
})
|
||||
|
||||
createRule := func(t *testing.T) string {
|
||||
rule, _, err := client.ACL().BindingRuleCreate(
|
||||
&api.ACLBindingRule{
|
||||
AuthMethod: "test",
|
||||
Description: "test rule",
|
||||
BindType: api.BindingRuleBindTypeService,
|
||||
BindName: "test-${serviceaccount.name}",
|
||||
Selector: "serviceaccount.namespace==default",
|
||||
},
|
||||
&api.WriteOptions{Token: "root"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
return rule.ID
|
||||
}
|
||||
|
||||
createDupe := func(t *testing.T) string {
|
||||
for {
|
||||
// Check for 1-char duplicates.
|
||||
rules, _, err := client.ACL().BindingRuleList(
|
||||
"test",
|
||||
&api.QueryOptions{Token: "root"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
m := make(map[byte]struct{})
|
||||
for _, rule := range rules {
|
||||
c := rule.ID[0]
|
||||
|
||||
if _, ok := m[c]; ok {
|
||||
return string(c)
|
||||
}
|
||||
m[c] = struct{}{}
|
||||
}
|
||||
|
||||
_ = createRule(t)
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("rule id partial matches multiple", func(t *testing.T) {
|
||||
prefix := createDupe(t)
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-id=" + prefix,
|
||||
"-token=root",
|
||||
"-description=test rule edited",
|
||||
}
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 1)
|
||||
require.Contains(t, ui.ErrorWriter.String(), "Error determining binding rule ID")
|
||||
})
|
||||
|
||||
t.Run("must use roughly valid selector", func(t *testing.T) {
|
||||
id := createRule(t)
|
||||
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
"-id", id,
|
||||
"-selector", "foo",
|
||||
}
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 1)
|
||||
require.Contains(t, ui.ErrorWriter.String(), "Selector is invalid")
|
||||
})
|
||||
|
||||
t.Run("update all fields", func(t *testing.T) {
|
||||
id := createRule(t)
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
"-id", id,
|
||||
"-description=test rule edited",
|
||||
"-bind-type", "role",
|
||||
"-bind-name=role-updated",
|
||||
"-selector=serviceaccount.namespace==alt and serviceaccount.name==demo",
|
||||
}
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String())
|
||||
require.Empty(t, ui.ErrorWriter.String())
|
||||
|
||||
rule, _, err := client.ACL().BindingRuleRead(
|
||||
id,
|
||||
&api.QueryOptions{Token: "root"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rule)
|
||||
|
||||
require.Equal(t, "test rule edited", rule.Description)
|
||||
require.Equal(t, "role-updated", rule.BindName)
|
||||
require.Equal(t, api.BindingRuleBindTypeRole, rule.BindType)
|
||||
require.Equal(t, "serviceaccount.namespace==alt and serviceaccount.name==demo", rule.Selector)
|
||||
})
|
||||
|
||||
t.Run("update all fields - partial", func(t *testing.T) {
|
||||
deleteRules(t) // reset since we created a bunch that might be dupes
|
||||
|
||||
id := createRule(t)
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
"-id", id[0:5],
|
||||
"-description=test rule edited",
|
||||
"-bind-type", "role",
|
||||
"-bind-name=role-updated",
|
||||
"-selector=serviceaccount.namespace==alt and serviceaccount.name==demo",
|
||||
}
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String())
|
||||
require.Empty(t, ui.ErrorWriter.String())
|
||||
|
||||
rule, _, err := client.ACL().BindingRuleRead(
|
||||
id,
|
||||
&api.QueryOptions{Token: "root"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rule)
|
||||
|
||||
require.Equal(t, "test rule edited", rule.Description)
|
||||
require.Equal(t, "role-updated", rule.BindName)
|
||||
require.Equal(t, api.BindingRuleBindTypeRole, rule.BindType)
|
||||
require.Equal(t, "serviceaccount.namespace==alt and serviceaccount.name==demo", rule.Selector)
|
||||
})
|
||||
|
||||
t.Run("update all fields but description", func(t *testing.T) {
|
||||
id := createRule(t)
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
"-id", id,
|
||||
"-bind-type", "role",
|
||||
"-bind-name=role-updated",
|
||||
"-selector=serviceaccount.namespace==alt and serviceaccount.name==demo",
|
||||
}
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String())
|
||||
require.Empty(t, ui.ErrorWriter.String())
|
||||
|
||||
rule, _, err := client.ACL().BindingRuleRead(
|
||||
id,
|
||||
&api.QueryOptions{Token: "root"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rule)
|
||||
|
||||
require.Equal(t, "test rule", rule.Description)
|
||||
require.Equal(t, api.BindingRuleBindTypeRole, rule.BindType)
|
||||
require.Equal(t, "role-updated", rule.BindName)
|
||||
require.Equal(t, "serviceaccount.namespace==alt and serviceaccount.name==demo", rule.Selector)
|
||||
})
|
||||
|
||||
t.Run("update all fields but bind name", func(t *testing.T) {
|
||||
id := createRule(t)
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
"-id", id,
|
||||
"-description=test rule edited",
|
||||
"-bind-type", "role",
|
||||
"-selector=serviceaccount.namespace==alt and serviceaccount.name==demo",
|
||||
}
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String())
|
||||
require.Empty(t, ui.ErrorWriter.String())
|
||||
|
||||
rule, _, err := client.ACL().BindingRuleRead(
|
||||
id,
|
||||
&api.QueryOptions{Token: "root"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rule)
|
||||
|
||||
require.Equal(t, "test rule edited", rule.Description)
|
||||
require.Equal(t, api.BindingRuleBindTypeRole, rule.BindType)
|
||||
require.Equal(t, "test-${serviceaccount.name}", rule.BindName)
|
||||
require.Equal(t, "serviceaccount.namespace==alt and serviceaccount.name==demo", rule.Selector)
|
||||
})
|
||||
|
||||
t.Run("update all fields but must exist", func(t *testing.T) {
|
||||
id := createRule(t)
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
"-id", id,
|
||||
"-description=test rule edited",
|
||||
"-bind-name=role-updated",
|
||||
"-selector=serviceaccount.namespace==alt and serviceaccount.name==demo",
|
||||
}
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String())
|
||||
require.Empty(t, ui.ErrorWriter.String())
|
||||
|
||||
rule, _, err := client.ACL().BindingRuleRead(
|
||||
id,
|
||||
&api.QueryOptions{Token: "root"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rule)
|
||||
|
||||
require.Equal(t, "test rule edited", rule.Description)
|
||||
require.Equal(t, api.BindingRuleBindTypeService, rule.BindType)
|
||||
require.Equal(t, "role-updated", rule.BindName)
|
||||
require.Equal(t, "serviceaccount.namespace==alt and serviceaccount.name==demo", rule.Selector)
|
||||
})
|
||||
|
||||
t.Run("update all fields but selector", func(t *testing.T) {
|
||||
id := createRule(t)
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
"-id", id,
|
||||
"-description=test rule edited",
|
||||
"-bind-type", "role",
|
||||
"-bind-name=role-updated",
|
||||
}
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String())
|
||||
require.Empty(t, ui.ErrorWriter.String())
|
||||
|
||||
rule, _, err := client.ACL().BindingRuleRead(
|
||||
id,
|
||||
&api.QueryOptions{Token: "root"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rule)
|
||||
|
||||
require.Equal(t, "test rule edited", rule.Description)
|
||||
require.Equal(t, api.BindingRuleBindTypeRole, rule.BindType)
|
||||
require.Equal(t, "role-updated", rule.BindName)
|
||||
require.Equal(t, "serviceaccount.namespace==default", rule.Selector)
|
||||
})
|
||||
|
||||
t.Run("update all fields clear selector", func(t *testing.T) {
|
||||
id := createRule(t)
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
"-id", id,
|
||||
"-description=test rule edited",
|
||||
"-bind-type", "role",
|
||||
"-bind-name=role-updated",
|
||||
"-selector=",
|
||||
}
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String())
|
||||
require.Empty(t, ui.ErrorWriter.String())
|
||||
|
||||
rule, _, err := client.ACL().BindingRuleRead(
|
||||
id,
|
||||
&api.QueryOptions{Token: "root"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rule)
|
||||
|
||||
require.Equal(t, "test rule edited", rule.Description)
|
||||
require.Equal(t, api.BindingRuleBindTypeRole, rule.BindType)
|
||||
require.Equal(t, "role-updated", rule.BindName)
|
||||
require.Empty(t, rule.Selector)
|
||||
})
|
||||
}
|
||||
|
||||
func TestBindingRuleUpdateCommand_noMerge(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testDir := testutil.TempDir(t, "acl")
|
||||
defer os.RemoveAll(testDir)
|
||||
|
||||
a := agent.NewTestAgent(t, t.Name(), `
|
||||
primary_datacenter = "dc1"
|
||||
acl {
|
||||
enabled = true
|
||||
tokens {
|
||||
master = "root"
|
||||
}
|
||||
}`)
|
||||
|
||||
a.Agent.LogWriter = logger.NewLogWriter(512)
|
||||
|
||||
defer a.Shutdown()
|
||||
testrpc.WaitForLeader(t, a.RPC, "dc1")
|
||||
|
||||
client := a.Client()
|
||||
|
||||
// create an auth method in advance
|
||||
{
|
||||
_, _, err := client.ACL().AuthMethodCreate(
|
||||
&api.ACLAuthMethod{
|
||||
Name: "test",
|
||||
Type: "testing",
|
||||
},
|
||||
&api.WriteOptions{Token: "root"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
deleteRules := func(t *testing.T) {
|
||||
rules, _, err := client.ACL().BindingRuleList(
|
||||
"test",
|
||||
&api.QueryOptions{Token: "root"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, rule := range rules {
|
||||
_, err := client.ACL().BindingRuleDelete(
|
||||
rule.ID,
|
||||
&api.WriteOptions{Token: "root"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("rule id required", func(t *testing.T) {
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
"-no-merge",
|
||||
}
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 1)
|
||||
require.Contains(t, ui.ErrorWriter.String(), "Cannot update a binding rule without specifying the -id parameter")
|
||||
})
|
||||
|
||||
t.Run("rule id partial matches nothing", func(t *testing.T) {
|
||||
fakeID, err := uuid.GenerateUUID()
|
||||
require.NoError(t, err)
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-id=" + fakeID[0:5],
|
||||
"-token=root",
|
||||
"-no-merge",
|
||||
"-description=test rule edited",
|
||||
}
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 1)
|
||||
require.Contains(t, ui.ErrorWriter.String(), "Error determining binding rule ID")
|
||||
})
|
||||
|
||||
t.Run("rule id exact match doesn't exist", func(t *testing.T) {
|
||||
fakeID, err := uuid.GenerateUUID()
|
||||
require.NoError(t, err)
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-id=" + fakeID,
|
||||
"-token=root",
|
||||
"-no-merge",
|
||||
"-description=test rule edited",
|
||||
}
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 1)
|
||||
require.Contains(t, ui.ErrorWriter.String(), "Binding rule not found with ID")
|
||||
})
|
||||
|
||||
createRule := func(t *testing.T) string {
|
||||
rule, _, err := client.ACL().BindingRuleCreate(
|
||||
&api.ACLBindingRule{
|
||||
AuthMethod: "test",
|
||||
Description: "test rule",
|
||||
BindType: api.BindingRuleBindTypeRole,
|
||||
BindName: "test-${serviceaccount.name}",
|
||||
Selector: "serviceaccount.namespace==default",
|
||||
},
|
||||
&api.WriteOptions{Token: "root"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
return rule.ID
|
||||
}
|
||||
|
||||
createDupe := func(t *testing.T) string {
|
||||
for {
|
||||
// Check for 1-char duplicates.
|
||||
rules, _, err := client.ACL().BindingRuleList(
|
||||
"test",
|
||||
&api.QueryOptions{Token: "root"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
m := make(map[byte]struct{})
|
||||
for _, rule := range rules {
|
||||
c := rule.ID[0]
|
||||
|
||||
if _, ok := m[c]; ok {
|
||||
return string(c)
|
||||
}
|
||||
m[c] = struct{}{}
|
||||
}
|
||||
|
||||
_ = createRule(t)
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("rule id partial matches multiple", func(t *testing.T) {
|
||||
prefix := createDupe(t)
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-id=" + prefix,
|
||||
"-token=root",
|
||||
"-no-merge",
|
||||
"-description=test rule edited",
|
||||
}
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 1)
|
||||
require.Contains(t, ui.ErrorWriter.String(), "Error determining binding rule ID")
|
||||
})
|
||||
|
||||
t.Run("must use roughly valid selector", func(t *testing.T) {
|
||||
id := createRule(t)
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
"-no-merge",
|
||||
"-id", id,
|
||||
"-description=test rule edited",
|
||||
"-bind-type", "service",
|
||||
"-bind-name=role-updated",
|
||||
"-selector", "foo",
|
||||
}
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 1)
|
||||
require.Contains(t, ui.ErrorWriter.String(), "Selector is invalid")
|
||||
})
|
||||
|
||||
t.Run("update all fields", func(t *testing.T) {
|
||||
id := createRule(t)
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
"-no-merge",
|
||||
"-id", id,
|
||||
"-description=test rule edited",
|
||||
"-bind-type", "service",
|
||||
"-bind-name=role-updated",
|
||||
"-selector=serviceaccount.namespace==alt and serviceaccount.name==demo",
|
||||
}
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String())
|
||||
require.Empty(t, ui.ErrorWriter.String())
|
||||
|
||||
rule, _, err := client.ACL().BindingRuleRead(
|
||||
id,
|
||||
&api.QueryOptions{Token: "root"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rule)
|
||||
|
||||
require.Equal(t, "test rule edited", rule.Description)
|
||||
require.Equal(t, api.BindingRuleBindTypeService, rule.BindType)
|
||||
require.Equal(t, "role-updated", rule.BindName)
|
||||
require.Equal(t, "serviceaccount.namespace==alt and serviceaccount.name==demo", rule.Selector)
|
||||
})
|
||||
|
||||
t.Run("update all fields - partial", func(t *testing.T) {
|
||||
deleteRules(t) // reset since we created a bunch that might be dupes
|
||||
|
||||
id := createRule(t)
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
"-no-merge",
|
||||
"-id", id[0:5],
|
||||
"-description=test rule edited",
|
||||
"-bind-type", "service",
|
||||
"-bind-name=role-updated",
|
||||
"-selector=serviceaccount.namespace==alt and serviceaccount.name==demo",
|
||||
}
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String())
|
||||
require.Empty(t, ui.ErrorWriter.String())
|
||||
|
||||
rule, _, err := client.ACL().BindingRuleRead(
|
||||
id,
|
||||
&api.QueryOptions{Token: "root"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rule)
|
||||
|
||||
require.Equal(t, "test rule edited", rule.Description)
|
||||
require.Equal(t, api.BindingRuleBindTypeService, rule.BindType)
|
||||
require.Equal(t, "role-updated", rule.BindName)
|
||||
require.Equal(t, "serviceaccount.namespace==alt and serviceaccount.name==demo", rule.Selector)
|
||||
})
|
||||
|
||||
t.Run("update all fields but description", func(t *testing.T) {
|
||||
id := createRule(t)
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
"-no-merge",
|
||||
"-id", id,
|
||||
"-bind-type", "service",
|
||||
"-bind-name=role-updated",
|
||||
"-selector=serviceaccount.namespace==alt and serviceaccount.name==demo",
|
||||
}
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String())
|
||||
require.Empty(t, ui.ErrorWriter.String())
|
||||
|
||||
rule, _, err := client.ACL().BindingRuleRead(
|
||||
id,
|
||||
&api.QueryOptions{Token: "root"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rule)
|
||||
|
||||
require.Empty(t, rule.Description)
|
||||
require.Equal(t, api.BindingRuleBindTypeService, rule.BindType)
|
||||
require.Equal(t, "role-updated", rule.BindName)
|
||||
require.Equal(t, "serviceaccount.namespace==alt and serviceaccount.name==demo", rule.Selector)
|
||||
})
|
||||
|
||||
t.Run("missing bind name", func(t *testing.T) {
|
||||
id := createRule(t)
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
"-no-merge",
|
||||
"-id=" + id,
|
||||
"-description=test rule edited",
|
||||
"-bind-type", "role",
|
||||
"-selector=serviceaccount.namespace==alt and serviceaccount.name==demo",
|
||||
}
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 1)
|
||||
require.Contains(t, ui.ErrorWriter.String(), "Missing required '-bind-name' flag")
|
||||
})
|
||||
|
||||
t.Run("update all fields but selector", func(t *testing.T) {
|
||||
id := createRule(t)
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
"-no-merge",
|
||||
"-id", id,
|
||||
"-description=test rule edited",
|
||||
"-bind-type", "service",
|
||||
"-bind-name=role-updated",
|
||||
}
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String())
|
||||
require.Empty(t, ui.ErrorWriter.String())
|
||||
|
||||
rule, _, err := client.ACL().BindingRuleRead(
|
||||
id,
|
||||
&api.QueryOptions{Token: "root"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rule)
|
||||
|
||||
require.Equal(t, "test rule edited", rule.Description)
|
||||
require.Equal(t, api.BindingRuleBindTypeService, rule.BindType)
|
||||
require.Equal(t, "role-updated", rule.BindName)
|
||||
require.Empty(t, rule.Selector)
|
||||
})
|
||||
}
|
|
@ -21,7 +21,8 @@ type cmd struct {
|
|||
http *flags.HTTPFlags
|
||||
help string
|
||||
|
||||
roleID string
|
||||
roleID string
|
||||
roleName string
|
||||
}
|
||||
|
||||
func (c *cmd) init() {
|
||||
|
@ -29,6 +30,7 @@ func (c *cmd) init() {
|
|||
c.flags.StringVar(&c.roleID, "id", "", "The ID of the role to delete. "+
|
||||
"It may be specified as a unique ID prefix but will error if the prefix "+
|
||||
"matches multiple role IDs")
|
||||
c.flags.StringVar(&c.roleName, "name", "", "The name of the role to delete.")
|
||||
c.http = &flags.HTTPFlags{}
|
||||
flags.Merge(c.flags, c.http.ClientFlags())
|
||||
flags.Merge(c.flags, c.http.ServerFlags())
|
||||
|
@ -40,8 +42,8 @@ func (c *cmd) Run(args []string) int {
|
|||
return 1
|
||||
}
|
||||
|
||||
if c.roleID == "" {
|
||||
c.UI.Error(fmt.Sprintf("Must specify the -id parameter"))
|
||||
if c.roleID == "" && c.roleName == "" {
|
||||
c.UI.Error(fmt.Sprintf("Must specify the -id or -name parameters"))
|
||||
return 1
|
||||
}
|
||||
|
||||
|
@ -51,7 +53,12 @@ func (c *cmd) Run(args []string) int {
|
|||
return 1
|
||||
}
|
||||
|
||||
roleID, err := acl.GetRoleIDFromPartial(client, c.roleID)
|
||||
var roleID string
|
||||
if c.roleID != "" {
|
||||
roleID, err = acl.GetRoleIDFromPartial(client, c.roleID)
|
||||
} else {
|
||||
roleID, err = acl.GetRoleIDByName(client, c.roleName)
|
||||
}
|
||||
if err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Error determining role ID: %v", err))
|
||||
return 1
|
||||
|
|
|
@ -45,7 +45,7 @@ func TestRoleDeleteCommand(t *testing.T) {
|
|||
|
||||
client := a.Client()
|
||||
|
||||
t.Run("id required", func(t *testing.T) {
|
||||
t.Run("id or name required", func(t *testing.T) {
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
|
@ -56,7 +56,7 @@ func TestRoleDeleteCommand(t *testing.T) {
|
|||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 1)
|
||||
require.Contains(t, ui.ErrorWriter.String(), "Must specify the -id parameter")
|
||||
require.Contains(t, ui.ErrorWriter.String(), "Must specify the -id or -name parameters")
|
||||
})
|
||||
|
||||
t.Run("delete works", func(t *testing.T) {
|
||||
|
|
|
@ -3,6 +3,18 @@ package command
|
|||
import (
|
||||
"github.com/hashicorp/consul/command/acl"
|
||||
aclagent "github.com/hashicorp/consul/command/acl/agenttokens"
|
||||
aclam "github.com/hashicorp/consul/command/acl/authmethod"
|
||||
aclamcreate "github.com/hashicorp/consul/command/acl/authmethod/create"
|
||||
aclamdelete "github.com/hashicorp/consul/command/acl/authmethod/delete"
|
||||
aclamlist "github.com/hashicorp/consul/command/acl/authmethod/list"
|
||||
aclamread "github.com/hashicorp/consul/command/acl/authmethod/read"
|
||||
aclamupdate "github.com/hashicorp/consul/command/acl/authmethod/update"
|
||||
aclbr "github.com/hashicorp/consul/command/acl/bindingrule"
|
||||
aclbrcreate "github.com/hashicorp/consul/command/acl/bindingrule/create"
|
||||
aclbrdelete "github.com/hashicorp/consul/command/acl/bindingrule/delete"
|
||||
aclbrlist "github.com/hashicorp/consul/command/acl/bindingrule/list"
|
||||
aclbrread "github.com/hashicorp/consul/command/acl/bindingrule/read"
|
||||
aclbrupdate "github.com/hashicorp/consul/command/acl/bindingrule/update"
|
||||
aclbootstrap "github.com/hashicorp/consul/command/acl/bootstrap"
|
||||
aclpolicy "github.com/hashicorp/consul/command/acl/policy"
|
||||
aclpcreate "github.com/hashicorp/consul/command/acl/policy/create"
|
||||
|
@ -57,6 +69,8 @@ import (
|
|||
kvput "github.com/hashicorp/consul/command/kv/put"
|
||||
"github.com/hashicorp/consul/command/leave"
|
||||
"github.com/hashicorp/consul/command/lock"
|
||||
login "github.com/hashicorp/consul/command/login"
|
||||
logout "github.com/hashicorp/consul/command/logout"
|
||||
"github.com/hashicorp/consul/command/maint"
|
||||
"github.com/hashicorp/consul/command/members"
|
||||
"github.com/hashicorp/consul/command/monitor"
|
||||
|
@ -118,6 +132,18 @@ func init() {
|
|||
Register("acl role read", func(ui cli.Ui) (cli.Command, error) { return aclrread.New(ui), nil })
|
||||
Register("acl role update", func(ui cli.Ui) (cli.Command, error) { return aclrupdate.New(ui), nil })
|
||||
Register("acl role delete", func(ui cli.Ui) (cli.Command, error) { return aclrdelete.New(ui), nil })
|
||||
Register("acl auth-method", func(cli.Ui) (cli.Command, error) { return aclam.New(), nil })
|
||||
Register("acl auth-method create", func(ui cli.Ui) (cli.Command, error) { return aclamcreate.New(ui), nil })
|
||||
Register("acl auth-method list", func(ui cli.Ui) (cli.Command, error) { return aclamlist.New(ui), nil })
|
||||
Register("acl auth-method read", func(ui cli.Ui) (cli.Command, error) { return aclamread.New(ui), nil })
|
||||
Register("acl auth-method update", func(ui cli.Ui) (cli.Command, error) { return aclamupdate.New(ui), nil })
|
||||
Register("acl auth-method delete", func(ui cli.Ui) (cli.Command, error) { return aclamdelete.New(ui), nil })
|
||||
Register("acl binding-rule", func(cli.Ui) (cli.Command, error) { return aclbr.New(), nil })
|
||||
Register("acl binding-rule create", func(ui cli.Ui) (cli.Command, error) { return aclbrcreate.New(ui), nil })
|
||||
Register("acl binding-rule list", func(ui cli.Ui) (cli.Command, error) { return aclbrlist.New(ui), nil })
|
||||
Register("acl binding-rule read", func(ui cli.Ui) (cli.Command, error) { return aclbrread.New(ui), nil })
|
||||
Register("acl binding-rule update", func(ui cli.Ui) (cli.Command, error) { return aclbrupdate.New(ui), nil })
|
||||
Register("acl binding-rule delete", func(ui cli.Ui) (cli.Command, error) { return aclbrdelete.New(ui), nil })
|
||||
Register("agent", func(ui cli.Ui) (cli.Command, error) {
|
||||
return agent.New(ui, rev, ver, verPre, verHuman, make(chan struct{})), nil
|
||||
})
|
||||
|
@ -153,6 +179,8 @@ func init() {
|
|||
Register("kv put", func(ui cli.Ui) (cli.Command, error) { return kvput.New(ui), nil })
|
||||
Register("leave", func(ui cli.Ui) (cli.Command, error) { return leave.New(ui), nil })
|
||||
Register("lock", func(ui cli.Ui) (cli.Command, error) { return lock.New(ui), nil })
|
||||
Register("login", func(ui cli.Ui) (cli.Command, error) { return login.New(ui), nil })
|
||||
Register("logout", func(ui cli.Ui) (cli.Command, error) { return logout.New(ui), nil })
|
||||
Register("maint", func(ui cli.Ui) (cli.Command, error) { return maint.New(ui), nil })
|
||||
Register("members", func(ui cli.Ui) (cli.Command, error) { return members.New(ui), nil })
|
||||
Register("monitor", func(ui cli.Ui) (cli.Command, error) { return monitor.New(ui, MakeShutdownCh()), nil })
|
||||
|
|
|
@ -104,7 +104,7 @@ func (c *cmd) Run(args []string) int {
|
|||
// enabled.
|
||||
c.grpcAddr = "localhost:8502"
|
||||
}
|
||||
if c.http.Token() == "" {
|
||||
if c.http.Token() == "" && c.http.TokenFile() == "" {
|
||||
// Extra check needed since CONSUL_HTTP_TOKEN has not been consulted yet but
|
||||
// calling SetToken with empty will force that to override the
|
||||
if proxyToken := os.Getenv(proxyAgent.EnvProxyToken); proxyToken != "" {
|
||||
|
|
|
@ -129,7 +129,7 @@ func (c *cmd) Run(args []string) int {
|
|||
if c.sidecarFor == "" {
|
||||
c.sidecarFor = os.Getenv(proxyAgent.EnvSidecarFor)
|
||||
}
|
||||
if c.http.Token() == "" {
|
||||
if c.http.Token() == "" && c.http.TokenFile() == "" {
|
||||
c.http.SetToken(os.Getenv(proxyAgent.EnvProxyToken))
|
||||
}
|
||||
|
||||
|
|
|
@ -2,6 +2,8 @@ package flags
|
|||
|
||||
import (
|
||||
"flag"
|
||||
"io/ioutil"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/consul/api"
|
||||
)
|
||||
|
@ -10,6 +12,7 @@ type HTTPFlags struct {
|
|||
// client api flags
|
||||
address StringValue
|
||||
token StringValue
|
||||
tokenFile StringValue
|
||||
caFile StringValue
|
||||
caPath StringValue
|
||||
certFile StringValue
|
||||
|
@ -33,6 +36,10 @@ func (f *HTTPFlags) ClientFlags() *flag.FlagSet {
|
|||
"ACL token to use in the request. This can also be specified via the "+
|
||||
"CONSUL_HTTP_TOKEN environment variable. If unspecified, the query will "+
|
||||
"default to the token of the Consul agent at the HTTP address.")
|
||||
fs.Var(&f.tokenFile, "token-file",
|
||||
"File containing the ACL token to use in the request instead of one specified "+
|
||||
"via the -token argument or CONSUL_HTTP_TOKEN environment variable. "+
|
||||
"This can also be specified via the CONSUL_HTTP_TOKEN_FILE environment variable.")
|
||||
fs.Var(&f.caFile, "ca-file",
|
||||
"Path to a CA file to use for TLS when communicating with Consul. This "+
|
||||
"can also be specified via the CONSUL_CACERT environment variable.")
|
||||
|
@ -88,6 +95,28 @@ func (f *HTTPFlags) SetToken(v string) error {
|
|||
return f.token.Set(v)
|
||||
}
|
||||
|
||||
func (f *HTTPFlags) TokenFile() string {
|
||||
return f.tokenFile.String()
|
||||
}
|
||||
|
||||
func (f *HTTPFlags) SetTokenFile(v string) error {
|
||||
return f.tokenFile.Set(v)
|
||||
}
|
||||
|
||||
func (f *HTTPFlags) ReadTokenFile() (string, error) {
|
||||
tokenFile := f.tokenFile.String()
|
||||
if tokenFile == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
data, err := ioutil.ReadFile(tokenFile)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return strings.TrimSpace(string(data)), nil
|
||||
}
|
||||
|
||||
func (f *HTTPFlags) APIClient() (*api.Client, error) {
|
||||
c := api.DefaultConfig()
|
||||
|
||||
|
@ -99,6 +128,7 @@ func (f *HTTPFlags) APIClient() (*api.Client, error) {
|
|||
func (f *HTTPFlags) MergeOntoConfig(c *api.Config) {
|
||||
f.address.Merge(&c.Address)
|
||||
f.token.Merge(&c.Token)
|
||||
f.tokenFile.Merge(&c.TokenFile)
|
||||
f.caFile.Merge(&c.TLSConfig.CAFile)
|
||||
f.caPath.Merge(&c.TLSConfig.CAPath)
|
||||
f.certFile.Merge(&c.TLSConfig.CertFile)
|
||||
|
|
|
@ -0,0 +1,148 @@
|
|||
package login
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/consul/api"
|
||||
"github.com/hashicorp/consul/command/flags"
|
||||
"github.com/hashicorp/consul/lib/file"
|
||||
"github.com/mitchellh/cli"
|
||||
)
|
||||
|
||||
func New(ui cli.Ui) *cmd {
|
||||
c := &cmd{UI: ui}
|
||||
c.init()
|
||||
return c
|
||||
}
|
||||
|
||||
type cmd struct {
|
||||
UI cli.Ui
|
||||
flags *flag.FlagSet
|
||||
http *flags.HTTPFlags
|
||||
help string
|
||||
|
||||
shutdownCh <-chan struct{}
|
||||
|
||||
bearerToken string
|
||||
|
||||
// flags
|
||||
authMethodName string
|
||||
bearerTokenFile string
|
||||
tokenSinkFile string
|
||||
meta map[string]string
|
||||
}
|
||||
|
||||
func (c *cmd) init() {
|
||||
c.flags = flag.NewFlagSet("", flag.ContinueOnError)
|
||||
|
||||
c.flags.StringVar(&c.authMethodName, "method", "",
|
||||
"Name of the auth method to login to.")
|
||||
|
||||
c.flags.StringVar(&c.bearerTokenFile, "bearer-token-file", "",
|
||||
"Path to a file containing a secret bearer token to use with this auth method.")
|
||||
|
||||
c.flags.StringVar(&c.tokenSinkFile, "token-sink-file", "",
|
||||
"The most recent token's SecretID is kept up to date in this file.")
|
||||
|
||||
c.flags.Var((*flags.FlagMapValue)(&c.meta), "meta",
|
||||
"Metadata to set on the token, formatted as key=value. This flag "+
|
||||
"may be specified multiple times to set multiple meta fields.")
|
||||
|
||||
c.http = &flags.HTTPFlags{}
|
||||
flags.Merge(c.flags, c.http.ClientFlags())
|
||||
flags.Merge(c.flags, c.http.ServerFlags())
|
||||
c.help = flags.Usage(help, c.flags)
|
||||
}
|
||||
|
||||
func (c *cmd) Run(args []string) int {
|
||||
if err := c.flags.Parse(args); err != nil {
|
||||
return 1
|
||||
}
|
||||
if len(c.flags.Args()) > 0 {
|
||||
c.UI.Error(fmt.Sprintf("Should have no non-flag arguments."))
|
||||
return 1
|
||||
}
|
||||
|
||||
if c.authMethodName == "" {
|
||||
c.UI.Error(fmt.Sprintf("Missing required '-method' flag"))
|
||||
return 1
|
||||
}
|
||||
if c.tokenSinkFile == "" {
|
||||
c.UI.Error(fmt.Sprintf("Missing required '-token-sink-file' flag"))
|
||||
return 1
|
||||
}
|
||||
|
||||
if c.bearerTokenFile == "" {
|
||||
c.UI.Error(fmt.Sprintf("Missing required '-bearer-token-file' flag"))
|
||||
return 1
|
||||
}
|
||||
|
||||
data, err := ioutil.ReadFile(c.bearerTokenFile)
|
||||
if err != nil {
|
||||
c.UI.Error(err.Error())
|
||||
return 1
|
||||
}
|
||||
c.bearerToken = strings.TrimSpace(string(data))
|
||||
|
||||
if c.bearerToken == "" {
|
||||
c.UI.Error(fmt.Sprintf("No bearer token found in %s", c.bearerTokenFile))
|
||||
return 1
|
||||
}
|
||||
|
||||
// Ensure that we don't try to use a token when performing a login
|
||||
// operation.
|
||||
c.http.SetToken("")
|
||||
c.http.SetTokenFile("")
|
||||
|
||||
client, err := c.http.APIClient()
|
||||
if err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Error connecting to Consul agent: %s", err))
|
||||
return 1
|
||||
}
|
||||
|
||||
// Do the login.
|
||||
req := &api.ACLLoginParams{
|
||||
AuthMethod: c.authMethodName,
|
||||
BearerToken: c.bearerToken,
|
||||
Meta: c.meta,
|
||||
}
|
||||
tok, _, err := client.ACL().Login(req, nil)
|
||||
if err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Error logging in: %s", err))
|
||||
return 1
|
||||
}
|
||||
|
||||
if err := c.writeToSink(tok); err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Error writing token to file sink: %s", err))
|
||||
return 1
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
func (c *cmd) writeToSink(tok *api.ACLToken) error {
|
||||
payload := []byte(tok.SecretID)
|
||||
return file.WriteAtomicWithPerms(c.tokenSinkFile, payload, 0600)
|
||||
}
|
||||
|
||||
func (c *cmd) Synopsis() string {
|
||||
return synopsis
|
||||
}
|
||||
|
||||
func (c *cmd) Help() string {
|
||||
return flags.Usage(c.help, nil)
|
||||
}
|
||||
|
||||
const synopsis = "Login to Consul using an Auth Method"
|
||||
|
||||
const help = `
|
||||
Usage: consul login [options]
|
||||
|
||||
The login command will exchange the provided third party credentials with the
|
||||
requested auth method for a newly minted Consul ACL Token. The companion
|
||||
command 'consul logout' should be used to destroy any tokens created this way
|
||||
to avoid a resource leak.
|
||||
`
|
|
@ -0,0 +1,321 @@
|
|||
package login
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/consul/agent"
|
||||
"github.com/hashicorp/consul/agent/consul/authmethod/kubeauth"
|
||||
"github.com/hashicorp/consul/agent/consul/authmethod/testauth"
|
||||
"github.com/hashicorp/consul/api"
|
||||
"github.com/hashicorp/consul/command/acl"
|
||||
"github.com/hashicorp/consul/logger"
|
||||
"github.com/hashicorp/consul/sdk/testutil"
|
||||
"github.com/hashicorp/consul/testrpc"
|
||||
"github.com/mitchellh/cli"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestLoginCommand_noTabs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if strings.ContainsRune(New(cli.NewMockUi()).Help(), '\t') {
|
||||
t.Fatal("help has tabs")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoginCommand(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testDir := testutil.TempDir(t, "acl")
|
||||
defer os.RemoveAll(testDir)
|
||||
|
||||
a := agent.NewTestAgent(t, t.Name(), `
|
||||
primary_datacenter = "dc1"
|
||||
acl {
|
||||
enabled = true
|
||||
tokens {
|
||||
master = "root"
|
||||
}
|
||||
}`)
|
||||
|
||||
a.Agent.LogWriter = logger.NewLogWriter(512)
|
||||
|
||||
defer a.Shutdown()
|
||||
testrpc.WaitForLeader(t, a.RPC, "dc1")
|
||||
|
||||
client := a.Client()
|
||||
|
||||
t.Run("method is required", func(t *testing.T) {
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
}
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 1, "err: %s", ui.ErrorWriter.String())
|
||||
require.Contains(t, ui.ErrorWriter.String(), "Missing required '-method' flag")
|
||||
})
|
||||
|
||||
tokenSinkFile := filepath.Join(testDir, "test.token")
|
||||
|
||||
t.Run("token-sink-file is required", func(t *testing.T) {
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
"-method=test",
|
||||
}
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 1, "err: %s", ui.ErrorWriter.String())
|
||||
require.Contains(t, ui.ErrorWriter.String(), "Missing required '-token-sink-file' flag")
|
||||
})
|
||||
|
||||
t.Run("bearer-token-file is required", func(t *testing.T) {
|
||||
defer os.Remove(tokenSinkFile)
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
"-method=test",
|
||||
"-token-sink-file", tokenSinkFile,
|
||||
}
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 1, "err: %s", ui.ErrorWriter.String())
|
||||
require.Contains(t, ui.ErrorWriter.String(), "Missing required '-bearer-token-file' flag")
|
||||
})
|
||||
|
||||
bearerTokenFile := filepath.Join(testDir, "bearer.token")
|
||||
|
||||
t.Run("bearer-token-file is empty", func(t *testing.T) {
|
||||
defer os.Remove(tokenSinkFile)
|
||||
|
||||
require.NoError(t, ioutil.WriteFile(bearerTokenFile, []byte(""), 0600))
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
"-method=test",
|
||||
"-token-sink-file", tokenSinkFile,
|
||||
"-bearer-token-file", bearerTokenFile,
|
||||
}
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 1, "err: %s", ui.ErrorWriter.String())
|
||||
require.Contains(t, ui.ErrorWriter.String(), "No bearer token found in")
|
||||
})
|
||||
|
||||
require.NoError(t, ioutil.WriteFile(bearerTokenFile, []byte("demo-token"), 0600))
|
||||
|
||||
t.Run("try login with no method configured", func(t *testing.T) {
|
||||
defer os.Remove(tokenSinkFile)
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
"-method=test",
|
||||
"-token-sink-file", tokenSinkFile,
|
||||
"-bearer-token-file", bearerTokenFile,
|
||||
}
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 1, "err: %s", ui.ErrorWriter.String())
|
||||
require.Contains(t, ui.ErrorWriter.String(), "403 (ACL not found)")
|
||||
})
|
||||
|
||||
testSessionID := testauth.StartSession()
|
||||
defer testauth.ResetSession(testSessionID)
|
||||
|
||||
testauth.InstallSessionToken(
|
||||
testSessionID,
|
||||
"demo-token",
|
||||
"default", "demo", "76091af4-4b56-11e9-ac4b-708b11801cbe",
|
||||
)
|
||||
|
||||
{
|
||||
_, _, err := client.ACL().AuthMethodCreate(
|
||||
&api.ACLAuthMethod{
|
||||
Name: "test",
|
||||
Type: "testing",
|
||||
Config: map[string]interface{}{
|
||||
"SessionID": testSessionID,
|
||||
},
|
||||
},
|
||||
&api.WriteOptions{Token: "root"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
t.Run("try login with method configured but no binding rules", func(t *testing.T) {
|
||||
defer os.Remove(tokenSinkFile)
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
"-method=test",
|
||||
"-token-sink-file", tokenSinkFile,
|
||||
"-bearer-token-file", bearerTokenFile,
|
||||
}
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, 1, code, "err: %s", ui.ErrorWriter.String())
|
||||
require.Contains(t, ui.ErrorWriter.String(), "403 (Permission denied)")
|
||||
})
|
||||
|
||||
{
|
||||
_, _, err := client.ACL().BindingRuleCreate(&api.ACLBindingRule{
|
||||
AuthMethod: "test",
|
||||
BindType: api.BindingRuleBindTypeService,
|
||||
BindName: "${serviceaccount.name}",
|
||||
Selector: "serviceaccount.namespace==default",
|
||||
},
|
||||
&api.WriteOptions{Token: "root"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
t.Run("try login with method configured and binding rules", func(t *testing.T) {
|
||||
defer os.Remove(tokenSinkFile)
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
"-method=test",
|
||||
"-token-sink-file", tokenSinkFile,
|
||||
"-bearer-token-file", bearerTokenFile,
|
||||
}
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, 0, code, "err: %s", ui.ErrorWriter.String())
|
||||
require.Empty(t, ui.ErrorWriter.String())
|
||||
require.Empty(t, ui.OutputWriter.String())
|
||||
|
||||
raw, err := ioutil.ReadFile(tokenSinkFile)
|
||||
require.NoError(t, err)
|
||||
|
||||
token := strings.TrimSpace(string(raw))
|
||||
require.Len(t, token, 36, "must be a valid uid: %s", token)
|
||||
})
|
||||
}
|
||||
|
||||
func TestLoginCommand_k8s(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testDir := testutil.TempDir(t, "acl")
|
||||
defer os.RemoveAll(testDir)
|
||||
|
||||
a := agent.NewTestAgent(t, t.Name(), `
|
||||
primary_datacenter = "dc1"
|
||||
acl {
|
||||
enabled = true
|
||||
tokens {
|
||||
master = "root"
|
||||
}
|
||||
}`)
|
||||
|
||||
a.Agent.LogWriter = logger.NewLogWriter(512)
|
||||
|
||||
defer a.Shutdown()
|
||||
testrpc.WaitForLeader(t, a.RPC, "dc1")
|
||||
|
||||
client := a.Client()
|
||||
|
||||
tokenSinkFile := filepath.Join(testDir, "test.token")
|
||||
bearerTokenFile := filepath.Join(testDir, "bearer.token")
|
||||
|
||||
// the "B" jwt will be the one being reviewed
|
||||
require.NoError(t, ioutil.WriteFile(bearerTokenFile, []byte(acl.TestKubernetesJWT_B), 0600))
|
||||
|
||||
// spin up a fake api server
|
||||
testSrv := kubeauth.StartTestAPIServer(t)
|
||||
defer testSrv.Stop()
|
||||
|
||||
testSrv.AuthorizeJWT(acl.TestKubernetesJWT_A)
|
||||
testSrv.SetAllowedServiceAccount(
|
||||
"default",
|
||||
"demo",
|
||||
"76091af4-4b56-11e9-ac4b-708b11801cbe",
|
||||
"",
|
||||
acl.TestKubernetesJWT_B,
|
||||
)
|
||||
|
||||
{
|
||||
_, _, err := client.ACL().AuthMethodCreate(
|
||||
&api.ACLAuthMethod{
|
||||
Name: "k8s",
|
||||
Type: "kubernetes",
|
||||
Config: map[string]interface{}{
|
||||
"Host": testSrv.Addr(),
|
||||
"CACert": testSrv.CACert(),
|
||||
// the "A" jwt will be the one with token review privs
|
||||
"ServiceAccountJWT": acl.TestKubernetesJWT_A,
|
||||
},
|
||||
},
|
||||
&api.WriteOptions{Token: "root"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
{
|
||||
_, _, err := client.ACL().BindingRuleCreate(&api.ACLBindingRule{
|
||||
AuthMethod: "k8s",
|
||||
BindType: api.BindingRuleBindTypeService,
|
||||
BindName: "${serviceaccount.name}",
|
||||
Selector: "serviceaccount.namespace==default",
|
||||
},
|
||||
&api.WriteOptions{Token: "root"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
t.Run("try login with method configured and binding rules", func(t *testing.T) {
|
||||
defer os.Remove(tokenSinkFile)
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
"-method=k8s",
|
||||
"-token-sink-file", tokenSinkFile,
|
||||
"-bearer-token-file", bearerTokenFile,
|
||||
}
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, 0, code, "err: %s", ui.ErrorWriter.String())
|
||||
require.Empty(t, ui.ErrorWriter.String())
|
||||
require.Empty(t, ui.OutputWriter.String())
|
||||
|
||||
raw, err := ioutil.ReadFile(tokenSinkFile)
|
||||
require.NoError(t, err)
|
||||
|
||||
token := strings.TrimSpace(string(raw))
|
||||
require.Len(t, token, 36, "must be a valid uid: %s", token)
|
||||
})
|
||||
}
|
|
@ -0,0 +1,70 @@
|
|||
package logout
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/consul/command/flags"
|
||||
"github.com/mitchellh/cli"
|
||||
)
|
||||
|
||||
func New(ui cli.Ui) *cmd {
|
||||
c := &cmd{UI: ui}
|
||||
c.init()
|
||||
return c
|
||||
}
|
||||
|
||||
type cmd struct {
|
||||
UI cli.Ui
|
||||
flags *flag.FlagSet
|
||||
http *flags.HTTPFlags
|
||||
help string
|
||||
}
|
||||
|
||||
func (c *cmd) init() {
|
||||
c.flags = flag.NewFlagSet("", flag.ContinueOnError)
|
||||
c.http = &flags.HTTPFlags{}
|
||||
flags.Merge(c.flags, c.http.ClientFlags())
|
||||
flags.Merge(c.flags, c.http.ServerFlags())
|
||||
c.help = flags.Usage(help, c.flags)
|
||||
}
|
||||
|
||||
func (c *cmd) Run(args []string) int {
|
||||
if err := c.flags.Parse(args); err != nil {
|
||||
return 1
|
||||
}
|
||||
if len(c.flags.Args()) > 0 {
|
||||
c.UI.Error(fmt.Sprintf("Should have no non-flag arguments."))
|
||||
return 1
|
||||
}
|
||||
|
||||
client, err := c.http.APIClient()
|
||||
if err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Error connecting to Consul agent: %s", err))
|
||||
return 1
|
||||
}
|
||||
|
||||
if _, err := client.ACL().Logout(nil); err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Error destroying token: %v", err))
|
||||
return 1
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
func (c *cmd) Synopsis() string {
|
||||
return synopsis
|
||||
}
|
||||
|
||||
func (c *cmd) Help() string {
|
||||
return flags.Usage(c.help, nil)
|
||||
}
|
||||
|
||||
const synopsis = "Destroy a Consul Token created with Login"
|
||||
|
||||
const help = `
|
||||
Usage: consul logout [options]
|
||||
|
||||
The logout command will destroy the provided token if it was created from
|
||||
'consul login'.
|
||||
`
|
|
@ -0,0 +1,299 @@
|
|||
package logout
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/consul/agent"
|
||||
"github.com/hashicorp/consul/agent/consul/authmethod/kubeauth"
|
||||
"github.com/hashicorp/consul/agent/consul/authmethod/testauth"
|
||||
"github.com/hashicorp/consul/api"
|
||||
"github.com/hashicorp/consul/command/acl"
|
||||
"github.com/hashicorp/consul/logger"
|
||||
"github.com/hashicorp/consul/sdk/testutil"
|
||||
"github.com/hashicorp/consul/testrpc"
|
||||
"github.com/hashicorp/go-uuid"
|
||||
"github.com/mitchellh/cli"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestLogout_noTabs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if strings.ContainsRune(New(cli.NewMockUi()).Help(), '\t') {
|
||||
t.Fatal("help has tabs")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogoutCommand(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testDir := testutil.TempDir(t, "acl")
|
||||
defer os.RemoveAll(testDir)
|
||||
|
||||
a := agent.NewTestAgent(t, t.Name(), `
|
||||
primary_datacenter = "dc1"
|
||||
acl {
|
||||
enabled = true
|
||||
tokens {
|
||||
master = "root"
|
||||
}
|
||||
}`)
|
||||
|
||||
a.Agent.LogWriter = logger.NewLogWriter(512)
|
||||
|
||||
defer a.Shutdown()
|
||||
testrpc.WaitForLeader(t, a.RPC, "dc1")
|
||||
|
||||
client := a.Client()
|
||||
|
||||
t.Run("no token specified", func(t *testing.T) {
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
}
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 1, "err: %s", ui.ErrorWriter.String())
|
||||
require.Contains(t, ui.ErrorWriter.String(), "403 (ACL not found)")
|
||||
})
|
||||
|
||||
t.Run("logout of deleted token", func(t *testing.T) {
|
||||
fakeID, err := uuid.GenerateUUID()
|
||||
require.NoError(t, err)
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=" + fakeID,
|
||||
}
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 1, "err: %s", ui.ErrorWriter.String())
|
||||
require.Contains(t, ui.ErrorWriter.String(), "403 (ACL not found)")
|
||||
})
|
||||
|
||||
plainToken, _, err := client.ACL().TokenCreate(
|
||||
&api.ACLToken{Description: "test"},
|
||||
&api.WriteOptions{Token: "root"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("logout of ordinary token", func(t *testing.T) {
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=" + plainToken.SecretID,
|
||||
}
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 1, "err: %s", ui.ErrorWriter.String())
|
||||
require.Contains(t, ui.ErrorWriter.String(), "403 (Permission denied)")
|
||||
})
|
||||
|
||||
testSessionID := testauth.StartSession()
|
||||
defer testauth.ResetSession(testSessionID)
|
||||
|
||||
testauth.InstallSessionToken(
|
||||
testSessionID,
|
||||
"demo-token",
|
||||
"default", "demo", "76091af4-4b56-11e9-ac4b-708b11801cbe",
|
||||
)
|
||||
|
||||
{
|
||||
_, _, err := client.ACL().AuthMethodCreate(
|
||||
&api.ACLAuthMethod{
|
||||
Name: "test",
|
||||
Type: "testing",
|
||||
Config: map[string]interface{}{
|
||||
"SessionID": testSessionID,
|
||||
},
|
||||
},
|
||||
&api.WriteOptions{Token: "root"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
{
|
||||
_, _, err := client.ACL().BindingRuleCreate(&api.ACLBindingRule{
|
||||
AuthMethod: "test",
|
||||
BindType: api.BindingRuleBindTypeService,
|
||||
BindName: "${serviceaccount.name}",
|
||||
},
|
||||
&api.WriteOptions{Token: "root"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
var loginTokenSecret string
|
||||
{
|
||||
tok, _, err := client.ACL().Login(&api.ACLLoginParams{
|
||||
AuthMethod: "test",
|
||||
BearerToken: "demo-token",
|
||||
}, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
loginTokenSecret = tok.SecretID
|
||||
}
|
||||
|
||||
t.Run("logout of login token", func(t *testing.T) {
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=" + loginTokenSecret,
|
||||
}
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String())
|
||||
require.Empty(t, ui.ErrorWriter.String())
|
||||
})
|
||||
}
|
||||
|
||||
func TestLogoutCommand_k8s(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testDir := testutil.TempDir(t, "acl")
|
||||
defer os.RemoveAll(testDir)
|
||||
|
||||
a := agent.NewTestAgent(t, t.Name(), `
|
||||
primary_datacenter = "dc1"
|
||||
acl {
|
||||
enabled = true
|
||||
tokens {
|
||||
master = "root"
|
||||
}
|
||||
}`)
|
||||
|
||||
a.Agent.LogWriter = logger.NewLogWriter(512)
|
||||
|
||||
defer a.Shutdown()
|
||||
testrpc.WaitForLeader(t, a.RPC, "dc1")
|
||||
|
||||
client := a.Client()
|
||||
|
||||
t.Run("no token specified", func(t *testing.T) {
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
}
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 1, "err: %s", ui.ErrorWriter.String())
|
||||
require.Contains(t, ui.ErrorWriter.String(), "403 (ACL not found)")
|
||||
})
|
||||
|
||||
t.Run("logout of deleted token", func(t *testing.T) {
|
||||
fakeID, err := uuid.GenerateUUID()
|
||||
require.NoError(t, err)
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=" + fakeID,
|
||||
}
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 1, "err: %s", ui.ErrorWriter.String())
|
||||
require.Contains(t, ui.ErrorWriter.String(), "403 (ACL not found)")
|
||||
})
|
||||
|
||||
plainToken, _, err := client.ACL().TokenCreate(
|
||||
&api.ACLToken{Description: "test"},
|
||||
&api.WriteOptions{Token: "root"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("logout of ordinary token", func(t *testing.T) {
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=" + plainToken.SecretID,
|
||||
}
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 1, "err: %s", ui.ErrorWriter.String())
|
||||
require.Contains(t, ui.ErrorWriter.String(), "403 (Permission denied)")
|
||||
})
|
||||
|
||||
// go to the trouble of creating a login token
|
||||
// require.NoError(t, ioutil.WriteFile(bearerTokenFile, []byte(acl.TestKubernetesJWT_B), 0600))
|
||||
|
||||
// spin up a fake api server
|
||||
testSrv := kubeauth.StartTestAPIServer(t)
|
||||
defer testSrv.Stop()
|
||||
|
||||
testSrv.AuthorizeJWT(acl.TestKubernetesJWT_A)
|
||||
testSrv.SetAllowedServiceAccount(
|
||||
"default",
|
||||
"demo",
|
||||
"76091af4-4b56-11e9-ac4b-708b11801cbe",
|
||||
"",
|
||||
acl.TestKubernetesJWT_B,
|
||||
)
|
||||
|
||||
{
|
||||
_, _, err := client.ACL().AuthMethodCreate(
|
||||
&api.ACLAuthMethod{
|
||||
Name: "k8s",
|
||||
Type: "kubernetes",
|
||||
Config: map[string]interface{}{
|
||||
"Host": testSrv.Addr(),
|
||||
"CACert": testSrv.CACert(),
|
||||
// the "A" jwt will be the one with token review privs
|
||||
"ServiceAccountJWT": acl.TestKubernetesJWT_A,
|
||||
},
|
||||
},
|
||||
&api.WriteOptions{Token: "root"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
{
|
||||
_, _, err := client.ACL().BindingRuleCreate(&api.ACLBindingRule{
|
||||
AuthMethod: "k8s",
|
||||
BindType: api.BindingRuleBindTypeService,
|
||||
BindName: "${serviceaccount.name}",
|
||||
},
|
||||
&api.WriteOptions{Token: "root"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
var loginTokenSecret string
|
||||
{
|
||||
tok, _, err := client.ACL().Login(&api.ACLLoginParams{
|
||||
AuthMethod: "k8s",
|
||||
BearerToken: acl.TestKubernetesJWT_B,
|
||||
}, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
loginTokenSecret = tok.SecretID
|
||||
}
|
||||
|
||||
t.Run("logout of login token", func(t *testing.T) {
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=" + loginTokenSecret,
|
||||
}
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 0, "err: %s", ui.ErrorWriter.String())
|
||||
require.Empty(t, ui.ErrorWriter.String())
|
||||
})
|
||||
}
|
|
@ -87,6 +87,14 @@ func (c *cmd) Run(args []string) int {
|
|||
return 1
|
||||
}
|
||||
|
||||
token := c.http.Token()
|
||||
if tokenFromFile, err := c.http.ReadTokenFile(); err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Error loading token file: %s", err))
|
||||
return 1
|
||||
} else if tokenFromFile != "" {
|
||||
token = tokenFromFile
|
||||
}
|
||||
|
||||
// Compile the watch parameters
|
||||
params := make(map[string]interface{})
|
||||
if c.watchType != "" {
|
||||
|
@ -95,8 +103,8 @@ func (c *cmd) Run(args []string) int {
|
|||
if c.http.Datacenter() != "" {
|
||||
params["datacenter"] = c.http.Datacenter()
|
||||
}
|
||||
if c.http.Token() != "" {
|
||||
params["token"] = c.http.Token()
|
||||
if token != "" {
|
||||
params["token"] = token
|
||||
}
|
||||
if c.key != "" {
|
||||
params["key"] = c.key
|
||||
|
|
6
go.mod
6
go.mod
|
@ -123,7 +123,9 @@ require (
|
|||
gopkg.in/asn1-ber.v1 v1.0.0-20181015200546-f715ec2f112d // indirect
|
||||
gopkg.in/mgo.v2 v2.0.0-20160818020120-3f83fa500528 // indirect
|
||||
gopkg.in/ory-am/dockertest.v3 v3.3.4 // indirect
|
||||
gopkg.in/square/go-jose.v2 v2.3.1
|
||||
gotest.tools v2.2.0+incompatible // indirect
|
||||
k8s.io/api v0.0.0-20190118113203-912cbe2bfef3 // indirect
|
||||
k8s.io/apimachinery v0.0.0-20180904193909-def12e63c512 // indirect
|
||||
k8s.io/api v0.0.0-20190325185214-7544f9db76f6
|
||||
k8s.io/apimachinery v0.0.0-20190223001710-c182ff3b9841
|
||||
k8s.io/client-go v8.0.0+incompatible
|
||||
)
|
||||
|
|
10
go.sum
10
go.sum
|
@ -383,6 +383,8 @@ gopkg.in/mgo.v2 v2.0.0-20160818020120-3f83fa500528 h1:/saqWwm73dLmuzbNhe92F0QsZ/
|
|||
gopkg.in/mgo.v2 v2.0.0-20160818020120-3f83fa500528/go.mod h1:yeKp02qBN3iKW1OzL3MGk2IdtZzaj7SFntXj72NppTA=
|
||||
gopkg.in/ory-am/dockertest.v3 v3.3.4 h1:oen8RiwxVNxtQ1pRoV4e4jqh6UjNsOuIZ1NXns6jdcw=
|
||||
gopkg.in/ory-am/dockertest.v3 v3.3.4/go.mod h1:s9mmoLkaGeAh97qygnNj4xWkiN7e1SKekYC6CovU+ek=
|
||||
gopkg.in/square/go-jose.v2 v2.3.1 h1:SK5KegNXmKmqE342YYN2qPHEnUYeoMiXXl1poUlI+o4=
|
||||
gopkg.in/square/go-jose.v2 v2.3.1/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI=
|
||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
|
||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
|
||||
gopkg.in/yaml.v2 v2.2.1 h1:mUhvW9EsL+naU5Q3cakzfE91YhliOondGd6ZrsDBHQE=
|
||||
|
@ -391,10 +393,10 @@ gotest.tools v2.2.0+incompatible h1:VsBPFP1AI068pPrMxtb/S8Zkgf9xEmTLJjfM+P5UIEo=
|
|||
gotest.tools v2.2.0+incompatible/go.mod h1:DsYFclhRJ6vuDpmuTbkuFWG+y2sxOXAzmJt81HFBacw=
|
||||
honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
||||
k8s.io/api v0.0.0-20180806132203-61b11ee65332/go.mod h1:iuAfoD4hCxJ8Onx9kaTIt30j7jUFS00AXQi6QMi99vA=
|
||||
k8s.io/api v0.0.0-20190118113203-912cbe2bfef3 h1:lV0+KGoNkvZOt4zGT4H83hQrzWMt/US/LSz4z4+BQS4=
|
||||
k8s.io/api v0.0.0-20190118113203-912cbe2bfef3/go.mod h1:iuAfoD4hCxJ8Onx9kaTIt30j7jUFS00AXQi6QMi99vA=
|
||||
k8s.io/api v0.0.0-20190325185214-7544f9db76f6 h1:9MWtbqhwTyDvF4cS1qAhxDb9Mi8taXiAu+5nEacl7gY=
|
||||
k8s.io/api v0.0.0-20190325185214-7544f9db76f6/go.mod h1:iuAfoD4hCxJ8Onx9kaTIt30j7jUFS00AXQi6QMi99vA=
|
||||
k8s.io/apimachinery v0.0.0-20180821005732-488889b0007f/go.mod h1:ccL7Eh7zubPUSh9A3USN90/OzHNSVN6zxzde07TDCL0=
|
||||
k8s.io/apimachinery v0.0.0-20180904193909-def12e63c512 h1:/Z1m/6oEN6hE2SzWP4BHW2yATeUrBRr+1GxNf1Ny58Y=
|
||||
k8s.io/apimachinery v0.0.0-20180904193909-def12e63c512/go.mod h1:ccL7Eh7zubPUSh9A3USN90/OzHNSVN6zxzde07TDCL0=
|
||||
k8s.io/apimachinery v0.0.0-20190223001710-c182ff3b9841 h1:Q4RZrHNtlC/mSdC1sTrcZ5RchC/9vxLVj57pWiCBKv4=
|
||||
k8s.io/apimachinery v0.0.0-20190223001710-c182ff3b9841/go.mod h1:ccL7Eh7zubPUSh9A3USN90/OzHNSVN6zxzde07TDCL0=
|
||||
k8s.io/client-go v8.0.0+incompatible h1:tTI4hRmb1DRMl4fG6Vclfdi6nTM82oIrTT7HfitmxC4=
|
||||
k8s.io/client-go v8.0.0+incompatible/go.mod h1:7vJpHMYJwNQCWgzmNV+VYUl1zCObLyodBc8nIyt8L5s=
|
||||
|
|
|
@ -0,0 +1,77 @@
|
|||
// Copyright 2012 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
/*
|
||||
Package pbkdf2 implements the key derivation function PBKDF2 as defined in RFC
|
||||
2898 / PKCS #5 v2.0.
|
||||
|
||||
A key derivation function is useful when encrypting data based on a password
|
||||
or any other not-fully-random data. It uses a pseudorandom function to derive
|
||||
a secure encryption key based on the password.
|
||||
|
||||
While v2.0 of the standard defines only one pseudorandom function to use,
|
||||
HMAC-SHA1, the drafted v2.1 specification allows use of all five FIPS Approved
|
||||
Hash Functions SHA-1, SHA-224, SHA-256, SHA-384 and SHA-512 for HMAC. To
|
||||
choose, you can pass the `New` functions from the different SHA packages to
|
||||
pbkdf2.Key.
|
||||
*/
|
||||
package pbkdf2 // import "golang.org/x/crypto/pbkdf2"
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"hash"
|
||||
)
|
||||
|
||||
// Key derives a key from the password, salt and iteration count, returning a
|
||||
// []byte of length keylen that can be used as cryptographic key. The key is
|
||||
// derived based on the method described as PBKDF2 with the HMAC variant using
|
||||
// the supplied hash function.
|
||||
//
|
||||
// For example, to use a HMAC-SHA-1 based PBKDF2 key derivation function, you
|
||||
// can get a derived key for e.g. AES-256 (which needs a 32-byte key) by
|
||||
// doing:
|
||||
//
|
||||
// dk := pbkdf2.Key([]byte("some password"), salt, 4096, 32, sha1.New)
|
||||
//
|
||||
// Remember to get a good random salt. At least 8 bytes is recommended by the
|
||||
// RFC.
|
||||
//
|
||||
// Using a higher iteration count will increase the cost of an exhaustive
|
||||
// search but will also make derivation proportionally slower.
|
||||
func Key(password, salt []byte, iter, keyLen int, h func() hash.Hash) []byte {
|
||||
prf := hmac.New(h, password)
|
||||
hashLen := prf.Size()
|
||||
numBlocks := (keyLen + hashLen - 1) / hashLen
|
||||
|
||||
var buf [4]byte
|
||||
dk := make([]byte, 0, numBlocks*hashLen)
|
||||
U := make([]byte, hashLen)
|
||||
for block := 1; block <= numBlocks; block++ {
|
||||
// N.B.: || means concatenation, ^ means XOR
|
||||
// for each block T_i = U_1 ^ U_2 ^ ... ^ U_iter
|
||||
// U_1 = PRF(password, salt || uint(i))
|
||||
prf.Reset()
|
||||
prf.Write(salt)
|
||||
buf[0] = byte(block >> 24)
|
||||
buf[1] = byte(block >> 16)
|
||||
buf[2] = byte(block >> 8)
|
||||
buf[3] = byte(block)
|
||||
prf.Write(buf[:4])
|
||||
dk = prf.Sum(dk)
|
||||
T := dk[len(dk)-hashLen:]
|
||||
copy(U, T)
|
||||
|
||||
// U_n = PRF(password, U_(n-1))
|
||||
for n := 2; n <= iter; n++ {
|
||||
prf.Reset()
|
||||
prf.Write(U)
|
||||
U = U[:0]
|
||||
U = prf.Sum(U)
|
||||
for x := range U {
|
||||
T[x] ^= U[x]
|
||||
}
|
||||
}
|
||||
}
|
||||
return dk[:keyLen]
|
||||
}
|
|
@ -0,0 +1 @@
|
|||
'|Ę&{tÄU|gGę(ěŹCy=+¨śňcű:u:/pś#~žü["±4¤!nŮAŞDK<Šuf˙hĹażÂ:şü¸ˇ´B/ŁŘ¤ą¤ň_<C588>hÎŰSăT*wĚxĽŻťą-ç|ťŕŔÓ<C594>ŃÄäóĚ㣗A$$â6ŁÁâG)8nĎpűĆˡ3ĚšśoďĎvŽB–3ż]xÝ“Ó2l§G•|qRŢŻ
ö2
5R–Ó×Ç$´ń˝YčˇŢÝ™l‘Ë«yAI"ŰŚ<C5B0>®íĂ»ąĽkÄ|Kĺţ[9ĆâŇĺ=°ú˙źń|@S•3ó#ćťx?ľV„,ľ‚SĆÝőśwPíogŇ6&V6 ©D.dBŠ7
|
|
@ -0,0 +1,7 @@
|
|||
*~
|
||||
.*.swp
|
||||
*.out
|
||||
*.test
|
||||
*.pem
|
||||
*.cov
|
||||
jose-util/jose-util
|
|
@ -0,0 +1,46 @@
|
|||
language: go
|
||||
|
||||
sudo: false
|
||||
|
||||
matrix:
|
||||
fast_finish: true
|
||||
allow_failures:
|
||||
- go: tip
|
||||
|
||||
go:
|
||||
- '1.7.x'
|
||||
- '1.8.x'
|
||||
- '1.9.x'
|
||||
- '1.10.x'
|
||||
- '1.11.x'
|
||||
|
||||
go_import_path: gopkg.in/square/go-jose.v2
|
||||
|
||||
before_script:
|
||||
- export PATH=$HOME/.local/bin:$PATH
|
||||
|
||||
before_install:
|
||||
# Install encrypted gitcookies to get around bandwidth-limits
|
||||
# that is causing Travis-CI builds to fail. For more info, see
|
||||
# https://github.com/golang/go/issues/12933
|
||||
- openssl aes-256-cbc -K $encrypted_1528c3c2cafd_key -iv $encrypted_1528c3c2cafd_iv -in .gitcookies.sh.enc -out .gitcookies.sh -d || true
|
||||
- bash .gitcookies.sh || true
|
||||
- go get github.com/wadey/gocovmerge
|
||||
- go get github.com/mattn/goveralls
|
||||
- go get github.com/stretchr/testify/assert
|
||||
- go get golang.org/x/tools/cmd/cover || true
|
||||
- go get code.google.com/p/go.tools/cmd/cover || true
|
||||
- pip install cram --user
|
||||
|
||||
script:
|
||||
- go test . -v -covermode=count -coverprofile=profile.cov
|
||||
- go test ./cipher -v -covermode=count -coverprofile=cipher/profile.cov
|
||||
- go test ./jwt -v -covermode=count -coverprofile=jwt/profile.cov
|
||||
- go test ./json -v # no coverage for forked encoding/json package
|
||||
- cd jose-util && go build && PATH=$PWD:$PATH cram -v jose-util.t
|
||||
- cd ..
|
||||
|
||||
after_success:
|
||||
- gocovmerge *.cov */*.cov > merged.coverprofile
|
||||
- $HOME/gopath/bin/goveralls -coverprofile merged.coverprofile -service=travis-ci
|
||||
|
|
@ -0,0 +1,10 @@
|
|||
Serious about security
|
||||
======================
|
||||
|
||||
Square recognizes the important contributions the security research community
|
||||
can make. We therefore encourage reporting security issues with the code
|
||||
contained in this repository.
|
||||
|
||||
If you believe you have discovered a security vulnerability, please follow the
|
||||
guidelines at <https://bugcrowd.com/squareopensource>.
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
# Contributing
|
||||
|
||||
If you would like to contribute code to go-jose you can do so through GitHub by
|
||||
forking the repository and sending a pull request.
|
||||
|
||||
When submitting code, please make every effort to follow existing conventions
|
||||
and style in order to keep the code as readable as possible. Please also make
|
||||
sure all tests pass by running `go test`, and format your code with `go fmt`.
|
||||
We also recommend using `golint` and `errcheck`.
|
||||
|
||||
Before your code can be accepted into the project you must also sign the
|
||||
[Individual Contributor License Agreement][1].
|
||||
|
||||
[1]: https://spreadsheets.google.com/spreadsheet/viewform?formkey=dDViT2xzUHAwRkI3X3k5Z0lQM091OGc6MQ&ndplr=1
|
|
@ -0,0 +1,202 @@
|
|||
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
|
@ -0,0 +1,118 @@
|
|||
# Go JOSE
|
||||
|
||||
[![godoc](http://img.shields.io/badge/godoc-version_1-blue.svg?style=flat)](https://godoc.org/gopkg.in/square/go-jose.v1)
|
||||
[![godoc](http://img.shields.io/badge/godoc-version_2-blue.svg?style=flat)](https://godoc.org/gopkg.in/square/go-jose.v2)
|
||||
[![license](http://img.shields.io/badge/license-apache_2.0-blue.svg?style=flat)](https://raw.githubusercontent.com/square/go-jose/master/LICENSE)
|
||||
[![build](https://travis-ci.org/square/go-jose.svg?branch=v2)](https://travis-ci.org/square/go-jose)
|
||||
[![coverage](https://coveralls.io/repos/github/square/go-jose/badge.svg?branch=v2)](https://coveralls.io/r/square/go-jose)
|
||||
|
||||
Package jose aims to provide an implementation of the Javascript Object Signing
|
||||
and Encryption set of standards. This includes support for JSON Web Encryption,
|
||||
JSON Web Signature, and JSON Web Token standards.
|
||||
|
||||
**Disclaimer**: This library contains encryption software that is subject to
|
||||
the U.S. Export Administration Regulations. You may not export, re-export,
|
||||
transfer or download this code or any part of it in violation of any United
|
||||
States law, directive or regulation. In particular this software may not be
|
||||
exported or re-exported in any form or on any media to Iran, North Sudan,
|
||||
Syria, Cuba, or North Korea, or to denied persons or entities mentioned on any
|
||||
US maintained blocked list.
|
||||
|
||||
## Overview
|
||||
|
||||
The implementation follows the
|
||||
[JSON Web Encryption](http://dx.doi.org/10.17487/RFC7516) (RFC 7516),
|
||||
[JSON Web Signature](http://dx.doi.org/10.17487/RFC7515) (RFC 7515), and
|
||||
[JSON Web Token](http://dx.doi.org/10.17487/RFC7519) (RFC 7519).
|
||||
Tables of supported algorithms are shown below. The library supports both
|
||||
the compact and full serialization formats, and has optional support for
|
||||
multiple recipients. It also comes with a small command-line utility
|
||||
([`jose-util`](https://github.com/square/go-jose/tree/v2/jose-util))
|
||||
for dealing with JOSE messages in a shell.
|
||||
|
||||
**Note**: We use a forked version of the `encoding/json` package from the Go
|
||||
standard library which uses case-sensitive matching for member names (instead
|
||||
of [case-insensitive matching](https://www.ietf.org/mail-archive/web/json/current/msg03763.html)).
|
||||
This is to avoid differences in interpretation of messages between go-jose and
|
||||
libraries in other languages.
|
||||
|
||||
### Versions
|
||||
|
||||
We use [gopkg.in](https://gopkg.in) for versioning.
|
||||
|
||||
[Version 2](https://gopkg.in/square/go-jose.v2)
|
||||
([branch](https://github.com/square/go-jose/tree/v2),
|
||||
[doc](https://godoc.org/gopkg.in/square/go-jose.v2)) is the current version:
|
||||
|
||||
import "gopkg.in/square/go-jose.v2"
|
||||
|
||||
The old `v1` branch ([go-jose.v1](https://gopkg.in/square/go-jose.v1)) will
|
||||
still receive backported bug fixes and security fixes, but otherwise
|
||||
development is frozen. All new feature development takes place on the `v2`
|
||||
branch. Version 2 also contains additional sub-packages such as the
|
||||
[jwt](https://godoc.org/gopkg.in/square/go-jose.v2/jwt) implementation
|
||||
contributed by [@shaxbee](https://github.com/shaxbee).
|
||||
|
||||
### Supported algorithms
|
||||
|
||||
See below for a table of supported algorithms. Algorithm identifiers match
|
||||
the names in the [JSON Web Algorithms](http://dx.doi.org/10.17487/RFC7518)
|
||||
standard where possible. The Godoc reference has a list of constants.
|
||||
|
||||
Key encryption | Algorithm identifier(s)
|
||||
:------------------------- | :------------------------------
|
||||
RSA-PKCS#1v1.5 | RSA1_5
|
||||
RSA-OAEP | RSA-OAEP, RSA-OAEP-256
|
||||
AES key wrap | A128KW, A192KW, A256KW
|
||||
AES-GCM key wrap | A128GCMKW, A192GCMKW, A256GCMKW
|
||||
ECDH-ES + AES key wrap | ECDH-ES+A128KW, ECDH-ES+A192KW, ECDH-ES+A256KW
|
||||
ECDH-ES (direct) | ECDH-ES<sup>1</sup>
|
||||
Direct encryption | dir<sup>1</sup>
|
||||
|
||||
<sup>1. Not supported in multi-recipient mode</sup>
|
||||
|
||||
Signing / MAC | Algorithm identifier(s)
|
||||
:------------------------- | :------------------------------
|
||||
RSASSA-PKCS#1v1.5 | RS256, RS384, RS512
|
||||
RSASSA-PSS | PS256, PS384, PS512
|
||||
HMAC | HS256, HS384, HS512
|
||||
ECDSA | ES256, ES384, ES512
|
||||
Ed25519 | EdDSA<sup>2</sup>
|
||||
|
||||
<sup>2. Only available in version 2 of the package</sup>
|
||||
|
||||
Content encryption | Algorithm identifier(s)
|
||||
:------------------------- | :------------------------------
|
||||
AES-CBC+HMAC | A128CBC-HS256, A192CBC-HS384, A256CBC-HS512
|
||||
AES-GCM | A128GCM, A192GCM, A256GCM
|
||||
|
||||
Compression | Algorithm identifiers(s)
|
||||
:------------------------- | -------------------------------
|
||||
DEFLATE (RFC 1951) | DEF
|
||||
|
||||
### Supported key types
|
||||
|
||||
See below for a table of supported key types. These are understood by the
|
||||
library, and can be passed to corresponding functions such as `NewEncrypter` or
|
||||
`NewSigner`. Each of these keys can also be wrapped in a JWK if desired, which
|
||||
allows attaching a key id.
|
||||
|
||||
Algorithm(s) | Corresponding types
|
||||
:------------------------- | -------------------------------
|
||||
RSA | *[rsa.PublicKey](http://golang.org/pkg/crypto/rsa/#PublicKey), *[rsa.PrivateKey](http://golang.org/pkg/crypto/rsa/#PrivateKey)
|
||||
ECDH, ECDSA | *[ecdsa.PublicKey](http://golang.org/pkg/crypto/ecdsa/#PublicKey), *[ecdsa.PrivateKey](http://golang.org/pkg/crypto/ecdsa/#PrivateKey)
|
||||
EdDSA<sup>1</sup> | [ed25519.PublicKey](https://godoc.org/golang.org/x/crypto/ed25519#PublicKey), [ed25519.PrivateKey](https://godoc.org/golang.org/x/crypto/ed25519#PrivateKey)
|
||||
AES, HMAC | []byte
|
||||
|
||||
<sup>1. Only available in version 2 of the package</sup>
|
||||
|
||||
## Examples
|
||||
|
||||
[![godoc](http://img.shields.io/badge/godoc-version_1-blue.svg?style=flat)](https://godoc.org/gopkg.in/square/go-jose.v1)
|
||||
[![godoc](http://img.shields.io/badge/godoc-version_2-blue.svg?style=flat)](https://godoc.org/gopkg.in/square/go-jose.v2)
|
||||
|
||||
Examples can be found in the Godoc
|
||||
reference for this package. The
|
||||
[`jose-util`](https://github.com/square/go-jose/tree/v2/jose-util)
|
||||
subdirectory also contains a small command-line utility which might be useful
|
||||
as an example.
|
|
@ -0,0 +1,592 @@
|
|||
/*-
|
||||
* Copyright 2014 Square Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package jose
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/aes"
|
||||
"crypto/ecdsa"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/sha1"
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
|
||||
"golang.org/x/crypto/ed25519"
|
||||
"gopkg.in/square/go-jose.v2/cipher"
|
||||
"gopkg.in/square/go-jose.v2/json"
|
||||
)
|
||||
|
||||
// A generic RSA-based encrypter/verifier
|
||||
type rsaEncrypterVerifier struct {
|
||||
publicKey *rsa.PublicKey
|
||||
}
|
||||
|
||||
// A generic RSA-based decrypter/signer
|
||||
type rsaDecrypterSigner struct {
|
||||
privateKey *rsa.PrivateKey
|
||||
}
|
||||
|
||||
// A generic EC-based encrypter/verifier
|
||||
type ecEncrypterVerifier struct {
|
||||
publicKey *ecdsa.PublicKey
|
||||
}
|
||||
|
||||
type edEncrypterVerifier struct {
|
||||
publicKey ed25519.PublicKey
|
||||
}
|
||||
|
||||
// A key generator for ECDH-ES
|
||||
type ecKeyGenerator struct {
|
||||
size int
|
||||
algID string
|
||||
publicKey *ecdsa.PublicKey
|
||||
}
|
||||
|
||||
// A generic EC-based decrypter/signer
|
||||
type ecDecrypterSigner struct {
|
||||
privateKey *ecdsa.PrivateKey
|
||||
}
|
||||
|
||||
type edDecrypterSigner struct {
|
||||
privateKey ed25519.PrivateKey
|
||||
}
|
||||
|
||||
// newRSARecipient creates recipientKeyInfo based on the given key.
|
||||
func newRSARecipient(keyAlg KeyAlgorithm, publicKey *rsa.PublicKey) (recipientKeyInfo, error) {
|
||||
// Verify that key management algorithm is supported by this encrypter
|
||||
switch keyAlg {
|
||||
case RSA1_5, RSA_OAEP, RSA_OAEP_256:
|
||||
default:
|
||||
return recipientKeyInfo{}, ErrUnsupportedAlgorithm
|
||||
}
|
||||
|
||||
if publicKey == nil {
|
||||
return recipientKeyInfo{}, errors.New("invalid public key")
|
||||
}
|
||||
|
||||
return recipientKeyInfo{
|
||||
keyAlg: keyAlg,
|
||||
keyEncrypter: &rsaEncrypterVerifier{
|
||||
publicKey: publicKey,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// newRSASigner creates a recipientSigInfo based on the given key.
|
||||
func newRSASigner(sigAlg SignatureAlgorithm, privateKey *rsa.PrivateKey) (recipientSigInfo, error) {
|
||||
// Verify that key management algorithm is supported by this encrypter
|
||||
switch sigAlg {
|
||||
case RS256, RS384, RS512, PS256, PS384, PS512:
|
||||
default:
|
||||
return recipientSigInfo{}, ErrUnsupportedAlgorithm
|
||||
}
|
||||
|
||||
if privateKey == nil {
|
||||
return recipientSigInfo{}, errors.New("invalid private key")
|
||||
}
|
||||
|
||||
return recipientSigInfo{
|
||||
sigAlg: sigAlg,
|
||||
publicKey: staticPublicKey(&JSONWebKey{
|
||||
Key: privateKey.Public(),
|
||||
}),
|
||||
signer: &rsaDecrypterSigner{
|
||||
privateKey: privateKey,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func newEd25519Signer(sigAlg SignatureAlgorithm, privateKey ed25519.PrivateKey) (recipientSigInfo, error) {
|
||||
if sigAlg != EdDSA {
|
||||
return recipientSigInfo{}, ErrUnsupportedAlgorithm
|
||||
}
|
||||
|
||||
if privateKey == nil {
|
||||
return recipientSigInfo{}, errors.New("invalid private key")
|
||||
}
|
||||
return recipientSigInfo{
|
||||
sigAlg: sigAlg,
|
||||
publicKey: staticPublicKey(&JSONWebKey{
|
||||
Key: privateKey.Public(),
|
||||
}),
|
||||
signer: &edDecrypterSigner{
|
||||
privateKey: privateKey,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// newECDHRecipient creates recipientKeyInfo based on the given key.
|
||||
func newECDHRecipient(keyAlg KeyAlgorithm, publicKey *ecdsa.PublicKey) (recipientKeyInfo, error) {
|
||||
// Verify that key management algorithm is supported by this encrypter
|
||||
switch keyAlg {
|
||||
case ECDH_ES, ECDH_ES_A128KW, ECDH_ES_A192KW, ECDH_ES_A256KW:
|
||||
default:
|
||||
return recipientKeyInfo{}, ErrUnsupportedAlgorithm
|
||||
}
|
||||
|
||||
if publicKey == nil || !publicKey.Curve.IsOnCurve(publicKey.X, publicKey.Y) {
|
||||
return recipientKeyInfo{}, errors.New("invalid public key")
|
||||
}
|
||||
|
||||
return recipientKeyInfo{
|
||||
keyAlg: keyAlg,
|
||||
keyEncrypter: &ecEncrypterVerifier{
|
||||
publicKey: publicKey,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// newECDSASigner creates a recipientSigInfo based on the given key.
|
||||
func newECDSASigner(sigAlg SignatureAlgorithm, privateKey *ecdsa.PrivateKey) (recipientSigInfo, error) {
|
||||
// Verify that key management algorithm is supported by this encrypter
|
||||
switch sigAlg {
|
||||
case ES256, ES384, ES512:
|
||||
default:
|
||||
return recipientSigInfo{}, ErrUnsupportedAlgorithm
|
||||
}
|
||||
|
||||
if privateKey == nil {
|
||||
return recipientSigInfo{}, errors.New("invalid private key")
|
||||
}
|
||||
|
||||
return recipientSigInfo{
|
||||
sigAlg: sigAlg,
|
||||
publicKey: staticPublicKey(&JSONWebKey{
|
||||
Key: privateKey.Public(),
|
||||
}),
|
||||
signer: &ecDecrypterSigner{
|
||||
privateKey: privateKey,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Encrypt the given payload and update the object.
|
||||
func (ctx rsaEncrypterVerifier) encryptKey(cek []byte, alg KeyAlgorithm) (recipientInfo, error) {
|
||||
encryptedKey, err := ctx.encrypt(cek, alg)
|
||||
if err != nil {
|
||||
return recipientInfo{}, err
|
||||
}
|
||||
|
||||
return recipientInfo{
|
||||
encryptedKey: encryptedKey,
|
||||
header: &rawHeader{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Encrypt the given payload. Based on the key encryption algorithm,
|
||||
// this will either use RSA-PKCS1v1.5 or RSA-OAEP (with SHA-1 or SHA-256).
|
||||
func (ctx rsaEncrypterVerifier) encrypt(cek []byte, alg KeyAlgorithm) ([]byte, error) {
|
||||
switch alg {
|
||||
case RSA1_5:
|
||||
return rsa.EncryptPKCS1v15(RandReader, ctx.publicKey, cek)
|
||||
case RSA_OAEP:
|
||||
return rsa.EncryptOAEP(sha1.New(), RandReader, ctx.publicKey, cek, []byte{})
|
||||
case RSA_OAEP_256:
|
||||
return rsa.EncryptOAEP(sha256.New(), RandReader, ctx.publicKey, cek, []byte{})
|
||||
}
|
||||
|
||||
return nil, ErrUnsupportedAlgorithm
|
||||
}
|
||||
|
||||
// Decrypt the given payload and return the content encryption key.
|
||||
func (ctx rsaDecrypterSigner) decryptKey(headers rawHeader, recipient *recipientInfo, generator keyGenerator) ([]byte, error) {
|
||||
return ctx.decrypt(recipient.encryptedKey, headers.getAlgorithm(), generator)
|
||||
}
|
||||
|
||||
// Decrypt the given payload. Based on the key encryption algorithm,
|
||||
// this will either use RSA-PKCS1v1.5 or RSA-OAEP (with SHA-1 or SHA-256).
|
||||
func (ctx rsaDecrypterSigner) decrypt(jek []byte, alg KeyAlgorithm, generator keyGenerator) ([]byte, error) {
|
||||
// Note: The random reader on decrypt operations is only used for blinding,
|
||||
// so stubbing is meanlingless (hence the direct use of rand.Reader).
|
||||
switch alg {
|
||||
case RSA1_5:
|
||||
defer func() {
|
||||
// DecryptPKCS1v15SessionKey sometimes panics on an invalid payload
|
||||
// because of an index out of bounds error, which we want to ignore.
|
||||
// This has been fixed in Go 1.3.1 (released 2014/08/13), the recover()
|
||||
// only exists for preventing crashes with unpatched versions.
|
||||
// See: https://groups.google.com/forum/#!topic/golang-dev/7ihX6Y6kx9k
|
||||
// See: https://code.google.com/p/go/source/detail?r=58ee390ff31602edb66af41ed10901ec95904d33
|
||||
_ = recover()
|
||||
}()
|
||||
|
||||
// Perform some input validation.
|
||||
keyBytes := ctx.privateKey.PublicKey.N.BitLen() / 8
|
||||
if keyBytes != len(jek) {
|
||||
// Input size is incorrect, the encrypted payload should always match
|
||||
// the size of the public modulus (e.g. using a 2048 bit key will
|
||||
// produce 256 bytes of output). Reject this since it's invalid input.
|
||||
return nil, ErrCryptoFailure
|
||||
}
|
||||
|
||||
cek, _, err := generator.genKey()
|
||||
if err != nil {
|
||||
return nil, ErrCryptoFailure
|
||||
}
|
||||
|
||||
// When decrypting an RSA-PKCS1v1.5 payload, we must take precautions to
|
||||
// prevent chosen-ciphertext attacks as described in RFC 3218, "Preventing
|
||||
// the Million Message Attack on Cryptographic Message Syntax". We are
|
||||
// therefore deliberately ignoring errors here.
|
||||
_ = rsa.DecryptPKCS1v15SessionKey(rand.Reader, ctx.privateKey, jek, cek)
|
||||
|
||||
return cek, nil
|
||||
case RSA_OAEP:
|
||||
// Use rand.Reader for RSA blinding
|
||||
return rsa.DecryptOAEP(sha1.New(), rand.Reader, ctx.privateKey, jek, []byte{})
|
||||
case RSA_OAEP_256:
|
||||
// Use rand.Reader for RSA blinding
|
||||
return rsa.DecryptOAEP(sha256.New(), rand.Reader, ctx.privateKey, jek, []byte{})
|
||||
}
|
||||
|
||||
return nil, ErrUnsupportedAlgorithm
|
||||
}
|
||||
|
||||
// Sign the given payload
|
||||
func (ctx rsaDecrypterSigner) signPayload(payload []byte, alg SignatureAlgorithm) (Signature, error) {
|
||||
var hash crypto.Hash
|
||||
|
||||
switch alg {
|
||||
case RS256, PS256:
|
||||
hash = crypto.SHA256
|
||||
case RS384, PS384:
|
||||
hash = crypto.SHA384
|
||||
case RS512, PS512:
|
||||
hash = crypto.SHA512
|
||||
default:
|
||||
return Signature{}, ErrUnsupportedAlgorithm
|
||||
}
|
||||
|
||||
hasher := hash.New()
|
||||
|
||||
// According to documentation, Write() on hash never fails
|
||||
_, _ = hasher.Write(payload)
|
||||
hashed := hasher.Sum(nil)
|
||||
|
||||
var out []byte
|
||||
var err error
|
||||
|
||||
switch alg {
|
||||
case RS256, RS384, RS512:
|
||||
out, err = rsa.SignPKCS1v15(RandReader, ctx.privateKey, hash, hashed)
|
||||
case PS256, PS384, PS512:
|
||||
out, err = rsa.SignPSS(RandReader, ctx.privateKey, hash, hashed, &rsa.PSSOptions{
|
||||
SaltLength: rsa.PSSSaltLengthAuto,
|
||||
})
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return Signature{}, err
|
||||
}
|
||||
|
||||
return Signature{
|
||||
Signature: out,
|
||||
protected: &rawHeader{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Verify the given payload
|
||||
func (ctx rsaEncrypterVerifier) verifyPayload(payload []byte, signature []byte, alg SignatureAlgorithm) error {
|
||||
var hash crypto.Hash
|
||||
|
||||
switch alg {
|
||||
case RS256, PS256:
|
||||
hash = crypto.SHA256
|
||||
case RS384, PS384:
|
||||
hash = crypto.SHA384
|
||||
case RS512, PS512:
|
||||
hash = crypto.SHA512
|
||||
default:
|
||||
return ErrUnsupportedAlgorithm
|
||||
}
|
||||
|
||||
hasher := hash.New()
|
||||
|
||||
// According to documentation, Write() on hash never fails
|
||||
_, _ = hasher.Write(payload)
|
||||
hashed := hasher.Sum(nil)
|
||||
|
||||
switch alg {
|
||||
case RS256, RS384, RS512:
|
||||
return rsa.VerifyPKCS1v15(ctx.publicKey, hash, hashed, signature)
|
||||
case PS256, PS384, PS512:
|
||||
return rsa.VerifyPSS(ctx.publicKey, hash, hashed, signature, nil)
|
||||
}
|
||||
|
||||
return ErrUnsupportedAlgorithm
|
||||
}
|
||||
|
||||
// Encrypt the given payload and update the object.
|
||||
func (ctx ecEncrypterVerifier) encryptKey(cek []byte, alg KeyAlgorithm) (recipientInfo, error) {
|
||||
switch alg {
|
||||
case ECDH_ES:
|
||||
// ECDH-ES mode doesn't wrap a key, the shared secret is used directly as the key.
|
||||
return recipientInfo{
|
||||
header: &rawHeader{},
|
||||
}, nil
|
||||
case ECDH_ES_A128KW, ECDH_ES_A192KW, ECDH_ES_A256KW:
|
||||
default:
|
||||
return recipientInfo{}, ErrUnsupportedAlgorithm
|
||||
}
|
||||
|
||||
generator := ecKeyGenerator{
|
||||
algID: string(alg),
|
||||
publicKey: ctx.publicKey,
|
||||
}
|
||||
|
||||
switch alg {
|
||||
case ECDH_ES_A128KW:
|
||||
generator.size = 16
|
||||
case ECDH_ES_A192KW:
|
||||
generator.size = 24
|
||||
case ECDH_ES_A256KW:
|
||||
generator.size = 32
|
||||
}
|
||||
|
||||
kek, header, err := generator.genKey()
|
||||
if err != nil {
|
||||
return recipientInfo{}, err
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(kek)
|
||||
if err != nil {
|
||||
return recipientInfo{}, err
|
||||
}
|
||||
|
||||
jek, err := josecipher.KeyWrap(block, cek)
|
||||
if err != nil {
|
||||
return recipientInfo{}, err
|
||||
}
|
||||
|
||||
return recipientInfo{
|
||||
encryptedKey: jek,
|
||||
header: &header,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Get key size for EC key generator
|
||||
func (ctx ecKeyGenerator) keySize() int {
|
||||
return ctx.size
|
||||
}
|
||||
|
||||
// Get a content encryption key for ECDH-ES
|
||||
func (ctx ecKeyGenerator) genKey() ([]byte, rawHeader, error) {
|
||||
priv, err := ecdsa.GenerateKey(ctx.publicKey.Curve, RandReader)
|
||||
if err != nil {
|
||||
return nil, rawHeader{}, err
|
||||
}
|
||||
|
||||
out := josecipher.DeriveECDHES(ctx.algID, []byte{}, []byte{}, priv, ctx.publicKey, ctx.size)
|
||||
|
||||
b, err := json.Marshal(&JSONWebKey{
|
||||
Key: &priv.PublicKey,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
headers := rawHeader{
|
||||
headerEPK: makeRawMessage(b),
|
||||
}
|
||||
|
||||
return out, headers, nil
|
||||
}
|
||||
|
||||
// Decrypt the given payload and return the content encryption key.
|
||||
func (ctx ecDecrypterSigner) decryptKey(headers rawHeader, recipient *recipientInfo, generator keyGenerator) ([]byte, error) {
|
||||
epk, err := headers.getEPK()
|
||||
if err != nil {
|
||||
return nil, errors.New("square/go-jose: invalid epk header")
|
||||
}
|
||||
if epk == nil {
|
||||
return nil, errors.New("square/go-jose: missing epk header")
|
||||
}
|
||||
|
||||
publicKey, ok := epk.Key.(*ecdsa.PublicKey)
|
||||
if publicKey == nil || !ok {
|
||||
return nil, errors.New("square/go-jose: invalid epk header")
|
||||
}
|
||||
|
||||
if !ctx.privateKey.Curve.IsOnCurve(publicKey.X, publicKey.Y) {
|
||||
return nil, errors.New("square/go-jose: invalid public key in epk header")
|
||||
}
|
||||
|
||||
apuData, err := headers.getAPU()
|
||||
if err != nil {
|
||||
return nil, errors.New("square/go-jose: invalid apu header")
|
||||
}
|
||||
apvData, err := headers.getAPV()
|
||||
if err != nil {
|
||||
return nil, errors.New("square/go-jose: invalid apv header")
|
||||
}
|
||||
|
||||
deriveKey := func(algID string, size int) []byte {
|
||||
return josecipher.DeriveECDHES(algID, apuData.bytes(), apvData.bytes(), ctx.privateKey, publicKey, size)
|
||||
}
|
||||
|
||||
var keySize int
|
||||
|
||||
algorithm := headers.getAlgorithm()
|
||||
switch algorithm {
|
||||
case ECDH_ES:
|
||||
// ECDH-ES uses direct key agreement, no key unwrapping necessary.
|
||||
return deriveKey(string(headers.getEncryption()), generator.keySize()), nil
|
||||
case ECDH_ES_A128KW:
|
||||
keySize = 16
|
||||
case ECDH_ES_A192KW:
|
||||
keySize = 24
|
||||
case ECDH_ES_A256KW:
|
||||
keySize = 32
|
||||
default:
|
||||
return nil, ErrUnsupportedAlgorithm
|
||||
}
|
||||
|
||||
key := deriveKey(string(algorithm), keySize)
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return josecipher.KeyUnwrap(block, recipient.encryptedKey)
|
||||
}
|
||||
|
||||
func (ctx edDecrypterSigner) signPayload(payload []byte, alg SignatureAlgorithm) (Signature, error) {
|
||||
if alg != EdDSA {
|
||||
return Signature{}, ErrUnsupportedAlgorithm
|
||||
}
|
||||
|
||||
sig, err := ctx.privateKey.Sign(RandReader, payload, crypto.Hash(0))
|
||||
if err != nil {
|
||||
return Signature{}, err
|
||||
}
|
||||
|
||||
return Signature{
|
||||
Signature: sig,
|
||||
protected: &rawHeader{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (ctx edEncrypterVerifier) verifyPayload(payload []byte, signature []byte, alg SignatureAlgorithm) error {
|
||||
if alg != EdDSA {
|
||||
return ErrUnsupportedAlgorithm
|
||||
}
|
||||
ok := ed25519.Verify(ctx.publicKey, payload, signature)
|
||||
if !ok {
|
||||
return errors.New("square/go-jose: ed25519 signature failed to verify")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Sign the given payload
|
||||
func (ctx ecDecrypterSigner) signPayload(payload []byte, alg SignatureAlgorithm) (Signature, error) {
|
||||
var expectedBitSize int
|
||||
var hash crypto.Hash
|
||||
|
||||
switch alg {
|
||||
case ES256:
|
||||
expectedBitSize = 256
|
||||
hash = crypto.SHA256
|
||||
case ES384:
|
||||
expectedBitSize = 384
|
||||
hash = crypto.SHA384
|
||||
case ES512:
|
||||
expectedBitSize = 521
|
||||
hash = crypto.SHA512
|
||||
}
|
||||
|
||||
curveBits := ctx.privateKey.Curve.Params().BitSize
|
||||
if expectedBitSize != curveBits {
|
||||
return Signature{}, fmt.Errorf("square/go-jose: expected %d bit key, got %d bits instead", expectedBitSize, curveBits)
|
||||
}
|
||||
|
||||
hasher := hash.New()
|
||||
|
||||
// According to documentation, Write() on hash never fails
|
||||
_, _ = hasher.Write(payload)
|
||||
hashed := hasher.Sum(nil)
|
||||
|
||||
r, s, err := ecdsa.Sign(RandReader, ctx.privateKey, hashed)
|
||||
if err != nil {
|
||||
return Signature{}, err
|
||||
}
|
||||
|
||||
keyBytes := curveBits / 8
|
||||
if curveBits%8 > 0 {
|
||||
keyBytes++
|
||||
}
|
||||
|
||||
// We serialize the outputs (r and s) into big-endian byte arrays and pad
|
||||
// them with zeros on the left to make sure the sizes work out. Both arrays
|
||||
// must be keyBytes long, and the output must be 2*keyBytes long.
|
||||
rBytes := r.Bytes()
|
||||
rBytesPadded := make([]byte, keyBytes)
|
||||
copy(rBytesPadded[keyBytes-len(rBytes):], rBytes)
|
||||
|
||||
sBytes := s.Bytes()
|
||||
sBytesPadded := make([]byte, keyBytes)
|
||||
copy(sBytesPadded[keyBytes-len(sBytes):], sBytes)
|
||||
|
||||
out := append(rBytesPadded, sBytesPadded...)
|
||||
|
||||
return Signature{
|
||||
Signature: out,
|
||||
protected: &rawHeader{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Verify the given payload
|
||||
func (ctx ecEncrypterVerifier) verifyPayload(payload []byte, signature []byte, alg SignatureAlgorithm) error {
|
||||
var keySize int
|
||||
var hash crypto.Hash
|
||||
|
||||
switch alg {
|
||||
case ES256:
|
||||
keySize = 32
|
||||
hash = crypto.SHA256
|
||||
case ES384:
|
||||
keySize = 48
|
||||
hash = crypto.SHA384
|
||||
case ES512:
|
||||
keySize = 66
|
||||
hash = crypto.SHA512
|
||||
default:
|
||||
return ErrUnsupportedAlgorithm
|
||||
}
|
||||
|
||||
if len(signature) != 2*keySize {
|
||||
return fmt.Errorf("square/go-jose: invalid signature size, have %d bytes, wanted %d", len(signature), 2*keySize)
|
||||
}
|
||||
|
||||
hasher := hash.New()
|
||||
|
||||
// According to documentation, Write() on hash never fails
|
||||
_, _ = hasher.Write(payload)
|
||||
hashed := hasher.Sum(nil)
|
||||
|
||||
r := big.NewInt(0).SetBytes(signature[:keySize])
|
||||
s := big.NewInt(0).SetBytes(signature[keySize:])
|
||||
|
||||
match := ecdsa.Verify(ctx.publicKey, hashed, r, s)
|
||||
if !match {
|
||||
return errors.New("square/go-jose: ecdsa signature failed to verify")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,196 @@
|
|||
/*-
|
||||
* Copyright 2014 Square Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package josecipher
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/cipher"
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"crypto/sha512"
|
||||
"crypto/subtle"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"hash"
|
||||
)
|
||||
|
||||
const (
|
||||
nonceBytes = 16
|
||||
)
|
||||
|
||||
// NewCBCHMAC instantiates a new AEAD based on CBC+HMAC.
|
||||
func NewCBCHMAC(key []byte, newBlockCipher func([]byte) (cipher.Block, error)) (cipher.AEAD, error) {
|
||||
keySize := len(key) / 2
|
||||
integrityKey := key[:keySize]
|
||||
encryptionKey := key[keySize:]
|
||||
|
||||
blockCipher, err := newBlockCipher(encryptionKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var hash func() hash.Hash
|
||||
switch keySize {
|
||||
case 16:
|
||||
hash = sha256.New
|
||||
case 24:
|
||||
hash = sha512.New384
|
||||
case 32:
|
||||
hash = sha512.New
|
||||
}
|
||||
|
||||
return &cbcAEAD{
|
||||
hash: hash,
|
||||
blockCipher: blockCipher,
|
||||
authtagBytes: keySize,
|
||||
integrityKey: integrityKey,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// An AEAD based on CBC+HMAC
|
||||
type cbcAEAD struct {
|
||||
hash func() hash.Hash
|
||||
authtagBytes int
|
||||
integrityKey []byte
|
||||
blockCipher cipher.Block
|
||||
}
|
||||
|
||||
func (ctx *cbcAEAD) NonceSize() int {
|
||||
return nonceBytes
|
||||
}
|
||||
|
||||
func (ctx *cbcAEAD) Overhead() int {
|
||||
// Maximum overhead is block size (for padding) plus auth tag length, where
|
||||
// the length of the auth tag is equivalent to the key size.
|
||||
return ctx.blockCipher.BlockSize() + ctx.authtagBytes
|
||||
}
|
||||
|
||||
// Seal encrypts and authenticates the plaintext.
|
||||
func (ctx *cbcAEAD) Seal(dst, nonce, plaintext, data []byte) []byte {
|
||||
// Output buffer -- must take care not to mangle plaintext input.
|
||||
ciphertext := make([]byte, uint64(len(plaintext))+uint64(ctx.Overhead()))[:len(plaintext)]
|
||||
copy(ciphertext, plaintext)
|
||||
ciphertext = padBuffer(ciphertext, ctx.blockCipher.BlockSize())
|
||||
|
||||
cbc := cipher.NewCBCEncrypter(ctx.blockCipher, nonce)
|
||||
|
||||
cbc.CryptBlocks(ciphertext, ciphertext)
|
||||
authtag := ctx.computeAuthTag(data, nonce, ciphertext)
|
||||
|
||||
ret, out := resize(dst, uint64(len(dst))+uint64(len(ciphertext))+uint64(len(authtag)))
|
||||
copy(out, ciphertext)
|
||||
copy(out[len(ciphertext):], authtag)
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
// Open decrypts and authenticates the ciphertext.
|
||||
func (ctx *cbcAEAD) Open(dst, nonce, ciphertext, data []byte) ([]byte, error) {
|
||||
if len(ciphertext) < ctx.authtagBytes {
|
||||
return nil, errors.New("square/go-jose: invalid ciphertext (too short)")
|
||||
}
|
||||
|
||||
offset := len(ciphertext) - ctx.authtagBytes
|
||||
expectedTag := ctx.computeAuthTag(data, nonce, ciphertext[:offset])
|
||||
match := subtle.ConstantTimeCompare(expectedTag, ciphertext[offset:])
|
||||
if match != 1 {
|
||||
return nil, errors.New("square/go-jose: invalid ciphertext (auth tag mismatch)")
|
||||
}
|
||||
|
||||
cbc := cipher.NewCBCDecrypter(ctx.blockCipher, nonce)
|
||||
|
||||
// Make copy of ciphertext buffer, don't want to modify in place
|
||||
buffer := append([]byte{}, []byte(ciphertext[:offset])...)
|
||||
|
||||
if len(buffer)%ctx.blockCipher.BlockSize() > 0 {
|
||||
return nil, errors.New("square/go-jose: invalid ciphertext (invalid length)")
|
||||
}
|
||||
|
||||
cbc.CryptBlocks(buffer, buffer)
|
||||
|
||||
// Remove padding
|
||||
plaintext, err := unpadBuffer(buffer, ctx.blockCipher.BlockSize())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ret, out := resize(dst, uint64(len(dst))+uint64(len(plaintext)))
|
||||
copy(out, plaintext)
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
// Compute an authentication tag
|
||||
func (ctx *cbcAEAD) computeAuthTag(aad, nonce, ciphertext []byte) []byte {
|
||||
buffer := make([]byte, uint64(len(aad))+uint64(len(nonce))+uint64(len(ciphertext))+8)
|
||||
n := 0
|
||||
n += copy(buffer, aad)
|
||||
n += copy(buffer[n:], nonce)
|
||||
n += copy(buffer[n:], ciphertext)
|
||||
binary.BigEndian.PutUint64(buffer[n:], uint64(len(aad))*8)
|
||||
|
||||
// According to documentation, Write() on hash.Hash never fails.
|
||||
hmac := hmac.New(ctx.hash, ctx.integrityKey)
|
||||
_, _ = hmac.Write(buffer)
|
||||
|
||||
return hmac.Sum(nil)[:ctx.authtagBytes]
|
||||
}
|
||||
|
||||
// resize ensures the the given slice has a capacity of at least n bytes.
|
||||
// If the capacity of the slice is less than n, a new slice is allocated
|
||||
// and the existing data will be copied.
|
||||
func resize(in []byte, n uint64) (head, tail []byte) {
|
||||
if uint64(cap(in)) >= n {
|
||||
head = in[:n]
|
||||
} else {
|
||||
head = make([]byte, n)
|
||||
copy(head, in)
|
||||
}
|
||||
|
||||
tail = head[len(in):]
|
||||
return
|
||||
}
|
||||
|
||||
// Apply padding
|
||||
func padBuffer(buffer []byte, blockSize int) []byte {
|
||||
missing := blockSize - (len(buffer) % blockSize)
|
||||
ret, out := resize(buffer, uint64(len(buffer))+uint64(missing))
|
||||
padding := bytes.Repeat([]byte{byte(missing)}, missing)
|
||||
copy(out, padding)
|
||||
return ret
|
||||
}
|
||||
|
||||
// Remove padding
|
||||
func unpadBuffer(buffer []byte, blockSize int) ([]byte, error) {
|
||||
if len(buffer)%blockSize != 0 {
|
||||
return nil, errors.New("square/go-jose: invalid padding")
|
||||
}
|
||||
|
||||
last := buffer[len(buffer)-1]
|
||||
count := int(last)
|
||||
|
||||
if count == 0 || count > blockSize || count > len(buffer) {
|
||||
return nil, errors.New("square/go-jose: invalid padding")
|
||||
}
|
||||
|
||||
padding := bytes.Repeat([]byte{last}, count)
|
||||
if !bytes.HasSuffix(buffer, padding) {
|
||||
return nil, errors.New("square/go-jose: invalid padding")
|
||||
}
|
||||
|
||||
return buffer[:len(buffer)-count], nil
|
||||
}
|
|
@ -0,0 +1,75 @@
|
|||
/*-
|
||||
* Copyright 2014 Square Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package josecipher
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"encoding/binary"
|
||||
"hash"
|
||||
"io"
|
||||
)
|
||||
|
||||
type concatKDF struct {
|
||||
z, info []byte
|
||||
i uint32
|
||||
cache []byte
|
||||
hasher hash.Hash
|
||||
}
|
||||
|
||||
// NewConcatKDF builds a KDF reader based on the given inputs.
|
||||
func NewConcatKDF(hash crypto.Hash, z, algID, ptyUInfo, ptyVInfo, supPubInfo, supPrivInfo []byte) io.Reader {
|
||||
buffer := make([]byte, uint64(len(algID))+uint64(len(ptyUInfo))+uint64(len(ptyVInfo))+uint64(len(supPubInfo))+uint64(len(supPrivInfo)))
|
||||
n := 0
|
||||
n += copy(buffer, algID)
|
||||
n += copy(buffer[n:], ptyUInfo)
|
||||
n += copy(buffer[n:], ptyVInfo)
|
||||
n += copy(buffer[n:], supPubInfo)
|
||||
copy(buffer[n:], supPrivInfo)
|
||||
|
||||
hasher := hash.New()
|
||||
|
||||
return &concatKDF{
|
||||
z: z,
|
||||
info: buffer,
|
||||
hasher: hasher,
|
||||
cache: []byte{},
|
||||
i: 1,
|
||||
}
|
||||
}
|
||||
|
||||
func (ctx *concatKDF) Read(out []byte) (int, error) {
|
||||
copied := copy(out, ctx.cache)
|
||||
ctx.cache = ctx.cache[copied:]
|
||||
|
||||
for copied < len(out) {
|
||||
ctx.hasher.Reset()
|
||||
|
||||
// Write on a hash.Hash never fails
|
||||
_ = binary.Write(ctx.hasher, binary.BigEndian, ctx.i)
|
||||
_, _ = ctx.hasher.Write(ctx.z)
|
||||
_, _ = ctx.hasher.Write(ctx.info)
|
||||
|
||||
hash := ctx.hasher.Sum(nil)
|
||||
chunkCopied := copy(out[copied:], hash)
|
||||
copied += chunkCopied
|
||||
ctx.cache = hash[chunkCopied:]
|
||||
|
||||
ctx.i++
|
||||
}
|
||||
|
||||
return copied, nil
|
||||
}
|
|
@ -0,0 +1,62 @@
|
|||
/*-
|
||||
* Copyright 2014 Square Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package josecipher
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"encoding/binary"
|
||||
)
|
||||
|
||||
// DeriveECDHES derives a shared encryption key using ECDH/ConcatKDF as described in JWE/JWA.
|
||||
// It is an error to call this function with a private/public key that are not on the same
|
||||
// curve. Callers must ensure that the keys are valid before calling this function. Output
|
||||
// size may be at most 1<<16 bytes (64 KiB).
|
||||
func DeriveECDHES(alg string, apuData, apvData []byte, priv *ecdsa.PrivateKey, pub *ecdsa.PublicKey, size int) []byte {
|
||||
if size > 1<<16 {
|
||||
panic("ECDH-ES output size too large, must be less than or equal to 1<<16")
|
||||
}
|
||||
|
||||
// algId, partyUInfo, partyVInfo inputs must be prefixed with the length
|
||||
algID := lengthPrefixed([]byte(alg))
|
||||
ptyUInfo := lengthPrefixed(apuData)
|
||||
ptyVInfo := lengthPrefixed(apvData)
|
||||
|
||||
// suppPubInfo is the encoded length of the output size in bits
|
||||
supPubInfo := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(supPubInfo, uint32(size)*8)
|
||||
|
||||
if !priv.PublicKey.Curve.IsOnCurve(pub.X, pub.Y) {
|
||||
panic("public key not on same curve as private key")
|
||||
}
|
||||
|
||||
z, _ := priv.PublicKey.Curve.ScalarMult(pub.X, pub.Y, priv.D.Bytes())
|
||||
reader := NewConcatKDF(crypto.SHA256, z.Bytes(), algID, ptyUInfo, ptyVInfo, supPubInfo, []byte{})
|
||||
|
||||
key := make([]byte, size)
|
||||
|
||||
// Read on the KDF will never fail
|
||||
_, _ = reader.Read(key)
|
||||
return key
|
||||
}
|
||||
|
||||
func lengthPrefixed(data []byte) []byte {
|
||||
out := make([]byte, len(data)+4)
|
||||
binary.BigEndian.PutUint32(out, uint32(len(data)))
|
||||
copy(out[4:], data)
|
||||
return out
|
||||
}
|
|
@ -0,0 +1,109 @@
|
|||
/*-
|
||||
* Copyright 2014 Square Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package josecipher
|
||||
|
||||
import (
|
||||
"crypto/cipher"
|
||||
"crypto/subtle"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
)
|
||||
|
||||
var defaultIV = []byte{0xA6, 0xA6, 0xA6, 0xA6, 0xA6, 0xA6, 0xA6, 0xA6}
|
||||
|
||||
// KeyWrap implements NIST key wrapping; it wraps a content encryption key (cek) with the given block cipher.
|
||||
func KeyWrap(block cipher.Block, cek []byte) ([]byte, error) {
|
||||
if len(cek)%8 != 0 {
|
||||
return nil, errors.New("square/go-jose: key wrap input must be 8 byte blocks")
|
||||
}
|
||||
|
||||
n := len(cek) / 8
|
||||
r := make([][]byte, n)
|
||||
|
||||
for i := range r {
|
||||
r[i] = make([]byte, 8)
|
||||
copy(r[i], cek[i*8:])
|
||||
}
|
||||
|
||||
buffer := make([]byte, 16)
|
||||
tBytes := make([]byte, 8)
|
||||
copy(buffer, defaultIV)
|
||||
|
||||
for t := 0; t < 6*n; t++ {
|
||||
copy(buffer[8:], r[t%n])
|
||||
|
||||
block.Encrypt(buffer, buffer)
|
||||
|
||||
binary.BigEndian.PutUint64(tBytes, uint64(t+1))
|
||||
|
||||
for i := 0; i < 8; i++ {
|
||||
buffer[i] = buffer[i] ^ tBytes[i]
|
||||
}
|
||||
copy(r[t%n], buffer[8:])
|
||||
}
|
||||
|
||||
out := make([]byte, (n+1)*8)
|
||||
copy(out, buffer[:8])
|
||||
for i := range r {
|
||||
copy(out[(i+1)*8:], r[i])
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// KeyUnwrap implements NIST key unwrapping; it unwraps a content encryption key (cek) with the given block cipher.
|
||||
func KeyUnwrap(block cipher.Block, ciphertext []byte) ([]byte, error) {
|
||||
if len(ciphertext)%8 != 0 {
|
||||
return nil, errors.New("square/go-jose: key wrap input must be 8 byte blocks")
|
||||
}
|
||||
|
||||
n := (len(ciphertext) / 8) - 1
|
||||
r := make([][]byte, n)
|
||||
|
||||
for i := range r {
|
||||
r[i] = make([]byte, 8)
|
||||
copy(r[i], ciphertext[(i+1)*8:])
|
||||
}
|
||||
|
||||
buffer := make([]byte, 16)
|
||||
tBytes := make([]byte, 8)
|
||||
copy(buffer[:8], ciphertext[:8])
|
||||
|
||||
for t := 6*n - 1; t >= 0; t-- {
|
||||
binary.BigEndian.PutUint64(tBytes, uint64(t+1))
|
||||
|
||||
for i := 0; i < 8; i++ {
|
||||
buffer[i] = buffer[i] ^ tBytes[i]
|
||||
}
|
||||
copy(buffer[8:], r[t%n])
|
||||
|
||||
block.Decrypt(buffer, buffer)
|
||||
|
||||
copy(r[t%n], buffer[8:])
|
||||
}
|
||||
|
||||
if subtle.ConstantTimeCompare(buffer[:8], defaultIV) == 0 {
|
||||
return nil, errors.New("square/go-jose: failed to unwrap key")
|
||||
}
|
||||
|
||||
out := make([]byte, n*8)
|
||||
for i := range r {
|
||||
copy(out[i*8:], r[i])
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
|
@ -0,0 +1,535 @@
|
|||
/*-
|
||||
* Copyright 2014 Square Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package jose
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/rsa"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"gopkg.in/square/go-jose.v2/json"
|
||||
)
|
||||
|
||||
// Encrypter represents an encrypter which produces an encrypted JWE object.
|
||||
type Encrypter interface {
|
||||
Encrypt(plaintext []byte) (*JSONWebEncryption, error)
|
||||
EncryptWithAuthData(plaintext []byte, aad []byte) (*JSONWebEncryption, error)
|
||||
Options() EncrypterOptions
|
||||
}
|
||||
|
||||
// A generic content cipher
|
||||
type contentCipher interface {
|
||||
keySize() int
|
||||
encrypt(cek []byte, aad, plaintext []byte) (*aeadParts, error)
|
||||
decrypt(cek []byte, aad []byte, parts *aeadParts) ([]byte, error)
|
||||
}
|
||||
|
||||
// A key generator (for generating/getting a CEK)
|
||||
type keyGenerator interface {
|
||||
keySize() int
|
||||
genKey() ([]byte, rawHeader, error)
|
||||
}
|
||||
|
||||
// A generic key encrypter
|
||||
type keyEncrypter interface {
|
||||
encryptKey(cek []byte, alg KeyAlgorithm) (recipientInfo, error) // Encrypt a key
|
||||
}
|
||||
|
||||
// A generic key decrypter
|
||||
type keyDecrypter interface {
|
||||
decryptKey(headers rawHeader, recipient *recipientInfo, generator keyGenerator) ([]byte, error) // Decrypt a key
|
||||
}
|
||||
|
||||
// A generic encrypter based on the given key encrypter and content cipher.
|
||||
type genericEncrypter struct {
|
||||
contentAlg ContentEncryption
|
||||
compressionAlg CompressionAlgorithm
|
||||
cipher contentCipher
|
||||
recipients []recipientKeyInfo
|
||||
keyGenerator keyGenerator
|
||||
extraHeaders map[HeaderKey]interface{}
|
||||
}
|
||||
|
||||
type recipientKeyInfo struct {
|
||||
keyID string
|
||||
keyAlg KeyAlgorithm
|
||||
keyEncrypter keyEncrypter
|
||||
}
|
||||
|
||||
// EncrypterOptions represents options that can be set on new encrypters.
|
||||
type EncrypterOptions struct {
|
||||
Compression CompressionAlgorithm
|
||||
|
||||
// Optional map of additional keys to be inserted into the protected header
|
||||
// of a JWS object. Some specifications which make use of JWS like to insert
|
||||
// additional values here. All values must be JSON-serializable.
|
||||
ExtraHeaders map[HeaderKey]interface{}
|
||||
}
|
||||
|
||||
// WithHeader adds an arbitrary value to the ExtraHeaders map, initializing it
|
||||
// if necessary. It returns itself and so can be used in a fluent style.
|
||||
func (eo *EncrypterOptions) WithHeader(k HeaderKey, v interface{}) *EncrypterOptions {
|
||||
if eo.ExtraHeaders == nil {
|
||||
eo.ExtraHeaders = map[HeaderKey]interface{}{}
|
||||
}
|
||||
eo.ExtraHeaders[k] = v
|
||||
return eo
|
||||
}
|
||||
|
||||
// WithContentType adds a content type ("cty") header and returns the updated
|
||||
// EncrypterOptions.
|
||||
func (eo *EncrypterOptions) WithContentType(contentType ContentType) *EncrypterOptions {
|
||||
return eo.WithHeader(HeaderContentType, contentType)
|
||||
}
|
||||
|
||||
// WithType adds a type ("typ") header and returns the updated EncrypterOptions.
|
||||
func (eo *EncrypterOptions) WithType(typ ContentType) *EncrypterOptions {
|
||||
return eo.WithHeader(HeaderType, typ)
|
||||
}
|
||||
|
||||
// Recipient represents an algorithm/key to encrypt messages to.
|
||||
//
|
||||
// PBES2Count and PBES2Salt correspond with the "p2c" and "p2s" headers used
|
||||
// on the password-based encryption algorithms PBES2-HS256+A128KW,
|
||||
// PBES2-HS384+A192KW, and PBES2-HS512+A256KW. If they are not provided a safe
|
||||
// default of 100000 will be used for the count and a 128-bit random salt will
|
||||
// be generated.
|
||||
type Recipient struct {
|
||||
Algorithm KeyAlgorithm
|
||||
Key interface{}
|
||||
KeyID string
|
||||
PBES2Count int
|
||||
PBES2Salt []byte
|
||||
}
|
||||
|
||||
// NewEncrypter creates an appropriate encrypter based on the key type
|
||||
func NewEncrypter(enc ContentEncryption, rcpt Recipient, opts *EncrypterOptions) (Encrypter, error) {
|
||||
encrypter := &genericEncrypter{
|
||||
contentAlg: enc,
|
||||
recipients: []recipientKeyInfo{},
|
||||
cipher: getContentCipher(enc),
|
||||
}
|
||||
if opts != nil {
|
||||
encrypter.compressionAlg = opts.Compression
|
||||
encrypter.extraHeaders = opts.ExtraHeaders
|
||||
}
|
||||
|
||||
if encrypter.cipher == nil {
|
||||
return nil, ErrUnsupportedAlgorithm
|
||||
}
|
||||
|
||||
var keyID string
|
||||
var rawKey interface{}
|
||||
switch encryptionKey := rcpt.Key.(type) {
|
||||
case JSONWebKey:
|
||||
keyID, rawKey = encryptionKey.KeyID, encryptionKey.Key
|
||||
case *JSONWebKey:
|
||||
keyID, rawKey = encryptionKey.KeyID, encryptionKey.Key
|
||||
default:
|
||||
rawKey = encryptionKey
|
||||
}
|
||||
|
||||
switch rcpt.Algorithm {
|
||||
case DIRECT:
|
||||
// Direct encryption mode must be treated differently
|
||||
if reflect.TypeOf(rawKey) != reflect.TypeOf([]byte{}) {
|
||||
return nil, ErrUnsupportedKeyType
|
||||
}
|
||||
if encrypter.cipher.keySize() != len(rawKey.([]byte)) {
|
||||
return nil, ErrInvalidKeySize
|
||||
}
|
||||
encrypter.keyGenerator = staticKeyGenerator{
|
||||
key: rawKey.([]byte),
|
||||
}
|
||||
recipientInfo, _ := newSymmetricRecipient(rcpt.Algorithm, rawKey.([]byte))
|
||||
recipientInfo.keyID = keyID
|
||||
if rcpt.KeyID != "" {
|
||||
recipientInfo.keyID = rcpt.KeyID
|
||||
}
|
||||
encrypter.recipients = []recipientKeyInfo{recipientInfo}
|
||||
return encrypter, nil
|
||||
case ECDH_ES:
|
||||
// ECDH-ES (w/o key wrapping) is similar to DIRECT mode
|
||||
typeOf := reflect.TypeOf(rawKey)
|
||||
if typeOf != reflect.TypeOf(&ecdsa.PublicKey{}) {
|
||||
return nil, ErrUnsupportedKeyType
|
||||
}
|
||||
encrypter.keyGenerator = ecKeyGenerator{
|
||||
size: encrypter.cipher.keySize(),
|
||||
algID: string(enc),
|
||||
publicKey: rawKey.(*ecdsa.PublicKey),
|
||||
}
|
||||
recipientInfo, _ := newECDHRecipient(rcpt.Algorithm, rawKey.(*ecdsa.PublicKey))
|
||||
recipientInfo.keyID = keyID
|
||||
if rcpt.KeyID != "" {
|
||||
recipientInfo.keyID = rcpt.KeyID
|
||||
}
|
||||
encrypter.recipients = []recipientKeyInfo{recipientInfo}
|
||||
return encrypter, nil
|
||||
default:
|
||||
// Can just add a standard recipient
|
||||
encrypter.keyGenerator = randomKeyGenerator{
|
||||
size: encrypter.cipher.keySize(),
|
||||
}
|
||||
err := encrypter.addRecipient(rcpt)
|
||||
return encrypter, err
|
||||
}
|
||||
}
|
||||
|
||||
// NewMultiEncrypter creates a multi-encrypter based on the given parameters
|
||||
func NewMultiEncrypter(enc ContentEncryption, rcpts []Recipient, opts *EncrypterOptions) (Encrypter, error) {
|
||||
cipher := getContentCipher(enc)
|
||||
|
||||
if cipher == nil {
|
||||
return nil, ErrUnsupportedAlgorithm
|
||||
}
|
||||
if rcpts == nil || len(rcpts) == 0 {
|
||||
return nil, fmt.Errorf("square/go-jose: recipients is nil or empty")
|
||||
}
|
||||
|
||||
encrypter := &genericEncrypter{
|
||||
contentAlg: enc,
|
||||
recipients: []recipientKeyInfo{},
|
||||
cipher: cipher,
|
||||
keyGenerator: randomKeyGenerator{
|
||||
size: cipher.keySize(),
|
||||
},
|
||||
}
|
||||
|
||||
if opts != nil {
|
||||
encrypter.compressionAlg = opts.Compression
|
||||
}
|
||||
|
||||
for _, recipient := range rcpts {
|
||||
err := encrypter.addRecipient(recipient)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return encrypter, nil
|
||||
}
|
||||
|
||||
func (ctx *genericEncrypter) addRecipient(recipient Recipient) (err error) {
|
||||
var recipientInfo recipientKeyInfo
|
||||
|
||||
switch recipient.Algorithm {
|
||||
case DIRECT, ECDH_ES:
|
||||
return fmt.Errorf("square/go-jose: key algorithm '%s' not supported in multi-recipient mode", recipient.Algorithm)
|
||||
}
|
||||
|
||||
recipientInfo, err = makeJWERecipient(recipient.Algorithm, recipient.Key)
|
||||
if recipient.KeyID != "" {
|
||||
recipientInfo.keyID = recipient.KeyID
|
||||
}
|
||||
|
||||
switch recipient.Algorithm {
|
||||
case PBES2_HS256_A128KW, PBES2_HS384_A192KW, PBES2_HS512_A256KW:
|
||||
if sr, ok := recipientInfo.keyEncrypter.(*symmetricKeyCipher); ok {
|
||||
sr.p2c = recipient.PBES2Count
|
||||
sr.p2s = recipient.PBES2Salt
|
||||
}
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
ctx.recipients = append(ctx.recipients, recipientInfo)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func makeJWERecipient(alg KeyAlgorithm, encryptionKey interface{}) (recipientKeyInfo, error) {
|
||||
switch encryptionKey := encryptionKey.(type) {
|
||||
case *rsa.PublicKey:
|
||||
return newRSARecipient(alg, encryptionKey)
|
||||
case *ecdsa.PublicKey:
|
||||
return newECDHRecipient(alg, encryptionKey)
|
||||
case []byte:
|
||||
return newSymmetricRecipient(alg, encryptionKey)
|
||||
case string:
|
||||
return newSymmetricRecipient(alg, []byte(encryptionKey))
|
||||
case *JSONWebKey:
|
||||
recipient, err := makeJWERecipient(alg, encryptionKey.Key)
|
||||
recipient.keyID = encryptionKey.KeyID
|
||||
return recipient, err
|
||||
default:
|
||||
return recipientKeyInfo{}, ErrUnsupportedKeyType
|
||||
}
|
||||
}
|
||||
|
||||
// newDecrypter creates an appropriate decrypter based on the key type
|
||||
func newDecrypter(decryptionKey interface{}) (keyDecrypter, error) {
|
||||
switch decryptionKey := decryptionKey.(type) {
|
||||
case *rsa.PrivateKey:
|
||||
return &rsaDecrypterSigner{
|
||||
privateKey: decryptionKey,
|
||||
}, nil
|
||||
case *ecdsa.PrivateKey:
|
||||
return &ecDecrypterSigner{
|
||||
privateKey: decryptionKey,
|
||||
}, nil
|
||||
case []byte:
|
||||
return &symmetricKeyCipher{
|
||||
key: decryptionKey,
|
||||
}, nil
|
||||
case string:
|
||||
return &symmetricKeyCipher{
|
||||
key: []byte(decryptionKey),
|
||||
}, nil
|
||||
case JSONWebKey:
|
||||
return newDecrypter(decryptionKey.Key)
|
||||
case *JSONWebKey:
|
||||
return newDecrypter(decryptionKey.Key)
|
||||
default:
|
||||
return nil, ErrUnsupportedKeyType
|
||||
}
|
||||
}
|
||||
|
||||
// Implementation of encrypt method producing a JWE object.
|
||||
func (ctx *genericEncrypter) Encrypt(plaintext []byte) (*JSONWebEncryption, error) {
|
||||
return ctx.EncryptWithAuthData(plaintext, nil)
|
||||
}
|
||||
|
||||
// Implementation of encrypt method producing a JWE object.
|
||||
func (ctx *genericEncrypter) EncryptWithAuthData(plaintext, aad []byte) (*JSONWebEncryption, error) {
|
||||
obj := &JSONWebEncryption{}
|
||||
obj.aad = aad
|
||||
|
||||
obj.protected = &rawHeader{}
|
||||
err := obj.protected.set(headerEncryption, ctx.contentAlg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
obj.recipients = make([]recipientInfo, len(ctx.recipients))
|
||||
|
||||
if len(ctx.recipients) == 0 {
|
||||
return nil, fmt.Errorf("square/go-jose: no recipients to encrypt to")
|
||||
}
|
||||
|
||||
cek, headers, err := ctx.keyGenerator.genKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
obj.protected.merge(&headers)
|
||||
|
||||
for i, info := range ctx.recipients {
|
||||
recipient, err := info.keyEncrypter.encryptKey(cek, info.keyAlg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = recipient.header.set(headerAlgorithm, info.keyAlg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if info.keyID != "" {
|
||||
err = recipient.header.set(headerKeyID, info.keyID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
obj.recipients[i] = recipient
|
||||
}
|
||||
|
||||
if len(ctx.recipients) == 1 {
|
||||
// Move per-recipient headers into main protected header if there's
|
||||
// only a single recipient.
|
||||
obj.protected.merge(obj.recipients[0].header)
|
||||
obj.recipients[0].header = nil
|
||||
}
|
||||
|
||||
if ctx.compressionAlg != NONE {
|
||||
plaintext, err = compress(ctx.compressionAlg, plaintext)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = obj.protected.set(headerCompression, ctx.compressionAlg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
for k, v := range ctx.extraHeaders {
|
||||
b, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
(*obj.protected)[k] = makeRawMessage(b)
|
||||
}
|
||||
|
||||
authData := obj.computeAuthData()
|
||||
parts, err := ctx.cipher.encrypt(cek, authData, plaintext)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
obj.iv = parts.iv
|
||||
obj.ciphertext = parts.ciphertext
|
||||
obj.tag = parts.tag
|
||||
|
||||
return obj, nil
|
||||
}
|
||||
|
||||
func (ctx *genericEncrypter) Options() EncrypterOptions {
|
||||
return EncrypterOptions{
|
||||
Compression: ctx.compressionAlg,
|
||||
ExtraHeaders: ctx.extraHeaders,
|
||||
}
|
||||
}
|
||||
|
||||
// Decrypt and validate the object and return the plaintext. Note that this
|
||||
// function does not support multi-recipient, if you desire multi-recipient
|
||||
// decryption use DecryptMulti instead.
|
||||
func (obj JSONWebEncryption) Decrypt(decryptionKey interface{}) ([]byte, error) {
|
||||
headers := obj.mergedHeaders(nil)
|
||||
|
||||
if len(obj.recipients) > 1 {
|
||||
return nil, errors.New("square/go-jose: too many recipients in payload; expecting only one")
|
||||
}
|
||||
|
||||
critical, err := headers.getCritical()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("square/go-jose: invalid crit header")
|
||||
}
|
||||
|
||||
if len(critical) > 0 {
|
||||
return nil, fmt.Errorf("square/go-jose: unsupported crit header")
|
||||
}
|
||||
|
||||
decrypter, err := newDecrypter(decryptionKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cipher := getContentCipher(headers.getEncryption())
|
||||
if cipher == nil {
|
||||
return nil, fmt.Errorf("square/go-jose: unsupported enc value '%s'", string(headers.getEncryption()))
|
||||
}
|
||||
|
||||
generator := randomKeyGenerator{
|
||||
size: cipher.keySize(),
|
||||
}
|
||||
|
||||
parts := &aeadParts{
|
||||
iv: obj.iv,
|
||||
ciphertext: obj.ciphertext,
|
||||
tag: obj.tag,
|
||||
}
|
||||
|
||||
authData := obj.computeAuthData()
|
||||
|
||||
var plaintext []byte
|
||||
recipient := obj.recipients[0]
|
||||
recipientHeaders := obj.mergedHeaders(&recipient)
|
||||
|
||||
cek, err := decrypter.decryptKey(recipientHeaders, &recipient, generator)
|
||||
if err == nil {
|
||||
// Found a valid CEK -- let's try to decrypt.
|
||||
plaintext, err = cipher.decrypt(cek, authData, parts)
|
||||
}
|
||||
|
||||
if plaintext == nil {
|
||||
return nil, ErrCryptoFailure
|
||||
}
|
||||
|
||||
// The "zip" header parameter may only be present in the protected header.
|
||||
if comp := obj.protected.getCompression(); comp != "" {
|
||||
plaintext, err = decompress(comp, plaintext)
|
||||
}
|
||||
|
||||
return plaintext, err
|
||||
}
|
||||
|
||||
// DecryptMulti decrypts and validates the object and returns the plaintexts,
|
||||
// with support for multiple recipients. It returns the index of the recipient
|
||||
// for which the decryption was successful, the merged headers for that recipient,
|
||||
// and the plaintext.
|
||||
func (obj JSONWebEncryption) DecryptMulti(decryptionKey interface{}) (int, Header, []byte, error) {
|
||||
globalHeaders := obj.mergedHeaders(nil)
|
||||
|
||||
critical, err := globalHeaders.getCritical()
|
||||
if err != nil {
|
||||
return -1, Header{}, nil, fmt.Errorf("square/go-jose: invalid crit header")
|
||||
}
|
||||
|
||||
if len(critical) > 0 {
|
||||
return -1, Header{}, nil, fmt.Errorf("square/go-jose: unsupported crit header")
|
||||
}
|
||||
|
||||
decrypter, err := newDecrypter(decryptionKey)
|
||||
if err != nil {
|
||||
return -1, Header{}, nil, err
|
||||
}
|
||||
|
||||
encryption := globalHeaders.getEncryption()
|
||||
cipher := getContentCipher(encryption)
|
||||
if cipher == nil {
|
||||
return -1, Header{}, nil, fmt.Errorf("square/go-jose: unsupported enc value '%s'", string(encryption))
|
||||
}
|
||||
|
||||
generator := randomKeyGenerator{
|
||||
size: cipher.keySize(),
|
||||
}
|
||||
|
||||
parts := &aeadParts{
|
||||
iv: obj.iv,
|
||||
ciphertext: obj.ciphertext,
|
||||
tag: obj.tag,
|
||||
}
|
||||
|
||||
authData := obj.computeAuthData()
|
||||
|
||||
index := -1
|
||||
var plaintext []byte
|
||||
var headers rawHeader
|
||||
|
||||
for i, recipient := range obj.recipients {
|
||||
recipientHeaders := obj.mergedHeaders(&recipient)
|
||||
|
||||
cek, err := decrypter.decryptKey(recipientHeaders, &recipient, generator)
|
||||
if err == nil {
|
||||
// Found a valid CEK -- let's try to decrypt.
|
||||
plaintext, err = cipher.decrypt(cek, authData, parts)
|
||||
if err == nil {
|
||||
index = i
|
||||
headers = recipientHeaders
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if plaintext == nil || err != nil {
|
||||
return -1, Header{}, nil, ErrCryptoFailure
|
||||
}
|
||||
|
||||
// The "zip" header parameter may only be present in the protected header.
|
||||
if comp := obj.protected.getCompression(); comp != "" {
|
||||
plaintext, err = decompress(comp, plaintext)
|
||||
}
|
||||
|
||||
sanitized, err := headers.sanitized()
|
||||
if err != nil {
|
||||
return -1, Header{}, nil, fmt.Errorf("square/go-jose: failed to sanitize header: %v", err)
|
||||
}
|
||||
|
||||
return index, sanitized, plaintext, err
|
||||
}
|
|
@ -0,0 +1,27 @@
|
|||
/*-
|
||||
* Copyright 2014 Square Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/*
|
||||
|
||||
Package jose aims to provide an implementation of the Javascript Object Signing
|
||||
and Encryption set of standards. It implements encryption and signing based on
|
||||
the JSON Web Encryption and JSON Web Signature standards, with optional JSON
|
||||
Web Token support available in a sub-package. The library supports both the
|
||||
compact and full serialization formats, and has optional support for multiple
|
||||
recipients.
|
||||
|
||||
*/
|
||||
package jose
|
|
@ -0,0 +1,179 @@
|
|||
/*-
|
||||
* Copyright 2014 Square Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package jose
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/flate"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"math/big"
|
||||
"regexp"
|
||||
|
||||
"gopkg.in/square/go-jose.v2/json"
|
||||
)
|
||||
|
||||
var stripWhitespaceRegex = regexp.MustCompile("\\s")
|
||||
|
||||
// Helper function to serialize known-good objects.
|
||||
// Precondition: value is not a nil pointer.
|
||||
func mustSerializeJSON(value interface{}) []byte {
|
||||
out, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
// We never want to serialize the top-level value "null," since it's not a
|
||||
// valid JOSE message. But if a caller passes in a nil pointer to this method,
|
||||
// MarshalJSON will happily serialize it as the top-level value "null". If
|
||||
// that value is then embedded in another operation, for instance by being
|
||||
// base64-encoded and fed as input to a signing algorithm
|
||||
// (https://github.com/square/go-jose/issues/22), the result will be
|
||||
// incorrect. Because this method is intended for known-good objects, and a nil
|
||||
// pointer is not a known-good object, we are free to panic in this case.
|
||||
// Note: It's not possible to directly check whether the data pointed at by an
|
||||
// interface is a nil pointer, so we do this hacky workaround.
|
||||
// https://groups.google.com/forum/#!topic/golang-nuts/wnH302gBa4I
|
||||
if string(out) == "null" {
|
||||
panic("Tried to serialize a nil pointer.")
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// Strip all newlines and whitespace
|
||||
func stripWhitespace(data string) string {
|
||||
return stripWhitespaceRegex.ReplaceAllString(data, "")
|
||||
}
|
||||
|
||||
// Perform compression based on algorithm
|
||||
func compress(algorithm CompressionAlgorithm, input []byte) ([]byte, error) {
|
||||
switch algorithm {
|
||||
case DEFLATE:
|
||||
return deflate(input)
|
||||
default:
|
||||
return nil, ErrUnsupportedAlgorithm
|
||||
}
|
||||
}
|
||||
|
||||
// Perform decompression based on algorithm
|
||||
func decompress(algorithm CompressionAlgorithm, input []byte) ([]byte, error) {
|
||||
switch algorithm {
|
||||
case DEFLATE:
|
||||
return inflate(input)
|
||||
default:
|
||||
return nil, ErrUnsupportedAlgorithm
|
||||
}
|
||||
}
|
||||
|
||||
// Compress with DEFLATE
|
||||
func deflate(input []byte) ([]byte, error) {
|
||||
output := new(bytes.Buffer)
|
||||
|
||||
// Writing to byte buffer, err is always nil
|
||||
writer, _ := flate.NewWriter(output, 1)
|
||||
_, _ = io.Copy(writer, bytes.NewBuffer(input))
|
||||
|
||||
err := writer.Close()
|
||||
return output.Bytes(), err
|
||||
}
|
||||
|
||||
// Decompress with DEFLATE
|
||||
func inflate(input []byte) ([]byte, error) {
|
||||
output := new(bytes.Buffer)
|
||||
reader := flate.NewReader(bytes.NewBuffer(input))
|
||||
|
||||
_, err := io.Copy(output, reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = reader.Close()
|
||||
return output.Bytes(), err
|
||||
}
|
||||
|
||||
// byteBuffer represents a slice of bytes that can be serialized to url-safe base64.
|
||||
type byteBuffer struct {
|
||||
data []byte
|
||||
}
|
||||
|
||||
func newBuffer(data []byte) *byteBuffer {
|
||||
if data == nil {
|
||||
return nil
|
||||
}
|
||||
return &byteBuffer{
|
||||
data: data,
|
||||
}
|
||||
}
|
||||
|
||||
func newFixedSizeBuffer(data []byte, length int) *byteBuffer {
|
||||
if len(data) > length {
|
||||
panic("square/go-jose: invalid call to newFixedSizeBuffer (len(data) > length)")
|
||||
}
|
||||
pad := make([]byte, length-len(data))
|
||||
return newBuffer(append(pad, data...))
|
||||
}
|
||||
|
||||
func newBufferFromInt(num uint64) *byteBuffer {
|
||||
data := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(data, num)
|
||||
return newBuffer(bytes.TrimLeft(data, "\x00"))
|
||||
}
|
||||
|
||||
func (b *byteBuffer) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(b.base64())
|
||||
}
|
||||
|
||||
func (b *byteBuffer) UnmarshalJSON(data []byte) error {
|
||||
var encoded string
|
||||
err := json.Unmarshal(data, &encoded)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if encoded == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
decoded, err := base64.RawURLEncoding.DecodeString(encoded)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*b = *newBuffer(decoded)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *byteBuffer) base64() string {
|
||||
return base64.RawURLEncoding.EncodeToString(b.data)
|
||||
}
|
||||
|
||||
func (b *byteBuffer) bytes() []byte {
|
||||
// Handling nil here allows us to transparently handle nil slices when serializing.
|
||||
if b == nil {
|
||||
return nil
|
||||
}
|
||||
return b.data
|
||||
}
|
||||
|
||||
func (b byteBuffer) bigInt() *big.Int {
|
||||
return new(big.Int).SetBytes(b.data)
|
||||
}
|
||||
|
||||
func (b byteBuffer) toInt() int {
|
||||
return int(b.bigInt().Int64())
|
||||
}
|
|
@ -0,0 +1,27 @@
|
|||
Copyright (c) 2012 The Go Authors. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above
|
||||
copyright notice, this list of conditions and the following disclaimer
|
||||
in the documentation and/or other materials provided with the
|
||||
distribution.
|
||||
* Neither the name of Google Inc. nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
@ -0,0 +1,13 @@
|
|||
# Safe JSON
|
||||
|
||||
This repository contains a fork of the `encoding/json` package from Go 1.6.
|
||||
|
||||
The following changes were made:
|
||||
|
||||
* Object deserialization uses case-sensitive member name matching instead of
|
||||
[case-insensitive matching](https://www.ietf.org/mail-archive/web/json/current/msg03763.html).
|
||||
This is to avoid differences in the interpretation of JOSE messages between
|
||||
go-jose and libraries written in other languages.
|
||||
* When deserializing a JSON object, we check for duplicate keys and reject the
|
||||
input whenever we detect a duplicate. Rather than trying to work with malformed
|
||||
data, we prefer to reject it right away.
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,141 @@
|
|||
// Copyright 2010 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package json
|
||||
|
||||
import "bytes"
|
||||
|
||||
// Compact appends to dst the JSON-encoded src with
|
||||
// insignificant space characters elided.
|
||||
func Compact(dst *bytes.Buffer, src []byte) error {
|
||||
return compact(dst, src, false)
|
||||
}
|
||||
|
||||
func compact(dst *bytes.Buffer, src []byte, escape bool) error {
|
||||
origLen := dst.Len()
|
||||
var scan scanner
|
||||
scan.reset()
|
||||
start := 0
|
||||
for i, c := range src {
|
||||
if escape && (c == '<' || c == '>' || c == '&') {
|
||||
if start < i {
|
||||
dst.Write(src[start:i])
|
||||
}
|
||||
dst.WriteString(`\u00`)
|
||||
dst.WriteByte(hex[c>>4])
|
||||
dst.WriteByte(hex[c&0xF])
|
||||
start = i + 1
|
||||
}
|
||||
// Convert U+2028 and U+2029 (E2 80 A8 and E2 80 A9).
|
||||
if c == 0xE2 && i+2 < len(src) && src[i+1] == 0x80 && src[i+2]&^1 == 0xA8 {
|
||||
if start < i {
|
||||
dst.Write(src[start:i])
|
||||
}
|
||||
dst.WriteString(`\u202`)
|
||||
dst.WriteByte(hex[src[i+2]&0xF])
|
||||
start = i + 3
|
||||
}
|
||||
v := scan.step(&scan, c)
|
||||
if v >= scanSkipSpace {
|
||||
if v == scanError {
|
||||
break
|
||||
}
|
||||
if start < i {
|
||||
dst.Write(src[start:i])
|
||||
}
|
||||
start = i + 1
|
||||
}
|
||||
}
|
||||
if scan.eof() == scanError {
|
||||
dst.Truncate(origLen)
|
||||
return scan.err
|
||||
}
|
||||
if start < len(src) {
|
||||
dst.Write(src[start:])
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func newline(dst *bytes.Buffer, prefix, indent string, depth int) {
|
||||
dst.WriteByte('\n')
|
||||
dst.WriteString(prefix)
|
||||
for i := 0; i < depth; i++ {
|
||||
dst.WriteString(indent)
|
||||
}
|
||||
}
|
||||
|
||||
// Indent appends to dst an indented form of the JSON-encoded src.
|
||||
// Each element in a JSON object or array begins on a new,
|
||||
// indented line beginning with prefix followed by one or more
|
||||
// copies of indent according to the indentation nesting.
|
||||
// The data appended to dst does not begin with the prefix nor
|
||||
// any indentation, to make it easier to embed inside other formatted JSON data.
|
||||
// Although leading space characters (space, tab, carriage return, newline)
|
||||
// at the beginning of src are dropped, trailing space characters
|
||||
// at the end of src are preserved and copied to dst.
|
||||
// For example, if src has no trailing spaces, neither will dst;
|
||||
// if src ends in a trailing newline, so will dst.
|
||||
func Indent(dst *bytes.Buffer, src []byte, prefix, indent string) error {
|
||||
origLen := dst.Len()
|
||||
var scan scanner
|
||||
scan.reset()
|
||||
needIndent := false
|
||||
depth := 0
|
||||
for _, c := range src {
|
||||
scan.bytes++
|
||||
v := scan.step(&scan, c)
|
||||
if v == scanSkipSpace {
|
||||
continue
|
||||
}
|
||||
if v == scanError {
|
||||
break
|
||||
}
|
||||
if needIndent && v != scanEndObject && v != scanEndArray {
|
||||
needIndent = false
|
||||
depth++
|
||||
newline(dst, prefix, indent, depth)
|
||||
}
|
||||
|
||||
// Emit semantically uninteresting bytes
|
||||
// (in particular, punctuation in strings) unmodified.
|
||||
if v == scanContinue {
|
||||
dst.WriteByte(c)
|
||||
continue
|
||||
}
|
||||
|
||||
// Add spacing around real punctuation.
|
||||
switch c {
|
||||
case '{', '[':
|
||||
// delay indent so that empty object and array are formatted as {} and [].
|
||||
needIndent = true
|
||||
dst.WriteByte(c)
|
||||
|
||||
case ',':
|
||||
dst.WriteByte(c)
|
||||
newline(dst, prefix, indent, depth)
|
||||
|
||||
case ':':
|
||||
dst.WriteByte(c)
|
||||
dst.WriteByte(' ')
|
||||
|
||||
case '}', ']':
|
||||
if needIndent {
|
||||
// suppress indent in empty object/array
|
||||
needIndent = false
|
||||
} else {
|
||||
depth--
|
||||
newline(dst, prefix, indent, depth)
|
||||
}
|
||||
dst.WriteByte(c)
|
||||
|
||||
default:
|
||||
dst.WriteByte(c)
|
||||
}
|
||||
}
|
||||
if scan.eof() == scanError {
|
||||
dst.Truncate(origLen)
|
||||
return scan.err
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,623 @@
|
|||
// Copyright 2010 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package json
|
||||
|
||||
// JSON value parser state machine.
|
||||
// Just about at the limit of what is reasonable to write by hand.
|
||||
// Some parts are a bit tedious, but overall it nicely factors out the
|
||||
// otherwise common code from the multiple scanning functions
|
||||
// in this package (Compact, Indent, checkValid, nextValue, etc).
|
||||
//
|
||||
// This file starts with two simple examples using the scanner
|
||||
// before diving into the scanner itself.
|
||||
|
||||
import "strconv"
|
||||
|
||||
// checkValid verifies that data is valid JSON-encoded data.
|
||||
// scan is passed in for use by checkValid to avoid an allocation.
|
||||
func checkValid(data []byte, scan *scanner) error {
|
||||
scan.reset()
|
||||
for _, c := range data {
|
||||
scan.bytes++
|
||||
if scan.step(scan, c) == scanError {
|
||||
return scan.err
|
||||
}
|
||||
}
|
||||
if scan.eof() == scanError {
|
||||
return scan.err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// nextValue splits data after the next whole JSON value,
|
||||
// returning that value and the bytes that follow it as separate slices.
|
||||
// scan is passed in for use by nextValue to avoid an allocation.
|
||||
func nextValue(data []byte, scan *scanner) (value, rest []byte, err error) {
|
||||
scan.reset()
|
||||
for i, c := range data {
|
||||
v := scan.step(scan, c)
|
||||
if v >= scanEndObject {
|
||||
switch v {
|
||||
// probe the scanner with a space to determine whether we will
|
||||
// get scanEnd on the next character. Otherwise, if the next character
|
||||
// is not a space, scanEndTop allocates a needless error.
|
||||
case scanEndObject, scanEndArray:
|
||||
if scan.step(scan, ' ') == scanEnd {
|
||||
return data[:i+1], data[i+1:], nil
|
||||
}
|
||||
case scanError:
|
||||
return nil, nil, scan.err
|
||||
case scanEnd:
|
||||
return data[:i], data[i:], nil
|
||||
}
|
||||
}
|
||||
}
|
||||
if scan.eof() == scanError {
|
||||
return nil, nil, scan.err
|
||||
}
|
||||
return data, nil, nil
|
||||
}
|
||||
|
||||
// A SyntaxError is a description of a JSON syntax error.
|
||||
type SyntaxError struct {
|
||||
msg string // description of error
|
||||
Offset int64 // error occurred after reading Offset bytes
|
||||
}
|
||||
|
||||
func (e *SyntaxError) Error() string { return e.msg }
|
||||
|
||||
// A scanner is a JSON scanning state machine.
|
||||
// Callers call scan.reset() and then pass bytes in one at a time
|
||||
// by calling scan.step(&scan, c) for each byte.
|
||||
// The return value, referred to as an opcode, tells the
|
||||
// caller about significant parsing events like beginning
|
||||
// and ending literals, objects, and arrays, so that the
|
||||
// caller can follow along if it wishes.
|
||||
// The return value scanEnd indicates that a single top-level
|
||||
// JSON value has been completed, *before* the byte that
|
||||
// just got passed in. (The indication must be delayed in order
|
||||
// to recognize the end of numbers: is 123 a whole value or
|
||||
// the beginning of 12345e+6?).
|
||||
type scanner struct {
|
||||
// The step is a func to be called to execute the next transition.
|
||||
// Also tried using an integer constant and a single func
|
||||
// with a switch, but using the func directly was 10% faster
|
||||
// on a 64-bit Mac Mini, and it's nicer to read.
|
||||
step func(*scanner, byte) int
|
||||
|
||||
// Reached end of top-level value.
|
||||
endTop bool
|
||||
|
||||
// Stack of what we're in the middle of - array values, object keys, object values.
|
||||
parseState []int
|
||||
|
||||
// Error that happened, if any.
|
||||
err error
|
||||
|
||||
// 1-byte redo (see undo method)
|
||||
redo bool
|
||||
redoCode int
|
||||
redoState func(*scanner, byte) int
|
||||
|
||||
// total bytes consumed, updated by decoder.Decode
|
||||
bytes int64
|
||||
}
|
||||
|
||||
// These values are returned by the state transition functions
|
||||
// assigned to scanner.state and the method scanner.eof.
|
||||
// They give details about the current state of the scan that
|
||||
// callers might be interested to know about.
|
||||
// It is okay to ignore the return value of any particular
|
||||
// call to scanner.state: if one call returns scanError,
|
||||
// every subsequent call will return scanError too.
|
||||
const (
|
||||
// Continue.
|
||||
scanContinue = iota // uninteresting byte
|
||||
scanBeginLiteral // end implied by next result != scanContinue
|
||||
scanBeginObject // begin object
|
||||
scanObjectKey // just finished object key (string)
|
||||
scanObjectValue // just finished non-last object value
|
||||
scanEndObject // end object (implies scanObjectValue if possible)
|
||||
scanBeginArray // begin array
|
||||
scanArrayValue // just finished array value
|
||||
scanEndArray // end array (implies scanArrayValue if possible)
|
||||
scanSkipSpace // space byte; can skip; known to be last "continue" result
|
||||
|
||||
// Stop.
|
||||
scanEnd // top-level value ended *before* this byte; known to be first "stop" result
|
||||
scanError // hit an error, scanner.err.
|
||||
)
|
||||
|
||||
// These values are stored in the parseState stack.
|
||||
// They give the current state of a composite value
|
||||
// being scanned. If the parser is inside a nested value
|
||||
// the parseState describes the nested state, outermost at entry 0.
|
||||
const (
|
||||
parseObjectKey = iota // parsing object key (before colon)
|
||||
parseObjectValue // parsing object value (after colon)
|
||||
parseArrayValue // parsing array value
|
||||
)
|
||||
|
||||
// reset prepares the scanner for use.
|
||||
// It must be called before calling s.step.
|
||||
func (s *scanner) reset() {
|
||||
s.step = stateBeginValue
|
||||
s.parseState = s.parseState[0:0]
|
||||
s.err = nil
|
||||
s.redo = false
|
||||
s.endTop = false
|
||||
}
|
||||
|
||||
// eof tells the scanner that the end of input has been reached.
|
||||
// It returns a scan status just as s.step does.
|
||||
func (s *scanner) eof() int {
|
||||
if s.err != nil {
|
||||
return scanError
|
||||
}
|
||||
if s.endTop {
|
||||
return scanEnd
|
||||
}
|
||||
s.step(s, ' ')
|
||||
if s.endTop {
|
||||
return scanEnd
|
||||
}
|
||||
if s.err == nil {
|
||||
s.err = &SyntaxError{"unexpected end of JSON input", s.bytes}
|
||||
}
|
||||
return scanError
|
||||
}
|
||||
|
||||
// pushParseState pushes a new parse state p onto the parse stack.
|
||||
func (s *scanner) pushParseState(p int) {
|
||||
s.parseState = append(s.parseState, p)
|
||||
}
|
||||
|
||||
// popParseState pops a parse state (already obtained) off the stack
|
||||
// and updates s.step accordingly.
|
||||
func (s *scanner) popParseState() {
|
||||
n := len(s.parseState) - 1
|
||||
s.parseState = s.parseState[0:n]
|
||||
s.redo = false
|
||||
if n == 0 {
|
||||
s.step = stateEndTop
|
||||
s.endTop = true
|
||||
} else {
|
||||
s.step = stateEndValue
|
||||
}
|
||||
}
|
||||
|
||||
func isSpace(c byte) bool {
|
||||
return c == ' ' || c == '\t' || c == '\r' || c == '\n'
|
||||
}
|
||||
|
||||
// stateBeginValueOrEmpty is the state after reading `[`.
|
||||
func stateBeginValueOrEmpty(s *scanner, c byte) int {
|
||||
if c <= ' ' && isSpace(c) {
|
||||
return scanSkipSpace
|
||||
}
|
||||
if c == ']' {
|
||||
return stateEndValue(s, c)
|
||||
}
|
||||
return stateBeginValue(s, c)
|
||||
}
|
||||
|
||||
// stateBeginValue is the state at the beginning of the input.
|
||||
func stateBeginValue(s *scanner, c byte) int {
|
||||
if c <= ' ' && isSpace(c) {
|
||||
return scanSkipSpace
|
||||
}
|
||||
switch c {
|
||||
case '{':
|
||||
s.step = stateBeginStringOrEmpty
|
||||
s.pushParseState(parseObjectKey)
|
||||
return scanBeginObject
|
||||
case '[':
|
||||
s.step = stateBeginValueOrEmpty
|
||||
s.pushParseState(parseArrayValue)
|
||||
return scanBeginArray
|
||||
case '"':
|
||||
s.step = stateInString
|
||||
return scanBeginLiteral
|
||||
case '-':
|
||||
s.step = stateNeg
|
||||
return scanBeginLiteral
|
||||
case '0': // beginning of 0.123
|
||||
s.step = state0
|
||||
return scanBeginLiteral
|
||||
case 't': // beginning of true
|
||||
s.step = stateT
|
||||
return scanBeginLiteral
|
||||
case 'f': // beginning of false
|
||||
s.step = stateF
|
||||
return scanBeginLiteral
|
||||
case 'n': // beginning of null
|
||||
s.step = stateN
|
||||
return scanBeginLiteral
|
||||
}
|
||||
if '1' <= c && c <= '9' { // beginning of 1234.5
|
||||
s.step = state1
|
||||
return scanBeginLiteral
|
||||
}
|
||||
return s.error(c, "looking for beginning of value")
|
||||
}
|
||||
|
||||
// stateBeginStringOrEmpty is the state after reading `{`.
|
||||
func stateBeginStringOrEmpty(s *scanner, c byte) int {
|
||||
if c <= ' ' && isSpace(c) {
|
||||
return scanSkipSpace
|
||||
}
|
||||
if c == '}' {
|
||||
n := len(s.parseState)
|
||||
s.parseState[n-1] = parseObjectValue
|
||||
return stateEndValue(s, c)
|
||||
}
|
||||
return stateBeginString(s, c)
|
||||
}
|
||||
|
||||
// stateBeginString is the state after reading `{"key": value,`.
|
||||
func stateBeginString(s *scanner, c byte) int {
|
||||
if c <= ' ' && isSpace(c) {
|
||||
return scanSkipSpace
|
||||
}
|
||||
if c == '"' {
|
||||
s.step = stateInString
|
||||
return scanBeginLiteral
|
||||
}
|
||||
return s.error(c, "looking for beginning of object key string")
|
||||
}
|
||||
|
||||
// stateEndValue is the state after completing a value,
|
||||
// such as after reading `{}` or `true` or `["x"`.
|
||||
func stateEndValue(s *scanner, c byte) int {
|
||||
n := len(s.parseState)
|
||||
if n == 0 {
|
||||
// Completed top-level before the current byte.
|
||||
s.step = stateEndTop
|
||||
s.endTop = true
|
||||
return stateEndTop(s, c)
|
||||
}
|
||||
if c <= ' ' && isSpace(c) {
|
||||
s.step = stateEndValue
|
||||
return scanSkipSpace
|
||||
}
|
||||
ps := s.parseState[n-1]
|
||||
switch ps {
|
||||
case parseObjectKey:
|
||||
if c == ':' {
|
||||
s.parseState[n-1] = parseObjectValue
|
||||
s.step = stateBeginValue
|
||||
return scanObjectKey
|
||||
}
|
||||
return s.error(c, "after object key")
|
||||
case parseObjectValue:
|
||||
if c == ',' {
|
||||
s.parseState[n-1] = parseObjectKey
|
||||
s.step = stateBeginString
|
||||
return scanObjectValue
|
||||
}
|
||||
if c == '}' {
|
||||
s.popParseState()
|
||||
return scanEndObject
|
||||
}
|
||||
return s.error(c, "after object key:value pair")
|
||||
case parseArrayValue:
|
||||
if c == ',' {
|
||||
s.step = stateBeginValue
|
||||
return scanArrayValue
|
||||
}
|
||||
if c == ']' {
|
||||
s.popParseState()
|
||||
return scanEndArray
|
||||
}
|
||||
return s.error(c, "after array element")
|
||||
}
|
||||
return s.error(c, "")
|
||||
}
|
||||
|
||||
// stateEndTop is the state after finishing the top-level value,
|
||||
// such as after reading `{}` or `[1,2,3]`.
|
||||
// Only space characters should be seen now.
|
||||
func stateEndTop(s *scanner, c byte) int {
|
||||
if c != ' ' && c != '\t' && c != '\r' && c != '\n' {
|
||||
// Complain about non-space byte on next call.
|
||||
s.error(c, "after top-level value")
|
||||
}
|
||||
return scanEnd
|
||||
}
|
||||
|
||||
// stateInString is the state after reading `"`.
|
||||
func stateInString(s *scanner, c byte) int {
|
||||
if c == '"' {
|
||||
s.step = stateEndValue
|
||||
return scanContinue
|
||||
}
|
||||
if c == '\\' {
|
||||
s.step = stateInStringEsc
|
||||
return scanContinue
|
||||
}
|
||||
if c < 0x20 {
|
||||
return s.error(c, "in string literal")
|
||||
}
|
||||
return scanContinue
|
||||
}
|
||||
|
||||
// stateInStringEsc is the state after reading `"\` during a quoted string.
|
||||
func stateInStringEsc(s *scanner, c byte) int {
|
||||
switch c {
|
||||
case 'b', 'f', 'n', 'r', 't', '\\', '/', '"':
|
||||
s.step = stateInString
|
||||
return scanContinue
|
||||
case 'u':
|
||||
s.step = stateInStringEscU
|
||||
return scanContinue
|
||||
}
|
||||
return s.error(c, "in string escape code")
|
||||
}
|
||||
|
||||
// stateInStringEscU is the state after reading `"\u` during a quoted string.
|
||||
func stateInStringEscU(s *scanner, c byte) int {
|
||||
if '0' <= c && c <= '9' || 'a' <= c && c <= 'f' || 'A' <= c && c <= 'F' {
|
||||
s.step = stateInStringEscU1
|
||||
return scanContinue
|
||||
}
|
||||
// numbers
|
||||
return s.error(c, "in \\u hexadecimal character escape")
|
||||
}
|
||||
|
||||
// stateInStringEscU1 is the state after reading `"\u1` during a quoted string.
|
||||
func stateInStringEscU1(s *scanner, c byte) int {
|
||||
if '0' <= c && c <= '9' || 'a' <= c && c <= 'f' || 'A' <= c && c <= 'F' {
|
||||
s.step = stateInStringEscU12
|
||||
return scanContinue
|
||||
}
|
||||
// numbers
|
||||
return s.error(c, "in \\u hexadecimal character escape")
|
||||
}
|
||||
|
||||
// stateInStringEscU12 is the state after reading `"\u12` during a quoted string.
|
||||
func stateInStringEscU12(s *scanner, c byte) int {
|
||||
if '0' <= c && c <= '9' || 'a' <= c && c <= 'f' || 'A' <= c && c <= 'F' {
|
||||
s.step = stateInStringEscU123
|
||||
return scanContinue
|
||||
}
|
||||
// numbers
|
||||
return s.error(c, "in \\u hexadecimal character escape")
|
||||
}
|
||||
|
||||
// stateInStringEscU123 is the state after reading `"\u123` during a quoted string.
|
||||
func stateInStringEscU123(s *scanner, c byte) int {
|
||||
if '0' <= c && c <= '9' || 'a' <= c && c <= 'f' || 'A' <= c && c <= 'F' {
|
||||
s.step = stateInString
|
||||
return scanContinue
|
||||
}
|
||||
// numbers
|
||||
return s.error(c, "in \\u hexadecimal character escape")
|
||||
}
|
||||
|
||||
// stateNeg is the state after reading `-` during a number.
|
||||
func stateNeg(s *scanner, c byte) int {
|
||||
if c == '0' {
|
||||
s.step = state0
|
||||
return scanContinue
|
||||
}
|
||||
if '1' <= c && c <= '9' {
|
||||
s.step = state1
|
||||
return scanContinue
|
||||
}
|
||||
return s.error(c, "in numeric literal")
|
||||
}
|
||||
|
||||
// state1 is the state after reading a non-zero integer during a number,
|
||||
// such as after reading `1` or `100` but not `0`.
|
||||
func state1(s *scanner, c byte) int {
|
||||
if '0' <= c && c <= '9' {
|
||||
s.step = state1
|
||||
return scanContinue
|
||||
}
|
||||
return state0(s, c)
|
||||
}
|
||||
|
||||
// state0 is the state after reading `0` during a number.
|
||||
func state0(s *scanner, c byte) int {
|
||||
if c == '.' {
|
||||
s.step = stateDot
|
||||
return scanContinue
|
||||
}
|
||||
if c == 'e' || c == 'E' {
|
||||
s.step = stateE
|
||||
return scanContinue
|
||||
}
|
||||
return stateEndValue(s, c)
|
||||
}
|
||||
|
||||
// stateDot is the state after reading the integer and decimal point in a number,
|
||||
// such as after reading `1.`.
|
||||
func stateDot(s *scanner, c byte) int {
|
||||
if '0' <= c && c <= '9' {
|
||||
s.step = stateDot0
|
||||
return scanContinue
|
||||
}
|
||||
return s.error(c, "after decimal point in numeric literal")
|
||||
}
|
||||
|
||||
// stateDot0 is the state after reading the integer, decimal point, and subsequent
|
||||
// digits of a number, such as after reading `3.14`.
|
||||
func stateDot0(s *scanner, c byte) int {
|
||||
if '0' <= c && c <= '9' {
|
||||
return scanContinue
|
||||
}
|
||||
if c == 'e' || c == 'E' {
|
||||
s.step = stateE
|
||||
return scanContinue
|
||||
}
|
||||
return stateEndValue(s, c)
|
||||
}
|
||||
|
||||
// stateE is the state after reading the mantissa and e in a number,
|
||||
// such as after reading `314e` or `0.314e`.
|
||||
func stateE(s *scanner, c byte) int {
|
||||
if c == '+' || c == '-' {
|
||||
s.step = stateESign
|
||||
return scanContinue
|
||||
}
|
||||
return stateESign(s, c)
|
||||
}
|
||||
|
||||
// stateESign is the state after reading the mantissa, e, and sign in a number,
|
||||
// such as after reading `314e-` or `0.314e+`.
|
||||
func stateESign(s *scanner, c byte) int {
|
||||
if '0' <= c && c <= '9' {
|
||||
s.step = stateE0
|
||||
return scanContinue
|
||||
}
|
||||
return s.error(c, "in exponent of numeric literal")
|
||||
}
|
||||
|
||||
// stateE0 is the state after reading the mantissa, e, optional sign,
|
||||
// and at least one digit of the exponent in a number,
|
||||
// such as after reading `314e-2` or `0.314e+1` or `3.14e0`.
|
||||
func stateE0(s *scanner, c byte) int {
|
||||
if '0' <= c && c <= '9' {
|
||||
return scanContinue
|
||||
}
|
||||
return stateEndValue(s, c)
|
||||
}
|
||||
|
||||
// stateT is the state after reading `t`.
|
||||
func stateT(s *scanner, c byte) int {
|
||||
if c == 'r' {
|
||||
s.step = stateTr
|
||||
return scanContinue
|
||||
}
|
||||
return s.error(c, "in literal true (expecting 'r')")
|
||||
}
|
||||
|
||||
// stateTr is the state after reading `tr`.
|
||||
func stateTr(s *scanner, c byte) int {
|
||||
if c == 'u' {
|
||||
s.step = stateTru
|
||||
return scanContinue
|
||||
}
|
||||
return s.error(c, "in literal true (expecting 'u')")
|
||||
}
|
||||
|
||||
// stateTru is the state after reading `tru`.
|
||||
func stateTru(s *scanner, c byte) int {
|
||||
if c == 'e' {
|
||||
s.step = stateEndValue
|
||||
return scanContinue
|
||||
}
|
||||
return s.error(c, "in literal true (expecting 'e')")
|
||||
}
|
||||
|
||||
// stateF is the state after reading `f`.
|
||||
func stateF(s *scanner, c byte) int {
|
||||
if c == 'a' {
|
||||
s.step = stateFa
|
||||
return scanContinue
|
||||
}
|
||||
return s.error(c, "in literal false (expecting 'a')")
|
||||
}
|
||||
|
||||
// stateFa is the state after reading `fa`.
|
||||
func stateFa(s *scanner, c byte) int {
|
||||
if c == 'l' {
|
||||
s.step = stateFal
|
||||
return scanContinue
|
||||
}
|
||||
return s.error(c, "in literal false (expecting 'l')")
|
||||
}
|
||||
|
||||
// stateFal is the state after reading `fal`.
|
||||
func stateFal(s *scanner, c byte) int {
|
||||
if c == 's' {
|
||||
s.step = stateFals
|
||||
return scanContinue
|
||||
}
|
||||
return s.error(c, "in literal false (expecting 's')")
|
||||
}
|
||||
|
||||
// stateFals is the state after reading `fals`.
|
||||
func stateFals(s *scanner, c byte) int {
|
||||
if c == 'e' {
|
||||
s.step = stateEndValue
|
||||
return scanContinue
|
||||
}
|
||||
return s.error(c, "in literal false (expecting 'e')")
|
||||
}
|
||||
|
||||
// stateN is the state after reading `n`.
|
||||
func stateN(s *scanner, c byte) int {
|
||||
if c == 'u' {
|
||||
s.step = stateNu
|
||||
return scanContinue
|
||||
}
|
||||
return s.error(c, "in literal null (expecting 'u')")
|
||||
}
|
||||
|
||||
// stateNu is the state after reading `nu`.
|
||||
func stateNu(s *scanner, c byte) int {
|
||||
if c == 'l' {
|
||||
s.step = stateNul
|
||||
return scanContinue
|
||||
}
|
||||
return s.error(c, "in literal null (expecting 'l')")
|
||||
}
|
||||
|
||||
// stateNul is the state after reading `nul`.
|
||||
func stateNul(s *scanner, c byte) int {
|
||||
if c == 'l' {
|
||||
s.step = stateEndValue
|
||||
return scanContinue
|
||||
}
|
||||
return s.error(c, "in literal null (expecting 'l')")
|
||||
}
|
||||
|
||||
// stateError is the state after reaching a syntax error,
|
||||
// such as after reading `[1}` or `5.1.2`.
|
||||
func stateError(s *scanner, c byte) int {
|
||||
return scanError
|
||||
}
|
||||
|
||||
// error records an error and switches to the error state.
|
||||
func (s *scanner) error(c byte, context string) int {
|
||||
s.step = stateError
|
||||
s.err = &SyntaxError{"invalid character " + quoteChar(c) + " " + context, s.bytes}
|
||||
return scanError
|
||||
}
|
||||
|
||||
// quoteChar formats c as a quoted character literal
|
||||
func quoteChar(c byte) string {
|
||||
// special cases - different from quoted strings
|
||||
if c == '\'' {
|
||||
return `'\''`
|
||||
}
|
||||
if c == '"' {
|
||||
return `'"'`
|
||||
}
|
||||
|
||||
// use quoted string with different quotation marks
|
||||
s := strconv.Quote(string(c))
|
||||
return "'" + s[1:len(s)-1] + "'"
|
||||
}
|
||||
|
||||
// undo causes the scanner to return scanCode from the next state transition.
|
||||
// This gives callers a simple 1-byte undo mechanism.
|
||||
func (s *scanner) undo(scanCode int) {
|
||||
if s.redo {
|
||||
panic("json: invalid use of scanner")
|
||||
}
|
||||
s.redoCode = scanCode
|
||||
s.redoState = s.step
|
||||
s.step = stateRedo
|
||||
s.redo = true
|
||||
}
|
||||
|
||||
// stateRedo helps implement the scanner's 1-byte undo.
|
||||
func stateRedo(s *scanner, c byte) int {
|
||||
s.redo = false
|
||||
s.step = s.redoState
|
||||
return s.redoCode
|
||||
}
|
|
@ -0,0 +1,480 @@
|
|||
// Copyright 2010 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package json
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
)
|
||||
|
||||
// A Decoder reads and decodes JSON objects from an input stream.
|
||||
type Decoder struct {
|
||||
r io.Reader
|
||||
buf []byte
|
||||
d decodeState
|
||||
scanp int // start of unread data in buf
|
||||
scan scanner
|
||||
err error
|
||||
|
||||
tokenState int
|
||||
tokenStack []int
|
||||
}
|
||||
|
||||
// NewDecoder returns a new decoder that reads from r.
|
||||
//
|
||||
// The decoder introduces its own buffering and may
|
||||
// read data from r beyond the JSON values requested.
|
||||
func NewDecoder(r io.Reader) *Decoder {
|
||||
return &Decoder{r: r}
|
||||
}
|
||||
|
||||
// UseNumber causes the Decoder to unmarshal a number into an interface{} as a
|
||||
// Number instead of as a float64.
|
||||
func (dec *Decoder) UseNumber() { dec.d.useNumber = true }
|
||||
|
||||
// Decode reads the next JSON-encoded value from its
|
||||
// input and stores it in the value pointed to by v.
|
||||
//
|
||||
// See the documentation for Unmarshal for details about
|
||||
// the conversion of JSON into a Go value.
|
||||
func (dec *Decoder) Decode(v interface{}) error {
|
||||
if dec.err != nil {
|
||||
return dec.err
|
||||
}
|
||||
|
||||
if err := dec.tokenPrepareForDecode(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !dec.tokenValueAllowed() {
|
||||
return &SyntaxError{msg: "not at beginning of value"}
|
||||
}
|
||||
|
||||
// Read whole value into buffer.
|
||||
n, err := dec.readValue()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dec.d.init(dec.buf[dec.scanp : dec.scanp+n])
|
||||
dec.scanp += n
|
||||
|
||||
// Don't save err from unmarshal into dec.err:
|
||||
// the connection is still usable since we read a complete JSON
|
||||
// object from it before the error happened.
|
||||
err = dec.d.unmarshal(v)
|
||||
|
||||
// fixup token streaming state
|
||||
dec.tokenValueEnd()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Buffered returns a reader of the data remaining in the Decoder's
|
||||
// buffer. The reader is valid until the next call to Decode.
|
||||
func (dec *Decoder) Buffered() io.Reader {
|
||||
return bytes.NewReader(dec.buf[dec.scanp:])
|
||||
}
|
||||
|
||||
// readValue reads a JSON value into dec.buf.
|
||||
// It returns the length of the encoding.
|
||||
func (dec *Decoder) readValue() (int, error) {
|
||||
dec.scan.reset()
|
||||
|
||||
scanp := dec.scanp
|
||||
var err error
|
||||
Input:
|
||||
for {
|
||||
// Look in the buffer for a new value.
|
||||
for i, c := range dec.buf[scanp:] {
|
||||
dec.scan.bytes++
|
||||
v := dec.scan.step(&dec.scan, c)
|
||||
if v == scanEnd {
|
||||
scanp += i
|
||||
break Input
|
||||
}
|
||||
// scanEnd is delayed one byte.
|
||||
// We might block trying to get that byte from src,
|
||||
// so instead invent a space byte.
|
||||
if (v == scanEndObject || v == scanEndArray) && dec.scan.step(&dec.scan, ' ') == scanEnd {
|
||||
scanp += i + 1
|
||||
break Input
|
||||
}
|
||||
if v == scanError {
|
||||
dec.err = dec.scan.err
|
||||
return 0, dec.scan.err
|
||||
}
|
||||
}
|
||||
scanp = len(dec.buf)
|
||||
|
||||
// Did the last read have an error?
|
||||
// Delayed until now to allow buffer scan.
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
if dec.scan.step(&dec.scan, ' ') == scanEnd {
|
||||
break Input
|
||||
}
|
||||
if nonSpace(dec.buf) {
|
||||
err = io.ErrUnexpectedEOF
|
||||
}
|
||||
}
|
||||
dec.err = err
|
||||
return 0, err
|
||||
}
|
||||
|
||||
n := scanp - dec.scanp
|
||||
err = dec.refill()
|
||||
scanp = dec.scanp + n
|
||||
}
|
||||
return scanp - dec.scanp, nil
|
||||
}
|
||||
|
||||
func (dec *Decoder) refill() error {
|
||||
// Make room to read more into the buffer.
|
||||
// First slide down data already consumed.
|
||||
if dec.scanp > 0 {
|
||||
n := copy(dec.buf, dec.buf[dec.scanp:])
|
||||
dec.buf = dec.buf[:n]
|
||||
dec.scanp = 0
|
||||
}
|
||||
|
||||
// Grow buffer if not large enough.
|
||||
const minRead = 512
|
||||
if cap(dec.buf)-len(dec.buf) < minRead {
|
||||
newBuf := make([]byte, len(dec.buf), 2*cap(dec.buf)+minRead)
|
||||
copy(newBuf, dec.buf)
|
||||
dec.buf = newBuf
|
||||
}
|
||||
|
||||
// Read. Delay error for next iteration (after scan).
|
||||
n, err := dec.r.Read(dec.buf[len(dec.buf):cap(dec.buf)])
|
||||
dec.buf = dec.buf[0 : len(dec.buf)+n]
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func nonSpace(b []byte) bool {
|
||||
for _, c := range b {
|
||||
if !isSpace(c) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// An Encoder writes JSON objects to an output stream.
|
||||
type Encoder struct {
|
||||
w io.Writer
|
||||
err error
|
||||
}
|
||||
|
||||
// NewEncoder returns a new encoder that writes to w.
|
||||
func NewEncoder(w io.Writer) *Encoder {
|
||||
return &Encoder{w: w}
|
||||
}
|
||||
|
||||
// Encode writes the JSON encoding of v to the stream,
|
||||
// followed by a newline character.
|
||||
//
|
||||
// See the documentation for Marshal for details about the
|
||||
// conversion of Go values to JSON.
|
||||
func (enc *Encoder) Encode(v interface{}) error {
|
||||
if enc.err != nil {
|
||||
return enc.err
|
||||
}
|
||||
e := newEncodeState()
|
||||
err := e.marshal(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Terminate each value with a newline.
|
||||
// This makes the output look a little nicer
|
||||
// when debugging, and some kind of space
|
||||
// is required if the encoded value was a number,
|
||||
// so that the reader knows there aren't more
|
||||
// digits coming.
|
||||
e.WriteByte('\n')
|
||||
|
||||
if _, err = enc.w.Write(e.Bytes()); err != nil {
|
||||
enc.err = err
|
||||
}
|
||||
encodeStatePool.Put(e)
|
||||
return err
|
||||
}
|
||||
|
||||
// RawMessage is a raw encoded JSON object.
|
||||
// It implements Marshaler and Unmarshaler and can
|
||||
// be used to delay JSON decoding or precompute a JSON encoding.
|
||||
type RawMessage []byte
|
||||
|
||||
// MarshalJSON returns *m as the JSON encoding of m.
|
||||
func (m *RawMessage) MarshalJSON() ([]byte, error) {
|
||||
return *m, nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON sets *m to a copy of data.
|
||||
func (m *RawMessage) UnmarshalJSON(data []byte) error {
|
||||
if m == nil {
|
||||
return errors.New("json.RawMessage: UnmarshalJSON on nil pointer")
|
||||
}
|
||||
*m = append((*m)[0:0], data...)
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ Marshaler = (*RawMessage)(nil)
|
||||
var _ Unmarshaler = (*RawMessage)(nil)
|
||||
|
||||
// A Token holds a value of one of these types:
|
||||
//
|
||||
// Delim, for the four JSON delimiters [ ] { }
|
||||
// bool, for JSON booleans
|
||||
// float64, for JSON numbers
|
||||
// Number, for JSON numbers
|
||||
// string, for JSON string literals
|
||||
// nil, for JSON null
|
||||
//
|
||||
type Token interface{}
|
||||
|
||||
const (
|
||||
tokenTopValue = iota
|
||||
tokenArrayStart
|
||||
tokenArrayValue
|
||||
tokenArrayComma
|
||||
tokenObjectStart
|
||||
tokenObjectKey
|
||||
tokenObjectColon
|
||||
tokenObjectValue
|
||||
tokenObjectComma
|
||||
)
|
||||
|
||||
// advance tokenstate from a separator state to a value state
|
||||
func (dec *Decoder) tokenPrepareForDecode() error {
|
||||
// Note: Not calling peek before switch, to avoid
|
||||
// putting peek into the standard Decode path.
|
||||
// peek is only called when using the Token API.
|
||||
switch dec.tokenState {
|
||||
case tokenArrayComma:
|
||||
c, err := dec.peek()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if c != ',' {
|
||||
return &SyntaxError{"expected comma after array element", 0}
|
||||
}
|
||||
dec.scanp++
|
||||
dec.tokenState = tokenArrayValue
|
||||
case tokenObjectColon:
|
||||
c, err := dec.peek()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if c != ':' {
|
||||
return &SyntaxError{"expected colon after object key", 0}
|
||||
}
|
||||
dec.scanp++
|
||||
dec.tokenState = tokenObjectValue
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dec *Decoder) tokenValueAllowed() bool {
|
||||
switch dec.tokenState {
|
||||
case tokenTopValue, tokenArrayStart, tokenArrayValue, tokenObjectValue:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (dec *Decoder) tokenValueEnd() {
|
||||
switch dec.tokenState {
|
||||
case tokenArrayStart, tokenArrayValue:
|
||||
dec.tokenState = tokenArrayComma
|
||||
case tokenObjectValue:
|
||||
dec.tokenState = tokenObjectComma
|
||||
}
|
||||
}
|
||||
|
||||
// A Delim is a JSON array or object delimiter, one of [ ] { or }.
|
||||
type Delim rune
|
||||
|
||||
func (d Delim) String() string {
|
||||
return string(d)
|
||||
}
|
||||
|
||||
// Token returns the next JSON token in the input stream.
|
||||
// At the end of the input stream, Token returns nil, io.EOF.
|
||||
//
|
||||
// Token guarantees that the delimiters [ ] { } it returns are
|
||||
// properly nested and matched: if Token encounters an unexpected
|
||||
// delimiter in the input, it will return an error.
|
||||
//
|
||||
// The input stream consists of basic JSON values—bool, string,
|
||||
// number, and null—along with delimiters [ ] { } of type Delim
|
||||
// to mark the start and end of arrays and objects.
|
||||
// Commas and colons are elided.
|
||||
func (dec *Decoder) Token() (Token, error) {
|
||||
for {
|
||||
c, err := dec.peek()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch c {
|
||||
case '[':
|
||||
if !dec.tokenValueAllowed() {
|
||||
return dec.tokenError(c)
|
||||
}
|
||||
dec.scanp++
|
||||
dec.tokenStack = append(dec.tokenStack, dec.tokenState)
|
||||
dec.tokenState = tokenArrayStart
|
||||
return Delim('['), nil
|
||||
|
||||
case ']':
|
||||
if dec.tokenState != tokenArrayStart && dec.tokenState != tokenArrayComma {
|
||||
return dec.tokenError(c)
|
||||
}
|
||||
dec.scanp++
|
||||
dec.tokenState = dec.tokenStack[len(dec.tokenStack)-1]
|
||||
dec.tokenStack = dec.tokenStack[:len(dec.tokenStack)-1]
|
||||
dec.tokenValueEnd()
|
||||
return Delim(']'), nil
|
||||
|
||||
case '{':
|
||||
if !dec.tokenValueAllowed() {
|
||||
return dec.tokenError(c)
|
||||
}
|
||||
dec.scanp++
|
||||
dec.tokenStack = append(dec.tokenStack, dec.tokenState)
|
||||
dec.tokenState = tokenObjectStart
|
||||
return Delim('{'), nil
|
||||
|
||||
case '}':
|
||||
if dec.tokenState != tokenObjectStart && dec.tokenState != tokenObjectComma {
|
||||
return dec.tokenError(c)
|
||||
}
|
||||
dec.scanp++
|
||||
dec.tokenState = dec.tokenStack[len(dec.tokenStack)-1]
|
||||
dec.tokenStack = dec.tokenStack[:len(dec.tokenStack)-1]
|
||||
dec.tokenValueEnd()
|
||||
return Delim('}'), nil
|
||||
|
||||
case ':':
|
||||
if dec.tokenState != tokenObjectColon {
|
||||
return dec.tokenError(c)
|
||||
}
|
||||
dec.scanp++
|
||||
dec.tokenState = tokenObjectValue
|
||||
continue
|
||||
|
||||
case ',':
|
||||
if dec.tokenState == tokenArrayComma {
|
||||
dec.scanp++
|
||||
dec.tokenState = tokenArrayValue
|
||||
continue
|
||||
}
|
||||
if dec.tokenState == tokenObjectComma {
|
||||
dec.scanp++
|
||||
dec.tokenState = tokenObjectKey
|
||||
continue
|
||||
}
|
||||
return dec.tokenError(c)
|
||||
|
||||
case '"':
|
||||
if dec.tokenState == tokenObjectStart || dec.tokenState == tokenObjectKey {
|
||||
var x string
|
||||
old := dec.tokenState
|
||||
dec.tokenState = tokenTopValue
|
||||
err := dec.Decode(&x)
|
||||
dec.tokenState = old
|
||||
if err != nil {
|
||||
clearOffset(err)
|
||||
return nil, err
|
||||
}
|
||||
dec.tokenState = tokenObjectColon
|
||||
return x, nil
|
||||
}
|
||||
fallthrough
|
||||
|
||||
default:
|
||||
if !dec.tokenValueAllowed() {
|
||||
return dec.tokenError(c)
|
||||
}
|
||||
var x interface{}
|
||||
if err := dec.Decode(&x); err != nil {
|
||||
clearOffset(err)
|
||||
return nil, err
|
||||
}
|
||||
return x, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func clearOffset(err error) {
|
||||
if s, ok := err.(*SyntaxError); ok {
|
||||
s.Offset = 0
|
||||
}
|
||||
}
|
||||
|
||||
func (dec *Decoder) tokenError(c byte) (Token, error) {
|
||||
var context string
|
||||
switch dec.tokenState {
|
||||
case tokenTopValue:
|
||||
context = " looking for beginning of value"
|
||||
case tokenArrayStart, tokenArrayValue, tokenObjectValue:
|
||||
context = " looking for beginning of value"
|
||||
case tokenArrayComma:
|
||||
context = " after array element"
|
||||
case tokenObjectKey:
|
||||
context = " looking for beginning of object key string"
|
||||
case tokenObjectColon:
|
||||
context = " after object key"
|
||||
case tokenObjectComma:
|
||||
context = " after object key:value pair"
|
||||
}
|
||||
return nil, &SyntaxError{"invalid character " + quoteChar(c) + " " + context, 0}
|
||||
}
|
||||
|
||||
// More reports whether there is another element in the
|
||||
// current array or object being parsed.
|
||||
func (dec *Decoder) More() bool {
|
||||
c, err := dec.peek()
|
||||
return err == nil && c != ']' && c != '}'
|
||||
}
|
||||
|
||||
func (dec *Decoder) peek() (byte, error) {
|
||||
var err error
|
||||
for {
|
||||
for i := dec.scanp; i < len(dec.buf); i++ {
|
||||
c := dec.buf[i]
|
||||
if isSpace(c) {
|
||||
continue
|
||||
}
|
||||
dec.scanp = i
|
||||
return c, nil
|
||||
}
|
||||
// buffer has been scanned, now report any error
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
err = dec.refill()
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
TODO
|
||||
|
||||
// EncodeToken writes the given JSON token to the stream.
|
||||
// It returns an error if the delimiters [ ] { } are not properly used.
|
||||
//
|
||||
// EncodeToken does not call Flush, because usually it is part of
|
||||
// a larger operation such as Encode, and those will call Flush when finished.
|
||||
// Callers that create an Encoder and then invoke EncodeToken directly,
|
||||
// without using Encode, need to call Flush when finished to ensure that
|
||||
// the JSON is written to the underlying writer.
|
||||
func (e *Encoder) EncodeToken(t Token) error {
|
||||
...
|
||||
}
|
||||
|
||||
*/
|
|
@ -0,0 +1,44 @@
|
|||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package json
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// tagOptions is the string following a comma in a struct field's "json"
|
||||
// tag, or the empty string. It does not include the leading comma.
|
||||
type tagOptions string
|
||||
|
||||
// parseTag splits a struct field's json tag into its name and
|
||||
// comma-separated options.
|
||||
func parseTag(tag string) (string, tagOptions) {
|
||||
if idx := strings.Index(tag, ","); idx != -1 {
|
||||
return tag[:idx], tagOptions(tag[idx+1:])
|
||||
}
|
||||
return tag, tagOptions("")
|
||||
}
|
||||
|
||||
// Contains reports whether a comma-separated list of options
|
||||
// contains a particular substr flag. substr must be surrounded by a
|
||||
// string boundary or commas.
|
||||
func (o tagOptions) Contains(optionName string) bool {
|
||||
if len(o) == 0 {
|
||||
return false
|
||||
}
|
||||
s := string(o)
|
||||
for s != "" {
|
||||
var next string
|
||||
i := strings.Index(s, ",")
|
||||
if i >= 0 {
|
||||
s, next = s[:i], s[i+1:]
|
||||
}
|
||||
if s == optionName {
|
||||
return true
|
||||
}
|
||||
s = next
|
||||
}
|
||||
return false
|
||||
}
|
|
@ -0,0 +1,294 @@
|
|||
/*-
|
||||
* Copyright 2014 Square Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package jose
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"gopkg.in/square/go-jose.v2/json"
|
||||
)
|
||||
|
||||
// rawJSONWebEncryption represents a raw JWE JSON object. Used for parsing/serializing.
|
||||
type rawJSONWebEncryption struct {
|
||||
Protected *byteBuffer `json:"protected,omitempty"`
|
||||
Unprotected *rawHeader `json:"unprotected,omitempty"`
|
||||
Header *rawHeader `json:"header,omitempty"`
|
||||
Recipients []rawRecipientInfo `json:"recipients,omitempty"`
|
||||
Aad *byteBuffer `json:"aad,omitempty"`
|
||||
EncryptedKey *byteBuffer `json:"encrypted_key,omitempty"`
|
||||
Iv *byteBuffer `json:"iv,omitempty"`
|
||||
Ciphertext *byteBuffer `json:"ciphertext,omitempty"`
|
||||
Tag *byteBuffer `json:"tag,omitempty"`
|
||||
}
|
||||
|
||||
// rawRecipientInfo represents a raw JWE Per-Recipient header JSON object. Used for parsing/serializing.
|
||||
type rawRecipientInfo struct {
|
||||
Header *rawHeader `json:"header,omitempty"`
|
||||
EncryptedKey string `json:"encrypted_key,omitempty"`
|
||||
}
|
||||
|
||||
// JSONWebEncryption represents an encrypted JWE object after parsing.
|
||||
type JSONWebEncryption struct {
|
||||
Header Header
|
||||
protected, unprotected *rawHeader
|
||||
recipients []recipientInfo
|
||||
aad, iv, ciphertext, tag []byte
|
||||
original *rawJSONWebEncryption
|
||||
}
|
||||
|
||||
// recipientInfo represents a raw JWE Per-Recipient header JSON object after parsing.
|
||||
type recipientInfo struct {
|
||||
header *rawHeader
|
||||
encryptedKey []byte
|
||||
}
|
||||
|
||||
// GetAuthData retrieves the (optional) authenticated data attached to the object.
|
||||
func (obj JSONWebEncryption) GetAuthData() []byte {
|
||||
if obj.aad != nil {
|
||||
out := make([]byte, len(obj.aad))
|
||||
copy(out, obj.aad)
|
||||
return out
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get the merged header values
|
||||
func (obj JSONWebEncryption) mergedHeaders(recipient *recipientInfo) rawHeader {
|
||||
out := rawHeader{}
|
||||
out.merge(obj.protected)
|
||||
out.merge(obj.unprotected)
|
||||
|
||||
if recipient != nil {
|
||||
out.merge(recipient.header)
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// Get the additional authenticated data from a JWE object.
|
||||
func (obj JSONWebEncryption) computeAuthData() []byte {
|
||||
var protected string
|
||||
|
||||
if obj.original != nil && obj.original.Protected != nil {
|
||||
protected = obj.original.Protected.base64()
|
||||
} else if obj.protected != nil {
|
||||
protected = base64.RawURLEncoding.EncodeToString(mustSerializeJSON((obj.protected)))
|
||||
} else {
|
||||
protected = ""
|
||||
}
|
||||
|
||||
output := []byte(protected)
|
||||
if obj.aad != nil {
|
||||
output = append(output, '.')
|
||||
output = append(output, []byte(base64.RawURLEncoding.EncodeToString(obj.aad))...)
|
||||
}
|
||||
|
||||
return output
|
||||
}
|
||||
|
||||
// ParseEncrypted parses an encrypted message in compact or full serialization format.
|
||||
func ParseEncrypted(input string) (*JSONWebEncryption, error) {
|
||||
input = stripWhitespace(input)
|
||||
if strings.HasPrefix(input, "{") {
|
||||
return parseEncryptedFull(input)
|
||||
}
|
||||
|
||||
return parseEncryptedCompact(input)
|
||||
}
|
||||
|
||||
// parseEncryptedFull parses a message in compact format.
|
||||
func parseEncryptedFull(input string) (*JSONWebEncryption, error) {
|
||||
var parsed rawJSONWebEncryption
|
||||
err := json.Unmarshal([]byte(input), &parsed)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return parsed.sanitized()
|
||||
}
|
||||
|
||||
// sanitized produces a cleaned-up JWE object from the raw JSON.
|
||||
func (parsed *rawJSONWebEncryption) sanitized() (*JSONWebEncryption, error) {
|
||||
obj := &JSONWebEncryption{
|
||||
original: parsed,
|
||||
unprotected: parsed.Unprotected,
|
||||
}
|
||||
|
||||
// Check that there is not a nonce in the unprotected headers
|
||||
if parsed.Unprotected != nil {
|
||||
if nonce := parsed.Unprotected.getNonce(); nonce != "" {
|
||||
return nil, ErrUnprotectedNonce
|
||||
}
|
||||
}
|
||||
if parsed.Header != nil {
|
||||
if nonce := parsed.Header.getNonce(); nonce != "" {
|
||||
return nil, ErrUnprotectedNonce
|
||||
}
|
||||
}
|
||||
|
||||
if parsed.Protected != nil && len(parsed.Protected.bytes()) > 0 {
|
||||
err := json.Unmarshal(parsed.Protected.bytes(), &obj.protected)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("square/go-jose: invalid protected header: %s, %s", err, parsed.Protected.base64())
|
||||
}
|
||||
}
|
||||
|
||||
// Note: this must be called _after_ we parse the protected header,
|
||||
// otherwise fields from the protected header will not get picked up.
|
||||
var err error
|
||||
mergedHeaders := obj.mergedHeaders(nil)
|
||||
obj.Header, err = mergedHeaders.sanitized()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("square/go-jose: cannot sanitize merged headers: %v (%v)", err, mergedHeaders)
|
||||
}
|
||||
|
||||
if len(parsed.Recipients) == 0 {
|
||||
obj.recipients = []recipientInfo{
|
||||
{
|
||||
header: parsed.Header,
|
||||
encryptedKey: parsed.EncryptedKey.bytes(),
|
||||
},
|
||||
}
|
||||
} else {
|
||||
obj.recipients = make([]recipientInfo, len(parsed.Recipients))
|
||||
for r := range parsed.Recipients {
|
||||
encryptedKey, err := base64.RawURLEncoding.DecodeString(parsed.Recipients[r].EncryptedKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check that there is not a nonce in the unprotected header
|
||||
if parsed.Recipients[r].Header != nil && parsed.Recipients[r].Header.getNonce() != "" {
|
||||
return nil, ErrUnprotectedNonce
|
||||
}
|
||||
|
||||
obj.recipients[r].header = parsed.Recipients[r].Header
|
||||
obj.recipients[r].encryptedKey = encryptedKey
|
||||
}
|
||||
}
|
||||
|
||||
for _, recipient := range obj.recipients {
|
||||
headers := obj.mergedHeaders(&recipient)
|
||||
if headers.getAlgorithm() == "" || headers.getEncryption() == "" {
|
||||
return nil, fmt.Errorf("square/go-jose: message is missing alg/enc headers")
|
||||
}
|
||||
}
|
||||
|
||||
obj.iv = parsed.Iv.bytes()
|
||||
obj.ciphertext = parsed.Ciphertext.bytes()
|
||||
obj.tag = parsed.Tag.bytes()
|
||||
obj.aad = parsed.Aad.bytes()
|
||||
|
||||
return obj, nil
|
||||
}
|
||||
|
||||
// parseEncryptedCompact parses a message in compact format.
|
||||
func parseEncryptedCompact(input string) (*JSONWebEncryption, error) {
|
||||
parts := strings.Split(input, ".")
|
||||
if len(parts) != 5 {
|
||||
return nil, fmt.Errorf("square/go-jose: compact JWE format must have five parts")
|
||||
}
|
||||
|
||||
rawProtected, err := base64.RawURLEncoding.DecodeString(parts[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
encryptedKey, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
iv, err := base64.RawURLEncoding.DecodeString(parts[2])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ciphertext, err := base64.RawURLEncoding.DecodeString(parts[3])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tag, err := base64.RawURLEncoding.DecodeString(parts[4])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
raw := &rawJSONWebEncryption{
|
||||
Protected: newBuffer(rawProtected),
|
||||
EncryptedKey: newBuffer(encryptedKey),
|
||||
Iv: newBuffer(iv),
|
||||
Ciphertext: newBuffer(ciphertext),
|
||||
Tag: newBuffer(tag),
|
||||
}
|
||||
|
||||
return raw.sanitized()
|
||||
}
|
||||
|
||||
// CompactSerialize serializes an object using the compact serialization format.
|
||||
func (obj JSONWebEncryption) CompactSerialize() (string, error) {
|
||||
if len(obj.recipients) != 1 || obj.unprotected != nil ||
|
||||
obj.protected == nil || obj.recipients[0].header != nil {
|
||||
return "", ErrNotSupported
|
||||
}
|
||||
|
||||
serializedProtected := mustSerializeJSON(obj.protected)
|
||||
|
||||
return fmt.Sprintf(
|
||||
"%s.%s.%s.%s.%s",
|
||||
base64.RawURLEncoding.EncodeToString(serializedProtected),
|
||||
base64.RawURLEncoding.EncodeToString(obj.recipients[0].encryptedKey),
|
||||
base64.RawURLEncoding.EncodeToString(obj.iv),
|
||||
base64.RawURLEncoding.EncodeToString(obj.ciphertext),
|
||||
base64.RawURLEncoding.EncodeToString(obj.tag)), nil
|
||||
}
|
||||
|
||||
// FullSerialize serializes an object using the full JSON serialization format.
|
||||
func (obj JSONWebEncryption) FullSerialize() string {
|
||||
raw := rawJSONWebEncryption{
|
||||
Unprotected: obj.unprotected,
|
||||
Iv: newBuffer(obj.iv),
|
||||
Ciphertext: newBuffer(obj.ciphertext),
|
||||
EncryptedKey: newBuffer(obj.recipients[0].encryptedKey),
|
||||
Tag: newBuffer(obj.tag),
|
||||
Aad: newBuffer(obj.aad),
|
||||
Recipients: []rawRecipientInfo{},
|
||||
}
|
||||
|
||||
if len(obj.recipients) > 1 {
|
||||
for _, recipient := range obj.recipients {
|
||||
info := rawRecipientInfo{
|
||||
Header: recipient.header,
|
||||
EncryptedKey: base64.RawURLEncoding.EncodeToString(recipient.encryptedKey),
|
||||
}
|
||||
raw.Recipients = append(raw.Recipients, info)
|
||||
}
|
||||
} else {
|
||||
// Use flattened serialization
|
||||
raw.Header = obj.recipients[0].header
|
||||
raw.EncryptedKey = newBuffer(obj.recipients[0].encryptedKey)
|
||||
}
|
||||
|
||||
if obj.protected != nil {
|
||||
raw.Protected = newBuffer(mustSerializeJSON(obj.protected))
|
||||
}
|
||||
|
||||
return string(mustSerializeJSON(raw))
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue