diff --git a/consul/acl.go b/consul/acl.go index abc011054c..66bea46180 100644 --- a/consul/acl.go +++ b/consul/acl.go @@ -192,3 +192,105 @@ func (s *Server) useACLPolicy(id, authDC string, cached *aclCacheEntry, p *struc s.aclCache.Add(id, cached) return compiled, nil } + +// discoveryFilter is used to determine if we should return a given node +// or service based on the ACL passed in. +func (s *Server) discoveryFilter(node, service string, acl acl.ACL) bool { + if acl == nil { + return true + } + + // Filter service discovery ACLs + if service != "" && service != ConsulServiceID && !acl.ServiceRead(service) { + s.logger.Printf("[DEBUG] consul: reading service '%s' denied due to ACLs", service) + return false + } + + // Filtering passed + return true +} + +// applyDiscoveryACLs is used to filter results from our service catalog based +// on the configured rules for the request ACL. Nodes or services which do +// not match the ACL rules will be dropped from the result. +func (s *Server) applyDiscoveryACLs(token string, subj interface{}) error { + // Get the ACL from the token + acl, err := s.resolveToken(token) + if err != nil { + return err + } + + // Fast path if ACLs are not enabled + if acl == nil { + return nil + } + + filt := func(service string) bool { + // Don't filter the "consul" service or empty service names + if service == "" || service == ConsulServiceID { + return true + } + + // Check the ACL + if !acl.ServiceRead(service) { + s.logger.Printf("[DEBUG] consul: reading service '%s' denied due to ACLs", service) + return false + } + return true + } + + switch v := subj.(type) { + // Filter health checks + case *structs.IndexedHealthChecks: + for i := 0; i < len(v.HealthChecks); i++ { + hc := v.HealthChecks[i] + if filt(hc.ServiceName) { + continue + } + v.HealthChecks = append(v.HealthChecks[:i], v.HealthChecks[i+1:]...) + i-- + } + + // Filter services + case *structs.IndexedServices: + for svc, _ := range v.Services { + if filt(svc) { + continue + } + delete(v.Services, svc) + } + + // Filter service nodes + case *structs.IndexedServiceNodes: + for i := 0; i < len(v.ServiceNodes); i++ { + node := v.ServiceNodes[i] + if filt(node.ServiceName) { + continue + } + v.ServiceNodes = append(v.ServiceNodes[:i], v.ServiceNodes[i+1:]...) + i-- + } + + // Filter node services + case *structs.IndexedNodeServices: + for svc, _ := range v.NodeServices.Services { + if filt(svc) { + continue + } + delete(v.NodeServices.Services, svc) + } + + // Filter check service nodes + case *structs.IndexedCheckServiceNodes: + for i := 0; i < len(v.Nodes); i++ { + cs := v.Nodes[i] + if filt(cs.Service.Service) { + continue + } + v.Nodes = append(v.Nodes[:i], v.Nodes[i+1:]...) + i-- + } + } + + return nil +} diff --git a/consul/catalog_endpoint.go b/consul/catalog_endpoint.go index b39c19f2bd..3e05755c6f 100644 --- a/consul/catalog_endpoint.go +++ b/consul/catalog_endpoint.go @@ -126,7 +126,7 @@ func (c *Catalog) ListNodes(args *structs.DCSpecificRequest, reply *structs.Inde state.QueryTables("Nodes"), func() error { reply.Index, reply.Nodes = state.Nodes() - return nil + return c.srv.applyDiscoveryACLs(args.Token, reply) }) } @@ -143,7 +143,7 @@ func (c *Catalog) ListServices(args *structs.DCSpecificRequest, reply *structs.I state.QueryTables("Services"), func() error { reply.Index, reply.Services = state.Services() - return nil + return c.srv.applyDiscoveryACLs(args.Token, reply) }) } @@ -169,7 +169,7 @@ func (c *Catalog) ServiceNodes(args *structs.ServiceSpecificRequest, reply *stru } else { reply.Index, reply.ServiceNodes = state.ServiceNodes(args.ServiceName) } - return nil + return c.srv.applyDiscoveryACLs(args.Token, reply) }) // Provide some metrics @@ -203,6 +203,6 @@ func (c *Catalog) NodeServices(args *structs.NodeSpecificRequest, reply *structs state.QueryTables("NodeServices"), func() error { reply.Index, reply.NodeServices = state.NodeServices(args.Node) - return nil + return c.srv.applyDiscoveryACLs(args.Token, reply) }) } diff --git a/consul/health_endpoint.go b/consul/health_endpoint.go index e6db8c99ad..93e7a44f2d 100644 --- a/consul/health_endpoint.go +++ b/consul/health_endpoint.go @@ -43,7 +43,7 @@ func (h *Health) NodeChecks(args *structs.NodeSpecificRequest, state.QueryTables("NodeChecks"), func() error { reply.Index, reply.HealthChecks = state.NodeChecks(args.Node) - return nil + return h.srv.applyDiscoveryACLs(args.Token, reply) }) } @@ -67,7 +67,7 @@ func (h *Health) ServiceChecks(args *structs.ServiceSpecificRequest, state.QueryTables("ServiceChecks"), func() error { reply.Index, reply.HealthChecks = state.ServiceChecks(args.ServiceName) - return nil + return h.srv.applyDiscoveryACLs(args.Token, reply) }) } @@ -93,7 +93,7 @@ func (h *Health) ServiceNodes(args *structs.ServiceSpecificRequest, reply *struc } else { reply.Index, reply.Nodes = state.CheckServiceNodes(args.ServiceName) } - return nil + return h.srv.applyDiscoveryACLs(args.Token, reply) }) // Provide some metrics diff --git a/consul/internal_endpoint.go b/consul/internal_endpoint.go index 3032e0b031..4cacd96290 100644 --- a/consul/internal_endpoint.go +++ b/consul/internal_endpoint.go @@ -12,39 +12,39 @@ type Internal struct { srv *Server } -// ChecksInState is used to get all the checks in a given state +// NodeInfo is used to retrieve information about a specific node. func (m *Internal) NodeInfo(args *structs.NodeSpecificRequest, reply *structs.IndexedNodeDump) error { if done, err := m.srv.forward("Internal.NodeInfo", args, args, reply); done { return err } - // Get the state specific checks + // Get the node info state := m.srv.fsm.State() return m.srv.blockingRPC(&args.QueryOptions, &reply.QueryMeta, state.QueryTables("NodeInfo"), func() error { reply.Index, reply.Dump = state.NodeInfo(args.Node) - return nil + return m.srv.applyDiscoveryACLs(args.Token, reply) }) } -// ChecksInState is used to get all the checks in a given state +// NodeDump is used to generate information about all of the nodes. func (m *Internal) NodeDump(args *structs.DCSpecificRequest, reply *structs.IndexedNodeDump) error { if done, err := m.srv.forward("Internal.NodeDump", args, args, reply); done { return err } - // Get the state specific checks + // Get all the node info state := m.srv.fsm.State() return m.srv.blockingRPC(&args.QueryOptions, &reply.QueryMeta, state.QueryTables("NodeDump"), func() error { reply.Index, reply.Dump = state.NodeDump() - return nil + return m.srv.applyDiscoveryACLs(args.Token, reply) }) }