mirror of
https://github.com/logos-messaging/go-multiaddr.git
synced 2026-01-05 22:43:10 +00:00
Merge pull request #50 from multiformats/feat/unix-sockets
Add support for unix sockets
This commit is contained in:
commit
1879060a4f
21
convert.go
21
convert.go
@ -3,6 +3,7 @@ package manet
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
ma "github.com/multiformats/go-multiaddr"
|
ma "github.com/multiformats/go-multiaddr"
|
||||||
madns "github.com/multiformats/go-multiaddr-dns"
|
madns "github.com/multiformats/go-multiaddr-dns"
|
||||||
@ -61,6 +62,8 @@ func parseBasicNetMaddr(maddr ma.Multiaddr) (net.Addr, error) {
|
|||||||
return net.ResolveUDPAddr(network, host)
|
return net.ResolveUDPAddr(network, host)
|
||||||
case "ip", "ip4", "ip6":
|
case "ip", "ip4", "ip6":
|
||||||
return net.ResolveIPAddr(network, host)
|
return net.ResolveIPAddr(network, host)
|
||||||
|
case "unix":
|
||||||
|
return net.ResolveUnixAddr(network, host)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, fmt.Errorf("network not supported: %s", network)
|
return nil, fmt.Errorf("network not supported: %s", network)
|
||||||
@ -96,7 +99,8 @@ func FromIP(ip net.IP) (ma.Multiaddr, error) {
|
|||||||
|
|
||||||
// DialArgs is a convenience function that returns network and address as
|
// DialArgs is a convenience function that returns network and address as
|
||||||
// expected by net.Dial. See https://godoc.org/net#Dial for an overview of
|
// expected by net.Dial. See https://godoc.org/net#Dial for an overview of
|
||||||
// possible return values (we do not support the unix* ones yet).
|
// possible return values (we do not support the unixpacket ones yet). Unix
|
||||||
|
// addresses do not, at present, compose.
|
||||||
func DialArgs(m ma.Multiaddr) (string, string, error) {
|
func DialArgs(m ma.Multiaddr) (string, string, error) {
|
||||||
var (
|
var (
|
||||||
zone, network, ip, port string
|
zone, network, ip, port string
|
||||||
@ -137,6 +141,10 @@ func DialArgs(m ma.Multiaddr) (string, string, error) {
|
|||||||
hostname = true
|
hostname = true
|
||||||
ip = c.Value()
|
ip = c.Value()
|
||||||
return true
|
return true
|
||||||
|
case ma.P_UNIX:
|
||||||
|
network = "unix"
|
||||||
|
ip = c.Value()
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
case "ip4":
|
case "ip4":
|
||||||
switch c.Protocol().Code {
|
switch c.Protocol().Code {
|
||||||
@ -184,6 +192,8 @@ func DialArgs(m ma.Multiaddr) (string, string, error) {
|
|||||||
return network, ip + ":" + port, nil
|
return network, ip + ":" + port, nil
|
||||||
}
|
}
|
||||||
return network, "[" + ip + "]" + ":" + port, nil
|
return network, "[" + ip + "]" + ":" + port, nil
|
||||||
|
case "unix":
|
||||||
|
return network, ip, nil
|
||||||
default:
|
default:
|
||||||
return "", "", fmt.Errorf("%s is not a 'thin waist' address", m)
|
return "", "", fmt.Errorf("%s is not a 'thin waist' address", m)
|
||||||
}
|
}
|
||||||
@ -248,3 +258,12 @@ func parseIPPlusNetAddr(a net.Addr) (ma.Multiaddr, error) {
|
|||||||
}
|
}
|
||||||
return FromIP(ac.IP)
|
return FromIP(ac.IP)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func parseUnixNetAddr(a net.Addr) (ma.Multiaddr, error) {
|
||||||
|
ac, ok := a.(*net.UnixAddr)
|
||||||
|
if !ok {
|
||||||
|
return nil, errIncorrectNetAddr
|
||||||
|
}
|
||||||
|
cleaned := filepath.Clean(ac.Name)
|
||||||
|
return ma.NewComponent("unix", cleaned)
|
||||||
|
}
|
||||||
|
|||||||
17
net.go
17
net.go
@ -167,7 +167,7 @@ func (d *Dialer) DialContext(ctx context.Context, remote ma.Multiaddr) (Conn, er
|
|||||||
// ok, Dial!
|
// ok, Dial!
|
||||||
var nconn net.Conn
|
var nconn net.Conn
|
||||||
switch rnet {
|
switch rnet {
|
||||||
case "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6":
|
case "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6", "unix":
|
||||||
nconn, err = d.Dialer.DialContext(ctx, rnet, rnaddr)
|
nconn, err = d.Dialer.DialContext(ctx, rnet, rnaddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -178,7 +178,9 @@ func (d *Dialer) DialContext(ctx context.Context, remote ma.Multiaddr) (Conn, er
|
|||||||
|
|
||||||
// get local address (pre-specified or assigned within net.Conn)
|
// get local address (pre-specified or assigned within net.Conn)
|
||||||
local := d.LocalAddr
|
local := d.LocalAddr
|
||||||
if local == nil {
|
// This block helps us avoid parsing addresses in transports (such as unix
|
||||||
|
// sockets) that don't have local addresses when dialing out.
|
||||||
|
if local == nil && nconn.LocalAddr().String() != "" {
|
||||||
local, err = FromNetAddr(nconn.LocalAddr())
|
local, err = FromNetAddr(nconn.LocalAddr())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -243,9 +245,14 @@ func (l *maListener) Accept() (Conn, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
raddr, err := FromNetAddr(nconn.RemoteAddr())
|
var raddr ma.Multiaddr
|
||||||
if err != nil {
|
// This block protects us in transports (i.e. unix sockets) that don't have
|
||||||
return nil, fmt.Errorf("failed to convert connn.RemoteAddr: %s", err)
|
// remote addresses for inbound connections.
|
||||||
|
if nconn.RemoteAddr().String() != "" {
|
||||||
|
raddr, err = FromNetAddr(nconn.RemoteAddr())
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to convert conn.RemoteAddr: %s", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return wrap(nconn, l.laddr, raddr), nil
|
return wrap(nconn, l.laddr, raddr), nil
|
||||||
|
|||||||
60
net_test.go
60
net_test.go
@ -3,9 +3,13 @@ package manet
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
"net"
|
"net"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
ma "github.com/multiformats/go-multiaddr"
|
ma "github.com/multiformats/go-multiaddr"
|
||||||
)
|
)
|
||||||
@ -75,6 +79,62 @@ func TestDial(t *testing.T) {
|
|||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUnixSockets(t *testing.T) {
|
||||||
|
dir, err := ioutil.TempDir(os.TempDir(), "manettest")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
path := filepath.Join(dir, "listen.sock")
|
||||||
|
maddr := newMultiaddr(t, "/unix/"+path)
|
||||||
|
|
||||||
|
listener, err := Listen(maddr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := []byte("hello")
|
||||||
|
|
||||||
|
// listen
|
||||||
|
done := make(chan struct{}, 1)
|
||||||
|
go func() {
|
||||||
|
conn, err := listener.Accept()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
buf := make([]byte, 1024)
|
||||||
|
n, err := conn.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if n != len(payload) {
|
||||||
|
t.Fatal("failed to read appropriate number of bytes")
|
||||||
|
}
|
||||||
|
if !bytes.Equal(buf[0:n], payload) {
|
||||||
|
t.Fatal("payload did not match")
|
||||||
|
}
|
||||||
|
done <- struct{}{}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// dial
|
||||||
|
conn, err := Dial(maddr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
n, err := conn.Write(payload)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if n != len(payload) {
|
||||||
|
t.Fatal("failed to write appropriate number of bytes")
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(1 * time.Second):
|
||||||
|
t.Fatal("timed out waiting for read")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestListen(t *testing.T) {
|
func TestListen(t *testing.T) {
|
||||||
|
|
||||||
maddr := newMultiaddr(t, "/ip4/127.0.0.1/tcp/4322")
|
maddr := newMultiaddr(t, "/ip4/127.0.0.1/tcp/4322")
|
||||||
|
|||||||
@ -21,8 +21,9 @@ func init() {
|
|||||||
defaultCodecs.RegisterFromNetAddr(parseUDPNetAddr, "udp", "udp4", "udp6")
|
defaultCodecs.RegisterFromNetAddr(parseUDPNetAddr, "udp", "udp4", "udp6")
|
||||||
defaultCodecs.RegisterFromNetAddr(parseIPNetAddr, "ip", "ip4", "ip6")
|
defaultCodecs.RegisterFromNetAddr(parseIPNetAddr, "ip", "ip4", "ip6")
|
||||||
defaultCodecs.RegisterFromNetAddr(parseIPPlusNetAddr, "ip+net")
|
defaultCodecs.RegisterFromNetAddr(parseIPPlusNetAddr, "ip+net")
|
||||||
|
defaultCodecs.RegisterFromNetAddr(parseUnixNetAddr, "unix")
|
||||||
|
|
||||||
defaultCodecs.RegisterToNetAddr(parseBasicNetMaddr, "tcp", "udp", "ip6", "ip4")
|
defaultCodecs.RegisterToNetAddr(parseBasicNetMaddr, "tcp", "udp", "ip6", "ip4", "unix")
|
||||||
}
|
}
|
||||||
|
|
||||||
// CodecMap holds a map of NetCodecs indexed by their Protocol ID
|
// CodecMap holds a map of NetCodecs indexed by their Protocol ID
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user