diff --git a/agent/consul/acl.go b/agent/consul/acl.go index 31dba80174..078dc40267 100644 --- a/agent/consul/acl.go +++ b/agent/consul/acl.go @@ -1380,9 +1380,12 @@ func (f *aclFilter) filterDatacenterCheckServiceNodes(datacenterNodes *map[strin *datacenterNodes = out } -// filterSessions is used to filter a set of sessions based on ACLs. -func (f *aclFilter) filterSessions(sessions *structs.Sessions) { +// filterSessions is used to filter a set of sessions based on ACLs. Returns +// true if any elements were removed. +func (f *aclFilter) filterSessions(sessions *structs.Sessions) bool { s := *sessions + + var removed bool for i := 0; i < len(s); i++ { session := s[i] @@ -1392,11 +1395,13 @@ func (f *aclFilter) filterSessions(sessions *structs.Sessions) { if f.allowSession(session.Node, &entCtx) { continue } + removed = true f.logger.Debug("dropping session from result due to ACLs", "session", session.ID) s = append(s[:i], s[i+1:]...) i-- } *sessions = s + return removed } // filterCoordinates is used to filter nodes in a coordinate dump based on ACL @@ -1852,7 +1857,7 @@ func filterACLWithAuthorizer(logger hclog.Logger, authorizer acl.Authorizer, sub filt.filterServices(v.Services, &v.EnterpriseMeta) case *structs.IndexedSessions: - filt.filterSessions(&v.Sessions) + v.QueryMeta.ResultsFilteredByACLs = filt.filterSessions(&v.Sessions) case *structs.IndexedPreparedQueries: filt.filterPreparedQueries(&v.Queries) diff --git a/agent/consul/acl_test.go b/agent/consul/acl_test.go index f0b18cf034..eaf43690eb 100644 --- a/agent/consul/acl_test.go +++ b/agent/consul/acl_test.go @@ -2796,29 +2796,57 @@ func TestACL_filterCoordinates(t *testing.T) { func TestACL_filterSessions(t *testing.T) { t.Parallel() - // Create a session list. - sessions := structs.Sessions{ - &structs.Session{ - Node: "foo", - }, - &structs.Session{ - Node: "bar", - }, + + logger := hclog.NewNullLogger() + + makeList := func() *structs.IndexedSessions { + return &structs.IndexedSessions{ + Sessions: structs.Sessions{ + {Node: "foo"}, + {Node: "bar"}, + }, + } } - // Try permissive filtering. - filt := newACLFilter(acl.AllowAll(), nil) - filt.filterSessions(&sessions) - if len(sessions) != 2 { - t.Fatalf("bad: %#v", sessions) - } + t.Run("all allowed", func(t *testing.T) { + require := require.New(t) - // Try restrictive filtering - filt = newACLFilter(acl.DenyAll(), nil) - filt.filterSessions(&sessions) - if len(sessions) != 0 { - t.Fatalf("bad: %#v", sessions) - } + list := makeList() + filterACLWithAuthorizer(logger, acl.AllowAll(), list) + + require.Len(list.Sessions, 2) + require.False(list.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be false") + }) + + t.Run("just one node's sessions allowed", func(t *testing.T) { + require := require.New(t) + + policy, err := acl.NewPolicyFromSource(` + session "foo" { + policy = "read" + } + `, acl.SyntaxLegacy, nil, nil) + require.NoError(err) + + authz, err := acl.NewPolicyAuthorizerWithDefaults(acl.DenyAll(), []*acl.Policy{policy}, nil) + require.NoError(err) + + list := makeList() + filterACLWithAuthorizer(logger, authz, list) + + require.Len(list.Sessions, 1) + require.True(list.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be true") + }) + + t.Run("denied", func(t *testing.T) { + require := require.New(t) + + list := makeList() + filterACLWithAuthorizer(logger, acl.DenyAll(), list) + + require.Empty(list.Sessions) + require.True(list.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be true") + }) } func TestACL_filterNodeDump(t *testing.T) { diff --git a/agent/consul/session_endpoint_test.go b/agent/consul/session_endpoint_test.go index 61551b7e83..58fe1d7872 100644 --- a/agent/consul/session_endpoint_test.go +++ b/agent/consul/session_endpoint_test.go @@ -6,6 +6,7 @@ import ( "time" msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc" + "github.com/stretchr/testify/require" "github.com/hashicorp/consul/acl" "github.com/hashicorp/consul/agent/structs" @@ -377,6 +378,7 @@ func TestSession_Get_List_NodeSessions_ACLFilter(t *testing.T) { } t.Parallel() + dir1, s1 := testServerWithConfig(t, func(c *Config) { c.PrimaryDatacenter = "dc1" c.ACLsEnabled = true @@ -391,12 +393,17 @@ func TestSession_Get_List_NodeSessions_ACLFilter(t *testing.T) { testrpc.WaitForLeader(t, s1.RPC, "dc1", testrpc.WithToken("root")) - rules := ` -session "foo" { - policy = "read" -} -` - token := createToken(t, codec, rules) + deniedToken := createTokenWithPolicyName(t, codec, "denied", ` + session "foo" { + policy = "deny" + } + `, "root") + + allowedToken := createTokenWithPolicyName(t, codec, "allowed", ` + session "foo" { + policy = "read" + } + `, "root") // Create a node and a session. s1.fsm.State().EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}) @@ -409,95 +416,94 @@ session "foo" { WriteRequest: structs.WriteRequest{Token: "root"}, } var out string - if err := msgpackrpc.CallWithCodec(codec, "Session.Apply", &arg, &out); err != nil { - t.Fatalf("err: %v", err) - } + err := msgpackrpc.CallWithCodec(codec, "Session.Apply", &arg, &out) + require.NoError(t, err) - // Perform all the read operations, and make sure everything is empty. - getR := structs.SessionSpecificRequest{ - Datacenter: "dc1", - SessionID: out, - } - { - var sessions structs.IndexedSessions - if err := msgpackrpc.CallWithCodec(codec, "Session.Get", &getR, &sessions); err != nil { - t.Fatalf("err: %v", err) + t.Run("Get", func(t *testing.T) { + require := require.New(t) + + req := &structs.SessionSpecificRequest{ + Datacenter: "dc1", + SessionID: out, } - if len(sessions.Sessions) != 0 { - t.Fatalf("bad: %v", sessions.Sessions) - } - } - listR := structs.DCSpecificRequest{ - Datacenter: "dc1", - } - { + req.Token = deniedToken + + // ACL-restricted results filtered out. var sessions structs.IndexedSessions - if err := msgpackrpc.CallWithCodec(codec, "Session.List", &listR, &sessions); err != nil { - t.Fatalf("err: %v", err) - } - if len(sessions.Sessions) != 0 { - t.Fatalf("bad: %v", sessions.Sessions) - } - } - nodeR := structs.NodeSpecificRequest{ - Datacenter: "dc1", - Node: "foo", - } - { - var sessions structs.IndexedSessions - if err := msgpackrpc.CallWithCodec(codec, "Session.NodeSessions", &nodeR, &sessions); err != nil { - t.Fatalf("err: %v", err) - } - if len(sessions.Sessions) != 0 { - t.Fatalf("bad: %v", sessions.Sessions) - } - } + err := msgpackrpc.CallWithCodec(codec, "Session.Get", req, &sessions) + require.NoError(err) + require.Empty(sessions.Sessions) + require.True(sessions.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be true") - // Finally, supply the token and make sure the reads are allowed. - getR.Token = token - { - var sessions structs.IndexedSessions - if err := msgpackrpc.CallWithCodec(codec, "Session.Get", &getR, &sessions); err != nil { - t.Fatalf("err: %v", err) - } - if len(sessions.Sessions) != 1 { - t.Fatalf("bad: %v", sessions.Sessions) - } - } - listR.Token = token - { - var sessions structs.IndexedSessions - if err := msgpackrpc.CallWithCodec(codec, "Session.List", &listR, &sessions); err != nil { - t.Fatalf("err: %v", err) - } - if len(sessions.Sessions) != 1 { - t.Fatalf("bad: %v", sessions.Sessions) - } - } - nodeR.Token = token - { - var sessions structs.IndexedSessions - if err := msgpackrpc.CallWithCodec(codec, "Session.NodeSessions", &nodeR, &sessions); err != nil { - t.Fatalf("err: %v", err) - } - if len(sessions.Sessions) != 1 { - t.Fatalf("bad: %v", sessions.Sessions) - } - } + // ACL-restricted results included. + req.Token = allowedToken - // Try to get a session that doesn't exist to make sure that's handled - // correctly by the filter (it will get passed a nil slice). - getR.SessionID = "adf4238a-882b-9ddc-4a9d-5b6758e4159e" - { + err = msgpackrpc.CallWithCodec(codec, "Session.Get", req, &sessions) + require.NoError(err) + require.Len(sessions.Sessions, 1) + require.False(sessions.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be false") + + // Try to get a session that doesn't exist to make sure that's handled + // correctly by the filter (it will get passed a nil slice). + req.SessionID = "adf4238a-882b-9ddc-4a9d-5b6758e4159e" + + err = msgpackrpc.CallWithCodec(codec, "Session.Get", req, &sessions) + require.NoError(err) + require.Empty(sessions.Sessions) + require.False(sessions.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be false") + }) + + t.Run("List", func(t *testing.T) { + require := require.New(t) + + req := &structs.DCSpecificRequest{ + Datacenter: "dc1", + } + req.Token = deniedToken + + // ACL-restricted results filtered out. var sessions structs.IndexedSessions - if err := msgpackrpc.CallWithCodec(codec, "Session.Get", &getR, &sessions); err != nil { - t.Fatalf("err: %v", err) + + err := msgpackrpc.CallWithCodec(codec, "Session.List", req, &sessions) + require.NoError(err) + require.Empty(sessions.Sessions) + require.True(sessions.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be true") + + // ACL-restricted results included. + req.Token = allowedToken + + err = msgpackrpc.CallWithCodec(codec, "Session.List", req, &sessions) + require.NoError(err) + require.Len(sessions.Sessions, 1) + require.False(sessions.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be false") + }) + + t.Run("NodeSessions", func(t *testing.T) { + require := require.New(t) + + req := &structs.NodeSpecificRequest{ + Datacenter: "dc1", + Node: "foo", } - if len(sessions.Sessions) != 0 { - t.Fatalf("bad: %v", sessions.Sessions) - } - } + req.Token = deniedToken + + // ACL-restricted results filtered out. + var sessions structs.IndexedSessions + + err := msgpackrpc.CallWithCodec(codec, "Session.NodeSessions", req, &sessions) + require.NoError(err) + require.Empty(sessions.Sessions) + require.True(sessions.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be true") + + // ACL-restricted results included. + req.Token = allowedToken + + err = msgpackrpc.CallWithCodec(codec, "Session.NodeSessions", req, &sessions) + require.NoError(err) + require.Len(sessions.Sessions, 1) + require.False(sessions.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be false") + }) } func TestSession_ApplyTimers(t *testing.T) {