Adds fine-grained watches to session endpoints.

This commit is contained in:
James Phillips 2017-01-24 10:08:14 -08:00
parent 8b7977ccb3
commit 1d39ddbd4b
No known key found for this signature in database
GPG Key ID: 77183E682AC5FC11
7 changed files with 167 additions and 107 deletions

View File

@ -500,7 +500,7 @@ func TestFSM_SnapshotRestore(t *testing.T) {
} }
// Verify session is restored // Verify session is restored
idx, s, err := fsm2.state.SessionGet(session.ID) idx, s, err := fsm2.state.SessionGet(nil, session.ID)
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -875,7 +875,7 @@ func TestFSM_SessionCreate_Destroy(t *testing.T) {
// Get the session // Get the session
id := resp.(string) id := resp.(string)
_, session, err := fsm.state.SessionGet(id) _, session, err := fsm.state.SessionGet(nil, id)
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -911,7 +911,7 @@ func TestFSM_SessionCreate_Destroy(t *testing.T) {
t.Fatalf("resp: %v", resp) t.Fatalf("resp: %v", resp)
} }
_, session, err = fsm.state.SessionGet(id) _, session, err = fsm.state.SessionGet(nil, id)
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }

View File

@ -6,6 +6,7 @@ import (
"github.com/armon/go-metrics" "github.com/armon/go-metrics"
"github.com/hashicorp/consul/consul/structs" "github.com/hashicorp/consul/consul/structs"
"github.com/hashicorp/go-memdb"
"github.com/hashicorp/go-uuid" "github.com/hashicorp/go-uuid"
) )
@ -39,7 +40,7 @@ func (s *Session) Apply(args *structs.SessionRequest, reply *string) error {
switch args.Op { switch args.Op {
case structs.SessionDestroy: case structs.SessionDestroy:
state := s.srv.fsm.State() state := s.srv.fsm.State()
_, existing, err := state.SessionGet(args.Session.ID) _, existing, err := state.SessionGet(nil, args.Session.ID)
if err != nil { if err != nil {
return fmt.Errorf("Unknown session %q", args.Session.ID) return fmt.Errorf("Unknown session %q", args.Session.ID)
} }
@ -94,7 +95,7 @@ func (s *Session) Apply(args *structs.SessionRequest, reply *string) error {
s.srv.logger.Printf("[ERR] consul.session: UUID generation failed: %v", err) s.srv.logger.Printf("[ERR] consul.session: UUID generation failed: %v", err)
return err return err
} }
_, sess, err := state.SessionGet(args.Session.ID) _, sess, err := state.SessionGet(nil, args.Session.ID)
if err != nil { if err != nil {
s.srv.logger.Printf("[ERR] consul.session: Session lookup failed: %v", err) s.srv.logger.Printf("[ERR] consul.session: Session lookup failed: %v", err)
return err return err
@ -141,12 +142,11 @@ func (s *Session) Get(args *structs.SessionSpecificRequest,
// Get the local state // Get the local state
state := s.srv.fsm.State() state := s.srv.fsm.State()
return s.srv.blockingRPC( return s.srv.blockingQuery(
&args.QueryOptions, &args.QueryOptions,
&reply.QueryMeta, &reply.QueryMeta,
state.GetQueryWatch("SessionGet"), func(ws memdb.WatchSet) error {
func() error { index, session, err := state.SessionGet(ws, args.Session)
index, session, err := state.SessionGet(args.Session)
if err != nil { if err != nil {
return err return err
} }
@ -173,12 +173,11 @@ func (s *Session) List(args *structs.DCSpecificRequest,
// Get the local state // Get the local state
state := s.srv.fsm.State() state := s.srv.fsm.State()
return s.srv.blockingRPC( return s.srv.blockingQuery(
&args.QueryOptions, &args.QueryOptions,
&reply.QueryMeta, &reply.QueryMeta,
state.GetQueryWatch("SessionList"), func(ws memdb.WatchSet) error {
func() error { index, sessions, err := state.SessionList(ws)
index, sessions, err := state.SessionList()
if err != nil { if err != nil {
return err return err
} }
@ -200,12 +199,11 @@ func (s *Session) NodeSessions(args *structs.NodeSpecificRequest,
// Get the local state // Get the local state
state := s.srv.fsm.State() state := s.srv.fsm.State()
return s.srv.blockingRPC( return s.srv.blockingQuery(
&args.QueryOptions, &args.QueryOptions,
&reply.QueryMeta, &reply.QueryMeta,
state.GetQueryWatch("NodeSessions"), func(ws memdb.WatchSet) error {
func() error { index, sessions, err := state.NodeSessions(ws, args.Node)
index, sessions, err := state.NodeSessions(args.Node)
if err != nil { if err != nil {
return err return err
} }
@ -228,7 +226,7 @@ func (s *Session) Renew(args *structs.SessionSpecificRequest,
// Get the session, from local state. // Get the session, from local state.
state := s.srv.fsm.State() state := s.srv.fsm.State()
index, session, err := state.SessionGet(args.Session) index, session, err := state.SessionGet(nil, args.Session)
if err != nil { if err != nil {
return err return err
} }

View File

@ -40,7 +40,7 @@ func TestSession_Apply(t *testing.T) {
// Verify // Verify
state := s1.fsm.State() state := s1.fsm.State()
_, s, err := state.SessionGet(out) _, s, err := state.SessionGet(nil, out)
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -62,7 +62,7 @@ func TestSession_Apply(t *testing.T) {
} }
// Verify // Verify
_, s, err = state.SessionGet(id) _, s, err = state.SessionGet(nil, id)
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -100,7 +100,7 @@ func TestSession_DeleteApply(t *testing.T) {
// Verify // Verify
state := s1.fsm.State() state := s1.fsm.State()
_, s, err := state.SessionGet(out) _, s, err := state.SessionGet(nil, out)
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -125,7 +125,7 @@ func TestSession_DeleteApply(t *testing.T) {
} }
// Verify // Verify
_, s, err = state.SessionGet(id) _, s, err = state.SessionGet(nil, id)
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }

View File

@ -22,7 +22,7 @@ const (
func (s *Server) initializeSessionTimers() error { func (s *Server) initializeSessionTimers() error {
// Scan all sessions and reset their timer // Scan all sessions and reset their timer
state := s.fsm.State() state := s.fsm.State()
_, sessions, err := state.SessionList() _, sessions, err := state.SessionList(nil)
if err != nil { if err != nil {
return err return err
} }
@ -41,7 +41,7 @@ func (s *Server) resetSessionTimer(id string, session *structs.Session) error {
// Fault the session in if not given // Fault the session in if not given
if session == nil { if session == nil {
state := s.fsm.State() state := s.fsm.State()
_, s, err := state.SessionGet(id) _, s, err := state.SessionGet(nil, id)
if err != nil { if err != nil {
return err return err
} }

View File

@ -225,7 +225,7 @@ func TestInvalidateSession(t *testing.T) {
s1.invalidateSession(session.ID) s1.invalidateSession(session.ID)
// Check it is gone // Check it is gone
_, sess, err := state.SessionGet(session.ID) _, sess, err := state.SessionGet(nil, session.ID)
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }

View File

@ -145,18 +145,19 @@ func (s *StateStore) sessionCreateTxn(tx *memdb.Txn, idx uint64, sess *structs.S
} }
// SessionGet is used to retrieve an active session from the state store. // SessionGet is used to retrieve an active session from the state store.
func (s *StateStore) SessionGet(sessionID string) (uint64, *structs.Session, error) { func (s *StateStore) SessionGet(ws memdb.WatchSet, sessionID string) (uint64, *structs.Session, error) {
tx := s.db.Txn(false) tx := s.db.Txn(false)
defer tx.Abort() defer tx.Abort()
// Get the table index. // Get the table index.
idx := maxIndexTxn(tx, s.getWatchTables("SessionGet")...) idx := maxIndexTxn(tx, "sessions")
// Look up the session by its ID // Look up the session by its ID
session, err := tx.First("sessions", "id", sessionID) watchCh, session, err := tx.FirstWatch("sessions", "id", sessionID)
if err != nil { if err != nil {
return 0, nil, fmt.Errorf("failed session lookup: %s", err) return 0, nil, fmt.Errorf("failed session lookup: %s", err)
} }
ws.Add(watchCh)
if session != nil { if session != nil {
return idx, session.(*structs.Session), nil return idx, session.(*structs.Session), nil
} }
@ -164,18 +165,19 @@ func (s *StateStore) SessionGet(sessionID string) (uint64, *structs.Session, err
} }
// SessionList returns a slice containing all of the active sessions. // SessionList returns a slice containing all of the active sessions.
func (s *StateStore) SessionList() (uint64, structs.Sessions, error) { func (s *StateStore) SessionList(ws memdb.WatchSet) (uint64, structs.Sessions, error) {
tx := s.db.Txn(false) tx := s.db.Txn(false)
defer tx.Abort() defer tx.Abort()
// Get the table index. // Get the table index.
idx := maxIndexTxn(tx, s.getWatchTables("SessionList")...) idx := maxIndexTxn(tx, "sessions")
// Query all of the active sessions. // Query all of the active sessions.
sessions, err := tx.Get("sessions", "id") sessions, err := tx.Get("sessions", "id")
if err != nil { if err != nil {
return 0, nil, fmt.Errorf("failed session lookup: %s", err) return 0, nil, fmt.Errorf("failed session lookup: %s", err)
} }
ws.Add(sessions.WatchCh())
// Go over the sessions and create a slice of them. // Go over the sessions and create a slice of them.
var result structs.Sessions var result structs.Sessions
@ -188,18 +190,19 @@ func (s *StateStore) SessionList() (uint64, structs.Sessions, error) {
// NodeSessions returns a set of active sessions associated // NodeSessions returns a set of active sessions associated
// with the given node ID. The returned index is the highest // with the given node ID. The returned index is the highest
// index seen from the result set. // index seen from the result set.
func (s *StateStore) NodeSessions(nodeID string) (uint64, structs.Sessions, error) { func (s *StateStore) NodeSessions(ws memdb.WatchSet, nodeID string) (uint64, structs.Sessions, error) {
tx := s.db.Txn(false) tx := s.db.Txn(false)
defer tx.Abort() defer tx.Abort()
// Get the table index. // Get the table index.
idx := maxIndexTxn(tx, s.getWatchTables("NodeSessions")...) idx := maxIndexTxn(tx, "sessions")
// Get all of the sessions which belong to the node // Get all of the sessions which belong to the node
sessions, err := tx.Get("sessions", "node", nodeID) sessions, err := tx.Get("sessions", "node", nodeID)
if err != nil { if err != nil {
return 0, nil, fmt.Errorf("failed session lookup: %s", err) return 0, nil, fmt.Errorf("failed session lookup: %s", err)
} }
ws.Add(sessions.WatchCh())
// Go over all of the sessions and return them as a slice // Go over all of the sessions and return them as a slice
var result structs.Sessions var result structs.Sessions

View File

@ -9,13 +9,15 @@ import (
"github.com/hashicorp/consul/consul/structs" "github.com/hashicorp/consul/consul/structs"
"github.com/hashicorp/consul/types" "github.com/hashicorp/consul/types"
"github.com/hashicorp/go-memdb"
) )
func TestStateStore_SessionCreate_SessionGet(t *testing.T) { func TestStateStore_SessionCreate_SessionGet(t *testing.T) {
s := testStateStore(t) s := testStateStore(t)
// SessionGet returns nil if the session doesn't exist // SessionGet returns nil if the session doesn't exist
idx, session, err := s.SessionGet(testUUID()) ws := memdb.NewWatchSet()
idx, session, err := s.SessionGet(ws, testUUID())
if session != nil || err != nil { if session != nil || err != nil {
t.Fatalf("expected (nil, nil), got: (%#v, %#v)", session, err) t.Fatalf("expected (nil, nil), got: (%#v, %#v)", session, err)
} }
@ -49,6 +51,9 @@ func TestStateStore_SessionCreate_SessionGet(t *testing.T) {
if idx := s.maxIndex("sessions"); idx != 0 { if idx := s.maxIndex("sessions"); idx != 0 {
t.Fatalf("bad index: %d", idx) t.Fatalf("bad index: %d", idx)
} }
if watchFired(ws) {
t.Fatalf("bad")
}
// Valid session is able to register // Valid session is able to register
testRegisterNode(t, s, 1, "node1") testRegisterNode(t, s, 1, "node1")
@ -62,9 +67,13 @@ func TestStateStore_SessionCreate_SessionGet(t *testing.T) {
if idx := s.maxIndex("sessions"); idx != 2 { if idx := s.maxIndex("sessions"); idx != 2 {
t.Fatalf("bad index: %s", err) t.Fatalf("bad index: %s", err)
} }
if !watchFired(ws) {
t.Fatalf("bad")
}
// Retrieve the session again // Retrieve the session again
idx, session, err = s.SessionGet(sess.ID) ws = memdb.NewWatchSet()
idx, session, err = s.SessionGet(ws, sess.ID)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -104,12 +113,19 @@ func TestStateStore_SessionCreate_SessionGet(t *testing.T) {
if err == nil || !strings.Contains(err.Error(), structs.HealthCritical) { if err == nil || !strings.Contains(err.Error(), structs.HealthCritical) {
t.Fatalf("expected critical state error, got: %#v", err) t.Fatalf("expected critical state error, got: %#v", err)
} }
if watchFired(ws) {
t.Fatalf("bad")
}
// Registering with a healthy check succeeds // Registering with a healthy check succeeds (doesn't hit the watch since
// we are looking at the old session).
testRegisterCheck(t, s, 4, "node1", "", "check1", structs.HealthPassing) testRegisterCheck(t, s, 4, "node1", "", "check1", structs.HealthPassing)
if err := s.SessionCreate(5, sess); err != nil { if err := s.SessionCreate(5, sess); err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
if watchFired(ws) {
t.Fatalf("bad")
}
// Register a session against two checks. // Register a session against two checks.
testRegisterCheck(t, s, 5, "node1", "", "check2", structs.HealthPassing) testRegisterCheck(t, s, 5, "node1", "", "check2", structs.HealthPassing)
@ -159,7 +175,7 @@ func TestStateStore_SessionCreate_SessionGet(t *testing.T) {
} }
// Pulling a nonexistent session gives the table index. // Pulling a nonexistent session gives the table index.
idx, session, err = s.SessionGet(testUUID()) idx, session, err = s.SessionGet(nil, testUUID())
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -175,7 +191,8 @@ func TegstStateStore_SessionList(t *testing.T) {
s := testStateStore(t) s := testStateStore(t)
// Listing when no sessions exist returns nil // Listing when no sessions exist returns nil
idx, res, err := s.SessionList() ws := memdb.NewWatchSet()
idx, res, err := s.SessionList(ws)
if idx != 0 || res != nil || err != nil { if idx != 0 || res != nil || err != nil {
t.Fatalf("expected (0, nil, nil), got: (%d, %#v, %#v)", idx, res, err) t.Fatalf("expected (0, nil, nil), got: (%d, %#v, %#v)", idx, res, err)
} }
@ -208,9 +225,12 @@ func TegstStateStore_SessionList(t *testing.T) {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
} }
if !watchFired(ws) {
t.Fatalf("bad")
}
// List out all of the sessions // List out all of the sessions
idx, sessionList, err := s.SessionList() idx, sessionList, err := s.SessionList(nil)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -226,7 +246,8 @@ func TestStateStore_NodeSessions(t *testing.T) {
s := testStateStore(t) s := testStateStore(t)
// Listing sessions with no results returns nil // Listing sessions with no results returns nil
idx, res, err := s.NodeSessions("node1") ws := memdb.NewWatchSet()
idx, res, err := s.NodeSessions(ws, "node1")
if idx != 0 || res != nil || err != nil { if idx != 0 || res != nil || err != nil {
t.Fatalf("expected (0, nil, nil), got: (%d, %#v, %#v)", idx, res, err) t.Fatalf("expected (0, nil, nil), got: (%d, %#v, %#v)", idx, res, err)
} }
@ -261,10 +282,14 @@ func TestStateStore_NodeSessions(t *testing.T) {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
} }
if !watchFired(ws) {
t.Fatalf("bad")
}
// Query all of the sessions associated with a specific // Query all of the sessions associated with a specific
// node in the state store. // node in the state store.
idx, res, err = s.NodeSessions("node1") ws1 := memdb.NewWatchSet()
idx, res, err = s.NodeSessions(ws1, "node1")
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -275,7 +300,8 @@ func TestStateStore_NodeSessions(t *testing.T) {
t.Fatalf("bad index: %d", idx) t.Fatalf("bad index: %d", idx)
} }
idx, res, err = s.NodeSessions("node2") ws2 := memdb.NewWatchSet()
idx, res, err = s.NodeSessions(ws2, "node2")
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -285,6 +311,17 @@ func TestStateStore_NodeSessions(t *testing.T) {
if idx != 6 { if idx != 6 {
t.Fatalf("bad index: %d", idx) t.Fatalf("bad index: %d", idx)
} }
// Destroying a session on node1 should not affect node2's watch.
if err := s.SessionDestroy(100, sessions1[0].ID); err != nil {
t.Fatalf("err: %s", err)
}
if !watchFired(ws1) {
t.Fatalf("bad")
}
if watchFired(ws2) {
t.Fatalf("bad")
}
} }
func TestStateStore_SessionDestroy(t *testing.T) { func TestStateStore_SessionDestroy(t *testing.T) {
@ -418,7 +455,7 @@ func TestStateStore_Session_Snapshot_Restore(t *testing.T) {
// Read the restored sessions back out and verify that they // Read the restored sessions back out and verify that they
// match. // match.
idx, res, err := s.SessionList() idx, res, err := s.SessionList(nil)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -520,17 +557,21 @@ func TestStateStore_Session_Invalidate_DeleteNode(t *testing.T) {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
// Delete the node and make sure the watches fire. // Delete the node and make sure the watch fires.
verifyWatch(t, s.getTableWatch("sessions"), func() { ws := memdb.NewWatchSet()
verifyWatch(t, s.getTableWatch("nodes"), func() { idx, s2, err := s.SessionGet(ws, session.ID)
if err := s.DeleteNode(15, "foo"); err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
}) if err := s.DeleteNode(15, "foo"); err != nil {
}) t.Fatalf("err: %v", err)
}
if !watchFired(ws) {
t.Fatalf("bad")
}
// Lookup by ID, should be nil. // Lookup by ID, should be nil.
idx, s2, err := s.SessionGet(session.ID) idx, s2, err = s.SessionGet(nil, session.ID)
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -571,19 +612,21 @@ func TestStateStore_Session_Invalidate_DeleteService(t *testing.T) {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
// Delete the service and make sure the watches fire. // Delete the service and make sure the watch fires.
verifyWatch(t, s.getTableWatch("sessions"), func() { ws := memdb.NewWatchSet()
verifyWatch(t, s.getTableWatch("services"), func() { idx, s2, err := s.SessionGet(ws, session.ID)
verifyWatch(t, s.getTableWatch("checks"), func() { if err != nil {
if err := s.DeleteService(15, "foo", "api"); err != nil { t.Fatalf("err: %v", err)
t.Fatalf("err: %v", err) }
} if err := s.DeleteService(15, "foo", "api"); err != nil {
}) t.Fatalf("err: %v", err)
}) }
}) if !watchFired(ws) {
t.Fatalf("bad")
}
// Lookup by ID, should be nil. // Lookup by ID, should be nil.
idx, s2, err := s.SessionGet(session.ID) idx, s2, err = s.SessionGet(nil, session.ID)
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -620,17 +663,21 @@ func TestStateStore_Session_Invalidate_Critical_Check(t *testing.T) {
} }
// Invalidate the check and make sure the watches fire. // Invalidate the check and make sure the watches fire.
verifyWatch(t, s.getTableWatch("sessions"), func() { ws := memdb.NewWatchSet()
verifyWatch(t, s.getTableWatch("checks"), func() { idx, s2, err := s.SessionGet(ws, session.ID)
check.Status = structs.HealthCritical if err != nil {
if err := s.EnsureCheck(15, check); err != nil { t.Fatalf("err: %v", err)
t.Fatalf("err: %v", err) }
} check.Status = structs.HealthCritical
}) if err := s.EnsureCheck(15, check); err != nil {
}) t.Fatalf("err: %v", err)
}
if !watchFired(ws) {
t.Fatalf("bad")
}
// Lookup by ID, should be nil. // Lookup by ID, should be nil.
idx, s2, err := s.SessionGet(session.ID) idx, s2, err = s.SessionGet(nil, session.ID)
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -667,16 +714,20 @@ func TestStateStore_Session_Invalidate_DeleteCheck(t *testing.T) {
} }
// Delete the check and make sure the watches fire. // Delete the check and make sure the watches fire.
verifyWatch(t, s.getTableWatch("sessions"), func() { ws := memdb.NewWatchSet()
verifyWatch(t, s.getTableWatch("checks"), func() { idx, s2, err := s.SessionGet(ws, session.ID)
if err := s.DeleteCheck(15, "foo", "bar"); err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
}) if err := s.DeleteCheck(15, "foo", "bar"); err != nil {
}) t.Fatalf("err: %v", err)
}
if !watchFired(ws) {
t.Fatalf("bad")
}
// Lookup by ID, should be nil. // Lookup by ID, should be nil.
idx, s2, err := s.SessionGet(session.ID) idx, s2, err = s.SessionGet(nil, session.ID)
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -731,18 +782,20 @@ func TestStateStore_Session_Invalidate_Key_Unlock_Behavior(t *testing.T) {
} }
// Delete the node and make sure the watches fire. // Delete the node and make sure the watches fire.
verifyWatch(t, s.getTableWatch("sessions"), func() { ws := memdb.NewWatchSet()
verifyWatch(t, s.getTableWatch("nodes"), func() { idx, s2, err := s.SessionGet(ws, session.ID)
verifyWatch(t, s.GetKVSWatch("/f"), func() { if err != nil {
if err := s.DeleteNode(6, "foo"); err != nil { t.Fatalf("err: %v", err)
t.Fatalf("err: %v", err) }
} if err := s.DeleteNode(6, "foo"); err != nil {
}) t.Fatalf("err: %v", err)
}) }
}) if !watchFired(ws) {
t.Fatalf("bad")
}
// Lookup by ID, should be nil. // Lookup by ID, should be nil.
idx, s2, err := s.SessionGet(session.ID) idx, s2, err = s.SessionGet(nil, session.ID)
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -811,18 +864,20 @@ func TestStateStore_Session_Invalidate_Key_Delete_Behavior(t *testing.T) {
} }
// Delete the node and make sure the watches fire. // Delete the node and make sure the watches fire.
verifyWatch(t, s.getTableWatch("sessions"), func() { ws := memdb.NewWatchSet()
verifyWatch(t, s.getTableWatch("nodes"), func() { idx, s2, err := s.SessionGet(ws, session.ID)
verifyWatch(t, s.GetKVSWatch("/b"), func() { if err != nil {
if err := s.DeleteNode(6, "foo"); err != nil { t.Fatalf("err: %v", err)
t.Fatalf("err: %v", err) }
} if err := s.DeleteNode(6, "foo"); err != nil {
}) t.Fatalf("err: %v", err)
}) }
}) if !watchFired(ws) {
t.Fatalf("bad")
}
// Lookup by ID, should be nil. // Lookup by ID, should be nil.
idx, s2, err := s.SessionGet(session.ID) idx, s2, err = s.SessionGet(nil, session.ID)
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -877,16 +932,20 @@ func TestStateStore_Session_Invalidate_PreparedQuery_Delete(t *testing.T) {
} }
// Invalidate the session and make sure the watches fire. // Invalidate the session and make sure the watches fire.
verifyWatch(t, s.getTableWatch("sessions"), func() { ws := memdb.NewWatchSet()
verifyWatch(t, s.getTableWatch("prepared-queries"), func() { idx, s2, err := s.SessionGet(ws, session.ID)
if err := s.SessionDestroy(5, session.ID); err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
}) if err := s.SessionDestroy(5, session.ID); err != nil {
}) t.Fatalf("err: %v", err)
}
if !watchFired(ws) {
t.Fatalf("bad")
}
// Make sure the session is gone. // Make sure the session is gone.
idx, s2, err := s.SessionGet(session.ID) idx, s2, err = s.SessionGet(nil, session.ID)
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }