2022-12-05 12:50:40 +11:00
|
|
|
package server
|
2022-12-05 12:52:19 +11:00
|
|
|
|
|
|
|
import (
|
|
|
|
"bytes"
|
|
|
|
"context"
|
|
|
|
"crypto/rand"
|
|
|
|
"encoding/binary"
|
|
|
|
"fmt"
|
|
|
|
"io"
|
|
|
|
"net"
|
|
|
|
"net/netip"
|
|
|
|
|
|
|
|
"github.com/anacrolix/dht/v2/krpc"
|
|
|
|
"github.com/anacrolix/log"
|
|
|
|
"github.com/anacrolix/torrent/tracker/udp"
|
|
|
|
)
|
|
|
|
|
|
|
|
type ConnectionTrackerAddr = string
|
|
|
|
|
|
|
|
type ConnectionTracker interface {
|
|
|
|
Add(ctx context.Context, addr ConnectionTrackerAddr, id udp.ConnectionId) error
|
|
|
|
Check(ctx context.Context, addr ConnectionTrackerAddr, id udp.ConnectionId) (bool, error)
|
|
|
|
}
|
|
|
|
|
|
|
|
type InfoHash = [20]byte
|
|
|
|
|
|
|
|
// This is reserved for stuff like filtering by IP version, avoiding an announcer's IP or key,
|
|
|
|
// limiting return count, etc.
|
|
|
|
type GetPeersOpts struct{}
|
|
|
|
|
|
|
|
type PeerInfo struct {
|
|
|
|
netip.AddrPort
|
|
|
|
}
|
|
|
|
|
|
|
|
type AnnounceTracker interface {
|
|
|
|
TrackAnnounce(ctx context.Context, req udp.AnnounceRequest, addr RequestSourceAddr) error
|
|
|
|
Scrape(ctx context.Context, infoHashes []InfoHash) ([]udp.ScrapeInfohashResult, error)
|
|
|
|
GetPeers(ctx context.Context, infoHash InfoHash, opts GetPeersOpts) ([]PeerInfo, error)
|
|
|
|
}
|
|
|
|
|
|
|
|
type Server struct {
|
|
|
|
ConnTracker ConnectionTracker
|
|
|
|
SendResponse func(data []byte, addr net.Addr) (int, error)
|
|
|
|
AnnounceTracker AnnounceTracker
|
|
|
|
}
|
|
|
|
|
|
|
|
type RequestSourceAddr = net.Addr
|
|
|
|
|
|
|
|
func (me *Server) HandleRequest(ctx context.Context, family udp.AddrFamily, source RequestSourceAddr, body []byte) error {
|
|
|
|
var h udp.RequestHeader
|
|
|
|
var r bytes.Reader
|
|
|
|
r.Reset(body)
|
|
|
|
err := udp.Read(&r, &h)
|
|
|
|
if err != nil {
|
|
|
|
err = fmt.Errorf("reading request header: %w", err)
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
switch h.Action {
|
|
|
|
case udp.ActionConnect:
|
|
|
|
err = me.handleConnect(ctx, source, h.TransactionId)
|
|
|
|
case udp.ActionAnnounce:
|
|
|
|
err = me.handleAnnounce(ctx, family, source, h.ConnectionId, h.TransactionId, &r)
|
|
|
|
default:
|
|
|
|
err = fmt.Errorf("unimplemented")
|
|
|
|
}
|
|
|
|
if err != nil {
|
|
|
|
err = fmt.Errorf("handling action %v: %w", h.Action, err)
|
|
|
|
}
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
func (me *Server) handleAnnounce(
|
|
|
|
ctx context.Context,
|
|
|
|
addrFamily udp.AddrFamily,
|
|
|
|
source RequestSourceAddr,
|
|
|
|
connId udp.ConnectionId,
|
|
|
|
tid udp.TransactionId,
|
|
|
|
r *bytes.Reader,
|
|
|
|
) error {
|
|
|
|
ok, err := me.ConnTracker.Check(ctx, source.String(), connId)
|
|
|
|
if err != nil {
|
|
|
|
err = fmt.Errorf("checking conn id: %w", err)
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
if !ok {
|
|
|
|
return fmt.Errorf("invalid connection id: %v", connId)
|
|
|
|
}
|
|
|
|
var req udp.AnnounceRequest
|
|
|
|
err = udp.Read(r, &req)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
// TODO: This should be done asynchronously to responding to the announce.
|
|
|
|
err = me.AnnounceTracker.TrackAnnounce(ctx, req, source)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
peers, err := me.AnnounceTracker.GetPeers(ctx, req.InfoHash, GetPeersOpts{})
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
nodeAddrs := make([]krpc.NodeAddr, 0, len(peers))
|
|
|
|
for _, p := range peers {
|
|
|
|
var ip net.IP
|
|
|
|
switch addrFamily {
|
|
|
|
default:
|
|
|
|
continue
|
|
|
|
case udp.AddrFamilyIpv4:
|
|
|
|
if !p.Addr().Unmap().Is4() {
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
ipBuf := p.Addr().As4()
|
|
|
|
ip = ipBuf[:]
|
|
|
|
case udp.AddrFamilyIpv6:
|
|
|
|
ipBuf := p.Addr().As16()
|
|
|
|
ip = ipBuf[:]
|
|
|
|
}
|
|
|
|
nodeAddrs = append(nodeAddrs, krpc.NodeAddr{
|
|
|
|
IP: ip[:],
|
|
|
|
Port: int(p.Port()),
|
|
|
|
})
|
|
|
|
}
|
|
|
|
var buf bytes.Buffer
|
|
|
|
err = udp.Write(&buf, udp.ResponseHeader{
|
|
|
|
Action: udp.ActionAnnounce,
|
|
|
|
TransactionId: tid,
|
|
|
|
})
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
err = udp.Write(&buf, udp.AnnounceResponseHeader{})
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
b, err := udp.GetNodeAddrsCompactMarshaler(nodeAddrs, addrFamily).MarshalBinary()
|
|
|
|
if err != nil {
|
|
|
|
err = fmt.Errorf("marshalling compact node addrs: %w", err)
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
log.Print(nodeAddrs)
|
|
|
|
buf.Write(b)
|
|
|
|
n, err := me.SendResponse(buf.Bytes(), source)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
if n < buf.Len() {
|
|
|
|
err = io.ErrShortWrite
|
|
|
|
}
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
func (me *Server) handleConnect(ctx context.Context, source RequestSourceAddr, tid udp.TransactionId) error {
|
|
|
|
connId := randomConnectionId()
|
|
|
|
err := me.ConnTracker.Add(ctx, source.String(), connId)
|
|
|
|
if err != nil {
|
|
|
|
err = fmt.Errorf("recording conn id: %w", err)
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
var buf bytes.Buffer
|
|
|
|
udp.Write(&buf, udp.ResponseHeader{
|
|
|
|
Action: udp.ActionConnect,
|
|
|
|
TransactionId: tid,
|
|
|
|
})
|
|
|
|
udp.Write(&buf, udp.ConnectionResponse{connId})
|
|
|
|
n, err := me.SendResponse(buf.Bytes(), source)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
if n < buf.Len() {
|
|
|
|
err = io.ErrShortWrite
|
|
|
|
}
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
func randomConnectionId() udp.ConnectionId {
|
|
|
|
var b [8]byte
|
|
|
|
_, err := rand.Read(b[:])
|
|
|
|
if err != nil {
|
|
|
|
panic(err)
|
|
|
|
}
|
|
|
|
return int64(binary.BigEndian.Uint64(b[:]))
|
|
|
|
}
|
|
|
|
|
|
|
|
func RunServer(ctx context.Context, s *Server, pc net.PacketConn, family udp.AddrFamily) error {
|
|
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
|
|
defer cancel()
|
|
|
|
for {
|
|
|
|
var b [1500]byte
|
|
|
|
n, addr, err := pc.ReadFrom(b[:])
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
go func() {
|
|
|
|
err := s.HandleRequest(ctx, family, addr, b[:n])
|
|
|
|
if err != nil {
|
|
|
|
log.Printf("error handling %v byte request from %v: %v", n, addr, err)
|
|
|
|
}
|
|
|
|
}()
|
|
|
|
}
|
|
|
|
}
|