2
0
mirror of synced 2025-02-24 14:48:27 +00:00

204 lines
4.7 KiB
Go
Raw Normal View History

2022-12-05 17:52:03 +11:00
package udpTrackerServer
2022-12-05 12:52:19 +11:00
import (
"bytes"
"context"
"crypto/rand"
"encoding/binary"
"fmt"
"io"
"net"
"net/netip"
"github.com/anacrolix/dht/v2/krpc"
2022-12-08 14:04:23 +11:00
"github.com/anacrolix/generics"
2022-12-05 12:52:19 +11:00
"github.com/anacrolix/log"
2022-12-05 17:52:03 +11:00
"github.com/anacrolix/torrent/tracker"
2022-12-05 12:52:19 +11:00
"github.com/anacrolix/torrent/tracker/udp"
)
type ConnectionTrackerAddr = string
type ConnectionTracker interface {
Add(ctx context.Context, addr ConnectionTrackerAddr, id udp.ConnectionId) error
Check(ctx context.Context, addr ConnectionTrackerAddr, id udp.ConnectionId) (bool, error)
}
type InfoHash = [20]byte
2022-12-05 17:52:03 +11:00
type AnnounceTracker = tracker.AnnounceTracker
2022-12-05 12:52:19 +11:00
type Server struct {
2022-12-07 01:54:38 +11:00
ConnTracker ConnectionTracker
SendResponse func(data []byte, addr net.Addr) (int, error)
Announce tracker.AnnounceHandler
2022-12-05 12:52:19 +11:00
}
type RequestSourceAddr = net.Addr
2022-12-05 17:52:03 +11:00
func (me *Server) HandleRequest(
ctx context.Context,
family udp.AddrFamily,
source RequestSourceAddr,
body []byte,
) error {
2022-12-05 12:52:19 +11:00
var h udp.RequestHeader
var r bytes.Reader
r.Reset(body)
err := udp.Read(&r, &h)
if err != nil {
err = fmt.Errorf("reading request header: %w", err)
return err
}
switch h.Action {
case udp.ActionConnect:
err = me.handleConnect(ctx, source, h.TransactionId)
case udp.ActionAnnounce:
err = me.handleAnnounce(ctx, family, source, h.ConnectionId, h.TransactionId, &r)
default:
err = fmt.Errorf("unimplemented")
}
if err != nil {
err = fmt.Errorf("handling action %v: %w", h.Action, err)
}
return err
}
func (me *Server) handleAnnounce(
ctx context.Context,
addrFamily udp.AddrFamily,
source RequestSourceAddr,
connId udp.ConnectionId,
tid udp.TransactionId,
r *bytes.Reader,
) error {
2022-12-07 01:54:38 +11:00
// Should we set a timeout of 10s or something for the entire response, so that we give up if a
// retry is imminent?
2022-12-05 12:52:19 +11:00
ok, err := me.ConnTracker.Check(ctx, source.String(), connId)
if err != nil {
err = fmt.Errorf("checking conn id: %w", err)
return err
}
if !ok {
return fmt.Errorf("incorrect connection id: %x", connId)
2022-12-05 12:52:19 +11:00
}
var req udp.AnnounceRequest
err = udp.Read(r, &req)
if err != nil {
return err
}
// TODO: This should be done asynchronously to responding to the announce.
2022-12-05 17:52:03 +11:00
announceAddr, err := netip.ParseAddrPort(source.String())
if err != nil {
err = fmt.Errorf("converting source net.Addr to AnnounceAddr: %w", err)
return err
}
2022-12-08 14:04:23 +11:00
opts := tracker.GetPeersOpts{MaxCount: generics.Some[uint](50)}
if addrFamily == udp.AddrFamilyIpv4 {
opts.MaxCount = generics.Some[uint](150)
}
peers, err := me.Announce.Serve(ctx, req, announceAddr, opts)
2022-12-05 12:52:19 +11:00
if err != nil {
return err
}
nodeAddrs := make([]krpc.NodeAddr, 0, len(peers))
for _, p := range peers {
var ip net.IP
switch addrFamily {
default:
continue
case udp.AddrFamilyIpv4:
if !p.Addr().Unmap().Is4() {
continue
}
ipBuf := p.Addr().As4()
ip = ipBuf[:]
case udp.AddrFamilyIpv6:
ipBuf := p.Addr().As16()
ip = ipBuf[:]
}
nodeAddrs = append(nodeAddrs, krpc.NodeAddr{
IP: ip[:],
Port: int(p.Port()),
})
}
var buf bytes.Buffer
err = udp.Write(&buf, udp.ResponseHeader{
Action: udp.ActionAnnounce,
TransactionId: tid,
})
if err != nil {
return err
}
err = udp.Write(&buf, udp.AnnounceResponseHeader{})
if err != nil {
return err
}
b, err := udp.GetNodeAddrsCompactMarshaler(nodeAddrs, addrFamily).MarshalBinary()
if err != nil {
err = fmt.Errorf("marshalling compact node addrs: %w", err)
return err
}
buf.Write(b)
n, err := me.SendResponse(buf.Bytes(), source)
if err != nil {
return err
}
if n < buf.Len() {
err = io.ErrShortWrite
}
return err
}
func (me *Server) handleConnect(ctx context.Context, source RequestSourceAddr, tid udp.TransactionId) error {
connId := randomConnectionId()
err := me.ConnTracker.Add(ctx, source.String(), connId)
if err != nil {
err = fmt.Errorf("recording conn id: %w", err)
return err
}
var buf bytes.Buffer
udp.Write(&buf, udp.ResponseHeader{
Action: udp.ActionConnect,
TransactionId: tid,
})
udp.Write(&buf, udp.ConnectionResponse{connId})
n, err := me.SendResponse(buf.Bytes(), source)
if err != nil {
return err
}
if n < buf.Len() {
err = io.ErrShortWrite
}
return err
}
func randomConnectionId() udp.ConnectionId {
var b [8]byte
_, err := rand.Read(b[:])
if err != nil {
panic(err)
}
return binary.BigEndian.Uint64(b[:])
2022-12-05 12:52:19 +11:00
}
2022-12-05 17:52:03 +11:00
func RunSimple(ctx context.Context, s *Server, pc net.PacketConn, family udp.AddrFamily) error {
2022-12-05 12:52:19 +11:00
ctx, cancel := context.WithCancel(ctx)
defer cancel()
for {
var b [1500]byte
n, addr, err := pc.ReadFrom(b[:])
if err != nil {
return err
}
go func() {
err := s.HandleRequest(ctx, family, addr, b[:n])
if err != nil {
log.Printf("error handling %v byte request from %v: %v", n, addr, err)
}
}()
}
}