Merge pull request #50 from multiformats/feat/unix-sockets

Add support for unix sockets
This commit is contained in:
bigs 2018-12-04 10:31:14 -05:00 committed by GitHub
commit 1879060a4f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 94 additions and 7 deletions

View File

@ -3,6 +3,7 @@ package manet
import ( import (
"fmt" "fmt"
"net" "net"
"path/filepath"
ma "github.com/multiformats/go-multiaddr" ma "github.com/multiformats/go-multiaddr"
madns "github.com/multiformats/go-multiaddr-dns" madns "github.com/multiformats/go-multiaddr-dns"
@ -61,6 +62,8 @@ func parseBasicNetMaddr(maddr ma.Multiaddr) (net.Addr, error) {
return net.ResolveUDPAddr(network, host) return net.ResolveUDPAddr(network, host)
case "ip", "ip4", "ip6": case "ip", "ip4", "ip6":
return net.ResolveIPAddr(network, host) return net.ResolveIPAddr(network, host)
case "unix":
return net.ResolveUnixAddr(network, host)
} }
return nil, fmt.Errorf("network not supported: %s", network) return nil, fmt.Errorf("network not supported: %s", network)
@ -96,7 +99,8 @@ func FromIP(ip net.IP) (ma.Multiaddr, error) {
// DialArgs is a convenience function that returns network and address as // DialArgs is a convenience function that returns network and address as
// expected by net.Dial. See https://godoc.org/net#Dial for an overview of // expected by net.Dial. See https://godoc.org/net#Dial for an overview of
// possible return values (we do not support the unix* ones yet). // possible return values (we do not support the unixpacket ones yet). Unix
// addresses do not, at present, compose.
func DialArgs(m ma.Multiaddr) (string, string, error) { func DialArgs(m ma.Multiaddr) (string, string, error) {
var ( var (
zone, network, ip, port string zone, network, ip, port string
@ -137,6 +141,10 @@ func DialArgs(m ma.Multiaddr) (string, string, error) {
hostname = true hostname = true
ip = c.Value() ip = c.Value()
return true return true
case ma.P_UNIX:
network = "unix"
ip = c.Value()
return false
} }
case "ip4": case "ip4":
switch c.Protocol().Code { switch c.Protocol().Code {
@ -184,6 +192,8 @@ func DialArgs(m ma.Multiaddr) (string, string, error) {
return network, ip + ":" + port, nil return network, ip + ":" + port, nil
} }
return network, "[" + ip + "]" + ":" + port, nil return network, "[" + ip + "]" + ":" + port, nil
case "unix":
return network, ip, nil
default: default:
return "", "", fmt.Errorf("%s is not a 'thin waist' address", m) return "", "", fmt.Errorf("%s is not a 'thin waist' address", m)
} }
@ -248,3 +258,12 @@ func parseIPPlusNetAddr(a net.Addr) (ma.Multiaddr, error) {
} }
return FromIP(ac.IP) return FromIP(ac.IP)
} }
func parseUnixNetAddr(a net.Addr) (ma.Multiaddr, error) {
ac, ok := a.(*net.UnixAddr)
if !ok {
return nil, errIncorrectNetAddr
}
cleaned := filepath.Clean(ac.Name)
return ma.NewComponent("unix", cleaned)
}

17
net.go
View File

@ -167,7 +167,7 @@ func (d *Dialer) DialContext(ctx context.Context, remote ma.Multiaddr) (Conn, er
// ok, Dial! // ok, Dial!
var nconn net.Conn var nconn net.Conn
switch rnet { switch rnet {
case "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6": case "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6", "unix":
nconn, err = d.Dialer.DialContext(ctx, rnet, rnaddr) nconn, err = d.Dialer.DialContext(ctx, rnet, rnaddr)
if err != nil { if err != nil {
return nil, err return nil, err
@ -178,7 +178,9 @@ func (d *Dialer) DialContext(ctx context.Context, remote ma.Multiaddr) (Conn, er
// get local address (pre-specified or assigned within net.Conn) // get local address (pre-specified or assigned within net.Conn)
local := d.LocalAddr local := d.LocalAddr
if local == nil { // This block helps us avoid parsing addresses in transports (such as unix
// sockets) that don't have local addresses when dialing out.
if local == nil && nconn.LocalAddr().String() != "" {
local, err = FromNetAddr(nconn.LocalAddr()) local, err = FromNetAddr(nconn.LocalAddr())
if err != nil { if err != nil {
return nil, err return nil, err
@ -243,9 +245,14 @@ func (l *maListener) Accept() (Conn, error) {
return nil, err return nil, err
} }
raddr, err := FromNetAddr(nconn.RemoteAddr()) var raddr ma.Multiaddr
if err != nil { // This block protects us in transports (i.e. unix sockets) that don't have
return nil, fmt.Errorf("failed to convert connn.RemoteAddr: %s", err) // remote addresses for inbound connections.
if nconn.RemoteAddr().String() != "" {
raddr, err = FromNetAddr(nconn.RemoteAddr())
if err != nil {
return nil, fmt.Errorf("failed to convert conn.RemoteAddr: %s", err)
}
} }
return wrap(nconn, l.laddr, raddr), nil return wrap(nconn, l.laddr, raddr), nil

View File

@ -3,9 +3,13 @@ package manet
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"io/ioutil"
"net" "net"
"os"
"path/filepath"
"sync" "sync"
"testing" "testing"
"time"
ma "github.com/multiformats/go-multiaddr" ma "github.com/multiformats/go-multiaddr"
) )
@ -75,6 +79,62 @@ func TestDial(t *testing.T) {
wg.Wait() wg.Wait()
} }
func TestUnixSockets(t *testing.T) {
dir, err := ioutil.TempDir(os.TempDir(), "manettest")
if err != nil {
t.Fatal(err)
}
path := filepath.Join(dir, "listen.sock")
maddr := newMultiaddr(t, "/unix/"+path)
listener, err := Listen(maddr)
if err != nil {
t.Fatal(err)
}
payload := []byte("hello")
// listen
done := make(chan struct{}, 1)
go func() {
conn, err := listener.Accept()
if err != nil {
t.Fatal(err)
}
defer conn.Close()
buf := make([]byte, 1024)
n, err := conn.Read(buf)
if err != nil {
t.Fatal(err)
}
if n != len(payload) {
t.Fatal("failed to read appropriate number of bytes")
}
if !bytes.Equal(buf[0:n], payload) {
t.Fatal("payload did not match")
}
done <- struct{}{}
}()
// dial
conn, err := Dial(maddr)
if err != nil {
t.Fatal(err)
}
n, err := conn.Write(payload)
if err != nil {
t.Fatal(err)
}
if n != len(payload) {
t.Fatal("failed to write appropriate number of bytes")
}
select {
case <-done:
case <-time.After(1 * time.Second):
t.Fatal("timed out waiting for read")
}
}
func TestListen(t *testing.T) { func TestListen(t *testing.T) {
maddr := newMultiaddr(t, "/ip4/127.0.0.1/tcp/4322") maddr := newMultiaddr(t, "/ip4/127.0.0.1/tcp/4322")

View File

@ -21,8 +21,9 @@ func init() {
defaultCodecs.RegisterFromNetAddr(parseUDPNetAddr, "udp", "udp4", "udp6") defaultCodecs.RegisterFromNetAddr(parseUDPNetAddr, "udp", "udp4", "udp6")
defaultCodecs.RegisterFromNetAddr(parseIPNetAddr, "ip", "ip4", "ip6") defaultCodecs.RegisterFromNetAddr(parseIPNetAddr, "ip", "ip4", "ip6")
defaultCodecs.RegisterFromNetAddr(parseIPPlusNetAddr, "ip+net") defaultCodecs.RegisterFromNetAddr(parseIPPlusNetAddr, "ip+net")
defaultCodecs.RegisterFromNetAddr(parseUnixNetAddr, "unix")
defaultCodecs.RegisterToNetAddr(parseBasicNetMaddr, "tcp", "udp", "ip6", "ip4") defaultCodecs.RegisterToNetAddr(parseBasicNetMaddr, "tcp", "udp", "ip6", "ip4", "unix")
} }
// CodecMap holds a map of NetCodecs indexed by their Protocol ID // CodecMap holds a map of NetCodecs indexed by their Protocol ID