Merge branch 'udp-tracker-no-dial' into te

This commit is contained in:
Matt Joiner 2021-11-29 11:19:54 +11:00
commit 5dba8f96e4
4 changed files with 83 additions and 39 deletions

View File

@ -4,13 +4,17 @@ import (
"bytes"
"context"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"sync"
"time"
"github.com/anacrolix/dht/v2/krpc"
)
// Client interacts with UDP trackers via its Writer and Dispatcher. It has no knowledge of
// connection specifics.
type Client struct {
mu sync.Mutex
connId ConnectionId
@ -20,11 +24,16 @@ type Client struct {
}
func (cl *Client) Announce(
ctx context.Context, req AnnounceRequest, peers AnnounceResponsePeers, opts Options,
ctx context.Context, req AnnounceRequest, opts Options,
// Decides whether the response body is IPv6 or IPv4, see BEP 15.
ipv6 func(net.Addr) bool,
) (
respHdr AnnounceResponseHeader, err error,
respHdr AnnounceResponseHeader,
// A slice of krpc.NodeAddr, likely wrapped in an appropriate unmarshalling wrapper.
peers AnnounceResponsePeers,
err error,
) {
respBody, err := cl.request(ctx, ActionAnnounce, append(mustMarshal(req), opts.Encode()...))
respBody, addr, err := cl.request(ctx, ActionAnnounce, append(mustMarshal(req), opts.Encode()...))
if err != nil {
return
}
@ -34,6 +43,11 @@ func (cl *Client) Announce(
err = fmt.Errorf("reading response header: %w", err)
return
}
if ipv6(addr) {
peers = &krpc.CompactIPv6NodeAddrs{}
} else {
peers = &krpc.CompactIPv4NodeAddrs{}
}
err = peers.UnmarshalBinary(r.Bytes())
if err != nil {
err = fmt.Errorf("reading response peers: %w", err)
@ -41,13 +55,13 @@ func (cl *Client) Announce(
return
}
// There's no way to pass options in a scrape, since we don't when the request body ends.
func (cl *Client) Scrape(
ctx context.Context, ihs []InfoHash,
) (
out ScrapeResponse, err error,
) {
// There's no way to pass options in a scrape, since we don't when the request body ends.
respBody, err := cl.request(ctx, ActionScrape, mustMarshal(ScrapeRequest(ihs)))
respBody, _, err := cl.request(ctx, ActionScrape, mustMarshal(ScrapeRequest(ihs)))
if err != nil {
return
}
@ -75,7 +89,7 @@ func (cl *Client) connect(ctx context.Context) (err error) {
if !cl.connIdIssued.IsZero() && time.Since(cl.connIdIssued) < time.Minute {
return nil
}
respBody, err := cl.request(ctx, ActionConnect, nil)
respBody, _, err := cl.request(ctx, ActionConnect, nil)
if err != nil {
return err
}
@ -132,7 +146,7 @@ func (cl *Client) requestWriter(ctx context.Context, action Action, body []byte,
}
}
func (cl *Client) request(ctx context.Context, action Action, body []byte) (respBody []byte, err error) {
func (cl *Client) request(ctx context.Context, action Action, body []byte) (respBody []byte, addr net.Addr, err error) {
respChan := make(chan DispatchedResponse, 1)
t := cl.Dispatcher.NewTransaction(func(dr DispatchedResponse) {
respChan <- dr
@ -148,8 +162,9 @@ func (cl *Client) request(ctx context.Context, action Action, body []byte) (resp
case dr := <-respChan:
if dr.Header.Action == action {
respBody = dr.Body
addr = dr.Addr
} else if dr.Header.Action == ActionError {
err = errors.New(string(dr.Body))
err = fmt.Errorf("error response: %s", dr.Body)
} else {
err = fmt.Errorf("unexpected response action %v", dr.Header.Action)
}

View File

@ -2,44 +2,52 @@ package udp
import (
"context"
"log"
"net"
"github.com/anacrolix/dht/v2/krpc"
"github.com/anacrolix/missinggo/v2"
)
type NewConnClientOpts struct {
// The network to operate to use, such as "udp4", "udp", "udp6".
Network string
Host string
Ipv6 *bool
// Tracker address
Host string
// If non-nil, forces either IPv4 or IPv6 in the UDP tracker wire protocol.
Ipv6 *bool
}
// Manages a Client with a specific connection.
type ConnClient struct {
Client Client
conn net.Conn
conn net.PacketConn
d Dispatcher
readErr error
ipv6 bool
closed bool
newOpts NewConnClientOpts
}
func (cc *ConnClient) reader() {
b := make([]byte, 0x800)
for {
n, err := cc.conn.Read(b)
n, addr, err := cc.conn.ReadFrom(b)
if err != nil {
// TODO: Do bad things to the dispatcher, and incoming calls to the client if we have a
// read error.
cc.readErr = err
if !cc.closed {
panic(err)
}
break
}
_ = cc.d.Dispatch(b[:n])
// if err != nil {
// log.Printf("dispatching packet received on %v (%q): %v", cc.conn, string(b[:n]), err)
// }
err = cc.d.Dispatch(b[:n], addr)
if err != nil {
log.Printf("dispatching packet received on %v: %v", cc.conn.LocalAddr(), err)
}
}
}
func ipv6(opt *bool, network string, conn net.Conn) bool {
func ipv6(opt *bool, network string, remoteAddr net.Addr) bool {
if opt != nil {
return *opt
}
@ -49,21 +57,40 @@ func ipv6(opt *bool, network string, conn net.Conn) bool {
case "udp6":
return true
}
rip := missinggo.AddrIP(conn.RemoteAddr())
rip := missinggo.AddrIP(remoteAddr)
return rip.To16() != nil && rip.To4() == nil
}
// Allows a UDP Client to write packets to an endpoint without knowing about the network specifics.
type clientWriter struct {
pc net.PacketConn
network string
address string
}
func (me clientWriter) Write(p []byte) (n int, err error) {
addr, err := net.ResolveUDPAddr(me.network, me.address)
if err != nil {
return
}
return me.pc.WriteTo(p, addr)
}
func NewConnClient(opts NewConnClientOpts) (cc *ConnClient, err error) {
conn, err := net.Dial(opts.Network, opts.Host)
conn, err := net.ListenPacket(opts.Network, ":0")
if err != nil {
return
}
cc = &ConnClient{
Client: Client{
Writer: conn,
Writer: clientWriter{
pc: conn,
network: opts.Network,
address: opts.Host,
},
},
conn: conn,
ipv6: ipv6(opts.Ipv6, opts.Network, conn),
conn: conn,
newOpts: opts,
}
cc.Client.Dispatcher = &cc.d
go cc.reader()
@ -71,6 +98,7 @@ func NewConnClient(opts NewConnClientOpts) (cc *ConnClient, err error) {
}
func (c *ConnClient) Close() error {
c.closed = true
return c.conn.Close()
}
@ -79,13 +107,7 @@ func (c *ConnClient) Announce(
) (
h AnnounceResponseHeader, nas AnnounceResponsePeers, err error,
) {
nas = func() AnnounceResponsePeers {
if c.ipv6 {
return &krpc.CompactIPv6NodeAddrs{}
} else {
return &krpc.CompactIPv4NodeAddrs{}
}
}()
h, err = c.Client.Announce(ctx, req, nas, opts)
return
return c.Client.Announce(ctx, req, opts, func(addr net.Addr) bool {
return ipv6(c.newOpts.Ipv6, c.newOpts.Network, addr)
})
}

View File

@ -3,16 +3,18 @@ package udp
import (
"bytes"
"fmt"
"net"
"sync"
)
// Maintains a mapping of transaction IDs to handlers.
type Dispatcher struct {
mu sync.RWMutex
transactions map[TransactionId]Transaction
}
// The caller owns b.
func (me *Dispatcher) Dispatch(b []byte) error {
func (me *Dispatcher) Dispatch(b []byte, addr net.Addr) error {
buf := bytes.NewBuffer(b)
var rh ResponseHeader
err := Read(buf, &rh)
@ -25,6 +27,7 @@ func (me *Dispatcher) Dispatch(b []byte) error {
t.h(DispatchedResponse{
Header: rh,
Body: append([]byte(nil), buf.Bytes()...),
Addr: addr,
})
return nil
} else {
@ -61,5 +64,8 @@ func (me *Dispatcher) NewTransaction(h TransactionResponseHandler) Transaction {
type DispatchedResponse struct {
Header ResponseHeader
Body []byte
// Response payload, after the header.
Body []byte
// Response source address
Addr net.Addr
}

View File

@ -40,7 +40,7 @@ func TestAnnounceLocalhost(t *testing.T) {
},
}
var err error
srv.pc, err = net.ListenPacket("udp", ":0")
srv.pc, err = net.ListenPacket("udp", "localhost:0")
require.NoError(t, err)
defer srv.pc.Close()
go func() {
@ -92,7 +92,7 @@ func TestUDPTracker(t *testing.T) {
t.Skip(err)
}
require.NoError(t, err)
t.Log(ar)
t.Logf("%+v", ar)
}
func TestAnnounceRandomInfoHashThirdParty(t *testing.T) {
@ -143,7 +143,7 @@ func TestAnnounceRandomInfoHashThirdParty(t *testing.T) {
// Check that URLPath option is done correctly.
func TestURLPathOption(t *testing.T) {
conn, err := net.ListenUDP("udp", nil)
conn, err := net.ListenPacket("udp", "localhost:0")
if err != nil {
panic(err)
}
@ -161,6 +161,7 @@ func TestURLPathOption(t *testing.T) {
announceErr <- err
}()
var b [512]byte
// conn.SetReadDeadline(time.Now().Add(time.Second))
_, addr, _ := conn.ReadFrom(b[:])
r := bytes.NewReader(b[:])
var h udp.RequestHeader