Merge branch 'udp-tracker-no-dial' into te

This commit is contained in:
Matt Joiner 2021-11-29 11:19:54 +11:00
commit 5dba8f96e4
4 changed files with 83 additions and 39 deletions

View File

@ -4,13 +4,17 @@ import (
"bytes" "bytes"
"context" "context"
"encoding/binary" "encoding/binary"
"errors"
"fmt" "fmt"
"io" "io"
"net"
"sync" "sync"
"time" "time"
"github.com/anacrolix/dht/v2/krpc"
) )
// Client interacts with UDP trackers via its Writer and Dispatcher. It has no knowledge of
// connection specifics.
type Client struct { type Client struct {
mu sync.Mutex mu sync.Mutex
connId ConnectionId connId ConnectionId
@ -20,11 +24,16 @@ type Client struct {
} }
func (cl *Client) Announce( func (cl *Client) Announce(
ctx context.Context, req AnnounceRequest, peers AnnounceResponsePeers, opts Options, ctx context.Context, req AnnounceRequest, opts Options,
// Decides whether the response body is IPv6 or IPv4, see BEP 15.
ipv6 func(net.Addr) bool,
) ( ) (
respHdr AnnounceResponseHeader, err error, respHdr AnnounceResponseHeader,
// A slice of krpc.NodeAddr, likely wrapped in an appropriate unmarshalling wrapper.
peers AnnounceResponsePeers,
err error,
) { ) {
respBody, err := cl.request(ctx, ActionAnnounce, append(mustMarshal(req), opts.Encode()...)) respBody, addr, err := cl.request(ctx, ActionAnnounce, append(mustMarshal(req), opts.Encode()...))
if err != nil { if err != nil {
return return
} }
@ -34,6 +43,11 @@ func (cl *Client) Announce(
err = fmt.Errorf("reading response header: %w", err) err = fmt.Errorf("reading response header: %w", err)
return return
} }
if ipv6(addr) {
peers = &krpc.CompactIPv6NodeAddrs{}
} else {
peers = &krpc.CompactIPv4NodeAddrs{}
}
err = peers.UnmarshalBinary(r.Bytes()) err = peers.UnmarshalBinary(r.Bytes())
if err != nil { if err != nil {
err = fmt.Errorf("reading response peers: %w", err) err = fmt.Errorf("reading response peers: %w", err)
@ -41,13 +55,13 @@ func (cl *Client) Announce(
return return
} }
// There's no way to pass options in a scrape, since we don't when the request body ends.
func (cl *Client) Scrape( func (cl *Client) Scrape(
ctx context.Context, ihs []InfoHash, ctx context.Context, ihs []InfoHash,
) ( ) (
out ScrapeResponse, err error, out ScrapeResponse, err error,
) { ) {
// There's no way to pass options in a scrape, since we don't when the request body ends. respBody, _, err := cl.request(ctx, ActionScrape, mustMarshal(ScrapeRequest(ihs)))
respBody, err := cl.request(ctx, ActionScrape, mustMarshal(ScrapeRequest(ihs)))
if err != nil { if err != nil {
return return
} }
@ -75,7 +89,7 @@ func (cl *Client) connect(ctx context.Context) (err error) {
if !cl.connIdIssued.IsZero() && time.Since(cl.connIdIssued) < time.Minute { if !cl.connIdIssued.IsZero() && time.Since(cl.connIdIssued) < time.Minute {
return nil return nil
} }
respBody, err := cl.request(ctx, ActionConnect, nil) respBody, _, err := cl.request(ctx, ActionConnect, nil)
if err != nil { if err != nil {
return err return err
} }
@ -132,7 +146,7 @@ func (cl *Client) requestWriter(ctx context.Context, action Action, body []byte,
} }
} }
func (cl *Client) request(ctx context.Context, action Action, body []byte) (respBody []byte, err error) { func (cl *Client) request(ctx context.Context, action Action, body []byte) (respBody []byte, addr net.Addr, err error) {
respChan := make(chan DispatchedResponse, 1) respChan := make(chan DispatchedResponse, 1)
t := cl.Dispatcher.NewTransaction(func(dr DispatchedResponse) { t := cl.Dispatcher.NewTransaction(func(dr DispatchedResponse) {
respChan <- dr respChan <- dr
@ -148,8 +162,9 @@ func (cl *Client) request(ctx context.Context, action Action, body []byte) (resp
case dr := <-respChan: case dr := <-respChan:
if dr.Header.Action == action { if dr.Header.Action == action {
respBody = dr.Body respBody = dr.Body
addr = dr.Addr
} else if dr.Header.Action == ActionError { } else if dr.Header.Action == ActionError {
err = errors.New(string(dr.Body)) err = fmt.Errorf("error response: %s", dr.Body)
} else { } else {
err = fmt.Errorf("unexpected response action %v", dr.Header.Action) err = fmt.Errorf("unexpected response action %v", dr.Header.Action)
} }

View File

@ -2,44 +2,52 @@ package udp
import ( import (
"context" "context"
"log"
"net" "net"
"github.com/anacrolix/dht/v2/krpc"
"github.com/anacrolix/missinggo/v2" "github.com/anacrolix/missinggo/v2"
) )
type NewConnClientOpts struct { type NewConnClientOpts struct {
// The network to operate to use, such as "udp4", "udp", "udp6".
Network string Network string
Host string // Tracker address
Ipv6 *bool Host string
// If non-nil, forces either IPv4 or IPv6 in the UDP tracker wire protocol.
Ipv6 *bool
} }
// Manages a Client with a specific connection.
type ConnClient struct { type ConnClient struct {
Client Client Client Client
conn net.Conn conn net.PacketConn
d Dispatcher d Dispatcher
readErr error readErr error
ipv6 bool closed bool
newOpts NewConnClientOpts
} }
func (cc *ConnClient) reader() { func (cc *ConnClient) reader() {
b := make([]byte, 0x800) b := make([]byte, 0x800)
for { for {
n, err := cc.conn.Read(b) n, addr, err := cc.conn.ReadFrom(b)
if err != nil { if err != nil {
// TODO: Do bad things to the dispatcher, and incoming calls to the client if we have a // TODO: Do bad things to the dispatcher, and incoming calls to the client if we have a
// read error. // read error.
cc.readErr = err cc.readErr = err
if !cc.closed {
panic(err)
}
break break
} }
_ = cc.d.Dispatch(b[:n]) err = cc.d.Dispatch(b[:n], addr)
// if err != nil { if err != nil {
// log.Printf("dispatching packet received on %v (%q): %v", cc.conn, string(b[:n]), err) log.Printf("dispatching packet received on %v: %v", cc.conn.LocalAddr(), err)
// } }
} }
} }
func ipv6(opt *bool, network string, conn net.Conn) bool { func ipv6(opt *bool, network string, remoteAddr net.Addr) bool {
if opt != nil { if opt != nil {
return *opt return *opt
} }
@ -49,21 +57,40 @@ func ipv6(opt *bool, network string, conn net.Conn) bool {
case "udp6": case "udp6":
return true return true
} }
rip := missinggo.AddrIP(conn.RemoteAddr()) rip := missinggo.AddrIP(remoteAddr)
return rip.To16() != nil && rip.To4() == nil return rip.To16() != nil && rip.To4() == nil
} }
// Allows a UDP Client to write packets to an endpoint without knowing about the network specifics.
type clientWriter struct {
pc net.PacketConn
network string
address string
}
func (me clientWriter) Write(p []byte) (n int, err error) {
addr, err := net.ResolveUDPAddr(me.network, me.address)
if err != nil {
return
}
return me.pc.WriteTo(p, addr)
}
func NewConnClient(opts NewConnClientOpts) (cc *ConnClient, err error) { func NewConnClient(opts NewConnClientOpts) (cc *ConnClient, err error) {
conn, err := net.Dial(opts.Network, opts.Host) conn, err := net.ListenPacket(opts.Network, ":0")
if err != nil { if err != nil {
return return
} }
cc = &ConnClient{ cc = &ConnClient{
Client: Client{ Client: Client{
Writer: conn, Writer: clientWriter{
pc: conn,
network: opts.Network,
address: opts.Host,
},
}, },
conn: conn, conn: conn,
ipv6: ipv6(opts.Ipv6, opts.Network, conn), newOpts: opts,
} }
cc.Client.Dispatcher = &cc.d cc.Client.Dispatcher = &cc.d
go cc.reader() go cc.reader()
@ -71,6 +98,7 @@ func NewConnClient(opts NewConnClientOpts) (cc *ConnClient, err error) {
} }
func (c *ConnClient) Close() error { func (c *ConnClient) Close() error {
c.closed = true
return c.conn.Close() return c.conn.Close()
} }
@ -79,13 +107,7 @@ func (c *ConnClient) Announce(
) ( ) (
h AnnounceResponseHeader, nas AnnounceResponsePeers, err error, h AnnounceResponseHeader, nas AnnounceResponsePeers, err error,
) { ) {
nas = func() AnnounceResponsePeers { return c.Client.Announce(ctx, req, opts, func(addr net.Addr) bool {
if c.ipv6 { return ipv6(c.newOpts.Ipv6, c.newOpts.Network, addr)
return &krpc.CompactIPv6NodeAddrs{} })
} else {
return &krpc.CompactIPv4NodeAddrs{}
}
}()
h, err = c.Client.Announce(ctx, req, nas, opts)
return
} }

View File

@ -3,16 +3,18 @@ package udp
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"net"
"sync" "sync"
) )
// Maintains a mapping of transaction IDs to handlers.
type Dispatcher struct { type Dispatcher struct {
mu sync.RWMutex mu sync.RWMutex
transactions map[TransactionId]Transaction transactions map[TransactionId]Transaction
} }
// The caller owns b. // The caller owns b.
func (me *Dispatcher) Dispatch(b []byte) error { func (me *Dispatcher) Dispatch(b []byte, addr net.Addr) error {
buf := bytes.NewBuffer(b) buf := bytes.NewBuffer(b)
var rh ResponseHeader var rh ResponseHeader
err := Read(buf, &rh) err := Read(buf, &rh)
@ -25,6 +27,7 @@ func (me *Dispatcher) Dispatch(b []byte) error {
t.h(DispatchedResponse{ t.h(DispatchedResponse{
Header: rh, Header: rh,
Body: append([]byte(nil), buf.Bytes()...), Body: append([]byte(nil), buf.Bytes()...),
Addr: addr,
}) })
return nil return nil
} else { } else {
@ -61,5 +64,8 @@ func (me *Dispatcher) NewTransaction(h TransactionResponseHandler) Transaction {
type DispatchedResponse struct { type DispatchedResponse struct {
Header ResponseHeader Header ResponseHeader
Body []byte // Response payload, after the header.
Body []byte
// Response source address
Addr net.Addr
} }

View File

@ -40,7 +40,7 @@ func TestAnnounceLocalhost(t *testing.T) {
}, },
} }
var err error var err error
srv.pc, err = net.ListenPacket("udp", ":0") srv.pc, err = net.ListenPacket("udp", "localhost:0")
require.NoError(t, err) require.NoError(t, err)
defer srv.pc.Close() defer srv.pc.Close()
go func() { go func() {
@ -92,7 +92,7 @@ func TestUDPTracker(t *testing.T) {
t.Skip(err) t.Skip(err)
} }
require.NoError(t, err) require.NoError(t, err)
t.Log(ar) t.Logf("%+v", ar)
} }
func TestAnnounceRandomInfoHashThirdParty(t *testing.T) { func TestAnnounceRandomInfoHashThirdParty(t *testing.T) {
@ -143,7 +143,7 @@ func TestAnnounceRandomInfoHashThirdParty(t *testing.T) {
// Check that URLPath option is done correctly. // Check that URLPath option is done correctly.
func TestURLPathOption(t *testing.T) { func TestURLPathOption(t *testing.T) {
conn, err := net.ListenUDP("udp", nil) conn, err := net.ListenPacket("udp", "localhost:0")
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -161,6 +161,7 @@ func TestURLPathOption(t *testing.T) {
announceErr <- err announceErr <- err
}() }()
var b [512]byte var b [512]byte
// conn.SetReadDeadline(time.Now().Add(time.Second))
_, addr, _ := conn.ReadFrom(b[:]) _, addr, _ := conn.ReadFrom(b[:])
r := bytes.NewReader(b[:]) r := bytes.NewReader(b[:])
var h udp.RequestHeader var h udp.RequestHeader