diff --git a/command/agent/agent.go b/command/agent/agent.go index 9bcc8be6c6..c741240ac7 100644 --- a/command/agent/agent.go +++ b/command/agent/agent.go @@ -261,6 +261,9 @@ func (a *Agent) consulConfig() *consul.Config { if a.config.ACLDownPolicy != "" { base.ACLDownPolicy = a.config.ACLDownPolicy } + if a.config.SessionTTLMinRaw != "" { + base.SessionTTLMin = a.config.SessionTTLMin + } // Format the build string revision := a.config.Revision diff --git a/command/agent/config.go b/command/agent/config.go index bd849d5fd8..48b347c332 100644 --- a/command/agent/config.go +++ b/command/agent/config.go @@ -363,6 +363,10 @@ type Config struct { // UnixSockets is a map of socket configuration data UnixSockets UnixSocketConfig `mapstructure:"unix_sockets"` + + // Minimum Session TTL + SessionTTLMin time.Duration `mapstructure:"-"` + SessionTTLMinRaw string `mapstructure:"session_ttl_min"` } // UnixSocketPermissions contains information about a unix socket, and @@ -609,6 +613,14 @@ func DecodeConfig(r io.Reader) (*Config, error) { result.DNSRecursors = append(result.DNSRecursors, result.DNSRecursor) } + if raw := result.SessionTTLMinRaw; raw != "" { + dur, err := time.ParseDuration(raw) + if err != nil { + return nil, fmt.Errorf("Session TTL Min invalid: %v", err) + } + result.SessionTTLMin = dur + } + return &result, nil } @@ -970,7 +982,10 @@ func MergeConfig(a, b *Config) *Config { if b.AtlasJoin { result.AtlasJoin = true } - + if b.SessionTTLMinRaw != "" { + result.SessionTTLMin = b.SessionTTLMin + result.SessionTTLMinRaw = b.SessionTTLMinRaw + } if len(b.HTTPAPIResponseHeaders) != 0 { if result.HTTPAPIResponseHeaders == nil { result.HTTPAPIResponseHeaders = make(map[string]string) diff --git a/command/agent/config_test.go b/command/agent/config_test.go index b6cdac2055..945e2152d4 100644 --- a/command/agent/config_test.go +++ b/command/agent/config_test.go @@ -653,6 +653,17 @@ func TestDecodeConfig(t *testing.T) { if !config.AtlasJoin { t.Fatalf("bad: %#v", config) } + + // SessionTTLMin + input = `{"session_ttl_min": "5s"}` + config, err = DecodeConfig(bytes.NewReader([]byte(input))) + if err != nil { + t.Fatalf("err: %s", err) + } + + if config.SessionTTLMin != 5*time.Second { + t.Fatalf("bad: %s %#v", config.SessionTTLMin.String(), config) + } } func TestDecodeConfig_invalidKeys(t *testing.T) { @@ -1120,6 +1131,8 @@ func TestMergeConfig(t *testing.T) { AtlasToken: "123456789", AtlasACLToken: "abcdefgh", AtlasJoin: true, + SessionTTLMinRaw: "1000s", + SessionTTLMin: 1000 * time.Second, } c := MergeConfig(a, b) diff --git a/command/agent/http_test.go b/command/agent/http_test.go index a6ec471d5f..1faa4ac992 100644 --- a/command/agent/http_test.go +++ b/command/agent/http_test.go @@ -521,7 +521,11 @@ func getIndex(t *testing.T, resp *httptest.ResponseRecorder) uint64 { } func httpTest(t *testing.T, f func(srv *HTTPServer)) { - dir, srv := makeHTTPServer(t) + httpTestWithConfig(t, f, nil) +} + +func httpTestWithConfig(t *testing.T, f func(srv *HTTPServer), cb func(c *Config)) { + dir, srv := makeHTTPServerWithConfig(t, cb) defer os.RemoveAll(dir) defer srv.Shutdown() defer srv.agent.Shutdown() diff --git a/command/agent/session_endpoint.go b/command/agent/session_endpoint.go index 25d48634cf..0750de7d27 100644 --- a/command/agent/session_endpoint.go +++ b/command/agent/session_endpoint.go @@ -53,21 +53,6 @@ func (s *HTTPServer) SessionCreate(resp http.ResponseWriter, req *http.Request) resp.Write([]byte(fmt.Sprintf("Request decode failed: %v", err))) return nil, nil } - - if args.Session.TTL != "" { - ttl, err := time.ParseDuration(args.Session.TTL) - if err != nil { - resp.WriteHeader(400) - resp.Write([]byte(fmt.Sprintf("Request TTL decode failed: %v", err))) - return nil, nil - } - - if ttl != 0 && (ttl < structs.SessionTTLMin || ttl > structs.SessionTTLMax) { - resp.WriteHeader(400) - resp.Write([]byte(fmt.Sprintf("Request TTL '%s', must be between [%v-%v]", args.Session.TTL, structs.SessionTTLMin, structs.SessionTTLMax))) - return nil, nil - } - } } // Create the session, get the ID diff --git a/command/agent/session_endpoint_test.go b/command/agent/session_endpoint_test.go index edfa074402..960fdeec18 100644 --- a/command/agent/session_endpoint_test.go +++ b/command/agent/session_endpoint_test.go @@ -3,12 +3,13 @@ package agent import ( "bytes" "encoding/json" - "github.com/hashicorp/consul/consul" - "github.com/hashicorp/consul/consul/structs" "net/http" "net/http/httptest" "testing" "time" + + "github.com/hashicorp/consul/consul" + "github.com/hashicorp/consul/consul/structs" ) func TestSessionCreate(t *testing.T) { @@ -215,9 +216,20 @@ func TestSessionDestroy(t *testing.T) { } func TestSessionTTL(t *testing.T) { - httpTest(t, func(srv *HTTPServer) { - TTL := "10s" // use the minimum legal ttl - ttl := 10 * time.Second + // use the minimum legal ttl + testSessionTTL(t, 10*time.Second, nil) +} + +func TestSessionTTLConfig(t *testing.T) { + testSessionTTL(t, 1*time.Second, func(c *Config) { + c.SessionTTLMinRaw = "1s" + c.SessionTTLMin = 1 * time.Second + }) +} + +func testSessionTTL(t *testing.T, ttl time.Duration, cb func(c *Config)) { + httpTestWithConfig(t, func(srv *HTTPServer) { + TTL := ttl.String() id := makeTestSessionTTL(t, srv, TTL) @@ -252,85 +264,7 @@ func TestSessionTTL(t *testing.T) { if len(respObj) != 0 { t.Fatalf("session '%s' should have been destroyed", id) } - }) -} - -func TestSessionBadTTL(t *testing.T) { - httpTest(t, func(srv *HTTPServer) { - badTTL := "10z" - - // Create Session with illegal TTL - body := bytes.NewBuffer(nil) - enc := json.NewEncoder(body) - raw := map[string]interface{}{ - "TTL": badTTL, - } - enc.Encode(raw) - - req, err := http.NewRequest("PUT", "/v1/session/create", body) - if err != nil { - t.Fatalf("err: %v", err) - } - resp := httptest.NewRecorder() - obj, err := srv.SessionCreate(resp, req) - if err != nil { - t.Fatalf("err: %v", err) - } - if obj != nil { - t.Fatalf("illegal TTL '%s' allowed", badTTL) - } - if resp.Code != 400 { - t.Fatalf("Bad response code, should be 400") - } - - // less than SessionTTLMin - body = bytes.NewBuffer(nil) - enc = json.NewEncoder(body) - raw = map[string]interface{}{ - "TTL": "5s", - } - enc.Encode(raw) - - req, err = http.NewRequest("PUT", "/v1/session/create", body) - if err != nil { - t.Fatalf("err: %v", err) - } - resp = httptest.NewRecorder() - obj, err = srv.SessionCreate(resp, req) - if err != nil { - t.Fatalf("err: %v", err) - } - if obj != nil { - t.Fatalf("illegal TTL '%s' allowed", badTTL) - } - if resp.Code != 400 { - t.Fatalf("Bad response code, should be 400") - } - - // more than SessionTTLMax - body = bytes.NewBuffer(nil) - enc = json.NewEncoder(body) - raw = map[string]interface{}{ - "TTL": "4000s", - } - enc.Encode(raw) - - req, err = http.NewRequest("PUT", "/v1/session/create", body) - if err != nil { - t.Fatalf("err: %v", err) - } - resp = httptest.NewRecorder() - obj, err = srv.SessionCreate(resp, req) - if err != nil { - t.Fatalf("err: %v", err) - } - if obj != nil { - t.Fatalf("illegal TTL '%s' allowed", badTTL) - } - if resp.Code != 400 { - t.Fatalf("Bad response code, should be 400") - } - }) + }, cb) } func TestSessionTTLRenew(t *testing.T) { diff --git a/consul/config.go b/consul/config.go index 4e834b178c..0b45c97c65 100644 --- a/consul/config.go +++ b/consul/config.go @@ -181,6 +181,9 @@ type Config struct { // to reduce overhead. It is unlikely a user would ever need to tune this. TombstoneTTLGranularity time.Duration + // Minimum Session TTL + SessionTTLMin time.Duration + // ServerUp callback can be used to trigger a notification that // a Consul server is now up and known about. ServerUp func() @@ -241,6 +244,7 @@ func DefaultConfig() *Config { ACLDownPolicy: "extend-cache", TombstoneTTL: 15 * time.Minute, TombstoneTTLGranularity: 30 * time.Second, + SessionTTLMin: 10 * time.Second, } // Increase our reap interval to 3 days instead of 24h. diff --git a/consul/session_endpoint.go b/consul/session_endpoint.go index af4c395048..2d6fe05ada 100644 --- a/consul/session_endpoint.go +++ b/consul/session_endpoint.go @@ -47,9 +47,9 @@ func (s *Session) Apply(args *structs.SessionRequest, reply *string) error { return fmt.Errorf("Session TTL '%s' invalid: %v", args.Session.TTL, err) } - if ttl != 0 && (ttl < structs.SessionTTLMin || ttl > structs.SessionTTLMax) { + if ttl != 0 && (ttl < s.srv.config.SessionTTLMin || ttl > structs.SessionTTLMax) { return fmt.Errorf("Invalid Session TTL '%d', must be between [%v=%v]", - ttl, structs.SessionTTLMin, structs.SessionTTLMax) + ttl, s.srv.config.SessionTTLMin, structs.SessionTTLMax) } } diff --git a/consul/session_endpoint_test.go b/consul/session_endpoint_test.go index fa721ada93..267d6f1946 100644 --- a/consul/session_endpoint_test.go +++ b/consul/session_endpoint_test.go @@ -483,3 +483,56 @@ func TestSessionEndpoint_NodeSessions(t *testing.T) { } } } + +func TestSessionEndpoint_Apply_BadTTL(t *testing.T) { + dir1, s1 := testServer(t) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + client := rpcClient(t, s1) + defer client.Close() + + testutil.WaitForLeader(t, client.Call, "dc1") + + arg := structs.SessionRequest{ + Datacenter: "dc1", + Op: structs.SessionCreate, + Session: structs.Session{ + Node: "foo", + Name: "my-session", + }, + } + + // Session with illegal TTL + arg.Session.TTL = "10z" + + var out string + err := client.Call("Session.Apply", &arg, &out) + if err == nil { + t.Fatal("expected error") + } + if err.Error() != "Session TTL '10z' invalid: time: unknown unit z in duration 10z" { + t.Fatalf("incorrect error message: %s", err.Error()) + } + + // less than SessionTTLMin + arg.Session.TTL = "5s" + + err = client.Call("Session.Apply", &arg, &out) + if err == nil { + t.Fatal("expected error") + } + if err.Error() != "Invalid Session TTL '5000000000', must be between [10s=1h0m0s]" { + t.Fatalf("incorrect error message: %s", err.Error()) + } + + // more than SessionTTLMax + arg.Session.TTL = "4000s" + + err = client.Call("Session.Apply", &arg, &out) + if err == nil { + t.Fatal("expected error") + } + if err.Error() != "Invalid Session TTL '4000000000000', must be between [10s=1h0m0s]" { + t.Fatalf("incorrect error message: %s", err.Error()) + } +} diff --git a/consul/state_store.go b/consul/state_store.go index d8783bf189..074261a1b9 100644 --- a/consul/state_store.go +++ b/consul/state_store.go @@ -1625,18 +1625,6 @@ func (s *StateStore) SessionCreate(index uint64, session *structs.Session) error return fmt.Errorf("Invalid Session Behavior setting '%s'", session.Behavior) } - if session.TTL != "" { - ttl, err := time.ParseDuration(session.TTL) - if err != nil { - return fmt.Errorf("Invalid Session TTL '%s': %v", session.TTL, err) - } - - if ttl != 0 && (ttl < structs.SessionTTLMin || ttl > structs.SessionTTLMax) { - return fmt.Errorf("Invalid Session TTL '%s', must be between [%v-%v]", - session.TTL, structs.SessionTTLMin, structs.SessionTTLMax) - } - } - // Assign the create index session.CreateIndex = index diff --git a/consul/structs/structs.go b/consul/structs/structs.go index 9ebaefb33f..b95711bc44 100644 --- a/consul/structs/structs.go +++ b/consul/structs/structs.go @@ -391,7 +391,6 @@ const ( ) const ( - SessionTTLMin = 10 * time.Second SessionTTLMax = 3600 * time.Second SessionTTLMultiplier = 2 )