mirror of https://github.com/status-im/consul.git
auto_encrypt: use server-port (#6287)
AutoEncrypt needs the server-port because it wants to talk via RPC. Information from gossip might not be available at that point and thats why the server-port is being used.
This commit is contained in:
parent
59150281c5
commit
3e46352ccb
|
@ -4,7 +4,6 @@ import (
|
|||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
@ -19,7 +18,7 @@ const (
|
|||
retryJitterWindow = 30 * time.Second
|
||||
)
|
||||
|
||||
func (c *Client) RequestAutoEncryptCerts(servers []string, defaultPort int, token string, interruptCh chan struct{}) (*structs.SignedResponse, string, error) {
|
||||
func (c *Client) RequestAutoEncryptCerts(servers []string, port int, token string, interruptCh chan struct{}) (*structs.SignedResponse, string, error) {
|
||||
errFn := func(err error) (*structs.SignedResponse, string, error) {
|
||||
return nil, "", err
|
||||
}
|
||||
|
@ -82,7 +81,7 @@ func (c *Client) RequestAutoEncryptCerts(servers []string, defaultPort int, toke
|
|||
// Translate host to net.TCPAddr to make life easier for
|
||||
// RPCInsecure.
|
||||
for _, s := range servers {
|
||||
ips, port, err := resolveAddr(s, defaultPort, c.logger)
|
||||
ips, err := resolveAddr(s, c.logger)
|
||||
if err != nil {
|
||||
c.logger.Printf("[WARN] agent: AutoEncrypt resolveAddr failed: %v", err)
|
||||
continue
|
||||
|
@ -114,29 +113,26 @@ func (c *Client) RequestAutoEncryptCerts(servers []string, defaultPort int, toke
|
|||
}
|
||||
}
|
||||
|
||||
// resolveAddr is used to resolve the host into IPs, port, and error.
|
||||
// If no port is given, use the default
|
||||
func resolveAddr(rawHost string, defaultPort int, logger *log.Logger) ([]net.IP, int, error) {
|
||||
host, splitPort, err := net.SplitHostPort(rawHost)
|
||||
if err != nil && err.Error() != fmt.Sprintf("address %s: missing port in address", rawHost) {
|
||||
return nil, defaultPort, err
|
||||
}
|
||||
func missingPortError(host string, err error) bool {
|
||||
return err != nil && err.Error() == fmt.Sprintf("address %s: missing port in address", host)
|
||||
}
|
||||
|
||||
// SplitHostPort returns empty host and splitPort on missingPort err,
|
||||
// so those are set to defaults
|
||||
var port int
|
||||
// resolveAddr is used to resolve the host into IPs and error.
|
||||
func resolveAddr(rawHost string, logger *log.Logger) ([]net.IP, error) {
|
||||
host, _, err := net.SplitHostPort(rawHost)
|
||||
if err != nil {
|
||||
host = rawHost
|
||||
port = defaultPort
|
||||
} else {
|
||||
port, err = strconv.Atoi(splitPort)
|
||||
if err != nil {
|
||||
port = defaultPort
|
||||
// In case we encounter this error, we proceed with the
|
||||
// rawHost. This is fine since -start-join and -retry-join
|
||||
// take only hosts anyways and this is an expected case.
|
||||
if missingPortError(rawHost, err) {
|
||||
host = rawHost
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
return []net.IP{ip}, port, nil
|
||||
return []net.IP{ip}, nil
|
||||
}
|
||||
|
||||
// First try TCP so we have the best chance for the largest list of
|
||||
|
@ -145,7 +141,7 @@ func resolveAddr(rawHost string, defaultPort int, logger *log.Logger) ([]net.IP,
|
|||
if ips, err := tcpLookupIP(host, logger); err != nil {
|
||||
logger.Printf("[DEBUG] agent: TCP-first lookup failed for '%s', falling back to UDP: %s", host, err)
|
||||
} else if len(ips) > 0 {
|
||||
return ips, port, nil
|
||||
return ips, nil
|
||||
}
|
||||
|
||||
// If TCP didn't yield anything then use the normal Go resolver which
|
||||
|
@ -153,9 +149,9 @@ func resolveAddr(rawHost string, defaultPort int, logger *log.Logger) ([]net.IP,
|
|||
// indicates it was truncated.
|
||||
ips, err := net.LookupIP(host)
|
||||
if err != nil {
|
||||
return nil, port, err
|
||||
return nil, err
|
||||
}
|
||||
return ips, port, nil
|
||||
return ips, nil
|
||||
}
|
||||
|
||||
// tcpLookupIP is a helper to initiate a TCP-based DNS lookup for the given host.
|
||||
|
|
|
@ -10,71 +10,70 @@ import (
|
|||
|
||||
func TestAutoEncrypt_resolveAddr(t *testing.T) {
|
||||
type args struct {
|
||||
rawHost string
|
||||
defaultPort int
|
||||
logger *log.Logger
|
||||
rawHost string
|
||||
logger *log.Logger
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
ips []net.IP
|
||||
port int
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "host without port",
|
||||
args: args{
|
||||
"127.0.0.1",
|
||||
8300,
|
||||
log.New(os.Stderr, "", log.LstdFlags),
|
||||
},
|
||||
ips: []net.IP{net.IPv4(127, 0, 0, 1)},
|
||||
port: 8300,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "host with port",
|
||||
args: args{
|
||||
"127.0.0.1:1234",
|
||||
8300,
|
||||
log.New(os.Stderr, "", log.LstdFlags),
|
||||
},
|
||||
ips: []net.IP{net.IPv4(127, 0, 0, 1)},
|
||||
port: 1234,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "host with broken port",
|
||||
args: args{
|
||||
"127.0.0.1:xyz",
|
||||
8300,
|
||||
log.New(os.Stderr, "", log.LstdFlags),
|
||||
},
|
||||
ips: []net.IP{net.IPv4(127, 0, 0, 1)},
|
||||
port: 8300,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "not an address",
|
||||
args: args{
|
||||
"abc",
|
||||
8300,
|
||||
log.New(os.Stderr, "", log.LstdFlags),
|
||||
},
|
||||
ips: nil,
|
||||
port: 8300,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ips, port, err := resolveAddr(tt.args.rawHost, tt.args.defaultPort, tt.args.logger)
|
||||
ips, err := resolveAddr(tt.args.rawHost, tt.args.logger)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("resolveAddr error: %v, wantErr: %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
require.Equal(t, tt.ips, ips)
|
||||
require.Equal(t, tt.port, port)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAutoEncrypt_missingPortError(t *testing.T) {
|
||||
host := "127.0.0.1"
|
||||
_, _, err := net.SplitHostPort(host)
|
||||
require.True(t, missingPortError(host, err))
|
||||
|
||||
host = "127.0.0.1:1234"
|
||||
_, _, err = net.SplitHostPort(host)
|
||||
require.False(t, missingPortError(host, err))
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue