mirror of https://github.com/status-im/consul.git
parent
cff872749d
commit
013bcefe5c
|
@ -6,7 +6,6 @@ package client
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
@ -21,7 +20,6 @@ import (
|
||||||
|
|
||||||
"github.com/hashicorp/go-cleanhttp"
|
"github.com/hashicorp/go-cleanhttp"
|
||||||
"github.com/hashicorp/go-hclog"
|
"github.com/hashicorp/go-hclog"
|
||||||
"github.com/hashicorp/go-rootcerts"
|
|
||||||
|
|
||||||
"github.com/hashicorp/consul/api"
|
"github.com/hashicorp/consul/api"
|
||||||
)
|
)
|
||||||
|
@ -73,14 +71,6 @@ const (
|
||||||
// whether or not to disable certificate checking.
|
// whether or not to disable certificate checking.
|
||||||
HTTPSSLVerifyEnvName = "CONSUL_HTTP_SSL_VERIFY"
|
HTTPSSLVerifyEnvName = "CONSUL_HTTP_SSL_VERIFY"
|
||||||
|
|
||||||
// GRPCCAFileEnvName defines an environment variable name which sets the
|
|
||||||
// CA file to use for talking to Consul gRPC over TLS.
|
|
||||||
GRPCCAFileEnvName = "CONSUL_GRPC_CACERT"
|
|
||||||
|
|
||||||
// GRPCCAPathEnvName defines an environment variable name which sets the
|
|
||||||
// path to a directory of CA certs to use for talking to Consul gRPC over TLS.
|
|
||||||
GRPCCAPathEnvName = "CONSUL_GRPC_CAPATH"
|
|
||||||
|
|
||||||
// HTTPNamespaceEnvVar defines an environment variable name which sets
|
// HTTPNamespaceEnvVar defines an environment variable name which sets
|
||||||
// the HTTP Namespace to be used by default. This can still be overridden.
|
// the HTTP Namespace to be used by default. This can still be overridden.
|
||||||
HTTPNamespaceEnvName = "CONSUL_NAMESPACE"
|
HTTPNamespaceEnvName = "CONSUL_NAMESPACE"
|
||||||
|
@ -538,60 +528,6 @@ func defaultConfig(logger hclog.Logger, transportFn func() *http.Transport) *Con
|
||||||
return config
|
return config
|
||||||
}
|
}
|
||||||
|
|
||||||
// TLSConfig is used to generate a TLSClientConfig that's useful for talking to
|
|
||||||
// Consul using TLS.
|
|
||||||
func SetupTLSConfig(tlsConfig *TLSConfig) (*tls.Config, error) {
|
|
||||||
tlsClientConfig := &tls.Config{
|
|
||||||
InsecureSkipVerify: tlsConfig.InsecureSkipVerify,
|
|
||||||
}
|
|
||||||
|
|
||||||
if tlsConfig.Address != "" {
|
|
||||||
server := tlsConfig.Address
|
|
||||||
hasPort := strings.LastIndex(server, ":") > strings.LastIndex(server, "]")
|
|
||||||
if hasPort {
|
|
||||||
var err error
|
|
||||||
server, _, err = net.SplitHostPort(server)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
tlsClientConfig.ServerName = server
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(tlsConfig.CertPEM) != 0 && len(tlsConfig.KeyPEM) != 0 {
|
|
||||||
tlsCert, err := tls.X509KeyPair(tlsConfig.CertPEM, tlsConfig.KeyPEM)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
tlsClientConfig.Certificates = []tls.Certificate{tlsCert}
|
|
||||||
} else if len(tlsConfig.CertPEM) != 0 || len(tlsConfig.KeyPEM) != 0 {
|
|
||||||
return nil, fmt.Errorf("both client cert and client key must be provided")
|
|
||||||
}
|
|
||||||
|
|
||||||
if tlsConfig.CertFile != "" && tlsConfig.KeyFile != "" {
|
|
||||||
tlsCert, err := tls.LoadX509KeyPair(tlsConfig.CertFile, tlsConfig.KeyFile)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
tlsClientConfig.Certificates = []tls.Certificate{tlsCert}
|
|
||||||
} else if tlsConfig.CertFile != "" || tlsConfig.KeyFile != "" {
|
|
||||||
return nil, fmt.Errorf("both client cert and client key must be provided")
|
|
||||||
}
|
|
||||||
|
|
||||||
if tlsConfig.CAFile != "" || tlsConfig.CAPath != "" || len(tlsConfig.CAPem) != 0 {
|
|
||||||
rootConfig := &rootcerts.Config{
|
|
||||||
CAFile: tlsConfig.CAFile,
|
|
||||||
CAPath: tlsConfig.CAPath,
|
|
||||||
CACertificate: tlsConfig.CAPem,
|
|
||||||
}
|
|
||||||
if err := rootcerts.ConfigureTLS(tlsClientConfig, rootConfig); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return tlsClientConfig, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Config) GenerateEnv() []string {
|
func (c *Config) GenerateEnv() []string {
|
||||||
env := make([]string, 0, 10)
|
env := make([]string, 0, 10)
|
||||||
|
|
||||||
|
|
|
@ -7,6 +7,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/credentials"
|
||||||
"google.golang.org/grpc/credentials/insecure"
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
|
|
||||||
"github.com/hashicorp/consul/proto-public/pbresource"
|
"github.com/hashicorp/consul/proto-public/pbresource"
|
||||||
|
@ -31,9 +32,34 @@ func NewGRPCClient(config *GRPCConfig) (*GRPCClient, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func dial(c *GRPCConfig) (*grpc.ClientConn, error) {
|
func dial(c *GRPCConfig) (*grpc.ClientConn, error) {
|
||||||
// TODO: decide if we use TLS mode based on the config
|
err := checkCertificates(c)
|
||||||
dialOpts := []grpc.DialOption{
|
if err != nil {
|
||||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
return nil, err
|
||||||
}
|
}
|
||||||
|
var dialOpts []grpc.DialOption
|
||||||
|
if c.GRPCTLS {
|
||||||
|
tlsConfig, err := SetupTLSConfig(c)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to setup tls config when tried to establish grpc call: %w", err)
|
||||||
|
}
|
||||||
|
dialOpts = append(dialOpts, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)))
|
||||||
|
} else {
|
||||||
|
dialOpts = append(dialOpts, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||||
|
}
|
||||||
|
|
||||||
return grpc.Dial(c.Address, dialOpts...)
|
return grpc.Dial(c.Address, dialOpts...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func checkCertificates(c *GRPCConfig) error {
|
||||||
|
if c.GRPCTLS {
|
||||||
|
certFileEmpty := c.CertFile == ""
|
||||||
|
keyFileEmpty := c.CertFile == ""
|
||||||
|
|
||||||
|
// both files need to be empty or both files need to be provided
|
||||||
|
if certFileEmpty != keyFileEmpty {
|
||||||
|
return fmt.Errorf("you have to provide client certificate file and key file at the same time " +
|
||||||
|
"if you intend to communicate in TLS/SSL mode")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
@ -9,23 +9,27 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
"google.golang.org/protobuf/proto"
|
"google.golang.org/protobuf/proto"
|
||||||
|
|
||||||
"github.com/hashicorp/consul/agent"
|
"github.com/hashicorp/consul/agent"
|
||||||
"github.com/hashicorp/consul/internal/resource/demo"
|
"github.com/hashicorp/consul/internal/resource/demo"
|
||||||
"github.com/hashicorp/consul/proto-public/pbresource"
|
"github.com/hashicorp/consul/proto-public/pbresource"
|
||||||
"github.com/hashicorp/consul/proto/private/prototest"
|
"github.com/hashicorp/consul/proto/private/prototest"
|
||||||
|
"github.com/hashicorp/consul/sdk/freeport"
|
||||||
"github.com/hashicorp/consul/sdk/testutil"
|
"github.com/hashicorp/consul/sdk/testutil"
|
||||||
"github.com/hashicorp/consul/testrpc"
|
"github.com/hashicorp/consul/testrpc"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestResourceRead(t *testing.T) {
|
func TestResourceRead(t *testing.T) {
|
||||||
t.Parallel()
|
availablePort := freeport.GetOne(t)
|
||||||
|
a := agent.NewTestAgent(t, fmt.Sprintf("ports { grpc = %d }", availablePort))
|
||||||
a := agent.NewTestAgent(t, "ports { grpc = 8502 }")
|
|
||||||
testrpc.WaitForTestAgent(t, a.RPC, "dc1")
|
testrpc.WaitForTestAgent(t, a.RPC, "dc1")
|
||||||
grpcConfig := GetDefaultGRPCConfig()
|
grpcConfig, err := LoadGRPCConfig(&GRPCConfig{Address: fmt.Sprintf("127.0.0.1:%d", availablePort)})
|
||||||
|
require.NoError(t, err)
|
||||||
gRPCClient, err := NewGRPCClient(grpcConfig)
|
gRPCClient, err := NewGRPCClient(grpcConfig)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
a.Shutdown()
|
a.Shutdown()
|
||||||
|
@ -49,3 +53,95 @@ func TestResourceRead(t *testing.T) {
|
||||||
prototest.AssertDeepEqual(t, writeRsp.Resource, readRsp.Resource)
|
prototest.AssertDeepEqual(t, writeRsp.Resource, readRsp.Resource)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestResourceReadInTLS(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
requireClientCert bool
|
||||||
|
grpcConfig func() (*GRPCConfig, int)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Test with CertFile, KeyFile and CAFile",
|
||||||
|
requireClientCert: true,
|
||||||
|
grpcConfig: func() (*GRPCConfig, int) {
|
||||||
|
availablePort := freeport.GetOne(t)
|
||||||
|
return &GRPCConfig{
|
||||||
|
Address: fmt.Sprintf("127.0.0.1:%d", availablePort),
|
||||||
|
GRPCTLS: true,
|
||||||
|
CertFile: "../../../test/client_certs/client.crt",
|
||||||
|
KeyFile: "../../../test/client_certs/client.key",
|
||||||
|
CAFile: "../../../test/client_certs/rootca.crt",
|
||||||
|
}, availablePort
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Test without CAFile",
|
||||||
|
requireClientCert: true,
|
||||||
|
grpcConfig: func() (*GRPCConfig, int) {
|
||||||
|
availablePort := freeport.GetOne(t)
|
||||||
|
return &GRPCConfig{
|
||||||
|
Address: fmt.Sprintf("127.0.0.1:%d", availablePort),
|
||||||
|
GRPCTLS: true,
|
||||||
|
GRPCTLSVerify: false,
|
||||||
|
CertFile: "../../../test/client_certs/client.crt",
|
||||||
|
KeyFile: "../../../test/client_certs/client.key",
|
||||||
|
}, availablePort
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Test without client certificates",
|
||||||
|
requireClientCert: false,
|
||||||
|
grpcConfig: func() (*GRPCConfig, int) {
|
||||||
|
availablePort := freeport.GetOne(t)
|
||||||
|
return &GRPCConfig{
|
||||||
|
Address: fmt.Sprintf("127.0.0.1:%d", availablePort),
|
||||||
|
GRPCTLS: true,
|
||||||
|
}, availablePort
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
grpcClientConfig, availablePort := tt.grpcConfig()
|
||||||
|
a := agent.StartTestAgent(t, agent.TestAgent{
|
||||||
|
HCL: fmt.Sprintf(`
|
||||||
|
ports { grpc_tls = %d }
|
||||||
|
enable_agent_tls_for_checks = true
|
||||||
|
tls {
|
||||||
|
defaults {
|
||||||
|
verify_incoming = %t
|
||||||
|
key_file = "../../../test/client_certs/server.key"
|
||||||
|
cert_file = "../../../test/client_certs/server.crt"
|
||||||
|
ca_file = "../../../test/client_certs/rootca.crt"
|
||||||
|
}
|
||||||
|
}`, availablePort, tt.requireClientCert),
|
||||||
|
UseGRPCTLS: true,
|
||||||
|
})
|
||||||
|
testrpc.WaitForTestAgent(t, a.RPC, "dc1")
|
||||||
|
grpcConfig, err := LoadGRPCConfig(grpcClientConfig)
|
||||||
|
require.NoError(t, err)
|
||||||
|
gRPCClient, err := NewGRPCClient(grpcConfig)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
a.Shutdown()
|
||||||
|
gRPCClient.Conn.Close()
|
||||||
|
})
|
||||||
|
|
||||||
|
v2Artist, err := demo.GenerateV2Artist()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = gRPCClient.Client.Read(context.Background(), &pbresource.ReadRequest{Id: v2Artist.Id})
|
||||||
|
require.Equal(t, codes.NotFound.String(), status.Code(err).String())
|
||||||
|
|
||||||
|
writeRsp, err := gRPCClient.Client.Write(testutil.TestContext(t), &pbresource.WriteRequest{Resource: v2Artist})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
readRsp, err := gRPCClient.Client.Read(context.Background(), &pbresource.ReadRequest{Id: v2Artist.Id})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, proto.Equal(readRsp.Resource.Id.Type, demo.TypeV2Artist), true)
|
||||||
|
prototest.AssertDeepEqual(t, writeRsp.Resource, readRsp.Resource)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -4,40 +4,132 @@
|
||||||
package client
|
package client
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// GRPCAddrEnvName defines an environment variable name which sets the gRPC
|
// GRPCAddrEnvName defines an environment variable name which sets the gRPC
|
||||||
// server address for the consul CLI.
|
// server address for the consul CLI.
|
||||||
GRPCAddrEnvName = "CONSUL_GRPC_ADDR"
|
GRPCAddrEnvName = "CONSUL_GRPC_ADDR"
|
||||||
|
|
||||||
|
// GRPCTLSEnvName defines an environment variable name which sets the gRPC
|
||||||
|
// communication mode. Default is false in plaintext mode.
|
||||||
|
GRPCTLSEnvName = "CONSUL_GRPC_TLS"
|
||||||
|
|
||||||
|
// GRPCTLSVerifyEnvName defines an environment variable name which sets
|
||||||
|
// whether to disable certificate checking.
|
||||||
|
GRPCTLSVerifyEnvName = "CONSUL_GRPC_TLS_VERIFY"
|
||||||
|
|
||||||
|
// GRPCClientCertEnvName defines an environment variable name which sets the
|
||||||
|
// client cert file to use for talking to Consul over TLS.
|
||||||
|
GRPCClientCertEnvName = "CONSUL_GRPC_CLIENT_CERT"
|
||||||
|
|
||||||
|
// GRPCClientKeyEnvName defines an environment variable name which sets the
|
||||||
|
// client key file to use for talking to Consul over TLS.
|
||||||
|
GRPCClientKeyEnvName = "CONSUL_GRPC_CLIENT_KEY"
|
||||||
|
|
||||||
|
// GRPCCAFileEnvName defines an environment variable name which sets the
|
||||||
|
// CA file to use for talking to Consul gRPC over TLS.
|
||||||
|
GRPCCAFileEnvName = "CONSUL_GRPC_CACERT"
|
||||||
|
|
||||||
|
// GRPCCAPathEnvName defines an environment variable name which sets the
|
||||||
|
// path to a directory of CA certs to use for talking to Consul gRPC over TLS.
|
||||||
|
GRPCCAPathEnvName = "CONSUL_GRPC_CAPATH"
|
||||||
)
|
)
|
||||||
|
|
||||||
type GRPCConfig struct {
|
type GRPCConfig struct {
|
||||||
|
// Address is the optional address of the Consul server in format of host:port.
|
||||||
|
// It doesn't include schema
|
||||||
Address string
|
Address string
|
||||||
|
|
||||||
|
// GRPCTLS is the optional boolean flag to determine the communication protocol
|
||||||
|
GRPCTLS bool
|
||||||
|
|
||||||
|
// GRPCTLSVerify is the optional boolean flag to disable certificate checking.
|
||||||
|
// Set to false only if you want to skip server verification
|
||||||
|
GRPCTLSVerify bool
|
||||||
|
|
||||||
|
// CertFile is the optional path to the certificate for Consul
|
||||||
|
// communication. If this is set then you need to also set KeyFile.
|
||||||
|
CertFile string
|
||||||
|
|
||||||
|
// KeyFile is the optional path to the private key for Consul communication.
|
||||||
|
// If this is set then you need to also set CertFile.
|
||||||
|
KeyFile string
|
||||||
|
|
||||||
|
// CAFile is the optional path to the CA certificate used for Consul
|
||||||
|
// communication, defaults to the system bundle if not specified.
|
||||||
|
CAFile string
|
||||||
|
|
||||||
|
// CAPath is the optional path to a directory of CA certificates to use for
|
||||||
|
// Consul communication, defaults to the system bundle if not specified.
|
||||||
|
CAPath string
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetDefaultGRPCConfig() *GRPCConfig {
|
func GetDefaultGRPCConfig() *GRPCConfig {
|
||||||
return &GRPCConfig{
|
return &GRPCConfig{
|
||||||
Address: "localhost:8502",
|
Address: "127.0.0.1:8502",
|
||||||
|
GRPCTLSVerify: false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func LoadGRPCConfig(defaultConfig *GRPCConfig) *GRPCConfig {
|
func LoadGRPCConfig(defaultConfig *GRPCConfig) (*GRPCConfig, error) {
|
||||||
if defaultConfig == nil {
|
if defaultConfig == nil {
|
||||||
defaultConfig = GetDefaultGRPCConfig()
|
defaultConfig = GetDefaultGRPCConfig()
|
||||||
}
|
}
|
||||||
|
|
||||||
overwrittenConfig := loadEnvToDefaultConfig(defaultConfig)
|
overwrittenConfig, err := loadEnvToDefaultConfig(defaultConfig)
|
||||||
|
if err != nil {
|
||||||
return overwrittenConfig
|
return nil, err
|
||||||
}
|
|
||||||
|
|
||||||
func loadEnvToDefaultConfig(config *GRPCConfig) *GRPCConfig {
|
|
||||||
|
|
||||||
if addr := os.Getenv(GRPCAddrEnvName); addr != "" {
|
|
||||||
config.Address = addr
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return config
|
return overwrittenConfig, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadEnvToDefaultConfig(config *GRPCConfig) (*GRPCConfig, error) {
|
||||||
|
if addr := os.Getenv(GRPCAddrEnvName); addr != "" {
|
||||||
|
if strings.HasPrefix(strings.ToLower(addr), "https://") {
|
||||||
|
config.GRPCTLS = true
|
||||||
|
}
|
||||||
|
config.Address = removeSchemaFromGRPCAddress(addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tlsMode := os.Getenv(GRPCTLSEnvName); tlsMode != "" {
|
||||||
|
doTLS, err := strconv.ParseBool(tlsMode)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse %s: %w", GRPCTLSEnvName, err)
|
||||||
|
}
|
||||||
|
if doTLS {
|
||||||
|
config.GRPCTLS = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if v := os.Getenv(GRPCTLSVerifyEnvName); v != "" {
|
||||||
|
doVerify, err := strconv.ParseBool(v)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse %s: %w", GRPCTLSVerifyEnvName, err)
|
||||||
|
}
|
||||||
|
config.GRPCTLSVerify = doVerify
|
||||||
|
}
|
||||||
|
|
||||||
|
if v := os.Getenv(GRPCClientCertEnvName); v != "" {
|
||||||
|
config.CertFile = v
|
||||||
|
}
|
||||||
|
|
||||||
|
if v := os.Getenv(GRPCClientKeyEnvName); v != "" {
|
||||||
|
config.KeyFile = v
|
||||||
|
}
|
||||||
|
|
||||||
|
if caFile := os.Getenv(GRPCCAFileEnvName); caFile != "" {
|
||||||
|
config.CAFile = caFile
|
||||||
|
}
|
||||||
|
|
||||||
|
if caPath := os.Getenv(GRPCCAPathEnvName); caPath != "" {
|
||||||
|
config.CAPath = caPath
|
||||||
|
}
|
||||||
|
|
||||||
|
return config, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,56 @@
|
||||||
|
// Copyright (c) HashiCorp, Inc.
|
||||||
|
// SPDX-License-Identifier: BUSL-1.1
|
||||||
|
|
||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestLoadGRPCConfig(t *testing.T) {
|
||||||
|
t.Run("Default Config", func(t *testing.T) {
|
||||||
|
// Test when defaultConfig is nil
|
||||||
|
config, err := LoadGRPCConfig(nil)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, GetDefaultGRPCConfig(), config)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test when environment variables are set
|
||||||
|
t.Run("Env Overwritten", func(t *testing.T) {
|
||||||
|
// Mock environment variables
|
||||||
|
t.Setenv(GRPCAddrEnvName, "localhost:8500")
|
||||||
|
t.Setenv(GRPCTLSEnvName, "true")
|
||||||
|
t.Setenv(GRPCTLSVerifyEnvName, "false")
|
||||||
|
t.Setenv(GRPCClientCertEnvName, "/path/to/client.crt")
|
||||||
|
t.Setenv(GRPCClientKeyEnvName, "/path/to/client.key")
|
||||||
|
t.Setenv(GRPCCAFileEnvName, "/path/to/ca.crt")
|
||||||
|
t.Setenv(GRPCCAPathEnvName, "/path/to/cacerts")
|
||||||
|
|
||||||
|
// Load and validate the configuration
|
||||||
|
config, err := LoadGRPCConfig(nil)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
expectedConfig := &GRPCConfig{
|
||||||
|
Address: "localhost:8500",
|
||||||
|
GRPCTLS: true,
|
||||||
|
GRPCTLSVerify: false,
|
||||||
|
CertFile: "/path/to/client.crt",
|
||||||
|
KeyFile: "/path/to/client.key",
|
||||||
|
CAFile: "/path/to/ca.crt",
|
||||||
|
CAPath: "/path/to/cacerts",
|
||||||
|
}
|
||||||
|
assert.Equal(t, expectedConfig, config)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test when there's an error parsing a boolean value from an environment variable
|
||||||
|
t.Run("Error Parsing Bool", func(t *testing.T) {
|
||||||
|
// Mock environment variable with an invalid boolean value
|
||||||
|
t.Setenv(GRPCTLSEnvName, "invalid_boolean_value")
|
||||||
|
|
||||||
|
// Load and expect an error
|
||||||
|
config, err := LoadGRPCConfig(nil)
|
||||||
|
assert.Error(t, err, "failed to parse CONSUL_GRPC_TLS: strconv.ParseBool: parsing \"invalid_boolean_value\": invalid syntax")
|
||||||
|
assert.Nil(t, config)
|
||||||
|
})
|
||||||
|
}
|
|
@ -5,54 +5,60 @@ package client
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"flag"
|
"flag"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
type GRPCFlags struct {
|
type GRPCFlags struct {
|
||||||
address StringValue
|
address TValue[string]
|
||||||
|
grpcTLS TValue[bool]
|
||||||
|
certFile TValue[string]
|
||||||
|
keyFile TValue[string]
|
||||||
|
caFile TValue[string]
|
||||||
|
caPath TValue[string]
|
||||||
}
|
}
|
||||||
|
|
||||||
// mergeFlagsIntoGRPCConfig merges flag values into grpc config
|
// MergeFlagsIntoGRPCConfig merges flag values into grpc config
|
||||||
// caller has to parse the CLI args before loading them into flag values
|
// caller has to parse the CLI args before loading them into flag values
|
||||||
func (f *GRPCFlags) mergeFlagsIntoGRPCConfig(c *GRPCConfig) {
|
// The flags take precedence over the environment values
|
||||||
|
func (f *GRPCFlags) MergeFlagsIntoGRPCConfig(c *GRPCConfig) {
|
||||||
|
if strings.HasPrefix(strings.ToLower(f.address.String()), "https://") {
|
||||||
|
c.GRPCTLS = true
|
||||||
|
}
|
||||||
|
f.address.Set(removeSchemaFromGRPCAddress(f.address.String()))
|
||||||
f.address.Merge(&c.Address)
|
f.address.Merge(&c.Address)
|
||||||
|
// won't overwrite the value if it's false
|
||||||
|
if f.grpcTLS.v != nil && *f.grpcTLS.v {
|
||||||
|
f.grpcTLS.Merge(&c.GRPCTLS)
|
||||||
|
}
|
||||||
|
f.certFile.Merge(&c.CertFile)
|
||||||
|
f.keyFile.Merge(&c.KeyFile)
|
||||||
|
f.caFile.Merge(&c.CAFile)
|
||||||
|
f.caPath.Merge(&c.CAPath)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// merge the client flags into command line flags then parse command line flags
|
||||||
func (f *GRPCFlags) ClientFlags() *flag.FlagSet {
|
func (f *GRPCFlags) ClientFlags() *flag.FlagSet {
|
||||||
fs := flag.NewFlagSet("", flag.ContinueOnError)
|
fs := flag.NewFlagSet("", flag.ContinueOnError)
|
||||||
fs.Var(&f.address, "grpc-addr",
|
fs.Var(&f.address, "grpc-addr",
|
||||||
"The `address` and port of the Consul GRPC agent. The value can be an IP "+
|
"The `address` and `port` of the Consul GRPC agent. The value can be an IP "+
|
||||||
"address or DNS address, but it must also include the port. This can "+
|
"address or DNS address, but it must also include the port. This can also be specified "+
|
||||||
"also be specified via the CONSUL_GRPC_ADDR environment variable. The "+
|
"via the CONSUL_GRPC_ADDR environment variable. The default value is "+
|
||||||
"default value is 127.0.0.1:8502. It supports TLS communication "+
|
"127.0.0.1:8502. If you intend to communicate in TLS mode, you have to either "+
|
||||||
"by setting the environment variable CONSUL_GRPC_TLS=true.")
|
"include https:// schema in the address, use grpc-tls flag or set environment variable "+
|
||||||
|
"CONSUL_GRPC_TLS = true, otherwise it uses plaintext mode")
|
||||||
|
fs.Var(&f.caFile, "grpc-tls",
|
||||||
|
"Set to true if you aim to communicate in TLS mode in the GRPC call.")
|
||||||
|
fs.Var(&f.certFile, "client-cert",
|
||||||
|
"Path to a client cert file to use for TLS when 'verify_incoming' is enabled. This "+
|
||||||
|
"can also be specified via the CONSUL_GRPC_CLIENT_CERT environment variable.")
|
||||||
|
fs.Var(&f.keyFile, "client-key",
|
||||||
|
"Path to a client key file to use for TLS when 'verify_incoming' is enabled. This "+
|
||||||
|
"can also be specified via the CONSUL_GRPC_CLIENT_KEY environment variable.")
|
||||||
|
fs.Var(&f.caFile, "ca-file",
|
||||||
|
"Path to a CA file to use for TLS when communicating with Consul. This "+
|
||||||
|
"can also be specified via the CONSUL_CACERT environment variable.")
|
||||||
|
fs.Var(&f.caPath, "ca-path",
|
||||||
|
"Path to a directory of CA certificates to use for TLS when communicating "+
|
||||||
|
"with Consul. This can also be specified via the CONSUL_CAPATH environment variable.")
|
||||||
return fs
|
return fs
|
||||||
}
|
}
|
||||||
|
|
||||||
type StringValue struct {
|
|
||||||
v *string
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set implements the flag.Value interface.
|
|
||||||
func (s *StringValue) Set(v string) error {
|
|
||||||
if s.v == nil {
|
|
||||||
s.v = new(string)
|
|
||||||
}
|
|
||||||
*(s.v) = v
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// String implements the flag.Value interface.
|
|
||||||
func (s *StringValue) String() string {
|
|
||||||
var current string
|
|
||||||
if s.v != nil {
|
|
||||||
current = *(s.v)
|
|
||||||
}
|
|
||||||
return current
|
|
||||||
}
|
|
||||||
|
|
||||||
// Merge will overlay this value if it has been set.
|
|
||||||
func (s *StringValue) Merge(onto *string) {
|
|
||||||
if s.v != nil {
|
|
||||||
*onto = *(s.v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
@ -0,0 +1,98 @@
|
||||||
|
// Copyright (c) HashiCorp, Inc.
|
||||||
|
// SPDX-License-Identifier: BUSL-1.1
|
||||||
|
|
||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMergeFlagsIntoGRPCConfig(t *testing.T) {
|
||||||
|
t.Run("MergeFlagsIntoGRPCConfig", func(t *testing.T) {
|
||||||
|
// Setup GRPCFlags with some flag values
|
||||||
|
flags := &GRPCFlags{
|
||||||
|
address: TValue[string]{v: stringPointer("https://example.com:8502")},
|
||||||
|
grpcTLS: TValue[bool]{v: boolPointer(true)},
|
||||||
|
certFile: TValue[string]{v: stringPointer("/path/to/client.crt")},
|
||||||
|
keyFile: TValue[string]{v: stringPointer("/path/to/client.key")},
|
||||||
|
caFile: TValue[string]{v: stringPointer("/path/to/ca.crt")},
|
||||||
|
caPath: TValue[string]{v: stringPointer("/path/to/cacerts")},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setup GRPCConfig with some initial values
|
||||||
|
config := &GRPCConfig{
|
||||||
|
Address: "localhost:8500",
|
||||||
|
GRPCTLS: false,
|
||||||
|
GRPCTLSVerify: true,
|
||||||
|
CertFile: "/path/to/default/client.crt",
|
||||||
|
KeyFile: "/path/to/default/client.key",
|
||||||
|
CAFile: "/path/to/default/ca.crt",
|
||||||
|
CAPath: "/path/to/default/cacerts",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call MergeFlagsIntoGRPCConfig to merge flag values into the config
|
||||||
|
flags.MergeFlagsIntoGRPCConfig(config)
|
||||||
|
|
||||||
|
// Validate the merged config
|
||||||
|
expectedConfig := &GRPCConfig{
|
||||||
|
Address: "example.com:8502",
|
||||||
|
GRPCTLS: true,
|
||||||
|
GRPCTLSVerify: true,
|
||||||
|
CertFile: "/path/to/client.crt",
|
||||||
|
KeyFile: "/path/to/client.key",
|
||||||
|
CAFile: "/path/to/ca.crt",
|
||||||
|
CAPath: "/path/to/cacerts",
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, expectedConfig, config)
|
||||||
|
})
|
||||||
|
t.Run("MergeFlagsIntoGRPCConfig: allow empty values", func(t *testing.T) {
|
||||||
|
// Setup GRPCFlags with some flag values
|
||||||
|
flags := &GRPCFlags{
|
||||||
|
address: TValue[string]{v: stringPointer("http://example.com:8502")},
|
||||||
|
grpcTLS: TValue[bool]{},
|
||||||
|
certFile: TValue[string]{v: stringPointer("/path/to/client.crt")},
|
||||||
|
keyFile: TValue[string]{v: stringPointer("/path/to/client.key")},
|
||||||
|
caFile: TValue[string]{v: stringPointer("/path/to/ca.crt")},
|
||||||
|
caPath: TValue[string]{v: stringPointer("/path/to/cacerts")},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setup GRPCConfig with some initial values
|
||||||
|
config := &GRPCConfig{
|
||||||
|
Address: "localhost:8500",
|
||||||
|
GRPCTLSVerify: true,
|
||||||
|
CertFile: "/path/to/default/client.crt",
|
||||||
|
KeyFile: "/path/to/default/client.key",
|
||||||
|
CAFile: "/path/to/default/ca.crt",
|
||||||
|
CAPath: "/path/to/default/cacerts",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call MergeFlagsIntoGRPCConfig to merge flag values into the config
|
||||||
|
flags.MergeFlagsIntoGRPCConfig(config)
|
||||||
|
|
||||||
|
// Validate the merged config
|
||||||
|
expectedConfig := &GRPCConfig{
|
||||||
|
Address: "example.com:8502",
|
||||||
|
GRPCTLS: false,
|
||||||
|
GRPCTLSVerify: true,
|
||||||
|
CertFile: "/path/to/client.crt",
|
||||||
|
KeyFile: "/path/to/client.key",
|
||||||
|
CAFile: "/path/to/ca.crt",
|
||||||
|
CAPath: "/path/to/cacerts",
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, expectedConfig, config)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Utility function to convert string to string pointer
|
||||||
|
func stringPointer(s string) *string {
|
||||||
|
return &s
|
||||||
|
}
|
||||||
|
|
||||||
|
// Utility function to convert bool to bool pointer
|
||||||
|
func boolPointer(b bool) *bool {
|
||||||
|
return &b
|
||||||
|
}
|
|
@ -0,0 +1,93 @@
|
||||||
|
// Copyright (c) HashiCorp, Inc.
|
||||||
|
// SPDX-License-Identifier: BUSL-1.1
|
||||||
|
|
||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/hashicorp/go-rootcerts"
|
||||||
|
)
|
||||||
|
|
||||||
|
// tls.Config is used to establish communication in TLS mode
|
||||||
|
func SetupTLSConfig(c *GRPCConfig) (*tls.Config, error) {
|
||||||
|
tlsConfig := &tls.Config{
|
||||||
|
InsecureSkipVerify: !c.GRPCTLSVerify,
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.CertFile != "" && c.KeyFile != "" {
|
||||||
|
tlsCert, err := tls.LoadX509KeyPair(c.CertFile, c.KeyFile)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
tlsConfig.Certificates = []tls.Certificate{tlsCert}
|
||||||
|
}
|
||||||
|
|
||||||
|
var caConfig *rootcerts.Config
|
||||||
|
if c.CAFile != "" || c.CAPath != "" {
|
||||||
|
caConfig = &rootcerts.Config{
|
||||||
|
CAFile: c.CAFile,
|
||||||
|
CAPath: c.CAPath,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// load system CA certs if user doesn't provide any
|
||||||
|
if err := rootcerts.ConfigureTLS(tlsConfig, caConfig); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return tlsConfig, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func removeSchemaFromGRPCAddress(addr string) string {
|
||||||
|
// Parse as host:port with option http prefix
|
||||||
|
grpcAddr := strings.TrimPrefix(addr, "http://")
|
||||||
|
grpcAddr = strings.TrimPrefix(grpcAddr, "https://")
|
||||||
|
return grpcAddr
|
||||||
|
}
|
||||||
|
|
||||||
|
type TValue[T string | bool] struct {
|
||||||
|
v *T
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set implements the flag.Value interface.
|
||||||
|
func (t *TValue[T]) Set(v string) error {
|
||||||
|
if t.v == nil {
|
||||||
|
t.v = new(T)
|
||||||
|
}
|
||||||
|
var err error
|
||||||
|
// have to use interface{}(t.v) to do type assertion
|
||||||
|
switch interface{}(t.v).(type) {
|
||||||
|
case *string:
|
||||||
|
// have to use interface{}(t.v).(*string) to assert t.v as *string
|
||||||
|
*(interface{}(t.v).(*string)) = v
|
||||||
|
case *bool:
|
||||||
|
// have to use interface{}(t.v).(*bool) to assert t.v as *bool
|
||||||
|
*(interface{}(t.v).(*bool)), err = strconv.ParseBool(v)
|
||||||
|
default:
|
||||||
|
err = fmt.Errorf("unsupported type %T", t.v)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// String implements the flag.Value interface.
|
||||||
|
func (t *TValue[T]) String() string {
|
||||||
|
var current T
|
||||||
|
if t.v != nil {
|
||||||
|
current = *(t.v)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%v", current)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Merge will overlay this value if it has been set.
|
||||||
|
func (t *TValue[T]) Merge(onto *T) error {
|
||||||
|
if onto == nil {
|
||||||
|
return fmt.Errorf("onto is nil")
|
||||||
|
}
|
||||||
|
if t.v != nil {
|
||||||
|
*onto = *(t.v)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -0,0 +1,70 @@
|
||||||
|
// Copyright (c) HashiCorp, Inc.
|
||||||
|
// SPDX-License-Identifier: BUSL-1.1
|
||||||
|
|
||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTValue(t *testing.T) {
|
||||||
|
t.Run("String: set", func(t *testing.T) {
|
||||||
|
var tv TValue[string]
|
||||||
|
|
||||||
|
err := tv.Set("testString")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, *tv.v, "testString")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("String: merge", func(t *testing.T) {
|
||||||
|
var tv TValue[string]
|
||||||
|
var onto string
|
||||||
|
|
||||||
|
testStr := "testString"
|
||||||
|
tv.v = &testStr
|
||||||
|
tv.Merge(&onto)
|
||||||
|
|
||||||
|
assert.Equal(t, onto, "testString")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("String: merge nil", func(t *testing.T) {
|
||||||
|
var tv TValue[string]
|
||||||
|
var onto *string = nil
|
||||||
|
|
||||||
|
testStr := "testString"
|
||||||
|
tv.v = &testStr
|
||||||
|
err := tv.Merge(onto)
|
||||||
|
|
||||||
|
assert.Equal(t, err.Error(), "onto is nil")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Get string", func(t *testing.T) {
|
||||||
|
var tv TValue[string]
|
||||||
|
testStr := "testString"
|
||||||
|
tv.v = &testStr
|
||||||
|
assert.Equal(t, tv.String(), "testString")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Bool: set", func(t *testing.T) {
|
||||||
|
var tv TValue[bool]
|
||||||
|
|
||||||
|
err := tv.Set("true")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, *tv.v, true)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Bool: merge", func(t *testing.T) {
|
||||||
|
var tv TValue[bool]
|
||||||
|
var onto bool
|
||||||
|
|
||||||
|
testBool := true
|
||||||
|
tv.v = &testBool
|
||||||
|
tv.Merge(&onto)
|
||||||
|
|
||||||
|
assert.Equal(t, onto, true)
|
||||||
|
})
|
||||||
|
}
|
Loading…
Reference in New Issue