mirror of https://github.com/status-im/consul.git
port oss changes (#11736)
This commit is contained in:
parent
e246defb6c
commit
ce326b6074
|
@ -151,7 +151,7 @@ func (s *Session) Apply(args *structs.SessionRequest, reply *string) error {
|
||||||
|
|
||||||
if args.Op == structs.SessionCreate && args.Session.TTL != "" {
|
if args.Op == structs.SessionCreate && args.Session.TTL != "" {
|
||||||
// If we created a session with a TTL, reset the expiration timer
|
// If we created a session with a TTL, reset the expiration timer
|
||||||
s.srv.resetSessionTimer(args.Session.ID, &args.Session)
|
s.srv.resetSessionTimer(&args.Session)
|
||||||
} else if args.Op == structs.SessionDestroy {
|
} else if args.Op == structs.SessionDestroy {
|
||||||
// If we destroyed a session, it might potentially have a TTL,
|
// If we destroyed a session, it might potentially have a TTL,
|
||||||
// and we need to clear the timer
|
// and we need to clear the timer
|
||||||
|
@ -308,7 +308,7 @@ func (s *Session) Renew(args *structs.SessionSpecificRequest,
|
||||||
|
|
||||||
// Reset the session TTL timer.
|
// Reset the session TTL timer.
|
||||||
reply.Sessions = structs.Sessions{session}
|
reply.Sessions = structs.Sessions{session}
|
||||||
if err := s.srv.resetSessionTimer(args.SessionID, session); err != nil {
|
if err := s.srv.resetSessionTimer(session); err != nil {
|
||||||
s.logger.Error("Session renew failed", "error", err)
|
s.logger.Error("Session renew failed", "error", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -47,13 +47,12 @@ func (s *Server) initializeSessionTimers() error {
|
||||||
// Scan all sessions and reset their timer
|
// Scan all sessions and reset their timer
|
||||||
state := s.fsm.State()
|
state := s.fsm.State()
|
||||||
|
|
||||||
// TODO(partitions): track all session timers in all partitions
|
_, sessions, err := state.SessionListAll(nil)
|
||||||
_, sessions, err := state.SessionList(nil, structs.WildcardEnterpriseMetaInDefaultPartition())
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
for _, session := range sessions {
|
for _, session := range sessions {
|
||||||
if err := s.resetSessionTimer(session.ID, session); err != nil {
|
if err := s.resetSessionTimer(session); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -63,20 +62,7 @@ func (s *Server) initializeSessionTimers() error {
|
||||||
// resetSessionTimer is used to renew the TTL of a session.
|
// resetSessionTimer is used to renew the TTL of a session.
|
||||||
// This can be used for new sessions and existing ones. A session
|
// This can be used for new sessions and existing ones. A session
|
||||||
// will be faulted in if not given.
|
// will be faulted in if not given.
|
||||||
func (s *Server) resetSessionTimer(id string, session *structs.Session) error {
|
func (s *Server) resetSessionTimer(session *structs.Session) error {
|
||||||
// Fault the session in if not given
|
|
||||||
if session == nil {
|
|
||||||
state := s.fsm.State()
|
|
||||||
_, s, err := state.SessionGet(nil, id, nil)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if s == nil {
|
|
||||||
return fmt.Errorf("Session '%s' not found", id)
|
|
||||||
}
|
|
||||||
session = s
|
|
||||||
}
|
|
||||||
|
|
||||||
// Bail if the session has no TTL, fast-path some common inputs
|
// Bail if the session has no TTL, fast-path some common inputs
|
||||||
switch session.TTL {
|
switch session.TTL {
|
||||||
case "", "0", "0s", "0m", "0h":
|
case "", "0", "0s", "0m", "0h":
|
||||||
|
|
|
@ -11,7 +11,7 @@ import (
|
||||||
"github.com/hashicorp/consul/sdk/testutil/retry"
|
"github.com/hashicorp/consul/sdk/testutil/retry"
|
||||||
"github.com/hashicorp/consul/testrpc"
|
"github.com/hashicorp/consul/testrpc"
|
||||||
"github.com/hashicorp/go-uuid"
|
"github.com/hashicorp/go-uuid"
|
||||||
"github.com/hashicorp/net-rpc-msgpackrpc"
|
msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc"
|
||||||
)
|
)
|
||||||
|
|
||||||
func generateUUID() (ret string) {
|
func generateUUID() (ret string) {
|
||||||
|
@ -59,50 +59,6 @@ func TestInitializeSessionTimers(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestResetSessionTimer_Fault(t *testing.T) {
|
|
||||||
if testing.Short() {
|
|
||||||
t.Skip("too slow for testing.Short")
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Parallel()
|
|
||||||
dir1, s1 := testServer(t)
|
|
||||||
defer os.RemoveAll(dir1)
|
|
||||||
defer s1.Shutdown()
|
|
||||||
|
|
||||||
testrpc.WaitForLeader(t, s1.RPC, "dc1")
|
|
||||||
|
|
||||||
// Should not exist
|
|
||||||
err := s1.resetSessionTimer(generateUUID(), nil)
|
|
||||||
if err == nil || !strings.Contains(err.Error(), "not found") {
|
|
||||||
t.Fatalf("err: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a session
|
|
||||||
state := s1.fsm.State()
|
|
||||||
if err := state.EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil {
|
|
||||||
t.Fatalf("err: %s", err)
|
|
||||||
}
|
|
||||||
session := &structs.Session{
|
|
||||||
ID: generateUUID(),
|
|
||||||
Node: "foo",
|
|
||||||
TTL: "10s",
|
|
||||||
}
|
|
||||||
if err := state.SessionCreate(100, session); err != nil {
|
|
||||||
t.Fatalf("err: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reset the session timer
|
|
||||||
err = s1.resetSessionTimer(session.ID, nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("err: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check that we have a timer
|
|
||||||
if s1.sessionTimers.Get(session.ID) == nil {
|
|
||||||
t.Fatalf("missing session timer")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResetSessionTimer_NoTTL(t *testing.T) {
|
func TestResetSessionTimer_NoTTL(t *testing.T) {
|
||||||
if testing.Short() {
|
if testing.Short() {
|
||||||
t.Skip("too slow for testing.Short")
|
t.Skip("too slow for testing.Short")
|
||||||
|
@ -130,7 +86,7 @@ func TestResetSessionTimer_NoTTL(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reset the session timer
|
// Reset the session timer
|
||||||
err := s1.resetSessionTimer(session.ID, session)
|
err := s1.resetSessionTimer(session)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %v", err)
|
t.Fatalf("err: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -155,7 +111,7 @@ func TestResetSessionTimer_InvalidTTL(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reset the session timer
|
// Reset the session timer
|
||||||
err := s1.resetSessionTimer(session.ID, session)
|
err := s1.resetSessionTimer(session)
|
||||||
if err == nil || !strings.Contains(err.Error(), "Invalid Session TTL") {
|
if err == nil || !strings.Contains(err.Error(), "Invalid Session TTL") {
|
||||||
t.Fatalf("err: %v", err)
|
t.Fatalf("err: %v", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -187,3 +187,7 @@ func (s *Store) SessionList(ws memdb.WatchSet, entMeta *structs.EnterpriseMeta)
|
||||||
func maxIndexTxnSessions(tx *memdb.Txn, _ *structs.EnterpriseMeta) uint64 {
|
func maxIndexTxnSessions(tx *memdb.Txn, _ *structs.EnterpriseMeta) uint64 {
|
||||||
return maxIndexTxn(tx, tableSessions)
|
return maxIndexTxn(tx, tableSessions)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Store) SessionListAll(ws memdb.WatchSet) (uint64, structs.Sessions, error) {
|
||||||
|
return s.SessionList(ws, nil)
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue