diff --git a/net.go b/net.go index e577d0a..8250a0b 100644 --- a/net.go +++ b/net.go @@ -4,6 +4,7 @@ import ( "fmt" "net" + utp "github.com/h2so5/utp" ma "github.com/jbenet/go-multiaddr" ) @@ -68,7 +69,7 @@ func (c *maConn) RemoteMultiaddr() ma.Multiaddr { // and RemoteAddr options are Multiaddrs, instead of net.Addrs. type Dialer struct { - // Dialer is just an embed net.Dialer, with all its options. + // Dialer is just an embedded net.Dialer, with all its options. net.Dialer // LocalAddr is the local address to use when dialing an @@ -103,9 +104,24 @@ func (d *Dialer) Dial(remote ma.Multiaddr) (Conn, error) { } // ok, Dial! - nconn, err := d.Dialer.Dial(rnet, rnaddr) - if err != nil { - return nil, err + var nconn net.Conn + switch rnet { + case "tcp": + nconn, err = d.Dialer.Dial(rnet, rnaddr) + if err != nil { + return nil, err + } + case "utp": + // construct utp dialer, with options on our net.Dialer + utpd := utp.Dialer{ + Timeout: d.Dialer.Timeout, + LocalAddr: d.Dialer.LocalAddr, + } + + nconn, err = utpd.Dial(rnet, rnaddr) + if err != nil { + return nil, err + } } // get local address (pre-specified or assigned within net.Conn) @@ -206,9 +222,18 @@ func Listen(laddr ma.Multiaddr) (Listener, error) { return nil, err } - nl, err := net.Listen(lnet, lnaddr) - if err != nil { - return nil, err + var nl net.Listener + switch lnet { + case "utp": + nl, err = utp.Listen(lnet, lnaddr) + if err != nil { + return nil, err + } + case "tcp": + nl, err = net.Listen(lnet, lnaddr) + if err != nil { + return nil, err + } } return &maListener{ diff --git a/net_test.go b/net_test.go index 0d8af9e..18c8853 100644 --- a/net_test.go +++ b/net_test.go @@ -12,7 +12,7 @@ import ( func newMultiaddr(t *testing.T, m string) ma.Multiaddr { maddr, err := ma.NewMultiaddr(m) if err != nil { - t.Fatalf("failed to construct multiaddr: %s", m) + t.Fatal("failed to construct multiaddr:", m, err) } return maddr } @@ -199,6 +199,67 @@ func TestListenAndDial(t *testing.T) { wg.Wait() } +func TestListenAndDialUTP(t *testing.T) { + + maddr := newMultiaddr(t, "/ip4/127.0.0.1/udp/4323/utp") + listener, err := Listen(maddr) + if err != nil { + t.Fatal("failed to listen") + } + + var wg sync.WaitGroup + wg.Add(1) + go func() { + + cB, err := listener.Accept() + if err != nil { + t.Fatal("failed to accept") + } + + if !cB.LocalMultiaddr().Equal(maddr) { + t.Fatal("local multiaddr not equal:", maddr, cB.LocalMultiaddr()) + } + + // echo out + buf := make([]byte, 1024) + for { + _, err := cB.Read(buf) + if err != nil { + break + } + cB.Write(buf) + } + + wg.Done() + }() + + cA, err := Dial(newMultiaddr(t, "/ip4/127.0.0.1/udp/4323/utp")) + if err != nil { + t.Fatal("failed to dial", err) + } + + buf := make([]byte, 1024) + if _, err := cA.Write([]byte("beep boop")); err != nil { + t.Fatal("failed to write:", err) + } + + if _, err := cA.Read(buf); err != nil { + t.Fatal("failed to read:", buf, err) + } + + if !bytes.Equal(buf[:9], []byte("beep boop")) { + t.Fatal("failed to echo:", buf) + } + + maddr2 := cA.RemoteMultiaddr() + if !maddr2.Equal(maddr) { + t.Fatal("remote multiaddr not equal:", maddr, maddr2) + } + + cA.Close() + wg.Wait() +} + func TestIPLoopback(t *testing.T) { if IP4Loopback.String() != "/ip4/127.0.0.1" { t.Error("IP4Loopback incorrect:", IP4Loopback)