mirror of https://github.com/status-im/consul.git
Allow cancelling startup when performing auto-config (#8157)
Co-authored-by: Daniel Nephin <dnephin@hashicorp.com>
This commit is contained in:
parent
1fa4d570d2
commit
0736c42b72
|
@ -257,8 +257,6 @@ type Agent struct {
|
|||
shutdownCh chan struct{}
|
||||
shutdownLock sync.Mutex
|
||||
|
||||
InterruptStartCh chan struct{}
|
||||
|
||||
// joinLANNotifier is called after a successful JoinLAN.
|
||||
joinLANNotifier notifier
|
||||
|
||||
|
@ -414,23 +412,22 @@ func New(options ...AgentOption) (*Agent, error) {
|
|||
|
||||
// Create most of the agent
|
||||
a := Agent{
|
||||
checkReapAfter: make(map[structs.CheckID]time.Duration),
|
||||
checkMonitors: make(map[structs.CheckID]*checks.CheckMonitor),
|
||||
checkTTLs: make(map[structs.CheckID]*checks.CheckTTL),
|
||||
checkHTTPs: make(map[structs.CheckID]*checks.CheckHTTP),
|
||||
checkTCPs: make(map[structs.CheckID]*checks.CheckTCP),
|
||||
checkGRPCs: make(map[structs.CheckID]*checks.CheckGRPC),
|
||||
checkDockers: make(map[structs.CheckID]*checks.CheckDocker),
|
||||
checkAliases: make(map[structs.CheckID]*checks.CheckAlias),
|
||||
eventCh: make(chan serf.UserEvent, 1024),
|
||||
eventBuf: make([]*UserEvent, 256),
|
||||
joinLANNotifier: &systemd.Notifier{},
|
||||
retryJoinCh: make(chan error),
|
||||
shutdownCh: make(chan struct{}),
|
||||
InterruptStartCh: make(chan struct{}),
|
||||
endpoints: make(map[string]string),
|
||||
tokens: new(token.Store),
|
||||
logger: flat.logger,
|
||||
checkReapAfter: make(map[structs.CheckID]time.Duration),
|
||||
checkMonitors: make(map[structs.CheckID]*checks.CheckMonitor),
|
||||
checkTTLs: make(map[structs.CheckID]*checks.CheckTTL),
|
||||
checkHTTPs: make(map[structs.CheckID]*checks.CheckHTTP),
|
||||
checkTCPs: make(map[structs.CheckID]*checks.CheckTCP),
|
||||
checkGRPCs: make(map[structs.CheckID]*checks.CheckGRPC),
|
||||
checkDockers: make(map[structs.CheckID]*checks.CheckDocker),
|
||||
checkAliases: make(map[structs.CheckID]*checks.CheckAlias),
|
||||
eventCh: make(chan serf.UserEvent, 1024),
|
||||
eventBuf: make([]*UserEvent, 256),
|
||||
joinLANNotifier: &systemd.Notifier{},
|
||||
retryJoinCh: make(chan error),
|
||||
shutdownCh: make(chan struct{}),
|
||||
endpoints: make(map[string]string),
|
||||
tokens: new(token.Store),
|
||||
logger: flat.logger,
|
||||
}
|
||||
|
||||
// parse the configuration and handle the error/warnings
|
||||
|
@ -599,13 +596,13 @@ func LocalConfig(cfg *config.RuntimeConfig) local.Config {
|
|||
}
|
||||
|
||||
// Start verifies its configuration and runs an agent's various subprocesses.
|
||||
func (a *Agent) Start() error {
|
||||
func (a *Agent) Start(ctx context.Context) error {
|
||||
a.stateLock.Lock()
|
||||
defer a.stateLock.Unlock()
|
||||
|
||||
// This needs to be done early on as it will potentially alter the configuration
|
||||
// and then how other bits are brought up
|
||||
c, err := a.autoConf.InitialConfiguration(&lib.StopChannelContext{StopCh: a.shutdownCh})
|
||||
c, err := a.autoConf.InitialConfiguration(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -709,7 +706,7 @@ func (a *Agent) Start() error {
|
|||
a.registerCache()
|
||||
|
||||
if a.config.AutoEncryptTLS && !a.config.ServerMode {
|
||||
reply, err := a.setupClientAutoEncrypt()
|
||||
reply, err := a.setupClientAutoEncrypt(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("AutoEncrypt failed: %s", err)
|
||||
}
|
||||
|
@ -822,7 +819,7 @@ func (a *Agent) Start() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (a *Agent) setupClientAutoEncrypt() (*structs.SignedResponse, error) {
|
||||
func (a *Agent) setupClientAutoEncrypt(ctx context.Context) (*structs.SignedResponse, error) {
|
||||
client := a.delegate.(*consul.Client)
|
||||
|
||||
addrs := a.config.StartJoinAddrsLAN
|
||||
|
@ -832,7 +829,7 @@ func (a *Agent) setupClientAutoEncrypt() (*structs.SignedResponse, error) {
|
|||
}
|
||||
addrs = append(addrs, retryJoinAddrs(disco, retryJoinSerfVariant, "LAN", a.config.RetryJoinLAN, a.logger)...)
|
||||
|
||||
reply, priv, err := client.RequestAutoEncryptCerts(addrs, a.config.ServerPort, a.tokens.AgentToken(), a.InterruptStartCh)
|
||||
reply, priv, err := client.RequestAutoEncryptCerts(ctx, addrs, a.config.ServerPort, a.tokens.AgentToken())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -961,7 +958,8 @@ func (a *Agent) setupClientAutoEncryptWatching(rootsReq *structs.DCSpecificReque
|
|||
// check auto encrypt client cert expiration
|
||||
if a.tlsConfigurator.AutoEncryptCertExpired() {
|
||||
autoLogger.Debug("client certificate expired.")
|
||||
reply, err := a.setupClientAutoEncrypt()
|
||||
// Background because the context is mainly useful when the agent is first starting up.
|
||||
reply, err := a.setupClientAutoEncrypt(context.Background())
|
||||
if err != nil {
|
||||
autoLogger.Error("client certificate expired, failed to renew", "error", err)
|
||||
// in case of an error, try again in one minute
|
||||
|
|
|
@ -262,6 +262,7 @@ func (ac *AutoConfig) InitialConfiguration(ctx context.Context) (*config.Runtime
|
|||
}
|
||||
|
||||
if !ready {
|
||||
ac.logger.Info("retrieving initial agent auto configuration remotely")
|
||||
if err := ac.getInitialConfiguration(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -447,7 +448,7 @@ func (ac *AutoConfig) getInitialConfigurationOnce(ctx context.Context) (bool, er
|
|||
return false, ctx.Err()
|
||||
}
|
||||
|
||||
ac.logger.Debug("Making Cluster.AutoConfig RPC", "addr", addr.String())
|
||||
ac.logger.Debug("making Cluster.AutoConfig RPC", "addr", addr.String())
|
||||
if err = ac.directRPC.RPC(ac.config.Datacenter, ac.config.NodeName, &addr, "Cluster.AutoConfig", &request, &reply); err != nil {
|
||||
ac.logger.Error("AutoConfig RPC failed", "addr", addr.String(), "error", err)
|
||||
continue
|
||||
|
@ -457,7 +458,7 @@ func (ac *AutoConfig) getInitialConfigurationOnce(ctx context.Context) (bool, er
|
|||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
return false, ctx.Err()
|
||||
}
|
||||
|
||||
// getInitialConfiguration implements a loop to retry calls to getInitialConfigurationOnce.
|
||||
|
@ -469,11 +470,16 @@ func (ac *AutoConfig) getInitialConfiguration(ctx context.Context) error {
|
|||
for {
|
||||
select {
|
||||
case <-wait:
|
||||
if done, err := ac.getInitialConfigurationOnce(ctx); done {
|
||||
done, err := ac.getInitialConfigurationOnce(ctx)
|
||||
if done {
|
||||
return err
|
||||
}
|
||||
if err != nil {
|
||||
ac.logger.Error(err.Error())
|
||||
}
|
||||
wait = ac.waiter.Failed()
|
||||
case <-ctx.Done():
|
||||
ac.logger.Info("interrupted during initial auto configuration", "err", ctx.Err())
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1923,6 +1923,15 @@ func (b *Builder) autoConfigVal(raw AutoConfigRaw) AutoConfig {
|
|||
|
||||
val.Enabled = b.boolValWithDefault(raw.Enabled, false)
|
||||
val.IntroToken = b.stringVal(raw.IntroToken)
|
||||
|
||||
// default the IntroToken to the env variable if specified.
|
||||
if envToken := os.Getenv("CONSUL_INTRO_TOKEN"); envToken != "" {
|
||||
if val.IntroToken != "" {
|
||||
b.warn("Both auto_config.intro_token and the CONSUL_INTRO_TOKEN environment variable are set. Using the value from the environment variable")
|
||||
}
|
||||
|
||||
val.IntroToken = envToken
|
||||
}
|
||||
val.IntroTokenFile = b.stringVal(raw.IntroTokenFile)
|
||||
// These can be go-discover values and so don't have to resolve fully yet
|
||||
val.ServerAddresses = b.expandAllOptionalAddrs("auto_config.server_addresses", raw.ServerAddresses)
|
||||
|
@ -1995,9 +2004,9 @@ func (b *Builder) validateAutoConfig(rt RuntimeConfig) error {
|
|||
|
||||
// When both are set we will prefer the given value over the file.
|
||||
if autoconf.IntroToken != "" && autoconf.IntroTokenFile != "" {
|
||||
b.warn("auto_config.intro_token and auto_config.intro_token_file are both set. Using the value of auto_config.intro_token")
|
||||
b.warn("Both an intro token and intro token file are set. The intro token will be used instead of the file")
|
||||
} else if autoconf.IntroToken == "" && autoconf.IntroTokenFile == "" {
|
||||
return fmt.Errorf("one of auto_config.intro_token or auto_config.intro_token_file must be set to enable auto_config")
|
||||
return fmt.Errorf("One of auto_config.intro_token, auto_config.intro_token_file or the CONSUL_INTRO_TOKEN environment variable must be set to enable auto_config")
|
||||
}
|
||||
|
||||
if len(autoconf.ServerAddresses) == 0 {
|
||||
|
|
|
@ -3847,7 +3847,7 @@ func TestConfigFlagsAndEdgecases(t *testing.T) {
|
|||
"server_addresses": ["198.18.0.1"]
|
||||
}
|
||||
}`},
|
||||
err: "one of auto_config.intro_token or auto_config.intro_token_file must be set to enable auto_config",
|
||||
err: "One of auto_config.intro_token, auto_config.intro_token_file or the CONSUL_INTRO_TOKEN environment variable must be set to enable auto_config",
|
||||
},
|
||||
|
||||
{
|
||||
|
@ -3899,7 +3899,7 @@ func TestConfigFlagsAndEdgecases(t *testing.T) {
|
|||
}`},
|
||||
warns: []string{
|
||||
"Cannot parse ip \"invalid\" from auto_config.ip_sans",
|
||||
"auto_config.intro_token and auto_config.intro_token_file are both set. Using the value of auto_config.intro_token",
|
||||
"Both an intro token and intro token file are set. The intro token will be used instead of the file",
|
||||
},
|
||||
patch: func(rt *RuntimeConfig) {
|
||||
rt.AutoConfig.Enabled = true
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package consul
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
|
@ -18,7 +19,7 @@ const (
|
|||
retryJitterWindow = 30 * time.Second
|
||||
)
|
||||
|
||||
func (c *Client) RequestAutoEncryptCerts(servers []string, port int, token string, interruptCh chan struct{}) (*structs.SignedResponse, string, error) {
|
||||
func (c *Client) RequestAutoEncryptCerts(ctx context.Context, servers []string, port int, token string) (*structs.SignedResponse, string, error) {
|
||||
errFn := func(err error) (*structs.SignedResponse, string, error) {
|
||||
return nil, "", err
|
||||
}
|
||||
|
@ -92,8 +93,8 @@ func (c *Client) RequestAutoEncryptCerts(servers []string, port int, token strin
|
|||
attempts := 0
|
||||
for {
|
||||
select {
|
||||
case <-interruptCh:
|
||||
return errFn(fmt.Errorf("aborting AutoEncrypt because interrupted"))
|
||||
case <-ctx.Done():
|
||||
return errFn(fmt.Errorf("aborting AutoEncrypt because interrupted: %w", ctx.Err()))
|
||||
default:
|
||||
}
|
||||
|
||||
|
@ -124,8 +125,8 @@ func (c *Client) RequestAutoEncryptCerts(servers []string, port int, token strin
|
|||
select {
|
||||
case <-time.After(interval):
|
||||
continue
|
||||
case <-interruptCh:
|
||||
return errFn(fmt.Errorf("aborting AutoEncrypt because interrupted"))
|
||||
case <-ctx.Done():
|
||||
return errFn(fmt.Errorf("aborting AutoEncrypt because interrupted: %w", ctx.Err()))
|
||||
case <-c.shutdownCh:
|
||||
return errFn(fmt.Errorf("aborting AutoEncrypt because shutting down"))
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package consul
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"os"
|
||||
"testing"
|
||||
|
@ -90,11 +91,14 @@ func TestAutoEncrypt_RequestAutoEncryptCerts(t *testing.T) {
|
|||
servers := []string{"localhost"}
|
||||
port := 8301
|
||||
token := ""
|
||||
interruptCh := make(chan struct{})
|
||||
|
||||
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(75*time.Millisecond))
|
||||
defer cancel()
|
||||
|
||||
doneCh := make(chan struct{})
|
||||
var err error
|
||||
go func() {
|
||||
_, _, err = c1.RequestAutoEncryptCerts(servers, port, token, interruptCh)
|
||||
_, _, err = c1.RequestAutoEncryptCerts(ctx, servers, port, token)
|
||||
close(doneCh)
|
||||
}()
|
||||
select {
|
||||
|
@ -104,9 +108,8 @@ func TestAutoEncrypt_RequestAutoEncryptCerts(t *testing.T) {
|
|||
// in the setup phase before entering the for loop in
|
||||
// RequestAutoEncryptCerts.
|
||||
require.NoError(t, err)
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
case <-ctx.Done():
|
||||
// this is the happy case since auto encrypt is in its loop to
|
||||
// try to request certs.
|
||||
interruptCh <- struct{}{}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ package agent
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/x509"
|
||||
|
@ -258,7 +259,7 @@ func (a *TestAgent) Start(t *testing.T) (err error) {
|
|||
|
||||
id := string(a.Config.NodeID)
|
||||
|
||||
if err := agent.Start(); err != nil {
|
||||
if err := agent.Start(context.Background()); err != nil {
|
||||
cleanupTmpDir()
|
||||
agent.ShutdownAgent()
|
||||
agent.ShutdownEndpoints()
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
|
@ -212,16 +213,17 @@ func (c *cmd) run(args []string) int {
|
|||
|
||||
// wait for signal
|
||||
signalCh := make(chan os.Signal, 10)
|
||||
stopCh := make(chan struct{})
|
||||
signal.Notify(signalCh, os.Interrupt, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGPIPE)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
go func() {
|
||||
for {
|
||||
var sig os.Signal
|
||||
select {
|
||||
case s := <-signalCh:
|
||||
sig = s
|
||||
case <-stopCh:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -235,18 +237,16 @@ func (c *cmd) run(args []string) int {
|
|||
|
||||
default:
|
||||
c.logger.Info("Caught", "signal", sig)
|
||||
agent.InterruptStartCh <- struct{}{}
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
err = agent.Start()
|
||||
err = agent.Start(ctx)
|
||||
signal.Stop(signalCh)
|
||||
select {
|
||||
case stopCh <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
cancel()
|
||||
|
||||
if err != nil {
|
||||
c.logger.Error("Error starting agent", "error", err)
|
||||
return 1
|
||||
|
|
Loading…
Reference in New Issue