diff --git a/agent/consul/auto_encrypt.go b/agent/consul/auto_encrypt.go index 8d5a2eabb8..3acf15a614 100644 --- a/agent/consul/auto_encrypt.go +++ b/agent/consul/auto_encrypt.go @@ -4,6 +4,7 @@ import ( "fmt" "log" "net" + "strconv" "strings" "time" @@ -18,7 +19,7 @@ const ( retryJitterWindow = 30 * time.Second ) -func (c *Client) RequestAutoEncryptCerts(servers []string, port int, token string, interruptCh chan struct{}) (*structs.SignedResponse, string, error) { +func (c *Client) RequestAutoEncryptCerts(servers []string, defaultPort int, token string, interruptCh chan struct{}) (*structs.SignedResponse, string, error) { errFn := func(err error) (*structs.SignedResponse, string, error) { return nil, "", err } @@ -81,11 +82,12 @@ func (c *Client) RequestAutoEncryptCerts(servers []string, port int, token strin // Translate host to net.TCPAddr to make life easier for // RPCInsecure. for _, s := range servers { - ips, err := resolveAddr(s, c.logger) + ips, port, err := resolveAddr(s, defaultPort, c.logger) if err != nil { c.logger.Printf("[WARN] agent: AutoEncrypt resolveAddr failed: %v", err) continue } + for _, ip := range ips { addr := net.TCPAddr{IP: ip, Port: port} @@ -112,16 +114,29 @@ func (c *Client) RequestAutoEncryptCerts(servers []string, port int, token strin } } -// resolveAddr is used to resolve the address into an address, -// port, and error. If no port is given, use the default -func resolveAddr(rawHost string, logger *log.Logger) ([]net.IP, error) { - host, _, err := net.SplitHostPort(rawHost) - if err != nil && err.Error() != "missing port in address" { - return nil, err +// resolveAddr is used to resolve the host into IPs, port, and error. +// If no port is given, use the default +func resolveAddr(rawHost string, defaultPort int, logger *log.Logger) ([]net.IP, int, error) { + host, splitPort, err := net.SplitHostPort(rawHost) + if err != nil && err.Error() != fmt.Sprintf("address %s: missing port in address", rawHost) { + return nil, defaultPort, err + } + + // SplitHostPort returns empty host and splitPort on missingPort err, + // so those are set to defaults + var port int + if err != nil { + host = rawHost + port = defaultPort + } else { + port, err = strconv.Atoi(splitPort) + if err != nil { + port = defaultPort + } } if ip := net.ParseIP(host); ip != nil { - return []net.IP{ip}, nil + return []net.IP{ip}, port, nil } // First try TCP so we have the best chance for the largest list of @@ -130,13 +145,17 @@ func resolveAddr(rawHost string, logger *log.Logger) ([]net.IP, error) { if ips, err := tcpLookupIP(host, logger); err != nil { logger.Printf("[DEBUG] agent: TCP-first lookup failed for '%s', falling back to UDP: %s", host, err) } else if len(ips) > 0 { - return ips, nil + return ips, port, nil } // If TCP didn't yield anything then use the normal Go resolver which // will try UDP, then might possibly try TCP again if the UDP response // indicates it was truncated. - return net.LookupIP(host) + ips, err := net.LookupIP(host) + if err != nil { + return nil, port, err + } + return ips, port, nil } // tcpLookupIP is a helper to initiate a TCP-based DNS lookup for the given host. diff --git a/agent/consul/auto_encrypt_test.go b/agent/consul/auto_encrypt_test.go new file mode 100644 index 0000000000..2a4daa012b --- /dev/null +++ b/agent/consul/auto_encrypt_test.go @@ -0,0 +1,80 @@ +package consul + +import ( + "github.com/stretchr/testify/require" + "log" + "net" + "os" + "testing" +) + +func TestAutoEncrypt_resolveAddr(t *testing.T) { + type args struct { + rawHost string + defaultPort int + logger *log.Logger + } + tests := []struct { + name string + args args + ips []net.IP + port int + wantErr bool + }{ + { + name: "host without port", + args: args{ + "127.0.0.1", + 8300, + log.New(os.Stderr, "", log.LstdFlags), + }, + ips: []net.IP{net.IPv4(127, 0, 0, 1)}, + port: 8300, + wantErr: false, + }, + { + name: "host with port", + args: args{ + "127.0.0.1:1234", + 8300, + log.New(os.Stderr, "", log.LstdFlags), + }, + ips: []net.IP{net.IPv4(127, 0, 0, 1)}, + port: 1234, + wantErr: false, + }, + { + name: "host with broken port", + args: args{ + "127.0.0.1:xyz", + 8300, + log.New(os.Stderr, "", log.LstdFlags), + }, + ips: []net.IP{net.IPv4(127, 0, 0, 1)}, + port: 8300, + wantErr: false, + }, + { + name: "not an address", + args: args{ + "abc", + 8300, + log.New(os.Stderr, "", log.LstdFlags), + }, + ips: nil, + port: 8300, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ips, port, err := resolveAddr(tt.args.rawHost, tt.args.defaultPort, tt.args.logger) + if (err != nil) != tt.wantErr { + t.Errorf("resolveAddr error: %v, wantErr: %v", err, tt.wantErr) + return + } + require.Equal(t, tt.ips, ips) + require.Equal(t, tt.port, port) + }) + } +}