mirror of https://github.com/status-im/consul.git
port oss changes (#11736)
This commit is contained in:
parent
e246defb6c
commit
ce326b6074
|
@ -151,7 +151,7 @@ func (s *Session) Apply(args *structs.SessionRequest, reply *string) error {
|
|||
|
||||
if args.Op == structs.SessionCreate && args.Session.TTL != "" {
|
||||
// If we created a session with a TTL, reset the expiration timer
|
||||
s.srv.resetSessionTimer(args.Session.ID, &args.Session)
|
||||
s.srv.resetSessionTimer(&args.Session)
|
||||
} else if args.Op == structs.SessionDestroy {
|
||||
// If we destroyed a session, it might potentially have a TTL,
|
||||
// and we need to clear the timer
|
||||
|
@ -308,7 +308,7 @@ func (s *Session) Renew(args *structs.SessionSpecificRequest,
|
|||
|
||||
// Reset the session TTL timer.
|
||||
reply.Sessions = structs.Sessions{session}
|
||||
if err := s.srv.resetSessionTimer(args.SessionID, session); err != nil {
|
||||
if err := s.srv.resetSessionTimer(session); err != nil {
|
||||
s.logger.Error("Session renew failed", "error", err)
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -47,13 +47,12 @@ func (s *Server) initializeSessionTimers() error {
|
|||
// Scan all sessions and reset their timer
|
||||
state := s.fsm.State()
|
||||
|
||||
// TODO(partitions): track all session timers in all partitions
|
||||
_, sessions, err := state.SessionList(nil, structs.WildcardEnterpriseMetaInDefaultPartition())
|
||||
_, sessions, err := state.SessionListAll(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, session := range sessions {
|
||||
if err := s.resetSessionTimer(session.ID, session); err != nil {
|
||||
if err := s.resetSessionTimer(session); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
@ -63,20 +62,7 @@ func (s *Server) initializeSessionTimers() error {
|
|||
// resetSessionTimer is used to renew the TTL of a session.
|
||||
// This can be used for new sessions and existing ones. A session
|
||||
// will be faulted in if not given.
|
||||
func (s *Server) resetSessionTimer(id string, session *structs.Session) error {
|
||||
// Fault the session in if not given
|
||||
if session == nil {
|
||||
state := s.fsm.State()
|
||||
_, s, err := state.SessionGet(nil, id, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if s == nil {
|
||||
return fmt.Errorf("Session '%s' not found", id)
|
||||
}
|
||||
session = s
|
||||
}
|
||||
|
||||
func (s *Server) resetSessionTimer(session *structs.Session) error {
|
||||
// Bail if the session has no TTL, fast-path some common inputs
|
||||
switch session.TTL {
|
||||
case "", "0", "0s", "0m", "0h":
|
||||
|
|
|
@ -11,7 +11,7 @@ import (
|
|||
"github.com/hashicorp/consul/sdk/testutil/retry"
|
||||
"github.com/hashicorp/consul/testrpc"
|
||||
"github.com/hashicorp/go-uuid"
|
||||
"github.com/hashicorp/net-rpc-msgpackrpc"
|
||||
msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc"
|
||||
)
|
||||
|
||||
func generateUUID() (ret string) {
|
||||
|
@ -59,50 +59,6 @@ func TestInitializeSessionTimers(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestResetSessionTimer_Fault(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("too slow for testing.Short")
|
||||
}
|
||||
|
||||
t.Parallel()
|
||||
dir1, s1 := testServer(t)
|
||||
defer os.RemoveAll(dir1)
|
||||
defer s1.Shutdown()
|
||||
|
||||
testrpc.WaitForLeader(t, s1.RPC, "dc1")
|
||||
|
||||
// Should not exist
|
||||
err := s1.resetSessionTimer(generateUUID(), nil)
|
||||
if err == nil || !strings.Contains(err.Error(), "not found") {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
// Create a session
|
||||
state := s1.fsm.State()
|
||||
if err := state.EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
session := &structs.Session{
|
||||
ID: generateUUID(),
|
||||
Node: "foo",
|
||||
TTL: "10s",
|
||||
}
|
||||
if err := state.SessionCreate(100, session); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
// Reset the session timer
|
||||
err = s1.resetSessionTimer(session.ID, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
// Check that we have a timer
|
||||
if s1.sessionTimers.Get(session.ID) == nil {
|
||||
t.Fatalf("missing session timer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResetSessionTimer_NoTTL(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("too slow for testing.Short")
|
||||
|
@ -130,7 +86,7 @@ func TestResetSessionTimer_NoTTL(t *testing.T) {
|
|||
}
|
||||
|
||||
// Reset the session timer
|
||||
err := s1.resetSessionTimer(session.ID, session)
|
||||
err := s1.resetSessionTimer(session)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
@ -155,7 +111,7 @@ func TestResetSessionTimer_InvalidTTL(t *testing.T) {
|
|||
}
|
||||
|
||||
// Reset the session timer
|
||||
err := s1.resetSessionTimer(session.ID, session)
|
||||
err := s1.resetSessionTimer(session)
|
||||
if err == nil || !strings.Contains(err.Error(), "Invalid Session TTL") {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
|
|
@ -187,3 +187,7 @@ func (s *Store) SessionList(ws memdb.WatchSet, entMeta *structs.EnterpriseMeta)
|
|||
func maxIndexTxnSessions(tx *memdb.Txn, _ *structs.EnterpriseMeta) uint64 {
|
||||
return maxIndexTxn(tx, tableSessions)
|
||||
}
|
||||
|
||||
func (s *Store) SessionListAll(ws memdb.WatchSet) (uint64, structs.Sessions, error) {
|
||||
return s.SessionList(ws, nil)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue