Merge pull request #612 from hashicorp/f-refactor

Refactor UNIX domain socket
This commit is contained in:
Ryan Uber 2015-01-16 14:34:19 -08:00
commit df52ac6bae
14 changed files with 259 additions and 415 deletions

View File

@ -120,8 +120,8 @@ func DefaultConfig() *Config {
HttpClient: http.DefaultClient, HttpClient: http.DefaultClient,
} }
if len(os.Getenv("CONSUL_HTTP_ADDR")) > 0 { if addr := os.Getenv("CONSUL_HTTP_ADDR"); addr != "" {
config.Address = os.Getenv("CONSUL_HTTP_ADDR") config.Address = addr
} }
return config return config
@ -137,11 +137,7 @@ func NewClient(config *Config) (*Client, error) {
// bootstrap the config // bootstrap the config
defConfig := DefaultConfig() defConfig := DefaultConfig()
switch { if len(config.Address) == 0 {
case len(config.Address) != 0:
case len(os.Getenv("CONSUL_HTTP_ADDR")) > 0:
config.Address = os.Getenv("CONSUL_HTTP_ADDR")
default:
config.Address = defConfig.Address config.Address = defConfig.Address
} }
@ -153,14 +149,15 @@ func NewClient(config *Config) (*Client, error) {
config.HttpClient = defConfig.HttpClient config.HttpClient = defConfig.HttpClient
} }
if strings.HasPrefix(config.Address, "unix://") { if parts := strings.SplitN(config.Address, "unix://", 2); len(parts) == 2 {
shortStr := strings.TrimPrefix(config.Address, "unix://") config.HttpClient = &http.Client{
t := &http.Transport{} Transport: &http.Transport{
t.Dial = func(_, _ string) (net.Conn, error) { Dial: func(_, _ string) (net.Conn, error) {
return net.Dial("unix", shortStr) return net.Dial("unix", parts[1])
},
},
} }
config.HttpClient.Transport = t config.Address = parts[1]
config.Address = shortStr
} }
client := &Client{ client := &Client{

View File

@ -8,6 +8,8 @@ import (
"net/http" "net/http"
"os" "os"
"os/exec" "os/exec"
"path/filepath"
"runtime"
"testing" "testing"
"time" "time"
@ -42,6 +44,10 @@ type testServerConfig struct {
Ports testPortConfig `json:"ports,omitempty"` Ports testPortConfig `json:"ports,omitempty"`
} }
// Callback functions for modifying config
type configCallback func(c *Config)
type serverConfigCallback func(c *testServerConfig)
func defaultConfig() *testServerConfig { func defaultConfig() *testServerConfig {
return &testServerConfig{ return &testServerConfig{
Bootstrap: true, Bootstrap: true,
@ -72,7 +78,7 @@ func newTestServer(t *testing.T) *testServer {
return newTestServerWithConfig(t, func(c *testServerConfig) {}) return newTestServerWithConfig(t, func(c *testServerConfig) {})
} }
func newTestServerWithConfig(t *testing.T, cb func(c *testServerConfig)) *testServer { func newTestServerWithConfig(t *testing.T, cb serverConfigCallback) *testServer {
if path, err := exec.LookPath("consul"); err != nil || path == "" { if path, err := exec.LookPath("consul"); err != nil || path == "" {
t.Log("consul not found on $PATH, skipping") t.Log("consul not found on $PATH, skipping")
t.SkipNow() t.SkipNow()
@ -131,15 +137,20 @@ func makeClient(t *testing.T) (*Client, *testServer) {
}, func(c *testServerConfig) {}) }, func(c *testServerConfig) {})
} }
func makeClientWithConfig(t *testing.T, clientConfig func(c *Config), serverConfig func(c *testServerConfig)) (*Client, *testServer) { func makeClientWithConfig(t *testing.T, cb1 configCallback, cb2 serverConfigCallback) (*Client, *testServer) {
server := newTestServerWithConfig(t, serverConfig) // Make client config
conf := DefaultConfig() conf := DefaultConfig()
clientConfig(conf) cb1(conf)
// Create client
client, err := NewClient(conf) client, err := NewClient(conf)
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
// Create server
server := newTestServerWithConfig(t, cb2)
// Allow the server some time to start, and verify we have a leader. // Allow the server some time to start, and verify we have a leader.
testutil.WaitForResult(func() (bool, error) { testutil.WaitForResult(func() (bool, error) {
req := client.newRequest("GET", "/v1/catalog/nodes") req := client.newRequest("GET", "/v1/catalog/nodes")
@ -278,3 +289,35 @@ func TestParseQueryMeta(t *testing.T) {
t.Fatalf("Bad: %v", qm) t.Fatalf("Bad: %v", qm)
} }
} }
func TestAPI_UnixSocket(t *testing.T) {
if runtime.GOOS == "windows" {
t.SkipNow()
}
tempDir, err := ioutil.TempDir("", "consul")
if err != nil {
t.Fatalf("err: %s", err)
}
defer os.RemoveAll(tempDir)
socket := filepath.Join(tempDir, "test.sock")
c, s := makeClientWithConfig(t, func(c *Config) {
c.Address = "unix://" + socket
}, func(c *testServerConfig) {
c.Addresses = &testAddressConfig{
HTTP: "unix://" + socket,
}
})
defer s.stop()
agent := c.Agent()
info, err := agent.Self()
if err != nil {
t.Fatalf("err: %s", err)
}
if info["Config"]["NodeName"] == "" {
t.Fatalf("bad: %v", info)
}
}

View File

@ -1,13 +1,10 @@
package api package api
import ( import (
"io/ioutil"
"os/user"
"runtime"
"testing" "testing"
) )
func TestStatusLeaderTCP(t *testing.T) { func TestStatusLeader(t *testing.T) {
c, s := makeClient(t) c, s := makeClient(t)
defer s.stop() defer s.stop()
@ -22,48 +19,6 @@ func TestStatusLeaderTCP(t *testing.T) {
} }
} }
func TestStatusLeaderUnix(t *testing.T) {
if runtime.GOOS == "windows" {
t.SkipNow()
}
tempdir, err := ioutil.TempDir("", "consul-test-")
if err != nil {
t.Fatal("Could not create a working directory")
}
socket := "unix://" + tempdir + "/unix-http-test.sock"
clientConfig := func(c *Config) {
c.Address = socket
}
serverConfig := func(c *testServerConfig) {
user, err := user.Current()
if err != nil {
t.Fatal("Could not get current user")
}
if c.Addresses == nil {
c.Addresses = &testAddressConfig{}
}
c.Addresses.HTTP = socket + ";" + user.Uid + ";" + user.Gid + ";640"
}
c, s := makeClientWithConfig(t, clientConfig, serverConfig)
defer s.stop()
status := c.Status()
leader, err := status.Leader()
if err != nil {
t.Fatalf("err: %v", err)
}
if leader == "" {
t.Fatalf("Expected leader")
}
}
func TestStatusPeers(t *testing.T) { func TestStatusPeers(t *testing.T) {
c, s := makeClient(t) c, s := makeClient(t)
defer s.stop() defer s.stop()

View File

@ -22,6 +22,13 @@ const (
// Path to save local agent checks // Path to save local agent checks
checksDir = "checks" checksDir = "checks"
// errSocketFileExists is the human-friendly error message displayed when
// trying to bind a socket to an existing file.
errSocketFileExists = "A file exists at the requested socket path %q. " +
"If Consul was not shut down properly, the socket file may " +
"be left behind. If the path looks correct, remove the file " +
"and try again."
) )
/* /*

View File

@ -7,10 +7,8 @@ import (
"io" "io"
"io/ioutil" "io/ioutil"
"os" "os"
"os/user"
"path/filepath" "path/filepath"
"reflect" "reflect"
"runtime"
"sync/atomic" "sync/atomic"
"testing" "testing"
"time" "time"
@ -125,7 +123,7 @@ func TestAgentStartStop(t *testing.T) {
} }
} }
func TestAgent_RPCPingTCP(t *testing.T) { func TestAgent_RPCPing(t *testing.T) {
dir, agent := makeAgent(t, nextConfig()) dir, agent := makeAgent(t, nextConfig())
defer os.RemoveAll(dir) defer os.RemoveAll(dir)
defer agent.Shutdown() defer agent.Shutdown()
@ -136,35 +134,6 @@ func TestAgent_RPCPingTCP(t *testing.T) {
} }
} }
func TestAgent_RPCPingUnix(t *testing.T) {
if runtime.GOOS == "windows" {
t.SkipNow()
}
nextConf := nextConfig()
tempdir, err := ioutil.TempDir("", "consul-test-")
if err != nil {
t.Fatal("Could not create a working directory")
}
user, err := user.Current()
if err != nil {
t.Fatal("Could not get current user")
}
nextConf.Addresses.RPC = "unix://" + tempdir + "/unix-rpc-test.sock;" + user.Uid + ";" + user.Gid + ";640"
dir, agent := makeAgent(t, nextConf)
defer os.RemoveAll(dir)
defer agent.Shutdown()
var out struct{}
if err := agent.RPC("Status.Ping", struct{}{}, &out); err != nil {
t.Fatalf("err: %v", err)
}
}
func TestAgent_AddService(t *testing.T) { func TestAgent_AddService(t *testing.T) {
dir, agent := makeAgent(t, nextConfig()) dir, agent := makeAgent(t, nextConfig())
defer os.RemoveAll(dir) defer os.RemoveAll(dir)

View File

@ -295,9 +295,12 @@ func (c *Command) setupAgent(config *Config, logOutput io.Writer, logWriter *log
return err return err
} }
if _, ok := rpcAddr.(*net.UnixAddr); ok { // Error if we are trying to bind a domain socket to an existing path
// Remove the socket if it exists, or we'll get a bind error if path, ok := unixSocketAddr(config.Addresses.RPC); ok {
_ = os.Remove(rpcAddr.String()) if _, err := os.Stat(path); err == nil || !os.IsNotExist(err) {
c.Ui.Output(fmt.Sprintf(errSocketFileExists, path))
return fmt.Errorf(errSocketFileExists, path)
}
} }
rpcListener, err := net.Listen(rpcAddr.Network(), rpcAddr.String()) rpcListener, err := net.Listen(rpcAddr.Network(), rpcAddr.String())
@ -307,14 +310,6 @@ func (c *Command) setupAgent(config *Config, logOutput io.Writer, logWriter *log
return err return err
} }
if _, ok := rpcAddr.(*net.UnixAddr); ok {
if err := adjustUnixSocketPermissions(config.Addresses.RPC); err != nil {
agent.Shutdown()
c.Ui.Error(fmt.Sprintf("Error adjusting Unix socket permissions: %s", err))
return err
}
}
// Start the IPC layer // Start the IPC layer
c.Ui.Output("Starting Consul agent RPC...") c.Ui.Output("Starting Consul agent RPC...")
c.rpcServer = NewAgentRPC(agent, rpcListener, logOutput, logWriter) c.rpcServer = NewAgentRPC(agent, rpcListener, logOutput, logWriter)

View File

@ -7,11 +7,8 @@ import (
"io" "io"
"net" "net"
"os" "os"
"os/user"
"path/filepath" "path/filepath"
"regexp"
"sort" "sort"
"strconv"
"strings" "strings"
"time" "time"
@ -348,89 +345,13 @@ type Config struct {
WatchPlans []*watch.WatchPlan `mapstructure:"-" json:"-"` WatchPlans []*watch.WatchPlan `mapstructure:"-" json:"-"`
} }
// UnixSocket contains the parameters for a Unix socket interface // unixSocketAddr tests if a given address describes a domain socket,
type UnixSocket struct { // and returns the relevant path part of the string if it is.
// Path to the socket on-disk func unixSocketAddr(addr string) (string, bool) {
Path string
// uid of the owner of the socket
Uid int
// gid of the group of the socket
Gid int
// Permissions for the socket file
Permissions os.FileMode
}
func populateUnixSocket(addr string) (*UnixSocket, error) {
if !strings.HasPrefix(addr, "unix://") { if !strings.HasPrefix(addr, "unix://") {
return nil, fmt.Errorf("Failed to parse Unix address, format is unix://[path];[user];[group];[mode]: %v", addr) return "", false
} }
return strings.TrimPrefix(addr, "unix://"), true
splitAddr := strings.Split(strings.TrimPrefix(addr, "unix://"), ";")
if len(splitAddr) != 4 {
return nil, fmt.Errorf("Failed to parse Unix address, format is unix://[path];[user];[group];[mode]: %v", addr)
}
ret := &UnixSocket{Path: splitAddr[0]}
var userVal *user.User
var err error
regex := regexp.MustCompile("[\\d]+")
if regex.MatchString(splitAddr[1]) {
userVal, err = user.LookupId(splitAddr[1])
} else {
userVal, err = user.Lookup(splitAddr[1])
}
if err != nil {
return nil, fmt.Errorf("Invalid user given for Unix socket ownership: %v", splitAddr[1])
}
if uid64, err := strconv.ParseInt(userVal.Uid, 10, 32); err != nil {
return nil, fmt.Errorf("Failed to parse given user ID of %v into integer", userVal.Uid)
} else {
ret.Uid = int(uid64)
}
// Go doesn't currently have a way to look up gid from group name,
// so require a numeric gid; see
// https://codereview.appspot.com/101310044
if gid64, err := strconv.ParseInt(splitAddr[2], 10, 32); err != nil {
return nil, fmt.Errorf("Socket group must be given as numeric gid. Failed to parse given group ID of %v into integer", splitAddr[2])
} else {
ret.Gid = int(gid64)
}
if mode, err := strconv.ParseUint(splitAddr[3], 8, 32); err != nil {
return nil, fmt.Errorf("Failed to parse given mode of %v into integer", splitAddr[3])
} else {
if mode > 0777 {
return nil, fmt.Errorf("Given mode is invalid; must be an octal number between 0 and 777")
} else {
ret.Permissions = os.FileMode(mode)
}
}
return ret, nil
}
func adjustUnixSocketPermissions(addr string) error {
sock, err := populateUnixSocket(addr)
if err != nil {
return err
}
if err = os.Chown(sock.Path, sock.Uid, sock.Gid); err != nil {
return fmt.Errorf("Error attempting to change socket permissions to userid %v and groupid %v: %v", sock.Uid, sock.Gid, err)
}
if err = os.Chmod(sock.Path, sock.Permissions); err != nil {
return fmt.Errorf("Error attempting to change socket permissions to mode %v: %v", sock.Permissions, err)
}
return nil
} }
type dirEnts []os.FileInfo type dirEnts []os.FileInfo
@ -485,31 +406,14 @@ func (c *Config) ClientListener(override string, port int) (net.Addr, error) {
addr = c.ClientAddr addr = c.ClientAddr
} }
switch { if path, ok := unixSocketAddr(addr); ok {
case strings.HasPrefix(addr, "unix://"): return &net.UnixAddr{Name: path, Net: "unix"}, nil
sock, err := populateUnixSocket(addr)
if err != nil {
return nil, err
}
return &net.UnixAddr{Name: sock.Path, Net: "unix"}, nil
default:
ip := net.ParseIP(addr)
if ip == nil {
return nil, fmt.Errorf("Failed to parse IP: %v", addr)
}
if ip.IsUnspecified() {
ip = net.ParseIP("127.0.0.1")
}
if ip == nil {
return nil, fmt.Errorf("Failed to parse IP 127.0.0.1")
}
return &net.TCPAddr{IP: ip, Port: port}, nil
} }
ip := net.ParseIP(addr)
if ip == nil {
return nil, fmt.Errorf("Failed to parse IP: %v", addr)
}
return &net.TCPAddr{IP: ip, Port: port}, nil
} }
// DecodeConfig reads the configuration from the given reader in JSON // DecodeConfig reads the configuration from the given reader in JSON

View File

@ -4,12 +4,9 @@ import (
"bytes" "bytes"
"encoding/base64" "encoding/base64"
"io/ioutil" "io/ioutil"
"net"
"os" "os"
"os/user"
"path/filepath" "path/filepath"
"reflect" "reflect"
"runtime"
"strings" "strings"
"testing" "testing"
"time" "time"
@ -1073,107 +1070,13 @@ func TestReadConfigPaths_dir(t *testing.T) {
} }
func TestUnixSockets(t *testing.T) { func TestUnixSockets(t *testing.T) {
if runtime.GOOS == "windows" { path1, ok := unixSocketAddr("unix:///path/to/socket")
t.SkipNow() if !ok || path1 != "/path/to/socket" {
t.Fatalf("bad: %v %v", ok, path1)
} }
usr, err := user.Current() path2, ok := unixSocketAddr("notunix://blah")
if err != nil { if ok || path2 != "" {
t.Fatal("Could not get current user: ", err) t.Fatalf("bad: %v %v", ok, path2)
}
tempdir, err := ioutil.TempDir("", "consul-test-")
if err != nil {
t.Fatal("Could not create a working directory: ", err)
}
type SocketTestData struct {
Path string
Uid string
Gid string
Mode string
}
testUnixSocketPopulation := func(s SocketTestData) (*UnixSocket, error) {
return populateUnixSocket("unix://" + s.Path + ";" + s.Uid + ";" + s.Gid + ";" + s.Mode)
}
testUnixSocketPermissions := func(s SocketTestData) error {
return adjustUnixSocketPermissions("unix://" + s.Path + ";" + s.Uid + ";" + s.Gid + ";" + s.Mode)
}
_, err = populateUnixSocket("tcp://abc123")
if err == nil {
t.Fatal("Should have rejected invalid scheme")
}
_, err = populateUnixSocket("unix://x;y;z")
if err == nil {
t.Fatal("Should have rejected invalid number of parameters in Unix socket definition")
}
std := SocketTestData{
Path: tempdir + "/unix-config-test.sock",
Uid: usr.Uid,
Gid: usr.Gid,
Mode: "640",
}
std.Uid = "orasdfdsnfoinweroiu"
_, err = testUnixSocketPopulation(std)
if err == nil {
t.Fatal("Did not error on invalid username")
}
std.Uid = usr.Username
std.Gid = "foinfphawepofhewof"
_, err = testUnixSocketPopulation(std)
if err == nil {
t.Fatal("Did not error on invalid group (a name, must be gid)")
}
std.Gid = usr.Gid
std.Mode = "999"
_, err = testUnixSocketPopulation(std)
if err == nil {
t.Fatal("Did not error on invalid socket mode")
}
std.Uid = usr.Username
std.Mode = "640"
_, err = testUnixSocketPopulation(std)
if err != nil {
t.Fatal("Unix socket test failed (using username): ", err)
}
std.Uid = usr.Uid
sock, err := testUnixSocketPopulation(std)
if err != nil {
t.Fatal("Unix socket test failed (using uid): ", err)
}
addr := &net.UnixAddr{Name: sock.Path, Net: "unix"}
_, err = net.Listen(addr.Network(), addr.String())
if err != nil {
t.Fatal("Error creating socket for futher tests: ", err)
}
std.Uid = "-999999"
err = testUnixSocketPermissions(std)
if err == nil {
t.Fatal("Did not error on invalid uid")
}
std.Uid = usr.Uid
std.Gid = "-999999"
err = testUnixSocketPermissions(std)
if err == nil {
t.Fatal("Did not error on invalid uid")
}
std.Gid = usr.Gid
err = testUnixSocketPermissions(std)
if err != nil {
t.Fatal("Adjusting socket permissions failed: ", err)
} }
} }

View File

@ -59,28 +59,15 @@ func NewHTTPServers(agent *Agent, config *Config, logOutput io.Writer) ([]*HTTPS
return nil, err return nil, err
} }
if _, ok := httpAddr.(*net.UnixAddr); ok {
// Remove the socket if it exists, or we'll get a bind error
_ = os.Remove(httpAddr.String())
}
ln, err := net.Listen(httpAddr.Network(), httpAddr.String()) ln, err := net.Listen(httpAddr.Network(), httpAddr.String())
if err != nil { if err != nil {
return nil, fmt.Errorf("Failed to get Listen on %s: %v", httpAddr.String(), err) return nil, fmt.Errorf("Failed to get Listen on %s: %v", httpAddr.String(), err)
} }
switch httpAddr.(type) { if _, ok := unixSocketAddr(config.Addresses.HTTPS); ok {
case *net.UnixAddr:
if err := adjustUnixSocketPermissions(config.Addresses.HTTPS); err != nil {
return nil, err
}
list = tls.NewListener(ln, tlsConfig) list = tls.NewListener(ln, tlsConfig)
} else {
case *net.TCPAddr:
list = tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, tlsConfig) list = tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, tlsConfig)
default:
return nil, fmt.Errorf("Error determining address type when attempting to get Listen on %s: %v", httpAddr.String(), err)
} }
// Create the mux // Create the mux
@ -108,9 +95,11 @@ func NewHTTPServers(agent *Agent, config *Config, logOutput io.Writer) ([]*HTTPS
return nil, fmt.Errorf("Failed to get ClientListener address:port: %v", err) return nil, fmt.Errorf("Failed to get ClientListener address:port: %v", err)
} }
if _, ok := httpAddr.(*net.UnixAddr); ok { // Error if we are trying to bind a domain socket to an existing path
// Remove the socket if it exists, or we'll get a bind error if path, ok := unixSocketAddr(config.Addresses.HTTP); ok {
_ = os.Remove(httpAddr.String()) if _, err := os.Stat(path); err == nil || !os.IsNotExist(err) {
return nil, fmt.Errorf(errSocketFileExists, path)
}
} }
ln, err := net.Listen(httpAddr.Network(), httpAddr.String()) ln, err := net.Listen(httpAddr.Network(), httpAddr.String())
@ -118,18 +107,10 @@ func NewHTTPServers(agent *Agent, config *Config, logOutput io.Writer) ([]*HTTPS
return nil, fmt.Errorf("Failed to get Listen on %s: %v", httpAddr.String(), err) return nil, fmt.Errorf("Failed to get Listen on %s: %v", httpAddr.String(), err)
} }
switch httpAddr.(type) { if _, ok := unixSocketAddr(config.Addresses.HTTP); ok {
case *net.UnixAddr:
if err := adjustUnixSocketPermissions(config.Addresses.HTTP); err != nil {
return nil, err
}
list = ln list = ln
} else {
case *net.TCPAddr:
list = tcpKeepAliveListener{ln.(*net.TCPListener)} list = tcpKeepAliveListener{ln.(*net.TCPListener)}
default:
return nil, fmt.Errorf("Error determining address type when attempting to get Listen on %s: %v", httpAddr.String(), err)
} }
// Create the mux // Create the mux

View File

@ -6,10 +6,12 @@ import (
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"net"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"os" "os"
"path/filepath" "path/filepath"
"runtime"
"strconv" "strconv"
"testing" "testing"
"time" "time"
@ -19,7 +21,15 @@ import (
) )
func makeHTTPServer(t *testing.T) (string, *HTTPServer) { func makeHTTPServer(t *testing.T) (string, *HTTPServer) {
return makeHTTPServerWithConfig(t, nil)
}
func makeHTTPServerWithConfig(t *testing.T, cb func(c *Config)) (string, *HTTPServer) {
conf := nextConfig() conf := nextConfig()
if cb != nil {
cb(conf)
}
dir, agent := makeAgent(t, conf) dir, agent := makeAgent(t, conf)
uiDir := filepath.Join(dir, "ui") uiDir := filepath.Join(dir, "ui")
if err := os.Mkdir(uiDir, 755); err != nil { if err := os.Mkdir(uiDir, 755); err != nil {
@ -43,6 +53,93 @@ func encodeReq(obj interface{}) io.ReadCloser {
return ioutil.NopCloser(buf) return ioutil.NopCloser(buf)
} }
func TestHTTPServer_UnixSocket(t *testing.T) {
if runtime.GOOS == "windows" {
t.SkipNow()
}
tempDir, err := ioutil.TempDir("", "consul")
if err != nil {
t.Fatalf("err: %s", err)
}
defer os.RemoveAll(tempDir)
socket := filepath.Join(tempDir, "test.sock")
dir, srv := makeHTTPServerWithConfig(t, func(c *Config) {
c.Addresses.HTTP = "unix://" + socket
})
defer os.RemoveAll(dir)
defer srv.Shutdown()
defer srv.agent.Shutdown()
// Ensure the socket was created
if _, err := os.Stat(socket); err != nil {
t.Fatalf("err: %s", err)
}
// Ensure we can get a response from the socket.
path, _ := unixSocketAddr(srv.agent.config.Addresses.HTTP)
client := &http.Client{
Transport: &http.Transport{
Dial: func(_, _ string) (net.Conn, error) {
return net.Dial("unix", path)
},
},
}
// This URL doesn't look like it makes sense, but the scheme (http://) and
// the host (127.0.0.1) are required by the HTTP client library. In reality
// this will just use the custom dialer and talk to the socket.
resp, err := client.Get("http://127.0.0.1/v1/agent/self")
if err != nil {
t.Fatalf("err: %s", err)
}
defer resp.Body.Close()
if body, err := ioutil.ReadAll(resp.Body); err != nil || len(body) == 0 {
t.Fatalf("bad: %s %v", body, err)
}
}
func TestHTTPServer_UnixSocket_FileExists(t *testing.T) {
if runtime.GOOS == "windows" {
t.SkipNow()
}
tempDir, err := ioutil.TempDir("", "consul")
if err != nil {
t.Fatalf("err: %s", err)
}
defer os.RemoveAll(tempDir)
socket := filepath.Join(tempDir, "test.sock")
// Create a regular file at the socket path
if err := ioutil.WriteFile(socket, []byte("hello world"), 0644); err != nil {
t.Fatalf("err: %s", err)
}
fi, err := os.Stat(socket)
if err != nil {
t.Fatalf("err: %s", err)
}
if !fi.Mode().IsRegular() {
t.Fatalf("not a regular file: %s", socket)
}
conf := nextConfig()
conf.Addresses.HTTP = "unix://" + socket
dir, agent := makeAgent(t, conf)
defer os.RemoveAll(dir)
// Try to start the server with the same path anyways.
if servers, err := NewHTTPServers(agent, conf, agent.logOutput); err == nil {
for _, server := range servers {
server.Shutdown()
}
t.Fatalf("expected socket binding error")
}
}
func TestSetIndex(t *testing.T) { func TestSetIndex(t *testing.T) {
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
setIndex(resp, 1000) setIndex(resp, 1000)

View File

@ -81,24 +81,19 @@ func (c *RPCClient) send(header *requestHeader, obj interface{}) error {
// NewRPCClient is used to create a new RPC client given the address. // NewRPCClient is used to create a new RPC client given the address.
// This will properly dial, handshake, and start listening // This will properly dial, handshake, and start listening
func NewRPCClient(addr string) (*RPCClient, error) { func NewRPCClient(addr string) (*RPCClient, error) {
sanedAddr := os.Getenv("CONSUL_RPC_ADDR") var conn net.Conn
if len(sanedAddr) == 0 { var err error
sanedAddr = addr
}
mode := "tcp" if envAddr := os.Getenv("CONSUL_RPC_ADDR"); envAddr != "" {
addr = envAddr
if strings.HasPrefix(sanedAddr, "unix://") {
sanedAddr = strings.TrimPrefix(sanedAddr, "unix://")
}
if strings.HasPrefix(sanedAddr, "/") {
mode = "unix"
} }
// Try to dial to agent // Try to dial to agent
conn, err := net.Dial(mode, sanedAddr) mode := "tcp"
if err != nil { if strings.HasPrefix(addr, "/") {
mode = "unix"
}
if conn, err = net.Dial(mode, addr); err != nil {
return nil, err return nil, err
} }

View File

@ -9,7 +9,7 @@ import (
"io/ioutil" "io/ioutil"
"net" "net"
"os" "os"
"os/user" "path/filepath"
"runtime" "runtime"
"strings" "strings"
"testing" "testing"
@ -69,6 +69,38 @@ func testRPCClientWithConfig(t *testing.T, cb func(c *Config)) *rpcParts {
} }
} }
func TestRPCClient_UnixSocket(t *testing.T) {
if runtime.GOOS == "windows" {
t.SkipNow()
}
tempDir, err := ioutil.TempDir("", "consul")
if err != nil {
t.Fatalf("err: %s", err)
}
defer os.RemoveAll(tempDir)
socket := filepath.Join(tempDir, "test.sock")
p1 := testRPCClientWithConfig(t, func(c *Config) {
c.Addresses.RPC = "unix://" + socket
})
defer p1.Close()
// Ensure the socket was created
if _, err := os.Stat(socket); err != nil {
t.Fatalf("err: %s", err)
}
// Ensure we can talk with the socket
mem, err := p1.client.LANMembers()
if err != nil {
t.Fatalf("err: %s", err)
}
if len(mem) != 1 {
t.Fatalf("bad: %#v", mem)
}
}
func TestRPCClientForceLeave(t *testing.T) { func TestRPCClientForceLeave(t *testing.T) {
p1 := testRPCClient(t) p1 := testRPCClient(t)
p2 := testRPCClient(t) p2 := testRPCClient(t)
@ -216,41 +248,6 @@ func TestRPCClientStats(t *testing.T) {
} }
} }
func TestRPCClientStatsUnix(t *testing.T) {
if runtime.GOOS == "windows" {
t.SkipNow()
}
tempdir, err := ioutil.TempDir("", "consul-test-")
if err != nil {
t.Fatal("Could not create a working directory: ", err)
}
user, err := user.Current()
if err != nil {
t.Fatal("Could not get current user: ", err)
}
cb := func(c *Config) {
c.Addresses.RPC = "unix://" + tempdir + "/unix-rpc-test.sock;" + user.Uid + ";" + user.Gid + ";640"
}
p1 := testRPCClientWithConfig(t, cb)
stats, err := p1.client.Stats()
if err != nil {
t.Fatalf("err: %s", err)
}
if _, ok := stats["agent"]; !ok {
t.Fatalf("bad: %#v", stats)
}
if _, ok := stats["consul"]; !ok {
t.Fatalf("bad: %#v", stats)
}
}
func TestRPCClientLeave(t *testing.T) { func TestRPCClientLeave(t *testing.T) {
p1 := testRPCClient(t) p1 := testRPCClient(t)
defer p1.Close() defer p1.Close()

View File

@ -43,12 +43,10 @@ func HTTPClient(addr string) (*consulapi.Client, error) {
// HTTPClientDC returns a new Consul HTTP client with the given address and datacenter // HTTPClientDC returns a new Consul HTTP client with the given address and datacenter
func HTTPClientDC(addr, dc string) (*consulapi.Client, error) { func HTTPClientDC(addr, dc string) (*consulapi.Client, error) {
conf := consulapi.DefaultConfig() conf := consulapi.DefaultConfig()
switch { if envAddr := os.Getenv("CONSUL_HTTP_ADDR"); envAddr != "" {
case len(os.Getenv("CONSUL_HTTP_ADDR")) > 0: addr = envAddr
conf.Address = os.Getenv("CONSUL_HTTP_ADDR")
default:
conf.Address = addr
} }
conf.Address = addr
conf.Datacenter = dc conf.Datacenter = dc
return consulapi.NewClient(conf) return consulapi.NewClient(conf)
} }

View File

@ -239,20 +239,23 @@ definitions support being updated during a reload.
However, because the caches are not actively invalidated, ACL policy may be stale However, because the caches are not actively invalidated, ACL policy may be stale
up to the TTL value. up to the TTL value.
* `addresses` - This is a nested object that allows setting bind addresses. For `rpc` * `addresses` - This is a nested object that allows setting bind addresses.
and `http`, a Unix socket can be specified in the following form: <br><br>
unix://[/path/to/socket];[username|uid];[gid];[mode]. The socket will be created Both `rpc` and `http` support binding to Unix domain sockets. A socket can be
in the specified location with the given username or uid, gid, and mode. The specified in the form `unix:///path/to/socket`. A new domain socket will be
user Consul is running as must have appropriate permissions to change the socket created at the given path. If the specified file path already exists, Consul
ownership to the given uid or gid. When running Consul agent commands against will refuse to start and return an error. For information on how to secure
Unix socket interfaces, use the `-rpc-addr` or `-http-addr` arguments to specify socket file permissions, refer to the manual page for your operating system.
the path to the socket, e.g. "unix://path/to/socket". You can also place the desired <br><br>
values in `CONSUL_RPC_ADDR` and `CONSUL_HTTP_ADDR` environment variables. For TCP When running Consul agent commands against Unix socket interfaces, use the
addresses, these should be in the form ip:port. `-rpc-addr` or `-http-addr` arguments to specify the path to the socket. You
The following keys are valid: can also place the desired values in `CONSUL_RPC_ADDR` and `CONSUL_HTTP_ADDR`
* `dns` - The DNS server. Defaults to `client_addr` environment variables. For TCP addresses, these should be in the form ip:port.
* `http` - The HTTP API. Defaults to `client_addr` <br><br>
* `rpc` - The RPC endpoint. Defaults to `client_addr` The following keys are valid:
* `dns` - The DNS server. Defaults to `client_addr`
* `http` - The HTTP API. Defaults to `client_addr`
* `rpc` - The RPC endpoint. Defaults to `client_addr`
* `advertise_addr` - Equivalent to the `-advertise` command-line flag. * `advertise_addr` - Equivalent to the `-advertise` command-line flag.