mirror of https://github.com/status-im/consul.git
Allow cancelling startup when performing auto-config (#8157)
Co-authored-by: Daniel Nephin <dnephin@hashicorp.com>
This commit is contained in:
parent
1fa4d570d2
commit
0736c42b72
|
@ -257,8 +257,6 @@ type Agent struct {
|
||||||
shutdownCh chan struct{}
|
shutdownCh chan struct{}
|
||||||
shutdownLock sync.Mutex
|
shutdownLock sync.Mutex
|
||||||
|
|
||||||
InterruptStartCh chan struct{}
|
|
||||||
|
|
||||||
// joinLANNotifier is called after a successful JoinLAN.
|
// joinLANNotifier is called after a successful JoinLAN.
|
||||||
joinLANNotifier notifier
|
joinLANNotifier notifier
|
||||||
|
|
||||||
|
@ -427,7 +425,6 @@ func New(options ...AgentOption) (*Agent, error) {
|
||||||
joinLANNotifier: &systemd.Notifier{},
|
joinLANNotifier: &systemd.Notifier{},
|
||||||
retryJoinCh: make(chan error),
|
retryJoinCh: make(chan error),
|
||||||
shutdownCh: make(chan struct{}),
|
shutdownCh: make(chan struct{}),
|
||||||
InterruptStartCh: make(chan struct{}),
|
|
||||||
endpoints: make(map[string]string),
|
endpoints: make(map[string]string),
|
||||||
tokens: new(token.Store),
|
tokens: new(token.Store),
|
||||||
logger: flat.logger,
|
logger: flat.logger,
|
||||||
|
@ -599,13 +596,13 @@ func LocalConfig(cfg *config.RuntimeConfig) local.Config {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start verifies its configuration and runs an agent's various subprocesses.
|
// Start verifies its configuration and runs an agent's various subprocesses.
|
||||||
func (a *Agent) Start() error {
|
func (a *Agent) Start(ctx context.Context) error {
|
||||||
a.stateLock.Lock()
|
a.stateLock.Lock()
|
||||||
defer a.stateLock.Unlock()
|
defer a.stateLock.Unlock()
|
||||||
|
|
||||||
// This needs to be done early on as it will potentially alter the configuration
|
// This needs to be done early on as it will potentially alter the configuration
|
||||||
// and then how other bits are brought up
|
// and then how other bits are brought up
|
||||||
c, err := a.autoConf.InitialConfiguration(&lib.StopChannelContext{StopCh: a.shutdownCh})
|
c, err := a.autoConf.InitialConfiguration(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -709,7 +706,7 @@ func (a *Agent) Start() error {
|
||||||
a.registerCache()
|
a.registerCache()
|
||||||
|
|
||||||
if a.config.AutoEncryptTLS && !a.config.ServerMode {
|
if a.config.AutoEncryptTLS && !a.config.ServerMode {
|
||||||
reply, err := a.setupClientAutoEncrypt()
|
reply, err := a.setupClientAutoEncrypt(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("AutoEncrypt failed: %s", err)
|
return fmt.Errorf("AutoEncrypt failed: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -822,7 +819,7 @@ func (a *Agent) Start() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Agent) setupClientAutoEncrypt() (*structs.SignedResponse, error) {
|
func (a *Agent) setupClientAutoEncrypt(ctx context.Context) (*structs.SignedResponse, error) {
|
||||||
client := a.delegate.(*consul.Client)
|
client := a.delegate.(*consul.Client)
|
||||||
|
|
||||||
addrs := a.config.StartJoinAddrsLAN
|
addrs := a.config.StartJoinAddrsLAN
|
||||||
|
@ -832,7 +829,7 @@ func (a *Agent) setupClientAutoEncrypt() (*structs.SignedResponse, error) {
|
||||||
}
|
}
|
||||||
addrs = append(addrs, retryJoinAddrs(disco, retryJoinSerfVariant, "LAN", a.config.RetryJoinLAN, a.logger)...)
|
addrs = append(addrs, retryJoinAddrs(disco, retryJoinSerfVariant, "LAN", a.config.RetryJoinLAN, a.logger)...)
|
||||||
|
|
||||||
reply, priv, err := client.RequestAutoEncryptCerts(addrs, a.config.ServerPort, a.tokens.AgentToken(), a.InterruptStartCh)
|
reply, priv, err := client.RequestAutoEncryptCerts(ctx, addrs, a.config.ServerPort, a.tokens.AgentToken())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -961,7 +958,8 @@ func (a *Agent) setupClientAutoEncryptWatching(rootsReq *structs.DCSpecificReque
|
||||||
// check auto encrypt client cert expiration
|
// check auto encrypt client cert expiration
|
||||||
if a.tlsConfigurator.AutoEncryptCertExpired() {
|
if a.tlsConfigurator.AutoEncryptCertExpired() {
|
||||||
autoLogger.Debug("client certificate expired.")
|
autoLogger.Debug("client certificate expired.")
|
||||||
reply, err := a.setupClientAutoEncrypt()
|
// Background because the context is mainly useful when the agent is first starting up.
|
||||||
|
reply, err := a.setupClientAutoEncrypt(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
autoLogger.Error("client certificate expired, failed to renew", "error", err)
|
autoLogger.Error("client certificate expired, failed to renew", "error", err)
|
||||||
// in case of an error, try again in one minute
|
// in case of an error, try again in one minute
|
||||||
|
|
|
@ -262,6 +262,7 @@ func (ac *AutoConfig) InitialConfiguration(ctx context.Context) (*config.Runtime
|
||||||
}
|
}
|
||||||
|
|
||||||
if !ready {
|
if !ready {
|
||||||
|
ac.logger.Info("retrieving initial agent auto configuration remotely")
|
||||||
if err := ac.getInitialConfiguration(ctx); err != nil {
|
if err := ac.getInitialConfiguration(ctx); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -447,7 +448,7 @@ func (ac *AutoConfig) getInitialConfigurationOnce(ctx context.Context) (bool, er
|
||||||
return false, ctx.Err()
|
return false, ctx.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
ac.logger.Debug("Making Cluster.AutoConfig RPC", "addr", addr.String())
|
ac.logger.Debug("making Cluster.AutoConfig RPC", "addr", addr.String())
|
||||||
if err = ac.directRPC.RPC(ac.config.Datacenter, ac.config.NodeName, &addr, "Cluster.AutoConfig", &request, &reply); err != nil {
|
if err = ac.directRPC.RPC(ac.config.Datacenter, ac.config.NodeName, &addr, "Cluster.AutoConfig", &request, &reply); err != nil {
|
||||||
ac.logger.Error("AutoConfig RPC failed", "addr", addr.String(), "error", err)
|
ac.logger.Error("AutoConfig RPC failed", "addr", addr.String(), "error", err)
|
||||||
continue
|
continue
|
||||||
|
@ -457,7 +458,7 @@ func (ac *AutoConfig) getInitialConfigurationOnce(ctx context.Context) (bool, er
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return false, nil
|
return false, ctx.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
// getInitialConfiguration implements a loop to retry calls to getInitialConfigurationOnce.
|
// getInitialConfiguration implements a loop to retry calls to getInitialConfigurationOnce.
|
||||||
|
@ -469,11 +470,16 @@ func (ac *AutoConfig) getInitialConfiguration(ctx context.Context) error {
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-wait:
|
case <-wait:
|
||||||
if done, err := ac.getInitialConfigurationOnce(ctx); done {
|
done, err := ac.getInitialConfigurationOnce(ctx)
|
||||||
|
if done {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
if err != nil {
|
||||||
|
ac.logger.Error(err.Error())
|
||||||
|
}
|
||||||
wait = ac.waiter.Failed()
|
wait = ac.waiter.Failed()
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
|
ac.logger.Info("interrupted during initial auto configuration", "err", ctx.Err())
|
||||||
return ctx.Err()
|
return ctx.Err()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1923,6 +1923,15 @@ func (b *Builder) autoConfigVal(raw AutoConfigRaw) AutoConfig {
|
||||||
|
|
||||||
val.Enabled = b.boolValWithDefault(raw.Enabled, false)
|
val.Enabled = b.boolValWithDefault(raw.Enabled, false)
|
||||||
val.IntroToken = b.stringVal(raw.IntroToken)
|
val.IntroToken = b.stringVal(raw.IntroToken)
|
||||||
|
|
||||||
|
// default the IntroToken to the env variable if specified.
|
||||||
|
if envToken := os.Getenv("CONSUL_INTRO_TOKEN"); envToken != "" {
|
||||||
|
if val.IntroToken != "" {
|
||||||
|
b.warn("Both auto_config.intro_token and the CONSUL_INTRO_TOKEN environment variable are set. Using the value from the environment variable")
|
||||||
|
}
|
||||||
|
|
||||||
|
val.IntroToken = envToken
|
||||||
|
}
|
||||||
val.IntroTokenFile = b.stringVal(raw.IntroTokenFile)
|
val.IntroTokenFile = b.stringVal(raw.IntroTokenFile)
|
||||||
// These can be go-discover values and so don't have to resolve fully yet
|
// These can be go-discover values and so don't have to resolve fully yet
|
||||||
val.ServerAddresses = b.expandAllOptionalAddrs("auto_config.server_addresses", raw.ServerAddresses)
|
val.ServerAddresses = b.expandAllOptionalAddrs("auto_config.server_addresses", raw.ServerAddresses)
|
||||||
|
@ -1995,9 +2004,9 @@ func (b *Builder) validateAutoConfig(rt RuntimeConfig) error {
|
||||||
|
|
||||||
// When both are set we will prefer the given value over the file.
|
// When both are set we will prefer the given value over the file.
|
||||||
if autoconf.IntroToken != "" && autoconf.IntroTokenFile != "" {
|
if autoconf.IntroToken != "" && autoconf.IntroTokenFile != "" {
|
||||||
b.warn("auto_config.intro_token and auto_config.intro_token_file are both set. Using the value of auto_config.intro_token")
|
b.warn("Both an intro token and intro token file are set. The intro token will be used instead of the file")
|
||||||
} else if autoconf.IntroToken == "" && autoconf.IntroTokenFile == "" {
|
} else if autoconf.IntroToken == "" && autoconf.IntroTokenFile == "" {
|
||||||
return fmt.Errorf("one of auto_config.intro_token or auto_config.intro_token_file must be set to enable auto_config")
|
return fmt.Errorf("One of auto_config.intro_token, auto_config.intro_token_file or the CONSUL_INTRO_TOKEN environment variable must be set to enable auto_config")
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(autoconf.ServerAddresses) == 0 {
|
if len(autoconf.ServerAddresses) == 0 {
|
||||||
|
|
|
@ -3847,7 +3847,7 @@ func TestConfigFlagsAndEdgecases(t *testing.T) {
|
||||||
"server_addresses": ["198.18.0.1"]
|
"server_addresses": ["198.18.0.1"]
|
||||||
}
|
}
|
||||||
}`},
|
}`},
|
||||||
err: "one of auto_config.intro_token or auto_config.intro_token_file must be set to enable auto_config",
|
err: "One of auto_config.intro_token, auto_config.intro_token_file or the CONSUL_INTRO_TOKEN environment variable must be set to enable auto_config",
|
||||||
},
|
},
|
||||||
|
|
||||||
{
|
{
|
||||||
|
@ -3899,7 +3899,7 @@ func TestConfigFlagsAndEdgecases(t *testing.T) {
|
||||||
}`},
|
}`},
|
||||||
warns: []string{
|
warns: []string{
|
||||||
"Cannot parse ip \"invalid\" from auto_config.ip_sans",
|
"Cannot parse ip \"invalid\" from auto_config.ip_sans",
|
||||||
"auto_config.intro_token and auto_config.intro_token_file are both set. Using the value of auto_config.intro_token",
|
"Both an intro token and intro token file are set. The intro token will be used instead of the file",
|
||||||
},
|
},
|
||||||
patch: func(rt *RuntimeConfig) {
|
patch: func(rt *RuntimeConfig) {
|
||||||
rt.AutoConfig.Enabled = true
|
rt.AutoConfig.Enabled = true
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package consul
|
package consul
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -18,7 +19,7 @@ const (
|
||||||
retryJitterWindow = 30 * time.Second
|
retryJitterWindow = 30 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
func (c *Client) RequestAutoEncryptCerts(servers []string, port int, token string, interruptCh chan struct{}) (*structs.SignedResponse, string, error) {
|
func (c *Client) RequestAutoEncryptCerts(ctx context.Context, servers []string, port int, token string) (*structs.SignedResponse, string, error) {
|
||||||
errFn := func(err error) (*structs.SignedResponse, string, error) {
|
errFn := func(err error) (*structs.SignedResponse, string, error) {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
@ -92,8 +93,8 @@ func (c *Client) RequestAutoEncryptCerts(servers []string, port int, token strin
|
||||||
attempts := 0
|
attempts := 0
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-interruptCh:
|
case <-ctx.Done():
|
||||||
return errFn(fmt.Errorf("aborting AutoEncrypt because interrupted"))
|
return errFn(fmt.Errorf("aborting AutoEncrypt because interrupted: %w", ctx.Err()))
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -124,8 +125,8 @@ func (c *Client) RequestAutoEncryptCerts(servers []string, port int, token strin
|
||||||
select {
|
select {
|
||||||
case <-time.After(interval):
|
case <-time.After(interval):
|
||||||
continue
|
continue
|
||||||
case <-interruptCh:
|
case <-ctx.Done():
|
||||||
return errFn(fmt.Errorf("aborting AutoEncrypt because interrupted"))
|
return errFn(fmt.Errorf("aborting AutoEncrypt because interrupted: %w", ctx.Err()))
|
||||||
case <-c.shutdownCh:
|
case <-c.shutdownCh:
|
||||||
return errFn(fmt.Errorf("aborting AutoEncrypt because shutting down"))
|
return errFn(fmt.Errorf("aborting AutoEncrypt because shutting down"))
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package consul
|
package consul
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
|
@ -90,11 +91,14 @@ func TestAutoEncrypt_RequestAutoEncryptCerts(t *testing.T) {
|
||||||
servers := []string{"localhost"}
|
servers := []string{"localhost"}
|
||||||
port := 8301
|
port := 8301
|
||||||
token := ""
|
token := ""
|
||||||
interruptCh := make(chan struct{})
|
|
||||||
|
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(75*time.Millisecond))
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
doneCh := make(chan struct{})
|
doneCh := make(chan struct{})
|
||||||
var err error
|
var err error
|
||||||
go func() {
|
go func() {
|
||||||
_, _, err = c1.RequestAutoEncryptCerts(servers, port, token, interruptCh)
|
_, _, err = c1.RequestAutoEncryptCerts(ctx, servers, port, token)
|
||||||
close(doneCh)
|
close(doneCh)
|
||||||
}()
|
}()
|
||||||
select {
|
select {
|
||||||
|
@ -104,9 +108,8 @@ func TestAutoEncrypt_RequestAutoEncryptCerts(t *testing.T) {
|
||||||
// in the setup phase before entering the for loop in
|
// in the setup phase before entering the for loop in
|
||||||
// RequestAutoEncryptCerts.
|
// RequestAutoEncryptCerts.
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
case <-time.After(50 * time.Millisecond):
|
case <-ctx.Done():
|
||||||
// this is the happy case since auto encrypt is in its loop to
|
// this is the happy case since auto encrypt is in its loop to
|
||||||
// try to request certs.
|
// try to request certs.
|
||||||
interruptCh <- struct{}{}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,6 +2,7 @@ package agent
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"crypto/ecdsa"
|
"crypto/ecdsa"
|
||||||
"crypto/elliptic"
|
"crypto/elliptic"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
|
@ -258,7 +259,7 @@ func (a *TestAgent) Start(t *testing.T) (err error) {
|
||||||
|
|
||||||
id := string(a.Config.NodeID)
|
id := string(a.Config.NodeID)
|
||||||
|
|
||||||
if err := agent.Start(); err != nil {
|
if err := agent.Start(context.Background()); err != nil {
|
||||||
cleanupTmpDir()
|
cleanupTmpDir()
|
||||||
agent.ShutdownAgent()
|
agent.ShutdownAgent()
|
||||||
agent.ShutdownEndpoints()
|
agent.ShutdownEndpoints()
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package agent
|
package agent
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
@ -212,16 +213,17 @@ func (c *cmd) run(args []string) int {
|
||||||
|
|
||||||
// wait for signal
|
// wait for signal
|
||||||
signalCh := make(chan os.Signal, 10)
|
signalCh := make(chan os.Signal, 10)
|
||||||
stopCh := make(chan struct{})
|
|
||||||
signal.Notify(signalCh, os.Interrupt, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGPIPE)
|
signal.Notify(signalCh, os.Interrupt, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGPIPE)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
for {
|
for {
|
||||||
var sig os.Signal
|
var sig os.Signal
|
||||||
select {
|
select {
|
||||||
case s := <-signalCh:
|
case s := <-signalCh:
|
||||||
sig = s
|
sig = s
|
||||||
case <-stopCh:
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -235,18 +237,16 @@ func (c *cmd) run(args []string) int {
|
||||||
|
|
||||||
default:
|
default:
|
||||||
c.logger.Info("Caught", "signal", sig)
|
c.logger.Info("Caught", "signal", sig)
|
||||||
agent.InterruptStartCh <- struct{}{}
|
cancel()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
err = agent.Start()
|
err = agent.Start(ctx)
|
||||||
signal.Stop(signalCh)
|
signal.Stop(signalCh)
|
||||||
select {
|
cancel()
|
||||||
case stopCh <- struct{}{}:
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.logger.Error("Error starting agent", "error", err)
|
c.logger.Error("Error starting agent", "error", err)
|
||||||
return 1
|
return 1
|
||||||
|
|
Loading…
Reference in New Issue