RPC and HTTP interfaces fully generically-sockified so Unix is supported.

Client works for RPC; will honor CONSUL_RPC_ADDR. HTTP works via consul/api;
honors CONSUL_HTTP_ADDR.

The format of a Unix socket in configuration data is:
"unix://[/path/to/socket];[username or uid];[gid];[mode]"

Obviously, the user must have appropriate permissions to create the socket
file in the given path and assign the requested uid/gid. Also note that Go does
not support gid lookups from group name, so gid must be numeric. See
https://codereview.appspot.com/101310044

When connecting from the client, the format is just the first part of the
above line:
"unix://[/path/to/socket]"

This code is copyright 2014 Akamai Technologies, Inc. <opensource@akamai.com>
This commit is contained in:
Jeff Mitchell 2015-01-08 16:38:09 +00:00
parent c835a04054
commit 11a3ce0bdd
6 changed files with 214 additions and 29 deletions

View File

@ -5,9 +5,12 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"net"
"net/http" "net/http"
"net/url" "net/url"
"os"
"strconv" "strconv"
"strings"
"time" "time"
) )
@ -111,11 +114,17 @@ type Config struct {
// DefaultConfig returns a default configuration for the client // DefaultConfig returns a default configuration for the client
func DefaultConfig() *Config { func DefaultConfig() *Config {
return &Config{ config := &Config{
Address: "127.0.0.1:8500", Address: "127.0.0.1:8500",
Scheme: "http", Scheme: "http",
HttpClient: http.DefaultClient, HttpClient: http.DefaultClient,
} }
if len(os.Getenv("CONSUL_HTTP_ADDR")) > 0 {
config.Address = os.Getenv("CONSUL_HTTP_ADDR")
}
return config
} }
// Client provides a client to the Consul API // Client provides a client to the Consul API
@ -128,7 +137,11 @@ func NewClient(config *Config) (*Client, error) {
// bootstrap the config // bootstrap the config
defConfig := DefaultConfig() defConfig := DefaultConfig()
if len(config.Address) == 0 { switch {
case len(config.Address) != 0:
case len(os.Getenv("CONSUL_HTTP_ADDR")) > 0:
config.Address = os.Getenv("CONSUL_HTTP_ADDR")
default:
config.Address = defConfig.Address config.Address = defConfig.Address
} }
@ -140,6 +153,16 @@ func NewClient(config *Config) (*Client, error) {
config.HttpClient = defConfig.HttpClient config.HttpClient = defConfig.HttpClient
} }
if strings.HasPrefix(config.Address, "unix://") {
shortStr := strings.TrimPrefix(config.Address, "unix://")
t := &http.Transport{}
t.Dial = func(_, _ string) (net.Conn, error) {
return net.Dial("unix", shortStr)
}
config.HttpClient.Transport = t
config.Address = shortStr
}
client := &Client{ client := &Client{
config: *config, config: *config,
} }
@ -206,9 +229,6 @@ func (r *request) toHTTP() (*http.Request, error) {
// Encode the query parameters // Encode the query parameters
r.url.RawQuery = r.params.Encode() r.url.RawQuery = r.params.Encode()
// Get the url sring
urlRaw := r.url.String()
// Check if we should encode the body // Check if we should encode the body
if r.body == nil && r.obj != nil { if r.body == nil && r.obj != nil {
if b, err := encodeBody(r.obj); err != nil { if b, err := encodeBody(r.obj); err != nil {
@ -219,14 +239,21 @@ func (r *request) toHTTP() (*http.Request, error) {
} }
// Create the HTTP request // Create the HTTP request
req, err := http.NewRequest(r.method, urlRaw, r.body) req, err := http.NewRequest(r.method, r.url.RequestURI(), r.body)
if err != nil {
return nil, err
}
req.URL.Host = r.url.Host
req.URL.Scheme = r.url.Scheme
req.Host = r.url.Host
// Setup auth // Setup auth
if err == nil && r.config.HttpAuth != nil { if err == nil && r.config.HttpAuth != nil {
req.SetBasicAuth(r.config.HttpAuth.Username, r.config.HttpAuth.Password) req.SetBasicAuth(r.config.HttpAuth.Username, r.config.HttpAuth.Password)
} }
return req, err return req, nil
} }
// newRequest is used to create a new request // newRequest is used to create a new request

View File

@ -295,13 +295,26 @@ func (c *Command) setupAgent(config *Config, logOutput io.Writer, logWriter *log
return err return err
} }
rpcListener, err := net.Listen("tcp", rpcAddr.String()) if _, ok := rpcAddr.(*net.UnixAddr); ok {
// Remove the socket if it exists, or we'll get a bind error
_ = os.Remove(rpcAddr.String())
}
rpcListener, err := net.Listen(rpcAddr.Network(), rpcAddr.String())
if err != nil { if err != nil {
agent.Shutdown() agent.Shutdown()
c.Ui.Error(fmt.Sprintf("Error starting RPC listener: %s", err)) c.Ui.Error(fmt.Sprintf("Error starting RPC listener: %s", err))
return err return err
} }
if _, ok := rpcAddr.(*net.UnixAddr); ok {
if err := adjustUnixSocketPermissions(config.Addresses.RPC); err != nil {
agent.Shutdown()
c.Ui.Error(fmt.Sprintf("Error adjusting Unix socket permissions: %s", err))
return err
}
}
// Start the IPC layer // Start the IPC layer
c.Ui.Output("Starting Consul agent RPC...") c.Ui.Output("Starting Consul agent RPC...")
c.rpcServer = NewAgentRPC(agent, rpcListener, logOutput, logWriter) c.rpcServer = NewAgentRPC(agent, rpcListener, logOutput, logWriter)
@ -319,6 +332,7 @@ func (c *Command) setupAgent(config *Config, logOutput io.Writer, logWriter *log
if config.Ports.DNS > 0 { if config.Ports.DNS > 0 {
dnsAddr, err := config.ClientListener(config.Addresses.DNS, config.Ports.DNS) dnsAddr, err := config.ClientListener(config.Addresses.DNS, config.Ports.DNS)
if err != nil { if err != nil {
agent.Shutdown()
c.Ui.Error(fmt.Sprintf("Invalid DNS bind address: %s", err)) c.Ui.Error(fmt.Sprintf("Invalid DNS bind address: %s", err))
return err return err
} }

View File

@ -7,8 +7,10 @@ import (
"io" "io"
"net" "net"
"os" "os"
"os/user"
"path/filepath" "path/filepath"
"sort" "sort"
"strconv"
"strings" "strings"
"time" "time"
@ -345,6 +347,82 @@ type Config struct {
WatchPlans []*watch.WatchPlan `mapstructure:"-" json:"-"` WatchPlans []*watch.WatchPlan `mapstructure:"-" json:"-"`
} }
// UnixSocket contains the parameters for a Unix socket interface
type UnixSocket struct {
// Path to the socket on-disk
Path string
// uid of the owner of the socket
Uid int
// gid of the group of the socket
Gid int
// Permissions for the socket file
Permissions os.FileMode
}
func populateUnixSocket(addr string) (*UnixSocket, error) {
if !strings.HasPrefix(addr, "unix://") {
return nil, fmt.Errorf("Failed to parse Unix address, format is [path];[user];[group];[mode]: %v", addr)
}
splitAddr := strings.Split(strings.TrimPrefix(addr, "unix://"), ";")
if len(splitAddr) != 4 {
return nil, fmt.Errorf("Failed to parse Unix address, format is [path];[user];[group];[mode]: %v", addr)
}
ret := &UnixSocket{Path: splitAddr[0]}
if userVal, err := user.Lookup(splitAddr[1]); err != nil {
return nil, fmt.Errorf("Invalid user given for Unix socket ownership: %v", splitAddr[1])
} else {
if uid64, err := strconv.ParseInt(userVal.Uid, 10, 32); err != nil {
return nil, fmt.Errorf("Failed to parse given user ID of %v into integer", userVal.Uid)
} else {
ret.Uid = int(uid64)
}
}
// Go doesn't currently have a way to look up gid from group name,
// so require a numeric gid; see
// https://codereview.appspot.com/101310044
if gid64, err := strconv.ParseInt(splitAddr[2], 10, 32); err != nil {
return nil, fmt.Errorf("Socket group must be given as numeric gid. Failed to parse given group ID of %v into integer", splitAddr[2])
} else {
ret.Gid = int(gid64)
}
if mode, err := strconv.ParseUint(splitAddr[3], 8, 32); err != nil {
return nil, fmt.Errorf("Failed to parse given mode of %v into integer", splitAddr[3])
} else {
if mode > 0777 {
return nil, fmt.Errorf("Given mode is invalid; must be an octal number between 0 and 777")
} else {
ret.Permissions = os.FileMode(mode)
}
}
return ret, nil
}
func adjustUnixSocketPermissions(addr string) error {
sock, err := populateUnixSocket(addr)
if err != nil {
return err
}
if err = os.Chown(sock.Path, sock.Uid, sock.Gid); err != nil {
return fmt.Errorf("Error attempting to change socket permissions to userid %v and groupid %v: %v", sock.Uid, sock.Gid, err)
}
if err = os.Chmod(sock.Path, sock.Permissions); err != nil {
return fmt.Errorf("Error attempting to change socket permissions to mode %v: %v", sock.Permissions, err)
}
return nil
}
type dirEnts []os.FileInfo type dirEnts []os.FileInfo
// DefaultConfig is used to return a sane default configuration // DefaultConfig is used to return a sane default configuration
@ -389,18 +467,30 @@ func (c *Config) EncryptBytes() ([]byte, error) {
// ClientListener is used to format a listener for a // ClientListener is used to format a listener for a
// port on a ClientAddr // port on a ClientAddr
func (c *Config) ClientListener(override string, port int) (*net.TCPAddr, error) { func (c *Config) ClientListener(override string, port int) (net.Addr, error) {
var addr string var addr string
if override != "" { if override != "" {
addr = override addr = override
} else { } else {
addr = c.ClientAddr addr = c.ClientAddr
} }
ip := net.ParseIP(addr)
if ip == nil { switch {
return nil, fmt.Errorf("Failed to parse IP: %v", addr) case strings.HasPrefix(addr, "unix://"):
sock, err := populateUnixSocket(addr)
if err != nil {
return nil, err
}
return &net.UnixAddr{Name: sock.Path, Net: "unix"}, nil
default:
ip := net.ParseIP(addr)
if ip == nil {
return nil, fmt.Errorf("Failed to parse IP: %v", addr)
}
return &net.TCPAddr{IP: ip, Port: port}, nil
} }
return &net.TCPAddr{IP: ip, Port: port}, nil
} }
// ClientListenerAddr is used to format an address for a // ClientListenerAddr is used to format an address for a
@ -410,8 +500,11 @@ func (c *Config) ClientListenerAddr(override string, port int) (string, error) {
if err != nil { if err != nil {
return "", err return "", err
} }
if addr.IP.IsUnspecified() {
addr.IP = net.ParseIP("127.0.0.1") if ipAddr, ok := addr.(*net.TCPAddr); ok {
if ipAddr.IP.IsUnspecified() {
ipAddr.IP = net.ParseIP("127.0.0.1")
}
} }
return addr.String(), nil return addr.String(), nil
} }

View File

@ -9,6 +9,7 @@ import (
"net" "net"
"net/http" "net/http"
"net/http/pprof" "net/http/pprof"
"os"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@ -34,7 +35,7 @@ type HTTPServer struct {
func NewHTTPServers(agent *Agent, config *Config, logOutput io.Writer) ([]*HTTPServer, error) { func NewHTTPServers(agent *Agent, config *Config, logOutput io.Writer) ([]*HTTPServer, error) {
var tlsConfig *tls.Config var tlsConfig *tls.Config
var list net.Listener var list net.Listener
var httpAddr *net.TCPAddr var httpAddr net.Addr
var err error var err error
var servers []*HTTPServer var servers []*HTTPServer
@ -58,12 +59,29 @@ func NewHTTPServers(agent *Agent, config *Config, logOutput io.Writer) ([]*HTTPS
return nil, err return nil, err
} }
ln, err := net.Listen("tcp", httpAddr.String()) if _, ok := httpAddr.(*net.UnixAddr); ok {
if err != nil { // Remove the socket if it exists, or we'll get a bind error
return nil, err _ = os.Remove(httpAddr.String())
} }
list = tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, tlsConfig) ln, err := net.Listen(httpAddr.Network(), httpAddr.String())
if err != nil {
return nil, fmt.Errorf("Failed to get Listen on %s: %v", httpAddr.String(), err)
}
switch httpAddr.(type) {
case *net.UnixAddr:
if err := adjustUnixSocketPermissions(config.Addresses.HTTPS); err != nil {
return nil, err
}
list = tls.NewListener(ln, tlsConfig)
case *net.TCPAddr:
list = tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, tlsConfig)
default:
return nil, fmt.Errorf("Error determining address type when attempting to get Listen on %s: %v", httpAddr.String(), err)
}
// Create the mux // Create the mux
mux := http.NewServeMux() mux := http.NewServeMux()
@ -90,13 +108,29 @@ func NewHTTPServers(agent *Agent, config *Config, logOutput io.Writer) ([]*HTTPS
return nil, fmt.Errorf("Failed to get ClientListener address:port: %v", err) return nil, fmt.Errorf("Failed to get ClientListener address:port: %v", err)
} }
// Create non-TLS listener if _, ok := httpAddr.(*net.UnixAddr); ok {
ln, err := net.Listen("tcp", httpAddr.String()) // Remove the socket if it exists, or we'll get a bind error
_ = os.Remove(httpAddr.String())
}
ln, err := net.Listen(httpAddr.Network(), httpAddr.String())
if err != nil { if err != nil {
return nil, fmt.Errorf("Failed to get Listen on %s: %v", httpAddr.String(), err) return nil, fmt.Errorf("Failed to get Listen on %s: %v", httpAddr.String(), err)
} }
list = tcpKeepAliveListener{ln.(*net.TCPListener)} switch httpAddr.(type) {
case *net.UnixAddr:
if err := adjustUnixSocketPermissions(config.Addresses.HTTP); err != nil {
return nil, err
}
list = ln
case *net.TCPAddr:
list = tcpKeepAliveListener{ln.(*net.TCPListener)}
default:
return nil, fmt.Errorf("Error determining address type when attempting to get Listen on %s: %v", httpAddr.String(), err)
}
// Create the mux // Create the mux
mux := http.NewServeMux() mux := http.NewServeMux()

View File

@ -7,6 +7,8 @@ import (
"github.com/hashicorp/logutils" "github.com/hashicorp/logutils"
"log" "log"
"net" "net"
"os"
"strings"
"sync" "sync"
"sync/atomic" "sync/atomic"
) )
@ -34,7 +36,7 @@ type seqHandler interface {
type RPCClient struct { type RPCClient struct {
seq uint64 seq uint64
conn *net.TCPConn conn net.Conn
reader *bufio.Reader reader *bufio.Reader
writer *bufio.Writer writer *bufio.Writer
dec *codec.Decoder dec *codec.Decoder
@ -79,8 +81,18 @@ func (c *RPCClient) send(header *requestHeader, obj interface{}) error {
// NewRPCClient is used to create a new RPC client given the address. // NewRPCClient is used to create a new RPC client given the address.
// This will properly dial, handshake, and start listening // This will properly dial, handshake, and start listening
func NewRPCClient(addr string) (*RPCClient, error) { func NewRPCClient(addr string) (*RPCClient, error) {
sanedAddr := os.Getenv("CONSUL_RPC_ADDR")
if len(sanedAddr) == 0 {
sanedAddr = addr
}
mode := "tcp"
if strings.HasPrefix(sanedAddr, "unix://") {
sanedAddr = strings.TrimPrefix(sanedAddr, "unix://")
mode = "unix"
}
// Try to dial to agent // Try to dial to agent
conn, err := net.Dial("tcp", addr) conn, err := net.Dial(mode, sanedAddr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -88,7 +100,7 @@ func NewRPCClient(addr string) (*RPCClient, error) {
// Create the client // Create the client
client := &RPCClient{ client := &RPCClient{
seq: 0, seq: 0,
conn: conn.(*net.TCPConn), conn: conn,
reader: bufio.NewReader(conn), reader: bufio.NewReader(conn),
writer: bufio.NewWriter(conn), writer: bufio.NewWriter(conn),
dispatch: make(map[uint64]seqHandler), dispatch: make(map[uint64]seqHandler),

View File

@ -8,8 +8,8 @@ import (
"github.com/hashicorp/consul/command/agent" "github.com/hashicorp/consul/command/agent"
) )
// RPCAddrEnvName defines the environment variable name, which can set // RPCAddrEnvName defines an environment variable name which sets
// a default RPC address in case there is no -rpc-addr specified. // an RPC address if there is no -rpc-addr specified.
const RPCAddrEnvName = "CONSUL_RPC_ADDR" const RPCAddrEnvName = "CONSUL_RPC_ADDR"
// RPCAddrFlag returns a pointer to a string that will be populated // RPCAddrFlag returns a pointer to a string that will be populated
@ -43,7 +43,12 @@ func HTTPClient(addr string) (*consulapi.Client, error) {
// HTTPClientDC returns a new Consul HTTP client with the given address and datacenter // HTTPClientDC returns a new Consul HTTP client with the given address and datacenter
func HTTPClientDC(addr, dc string) (*consulapi.Client, error) { func HTTPClientDC(addr, dc string) (*consulapi.Client, error) {
conf := consulapi.DefaultConfig() conf := consulapi.DefaultConfig()
conf.Address = addr switch {
case len(os.Getenv("CONSUL_HTTP_ADDR")) > 0:
conf.Address = os.Getenv("CONSUL_HTTP_ADDR")
default:
conf.Address = addr
}
conf.Datacenter = dc conf.Datacenter = dc
return consulapi.NewClient(conf) return consulapi.NewClient(conf)
} }