Merge pull request #821 from fraenkel/sessionttlmin

Support SesionTTLMin configuration
This commit is contained in:
Ryan Uber 2015-03-27 11:24:22 -07:00
commit cf668aba06
11 changed files with 114 additions and 116 deletions

View File

@ -261,6 +261,9 @@ func (a *Agent) consulConfig() *consul.Config {
if a.config.ACLDownPolicy != "" { if a.config.ACLDownPolicy != "" {
base.ACLDownPolicy = a.config.ACLDownPolicy base.ACLDownPolicy = a.config.ACLDownPolicy
} }
if a.config.SessionTTLMinRaw != "" {
base.SessionTTLMin = a.config.SessionTTLMin
}
// Format the build string // Format the build string
revision := a.config.Revision revision := a.config.Revision

View File

@ -363,6 +363,10 @@ type Config struct {
// UnixSockets is a map of socket configuration data // UnixSockets is a map of socket configuration data
UnixSockets UnixSocketConfig `mapstructure:"unix_sockets"` UnixSockets UnixSocketConfig `mapstructure:"unix_sockets"`
// Minimum Session TTL
SessionTTLMin time.Duration `mapstructure:"-"`
SessionTTLMinRaw string `mapstructure:"session_ttl_min"`
} }
// UnixSocketPermissions contains information about a unix socket, and // UnixSocketPermissions contains information about a unix socket, and
@ -609,6 +613,14 @@ func DecodeConfig(r io.Reader) (*Config, error) {
result.DNSRecursors = append(result.DNSRecursors, result.DNSRecursor) result.DNSRecursors = append(result.DNSRecursors, result.DNSRecursor)
} }
if raw := result.SessionTTLMinRaw; raw != "" {
dur, err := time.ParseDuration(raw)
if err != nil {
return nil, fmt.Errorf("Session TTL Min invalid: %v", err)
}
result.SessionTTLMin = dur
}
return &result, nil return &result, nil
} }
@ -970,7 +982,10 @@ func MergeConfig(a, b *Config) *Config {
if b.AtlasJoin { if b.AtlasJoin {
result.AtlasJoin = true result.AtlasJoin = true
} }
if b.SessionTTLMinRaw != "" {
result.SessionTTLMin = b.SessionTTLMin
result.SessionTTLMinRaw = b.SessionTTLMinRaw
}
if len(b.HTTPAPIResponseHeaders) != 0 { if len(b.HTTPAPIResponseHeaders) != 0 {
if result.HTTPAPIResponseHeaders == nil { if result.HTTPAPIResponseHeaders == nil {
result.HTTPAPIResponseHeaders = make(map[string]string) result.HTTPAPIResponseHeaders = make(map[string]string)

View File

@ -653,6 +653,17 @@ func TestDecodeConfig(t *testing.T) {
if !config.AtlasJoin { if !config.AtlasJoin {
t.Fatalf("bad: %#v", config) t.Fatalf("bad: %#v", config)
} }
// SessionTTLMin
input = `{"session_ttl_min": "5s"}`
config, err = DecodeConfig(bytes.NewReader([]byte(input)))
if err != nil {
t.Fatalf("err: %s", err)
}
if config.SessionTTLMin != 5*time.Second {
t.Fatalf("bad: %s %#v", config.SessionTTLMin.String(), config)
}
} }
func TestDecodeConfig_invalidKeys(t *testing.T) { func TestDecodeConfig_invalidKeys(t *testing.T) {
@ -1120,6 +1131,8 @@ func TestMergeConfig(t *testing.T) {
AtlasToken: "123456789", AtlasToken: "123456789",
AtlasACLToken: "abcdefgh", AtlasACLToken: "abcdefgh",
AtlasJoin: true, AtlasJoin: true,
SessionTTLMinRaw: "1000s",
SessionTTLMin: 1000 * time.Second,
} }
c := MergeConfig(a, b) c := MergeConfig(a, b)

View File

@ -521,7 +521,11 @@ func getIndex(t *testing.T, resp *httptest.ResponseRecorder) uint64 {
} }
func httpTest(t *testing.T, f func(srv *HTTPServer)) { func httpTest(t *testing.T, f func(srv *HTTPServer)) {
dir, srv := makeHTTPServer(t) httpTestWithConfig(t, f, nil)
}
func httpTestWithConfig(t *testing.T, f func(srv *HTTPServer), cb func(c *Config)) {
dir, srv := makeHTTPServerWithConfig(t, cb)
defer os.RemoveAll(dir) defer os.RemoveAll(dir)
defer srv.Shutdown() defer srv.Shutdown()
defer srv.agent.Shutdown() defer srv.agent.Shutdown()

View File

@ -53,21 +53,6 @@ func (s *HTTPServer) SessionCreate(resp http.ResponseWriter, req *http.Request)
resp.Write([]byte(fmt.Sprintf("Request decode failed: %v", err))) resp.Write([]byte(fmt.Sprintf("Request decode failed: %v", err)))
return nil, nil return nil, nil
} }
if args.Session.TTL != "" {
ttl, err := time.ParseDuration(args.Session.TTL)
if err != nil {
resp.WriteHeader(400)
resp.Write([]byte(fmt.Sprintf("Request TTL decode failed: %v", err)))
return nil, nil
}
if ttl != 0 && (ttl < structs.SessionTTLMin || ttl > structs.SessionTTLMax) {
resp.WriteHeader(400)
resp.Write([]byte(fmt.Sprintf("Request TTL '%s', must be between [%v-%v]", args.Session.TTL, structs.SessionTTLMin, structs.SessionTTLMax)))
return nil, nil
}
}
} }
// Create the session, get the ID // Create the session, get the ID

View File

@ -3,12 +3,13 @@ package agent
import ( import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"github.com/hashicorp/consul/consul"
"github.com/hashicorp/consul/consul/structs"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"time" "time"
"github.com/hashicorp/consul/consul"
"github.com/hashicorp/consul/consul/structs"
) )
func TestSessionCreate(t *testing.T) { func TestSessionCreate(t *testing.T) {
@ -215,9 +216,20 @@ func TestSessionDestroy(t *testing.T) {
} }
func TestSessionTTL(t *testing.T) { func TestSessionTTL(t *testing.T) {
httpTest(t, func(srv *HTTPServer) { // use the minimum legal ttl
TTL := "10s" // use the minimum legal ttl testSessionTTL(t, 10*time.Second, nil)
ttl := 10 * time.Second }
func TestSessionTTLConfig(t *testing.T) {
testSessionTTL(t, 1*time.Second, func(c *Config) {
c.SessionTTLMinRaw = "1s"
c.SessionTTLMin = 1 * time.Second
})
}
func testSessionTTL(t *testing.T, ttl time.Duration, cb func(c *Config)) {
httpTestWithConfig(t, func(srv *HTTPServer) {
TTL := ttl.String()
id := makeTestSessionTTL(t, srv, TTL) id := makeTestSessionTTL(t, srv, TTL)
@ -252,85 +264,7 @@ func TestSessionTTL(t *testing.T) {
if len(respObj) != 0 { if len(respObj) != 0 {
t.Fatalf("session '%s' should have been destroyed", id) t.Fatalf("session '%s' should have been destroyed", id)
} }
}) }, cb)
}
func TestSessionBadTTL(t *testing.T) {
httpTest(t, func(srv *HTTPServer) {
badTTL := "10z"
// Create Session with illegal TTL
body := bytes.NewBuffer(nil)
enc := json.NewEncoder(body)
raw := map[string]interface{}{
"TTL": badTTL,
}
enc.Encode(raw)
req, err := http.NewRequest("PUT", "/v1/session/create", body)
if err != nil {
t.Fatalf("err: %v", err)
}
resp := httptest.NewRecorder()
obj, err := srv.SessionCreate(resp, req)
if err != nil {
t.Fatalf("err: %v", err)
}
if obj != nil {
t.Fatalf("illegal TTL '%s' allowed", badTTL)
}
if resp.Code != 400 {
t.Fatalf("Bad response code, should be 400")
}
// less than SessionTTLMin
body = bytes.NewBuffer(nil)
enc = json.NewEncoder(body)
raw = map[string]interface{}{
"TTL": "5s",
}
enc.Encode(raw)
req, err = http.NewRequest("PUT", "/v1/session/create", body)
if err != nil {
t.Fatalf("err: %v", err)
}
resp = httptest.NewRecorder()
obj, err = srv.SessionCreate(resp, req)
if err != nil {
t.Fatalf("err: %v", err)
}
if obj != nil {
t.Fatalf("illegal TTL '%s' allowed", badTTL)
}
if resp.Code != 400 {
t.Fatalf("Bad response code, should be 400")
}
// more than SessionTTLMax
body = bytes.NewBuffer(nil)
enc = json.NewEncoder(body)
raw = map[string]interface{}{
"TTL": "4000s",
}
enc.Encode(raw)
req, err = http.NewRequest("PUT", "/v1/session/create", body)
if err != nil {
t.Fatalf("err: %v", err)
}
resp = httptest.NewRecorder()
obj, err = srv.SessionCreate(resp, req)
if err != nil {
t.Fatalf("err: %v", err)
}
if obj != nil {
t.Fatalf("illegal TTL '%s' allowed", badTTL)
}
if resp.Code != 400 {
t.Fatalf("Bad response code, should be 400")
}
})
} }
func TestSessionTTLRenew(t *testing.T) { func TestSessionTTLRenew(t *testing.T) {

View File

@ -181,6 +181,9 @@ type Config struct {
// to reduce overhead. It is unlikely a user would ever need to tune this. // to reduce overhead. It is unlikely a user would ever need to tune this.
TombstoneTTLGranularity time.Duration TombstoneTTLGranularity time.Duration
// Minimum Session TTL
SessionTTLMin time.Duration
// ServerUp callback can be used to trigger a notification that // ServerUp callback can be used to trigger a notification that
// a Consul server is now up and known about. // a Consul server is now up and known about.
ServerUp func() ServerUp func()
@ -241,6 +244,7 @@ func DefaultConfig() *Config {
ACLDownPolicy: "extend-cache", ACLDownPolicy: "extend-cache",
TombstoneTTL: 15 * time.Minute, TombstoneTTL: 15 * time.Minute,
TombstoneTTLGranularity: 30 * time.Second, TombstoneTTLGranularity: 30 * time.Second,
SessionTTLMin: 10 * time.Second,
} }
// Increase our reap interval to 3 days instead of 24h. // Increase our reap interval to 3 days instead of 24h.

View File

@ -47,9 +47,9 @@ func (s *Session) Apply(args *structs.SessionRequest, reply *string) error {
return fmt.Errorf("Session TTL '%s' invalid: %v", args.Session.TTL, err) return fmt.Errorf("Session TTL '%s' invalid: %v", args.Session.TTL, err)
} }
if ttl != 0 && (ttl < structs.SessionTTLMin || ttl > structs.SessionTTLMax) { if ttl != 0 && (ttl < s.srv.config.SessionTTLMin || ttl > structs.SessionTTLMax) {
return fmt.Errorf("Invalid Session TTL '%d', must be between [%v=%v]", return fmt.Errorf("Invalid Session TTL '%d', must be between [%v=%v]",
ttl, structs.SessionTTLMin, structs.SessionTTLMax) ttl, s.srv.config.SessionTTLMin, structs.SessionTTLMax)
} }
} }

View File

@ -483,3 +483,56 @@ func TestSessionEndpoint_NodeSessions(t *testing.T) {
} }
} }
} }
func TestSessionEndpoint_Apply_BadTTL(t *testing.T) {
dir1, s1 := testServer(t)
defer os.RemoveAll(dir1)
defer s1.Shutdown()
client := rpcClient(t, s1)
defer client.Close()
testutil.WaitForLeader(t, client.Call, "dc1")
arg := structs.SessionRequest{
Datacenter: "dc1",
Op: structs.SessionCreate,
Session: structs.Session{
Node: "foo",
Name: "my-session",
},
}
// Session with illegal TTL
arg.Session.TTL = "10z"
var out string
err := client.Call("Session.Apply", &arg, &out)
if err == nil {
t.Fatal("expected error")
}
if err.Error() != "Session TTL '10z' invalid: time: unknown unit z in duration 10z" {
t.Fatalf("incorrect error message: %s", err.Error())
}
// less than SessionTTLMin
arg.Session.TTL = "5s"
err = client.Call("Session.Apply", &arg, &out)
if err == nil {
t.Fatal("expected error")
}
if err.Error() != "Invalid Session TTL '5000000000', must be between [10s=1h0m0s]" {
t.Fatalf("incorrect error message: %s", err.Error())
}
// more than SessionTTLMax
arg.Session.TTL = "4000s"
err = client.Call("Session.Apply", &arg, &out)
if err == nil {
t.Fatal("expected error")
}
if err.Error() != "Invalid Session TTL '4000000000000', must be between [10s=1h0m0s]" {
t.Fatalf("incorrect error message: %s", err.Error())
}
}

View File

@ -1625,18 +1625,6 @@ func (s *StateStore) SessionCreate(index uint64, session *structs.Session) error
return fmt.Errorf("Invalid Session Behavior setting '%s'", session.Behavior) return fmt.Errorf("Invalid Session Behavior setting '%s'", session.Behavior)
} }
if session.TTL != "" {
ttl, err := time.ParseDuration(session.TTL)
if err != nil {
return fmt.Errorf("Invalid Session TTL '%s': %v", session.TTL, err)
}
if ttl != 0 && (ttl < structs.SessionTTLMin || ttl > structs.SessionTTLMax) {
return fmt.Errorf("Invalid Session TTL '%s', must be between [%v-%v]",
session.TTL, structs.SessionTTLMin, structs.SessionTTLMax)
}
}
// Assign the create index // Assign the create index
session.CreateIndex = index session.CreateIndex = index

View File

@ -391,7 +391,6 @@ const (
) )
const ( const (
SessionTTLMin = 10 * time.Second
SessionTTLMax = 3600 * time.Second SessionTTLMax = 3600 * time.Second
SessionTTLMultiplier = 2 SessionTTLMultiplier = 2
) )