diff --git a/consul/state/state_store.go b/consul/state/state_store.go index 4e53983f4e..64bcea756c 100644 --- a/consul/state/state_store.go +++ b/consul/state/state_store.go @@ -1101,15 +1101,40 @@ func (s *StateStore) SessionList() (uint64, []*structs.Session, error) { var lindex uint64 for session := sessions.Next(); session != nil; session = sessions.Next() { sess := session.(*structs.Session) + result = append(result, sess) + + // Compute the highest index + if sess.ModifyIndex > lindex { + lindex = sess.ModifyIndex + } + } + return lindex, result, nil +} + +// NodeSessions returns a set of active sessions associated +// with the given node ID. The returned index is the highest +// index seen from the result set. +func (s *StateStore) NodeSessions(nodeID string) (uint64, []*structs.Session, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + // Get all of the sessions which belong to the node + sessions, err := tx.Get("sessions", "node", nodeID) + if err != nil { + return 0, nil, fmt.Errorf("failed session lookup: %s", err) + } + + // Go over all of the sessions and return them as a slice + var result []*structs.Session + var lindex uint64 + for session := sessions.Next(); session != nil; session = sessions.Next() { + sess := session.(*structs.Session) + result = append(result, sess) // Compute the highest index if sess.ModifyIndex > lindex { lindex = sess.ModifyIndex } - - // Add the session to the result - result = append(result, sess) } - return lindex, result, nil } diff --git a/consul/state/state_store_test.go b/consul/state/state_store_test.go index 37b03c25c3..57963ec8de 100644 --- a/consul/state/state_store_test.go +++ b/consul/state/state_store_test.go @@ -1278,6 +1278,11 @@ func TestStateStore_SessionCreate(t *testing.T) { func TestStateStore_ListSessions(t *testing.T) { s := testStateStore(t) + // Listing when no sessions exist returns nil + if idx, res, err := s.SessionList(); idx != 0 || res != nil || err != nil { + t.Fatalf("expected (0, nil, nil), got: (%d, %#v, %#v)", idx, res, err) + } + // Register some nodes testRegisterNode(t, s, 1, "node1") testRegisterNode(t, s, 2, "node2") @@ -1319,3 +1324,61 @@ func TestStateStore_ListSessions(t *testing.T) { t.Fatalf("bad: %#v", sessions) } } + +func TestStateStore_NodeSessions(t *testing.T) { + s := testStateStore(t) + + // Listing sessions with no results returns nil + if idx, res, err := s.NodeSessions("node1"); idx != 0 || res != nil || err != nil { + t.Fatalf("expected (0, nil, nil), got: (%d, %#v, %#v)", idx, res, err) + } + + // Create the nodes + testRegisterNode(t, s, 1, "node1") + testRegisterNode(t, s, 2, "node2") + + // Register some sessions with the nodes + sessions1 := []*structs.Session{ + &structs.Session{ + ID: "session1", + Node: "node1", + }, + &structs.Session{ + ID: "session2", + Node: "node1", + }, + } + sessions2 := []*structs.Session{ + &structs.Session{ + ID: "session3", + Node: "node2", + }, + &structs.Session{ + ID: "session4", + Node: "node2", + }, + } + for i, sess := range append(sessions1, sessions2...) { + if err := s.SessionCreate(uint64(3+i), sess); err != nil { + t.Fatalf("err: %s", err) + } + } + + // Query all of the sessions associated with a specific + // node in the state store. + idx, result, err := s.NodeSessions("node1") + if err != nil { + t.Fatalf("err: %s", err) + } + + // Check that the index was properly filtered based + // on the provided node ID. + if idx != 4 { + t.Fatalf("bad index: %s", err) + } + + // Check that the returned sessions match. + if !reflect.DeepEqual(result, sessions1) { + t.Fatalf("bad: %#v", result) + } +}