Support SesionTTLMin configuration

- Allow setting SessionTTLMin
- Validate on the Server
This commit is contained in:
Michael Fraenkel 2015-03-26 22:30:04 -07:00
parent 6e261622f8
commit 8c26836783
11 changed files with 114 additions and 116 deletions

View File

@ -261,6 +261,9 @@ func (a *Agent) consulConfig() *consul.Config {
if a.config.ACLDownPolicy != "" {
base.ACLDownPolicy = a.config.ACLDownPolicy
}
if a.config.SessionTTLMinRaw != "" {
base.SessionTTLMin = a.config.SessionTTLMin
}
// Format the build string
revision := a.config.Revision

View File

@ -363,6 +363,10 @@ type Config struct {
// UnixSockets is a map of socket configuration data
UnixSockets UnixSocketConfig `mapstructure:"unix_sockets"`
// Minimum Session TTL
SessionTTLMin time.Duration `mapstructure:"-"`
SessionTTLMinRaw string `mapstructure:"session_ttl_min"`
}
// UnixSocketPermissions contains information about a unix socket, and
@ -609,6 +613,14 @@ func DecodeConfig(r io.Reader) (*Config, error) {
result.DNSRecursors = append(result.DNSRecursors, result.DNSRecursor)
}
if raw := result.SessionTTLMinRaw; raw != "" {
dur, err := time.ParseDuration(raw)
if err != nil {
return nil, fmt.Errorf("Session TTL Min invalid: %v", err)
}
result.SessionTTLMin = dur
}
return &result, nil
}
@ -970,7 +982,10 @@ func MergeConfig(a, b *Config) *Config {
if b.AtlasJoin {
result.AtlasJoin = true
}
if b.SessionTTLMinRaw != "" {
result.SessionTTLMin = b.SessionTTLMin
result.SessionTTLMinRaw = b.SessionTTLMinRaw
}
if len(b.HTTPAPIResponseHeaders) != 0 {
if result.HTTPAPIResponseHeaders == nil {
result.HTTPAPIResponseHeaders = make(map[string]string)

View File

@ -653,6 +653,17 @@ func TestDecodeConfig(t *testing.T) {
if !config.AtlasJoin {
t.Fatalf("bad: %#v", config)
}
// SessionTTLMin
input = `{"session_ttl_min": "5s"}`
config, err = DecodeConfig(bytes.NewReader([]byte(input)))
if err != nil {
t.Fatalf("err: %s", err)
}
if config.SessionTTLMin != 5*time.Second {
t.Fatalf("bad: %s %#v", config.SessionTTLMin.String(), config)
}
}
func TestDecodeConfig_invalidKeys(t *testing.T) {
@ -1120,6 +1131,8 @@ func TestMergeConfig(t *testing.T) {
AtlasToken: "123456789",
AtlasACLToken: "abcdefgh",
AtlasJoin: true,
SessionTTLMinRaw: "1000s",
SessionTTLMin: 1000 * time.Second,
}
c := MergeConfig(a, b)

View File

@ -521,7 +521,11 @@ func getIndex(t *testing.T, resp *httptest.ResponseRecorder) uint64 {
}
func httpTest(t *testing.T, f func(srv *HTTPServer)) {
dir, srv := makeHTTPServer(t)
httpTestWithConfig(t, f, nil)
}
func httpTestWithConfig(t *testing.T, f func(srv *HTTPServer), cb func(c *Config)) {
dir, srv := makeHTTPServerWithConfig(t, cb)
defer os.RemoveAll(dir)
defer srv.Shutdown()
defer srv.agent.Shutdown()

View File

@ -53,21 +53,6 @@ func (s *HTTPServer) SessionCreate(resp http.ResponseWriter, req *http.Request)
resp.Write([]byte(fmt.Sprintf("Request decode failed: %v", err)))
return nil, nil
}
if args.Session.TTL != "" {
ttl, err := time.ParseDuration(args.Session.TTL)
if err != nil {
resp.WriteHeader(400)
resp.Write([]byte(fmt.Sprintf("Request TTL decode failed: %v", err)))
return nil, nil
}
if ttl != 0 && (ttl < structs.SessionTTLMin || ttl > structs.SessionTTLMax) {
resp.WriteHeader(400)
resp.Write([]byte(fmt.Sprintf("Request TTL '%s', must be between [%v-%v]", args.Session.TTL, structs.SessionTTLMin, structs.SessionTTLMax)))
return nil, nil
}
}
}
// Create the session, get the ID

View File

@ -3,12 +3,13 @@ package agent
import (
"bytes"
"encoding/json"
"github.com/hashicorp/consul/consul"
"github.com/hashicorp/consul/consul/structs"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/hashicorp/consul/consul"
"github.com/hashicorp/consul/consul/structs"
)
func TestSessionCreate(t *testing.T) {
@ -215,9 +216,20 @@ func TestSessionDestroy(t *testing.T) {
}
func TestSessionTTL(t *testing.T) {
httpTest(t, func(srv *HTTPServer) {
TTL := "10s" // use the minimum legal ttl
ttl := 10 * time.Second
// use the minimum legal ttl
testSessionTTL(t, 10*time.Second, nil)
}
func TestSessionTTLConfig(t *testing.T) {
testSessionTTL(t, 1*time.Second, func(c *Config) {
c.SessionTTLMinRaw = "1s"
c.SessionTTLMin = 1 * time.Second
})
}
func testSessionTTL(t *testing.T, ttl time.Duration, cb func(c *Config)) {
httpTestWithConfig(t, func(srv *HTTPServer) {
TTL := ttl.String()
id := makeTestSessionTTL(t, srv, TTL)
@ -252,85 +264,7 @@ func TestSessionTTL(t *testing.T) {
if len(respObj) != 0 {
t.Fatalf("session '%s' should have been destroyed", id)
}
})
}
func TestSessionBadTTL(t *testing.T) {
httpTest(t, func(srv *HTTPServer) {
badTTL := "10z"
// Create Session with illegal TTL
body := bytes.NewBuffer(nil)
enc := json.NewEncoder(body)
raw := map[string]interface{}{
"TTL": badTTL,
}
enc.Encode(raw)
req, err := http.NewRequest("PUT", "/v1/session/create", body)
if err != nil {
t.Fatalf("err: %v", err)
}
resp := httptest.NewRecorder()
obj, err := srv.SessionCreate(resp, req)
if err != nil {
t.Fatalf("err: %v", err)
}
if obj != nil {
t.Fatalf("illegal TTL '%s' allowed", badTTL)
}
if resp.Code != 400 {
t.Fatalf("Bad response code, should be 400")
}
// less than SessionTTLMin
body = bytes.NewBuffer(nil)
enc = json.NewEncoder(body)
raw = map[string]interface{}{
"TTL": "5s",
}
enc.Encode(raw)
req, err = http.NewRequest("PUT", "/v1/session/create", body)
if err != nil {
t.Fatalf("err: %v", err)
}
resp = httptest.NewRecorder()
obj, err = srv.SessionCreate(resp, req)
if err != nil {
t.Fatalf("err: %v", err)
}
if obj != nil {
t.Fatalf("illegal TTL '%s' allowed", badTTL)
}
if resp.Code != 400 {
t.Fatalf("Bad response code, should be 400")
}
// more than SessionTTLMax
body = bytes.NewBuffer(nil)
enc = json.NewEncoder(body)
raw = map[string]interface{}{
"TTL": "4000s",
}
enc.Encode(raw)
req, err = http.NewRequest("PUT", "/v1/session/create", body)
if err != nil {
t.Fatalf("err: %v", err)
}
resp = httptest.NewRecorder()
obj, err = srv.SessionCreate(resp, req)
if err != nil {
t.Fatalf("err: %v", err)
}
if obj != nil {
t.Fatalf("illegal TTL '%s' allowed", badTTL)
}
if resp.Code != 400 {
t.Fatalf("Bad response code, should be 400")
}
})
}, cb)
}
func TestSessionTTLRenew(t *testing.T) {

View File

@ -181,6 +181,9 @@ type Config struct {
// to reduce overhead. It is unlikely a user would ever need to tune this.
TombstoneTTLGranularity time.Duration
// Minimum Session TTL
SessionTTLMin time.Duration
// ServerUp callback can be used to trigger a notification that
// a Consul server is now up and known about.
ServerUp func()
@ -241,6 +244,7 @@ func DefaultConfig() *Config {
ACLDownPolicy: "extend-cache",
TombstoneTTL: 15 * time.Minute,
TombstoneTTLGranularity: 30 * time.Second,
SessionTTLMin: 10 * time.Second,
}
// Increase our reap interval to 3 days instead of 24h.

View File

@ -47,9 +47,9 @@ func (s *Session) Apply(args *structs.SessionRequest, reply *string) error {
return fmt.Errorf("Session TTL '%s' invalid: %v", args.Session.TTL, err)
}
if ttl != 0 && (ttl < structs.SessionTTLMin || ttl > structs.SessionTTLMax) {
if ttl != 0 && (ttl < s.srv.config.SessionTTLMin || ttl > structs.SessionTTLMax) {
return fmt.Errorf("Invalid Session TTL '%d', must be between [%v=%v]",
ttl, structs.SessionTTLMin, structs.SessionTTLMax)
ttl, s.srv.config.SessionTTLMin, structs.SessionTTLMax)
}
}

View File

@ -483,3 +483,56 @@ func TestSessionEndpoint_NodeSessions(t *testing.T) {
}
}
}
func TestSessionEndpoint_Apply_BadTTL(t *testing.T) {
dir1, s1 := testServer(t)
defer os.RemoveAll(dir1)
defer s1.Shutdown()
client := rpcClient(t, s1)
defer client.Close()
testutil.WaitForLeader(t, client.Call, "dc1")
arg := structs.SessionRequest{
Datacenter: "dc1",
Op: structs.SessionCreate,
Session: structs.Session{
Node: "foo",
Name: "my-session",
},
}
// Session with illegal TTL
arg.Session.TTL = "10z"
var out string
err := client.Call("Session.Apply", &arg, &out)
if err == nil {
t.Fatal("expected error")
}
if err.Error() != "Session TTL '10z' invalid: time: unknown unit z in duration 10z" {
t.Fatalf("incorrect error message: %s", err.Error())
}
// less than SessionTTLMin
arg.Session.TTL = "5s"
err = client.Call("Session.Apply", &arg, &out)
if err == nil {
t.Fatal("expected error")
}
if err.Error() != "Invalid Session TTL '5000000000', must be between [10s=1h0m0s]" {
t.Fatalf("incorrect error message: %s", err.Error())
}
// more than SessionTTLMax
arg.Session.TTL = "4000s"
err = client.Call("Session.Apply", &arg, &out)
if err == nil {
t.Fatal("expected error")
}
if err.Error() != "Invalid Session TTL '4000000000000', must be between [10s=1h0m0s]" {
t.Fatalf("incorrect error message: %s", err.Error())
}
}

View File

@ -1625,18 +1625,6 @@ func (s *StateStore) SessionCreate(index uint64, session *structs.Session) error
return fmt.Errorf("Invalid Session Behavior setting '%s'", session.Behavior)
}
if session.TTL != "" {
ttl, err := time.ParseDuration(session.TTL)
if err != nil {
return fmt.Errorf("Invalid Session TTL '%s': %v", session.TTL, err)
}
if ttl != 0 && (ttl < structs.SessionTTLMin || ttl > structs.SessionTTLMax) {
return fmt.Errorf("Invalid Session TTL '%s', must be between [%v-%v]",
session.TTL, structs.SessionTTLMin, structs.SessionTTLMax)
}
}
// Assign the create index
session.CreateIndex = index

View File

@ -391,7 +391,6 @@ const (
)
const (
SessionTTLMin = 10 * time.Second
SessionTTLMax = 3600 * time.Second
SessionTTLMultiplier = 2
)