grpc client in tls mode (#19680)

* client in tls mode
This commit is contained in:
wangxinyi7 2023-12-19 10:04:55 -08:00 committed by GitHub
parent cff872749d
commit 013bcefe5c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 593 additions and 120 deletions

View File

@ -6,7 +6,6 @@ package client
import ( import (
"bytes" "bytes"
"context" "context"
"crypto/tls"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
@ -21,7 +20,6 @@ import (
"github.com/hashicorp/go-cleanhttp" "github.com/hashicorp/go-cleanhttp"
"github.com/hashicorp/go-hclog" "github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-rootcerts"
"github.com/hashicorp/consul/api" "github.com/hashicorp/consul/api"
) )
@ -73,14 +71,6 @@ const (
// whether or not to disable certificate checking. // whether or not to disable certificate checking.
HTTPSSLVerifyEnvName = "CONSUL_HTTP_SSL_VERIFY" HTTPSSLVerifyEnvName = "CONSUL_HTTP_SSL_VERIFY"
// GRPCCAFileEnvName defines an environment variable name which sets the
// CA file to use for talking to Consul gRPC over TLS.
GRPCCAFileEnvName = "CONSUL_GRPC_CACERT"
// GRPCCAPathEnvName defines an environment variable name which sets the
// path to a directory of CA certs to use for talking to Consul gRPC over TLS.
GRPCCAPathEnvName = "CONSUL_GRPC_CAPATH"
// HTTPNamespaceEnvVar defines an environment variable name which sets // HTTPNamespaceEnvVar defines an environment variable name which sets
// the HTTP Namespace to be used by default. This can still be overridden. // the HTTP Namespace to be used by default. This can still be overridden.
HTTPNamespaceEnvName = "CONSUL_NAMESPACE" HTTPNamespaceEnvName = "CONSUL_NAMESPACE"
@ -538,60 +528,6 @@ func defaultConfig(logger hclog.Logger, transportFn func() *http.Transport) *Con
return config return config
} }
// TLSConfig is used to generate a TLSClientConfig that's useful for talking to
// Consul using TLS.
func SetupTLSConfig(tlsConfig *TLSConfig) (*tls.Config, error) {
tlsClientConfig := &tls.Config{
InsecureSkipVerify: tlsConfig.InsecureSkipVerify,
}
if tlsConfig.Address != "" {
server := tlsConfig.Address
hasPort := strings.LastIndex(server, ":") > strings.LastIndex(server, "]")
if hasPort {
var err error
server, _, err = net.SplitHostPort(server)
if err != nil {
return nil, err
}
}
tlsClientConfig.ServerName = server
}
if len(tlsConfig.CertPEM) != 0 && len(tlsConfig.KeyPEM) != 0 {
tlsCert, err := tls.X509KeyPair(tlsConfig.CertPEM, tlsConfig.KeyPEM)
if err != nil {
return nil, err
}
tlsClientConfig.Certificates = []tls.Certificate{tlsCert}
} else if len(tlsConfig.CertPEM) != 0 || len(tlsConfig.KeyPEM) != 0 {
return nil, fmt.Errorf("both client cert and client key must be provided")
}
if tlsConfig.CertFile != "" && tlsConfig.KeyFile != "" {
tlsCert, err := tls.LoadX509KeyPair(tlsConfig.CertFile, tlsConfig.KeyFile)
if err != nil {
return nil, err
}
tlsClientConfig.Certificates = []tls.Certificate{tlsCert}
} else if tlsConfig.CertFile != "" || tlsConfig.KeyFile != "" {
return nil, fmt.Errorf("both client cert and client key must be provided")
}
if tlsConfig.CAFile != "" || tlsConfig.CAPath != "" || len(tlsConfig.CAPem) != 0 {
rootConfig := &rootcerts.Config{
CAFile: tlsConfig.CAFile,
CAPath: tlsConfig.CAPath,
CACertificate: tlsConfig.CAPem,
}
if err := rootcerts.ConfigureTLS(tlsClientConfig, rootConfig); err != nil {
return nil, err
}
}
return tlsClientConfig, nil
}
func (c *Config) GenerateEnv() []string { func (c *Config) GenerateEnv() []string {
env := make([]string, 0, 10) env := make([]string, 0, 10)

View File

@ -7,6 +7,7 @@ import (
"fmt" "fmt"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/credentials/insecure"
"github.com/hashicorp/consul/proto-public/pbresource" "github.com/hashicorp/consul/proto-public/pbresource"
@ -31,9 +32,34 @@ func NewGRPCClient(config *GRPCConfig) (*GRPCClient, error) {
} }
func dial(c *GRPCConfig) (*grpc.ClientConn, error) { func dial(c *GRPCConfig) (*grpc.ClientConn, error) {
// TODO: decide if we use TLS mode based on the config err := checkCertificates(c)
dialOpts := []grpc.DialOption{ if err != nil {
grpc.WithTransportCredentials(insecure.NewCredentials()), return nil, err
} }
var dialOpts []grpc.DialOption
if c.GRPCTLS {
tlsConfig, err := SetupTLSConfig(c)
if err != nil {
return nil, fmt.Errorf("failed to setup tls config when tried to establish grpc call: %w", err)
}
dialOpts = append(dialOpts, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)))
} else {
dialOpts = append(dialOpts, grpc.WithTransportCredentials(insecure.NewCredentials()))
}
return grpc.Dial(c.Address, dialOpts...) return grpc.Dial(c.Address, dialOpts...)
} }
func checkCertificates(c *GRPCConfig) error {
if c.GRPCTLS {
certFileEmpty := c.CertFile == ""
keyFileEmpty := c.CertFile == ""
// both files need to be empty or both files need to be provided
if certFileEmpty != keyFileEmpty {
return fmt.Errorf("you have to provide client certificate file and key file at the same time " +
"if you intend to communicate in TLS/SSL mode")
}
}
return nil
}

View File

@ -9,23 +9,27 @@ import (
"testing" "testing"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
"github.com/hashicorp/consul/agent" "github.com/hashicorp/consul/agent"
"github.com/hashicorp/consul/internal/resource/demo" "github.com/hashicorp/consul/internal/resource/demo"
"github.com/hashicorp/consul/proto-public/pbresource" "github.com/hashicorp/consul/proto-public/pbresource"
"github.com/hashicorp/consul/proto/private/prototest" "github.com/hashicorp/consul/proto/private/prototest"
"github.com/hashicorp/consul/sdk/freeport"
"github.com/hashicorp/consul/sdk/testutil" "github.com/hashicorp/consul/sdk/testutil"
"github.com/hashicorp/consul/testrpc" "github.com/hashicorp/consul/testrpc"
) )
func TestResourceRead(t *testing.T) { func TestResourceRead(t *testing.T) {
t.Parallel() availablePort := freeport.GetOne(t)
a := agent.NewTestAgent(t, fmt.Sprintf("ports { grpc = %d }", availablePort))
a := agent.NewTestAgent(t, "ports { grpc = 8502 }")
testrpc.WaitForTestAgent(t, a.RPC, "dc1") testrpc.WaitForTestAgent(t, a.RPC, "dc1")
grpcConfig := GetDefaultGRPCConfig() grpcConfig, err := LoadGRPCConfig(&GRPCConfig{Address: fmt.Sprintf("127.0.0.1:%d", availablePort)})
require.NoError(t, err)
gRPCClient, err := NewGRPCClient(grpcConfig) gRPCClient, err := NewGRPCClient(grpcConfig)
require.NoError(t, err)
t.Cleanup(func() { t.Cleanup(func() {
a.Shutdown() a.Shutdown()
@ -49,3 +53,95 @@ func TestResourceRead(t *testing.T) {
prototest.AssertDeepEqual(t, writeRsp.Resource, readRsp.Resource) prototest.AssertDeepEqual(t, writeRsp.Resource, readRsp.Resource)
}) })
} }
func TestResourceReadInTLS(t *testing.T) {
tests := []struct {
name string
requireClientCert bool
grpcConfig func() (*GRPCConfig, int)
}{
{
name: "Test with CertFile, KeyFile and CAFile",
requireClientCert: true,
grpcConfig: func() (*GRPCConfig, int) {
availablePort := freeport.GetOne(t)
return &GRPCConfig{
Address: fmt.Sprintf("127.0.0.1:%d", availablePort),
GRPCTLS: true,
CertFile: "../../../test/client_certs/client.crt",
KeyFile: "../../../test/client_certs/client.key",
CAFile: "../../../test/client_certs/rootca.crt",
}, availablePort
},
},
{
name: "Test without CAFile",
requireClientCert: true,
grpcConfig: func() (*GRPCConfig, int) {
availablePort := freeport.GetOne(t)
return &GRPCConfig{
Address: fmt.Sprintf("127.0.0.1:%d", availablePort),
GRPCTLS: true,
GRPCTLSVerify: false,
CertFile: "../../../test/client_certs/client.crt",
KeyFile: "../../../test/client_certs/client.key",
}, availablePort
},
},
{
name: "Test without client certificates",
requireClientCert: false,
grpcConfig: func() (*GRPCConfig, int) {
availablePort := freeport.GetOne(t)
return &GRPCConfig{
Address: fmt.Sprintf("127.0.0.1:%d", availablePort),
GRPCTLS: true,
}, availablePort
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
grpcClientConfig, availablePort := tt.grpcConfig()
a := agent.StartTestAgent(t, agent.TestAgent{
HCL: fmt.Sprintf(`
ports { grpc_tls = %d }
enable_agent_tls_for_checks = true
tls {
defaults {
verify_incoming = %t
key_file = "../../../test/client_certs/server.key"
cert_file = "../../../test/client_certs/server.crt"
ca_file = "../../../test/client_certs/rootca.crt"
}
}`, availablePort, tt.requireClientCert),
UseGRPCTLS: true,
})
testrpc.WaitForTestAgent(t, a.RPC, "dc1")
grpcConfig, err := LoadGRPCConfig(grpcClientConfig)
require.NoError(t, err)
gRPCClient, err := NewGRPCClient(grpcConfig)
require.NoError(t, err)
t.Cleanup(func() {
a.Shutdown()
gRPCClient.Conn.Close()
})
v2Artist, err := demo.GenerateV2Artist()
require.NoError(t, err)
_, err = gRPCClient.Client.Read(context.Background(), &pbresource.ReadRequest{Id: v2Artist.Id})
require.Equal(t, codes.NotFound.String(), status.Code(err).String())
writeRsp, err := gRPCClient.Client.Write(testutil.TestContext(t), &pbresource.WriteRequest{Resource: v2Artist})
require.NoError(t, err)
readRsp, err := gRPCClient.Client.Read(context.Background(), &pbresource.ReadRequest{Id: v2Artist.Id})
require.NoError(t, err)
require.Equal(t, proto.Equal(readRsp.Resource.Id.Type, demo.TypeV2Artist), true)
prototest.AssertDeepEqual(t, writeRsp.Resource, readRsp.Resource)
})
}
}

View File

@ -4,40 +4,132 @@
package client package client
import ( import (
"fmt"
"os" "os"
"strconv"
"strings"
) )
const ( const (
// GRPCAddrEnvName defines an environment variable name which sets the gRPC // GRPCAddrEnvName defines an environment variable name which sets the gRPC
// server address for the consul CLI. // server address for the consul CLI.
GRPCAddrEnvName = "CONSUL_GRPC_ADDR" GRPCAddrEnvName = "CONSUL_GRPC_ADDR"
// GRPCTLSEnvName defines an environment variable name which sets the gRPC
// communication mode. Default is false in plaintext mode.
GRPCTLSEnvName = "CONSUL_GRPC_TLS"
// GRPCTLSVerifyEnvName defines an environment variable name which sets
// whether to disable certificate checking.
GRPCTLSVerifyEnvName = "CONSUL_GRPC_TLS_VERIFY"
// GRPCClientCertEnvName defines an environment variable name which sets the
// client cert file to use for talking to Consul over TLS.
GRPCClientCertEnvName = "CONSUL_GRPC_CLIENT_CERT"
// GRPCClientKeyEnvName defines an environment variable name which sets the
// client key file to use for talking to Consul over TLS.
GRPCClientKeyEnvName = "CONSUL_GRPC_CLIENT_KEY"
// GRPCCAFileEnvName defines an environment variable name which sets the
// CA file to use for talking to Consul gRPC over TLS.
GRPCCAFileEnvName = "CONSUL_GRPC_CACERT"
// GRPCCAPathEnvName defines an environment variable name which sets the
// path to a directory of CA certs to use for talking to Consul gRPC over TLS.
GRPCCAPathEnvName = "CONSUL_GRPC_CAPATH"
) )
type GRPCConfig struct { type GRPCConfig struct {
// Address is the optional address of the Consul server in format of host:port.
// It doesn't include schema
Address string Address string
// GRPCTLS is the optional boolean flag to determine the communication protocol
GRPCTLS bool
// GRPCTLSVerify is the optional boolean flag to disable certificate checking.
// Set to false only if you want to skip server verification
GRPCTLSVerify bool
// CertFile is the optional path to the certificate for Consul
// communication. If this is set then you need to also set KeyFile.
CertFile string
// KeyFile is the optional path to the private key for Consul communication.
// If this is set then you need to also set CertFile.
KeyFile string
// CAFile is the optional path to the CA certificate used for Consul
// communication, defaults to the system bundle if not specified.
CAFile string
// CAPath is the optional path to a directory of CA certificates to use for
// Consul communication, defaults to the system bundle if not specified.
CAPath string
} }
func GetDefaultGRPCConfig() *GRPCConfig { func GetDefaultGRPCConfig() *GRPCConfig {
return &GRPCConfig{ return &GRPCConfig{
Address: "localhost:8502", Address: "127.0.0.1:8502",
GRPCTLSVerify: false,
} }
} }
func LoadGRPCConfig(defaultConfig *GRPCConfig) *GRPCConfig { func LoadGRPCConfig(defaultConfig *GRPCConfig) (*GRPCConfig, error) {
if defaultConfig == nil { if defaultConfig == nil {
defaultConfig = GetDefaultGRPCConfig() defaultConfig = GetDefaultGRPCConfig()
} }
overwrittenConfig := loadEnvToDefaultConfig(defaultConfig) overwrittenConfig, err := loadEnvToDefaultConfig(defaultConfig)
if err != nil {
return overwrittenConfig return nil, err
}
func loadEnvToDefaultConfig(config *GRPCConfig) *GRPCConfig {
if addr := os.Getenv(GRPCAddrEnvName); addr != "" {
config.Address = addr
} }
return config return overwrittenConfig, nil
}
func loadEnvToDefaultConfig(config *GRPCConfig) (*GRPCConfig, error) {
if addr := os.Getenv(GRPCAddrEnvName); addr != "" {
if strings.HasPrefix(strings.ToLower(addr), "https://") {
config.GRPCTLS = true
}
config.Address = removeSchemaFromGRPCAddress(addr)
}
if tlsMode := os.Getenv(GRPCTLSEnvName); tlsMode != "" {
doTLS, err := strconv.ParseBool(tlsMode)
if err != nil {
return nil, fmt.Errorf("failed to parse %s: %w", GRPCTLSEnvName, err)
}
if doTLS {
config.GRPCTLS = true
}
}
if v := os.Getenv(GRPCTLSVerifyEnvName); v != "" {
doVerify, err := strconv.ParseBool(v)
if err != nil {
return nil, fmt.Errorf("failed to parse %s: %w", GRPCTLSVerifyEnvName, err)
}
config.GRPCTLSVerify = doVerify
}
if v := os.Getenv(GRPCClientCertEnvName); v != "" {
config.CertFile = v
}
if v := os.Getenv(GRPCClientKeyEnvName); v != "" {
config.KeyFile = v
}
if caFile := os.Getenv(GRPCCAFileEnvName); caFile != "" {
config.CAFile = caFile
}
if caPath := os.Getenv(GRPCCAPathEnvName); caPath != "" {
config.CAPath = caPath
}
return config, nil
} }

View File

@ -0,0 +1,56 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package client
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestLoadGRPCConfig(t *testing.T) {
t.Run("Default Config", func(t *testing.T) {
// Test when defaultConfig is nil
config, err := LoadGRPCConfig(nil)
assert.NoError(t, err)
assert.Equal(t, GetDefaultGRPCConfig(), config)
})
// Test when environment variables are set
t.Run("Env Overwritten", func(t *testing.T) {
// Mock environment variables
t.Setenv(GRPCAddrEnvName, "localhost:8500")
t.Setenv(GRPCTLSEnvName, "true")
t.Setenv(GRPCTLSVerifyEnvName, "false")
t.Setenv(GRPCClientCertEnvName, "/path/to/client.crt")
t.Setenv(GRPCClientKeyEnvName, "/path/to/client.key")
t.Setenv(GRPCCAFileEnvName, "/path/to/ca.crt")
t.Setenv(GRPCCAPathEnvName, "/path/to/cacerts")
// Load and validate the configuration
config, err := LoadGRPCConfig(nil)
assert.NoError(t, err)
expectedConfig := &GRPCConfig{
Address: "localhost:8500",
GRPCTLS: true,
GRPCTLSVerify: false,
CertFile: "/path/to/client.crt",
KeyFile: "/path/to/client.key",
CAFile: "/path/to/ca.crt",
CAPath: "/path/to/cacerts",
}
assert.Equal(t, expectedConfig, config)
})
// Test when there's an error parsing a boolean value from an environment variable
t.Run("Error Parsing Bool", func(t *testing.T) {
// Mock environment variable with an invalid boolean value
t.Setenv(GRPCTLSEnvName, "invalid_boolean_value")
// Load and expect an error
config, err := LoadGRPCConfig(nil)
assert.Error(t, err, "failed to parse CONSUL_GRPC_TLS: strconv.ParseBool: parsing \"invalid_boolean_value\": invalid syntax")
assert.Nil(t, config)
})
}

View File

@ -5,54 +5,60 @@ package client
import ( import (
"flag" "flag"
"strings"
) )
type GRPCFlags struct { type GRPCFlags struct {
address StringValue address TValue[string]
grpcTLS TValue[bool]
certFile TValue[string]
keyFile TValue[string]
caFile TValue[string]
caPath TValue[string]
} }
// mergeFlagsIntoGRPCConfig merges flag values into grpc config // MergeFlagsIntoGRPCConfig merges flag values into grpc config
// caller has to parse the CLI args before loading them into flag values // caller has to parse the CLI args before loading them into flag values
func (f *GRPCFlags) mergeFlagsIntoGRPCConfig(c *GRPCConfig) { // The flags take precedence over the environment values
func (f *GRPCFlags) MergeFlagsIntoGRPCConfig(c *GRPCConfig) {
if strings.HasPrefix(strings.ToLower(f.address.String()), "https://") {
c.GRPCTLS = true
}
f.address.Set(removeSchemaFromGRPCAddress(f.address.String()))
f.address.Merge(&c.Address) f.address.Merge(&c.Address)
// won't overwrite the value if it's false
if f.grpcTLS.v != nil && *f.grpcTLS.v {
f.grpcTLS.Merge(&c.GRPCTLS)
}
f.certFile.Merge(&c.CertFile)
f.keyFile.Merge(&c.KeyFile)
f.caFile.Merge(&c.CAFile)
f.caPath.Merge(&c.CAPath)
} }
// merge the client flags into command line flags then parse command line flags
func (f *GRPCFlags) ClientFlags() *flag.FlagSet { func (f *GRPCFlags) ClientFlags() *flag.FlagSet {
fs := flag.NewFlagSet("", flag.ContinueOnError) fs := flag.NewFlagSet("", flag.ContinueOnError)
fs.Var(&f.address, "grpc-addr", fs.Var(&f.address, "grpc-addr",
"The `address` and port of the Consul GRPC agent. The value can be an IP "+ "The `address` and `port` of the Consul GRPC agent. The value can be an IP "+
"address or DNS address, but it must also include the port. This can "+ "address or DNS address, but it must also include the port. This can also be specified "+
"also be specified via the CONSUL_GRPC_ADDR environment variable. The "+ "via the CONSUL_GRPC_ADDR environment variable. The default value is "+
"default value is 127.0.0.1:8502. It supports TLS communication "+ "127.0.0.1:8502. If you intend to communicate in TLS mode, you have to either "+
"by setting the environment variable CONSUL_GRPC_TLS=true.") "include https:// schema in the address, use grpc-tls flag or set environment variable "+
"CONSUL_GRPC_TLS = true, otherwise it uses plaintext mode")
fs.Var(&f.caFile, "grpc-tls",
"Set to true if you aim to communicate in TLS mode in the GRPC call.")
fs.Var(&f.certFile, "client-cert",
"Path to a client cert file to use for TLS when 'verify_incoming' is enabled. This "+
"can also be specified via the CONSUL_GRPC_CLIENT_CERT environment variable.")
fs.Var(&f.keyFile, "client-key",
"Path to a client key file to use for TLS when 'verify_incoming' is enabled. This "+
"can also be specified via the CONSUL_GRPC_CLIENT_KEY environment variable.")
fs.Var(&f.caFile, "ca-file",
"Path to a CA file to use for TLS when communicating with Consul. This "+
"can also be specified via the CONSUL_CACERT environment variable.")
fs.Var(&f.caPath, "ca-path",
"Path to a directory of CA certificates to use for TLS when communicating "+
"with Consul. This can also be specified via the CONSUL_CAPATH environment variable.")
return fs return fs
} }
type StringValue struct {
v *string
}
// Set implements the flag.Value interface.
func (s *StringValue) Set(v string) error {
if s.v == nil {
s.v = new(string)
}
*(s.v) = v
return nil
}
// String implements the flag.Value interface.
func (s *StringValue) String() string {
var current string
if s.v != nil {
current = *(s.v)
}
return current
}
// Merge will overlay this value if it has been set.
func (s *StringValue) Merge(onto *string) {
if s.v != nil {
*onto = *(s.v)
}
}

View File

@ -0,0 +1,98 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package client
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestMergeFlagsIntoGRPCConfig(t *testing.T) {
t.Run("MergeFlagsIntoGRPCConfig", func(t *testing.T) {
// Setup GRPCFlags with some flag values
flags := &GRPCFlags{
address: TValue[string]{v: stringPointer("https://example.com:8502")},
grpcTLS: TValue[bool]{v: boolPointer(true)},
certFile: TValue[string]{v: stringPointer("/path/to/client.crt")},
keyFile: TValue[string]{v: stringPointer("/path/to/client.key")},
caFile: TValue[string]{v: stringPointer("/path/to/ca.crt")},
caPath: TValue[string]{v: stringPointer("/path/to/cacerts")},
}
// Setup GRPCConfig with some initial values
config := &GRPCConfig{
Address: "localhost:8500",
GRPCTLS: false,
GRPCTLSVerify: true,
CertFile: "/path/to/default/client.crt",
KeyFile: "/path/to/default/client.key",
CAFile: "/path/to/default/ca.crt",
CAPath: "/path/to/default/cacerts",
}
// Call MergeFlagsIntoGRPCConfig to merge flag values into the config
flags.MergeFlagsIntoGRPCConfig(config)
// Validate the merged config
expectedConfig := &GRPCConfig{
Address: "example.com:8502",
GRPCTLS: true,
GRPCTLSVerify: true,
CertFile: "/path/to/client.crt",
KeyFile: "/path/to/client.key",
CAFile: "/path/to/ca.crt",
CAPath: "/path/to/cacerts",
}
assert.Equal(t, expectedConfig, config)
})
t.Run("MergeFlagsIntoGRPCConfig: allow empty values", func(t *testing.T) {
// Setup GRPCFlags with some flag values
flags := &GRPCFlags{
address: TValue[string]{v: stringPointer("http://example.com:8502")},
grpcTLS: TValue[bool]{},
certFile: TValue[string]{v: stringPointer("/path/to/client.crt")},
keyFile: TValue[string]{v: stringPointer("/path/to/client.key")},
caFile: TValue[string]{v: stringPointer("/path/to/ca.crt")},
caPath: TValue[string]{v: stringPointer("/path/to/cacerts")},
}
// Setup GRPCConfig with some initial values
config := &GRPCConfig{
Address: "localhost:8500",
GRPCTLSVerify: true,
CertFile: "/path/to/default/client.crt",
KeyFile: "/path/to/default/client.key",
CAFile: "/path/to/default/ca.crt",
CAPath: "/path/to/default/cacerts",
}
// Call MergeFlagsIntoGRPCConfig to merge flag values into the config
flags.MergeFlagsIntoGRPCConfig(config)
// Validate the merged config
expectedConfig := &GRPCConfig{
Address: "example.com:8502",
GRPCTLS: false,
GRPCTLSVerify: true,
CertFile: "/path/to/client.crt",
KeyFile: "/path/to/client.key",
CAFile: "/path/to/ca.crt",
CAPath: "/path/to/cacerts",
}
assert.Equal(t, expectedConfig, config)
})
}
// Utility function to convert string to string pointer
func stringPointer(s string) *string {
return &s
}
// Utility function to convert bool to bool pointer
func boolPointer(b bool) *bool {
return &b
}

View File

@ -0,0 +1,93 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package client
import (
"crypto/tls"
"fmt"
"strconv"
"strings"
"github.com/hashicorp/go-rootcerts"
)
// tls.Config is used to establish communication in TLS mode
func SetupTLSConfig(c *GRPCConfig) (*tls.Config, error) {
tlsConfig := &tls.Config{
InsecureSkipVerify: !c.GRPCTLSVerify,
}
if c.CertFile != "" && c.KeyFile != "" {
tlsCert, err := tls.LoadX509KeyPair(c.CertFile, c.KeyFile)
if err != nil {
return nil, err
}
tlsConfig.Certificates = []tls.Certificate{tlsCert}
}
var caConfig *rootcerts.Config
if c.CAFile != "" || c.CAPath != "" {
caConfig = &rootcerts.Config{
CAFile: c.CAFile,
CAPath: c.CAPath,
}
}
// load system CA certs if user doesn't provide any
if err := rootcerts.ConfigureTLS(tlsConfig, caConfig); err != nil {
return nil, err
}
return tlsConfig, nil
}
func removeSchemaFromGRPCAddress(addr string) string {
// Parse as host:port with option http prefix
grpcAddr := strings.TrimPrefix(addr, "http://")
grpcAddr = strings.TrimPrefix(grpcAddr, "https://")
return grpcAddr
}
type TValue[T string | bool] struct {
v *T
}
// Set implements the flag.Value interface.
func (t *TValue[T]) Set(v string) error {
if t.v == nil {
t.v = new(T)
}
var err error
// have to use interface{}(t.v) to do type assertion
switch interface{}(t.v).(type) {
case *string:
// have to use interface{}(t.v).(*string) to assert t.v as *string
*(interface{}(t.v).(*string)) = v
case *bool:
// have to use interface{}(t.v).(*bool) to assert t.v as *bool
*(interface{}(t.v).(*bool)), err = strconv.ParseBool(v)
default:
err = fmt.Errorf("unsupported type %T", t.v)
}
return err
}
// String implements the flag.Value interface.
func (t *TValue[T]) String() string {
var current T
if t.v != nil {
current = *(t.v)
}
return fmt.Sprintf("%v", current)
}
// Merge will overlay this value if it has been set.
func (t *TValue[T]) Merge(onto *T) error {
if onto == nil {
return fmt.Errorf("onto is nil")
}
if t.v != nil {
*onto = *(t.v)
}
return nil
}

View File

@ -0,0 +1,70 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package client
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestTValue(t *testing.T) {
t.Run("String: set", func(t *testing.T) {
var tv TValue[string]
err := tv.Set("testString")
assert.NoError(t, err)
assert.Equal(t, *tv.v, "testString")
})
t.Run("String: merge", func(t *testing.T) {
var tv TValue[string]
var onto string
testStr := "testString"
tv.v = &testStr
tv.Merge(&onto)
assert.Equal(t, onto, "testString")
})
t.Run("String: merge nil", func(t *testing.T) {
var tv TValue[string]
var onto *string = nil
testStr := "testString"
tv.v = &testStr
err := tv.Merge(onto)
assert.Equal(t, err.Error(), "onto is nil")
})
t.Run("Get string", func(t *testing.T) {
var tv TValue[string]
testStr := "testString"
tv.v = &testStr
assert.Equal(t, tv.String(), "testString")
})
t.Run("Bool: set", func(t *testing.T) {
var tv TValue[bool]
err := tv.Set("true")
assert.NoError(t, err)
assert.Equal(t, *tv.v, true)
})
t.Run("Bool: merge", func(t *testing.T) {
var tv TValue[bool]
var onto bool
testBool := true
tv.v = &testBool
tv.Merge(&onto)
assert.Equal(t, onto, true)
})
}