mirror of
https://github.com/status-im/consul.git
synced 2025-01-25 21:19:12 +00:00
commit
29afa881f4
@ -188,6 +188,7 @@ func (s *HTTPServer) registerHandlers(enableDebug bool) {
|
||||
|
||||
s.mux.HandleFunc("/v1/session/create", s.wrap(s.SessionCreate))
|
||||
s.mux.HandleFunc("/v1/session/destroy/", s.wrap(s.SessionDestroy))
|
||||
s.mux.HandleFunc("/v1/session/renew/", s.wrap(s.SessionRenew))
|
||||
s.mux.HandleFunc("/v1/session/info/", s.wrap(s.SessionGet))
|
||||
s.mux.HandleFunc("/v1/session/node/", s.wrap(s.SessionsForNode))
|
||||
s.mux.HandleFunc("/v1/session/list", s.wrap(s.SessionList))
|
||||
|
@ -40,6 +40,7 @@ func (s *HTTPServer) SessionCreate(resp http.ResponseWriter, req *http.Request)
|
||||
Checks: []string{consul.SerfCheckID},
|
||||
LockDelay: 15 * time.Second,
|
||||
Behavior: structs.SessionKeysRelease,
|
||||
TTL: "",
|
||||
},
|
||||
}
|
||||
s.parseDC(req, &args.Datacenter)
|
||||
@ -51,6 +52,21 @@ func (s *HTTPServer) SessionCreate(resp http.ResponseWriter, req *http.Request)
|
||||
resp.Write([]byte(fmt.Sprintf("Request decode failed: %v", err)))
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if args.Session.TTL != "" {
|
||||
ttl, err := time.ParseDuration(args.Session.TTL)
|
||||
if err != nil {
|
||||
resp.WriteHeader(400)
|
||||
resp.Write([]byte(fmt.Sprintf("Request TTL decode failed: %v", err)))
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if ttl < structs.SessionTTLMin || ttl > structs.SessionTTLMax {
|
||||
resp.WriteHeader(400)
|
||||
resp.Write([]byte(fmt.Sprintf("Request TTL '%s', must be between [%v-%v]", args.Session.TTL, structs.SessionTTLMin, structs.SessionTTLMax)))
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create the session, get the ID
|
||||
@ -130,6 +146,39 @@ func (s *HTTPServer) SessionDestroy(resp http.ResponseWriter, req *http.Request)
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// SessionRenew is used to renew the TTL on an existing TTL session
|
||||
func (s *HTTPServer) SessionRenew(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
// Mandate a PUT request
|
||||
if req.Method != "PUT" {
|
||||
resp.WriteHeader(405)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
args := structs.SessionSpecificRequest{}
|
||||
if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Pull out the session id
|
||||
args.Session = strings.TrimPrefix(req.URL.Path, "/v1/session/renew/")
|
||||
if args.Session == "" {
|
||||
resp.WriteHeader(400)
|
||||
resp.Write([]byte("Missing session"))
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var out structs.IndexedSessions
|
||||
if err := s.agent.RPC("Session.Renew", &args, &out); err != nil {
|
||||
return nil, err
|
||||
} else if out.Sessions == nil {
|
||||
resp.WriteHeader(404)
|
||||
resp.Write([]byte(fmt.Sprintf("Session id '%s' not found", args.Session)))
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return out.Sessions, nil
|
||||
}
|
||||
|
||||
// SessionGet is used to get info for a particular session
|
||||
func (s *HTTPServer) SessionGet(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
args := structs.SessionSpecificRequest{}
|
||||
|
@ -176,6 +176,28 @@ func makeTestSessionDelete(t *testing.T, srv *HTTPServer) string {
|
||||
return sessResp.ID
|
||||
}
|
||||
|
||||
func makeTestSessionTTL(t *testing.T, srv *HTTPServer, ttl string) string {
|
||||
// Create Session with TTL
|
||||
body := bytes.NewBuffer(nil)
|
||||
enc := json.NewEncoder(body)
|
||||
raw := map[string]interface{}{
|
||||
"TTL": ttl,
|
||||
}
|
||||
enc.Encode(raw)
|
||||
|
||||
req, err := http.NewRequest("PUT", "/v1/session/create", body)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := srv.SessionCreate(resp, req)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
sessResp := obj.(sessionCreateResponse)
|
||||
return sessResp.ID
|
||||
}
|
||||
|
||||
func TestSessionDestroy(t *testing.T) {
|
||||
httpTest(t, func(srv *HTTPServer) {
|
||||
id := makeTestSession(t, srv)
|
||||
@ -192,6 +214,206 @@ func TestSessionDestroy(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestSessionTTL(t *testing.T) {
|
||||
httpTest(t, func(srv *HTTPServer) {
|
||||
TTL := "10s" // use the minimum legal ttl
|
||||
ttl := 10 * time.Second
|
||||
|
||||
id := makeTestSessionTTL(t, srv, TTL)
|
||||
|
||||
req, err := http.NewRequest("GET",
|
||||
"/v1/session/info/"+id, nil)
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := srv.SessionGet(resp, req)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
respObj, ok := obj.(structs.Sessions)
|
||||
if !ok {
|
||||
t.Fatalf("should work")
|
||||
}
|
||||
if len(respObj) != 1 {
|
||||
t.Fatalf("bad: %v", respObj)
|
||||
}
|
||||
if respObj[0].TTL != TTL {
|
||||
t.Fatalf("Incorrect TTL: %s", respObj[0].TTL)
|
||||
}
|
||||
|
||||
time.Sleep(ttl*structs.SessionTTLMultiplier + ttl)
|
||||
|
||||
req, err = http.NewRequest("GET",
|
||||
"/v1/session/info/"+id, nil)
|
||||
resp = httptest.NewRecorder()
|
||||
obj, err = srv.SessionGet(resp, req)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
respObj, ok = obj.(structs.Sessions)
|
||||
if len(respObj) != 0 {
|
||||
t.Fatalf("session '%s' should have been destroyed", id)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSessionBadTTL(t *testing.T) {
|
||||
httpTest(t, func(srv *HTTPServer) {
|
||||
badTTL := "10z"
|
||||
|
||||
// Create Session with illegal TTL
|
||||
body := bytes.NewBuffer(nil)
|
||||
enc := json.NewEncoder(body)
|
||||
raw := map[string]interface{}{
|
||||
"TTL": badTTL,
|
||||
}
|
||||
enc.Encode(raw)
|
||||
|
||||
req, err := http.NewRequest("PUT", "/v1/session/create", body)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := srv.SessionCreate(resp, req)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if obj != nil {
|
||||
t.Fatalf("illegal TTL '%s' allowed", badTTL)
|
||||
}
|
||||
if resp.Code != 400 {
|
||||
t.Fatalf("Bad response code, should be 400")
|
||||
}
|
||||
|
||||
// less than SessionTTLMin
|
||||
body = bytes.NewBuffer(nil)
|
||||
enc = json.NewEncoder(body)
|
||||
raw = map[string]interface{}{
|
||||
"TTL": "5s",
|
||||
}
|
||||
enc.Encode(raw)
|
||||
|
||||
req, err = http.NewRequest("PUT", "/v1/session/create", body)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
resp = httptest.NewRecorder()
|
||||
obj, err = srv.SessionCreate(resp, req)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if obj != nil {
|
||||
t.Fatalf("illegal TTL '%s' allowed", badTTL)
|
||||
}
|
||||
if resp.Code != 400 {
|
||||
t.Fatalf("Bad response code, should be 400")
|
||||
}
|
||||
|
||||
// more than SessionTTLMax
|
||||
body = bytes.NewBuffer(nil)
|
||||
enc = json.NewEncoder(body)
|
||||
raw = map[string]interface{}{
|
||||
"TTL": "4000s",
|
||||
}
|
||||
enc.Encode(raw)
|
||||
|
||||
req, err = http.NewRequest("PUT", "/v1/session/create", body)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
resp = httptest.NewRecorder()
|
||||
obj, err = srv.SessionCreate(resp, req)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if obj != nil {
|
||||
t.Fatalf("illegal TTL '%s' allowed", badTTL)
|
||||
}
|
||||
if resp.Code != 400 {
|
||||
t.Fatalf("Bad response code, should be 400")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSessionTTLRenew(t *testing.T) {
|
||||
httpTest(t, func(srv *HTTPServer) {
|
||||
TTL := "10s" // use the minimum legal ttl
|
||||
ttl := 10 * time.Second
|
||||
|
||||
id := makeTestSessionTTL(t, srv, TTL)
|
||||
|
||||
req, err := http.NewRequest("GET",
|
||||
"/v1/session/info/"+id, nil)
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := srv.SessionGet(resp, req)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
respObj, ok := obj.(structs.Sessions)
|
||||
if !ok {
|
||||
t.Fatalf("should work")
|
||||
}
|
||||
if len(respObj) != 1 {
|
||||
t.Fatalf("bad: %v", respObj)
|
||||
}
|
||||
if respObj[0].TTL != TTL {
|
||||
t.Fatalf("Incorrect TTL: %s", respObj[0].TTL)
|
||||
}
|
||||
|
||||
// Sleep to consume some time before renew
|
||||
time.Sleep(ttl * (structs.SessionTTLMultiplier / 2))
|
||||
|
||||
req, err = http.NewRequest("PUT",
|
||||
"/v1/session/renew/"+id, nil)
|
||||
resp = httptest.NewRecorder()
|
||||
obj, err = srv.SessionRenew(resp, req)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
respObj, ok = obj.(structs.Sessions)
|
||||
if !ok {
|
||||
t.Fatalf("should work")
|
||||
}
|
||||
if len(respObj) != 1 {
|
||||
t.Fatalf("bad: %v", respObj)
|
||||
}
|
||||
|
||||
// Sleep for ttl * TTL Multiplier
|
||||
time.Sleep(ttl * structs.SessionTTLMultiplier)
|
||||
|
||||
req, err = http.NewRequest("GET",
|
||||
"/v1/session/info/"+id, nil)
|
||||
resp = httptest.NewRecorder()
|
||||
obj, err = srv.SessionGet(resp, req)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
respObj, ok = obj.(structs.Sessions)
|
||||
if !ok {
|
||||
t.Fatalf("session '%s' should have renewed", id)
|
||||
}
|
||||
if len(respObj) != 1 {
|
||||
t.Fatalf("session '%s' should have renewed", id)
|
||||
}
|
||||
|
||||
// now wait for timeout and expect session to get destroyed
|
||||
time.Sleep(ttl * structs.SessionTTLMultiplier)
|
||||
|
||||
req, err = http.NewRequest("GET",
|
||||
"/v1/session/info/"+id, nil)
|
||||
resp = httptest.NewRecorder()
|
||||
obj, err = srv.SessionGet(resp, req)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
respObj, ok = obj.(structs.Sessions)
|
||||
if !ok {
|
||||
t.Fatalf("session '%s' should have destroyed", id)
|
||||
}
|
||||
if len(respObj) != 0 {
|
||||
t.Fatalf("session '%s' should have destroyed", id)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSessionGet(t *testing.T) {
|
||||
httpTest(t, func(srv *HTTPServer) {
|
||||
id := makeTestSession(t, srv)
|
||||
|
@ -61,6 +61,13 @@ func (s *Server) leaderLoop(stopCh chan struct{}) {
|
||||
s.logger.Printf("[ERR] consul: ACL initialization failed: %v", err)
|
||||
}
|
||||
|
||||
// Setup Session Timers if we are the leader and need to
|
||||
if err := s.initializeSessionTimers(); err != nil {
|
||||
s.logger.Printf("[ERR] consul: Session Timers initialization failed: %v", err)
|
||||
}
|
||||
// clear the session timers if we are no longer leader and exit the leaderLoop
|
||||
defer s.clearAllSessionTimers()
|
||||
|
||||
// Reconcile channel is only used once initial reconcile
|
||||
// has succeeded
|
||||
var reconcileCh chan serf.Member
|
||||
|
@ -370,6 +370,9 @@ func TestLeader_LeftLeader(t *testing.T) {
|
||||
break
|
||||
}
|
||||
}
|
||||
if leader == nil {
|
||||
t.Fatalf("Should have a leader")
|
||||
}
|
||||
leader.Leave()
|
||||
leader.Shutdown()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
@ -128,6 +128,12 @@ type Server struct {
|
||||
// which SHOULD only consist of Consul servers
|
||||
serfWAN *serf.Serf
|
||||
|
||||
// sessionTimers track the expiration time of each Session that has
|
||||
// a TTL. On expiration, a SessionDestroy event will occur, and
|
||||
// destroy the session via standard session destory processing
|
||||
sessionTimers map[string]*time.Timer
|
||||
sessionTimersLock sync.RWMutex
|
||||
|
||||
shutdown bool
|
||||
shutdownCh chan struct{}
|
||||
shutdownLock sync.Mutex
|
||||
|
@ -36,6 +36,16 @@ func (s *Session) Apply(args *structs.SessionRequest, reply *string) error {
|
||||
default:
|
||||
return fmt.Errorf("Invalid Behavior setting '%s'", args.Session.Behavior)
|
||||
}
|
||||
if args.Session.TTL != "" {
|
||||
ttl, err := time.ParseDuration(args.Session.TTL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Session TTL '%s' invalid: %v", args.Session.TTL, err)
|
||||
}
|
||||
|
||||
if ttl < structs.SessionTTLMin || ttl > structs.SessionTTLMax {
|
||||
return fmt.Errorf("Invalid Session TTL '%d', must be between [%v=%v]", ttl, structs.SessionTTLMin, structs.SessionTTLMax)
|
||||
}
|
||||
}
|
||||
|
||||
// If this is a create, we must generate the Session ID. This must
|
||||
// be done prior to appending to the raft log, because the ID is not
|
||||
@ -63,6 +73,13 @@ func (s *Session) Apply(args *structs.SessionRequest, reply *string) error {
|
||||
s.srv.logger.Printf("[ERR] consul.session: Apply failed: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
if args.Op == structs.SessionCreate && args.Session.TTL != "" {
|
||||
s.srv.resetSessionTimer(args.Session.ID, nil)
|
||||
} else if args.Op == structs.SessionDestroy && args.Session.TTL != "" {
|
||||
s.srv.clearSessionTimer(args.Session.ID)
|
||||
}
|
||||
|
||||
if respErr, ok := resp.(error); ok {
|
||||
return respErr
|
||||
}
|
||||
@ -133,3 +150,24 @@ func (s *Session) NodeSessions(args *structs.NodeSpecificRequest,
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
// Renew is used to renew the TTL on a single session
|
||||
func (s *Session) Renew(args *structs.SessionSpecificRequest,
|
||||
reply *structs.IndexedSessions) error {
|
||||
if done, err := s.srv.forward("Session.Renew", args, args, reply); done {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get the local state
|
||||
state := s.srv.fsm.State()
|
||||
// Get the session, from local state
|
||||
index, session, err := state.SessionGet(args.Session)
|
||||
reply.Index = index
|
||||
if session != nil {
|
||||
reply.Sessions = structs.Sessions{session}
|
||||
// reset the session TTL timer
|
||||
err = s.srv.resetSessionTimer(args.Session, session)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
"github.com/hashicorp/consul/testutil"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestSessionEndpoint_Apply(t *testing.T) {
|
||||
@ -223,6 +224,161 @@ func TestSessionEndpoint_List(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionEndpoint_Renew(t *testing.T) {
|
||||
dir1, s1 := testServer(t)
|
||||
defer os.RemoveAll(dir1)
|
||||
defer s1.Shutdown()
|
||||
client := rpcClient(t, s1)
|
||||
defer client.Close()
|
||||
|
||||
testutil.WaitForLeader(t, client.Call, "dc1")
|
||||
TTL := "10s" // the minimum allowed ttl
|
||||
ttl := 10 * time.Second
|
||||
|
||||
s1.fsm.State().EnsureNode(1, structs.Node{"foo", "127.0.0.1"})
|
||||
ids := []string{}
|
||||
for i := 0; i < 5; i++ {
|
||||
arg := structs.SessionRequest{
|
||||
Datacenter: "dc1",
|
||||
Op: structs.SessionCreate,
|
||||
Session: structs.Session{
|
||||
Node: "foo",
|
||||
TTL: TTL,
|
||||
},
|
||||
}
|
||||
var out string
|
||||
if err := client.Call("Session.Apply", &arg, &out); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
ids = append(ids, out)
|
||||
}
|
||||
|
||||
getR := structs.DCSpecificRequest{
|
||||
Datacenter: "dc1",
|
||||
}
|
||||
|
||||
var sessions structs.IndexedSessions
|
||||
if err := client.Call("Session.List", &getR, &sessions); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
if sessions.Index == 0 {
|
||||
t.Fatalf("Bad: %v", sessions)
|
||||
}
|
||||
if len(sessions.Sessions) != 5 {
|
||||
t.Fatalf("Bad: %v", sessions.Sessions)
|
||||
}
|
||||
for i := 0; i < len(sessions.Sessions); i++ {
|
||||
s := sessions.Sessions[i]
|
||||
if !strContains(ids, s.ID) {
|
||||
t.Fatalf("bad: %v", s)
|
||||
}
|
||||
if s.Node != "foo" {
|
||||
t.Fatalf("bad: %v", s)
|
||||
}
|
||||
if s.TTL != TTL {
|
||||
t.Fatalf("bad session TTL: %s %v", s.TTL, s)
|
||||
}
|
||||
t.Logf("Created session '%s'", s.ID)
|
||||
}
|
||||
|
||||
// Sleep for time shorter than internal destroy ttl
|
||||
time.Sleep(ttl * structs.SessionTTLMultiplier / 2)
|
||||
|
||||
// renew 3 out of 5 sessions
|
||||
for i := 0; i < 3; i++ {
|
||||
renewR := structs.SessionSpecificRequest{
|
||||
Datacenter: "dc1",
|
||||
Session: ids[i],
|
||||
}
|
||||
var session structs.IndexedSessions
|
||||
if err := client.Call("Session.Renew", &renewR, &session); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
if session.Index == 0 {
|
||||
t.Fatalf("Bad: %v", session)
|
||||
}
|
||||
if len(session.Sessions) != 1 {
|
||||
t.Fatalf("Bad: %v", session.Sessions)
|
||||
}
|
||||
|
||||
s := session.Sessions[0]
|
||||
if !strContains(ids, s.ID) {
|
||||
t.Fatalf("bad: %v", s)
|
||||
}
|
||||
if s.Node != "foo" {
|
||||
t.Fatalf("bad: %v", s)
|
||||
}
|
||||
|
||||
t.Logf("Renewed session '%s'", s.ID)
|
||||
}
|
||||
|
||||
// now sleep for 2/3 the internal destroy TTL time for renewed sessions
|
||||
// which is more than the internal destroy TTL time for the non-renewed sessions
|
||||
time.Sleep((ttl * structs.SessionTTLMultiplier) * 2.0 / 3.0)
|
||||
|
||||
var sessionsL1 structs.IndexedSessions
|
||||
if err := client.Call("Session.List", &getR, &sessionsL1); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
if sessionsL1.Index == 0 {
|
||||
t.Fatalf("Bad: %v", sessionsL1)
|
||||
}
|
||||
|
||||
t.Logf("Expect 2 sessions to be destroyed")
|
||||
|
||||
for i := 0; i < len(sessionsL1.Sessions); i++ {
|
||||
s := sessionsL1.Sessions[i]
|
||||
if !strContains(ids, s.ID) {
|
||||
t.Fatalf("bad: %v", s)
|
||||
}
|
||||
if s.Node != "foo" {
|
||||
t.Fatalf("bad: %v", s)
|
||||
}
|
||||
if s.TTL != TTL {
|
||||
t.Fatalf("bad: %v", s)
|
||||
}
|
||||
if i > 2 {
|
||||
t.Errorf("session '%s' should be destroyed", s.ID)
|
||||
}
|
||||
}
|
||||
|
||||
if len(sessionsL1.Sessions) > 3 {
|
||||
t.Fatalf("Bad: %v", sessionsL1.Sessions)
|
||||
}
|
||||
|
||||
// now sleep again for ttl*2 - no sessions should still be alive
|
||||
time.Sleep(ttl * structs.SessionTTLMultiplier)
|
||||
|
||||
var sessionsL2 structs.IndexedSessions
|
||||
if err := client.Call("Session.List", &getR, &sessionsL2); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
if sessionsL2.Index == 0 {
|
||||
t.Fatalf("Bad: %v", sessionsL2)
|
||||
}
|
||||
if len(sessionsL2.Sessions) != 0 {
|
||||
for i := 0; i < len(sessionsL2.Sessions); i++ {
|
||||
s := sessionsL2.Sessions[i]
|
||||
if !strContains(ids, s.ID) {
|
||||
t.Fatalf("bad: %v", s)
|
||||
}
|
||||
if s.Node != "foo" {
|
||||
t.Fatalf("bad: %v", s)
|
||||
}
|
||||
if s.TTL != TTL {
|
||||
t.Fatalf("bad: %v", s)
|
||||
}
|
||||
t.Errorf("session '%s' should be destroyed", s.ID)
|
||||
}
|
||||
|
||||
t.Fatalf("Bad: %v", sessionsL2.Sessions)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionEndpoint_NodeSessions(t *testing.T) {
|
||||
dir1, s1 := testServer(t)
|
||||
defer os.RemoveAll(dir1)
|
||||
|
106
consul/session_ttl.go
Normal file
106
consul/session_ttl.go
Normal file
@ -0,0 +1,106 @@
|
||||
package consul
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/hashicorp/consul/consul/structs"
|
||||
"time"
|
||||
)
|
||||
|
||||
func (s *Server) initializeSessionTimers() error {
|
||||
s.sessionTimersLock.Lock()
|
||||
s.sessionTimers = make(map[string]*time.Timer)
|
||||
s.sessionTimersLock.Unlock()
|
||||
|
||||
// walk the TTL index and resetSessionTimer for each non-zero TTL
|
||||
state := s.fsm.State()
|
||||
_, sessions, err := state.SessionListTTL()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, session := range sessions {
|
||||
err := s.resetSessionTimer(session.ID, session)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// invalidate the session when timer expires, called by AfterFunc
|
||||
func (s *Server) invalidateSession(id string) {
|
||||
args := structs.SessionRequest{
|
||||
Datacenter: s.config.Datacenter,
|
||||
Op: structs.SessionDestroy,
|
||||
}
|
||||
args.Session.ID = id
|
||||
|
||||
// Apply the update to destroy the session
|
||||
_, err := s.raftApply(structs.SessionRequestType, args)
|
||||
if err != nil {
|
||||
s.logger.Printf("[ERR] consul.session: Apply failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) resetSessionTimer(id string, session *structs.Session) error {
|
||||
if session == nil {
|
||||
var err error
|
||||
|
||||
// find the session
|
||||
state := s.fsm.State()
|
||||
_, session, err = state.SessionGet(id)
|
||||
if err != nil || session == nil {
|
||||
return fmt.Errorf("Could not find session for '%s'\n", id)
|
||||
}
|
||||
}
|
||||
|
||||
if session.TTL == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
ttl, err := time.ParseDuration(session.TTL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Invalid Session TTL '%s': %v", session.TTL, err)
|
||||
}
|
||||
if ttl == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
s.sessionTimersLock.Lock()
|
||||
if s.sessionTimers == nil {
|
||||
s.sessionTimers = make(map[string]*time.Timer)
|
||||
}
|
||||
defer s.sessionTimersLock.Unlock()
|
||||
if t := s.sessionTimers[id]; t != nil {
|
||||
// TBD may modify the session's active TTL based on load here
|
||||
t.Reset(ttl * structs.SessionTTLMultiplier)
|
||||
} else {
|
||||
s.sessionTimers[session.ID] = time.AfterFunc(ttl*structs.SessionTTLMultiplier, func() {
|
||||
s.invalidateSession(session.ID)
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) clearSessionTimer(id string) error {
|
||||
s.sessionTimersLock.Lock()
|
||||
defer s.sessionTimersLock.Unlock()
|
||||
if s.sessionTimers[id] != nil {
|
||||
// stop the session timer and delete from the map
|
||||
s.sessionTimers[id].Stop()
|
||||
delete(s.sessionTimers, id)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) clearAllSessionTimers() error {
|
||||
s.sessionTimersLock.Lock()
|
||||
defer s.sessionTimersLock.Unlock()
|
||||
|
||||
// stop all timers and clear out the map
|
||||
for _, t := range s.sessionTimers {
|
||||
t.Stop()
|
||||
}
|
||||
s.sessionTimers = nil
|
||||
return nil
|
||||
}
|
168
consul/session_ttl_test.go
Normal file
168
consul/session_ttl_test.go
Normal file
@ -0,0 +1,168 @@
|
||||
package consul
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/consul/consul/structs"
|
||||
"github.com/hashicorp/consul/testutil"
|
||||
)
|
||||
|
||||
func TestServer_sessionTTL(t *testing.T) {
|
||||
dir1, s1 := testServer(t)
|
||||
defer os.RemoveAll(dir1)
|
||||
defer s1.Shutdown()
|
||||
|
||||
dir2, s2 := testServerDCBootstrap(t, "dc1", false)
|
||||
defer os.RemoveAll(dir2)
|
||||
defer s2.Shutdown()
|
||||
|
||||
dir3, s3 := testServerDCBootstrap(t, "dc1", false)
|
||||
defer os.RemoveAll(dir3)
|
||||
defer s3.Shutdown()
|
||||
servers := []*Server{s1, s2, s3}
|
||||
|
||||
// Try to join
|
||||
addr := fmt.Sprintf("127.0.0.1:%d",
|
||||
s1.config.SerfLANConfig.MemberlistConfig.BindPort)
|
||||
if _, err := s2.JoinLAN([]string{addr}); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if _, err := s3.JoinLAN([]string{addr}); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
for _, s := range servers {
|
||||
testutil.WaitForResult(func() (bool, error) {
|
||||
peers, _ := s.raftPeers.Peers()
|
||||
return len(peers) == 3, nil
|
||||
}, func(err error) {
|
||||
t.Fatalf("should have 3 peers")
|
||||
})
|
||||
}
|
||||
|
||||
// Find the leader
|
||||
var leader *Server
|
||||
for _, s := range servers {
|
||||
// check that s.sessionTimers is empty
|
||||
if len(s.sessionTimers) != 0 {
|
||||
t.Fatalf("should have no sessionTimers")
|
||||
}
|
||||
// find the leader too
|
||||
if s.IsLeader() {
|
||||
leader = s
|
||||
}
|
||||
}
|
||||
|
||||
if leader == nil {
|
||||
t.Fatalf("Should have a leader")
|
||||
}
|
||||
|
||||
client := rpcClient(t, leader)
|
||||
defer client.Close()
|
||||
|
||||
leader.fsm.State().EnsureNode(1, structs.Node{"foo", "127.0.0.1"})
|
||||
|
||||
// create a TTL session
|
||||
arg := structs.SessionRequest{
|
||||
Datacenter: "dc1",
|
||||
Op: structs.SessionCreate,
|
||||
Session: structs.Session{
|
||||
Node: "foo",
|
||||
TTL: "10s",
|
||||
},
|
||||
}
|
||||
var id1 string
|
||||
if err := client.Call("Session.Apply", &arg, &id1); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
// check that leader.sessionTimers has the session id in it
|
||||
// means initializeSessionTimers was called and resetSessionTimer was called
|
||||
if len(leader.sessionTimers) == 0 || leader.sessionTimers[id1] == nil {
|
||||
t.Fatalf("sessionTimers not initialized and does not contain session timer for session")
|
||||
}
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
leader.Leave()
|
||||
leader.Shutdown()
|
||||
|
||||
// leader.sessionTimers should be empty due to clearAllSessionTimers getting called
|
||||
if len(leader.sessionTimers) != 0 {
|
||||
t.Fatalf("session timers should be empty on the shutdown leader")
|
||||
}
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
var remain *Server
|
||||
for _, s := range servers {
|
||||
if s == leader {
|
||||
continue
|
||||
}
|
||||
remain = s
|
||||
testutil.WaitForResult(func() (bool, error) {
|
||||
peers, _ := s.raftPeers.Peers()
|
||||
return len(peers) == 2, errors.New(fmt.Sprintf("%v", peers))
|
||||
}, func(err error) {
|
||||
t.Fatalf("should have 2 peers: %v", err)
|
||||
})
|
||||
}
|
||||
|
||||
// Verify the old leader is deregistered
|
||||
state := remain.fsm.State()
|
||||
testutil.WaitForResult(func() (bool, error) {
|
||||
_, found, _ := state.GetNode(leader.config.NodeName)
|
||||
return !found, nil
|
||||
}, func(err error) {
|
||||
t.Fatalf("leader should be deregistered")
|
||||
})
|
||||
|
||||
// Find the new leader
|
||||
leader = nil
|
||||
for _, s := range servers {
|
||||
// find the leader too
|
||||
if s.IsLeader() {
|
||||
leader = s
|
||||
}
|
||||
}
|
||||
|
||||
if leader == nil {
|
||||
t.Fatalf("Should have a new leader")
|
||||
}
|
||||
|
||||
// check that new leader.sessionTimers has the session id in it
|
||||
if len(leader.sessionTimers) == 0 || leader.sessionTimers[id1] == nil {
|
||||
t.Fatalf("sessionTimers not initialized and does not contain session timer for session")
|
||||
}
|
||||
|
||||
// create another TTL session with the same parameters
|
||||
var id2 string
|
||||
if err := client.Call("Session.Apply", &arg, &id2); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
if len(leader.sessionTimers) != 2 {
|
||||
t.Fatalf("sessionTimes length should be 2")
|
||||
}
|
||||
|
||||
// destroy the via invalidateSession as if on TTL expiry
|
||||
leader.invalidateSession(id2)
|
||||
|
||||
if len(leader.sessionTimers) != 1 {
|
||||
t.Fatalf("sessionTimers length should 1")
|
||||
}
|
||||
|
||||
// destroy the id2 session (test clearSessionTimer)
|
||||
arg.Op = structs.SessionDestroy
|
||||
arg.Session.ID = id2
|
||||
if err := client.Call("Session.Apply", &arg, &id2); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
if len(leader.sessionTimers) != 0 {
|
||||
t.Fatalf("sessionTimers length should be 0")
|
||||
}
|
||||
}
|
@ -294,6 +294,10 @@ func (s *StateStore) initialize() error {
|
||||
AllowBlank: true,
|
||||
Fields: []string{"Node"},
|
||||
},
|
||||
"ttl": &MDBIndex{
|
||||
AllowBlank: true,
|
||||
Fields: []string{"TTL"},
|
||||
},
|
||||
},
|
||||
Decoder: func(buf []byte) interface{} {
|
||||
out := new(structs.Session)
|
||||
@ -369,6 +373,7 @@ func (s *StateStore) initialize() error {
|
||||
"KVSListKeys": MDBTables{s.kvsTable},
|
||||
"SessionGet": MDBTables{s.sessionTable},
|
||||
"SessionList": MDBTables{s.sessionTable},
|
||||
"SessionListTTL": MDBTables{s.sessionTable},
|
||||
"NodeSessions": MDBTables{s.sessionTable},
|
||||
"ACLGet": MDBTables{s.aclTable},
|
||||
"ACLList": MDBTables{s.aclTable},
|
||||
@ -1336,6 +1341,17 @@ func (s *StateStore) SessionCreate(index uint64, session *structs.Session) error
|
||||
return fmt.Errorf("Invalid Session Behavior setting '%s'", session.Behavior)
|
||||
}
|
||||
|
||||
if session.TTL != "" {
|
||||
ttl, err := time.ParseDuration(session.TTL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Invalid Session TTL '%s': %v", session.TTL, err)
|
||||
}
|
||||
|
||||
if ttl < structs.SessionTTLMin || ttl > structs.SessionTTLMax {
|
||||
return fmt.Errorf("Invalid Session TTL '%s', must be between [%v-%v]", session.TTL, structs.SessionTTLMin, structs.SessionTTLMax)
|
||||
}
|
||||
}
|
||||
|
||||
// Assign the create index
|
||||
session.CreateIndex = index
|
||||
|
||||
@ -1445,6 +1461,16 @@ func (s *StateStore) SessionList() (uint64, []*structs.Session, error) {
|
||||
return idx, out, err
|
||||
}
|
||||
|
||||
// SessionListTTL is used to list all the open ttl sessions
|
||||
func (s *StateStore) SessionListTTL() (uint64, []*structs.Session, error) {
|
||||
idx, res, err := s.sessionTable.Get("ttl")
|
||||
out := make([]*structs.Session, len(res))
|
||||
for i, raw := range res {
|
||||
out[i] = raw.(*structs.Session)
|
||||
}
|
||||
return idx, out, err
|
||||
}
|
||||
|
||||
// NodeSessions is used to list all the open sessions for a node
|
||||
func (s *StateStore) NodeSessions(node string) (uint64, []*structs.Session, error) {
|
||||
idx, res, err := s.sessionTable.Get("node", node)
|
||||
|
@ -703,13 +703,17 @@ func TestStoreSnapshot(t *testing.T) {
|
||||
if ok, err := store.KVSLock(18, d); err != nil || !ok {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
session = &structs.Session{ID: generateUUID(), Node: "baz", TTL: "60s"}
|
||||
if err := store.SessionCreate(19, session); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
a1 := &structs.ACL{
|
||||
ID: generateUUID(),
|
||||
Name: "User token",
|
||||
Type: structs.ACLTypeClient,
|
||||
}
|
||||
if err := store.ACLSet(19, a1); err != nil {
|
||||
if err := store.ACLSet(20, a1); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
@ -718,7 +722,7 @@ func TestStoreSnapshot(t *testing.T) {
|
||||
Name: "User token",
|
||||
Type: structs.ACLTypeClient,
|
||||
}
|
||||
if err := store.ACLSet(20, a2); err != nil {
|
||||
if err := store.ACLSet(21, a2); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
@ -730,7 +734,7 @@ func TestStoreSnapshot(t *testing.T) {
|
||||
defer snap.Close()
|
||||
|
||||
// Check the last nodes
|
||||
if idx := snap.LastIndex(); idx != 20 {
|
||||
if idx := snap.LastIndex(); idx != 21 {
|
||||
t.Fatalf("bad: %v", idx)
|
||||
}
|
||||
|
||||
@ -785,15 +789,25 @@ func TestStoreSnapshot(t *testing.T) {
|
||||
t.Fatalf("missing KVS entries!")
|
||||
}
|
||||
|
||||
// Check there are 2 sessions
|
||||
// Check there are 3 sessions
|
||||
sessions, err := snap.SessionList()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if len(sessions) != 2 {
|
||||
if len(sessions) != 3 {
|
||||
t.Fatalf("missing sessions")
|
||||
}
|
||||
|
||||
ttls := 0
|
||||
for _, session := range sessions {
|
||||
if session.TTL != "" {
|
||||
ttls++
|
||||
}
|
||||
}
|
||||
if ttls != 1 {
|
||||
t.Fatalf("Wrong number of sessions with TTL")
|
||||
}
|
||||
|
||||
// Check for an acl
|
||||
acls, err := snap.ACLList()
|
||||
if err != nil {
|
||||
@ -804,13 +818,13 @@ func TestStoreSnapshot(t *testing.T) {
|
||||
}
|
||||
|
||||
// Make some changes!
|
||||
if err := store.EnsureService(21, "foo", &structs.NodeService{"db", "db", []string{"slave"}, 8000}); err != nil {
|
||||
if err := store.EnsureService(22, "foo", &structs.NodeService{"db", "db", []string{"slave"}, 8000}); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if err := store.EnsureService(22, "bar", &structs.NodeService{"db", "db", []string{"master"}, 8000}); err != nil {
|
||||
if err := store.EnsureService(23, "bar", &structs.NodeService{"db", "db", []string{"master"}, 8000}); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if err := store.EnsureNode(23, structs.Node{"baz", "127.0.0.3"}); err != nil {
|
||||
if err := store.EnsureNode(24, structs.Node{"baz", "127.0.0.3"}); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
checkAfter := &structs.HealthCheck{
|
||||
@ -820,16 +834,16 @@ func TestStoreSnapshot(t *testing.T) {
|
||||
Status: structs.HealthCritical,
|
||||
ServiceID: "db",
|
||||
}
|
||||
if err := store.EnsureCheck(24, checkAfter); err != nil {
|
||||
if err := store.EnsureCheck(26, checkAfter); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
if err := store.KVSDelete(25, "/web/b"); err != nil {
|
||||
if err := store.KVSDelete(26, "/web/b"); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
// Nuke an ACL
|
||||
if err := store.ACLDelete(26, a1.ID); err != nil {
|
||||
if err := store.ACLDelete(27, a1.ID); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
@ -883,12 +897,12 @@ func TestStoreSnapshot(t *testing.T) {
|
||||
t.Fatalf("missing KVS entries!")
|
||||
}
|
||||
|
||||
// Check there are 2 sessions
|
||||
// Check there are 3 sessions
|
||||
sessions, err = snap.SessionList()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if len(sessions) != 2 {
|
||||
if len(sessions) != 3 {
|
||||
t.Fatalf("missing sessions")
|
||||
}
|
||||
|
||||
|
@ -385,6 +385,12 @@ const (
|
||||
SessionKeysDelete = "delete"
|
||||
)
|
||||
|
||||
const (
|
||||
SessionTTLMin = 10 * time.Second
|
||||
SessionTTLMax = 3600 * time.Second
|
||||
SessionTTLMultiplier = 2
|
||||
)
|
||||
|
||||
// Session is used to represent an open session in the KV store.
|
||||
// This issued to associate node checks with acquired locks.
|
||||
type Session struct {
|
||||
@ -395,6 +401,7 @@ type Session struct {
|
||||
Checks []string
|
||||
LockDelay time.Duration
|
||||
Behavior SessionBehavior // What to do when session is invalidated
|
||||
TTL string
|
||||
}
|
||||
type Sessions []*Session
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user