consul/command/exec/exec.go

701 lines
18 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package exec
import (
"bytes"
"encoding/json"
"flag"
"fmt"
"io"
"os"
"path"
"regexp"
"strconv"
"strings"
"time"
"unicode"
"github.com/hashicorp/consul/api"
"github.com/hashicorp/consul/command/flags"
"github.com/mitchellh/cli"
)
func New(ui cli.Ui, shutdownCh <-chan struct{}) *cmd {
c := &cmd{UI: ui, shutdownCh: shutdownCh}
c.init()
return c
}
type cmd struct {
UI cli.Ui
flags *flag.FlagSet
http *flags.HTTPFlags
help string
shutdownCh <-chan struct{}
conf rExecConf
apiclient *api.Client
sessionID string
stopCh chan struct{}
}
func (c *cmd) init() {
c.flags = flag.NewFlagSet("", flag.ContinueOnError)
c.flags.StringVar(&c.conf.node, "node", "",
"Regular expression to filter on node names.")
c.flags.StringVar(&c.conf.service, "service", "",
"Regular expression to filter on service instances.")
c.flags.StringVar(&c.conf.tag, "tag", "",
"Regular expression to filter on service tags. Must be used with -service.")
c.flags.StringVar(&c.conf.prefix, "prefix", rExecPrefix,
"Prefix in the KV store to use for request data.")
c.flags.BoolVar(&c.conf.shell, "shell", true,
"Use a shell to run the command.")
c.flags.DurationVar(&c.conf.wait, "wait", rExecQuietWait,
"Period to wait with no responses before terminating execution.")
c.flags.DurationVar(&c.conf.replWait, "wait-repl", rExecReplicationWait,
"Period to wait for replication before firing event. This is an optimization to allow stale reads to be performed.")
c.flags.BoolVar(&c.conf.verbose, "verbose", false,
"Enables verbose output.")
c.http = &flags.HTTPFlags{}
flags.Merge(c.flags, c.http.ClientFlags())
flags.Merge(c.flags, c.http.ServerFlags())
c.help = flags.Usage(help, c.flags)
}
func (c *cmd) Run(args []string) int {
if err := c.flags.Parse(args); err != nil {
return 1
}
// Join the commands to execute
c.conf.cmd = strings.Join(c.flags.Args(), " ")
// If there is no command, read stdin for a script input
if c.conf.cmd == "-" {
if !c.conf.shell {
c.UI.Error("Cannot configure -shell=false when reading from stdin")
return 1
}
c.conf.cmd = ""
var buf bytes.Buffer
_, err := io.Copy(&buf, os.Stdin)
if err != nil {
c.UI.Error(fmt.Sprintf("Failed to read stdin: %v", err))
c.UI.Error("")
c.UI.Error(c.Help())
return 1
}
c.conf.script = buf.Bytes()
} else if !c.conf.shell {
c.conf.cmd = ""
c.conf.args = c.flags.Args()
}
// Ensure we have a command or script
if c.conf.cmd == "" && len(c.conf.script) == 0 && len(c.conf.args) == 0 {
c.UI.Error("Must specify a command to execute")
c.UI.Error("")
c.UI.Error(c.Help())
return 1
}
// Validate the configuration
if err := c.conf.validate(); err != nil {
c.UI.Error(err.Error())
return 1
}
// Create and test the HTTP client
client, err := c.http.APIClient()
if err != nil {
c.UI.Error(fmt.Sprintf("Error connecting to Consul agent: %s", err))
return 1
}
info, err := client.Agent().Self()
if err != nil {
c.UI.Error(fmt.Sprintf("Error querying Consul agent: %s", err))
return 1
}
c.apiclient = client
// Check if this is a foreign datacenter
if c.http.Datacenter() != "" && c.http.Datacenter() != info["Config"]["Datacenter"] {
if c.conf.verbose {
c.UI.Info("Remote exec in foreign datacenter, using Session TTL")
}
c.conf.foreignDC = true
c.conf.localDC = info["Config"]["Datacenter"].(string)
c.conf.localNode = info["Config"]["NodeName"].(string)
}
// Create the job spec
spec, err := c.makeRExecSpec()
if err != nil {
c.UI.Error(fmt.Sprintf("Failed to create job spec: %s", err))
return 1
}
// Create a session for this
c.sessionID, err = c.createSession()
if err != nil {
c.UI.Error(fmt.Sprintf("Failed to create session: %s", err))
return 1
}
defer c.destroySession()
if c.conf.verbose {
c.UI.Info(fmt.Sprintf("Created remote execution session: %s", c.sessionID))
}
// Upload the payload
if err := c.uploadPayload(spec); err != nil {
c.UI.Error(fmt.Sprintf("Failed to create job file: %s", err))
return 1
}
defer c.destroyData()
if c.conf.verbose {
c.UI.Info(fmt.Sprintf("Uploaded remote execution spec"))
}
// Wait for replication. This is done so that when the event is
// received, the job file can be read using a stale read. If the
// stale read fails, we expect a consistent read to be done, so
// largely this is a heuristic.
select {
case <-time.After(c.conf.replWait):
case <-c.shutdownCh:
return 1
}
// Fire the event
id, err := c.fireEvent()
if err != nil {
c.UI.Error(fmt.Sprintf("Failed to fire event: %s", err))
return 1
}
if c.conf.verbose {
c.UI.Info(fmt.Sprintf("Fired remote execution event: %s", id))
}
// Wait for the job to finish now
return c.waitForJob()
}
func (c *cmd) Synopsis() string {
return synopsis
}
func (c *cmd) Help() string {
return c.help
}
const synopsis = "Executes a command on Consul nodes"
const help = `
Usage: consul exec [options] [-|command...]
Evaluates a command on remote Consul nodes. The nodes responding can
be filtered using regular expressions on node name, service, and tag
definitions. If a command is '-', stdin will be read until EOF
and used as a script input.
`
// waitForJob is used to poll for results and wait until the job is terminated
func (c *cmd) waitForJob() int {
// Although the session destroy is already deferred, we do it again here,
// because invalidation of the session before destroyData() ensures there is
// no race condition allowing an agent to upload data (the acquire will fail).
defer c.destroySession()
start := time.Now()
ackCh := make(chan rExecAck, 128)
heartCh := make(chan rExecHeart, 128)
outputCh := make(chan rExecOutput, 128)
exitCh := make(chan rExecExit, 128)
doneCh := make(chan struct{})
errCh := make(chan struct{}, 1)
defer close(doneCh)
go c.streamResults(doneCh, ackCh, heartCh, outputCh, exitCh, errCh)
target := &TargetedUI{UI: c.UI}
var ackCount, exitCount, badExit int
OUTER:
for {
// Determine wait time. We provide a larger window if we know about
// nodes which are still working.
waitIntv := c.conf.wait
if ackCount > exitCount {
waitIntv *= 2
}
select {
case e := <-ackCh:
ackCount++
if c.conf.verbose {
target.Target = e.Node
target.Info("acknowledged")
}
case h := <-heartCh:
if c.conf.verbose {
target.Target = h.Node
target.Info("heartbeat received")
}
case e := <-outputCh:
target.Target = e.Node
target.Output(string(e.Output))
case e := <-exitCh:
exitCount++
target.Target = e.Node
target.Info(fmt.Sprintf("finished with exit code %d", e.Code))
if e.Code != 0 {
badExit++
}
case <-time.After(waitIntv):
c.UI.Info(fmt.Sprintf("%d / %d node(s) completed / acknowledged", exitCount, ackCount))
if c.conf.verbose {
c.UI.Info(fmt.Sprintf("Completed in %0.2f seconds",
float64(time.Since(start))/float64(time.Second)))
}
if exitCount < ackCount {
badExit++
}
break OUTER
case <-errCh:
return 1
case <-c.shutdownCh:
return 1
}
}
if badExit > 0 {
return 2
}
return 0
}
// streamResults is used to perform blocking queries against the KV endpoint and stream in
// notice of various events into waitForJob
func (c *cmd) streamResults(doneCh chan struct{}, ackCh chan rExecAck, heartCh chan rExecHeart,
outputCh chan rExecOutput, exitCh chan rExecExit, errCh chan struct{}) {
kv := c.apiclient.KV()
opts := api.QueryOptions{WaitTime: c.conf.wait}
dir := path.Join(c.conf.prefix, c.sessionID) + "/"
seen := make(map[string]struct{})
for {
// Check if we've been signaled to exit
select {
case <-doneCh:
return
default:
}
// Block on waiting for new keys
keys, qm, err := kv.Keys(dir, "", &opts)
if err != nil {
c.UI.Error(fmt.Sprintf("Failed to read results: %s", err))
goto ERR_EXIT
}
// Fast-path the no-change case
if qm.LastIndex == opts.WaitIndex {
continue
}
opts.WaitIndex = qm.LastIndex
// Handle each key
for _, key := range keys {
// Ignore if we've seen it
if _, ok := seen[key]; ok {
continue
}
seen[key] = struct{}{}
// Trim the directory
full := key
key = strings.TrimPrefix(key, dir)
// Handle the key type
switch {
case key == rExecFileName:
continue
case strings.HasSuffix(key, rExecAckSuffix):
ackCh <- rExecAck{Node: strings.TrimSuffix(key, rExecAckSuffix)}
case strings.HasSuffix(key, rExecExitSuffix):
pair, _, err := kv.Get(full, nil)
if err != nil || pair == nil {
c.UI.Error(fmt.Sprintf("Failed to read key '%s': %v", full, err))
continue
}
code, err := strconv.ParseInt(string(pair.Value), 10, 32)
if err != nil {
c.UI.Error(fmt.Sprintf("Failed to parse exit code '%s': %v", pair.Value, err))
continue
}
exitCh <- rExecExit{
Node: strings.TrimSuffix(key, rExecExitSuffix),
Code: int(code),
}
case strings.LastIndex(key, rExecOutputDivider) != -1:
pair, _, err := kv.Get(full, nil)
if err != nil || pair == nil {
c.UI.Error(fmt.Sprintf("Failed to read key '%s': %v", full, err))
continue
}
idx := strings.LastIndex(key, rExecOutputDivider)
node := key[:idx]
if len(pair.Value) == 0 {
heartCh <- rExecHeart{Node: node}
} else {
outputCh <- rExecOutput{Node: node, Output: pair.Value}
}
default:
c.UI.Error(fmt.Sprintf("Unknown key '%s', ignoring.", key))
}
}
}
ERR_EXIT:
select {
case errCh <- struct{}{}:
default:
}
}
// validate checks that the configuration is reasonable
func (conf *rExecConf) validate() error {
// Validate the filters
if conf.node != "" {
if _, err := regexp.Compile(conf.node); err != nil {
return fmt.Errorf("Failed to compile node filter regexp: %v", err)
}
}
if conf.service != "" {
if _, err := regexp.Compile(conf.service); err != nil {
return fmt.Errorf("Failed to compile service filter regexp: %v", err)
}
}
if conf.tag != "" {
if _, err := regexp.Compile(conf.tag); err != nil {
return fmt.Errorf("Failed to compile tag filter regexp: %v", err)
}
}
if conf.tag != "" && conf.service == "" {
return fmt.Errorf("Cannot provide tag filter without service filter.")
}
return nil
}
// createSession is used to create a new session for this command
func (c *cmd) createSession() (string, error) {
var id string
var err error
if c.conf.foreignDC {
id, err = c.createSessionForeign()
} else {
id, err = c.createSessionLocal()
}
if err == nil {
c.stopCh = make(chan struct{})
go c.renewSession(id, c.stopCh)
}
return id, err
}
// createSessionLocal is used to create a new session in a local datacenter
// This is simpler since we can use the local agent to create the session.
func (c *cmd) createSessionLocal() (string, error) {
session := c.apiclient.Session()
se := api.SessionEntry{
Name: "Remote Exec",
Behavior: api.SessionBehaviorDelete,
TTL: rExecTTL,
}
id, _, err := session.Create(&se, nil)
return id, err
}
// createSessionLocal is used to create a new session in a foreign datacenter
// This is more complex since the local agent cannot be used to create
// a session, and we must associate with a node in the remote datacenter.
func (c *cmd) createSessionForeign() (string, error) {
// Look for a remote node to bind to
health := c.apiclient.Health()
services, _, err := health.Service("consul", "", true, nil)
if err != nil {
return "", fmt.Errorf("Failed to find Consul server in remote datacenter: %v", err)
}
if len(services) == 0 {
return "", fmt.Errorf("Failed to find Consul server in remote datacenter")
}
node := services[0].Node.Node
if c.conf.verbose {
c.UI.Info(fmt.Sprintf("Binding session to remote node %s@%s", node, c.http.Datacenter()))
}
session := c.apiclient.Session()
se := api.SessionEntry{
Name: fmt.Sprintf("Remote Exec via %s@%s", c.conf.localNode, c.conf.localDC),
Node: node,
Checks: []string{},
Behavior: api.SessionBehaviorDelete,
TTL: rExecTTL,
}
id, _, err := session.CreateNoChecks(&se, nil)
return id, err
}
// renewSession is a long running routine that periodically renews
// the session TTL. This is used for foreign sessions where we depend
// on TTLs.
func (c *cmd) renewSession(id string, stopCh chan struct{}) {
session := c.apiclient.Session()
for {
select {
case <-time.After(rExecRenewInterval):
_, _, err := session.Renew(id, nil)
if err != nil {
c.UI.Error(fmt.Sprintf("Session renew failed: %v", err))
return
}
case <-stopCh:
return
}
}
}
// destroySession is used to destroy the associated session
func (c *cmd) destroySession() error {
// Stop the session renew if any
if c.stopCh != nil {
close(c.stopCh)
c.stopCh = nil
}
// Destroy the session explicitly
session := c.apiclient.Session()
_, err := session.Destroy(c.sessionID, nil)
return err
}
// makeRExecSpec creates a serialized job specification
// that can be uploaded which will be parsed by agents to
// determine what to do.
func (c *cmd) makeRExecSpec() ([]byte, error) {
spec := &rExecSpec{
Command: c.conf.cmd,
Args: c.conf.args,
Script: c.conf.script,
Wait: c.conf.wait,
}
return json.Marshal(spec)
}
// uploadPayload is used to upload the request payload
func (c *cmd) uploadPayload(payload []byte) error {
kv := c.apiclient.KV()
pair := api.KVPair{
Key: path.Join(c.conf.prefix, c.sessionID, rExecFileName),
Value: payload,
Session: c.sessionID,
}
ok, _, err := kv.Acquire(&pair, nil)
if err != nil {
return err
}
if !ok {
return fmt.Errorf("failed to acquire key %s", pair.Key)
}
return nil
}
// destroyData is used to nuke all the data associated with
// this remote exec. We just do a recursive delete of our
// data directory.
func (c *cmd) destroyData() error {
kv := c.apiclient.KV()
dir := path.Join(c.conf.prefix, c.sessionID)
_, err := kv.DeleteTree(dir, nil)
return err
}
// fireEvent is used to fire the event that will notify nodes
// about the remote execution. Returns the event ID or error
func (c *cmd) fireEvent() (string, error) {
// Create the user event payload
msg := &rExecEvent{
Prefix: c.conf.prefix,
Session: c.sessionID,
}
buf, err := json.Marshal(msg)
if err != nil {
return "", err
}
// Format the user event
event := c.apiclient.Event()
params := &api.UserEvent{
Name: "_rexec",
Payload: buf,
NodeFilter: c.conf.node,
ServiceFilter: c.conf.service,
TagFilter: c.conf.tag,
}
// Fire the event
id, _, err := event.Fire(params, nil)
return id, err
}
const (
// rExecPrefix is the prefix in the KV store used to
// store the remote exec data
rExecPrefix = "_rexec"
// rExecFileName is the name of the file we append to
// the path, e.g. _rexec/session_id/job
rExecFileName = "job"
// rExecAck is the suffix added to an ack path
rExecAckSuffix = "/ack"
// rExecAck is the suffix added to an exit code
rExecExitSuffix = "/exit"
// rExecOutputDivider is used to namespace the output
rExecOutputDivider = "/out/"
// rExecReplicationWait is how long we wait for replication
rExecReplicationWait = 200 * time.Millisecond
// rExecQuietWait is how long we wait for no responses
// before assuming the job is done.
rExecQuietWait = 2 * time.Second
// rExecTTL is how long we default the session TTL to
rExecTTL = "15s"
// rExecRenewInterval is how often we renew the session TTL
// when doing an exec in a foreign DC.
rExecRenewInterval = 5 * time.Second
)
// rExecConf is used to pass around configuration
type rExecConf struct {
prefix string
shell bool
foreignDC bool
localDC string
localNode string
node string
service string
tag string
wait time.Duration
replWait time.Duration
cmd string
args []string
script []byte
verbose bool
}
// rExecEvent is the event we broadcast using a user-event
type rExecEvent struct {
Prefix string
Session string
}
// rExecSpec is the file we upload to specify the parameters
// of the remote execution.
type rExecSpec struct {
// Command is a single command to run directly in the shell
Command string `json:",omitempty"`
// Args is the list of arguments to run the subprocess directly
Args []string `json:",omitempty"`
// Script should be spilled to a file and executed
Script []byte `json:",omitempty"`
// Wait is how long we are waiting on a quiet period to terminate
Wait time.Duration
}
// rExecAck is used to transmit an acknowledgement
type rExecAck struct {
Node string
}
// rExecHeart is used to transmit a heartbeat
type rExecHeart struct {
Node string
}
// rExecOutput is used to transmit a chunk of output
type rExecOutput struct {
Node string
Output []byte
}
// rExecExit is used to transmit an exit code
type rExecExit struct {
Node string
Code int
}
// TargetedUI is a UI that wraps another UI implementation and modifies
// the output to indicate a specific target. Specifically, all Say output
// is prefixed with the target name. Message output is not prefixed but
// is offset by the length of the target so that output is lined up properly
// with Say output. Machine-readable output has the proper target set.
type TargetedUI struct {
Target string
UI cli.Ui
}
func (u *TargetedUI) Ask(query string) (string, error) {
return u.UI.Ask(u.prefixLines(true, query))
}
func (u *TargetedUI) Info(message string) {
u.UI.Info(u.prefixLines(true, message))
}
func (u *TargetedUI) Output(message string) {
u.UI.Output(u.prefixLines(false, message))
}
func (u *TargetedUI) Error(message string) {
u.UI.Error(u.prefixLines(true, message))
}
func (u *TargetedUI) prefixLines(arrow bool, message string) string {
arrowText := "==>"
if !arrow {
arrowText = strings.Repeat(" ", len(arrowText))
}
var result bytes.Buffer
for _, line := range strings.Split(message, "\n") {
result.WriteString(fmt.Sprintf("%s %s: %s\n", arrowText, u.Target, line))
}
return strings.TrimRightFunc(result.String(), unicode.IsSpace)
}