Merge branch 'udp-tracker-no-dial' into te
This commit is contained in:
commit
5dba8f96e4
|
@ -4,13 +4,17 @@ import (
|
|||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/anacrolix/dht/v2/krpc"
|
||||
)
|
||||
|
||||
// Client interacts with UDP trackers via its Writer and Dispatcher. It has no knowledge of
|
||||
// connection specifics.
|
||||
type Client struct {
|
||||
mu sync.Mutex
|
||||
connId ConnectionId
|
||||
|
@ -20,11 +24,16 @@ type Client struct {
|
|||
}
|
||||
|
||||
func (cl *Client) Announce(
|
||||
ctx context.Context, req AnnounceRequest, peers AnnounceResponsePeers, opts Options,
|
||||
ctx context.Context, req AnnounceRequest, opts Options,
|
||||
// Decides whether the response body is IPv6 or IPv4, see BEP 15.
|
||||
ipv6 func(net.Addr) bool,
|
||||
) (
|
||||
respHdr AnnounceResponseHeader, err error,
|
||||
respHdr AnnounceResponseHeader,
|
||||
// A slice of krpc.NodeAddr, likely wrapped in an appropriate unmarshalling wrapper.
|
||||
peers AnnounceResponsePeers,
|
||||
err error,
|
||||
) {
|
||||
respBody, err := cl.request(ctx, ActionAnnounce, append(mustMarshal(req), opts.Encode()...))
|
||||
respBody, addr, err := cl.request(ctx, ActionAnnounce, append(mustMarshal(req), opts.Encode()...))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -34,6 +43,11 @@ func (cl *Client) Announce(
|
|||
err = fmt.Errorf("reading response header: %w", err)
|
||||
return
|
||||
}
|
||||
if ipv6(addr) {
|
||||
peers = &krpc.CompactIPv6NodeAddrs{}
|
||||
} else {
|
||||
peers = &krpc.CompactIPv4NodeAddrs{}
|
||||
}
|
||||
err = peers.UnmarshalBinary(r.Bytes())
|
||||
if err != nil {
|
||||
err = fmt.Errorf("reading response peers: %w", err)
|
||||
|
@ -41,13 +55,13 @@ func (cl *Client) Announce(
|
|||
return
|
||||
}
|
||||
|
||||
// There's no way to pass options in a scrape, since we don't when the request body ends.
|
||||
func (cl *Client) Scrape(
|
||||
ctx context.Context, ihs []InfoHash,
|
||||
) (
|
||||
out ScrapeResponse, err error,
|
||||
) {
|
||||
// There's no way to pass options in a scrape, since we don't when the request body ends.
|
||||
respBody, err := cl.request(ctx, ActionScrape, mustMarshal(ScrapeRequest(ihs)))
|
||||
respBody, _, err := cl.request(ctx, ActionScrape, mustMarshal(ScrapeRequest(ihs)))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -75,7 +89,7 @@ func (cl *Client) connect(ctx context.Context) (err error) {
|
|||
if !cl.connIdIssued.IsZero() && time.Since(cl.connIdIssued) < time.Minute {
|
||||
return nil
|
||||
}
|
||||
respBody, err := cl.request(ctx, ActionConnect, nil)
|
||||
respBody, _, err := cl.request(ctx, ActionConnect, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -132,7 +146,7 @@ func (cl *Client) requestWriter(ctx context.Context, action Action, body []byte,
|
|||
}
|
||||
}
|
||||
|
||||
func (cl *Client) request(ctx context.Context, action Action, body []byte) (respBody []byte, err error) {
|
||||
func (cl *Client) request(ctx context.Context, action Action, body []byte) (respBody []byte, addr net.Addr, err error) {
|
||||
respChan := make(chan DispatchedResponse, 1)
|
||||
t := cl.Dispatcher.NewTransaction(func(dr DispatchedResponse) {
|
||||
respChan <- dr
|
||||
|
@ -148,8 +162,9 @@ func (cl *Client) request(ctx context.Context, action Action, body []byte) (resp
|
|||
case dr := <-respChan:
|
||||
if dr.Header.Action == action {
|
||||
respBody = dr.Body
|
||||
addr = dr.Addr
|
||||
} else if dr.Header.Action == ActionError {
|
||||
err = errors.New(string(dr.Body))
|
||||
err = fmt.Errorf("error response: %s", dr.Body)
|
||||
} else {
|
||||
err = fmt.Errorf("unexpected response action %v", dr.Header.Action)
|
||||
}
|
||||
|
|
|
@ -2,44 +2,52 @@ package udp
|
|||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"net"
|
||||
|
||||
"github.com/anacrolix/dht/v2/krpc"
|
||||
"github.com/anacrolix/missinggo/v2"
|
||||
)
|
||||
|
||||
type NewConnClientOpts struct {
|
||||
// The network to operate to use, such as "udp4", "udp", "udp6".
|
||||
Network string
|
||||
Host string
|
||||
Ipv6 *bool
|
||||
// Tracker address
|
||||
Host string
|
||||
// If non-nil, forces either IPv4 or IPv6 in the UDP tracker wire protocol.
|
||||
Ipv6 *bool
|
||||
}
|
||||
|
||||
// Manages a Client with a specific connection.
|
||||
type ConnClient struct {
|
||||
Client Client
|
||||
conn net.Conn
|
||||
conn net.PacketConn
|
||||
d Dispatcher
|
||||
readErr error
|
||||
ipv6 bool
|
||||
closed bool
|
||||
newOpts NewConnClientOpts
|
||||
}
|
||||
|
||||
func (cc *ConnClient) reader() {
|
||||
b := make([]byte, 0x800)
|
||||
for {
|
||||
n, err := cc.conn.Read(b)
|
||||
n, addr, err := cc.conn.ReadFrom(b)
|
||||
if err != nil {
|
||||
// TODO: Do bad things to the dispatcher, and incoming calls to the client if we have a
|
||||
// read error.
|
||||
cc.readErr = err
|
||||
if !cc.closed {
|
||||
panic(err)
|
||||
}
|
||||
break
|
||||
}
|
||||
_ = cc.d.Dispatch(b[:n])
|
||||
// if err != nil {
|
||||
// log.Printf("dispatching packet received on %v (%q): %v", cc.conn, string(b[:n]), err)
|
||||
// }
|
||||
err = cc.d.Dispatch(b[:n], addr)
|
||||
if err != nil {
|
||||
log.Printf("dispatching packet received on %v: %v", cc.conn.LocalAddr(), err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func ipv6(opt *bool, network string, conn net.Conn) bool {
|
||||
func ipv6(opt *bool, network string, remoteAddr net.Addr) bool {
|
||||
if opt != nil {
|
||||
return *opt
|
||||
}
|
||||
|
@ -49,21 +57,40 @@ func ipv6(opt *bool, network string, conn net.Conn) bool {
|
|||
case "udp6":
|
||||
return true
|
||||
}
|
||||
rip := missinggo.AddrIP(conn.RemoteAddr())
|
||||
rip := missinggo.AddrIP(remoteAddr)
|
||||
return rip.To16() != nil && rip.To4() == nil
|
||||
}
|
||||
|
||||
// Allows a UDP Client to write packets to an endpoint without knowing about the network specifics.
|
||||
type clientWriter struct {
|
||||
pc net.PacketConn
|
||||
network string
|
||||
address string
|
||||
}
|
||||
|
||||
func (me clientWriter) Write(p []byte) (n int, err error) {
|
||||
addr, err := net.ResolveUDPAddr(me.network, me.address)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return me.pc.WriteTo(p, addr)
|
||||
}
|
||||
|
||||
func NewConnClient(opts NewConnClientOpts) (cc *ConnClient, err error) {
|
||||
conn, err := net.Dial(opts.Network, opts.Host)
|
||||
conn, err := net.ListenPacket(opts.Network, ":0")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
cc = &ConnClient{
|
||||
Client: Client{
|
||||
Writer: conn,
|
||||
Writer: clientWriter{
|
||||
pc: conn,
|
||||
network: opts.Network,
|
||||
address: opts.Host,
|
||||
},
|
||||
},
|
||||
conn: conn,
|
||||
ipv6: ipv6(opts.Ipv6, opts.Network, conn),
|
||||
conn: conn,
|
||||
newOpts: opts,
|
||||
}
|
||||
cc.Client.Dispatcher = &cc.d
|
||||
go cc.reader()
|
||||
|
@ -71,6 +98,7 @@ func NewConnClient(opts NewConnClientOpts) (cc *ConnClient, err error) {
|
|||
}
|
||||
|
||||
func (c *ConnClient) Close() error {
|
||||
c.closed = true
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
|
@ -79,13 +107,7 @@ func (c *ConnClient) Announce(
|
|||
) (
|
||||
h AnnounceResponseHeader, nas AnnounceResponsePeers, err error,
|
||||
) {
|
||||
nas = func() AnnounceResponsePeers {
|
||||
if c.ipv6 {
|
||||
return &krpc.CompactIPv6NodeAddrs{}
|
||||
} else {
|
||||
return &krpc.CompactIPv4NodeAddrs{}
|
||||
}
|
||||
}()
|
||||
h, err = c.Client.Announce(ctx, req, nas, opts)
|
||||
return
|
||||
return c.Client.Announce(ctx, req, opts, func(addr net.Addr) bool {
|
||||
return ipv6(c.newOpts.Ipv6, c.newOpts.Network, addr)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -3,16 +3,18 @@ package udp
|
|||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Maintains a mapping of transaction IDs to handlers.
|
||||
type Dispatcher struct {
|
||||
mu sync.RWMutex
|
||||
transactions map[TransactionId]Transaction
|
||||
}
|
||||
|
||||
// The caller owns b.
|
||||
func (me *Dispatcher) Dispatch(b []byte) error {
|
||||
func (me *Dispatcher) Dispatch(b []byte, addr net.Addr) error {
|
||||
buf := bytes.NewBuffer(b)
|
||||
var rh ResponseHeader
|
||||
err := Read(buf, &rh)
|
||||
|
@ -25,6 +27,7 @@ func (me *Dispatcher) Dispatch(b []byte) error {
|
|||
t.h(DispatchedResponse{
|
||||
Header: rh,
|
||||
Body: append([]byte(nil), buf.Bytes()...),
|
||||
Addr: addr,
|
||||
})
|
||||
return nil
|
||||
} else {
|
||||
|
@ -61,5 +64,8 @@ func (me *Dispatcher) NewTransaction(h TransactionResponseHandler) Transaction {
|
|||
|
||||
type DispatchedResponse struct {
|
||||
Header ResponseHeader
|
||||
Body []byte
|
||||
// Response payload, after the header.
|
||||
Body []byte
|
||||
// Response source address
|
||||
Addr net.Addr
|
||||
}
|
||||
|
|
|
@ -40,7 +40,7 @@ func TestAnnounceLocalhost(t *testing.T) {
|
|||
},
|
||||
}
|
||||
var err error
|
||||
srv.pc, err = net.ListenPacket("udp", ":0")
|
||||
srv.pc, err = net.ListenPacket("udp", "localhost:0")
|
||||
require.NoError(t, err)
|
||||
defer srv.pc.Close()
|
||||
go func() {
|
||||
|
@ -92,7 +92,7 @@ func TestUDPTracker(t *testing.T) {
|
|||
t.Skip(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
t.Log(ar)
|
||||
t.Logf("%+v", ar)
|
||||
}
|
||||
|
||||
func TestAnnounceRandomInfoHashThirdParty(t *testing.T) {
|
||||
|
@ -143,7 +143,7 @@ func TestAnnounceRandomInfoHashThirdParty(t *testing.T) {
|
|||
|
||||
// Check that URLPath option is done correctly.
|
||||
func TestURLPathOption(t *testing.T) {
|
||||
conn, err := net.ListenUDP("udp", nil)
|
||||
conn, err := net.ListenPacket("udp", "localhost:0")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
@ -161,6 +161,7 @@ func TestURLPathOption(t *testing.T) {
|
|||
announceErr <- err
|
||||
}()
|
||||
var b [512]byte
|
||||
// conn.SetReadDeadline(time.Now().Add(time.Second))
|
||||
_, addr, _ := conn.ReadFrom(b[:])
|
||||
r := bytes.NewReader(b[:])
|
||||
var h udp.RequestHeader
|
||||
|
|
Loading…
Reference in New Issue