mirror of https://github.com/status-im/consul.git
Support SesionTTLMin configuration
- Allow setting SessionTTLMin - Validate on the Server
This commit is contained in:
parent
6e261622f8
commit
8c26836783
|
@ -261,6 +261,9 @@ func (a *Agent) consulConfig() *consul.Config {
|
|||
if a.config.ACLDownPolicy != "" {
|
||||
base.ACLDownPolicy = a.config.ACLDownPolicy
|
||||
}
|
||||
if a.config.SessionTTLMinRaw != "" {
|
||||
base.SessionTTLMin = a.config.SessionTTLMin
|
||||
}
|
||||
|
||||
// Format the build string
|
||||
revision := a.config.Revision
|
||||
|
|
|
@ -363,6 +363,10 @@ type Config struct {
|
|||
|
||||
// UnixSockets is a map of socket configuration data
|
||||
UnixSockets UnixSocketConfig `mapstructure:"unix_sockets"`
|
||||
|
||||
// Minimum Session TTL
|
||||
SessionTTLMin time.Duration `mapstructure:"-"`
|
||||
SessionTTLMinRaw string `mapstructure:"session_ttl_min"`
|
||||
}
|
||||
|
||||
// UnixSocketPermissions contains information about a unix socket, and
|
||||
|
@ -609,6 +613,14 @@ func DecodeConfig(r io.Reader) (*Config, error) {
|
|||
result.DNSRecursors = append(result.DNSRecursors, result.DNSRecursor)
|
||||
}
|
||||
|
||||
if raw := result.SessionTTLMinRaw; raw != "" {
|
||||
dur, err := time.ParseDuration(raw)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Session TTL Min invalid: %v", err)
|
||||
}
|
||||
result.SessionTTLMin = dur
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
|
@ -970,7 +982,10 @@ func MergeConfig(a, b *Config) *Config {
|
|||
if b.AtlasJoin {
|
||||
result.AtlasJoin = true
|
||||
}
|
||||
|
||||
if b.SessionTTLMinRaw != "" {
|
||||
result.SessionTTLMin = b.SessionTTLMin
|
||||
result.SessionTTLMinRaw = b.SessionTTLMinRaw
|
||||
}
|
||||
if len(b.HTTPAPIResponseHeaders) != 0 {
|
||||
if result.HTTPAPIResponseHeaders == nil {
|
||||
result.HTTPAPIResponseHeaders = make(map[string]string)
|
||||
|
|
|
@ -653,6 +653,17 @@ func TestDecodeConfig(t *testing.T) {
|
|||
if !config.AtlasJoin {
|
||||
t.Fatalf("bad: %#v", config)
|
||||
}
|
||||
|
||||
// SessionTTLMin
|
||||
input = `{"session_ttl_min": "5s"}`
|
||||
config, err = DecodeConfig(bytes.NewReader([]byte(input)))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
if config.SessionTTLMin != 5*time.Second {
|
||||
t.Fatalf("bad: %s %#v", config.SessionTTLMin.String(), config)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeConfig_invalidKeys(t *testing.T) {
|
||||
|
@ -1120,6 +1131,8 @@ func TestMergeConfig(t *testing.T) {
|
|||
AtlasToken: "123456789",
|
||||
AtlasACLToken: "abcdefgh",
|
||||
AtlasJoin: true,
|
||||
SessionTTLMinRaw: "1000s",
|
||||
SessionTTLMin: 1000 * time.Second,
|
||||
}
|
||||
|
||||
c := MergeConfig(a, b)
|
||||
|
|
|
@ -521,7 +521,11 @@ func getIndex(t *testing.T, resp *httptest.ResponseRecorder) uint64 {
|
|||
}
|
||||
|
||||
func httpTest(t *testing.T, f func(srv *HTTPServer)) {
|
||||
dir, srv := makeHTTPServer(t)
|
||||
httpTestWithConfig(t, f, nil)
|
||||
}
|
||||
|
||||
func httpTestWithConfig(t *testing.T, f func(srv *HTTPServer), cb func(c *Config)) {
|
||||
dir, srv := makeHTTPServerWithConfig(t, cb)
|
||||
defer os.RemoveAll(dir)
|
||||
defer srv.Shutdown()
|
||||
defer srv.agent.Shutdown()
|
||||
|
|
|
@ -53,21 +53,6 @@ func (s *HTTPServer) SessionCreate(resp http.ResponseWriter, req *http.Request)
|
|||
resp.Write([]byte(fmt.Sprintf("Request decode failed: %v", err)))
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if args.Session.TTL != "" {
|
||||
ttl, err := time.ParseDuration(args.Session.TTL)
|
||||
if err != nil {
|
||||
resp.WriteHeader(400)
|
||||
resp.Write([]byte(fmt.Sprintf("Request TTL decode failed: %v", err)))
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if ttl != 0 && (ttl < structs.SessionTTLMin || ttl > structs.SessionTTLMax) {
|
||||
resp.WriteHeader(400)
|
||||
resp.Write([]byte(fmt.Sprintf("Request TTL '%s', must be between [%v-%v]", args.Session.TTL, structs.SessionTTLMin, structs.SessionTTLMax)))
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create the session, get the ID
|
||||
|
|
|
@ -3,12 +3,13 @@ package agent
|
|||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"github.com/hashicorp/consul/consul"
|
||||
"github.com/hashicorp/consul/consul/structs"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/consul/consul"
|
||||
"github.com/hashicorp/consul/consul/structs"
|
||||
)
|
||||
|
||||
func TestSessionCreate(t *testing.T) {
|
||||
|
@ -215,9 +216,20 @@ func TestSessionDestroy(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSessionTTL(t *testing.T) {
|
||||
httpTest(t, func(srv *HTTPServer) {
|
||||
TTL := "10s" // use the minimum legal ttl
|
||||
ttl := 10 * time.Second
|
||||
// use the minimum legal ttl
|
||||
testSessionTTL(t, 10*time.Second, nil)
|
||||
}
|
||||
|
||||
func TestSessionTTLConfig(t *testing.T) {
|
||||
testSessionTTL(t, 1*time.Second, func(c *Config) {
|
||||
c.SessionTTLMinRaw = "1s"
|
||||
c.SessionTTLMin = 1 * time.Second
|
||||
})
|
||||
}
|
||||
|
||||
func testSessionTTL(t *testing.T, ttl time.Duration, cb func(c *Config)) {
|
||||
httpTestWithConfig(t, func(srv *HTTPServer) {
|
||||
TTL := ttl.String()
|
||||
|
||||
id := makeTestSessionTTL(t, srv, TTL)
|
||||
|
||||
|
@ -252,85 +264,7 @@ func TestSessionTTL(t *testing.T) {
|
|||
if len(respObj) != 0 {
|
||||
t.Fatalf("session '%s' should have been destroyed", id)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSessionBadTTL(t *testing.T) {
|
||||
httpTest(t, func(srv *HTTPServer) {
|
||||
badTTL := "10z"
|
||||
|
||||
// Create Session with illegal TTL
|
||||
body := bytes.NewBuffer(nil)
|
||||
enc := json.NewEncoder(body)
|
||||
raw := map[string]interface{}{
|
||||
"TTL": badTTL,
|
||||
}
|
||||
enc.Encode(raw)
|
||||
|
||||
req, err := http.NewRequest("PUT", "/v1/session/create", body)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := srv.SessionCreate(resp, req)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if obj != nil {
|
||||
t.Fatalf("illegal TTL '%s' allowed", badTTL)
|
||||
}
|
||||
if resp.Code != 400 {
|
||||
t.Fatalf("Bad response code, should be 400")
|
||||
}
|
||||
|
||||
// less than SessionTTLMin
|
||||
body = bytes.NewBuffer(nil)
|
||||
enc = json.NewEncoder(body)
|
||||
raw = map[string]interface{}{
|
||||
"TTL": "5s",
|
||||
}
|
||||
enc.Encode(raw)
|
||||
|
||||
req, err = http.NewRequest("PUT", "/v1/session/create", body)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
resp = httptest.NewRecorder()
|
||||
obj, err = srv.SessionCreate(resp, req)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if obj != nil {
|
||||
t.Fatalf("illegal TTL '%s' allowed", badTTL)
|
||||
}
|
||||
if resp.Code != 400 {
|
||||
t.Fatalf("Bad response code, should be 400")
|
||||
}
|
||||
|
||||
// more than SessionTTLMax
|
||||
body = bytes.NewBuffer(nil)
|
||||
enc = json.NewEncoder(body)
|
||||
raw = map[string]interface{}{
|
||||
"TTL": "4000s",
|
||||
}
|
||||
enc.Encode(raw)
|
||||
|
||||
req, err = http.NewRequest("PUT", "/v1/session/create", body)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
resp = httptest.NewRecorder()
|
||||
obj, err = srv.SessionCreate(resp, req)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if obj != nil {
|
||||
t.Fatalf("illegal TTL '%s' allowed", badTTL)
|
||||
}
|
||||
if resp.Code != 400 {
|
||||
t.Fatalf("Bad response code, should be 400")
|
||||
}
|
||||
})
|
||||
}, cb)
|
||||
}
|
||||
|
||||
func TestSessionTTLRenew(t *testing.T) {
|
||||
|
|
|
@ -181,6 +181,9 @@ type Config struct {
|
|||
// to reduce overhead. It is unlikely a user would ever need to tune this.
|
||||
TombstoneTTLGranularity time.Duration
|
||||
|
||||
// Minimum Session TTL
|
||||
SessionTTLMin time.Duration
|
||||
|
||||
// ServerUp callback can be used to trigger a notification that
|
||||
// a Consul server is now up and known about.
|
||||
ServerUp func()
|
||||
|
@ -241,6 +244,7 @@ func DefaultConfig() *Config {
|
|||
ACLDownPolicy: "extend-cache",
|
||||
TombstoneTTL: 15 * time.Minute,
|
||||
TombstoneTTLGranularity: 30 * time.Second,
|
||||
SessionTTLMin: 10 * time.Second,
|
||||
}
|
||||
|
||||
// Increase our reap interval to 3 days instead of 24h.
|
||||
|
|
|
@ -47,9 +47,9 @@ func (s *Session) Apply(args *structs.SessionRequest, reply *string) error {
|
|||
return fmt.Errorf("Session TTL '%s' invalid: %v", args.Session.TTL, err)
|
||||
}
|
||||
|
||||
if ttl != 0 && (ttl < structs.SessionTTLMin || ttl > structs.SessionTTLMax) {
|
||||
if ttl != 0 && (ttl < s.srv.config.SessionTTLMin || ttl > structs.SessionTTLMax) {
|
||||
return fmt.Errorf("Invalid Session TTL '%d', must be between [%v=%v]",
|
||||
ttl, structs.SessionTTLMin, structs.SessionTTLMax)
|
||||
ttl, s.srv.config.SessionTTLMin, structs.SessionTTLMax)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -483,3 +483,56 @@ func TestSessionEndpoint_NodeSessions(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionEndpoint_Apply_BadTTL(t *testing.T) {
|
||||
dir1, s1 := testServer(t)
|
||||
defer os.RemoveAll(dir1)
|
||||
defer s1.Shutdown()
|
||||
client := rpcClient(t, s1)
|
||||
defer client.Close()
|
||||
|
||||
testutil.WaitForLeader(t, client.Call, "dc1")
|
||||
|
||||
arg := structs.SessionRequest{
|
||||
Datacenter: "dc1",
|
||||
Op: structs.SessionCreate,
|
||||
Session: structs.Session{
|
||||
Node: "foo",
|
||||
Name: "my-session",
|
||||
},
|
||||
}
|
||||
|
||||
// Session with illegal TTL
|
||||
arg.Session.TTL = "10z"
|
||||
|
||||
var out string
|
||||
err := client.Call("Session.Apply", &arg, &out)
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
if err.Error() != "Session TTL '10z' invalid: time: unknown unit z in duration 10z" {
|
||||
t.Fatalf("incorrect error message: %s", err.Error())
|
||||
}
|
||||
|
||||
// less than SessionTTLMin
|
||||
arg.Session.TTL = "5s"
|
||||
|
||||
err = client.Call("Session.Apply", &arg, &out)
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
if err.Error() != "Invalid Session TTL '5000000000', must be between [10s=1h0m0s]" {
|
||||
t.Fatalf("incorrect error message: %s", err.Error())
|
||||
}
|
||||
|
||||
// more than SessionTTLMax
|
||||
arg.Session.TTL = "4000s"
|
||||
|
||||
err = client.Call("Session.Apply", &arg, &out)
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
if err.Error() != "Invalid Session TTL '4000000000000', must be between [10s=1h0m0s]" {
|
||||
t.Fatalf("incorrect error message: %s", err.Error())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1625,18 +1625,6 @@ func (s *StateStore) SessionCreate(index uint64, session *structs.Session) error
|
|||
return fmt.Errorf("Invalid Session Behavior setting '%s'", session.Behavior)
|
||||
}
|
||||
|
||||
if session.TTL != "" {
|
||||
ttl, err := time.ParseDuration(session.TTL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Invalid Session TTL '%s': %v", session.TTL, err)
|
||||
}
|
||||
|
||||
if ttl != 0 && (ttl < structs.SessionTTLMin || ttl > structs.SessionTTLMax) {
|
||||
return fmt.Errorf("Invalid Session TTL '%s', must be between [%v-%v]",
|
||||
session.TTL, structs.SessionTTLMin, structs.SessionTTLMax)
|
||||
}
|
||||
}
|
||||
|
||||
// Assign the create index
|
||||
session.CreateIndex = index
|
||||
|
||||
|
|
|
@ -391,7 +391,6 @@ const (
|
|||
)
|
||||
|
||||
const (
|
||||
SessionTTLMin = 10 * time.Second
|
||||
SessionTTLMax = 3600 * time.Second
|
||||
SessionTTLMultiplier = 2
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue