mirror of https://github.com/status-im/consul.git
Add HTTP endpoints for config entry management (#5718)
This commit is contained in:
parent
f67e12eb6f
commit
aba54cec55
|
@ -0,0 +1,128 @@
|
|||
package agent
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
)
|
||||
|
||||
// Config switches on the different CRUD operations for config entries.
|
||||
func (s *HTTPServer) Config(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
switch req.Method {
|
||||
case "GET":
|
||||
return s.configGet(resp, req)
|
||||
|
||||
case "DELETE":
|
||||
return s.configDelete(resp, req)
|
||||
|
||||
default:
|
||||
return nil, MethodNotAllowedError{req.Method, []string{"GET", "DELETE"}}
|
||||
}
|
||||
}
|
||||
|
||||
// configGet gets either a specific config entry, or lists all config entries
|
||||
// of a kind if no name is provided.
|
||||
func (s *HTTPServer) configGet(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
var args structs.ConfigEntryQuery
|
||||
if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done {
|
||||
return nil, nil
|
||||
}
|
||||
pathArgs := strings.SplitN(strings.TrimPrefix(req.URL.Path, "/v1/config/"), "/", 2)
|
||||
|
||||
switch len(pathArgs) {
|
||||
case 2:
|
||||
// Both kind/name provided.
|
||||
args.Kind = pathArgs[0]
|
||||
args.Name = pathArgs[1]
|
||||
|
||||
var reply structs.ConfigEntryResponse
|
||||
if err := s.agent.RPC("ConfigEntry.Get", &args, &reply); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if reply.Entry == nil {
|
||||
return nil, fmt.Errorf("Config entry not found for %q / %q", pathArgs[0], pathArgs[1])
|
||||
}
|
||||
|
||||
return reply.Entry, nil
|
||||
case 1:
|
||||
// Only kind provided, list entries.
|
||||
args.Kind = pathArgs[0]
|
||||
|
||||
var reply structs.IndexedConfigEntries
|
||||
if err := s.agent.RPC("ConfigEntry.List", &args, &reply); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return reply.Entries, nil
|
||||
default:
|
||||
resp.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprintf(resp, "Must provide either a kind or both kind and name")
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
// configDelete deletes the given config entry.
|
||||
func (s *HTTPServer) configDelete(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
var args structs.ConfigEntryRequest
|
||||
s.parseDC(req, &args.Datacenter)
|
||||
s.parseToken(req, &args.Token)
|
||||
pathArgs := strings.SplitN(strings.TrimPrefix(req.URL.Path, "/v1/config/"), "/", 2)
|
||||
|
||||
if len(pathArgs) != 2 {
|
||||
resp.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprintf(resp, "Must provide both a kind and name to delete")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
entry, err := structs.MakeConfigEntry(pathArgs[0], pathArgs[1])
|
||||
if err != nil {
|
||||
resp.WriteHeader(http.StatusBadRequest)
|
||||
fmt.Fprintf(resp, "%v", err)
|
||||
return nil, nil
|
||||
}
|
||||
args.Entry = entry
|
||||
|
||||
var reply struct{}
|
||||
if err := s.agent.RPC("ConfigEntry.Delete", &args, &reply); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return reply, nil
|
||||
}
|
||||
|
||||
// ConfigCreate applies the given config entry update.
|
||||
func (s *HTTPServer) ConfigApply(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
args := structs.ConfigEntryRequest{
|
||||
Op: structs.ConfigEntryUpsert,
|
||||
}
|
||||
s.parseDC(req, &args.Datacenter)
|
||||
s.parseToken(req, &args.Token)
|
||||
|
||||
var raw map[string]interface{}
|
||||
if err := decodeBody(req, &raw, nil); err != nil {
|
||||
return nil, BadRequestError{Reason: fmt.Sprintf("Request decoding failed: %v", err)}
|
||||
}
|
||||
|
||||
if entry, err := structs.DecodeConfigEntry(raw); err == nil {
|
||||
args.Entry = entry
|
||||
} else {
|
||||
return nil, BadRequestError{Reason: fmt.Sprintf("Request decoding failed: %v", err)}
|
||||
}
|
||||
|
||||
// Check for cas value
|
||||
if casStr := req.URL.Query().Get("cas"); casStr != "" {
|
||||
casVal, err := strconv.ParseUint(casStr, 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
args.Op = structs.ConfigEntryUpsertCAS
|
||||
args.Entry.GetRaftIndex().ModifyIndex = casVal
|
||||
}
|
||||
|
||||
var reply struct{}
|
||||
return nil, s.agent.RPC("ConfigEntry.Apply", &args, &reply)
|
||||
}
|
|
@ -0,0 +1,325 @@
|
|||
package agent
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/hashicorp/consul/testrpc"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestConfig_Get(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
a := NewTestAgent(t, t.Name(), "")
|
||||
defer a.Shutdown()
|
||||
testrpc.WaitForTestAgent(t, a.RPC, "dc1")
|
||||
|
||||
// Create some config entries.
|
||||
reqs := []structs.ConfigEntryRequest{
|
||||
{
|
||||
Datacenter: "dc1",
|
||||
Entry: &structs.ServiceConfigEntry{
|
||||
Name: "foo",
|
||||
},
|
||||
},
|
||||
{
|
||||
Datacenter: "dc1",
|
||||
Entry: &structs.ServiceConfigEntry{
|
||||
Name: "bar",
|
||||
},
|
||||
},
|
||||
{
|
||||
Datacenter: "dc1",
|
||||
Entry: &structs.ProxyConfigEntry{
|
||||
Name: structs.ProxyConfigGlobal,
|
||||
Config: map[string]interface{}{
|
||||
"foo": "bar",
|
||||
"bar": 1,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, req := range reqs {
|
||||
var out struct{}
|
||||
require.NoError(t, a.RPC("ConfigEntry.Apply", &req, &out))
|
||||
}
|
||||
|
||||
t.Run("get a single service entry", func(t *testing.T) {
|
||||
req, _ := http.NewRequest("GET", "/v1/config/service-defaults/foo", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := a.srv.Config(resp, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
value := obj.(structs.ConfigEntry)
|
||||
require.Equal(t, structs.ServiceDefaults, value.GetKind())
|
||||
entry := value.(*structs.ServiceConfigEntry)
|
||||
require.Equal(t, entry.Name, "foo")
|
||||
})
|
||||
t.Run("list both service entries", func(t *testing.T) {
|
||||
req, _ := http.NewRequest("GET", "/v1/config/service-defaults", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := a.srv.Config(resp, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
value := obj.([]structs.ConfigEntry)
|
||||
require.Len(t, value, 2)
|
||||
require.Equal(t, value[0].(*structs.ServiceConfigEntry).Name, "bar")
|
||||
require.Equal(t, value[1].(*structs.ServiceConfigEntry).Name, "foo")
|
||||
})
|
||||
t.Run("get global proxy config", func(t *testing.T) {
|
||||
req, _ := http.NewRequest("GET", "/v1/config/proxy-defaults/global", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := a.srv.Config(resp, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
value := obj.(structs.ConfigEntry)
|
||||
require.Equal(t, value.GetKind(), structs.ProxyDefaults)
|
||||
entry := value.(*structs.ProxyConfigEntry)
|
||||
require.Equal(t, structs.ProxyConfigGlobal, entry.Name)
|
||||
require.Contains(t, entry.Config, "foo")
|
||||
require.Equal(t, "bar", entry.Config["foo"])
|
||||
})
|
||||
t.Run("error on no arguments", func(t *testing.T) {
|
||||
req, _ := http.NewRequest("GET", "/v1/config/", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
_, err := a.srv.Config(resp, req)
|
||||
require.Error(t, errors.New("Must provide either a kind or both kind and name"), err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestConfig_Delete(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require := require.New(t)
|
||||
a := NewTestAgent(t, t.Name(), "")
|
||||
defer a.Shutdown()
|
||||
testrpc.WaitForTestAgent(t, a.RPC, "dc1")
|
||||
|
||||
// Create some config entries.
|
||||
reqs := []structs.ConfigEntryRequest{
|
||||
{
|
||||
Datacenter: "dc1",
|
||||
Entry: &structs.ServiceConfigEntry{
|
||||
Name: "foo",
|
||||
},
|
||||
},
|
||||
{
|
||||
Datacenter: "dc1",
|
||||
Entry: &structs.ServiceConfigEntry{
|
||||
Name: "bar",
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, req := range reqs {
|
||||
var out struct{}
|
||||
require.NoError(a.RPC("ConfigEntry.Apply", &req, &out))
|
||||
}
|
||||
|
||||
// Delete an entry.
|
||||
{
|
||||
req, _ := http.NewRequest("DELETE", "/v1/config/service-defaults/bar", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
_, err := a.srv.Config(resp, req)
|
||||
require.NoError(err)
|
||||
}
|
||||
// Get the remaining entry.
|
||||
{
|
||||
args := structs.ConfigEntryQuery{
|
||||
Kind: structs.ServiceDefaults,
|
||||
Datacenter: "dc1",
|
||||
}
|
||||
var out structs.IndexedConfigEntries
|
||||
require.NoError(a.RPC("ConfigEntry.List", &args, &out))
|
||||
require.Equal(structs.ServiceDefaults, out.Kind)
|
||||
require.Len(out.Entries, 1)
|
||||
entry := out.Entries[0].(*structs.ServiceConfigEntry)
|
||||
require.Equal(entry.Name, "foo")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_Apply(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require := require.New(t)
|
||||
a := NewTestAgent(t, t.Name(), "")
|
||||
defer a.Shutdown()
|
||||
testrpc.WaitForTestAgent(t, a.RPC, "dc1")
|
||||
|
||||
// Create some config entries.
|
||||
body := bytes.NewBuffer([]byte(`
|
||||
{
|
||||
"Kind": "service-defaults",
|
||||
"Name": "foo",
|
||||
"Protocol": "tcp"
|
||||
}`))
|
||||
|
||||
req, _ := http.NewRequest("PUT", "/v1/config", body)
|
||||
resp := httptest.NewRecorder()
|
||||
_, err := a.srv.ConfigApply(resp, req)
|
||||
require.NoError(err)
|
||||
if resp.Code != 200 {
|
||||
t.Fatalf(resp.Body.String())
|
||||
}
|
||||
|
||||
// Get the remaining entry.
|
||||
{
|
||||
args := structs.ConfigEntryQuery{
|
||||
Kind: structs.ServiceDefaults,
|
||||
Name: "foo",
|
||||
Datacenter: "dc1",
|
||||
}
|
||||
var out structs.ConfigEntryResponse
|
||||
require.NoError(a.RPC("ConfigEntry.Get", &args, &out))
|
||||
require.NotNil(out.Entry)
|
||||
entry := out.Entry.(*structs.ServiceConfigEntry)
|
||||
require.Equal(entry.Name, "foo")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_Apply_CAS(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require := require.New(t)
|
||||
a := NewTestAgent(t, t.Name(), "")
|
||||
defer a.Shutdown()
|
||||
testrpc.WaitForTestAgent(t, a.RPC, "dc1")
|
||||
|
||||
// Create some config entries.
|
||||
body := bytes.NewBuffer([]byte(`
|
||||
{
|
||||
"Kind": "service-defaults",
|
||||
"Name": "foo",
|
||||
"Protocol": "tcp"
|
||||
}`))
|
||||
|
||||
req, _ := http.NewRequest("PUT", "/v1/config", body)
|
||||
resp := httptest.NewRecorder()
|
||||
_, err := a.srv.ConfigApply(resp, req)
|
||||
require.NoError(err)
|
||||
if resp.Code != 200 {
|
||||
t.Fatalf(resp.Body.String())
|
||||
}
|
||||
|
||||
// Get the entry remaining entry.
|
||||
args := structs.ConfigEntryQuery{
|
||||
Kind: structs.ServiceDefaults,
|
||||
Name: "foo",
|
||||
Datacenter: "dc1",
|
||||
}
|
||||
|
||||
out := &structs.ConfigEntryResponse{}
|
||||
require.NoError(a.RPC("ConfigEntry.Get", &args, out))
|
||||
require.NotNil(out.Entry)
|
||||
entry := out.Entry.(*structs.ServiceConfigEntry)
|
||||
|
||||
req, _ = http.NewRequest("PUT", "/v1/config?cas=0", body)
|
||||
resp = httptest.NewRecorder()
|
||||
_, err = a.srv.ConfigApply(resp, req)
|
||||
require.Error(err)
|
||||
require.EqualValues(200, resp.Code, resp.Body.String())
|
||||
|
||||
body = bytes.NewBuffer([]byte(`
|
||||
{
|
||||
"Kind": "service-defaults",
|
||||
"Name": "foo",
|
||||
"Protocol": "udp"
|
||||
}
|
||||
`))
|
||||
|
||||
req, _ = http.NewRequest("PUT", fmt.Sprintf("/v1/config?cas=%d", entry.GetRaftIndex().ModifyIndex), body)
|
||||
resp = httptest.NewRecorder()
|
||||
_, err = a.srv.ConfigApply(resp, req)
|
||||
require.NoError(err)
|
||||
require.EqualValues(200, resp.Code, resp.Body.String())
|
||||
|
||||
// Get the entry remaining entry.
|
||||
args = structs.ConfigEntryQuery{
|
||||
Kind: structs.ServiceDefaults,
|
||||
Name: "foo",
|
||||
Datacenter: "dc1",
|
||||
}
|
||||
|
||||
out = &structs.ConfigEntryResponse{}
|
||||
require.NoError(a.RPC("ConfigEntry.Get", &args, out))
|
||||
require.NotNil(out.Entry)
|
||||
newEntry := out.Entry.(*structs.ServiceConfigEntry)
|
||||
require.NotEqual(entry.GetRaftIndex(), newEntry.GetRaftIndex())
|
||||
}
|
||||
|
||||
func TestConfig_Apply_Decoding(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
a := NewTestAgent(t, t.Name(), "")
|
||||
defer a.Shutdown()
|
||||
testrpc.WaitForTestAgent(t, a.RPC, "dc1")
|
||||
|
||||
t.Run("No Kind", func(t *testing.T) {
|
||||
body := bytes.NewBuffer([]byte(
|
||||
`{
|
||||
"Name": "foo",
|
||||
"Protocol": "tcp"
|
||||
}`))
|
||||
|
||||
req, _ := http.NewRequest("PUT", "/v1/config", body)
|
||||
resp := httptest.NewRecorder()
|
||||
|
||||
_, err := a.srv.ConfigApply(resp, req)
|
||||
require.Error(t, err)
|
||||
badReq, ok := err.(BadRequestError)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "Request decoding failed: Payload does not contain a kind/Kind key at the top level", badReq.Reason)
|
||||
})
|
||||
|
||||
t.Run("Kind Not String", func(t *testing.T) {
|
||||
body := bytes.NewBuffer([]byte(
|
||||
`{
|
||||
"Kind": 123,
|
||||
"Name": "foo",
|
||||
"Protocol": "tcp"
|
||||
}`))
|
||||
|
||||
req, _ := http.NewRequest("PUT", "/v1/config", body)
|
||||
resp := httptest.NewRecorder()
|
||||
|
||||
_, err := a.srv.ConfigApply(resp, req)
|
||||
require.Error(t, err)
|
||||
badReq, ok := err.(BadRequestError)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "Request decoding failed: Kind value in payload is not a string", badReq.Reason)
|
||||
})
|
||||
|
||||
t.Run("Lowercase kind", func(t *testing.T) {
|
||||
body := bytes.NewBuffer([]byte(
|
||||
`{
|
||||
"kind": "service-defaults",
|
||||
"Name": "foo",
|
||||
"Protocol": "tcp"
|
||||
}`))
|
||||
|
||||
req, _ := http.NewRequest("PUT", "/v1/config", body)
|
||||
resp := httptest.NewRecorder()
|
||||
_, err := a.srv.ConfigApply(resp, req)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 200, resp.Code, resp.Body.String())
|
||||
|
||||
// Get the remaining entry.
|
||||
{
|
||||
args := structs.ConfigEntryQuery{
|
||||
Kind: structs.ServiceDefaults,
|
||||
Name: "foo",
|
||||
Datacenter: "dc1",
|
||||
}
|
||||
var out structs.ConfigEntryResponse
|
||||
require.NoError(t, a.RPC("ConfigEntry.Get", &args, &out))
|
||||
require.NotNil(t, out.Entry)
|
||||
entry := out.Entry.(*structs.ServiceConfigEntry)
|
||||
require.Equal(t, entry.Name, "foo")
|
||||
}
|
||||
})
|
||||
}
|
|
@ -52,7 +52,7 @@ func (c *ConfigEntry) Apply(args *structs.ConfigEntryRequest, reply *struct{}) e
|
|||
}
|
||||
|
||||
// Get returns a single config entry by Kind/Name.
|
||||
func (c *ConfigEntry) Get(args *structs.ConfigEntryQuery, reply *structs.IndexedConfigEntries) error {
|
||||
func (c *ConfigEntry) Get(args *structs.ConfigEntryQuery, reply *structs.ConfigEntryResponse) error {
|
||||
if done, err := c.srv.forward("ConfigEntry.Get", args, args, reply); done {
|
||||
return err
|
||||
}
|
||||
|
@ -87,8 +87,7 @@ func (c *ConfigEntry) Get(args *structs.ConfigEntryQuery, reply *structs.Indexed
|
|||
return nil
|
||||
}
|
||||
|
||||
reply.Kind = args.Kind
|
||||
reply.Entries = []structs.ConfigEntry{entry}
|
||||
reply.Entry = entry
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
|
|
@ -153,10 +153,10 @@ func TestConfigEntry_Get(t *testing.T) {
|
|||
Name: "foo",
|
||||
Datacenter: s1.config.Datacenter,
|
||||
}
|
||||
var out structs.IndexedConfigEntries
|
||||
var out structs.ConfigEntryResponse
|
||||
require.NoError(msgpackrpc.CallWithCodec(codec, "ConfigEntry.Get", &args, &out))
|
||||
|
||||
serviceConf, ok := out.Entries[0].(*structs.ServiceConfigEntry)
|
||||
serviceConf, ok := out.Entry.(*structs.ServiceConfigEntry)
|
||||
require.True(ok)
|
||||
require.Equal("foo", serviceConf.Name)
|
||||
require.Equal(structs.ServiceDefaults, serviceConf.Kind)
|
||||
|
@ -218,7 +218,7 @@ operator = "read"
|
|||
Datacenter: s1.config.Datacenter,
|
||||
QueryOptions: structs.QueryOptions{Token: id},
|
||||
}
|
||||
var out structs.IndexedConfigEntries
|
||||
var out structs.ConfigEntryResponse
|
||||
err := msgpackrpc.CallWithCodec(codec, "ConfigEntry.Get", &args, &out)
|
||||
if !acl.IsErrPermissionDenied(err) {
|
||||
t.Fatalf("err: %v", err)
|
||||
|
@ -228,7 +228,7 @@ operator = "read"
|
|||
args.Name = "foo"
|
||||
require.NoError(msgpackrpc.CallWithCodec(codec, "ConfigEntry.Get", &args, &out))
|
||||
|
||||
serviceConf, ok := out.Entries[0].(*structs.ServiceConfigEntry)
|
||||
serviceConf, ok := out.Entry.(*structs.ServiceConfigEntry)
|
||||
require.True(ok)
|
||||
require.Equal("foo", serviceConf.Name)
|
||||
require.Equal(structs.ServiceDefaults, serviceConf.Kind)
|
||||
|
|
|
@ -446,6 +446,14 @@ func (c *FSM) applyConfigEntryOperation(buf []byte, index uint64) interface{} {
|
|||
}
|
||||
|
||||
switch req.Op {
|
||||
case structs.ConfigEntryUpsertCAS:
|
||||
defer metrics.MeasureSinceWithLabels([]string{"fsm", "config_entry", req.Entry.GetKind()}, time.Now(),
|
||||
[]metrics.Label{{Name: "op", Value: "upsert"}})
|
||||
updated, err := c.state.EnsureConfigEntryCAS(index, req.Entry.GetRaftIndex().ModifyIndex, req.Entry)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return updated
|
||||
case structs.ConfigEntryUpsert:
|
||||
defer metrics.MeasureSinceWithLabels([]string{"fsm", "config_entry", req.Entry.GetKind()}, time.Now(),
|
||||
[]metrics.Label{{Name: "op", Value: "upsert"}})
|
||||
|
|
|
@ -1394,14 +1394,6 @@ func TestFSM_ConfigEntry(t *testing.T) {
|
|||
require.NoError(err)
|
||||
entry.RaftIndex.CreateIndex = 1
|
||||
entry.RaftIndex.ModifyIndex = 1
|
||||
|
||||
proxyConf, ok := config.(*structs.ProxyConfigEntry)
|
||||
require.True(ok)
|
||||
|
||||
// Read the map[string]interface{} back out.
|
||||
value, _ := proxyConf.Config["foo"].([]uint8)
|
||||
proxyConf.Config["foo"] = structs.Uint8ToString(value)
|
||||
|
||||
require.Equal(entry, config)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -899,14 +899,18 @@ func (s *Server) bootstrapConfigEntries(entries []structs.ConfigEntry) error {
|
|||
|
||||
state := s.fsm.State()
|
||||
for _, entry := range entries {
|
||||
// avoid a round trip through Raft if we know the CAS is going to fail
|
||||
_, existing, err := state.ConfigEntry(nil, entry.GetKind(), entry.GetName())
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to determine whether the configuration for %q / %q already exists: %v", entry.GetKind(), entry.GetName(), err)
|
||||
}
|
||||
|
||||
if existing == nil {
|
||||
// ensure the ModifyIndex is set to 0 for the CAS request
|
||||
entry.GetRaftIndex().ModifyIndex = 0
|
||||
|
||||
req := structs.ConfigEntryRequest{
|
||||
Op: structs.ConfigEntryUpsert,
|
||||
Op: structs.ConfigEntryUpsertCAS,
|
||||
Datacenter: s.config.Datacenter,
|
||||
Entry: entry,
|
||||
}
|
||||
|
|
|
@ -1213,9 +1213,7 @@ func TestLeader_ConfigEntryBootstrap(t *testing.T) {
|
|||
Kind: structs.ProxyDefaults,
|
||||
Name: structs.ProxyConfigGlobal,
|
||||
Config: map[string]interface{}{
|
||||
// these are made a []uint8 and a int64 to allow the Equals test to pass
|
||||
// otherwise it will fail complaining about data types
|
||||
"foo": []uint8("bar"),
|
||||
"foo": "bar",
|
||||
"bar": int64(1),
|
||||
},
|
||||
}
|
||||
|
|
|
@ -980,7 +980,7 @@ func TestServer_Reload(t *testing.T) {
|
|||
Config: map[string]interface{}{
|
||||
// these are made a []uint8 and a int64 to allow the Equals test to pass
|
||||
// otherwise it will fail complaining about data types
|
||||
"foo": []uint8("bar"),
|
||||
"foo": "bar",
|
||||
"bar": int64(1),
|
||||
},
|
||||
}
|
||||
|
|
|
@ -69,6 +69,8 @@ func init() {
|
|||
registerEndpoint("/v1/catalog/services", []string{"GET"}, (*HTTPServer).CatalogServices)
|
||||
registerEndpoint("/v1/catalog/service/", []string{"GET"}, (*HTTPServer).CatalogServiceNodes)
|
||||
registerEndpoint("/v1/catalog/node/", []string{"GET"}, (*HTTPServer).CatalogNodeServices)
|
||||
registerEndpoint("/v1/config/", []string{"GET", "DELETE"}, (*HTTPServer).Config)
|
||||
registerEndpoint("/v1/config", []string{"PUT"}, (*HTTPServer).ConfigApply)
|
||||
registerEndpoint("/v1/connect/ca/configuration", []string{"GET", "PUT"}, (*HTTPServer).ConnectCAConfiguration)
|
||||
registerEndpoint("/v1/connect/ca/roots", []string{"GET"}, (*HTTPServer).ConnectCARoots)
|
||||
registerEndpoint("/v1/connect/intentions", []string{"GET", "POST"}, (*HTTPServer).IntentionEndpoint)
|
||||
|
|
|
@ -7,8 +7,10 @@ import (
|
|||
|
||||
"github.com/hashicorp/consul/acl"
|
||||
"github.com/hashicorp/consul/agent/cache"
|
||||
"github.com/hashicorp/consul/lib"
|
||||
"github.com/hashicorp/go-msgpack/codec"
|
||||
"github.com/mitchellh/hashstructure"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -20,7 +22,8 @@ const (
|
|||
DefaultServiceProtocol = "tcp"
|
||||
)
|
||||
|
||||
// ConfigEntry is the
|
||||
// ConfigEntry is the interface for centralized configuration stored in Raft.
|
||||
// Currently only service-defaults and proxy-defaults are supported.
|
||||
type ConfigEntry interface {
|
||||
GetKind() string
|
||||
GetName() string
|
||||
|
@ -159,11 +162,101 @@ func (e *ProxyConfigEntry) GetRaftIndex() *RaftIndex {
|
|||
return &e.RaftIndex
|
||||
}
|
||||
|
||||
func (e *ProxyConfigEntry) MarshalBinary() (data []byte, err error) {
|
||||
// We mainly want to implement the BinaryMarshaller interface so that
|
||||
// we can fixup some msgpack types to coerce them into JSON compatible
|
||||
// values. No special encoding needs to be done - we just simply msgpack
|
||||
// encode the struct which requires a type alias to prevent recursively
|
||||
// calling this function.
|
||||
|
||||
type alias ProxyConfigEntry
|
||||
|
||||
a := alias(*e)
|
||||
|
||||
// bs will grow if needed but allocate enough to avoid reallocation in common
|
||||
// case.
|
||||
bs := make([]byte, 128)
|
||||
enc := codec.NewEncoderBytes(&bs, msgpackHandle)
|
||||
err = enc.Encode(a)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return bs, nil
|
||||
}
|
||||
|
||||
func (e *ProxyConfigEntry) UnmarshalBinary(data []byte) error {
|
||||
// The goal here is to add a post-decoding operation to
|
||||
// decoding of a ProxyConfigEntry. The cleanest way I could
|
||||
// find to do so was to implement the BinaryMarshaller interface
|
||||
// and use a type alias to do the original round of decoding,
|
||||
// followed by a MapWalk of the Config to coerce everything
|
||||
// into JSON compatible types.
|
||||
type alias ProxyConfigEntry
|
||||
|
||||
var a alias
|
||||
dec := codec.NewDecoderBytes(data, msgpackHandle)
|
||||
if err := dec.Decode(&a); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*e = ProxyConfigEntry(a)
|
||||
|
||||
config, err := lib.MapWalk(e.Config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
e.Config = config
|
||||
return nil
|
||||
}
|
||||
|
||||
// DecodeConfigEntry can be used to decode a ConfigEntry from a raw map value.
|
||||
// Currently its used in the HTTP API to decode ConfigEntry structs coming from
|
||||
// JSON. Unlike some of our custom binary encodings we don't have a preamble including
|
||||
// the kind so we will not have a concrete type to decode into. In those cases we must
|
||||
// first decode into a map[string]interface{} and then call this function to decode
|
||||
// into a concrete type.
|
||||
func DecodeConfigEntry(raw map[string]interface{}) (ConfigEntry, error) {
|
||||
var entry ConfigEntry
|
||||
|
||||
kindVal, ok := raw["Kind"]
|
||||
if !ok {
|
||||
kindVal, ok = raw["kind"]
|
||||
}
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("Payload does not contain a kind/Kind key at the top level")
|
||||
}
|
||||
|
||||
if kindStr, ok := kindVal.(string); ok {
|
||||
newEntry, err := MakeConfigEntry(kindStr, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
entry = newEntry
|
||||
} else {
|
||||
return nil, fmt.Errorf("Kind value in payload is not a string")
|
||||
}
|
||||
|
||||
decodeConf := &mapstructure.DecoderConfig{
|
||||
DecodeHook: mapstructure.StringToTimeDurationHookFunc(),
|
||||
Result: &entry,
|
||||
}
|
||||
|
||||
decoder, err := mapstructure.NewDecoder(decodeConf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return entry, decoder.Decode(raw)
|
||||
}
|
||||
|
||||
type ConfigEntryOp string
|
||||
|
||||
const (
|
||||
ConfigEntryUpsert ConfigEntryOp = "upsert"
|
||||
ConfigEntryDelete ConfigEntryOp = "delete"
|
||||
ConfigEntryUpsert ConfigEntryOp = "upsert"
|
||||
ConfigEntryUpsertCAS ConfigEntryOp = "upsert-cas"
|
||||
ConfigEntryDelete ConfigEntryOp = "delete"
|
||||
)
|
||||
|
||||
// ConfigEntryRequest is used when creating/updating/deleting a ConfigEntry.
|
||||
|
@ -297,3 +390,53 @@ type ServiceConfigResponse struct {
|
|||
|
||||
QueryMeta
|
||||
}
|
||||
|
||||
// ConfigEntryResponse returns a single ConfigEntry
|
||||
type ConfigEntryResponse struct {
|
||||
Entry ConfigEntry
|
||||
QueryMeta
|
||||
}
|
||||
|
||||
func (c *ConfigEntryResponse) MarshalBinary() (data []byte, err error) {
|
||||
// bs will grow if needed but allocate enough to avoid reallocation in common
|
||||
// case.
|
||||
bs := make([]byte, 128)
|
||||
enc := codec.NewEncoderBytes(&bs, msgpackHandle)
|
||||
|
||||
if err := enc.Encode(c.Entry.GetKind()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := enc.Encode(c.Entry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := enc.Encode(c.QueryMeta); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return bs, nil
|
||||
}
|
||||
|
||||
func (c *ConfigEntryResponse) UnmarshalBinary(data []byte) error {
|
||||
dec := codec.NewDecoderBytes(data, msgpackHandle)
|
||||
|
||||
var kind string
|
||||
if err := dec.Decode(&kind); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
entry, err := MakeConfigEntry(kind, "")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := dec.Decode(entry); err != nil {
|
||||
return err
|
||||
}
|
||||
c.Entry = entry
|
||||
|
||||
if err := dec.Decode(&c.QueryMeta); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -0,0 +1,211 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/mitchellh/mapstructure"
|
||||
)
|
||||
|
||||
const (
|
||||
ServiceDefaults string = "service-defaults"
|
||||
ProxyDefaults string = "proxy-defaults"
|
||||
ProxyConfigGlobal string = "global"
|
||||
)
|
||||
|
||||
type ConfigEntry interface {
|
||||
GetKind() string
|
||||
GetName() string
|
||||
}
|
||||
|
||||
type ConnectConfiguration struct {
|
||||
SidecarProxy bool
|
||||
}
|
||||
|
||||
type ServiceConfigEntry struct {
|
||||
Kind string
|
||||
Name string
|
||||
Protocol string
|
||||
Connect ConnectConfiguration
|
||||
CreateIndex uint64
|
||||
ModifyIndex uint64
|
||||
}
|
||||
|
||||
func (s *ServiceConfigEntry) GetKind() string {
|
||||
return s.Kind
|
||||
}
|
||||
|
||||
func (s *ServiceConfigEntry) GetName() string {
|
||||
return s.Name
|
||||
}
|
||||
|
||||
type ProxyConfigEntry struct {
|
||||
Kind string
|
||||
Name string
|
||||
Config map[string]interface{}
|
||||
CreateIndex uint64
|
||||
ModifyIndex uint64
|
||||
}
|
||||
|
||||
func (p *ProxyConfigEntry) GetKind() string {
|
||||
return p.Kind
|
||||
}
|
||||
|
||||
func (p *ProxyConfigEntry) GetName() string {
|
||||
return p.Name
|
||||
}
|
||||
|
||||
type rawEntryListResponse struct {
|
||||
kind string
|
||||
Entries []map[string]interface{}
|
||||
}
|
||||
|
||||
func makeConfigEntry(kind, name string) (ConfigEntry, error) {
|
||||
switch kind {
|
||||
case ServiceDefaults:
|
||||
return &ServiceConfigEntry{Name: name}, nil
|
||||
case ProxyDefaults:
|
||||
return &ProxyConfigEntry{Name: name}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid config entry kind: %s", kind)
|
||||
}
|
||||
}
|
||||
|
||||
func DecodeConfigEntry(raw map[string]interface{}) (ConfigEntry, error) {
|
||||
var entry ConfigEntry
|
||||
|
||||
kindVal, ok := raw["Kind"]
|
||||
if !ok {
|
||||
kindVal, ok = raw["kind"]
|
||||
}
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("Payload does not contain a kind/Kind key at the top level")
|
||||
}
|
||||
|
||||
if kindStr, ok := kindVal.(string); ok {
|
||||
newEntry, err := makeConfigEntry(kindStr, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
entry = newEntry
|
||||
} else {
|
||||
return nil, fmt.Errorf("Kind value in payload is not a string")
|
||||
}
|
||||
|
||||
decodeConf := &mapstructure.DecoderConfig{
|
||||
DecodeHook: mapstructure.StringToTimeDurationHookFunc(),
|
||||
Result: &entry,
|
||||
WeaklyTypedInput: true,
|
||||
}
|
||||
|
||||
decoder, err := mapstructure.NewDecoder(decodeConf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return entry, decoder.Decode(raw)
|
||||
}
|
||||
|
||||
// Config can be used to query the Config endpoints
|
||||
type ConfigEntries struct {
|
||||
c *Client
|
||||
}
|
||||
|
||||
// Config returns a handle to the Config endpoints
|
||||
func (c *Client) ConfigEntries() *ConfigEntries {
|
||||
return &ConfigEntries{c}
|
||||
}
|
||||
|
||||
func (conf *ConfigEntries) Get(kind string, name string, q *QueryOptions) (ConfigEntry, *QueryMeta, error) {
|
||||
if kind == "" || name == "" {
|
||||
return nil, nil, fmt.Errorf("Both kind and name parameters must not be empty")
|
||||
}
|
||||
|
||||
entry, err := makeConfigEntry(kind, name)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
r := conf.c.newRequest("GET", fmt.Sprintf("/v1/config/%s/%s", kind, name))
|
||||
r.setQueryOptions(q)
|
||||
rtt, resp, err := requireOK(conf.c.doRequest(r))
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
qm := &QueryMeta{}
|
||||
parseQueryMeta(resp, qm)
|
||||
qm.RequestTime = rtt
|
||||
|
||||
if err := decodeBody(resp, entry); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return entry, qm, nil
|
||||
}
|
||||
|
||||
func (conf *ConfigEntries) List(kind string, q *QueryOptions) ([]ConfigEntry, *QueryMeta, error) {
|
||||
if kind == "" {
|
||||
return nil, nil, fmt.Errorf("The kind parameter must not be empty")
|
||||
}
|
||||
|
||||
r := conf.c.newRequest("GET", fmt.Sprintf("/v1/config/%s", kind))
|
||||
r.setQueryOptions(q)
|
||||
rtt, resp, err := requireOK(conf.c.doRequest(r))
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
qm := &QueryMeta{}
|
||||
parseQueryMeta(resp, qm)
|
||||
qm.RequestTime = rtt
|
||||
|
||||
var raw []map[string]interface{}
|
||||
if err := decodeBody(resp, &raw); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
var entries []ConfigEntry
|
||||
for _, rawEntry := range raw {
|
||||
entry, err := DecodeConfigEntry(rawEntry)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
entries = append(entries, entry)
|
||||
}
|
||||
|
||||
return entries, qm, nil
|
||||
}
|
||||
|
||||
func (conf *ConfigEntries) Set(entry ConfigEntry, w *WriteOptions) (*WriteMeta, error) {
|
||||
r := conf.c.newRequest("PUT", "/v1/config")
|
||||
r.setWriteOptions(w)
|
||||
r.obj = entry
|
||||
rtt, resp, err := requireOK(conf.c.doRequest(r))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
wm := &WriteMeta{RequestTime: rtt}
|
||||
return wm, nil
|
||||
}
|
||||
|
||||
func (conf *ConfigEntries) Delete(kind string, name string, w *WriteOptions) (*WriteMeta, error) {
|
||||
if kind == "" || name == "" {
|
||||
return nil, fmt.Errorf("Both kind and name parameters must not be empty")
|
||||
}
|
||||
|
||||
r := conf.c.newRequest("DELETE", fmt.Sprintf("/v1/config/%s/%s", kind, name))
|
||||
r.setWriteOptions(w)
|
||||
rtt, resp, err := requireOK(conf.c.doRequest(r))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp.Body.Close()
|
||||
wm := &WriteMeta{RequestTime: rtt}
|
||||
return wm, nil
|
||||
}
|
|
@ -0,0 +1,156 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAPI_ConfigEntries(t *testing.T) {
|
||||
t.Parallel()
|
||||
c, s := makeClient(t)
|
||||
defer s.Stop()
|
||||
|
||||
config_entries := c.ConfigEntries()
|
||||
|
||||
t.Run("Proxy Defaults", func(t *testing.T) {
|
||||
global_proxy := &ProxyConfigEntry{
|
||||
Kind: ProxyDefaults,
|
||||
Name: ProxyConfigGlobal,
|
||||
Config: map[string]interface{}{
|
||||
"foo": "bar",
|
||||
"bar": 1.0,
|
||||
},
|
||||
}
|
||||
|
||||
// set it
|
||||
wm, err := config_entries.Set(global_proxy, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, wm)
|
||||
require.NotEqual(t, 0, wm.RequestTime)
|
||||
|
||||
// get it
|
||||
entry, qm, err := config_entries.Get(ProxyDefaults, ProxyConfigGlobal, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, qm)
|
||||
require.NotEqual(t, 0, qm.RequestTime)
|
||||
|
||||
// verify it
|
||||
readProxy, ok := entry.(*ProxyConfigEntry)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, global_proxy.Kind, readProxy.Kind)
|
||||
require.Equal(t, global_proxy.Name, readProxy.Name)
|
||||
require.Equal(t, global_proxy.Config, readProxy.Config)
|
||||
|
||||
// update it
|
||||
global_proxy.Config["baz"] = true
|
||||
wm, err = config_entries.Set(global_proxy, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, wm)
|
||||
require.NotEqual(t, 0, wm.RequestTime)
|
||||
|
||||
// list it
|
||||
entries, qm, err := config_entries.List(ProxyDefaults, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, qm)
|
||||
require.NotEqual(t, 0, qm.RequestTime)
|
||||
require.Len(t, entries, 1)
|
||||
readProxy, ok = entries[0].(*ProxyConfigEntry)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, global_proxy.Kind, readProxy.Kind)
|
||||
require.Equal(t, global_proxy.Name, readProxy.Name)
|
||||
require.Equal(t, global_proxy.Config, readProxy.Config)
|
||||
|
||||
// delete it
|
||||
wm, err = config_entries.Delete(ProxyDefaults, ProxyConfigGlobal, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, wm)
|
||||
require.NotEqual(t, 0, wm.RequestTime)
|
||||
|
||||
entry, qm, err = config_entries.Get(ProxyDefaults, ProxyConfigGlobal, nil)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("Service Defaults", func(t *testing.T) {
|
||||
service := &ServiceConfigEntry{
|
||||
Kind: ServiceDefaults,
|
||||
Name: "foo",
|
||||
Protocol: "udp",
|
||||
}
|
||||
|
||||
service2 := &ServiceConfigEntry{
|
||||
Kind: ServiceDefaults,
|
||||
Name: "bar",
|
||||
Protocol: "tcp",
|
||||
}
|
||||
|
||||
// set it
|
||||
wm, err := config_entries.Set(service, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, wm)
|
||||
require.NotEqual(t, 0, wm.RequestTime)
|
||||
|
||||
// also set the second one
|
||||
wm, err = config_entries.Set(service2, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, wm)
|
||||
require.NotEqual(t, 0, wm.RequestTime)
|
||||
|
||||
// get it
|
||||
entry, qm, err := config_entries.Get(ServiceDefaults, "foo", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, qm)
|
||||
require.NotEqual(t, 0, qm.RequestTime)
|
||||
|
||||
// verify it
|
||||
readService, ok := entry.(*ServiceConfigEntry)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, service.Kind, readService.Kind)
|
||||
require.Equal(t, service.Name, readService.Name)
|
||||
require.Equal(t, service.Protocol, readService.Protocol)
|
||||
|
||||
// update it
|
||||
service.Protocol = "tcp"
|
||||
wm, err = config_entries.Set(service, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, wm)
|
||||
require.NotEqual(t, 0, wm.RequestTime)
|
||||
|
||||
// list them
|
||||
entries, qm, err := config_entries.List(ServiceDefaults, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, qm)
|
||||
require.NotEqual(t, 0, qm.RequestTime)
|
||||
require.Len(t, entries, 2)
|
||||
|
||||
for _, entry = range entries {
|
||||
switch entry.GetName() {
|
||||
case "foo":
|
||||
// this also verfies that the update value was persisted and
|
||||
// the updated values are seen
|
||||
readService, ok = entry.(*ServiceConfigEntry)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, service.Kind, readService.Kind)
|
||||
require.Equal(t, service.Name, readService.Name)
|
||||
require.Equal(t, service.Protocol, readService.Protocol)
|
||||
case "bar":
|
||||
readService, ok = entry.(*ServiceConfigEntry)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, service2.Kind, readService.Kind)
|
||||
require.Equal(t, service2.Name, readService.Name)
|
||||
require.Equal(t, service2.Protocol, readService.Protocol)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// delete it
|
||||
wm, err = config_entries.Delete(ServiceDefaults, "foo", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, wm)
|
||||
require.NotEqual(t, 0, wm.RequestTime)
|
||||
|
||||
// verify deletion
|
||||
entry, qm, err = config_entries.Get(ServiceDefaults, "foo", nil)
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
|
@ -0,0 +1,168 @@
|
|||
package lib
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"github.com/mitchellh/copystructure"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"github.com/mitchellh/reflectwalk"
|
||||
)
|
||||
|
||||
// MapWalk will traverse through the supplied input which should be a
|
||||
// map[string]interface{} (or something compatible that we can coerce
|
||||
// to a map[string]interface{}) and from it create a new map[string]interface{}
|
||||
// with all internal values coerced to JSON compatible types. i.e. a []uint8
|
||||
// can be converted (in most cases) to a string so it will not be base64 encoded
|
||||
// when output in JSON
|
||||
func MapWalk(input interface{}) (map[string]interface{}, error) {
|
||||
mapCopyRaw, err := copystructure.Copy(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mapCopy, ok := mapCopyRaw.(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("internal error: input to MapWalk is not a map[string]interface{}")
|
||||
}
|
||||
|
||||
if err := reflectwalk.Walk(mapCopy, &mapWalker{}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return mapCopy, nil
|
||||
}
|
||||
|
||||
var typMapIfaceIface = reflect.TypeOf(map[interface{}]interface{}{})
|
||||
|
||||
// mapWalker implements interfaces for the reflectwalk package
|
||||
// (github.com/mitchellh/reflectwalk) that can be used to automatically
|
||||
// make a JSON compatible map safe for JSON usage. This is currently
|
||||
// targeted at the map[string]interface{}
|
||||
//
|
||||
// Most of the implementation here is just keeping track of where we are
|
||||
// in the reflectwalk process, so that we can replace values. The key logic
|
||||
// is in Slice() and SliceElem().
|
||||
//
|
||||
// In particular we're looking to replace two cases the msgpack codec causes:
|
||||
//
|
||||
// 1.) String values get turned into byte slices. JSON will base64-encode
|
||||
// this and we don't want that, so we convert them back to strings.
|
||||
//
|
||||
// 2.) Nested maps turn into map[interface{}]interface{}. JSON cannot
|
||||
// encode this, so we need to turn it back into map[string]interface{}.
|
||||
//
|
||||
type mapWalker struct {
|
||||
lastValue reflect.Value // lastValue of map, required for replacement
|
||||
loc, lastLoc reflectwalk.Location // locations
|
||||
cs []reflect.Value // container stack
|
||||
csKey []reflect.Value // container keys (maps) stack
|
||||
csData interface{} // current container data
|
||||
sliceIndex []int // slice index stack (one for each slice in cs)
|
||||
}
|
||||
|
||||
func (w *mapWalker) Enter(loc reflectwalk.Location) error {
|
||||
w.lastLoc = w.loc
|
||||
w.loc = loc
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *mapWalker) Exit(loc reflectwalk.Location) error {
|
||||
w.loc = reflectwalk.None
|
||||
w.lastLoc = reflectwalk.None
|
||||
|
||||
switch loc {
|
||||
case reflectwalk.Map:
|
||||
w.cs = w.cs[:len(w.cs)-1]
|
||||
case reflectwalk.MapValue:
|
||||
w.csKey = w.csKey[:len(w.csKey)-1]
|
||||
case reflectwalk.Slice:
|
||||
// Split any values that need to be split
|
||||
w.cs = w.cs[:len(w.cs)-1]
|
||||
case reflectwalk.SliceElem:
|
||||
w.csKey = w.csKey[:len(w.csKey)-1]
|
||||
w.sliceIndex = w.sliceIndex[:len(w.sliceIndex)-1]
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *mapWalker) Map(m reflect.Value) error {
|
||||
w.cs = append(w.cs, m)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *mapWalker) MapElem(m, k, v reflect.Value) error {
|
||||
w.csData = k
|
||||
w.csKey = append(w.csKey, k)
|
||||
|
||||
w.lastValue = v
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *mapWalker) Slice(v reflect.Value) error {
|
||||
// If we find a []byte slice, it is an HCL-string converted to []byte.
|
||||
// Convert it back to a Go string and replace the value so that JSON
|
||||
// doesn't base64-encode it.
|
||||
if v.Type() == reflect.TypeOf([]byte{}) {
|
||||
resultVal := reflect.ValueOf(string(v.Interface().([]byte)))
|
||||
switch w.lastLoc {
|
||||
case reflectwalk.MapKey:
|
||||
m := w.cs[len(w.cs)-1]
|
||||
|
||||
// Delete the old value
|
||||
var zero reflect.Value
|
||||
m.SetMapIndex(w.csData.(reflect.Value), zero)
|
||||
|
||||
// Set the new key with the existing value
|
||||
m.SetMapIndex(resultVal, w.lastValue)
|
||||
|
||||
// Set the key to be the new key
|
||||
w.csData = resultVal
|
||||
case reflectwalk.MapValue:
|
||||
// If we're in a map, then the only way to set a map value is
|
||||
// to set it directly.
|
||||
m := w.cs[len(w.cs)-1]
|
||||
mk := w.csData.(reflect.Value)
|
||||
m.SetMapIndex(mk, resultVal)
|
||||
case reflectwalk.Slice:
|
||||
s := w.cs[len(w.cs)-1]
|
||||
s.Index(w.sliceIndex[len(w.sliceIndex)-1]).Set(resultVal)
|
||||
default:
|
||||
return fmt.Errorf("cannot convert []byte")
|
||||
}
|
||||
}
|
||||
|
||||
w.cs = append(w.cs, v)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *mapWalker) SliceElem(i int, elem reflect.Value) error {
|
||||
w.csKey = append(w.csKey, reflect.ValueOf(i))
|
||||
w.sliceIndex = append(w.sliceIndex, i)
|
||||
|
||||
// We're looking specifically for map[interface{}]interface{}, but the
|
||||
// values in a slice are wrapped up in interface{} so we need to unwrap
|
||||
// that first. Therefore, we do three checks: 1.) is it valid? so we
|
||||
// don't panic, 2.) is it an interface{}? so we can unwrap it and 3.)
|
||||
// after unwrapping the interface do we have the map we expect?
|
||||
if !elem.IsValid() {
|
||||
return nil
|
||||
}
|
||||
|
||||
if elem.Kind() != reflect.Interface {
|
||||
return nil
|
||||
}
|
||||
|
||||
if inner := elem.Elem(); inner.Type() == typMapIfaceIface {
|
||||
// map[interface{}]interface{}, attempt to weakly decode into string keys
|
||||
var target map[string]interface{}
|
||||
if err := mapstructure.WeakDecode(inner.Interface(), &target); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
elem.Set(reflect.ValueOf(target))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,63 @@
|
|||
package lib
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMapWalk(t *testing.T) {
|
||||
t.Parallel()
|
||||
type tcase struct {
|
||||
input interface{}
|
||||
expected interface{}
|
||||
unexpected bool
|
||||
err bool
|
||||
}
|
||||
|
||||
cases := map[string]tcase{
|
||||
// basically tests that []uint8 gets turned into
|
||||
// a string
|
||||
"simple": tcase{
|
||||
input: map[string]interface{}{
|
||||
"foo": []uint8("bar"),
|
||||
},
|
||||
expected: map[string]interface{}{
|
||||
"foo": "bar",
|
||||
},
|
||||
},
|
||||
// ensures that it was actually converted and not
|
||||
// just the require.Equal masking the underlying
|
||||
// type differences
|
||||
"uint8 conversion": tcase{
|
||||
input: map[string]interface{}{
|
||||
"foo": []uint8("bar"),
|
||||
},
|
||||
expected: map[string]interface{}{
|
||||
"foo": []uint8("bar"),
|
||||
},
|
||||
unexpected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tcase := range cases {
|
||||
name := name
|
||||
tcase := tcase
|
||||
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
actual, err := MapWalk(tcase.input)
|
||||
if tcase.err {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
if tcase.unexpected {
|
||||
require.NotEqual(t, tcase.expected, actual)
|
||||
} else {
|
||||
require.Equal(t, tcase.expected, actual)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue