consul/agent/connect/ca/provider_vault_auth_azure.go

146 lines
3.8 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package ca
import (
"fmt"
"io"
"net/http"
"strings"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/go-cleanhttp"
"github.com/hashicorp/vault/sdk/helper/jsonutil"
)
func NewAzureAuthClient(authMethod *structs.VaultAuthMethod) (*VaultAuthClient, error) {
params := authMethod.Params
authClient := NewVaultAPIAuthClient(authMethod, "")
// check for login data already in params (for backwards compability)
legacyKeys := []string{
"vm_name", "vmss_name", "resource_group_name", "subscription_id", "jwt",
}
if legacyCheck(params, legacyKeys...) {
return authClient, nil
}
role, ok := params["role"].(string)
if !ok || strings.TrimSpace(role) == "" {
return nil, fmt.Errorf("missing 'role' value")
}
resource, ok := params["resource"].(string)
if !ok || strings.TrimSpace(resource) == "" {
return nil, fmt.Errorf("missing 'resource' value")
}
authClient.LoginDataGen = AzureLoginDataGen
return authClient, nil
}
var ( // use variables so we can change these in tests
instanceEndpoint = "http://169.254.169.254/metadata/instance"
identityEndpoint = "http://169.254.169.254/metadata/identity/oauth2/token"
// minimum version 2018-02-01 needed for identity metadata
apiVersion = "2018-02-01"
)
type instanceData struct {
Compute Compute
}
type Compute struct {
Name string
ResourceGroupName string
SubscriptionID string
VMScaleSetName string
}
type identityData struct {
AccessToken string `json:"access_token"`
}
func AzureLoginDataGen(authMethod *structs.VaultAuthMethod) (map[string]any, error) {
params := authMethod.Params
role := params["role"].(string)
metaConf := map[string]string{
"role": role,
"resource": params["resource"].(string),
}
if objectID, ok := params["object_id"].(string); ok {
metaConf["object_id"] = objectID
}
if clientID, ok := params["client_id"].(string); ok {
metaConf["client_id"] = clientID
}
// Fetch instance data
var instance instanceData
body, err := getMetadataInfo(instanceEndpoint, nil)
if err != nil {
return nil, err
}
err = jsonutil.DecodeJSON(body, &instance)
if err != nil {
return nil, fmt.Errorf("error parsing instance metadata response: %w", err)
}
// Fetch JWT
var identity identityData
body, err = getMetadataInfo(identityEndpoint, metaConf)
if err != nil {
return nil, err
}
err = jsonutil.DecodeJSON(body, &identity)
if err != nil {
return nil, fmt.Errorf("error parsing instance metadata response: %w", err)
}
data := map[string]interface{}{
"role": role,
"vm_name": instance.Compute.Name,
"vmss_name": instance.Compute.VMScaleSetName,
"resource_group_name": instance.Compute.ResourceGroupName,
"subscription_id": instance.Compute.SubscriptionID,
"jwt": identity.AccessToken,
}
return data, nil
}
func getMetadataInfo(endpoint string, query map[string]string) ([]byte, error) {
req, err := http.NewRequest("GET", endpoint, nil)
if err != nil {
return nil, err
}
q := req.URL.Query()
q.Add("api-version", apiVersion)
for k, v := range query {
q.Add(k, v)
}
req.URL.RawQuery = q.Encode()
req.Header.Set("Metadata", "true")
req.Header.Set("User-Agent", "Consul")
client := cleanhttp.DefaultClient()
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("error fetching metadata from %s: %w", endpoint, err)
}
if resp == nil {
return nil, fmt.Errorf("empty response fetching metadata from %s", endpoint)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("error reading metadata from %s: %w", endpoint, err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("error response in metadata from %s: %s", endpoint, body)
}
return body, nil
}