torrent/dht/dht.go

422 lines
7.6 KiB
Go
Raw Normal View History

2014-05-24 06:51:56 +00:00
package dht
import (
"crypto"
_ "crypto/sha1"
2014-05-24 06:51:56 +00:00
"encoding/binary"
2014-05-25 11:34:29 +00:00
"errors"
2014-05-24 06:51:56 +00:00
"fmt"
"github.com/nsf/libtorgo/bencode"
"io"
"log"
"net"
"os"
2014-05-27 06:28:56 +00:00
"sync"
2014-05-24 06:51:56 +00:00
"time"
)
type Server struct {
ID string
2014-05-25 11:34:29 +00:00
Socket *net.UDPConn
2014-05-24 06:51:56 +00:00
transactions []*transaction
transactionIDInt uint64
nodes map[string]*Node
2014-05-27 06:28:56 +00:00
mu sync.Mutex
2014-05-24 06:51:56 +00:00
}
type Node struct {
2014-05-25 11:34:29 +00:00
addr *net.UDPAddr
2014-05-24 06:51:56 +00:00
id string
lastHeardFrom time.Time
lastSentTo time.Time
}
2014-05-27 06:28:56 +00:00
func (n *Node) Good() bool {
if len(n.id) != 20 {
return false
}
if time.Now().Sub(n.lastHeardFrom) >= 15*time.Minute {
return false
}
return true
}
2014-05-24 06:51:56 +00:00
type Msg map[string]interface{}
var _ fmt.Stringer = Msg{}
func (m Msg) String() string {
return fmt.Sprintf("%#v", m)
}
type transaction struct {
remoteAddr net.Addr
t string
Response chan Msg
2014-05-25 11:34:29 +00:00
response chan Msg
}
2014-05-24 06:51:56 +00:00
func (s *Server) setDefaults() {
if s.ID == "" {
var id [20]byte
h := crypto.SHA1.New()
ss, err := os.Hostname()
2014-05-24 06:51:56 +00:00
if err != nil {
log.Print(err)
}
ss += s.Socket.LocalAddr().String()
h.Write([]byte(ss))
if b := h.Sum(id[:0:20]); len(b) != 20 {
panic(len(b))
}
if len(id) != 20 {
panic(len(id))
2014-05-24 06:51:56 +00:00
}
s.ID = string(id[:])
}
}
2014-05-25 11:34:29 +00:00
func (s *Server) Init() {
s.setDefaults()
2014-05-24 06:51:56 +00:00
}
func (s *Server) Serve() error {
for {
var b [1500]byte
2014-05-25 11:34:29 +00:00
n, addr, err := s.Socket.ReadFromUDP(b[:])
2014-05-24 06:51:56 +00:00
if err != nil {
return err
}
var d map[string]interface{}
err = bencode.Unmarshal(b[:n], &d)
if err != nil {
log.Printf("bad krpc message: %s", err)
continue
}
2014-05-27 06:28:56 +00:00
s.mu.Lock()
if d["y"] == "q" {
s.handleQuery(addr, d)
2014-05-27 06:28:56 +00:00
s.mu.Unlock()
continue
}
2014-05-24 06:51:56 +00:00
t := s.findResponseTransaction(d["t"].(string), addr)
if t == nil {
log.Printf("unexpected message: %#v", d)
2014-05-27 06:28:56 +00:00
s.mu.Unlock()
continue
}
2014-05-25 11:34:29 +00:00
t.response <- d
2014-05-24 06:51:56 +00:00
s.removeTransaction(t)
id := ""
if d["y"] == "r" {
id = d["r"].(map[string]interface{})["id"].(string)
}
s.heardFromNode(addr, id)
2014-05-27 06:28:56 +00:00
s.mu.Unlock()
}
}
func (s *Server) AddNode(ni NodeInfo) {
if s.nodes == nil {
s.nodes = make(map[string]*Node)
}
n := s.getNode(ni.Addr)
if n.id == "" {
n.id = string(ni.ID[:])
2014-05-24 06:51:56 +00:00
}
}
func (s *Server) handleQuery(source *net.UDPAddr, m Msg) {
2014-05-27 06:28:56 +00:00
log.Print(m["q"])
if m["q"] != "ping" {
return
}
2014-05-27 06:28:56 +00:00
s.heardFromNode(source, m["a"].(map[string]interface{})["id"].(string))
s.reply(source, m["t"].(string))
}
func (s *Server) reply(addr *net.UDPAddr, t string) {
m := map[string]interface{}{
"t": t,
"y": "r",
"r": map[string]string{
"id": s.IDString(),
},
}
b, err := bencode.Marshal(m)
if err != nil {
panic(err)
}
err = s.writeToNode(b, addr)
if err != nil {
panic(err)
}
}
2014-05-25 11:34:29 +00:00
func (s *Server) heardFromNode(addr *net.UDPAddr, id string) {
2014-05-24 06:51:56 +00:00
n := s.getNode(addr)
n.id = id
n.lastHeardFrom = time.Now()
}
2014-05-25 11:34:29 +00:00
func (s *Server) getNode(addr *net.UDPAddr) (n *Node) {
2014-05-24 06:51:56 +00:00
n = s.nodes[addr.String()]
if n == nil {
n = &Node{
addr: addr,
}
s.nodes[addr.String()] = n
}
return
}
func (s *Server) writeToNode(b []byte, node *net.UDPAddr) (err error) {
n, err := s.Socket.WriteTo(b, node)
if err != nil {
return
}
if n != len(b) {
err = io.ErrShortWrite
return
}
s.sentToNode(node)
return
}
2014-05-25 11:34:29 +00:00
func (s *Server) sentToNode(addr *net.UDPAddr) {
2014-05-24 06:51:56 +00:00
n := s.getNode(addr)
n.lastSentTo = time.Now()
}
func (s *Server) findResponseTransaction(transactionID string, sourceNode net.Addr) *transaction {
for _, t := range s.transactions {
if t.t == transactionID && t.remoteAddr.String() == sourceNode.String() {
return t
}
}
return nil
}
func (s *Server) nextTransactionID() string {
var b [binary.MaxVarintLen64]byte
n := binary.PutUvarint(b[:], s.transactionIDInt)
s.transactionIDInt++
return string(b[:n])
}
func (s *Server) removeTransaction(t *transaction) {
for i, tt := range s.transactions {
if t == tt {
last := len(s.transactions) - 1
s.transactions[i] = s.transactions[last]
s.transactions = s.transactions[:last]
return
}
}
panic("transaction not found")
}
func (s *Server) addTransaction(t *transaction) {
s.transactions = append(s.transactions, t)
}
func (s *Server) IDString() string {
if len(s.ID) != 20 {
panic("bad node id")
}
return s.ID
}
2014-05-25 11:34:29 +00:00
func (s *Server) query(node *net.UDPAddr, q string, a map[string]string) (t *transaction, err error) {
2014-05-24 06:51:56 +00:00
tid := s.nextTransactionID()
if a == nil {
a = make(map[string]string, 1)
}
a["id"] = s.IDString()
d := map[string]interface{}{
"t": tid,
"y": "q",
"q": q,
"a": a,
}
b, err := bencode.Marshal(d)
if err != nil {
return
}
t = &transaction{
remoteAddr: node,
t: tid,
Response: make(chan Msg, 1),
}
2014-05-25 11:34:29 +00:00
t.response = t.Response
2014-05-24 06:51:56 +00:00
s.addTransaction(t)
err = s.writeToNode(b, node)
2014-05-24 06:51:56 +00:00
if err != nil {
s.removeTransaction(t)
}
return
}
2014-05-27 06:28:56 +00:00
const CompactNodeInfoLen = 26
2014-05-25 11:34:29 +00:00
2014-05-27 06:28:56 +00:00
type NodeInfo struct {
2014-05-25 11:34:29 +00:00
ID [20]byte
2014-05-27 06:28:56 +00:00
Addr *net.UDPAddr
2014-05-25 11:34:29 +00:00
}
2014-05-27 06:28:56 +00:00
func (ni *NodeInfo) PutCompact(b []byte) error {
if n := copy(b[:], ni.ID[:]); n != 20 {
2014-05-25 11:34:29 +00:00
panic(n)
}
2014-05-27 06:28:56 +00:00
ip := ni.Addr.IP.To4()
2014-05-25 11:34:29 +00:00
if len(ip) != 4 {
panic(ip)
}
if n := copy(b[20:], ip); n != 4 {
panic(n)
}
2014-05-27 06:28:56 +00:00
binary.BigEndian.PutUint16(b[24:], uint16(ni.Addr.Port))
return nil
2014-05-25 11:34:29 +00:00
}
2014-05-24 06:51:56 +00:00
2014-05-27 06:28:56 +00:00
func (cni *NodeInfo) UnmarshalCompact(b []byte) error {
2014-05-25 11:34:29 +00:00
if len(b) != 26 {
return errors.New("expected 26 bytes")
}
if 20 != copy(cni.ID[:], b[:20]) {
panic("impossibru!")
}
if cni.Addr == nil {
cni.Addr = &net.UDPAddr{}
}
cni.Addr.IP = net.IPv4(b[20], b[21], b[22], b[23])
cni.Addr.Port = int(binary.BigEndian.Uint16(b[24:26]))
return nil
2014-05-24 06:51:56 +00:00
}
2014-05-25 11:34:29 +00:00
func (s *Server) Ping(node *net.UDPAddr) (*transaction, error) {
2014-05-24 06:51:56 +00:00
return s.query(node, "ping", nil)
}
2014-05-25 11:34:29 +00:00
type findNodeResponse struct {
2014-05-27 06:28:56 +00:00
Nodes []NodeInfo
2014-05-25 11:34:29 +00:00
}
func getResponseNodes(m Msg) (s string, err error) {
defer func() {
r := recover()
if r == nil {
return
}
err = fmt.Errorf("couldn't get response nodes: %s: %#v", r, m)
}()
s = m["r"].(map[string]interface{})["nodes"].(string)
return
}
2014-05-25 11:34:29 +00:00
func (me *findNodeResponse) UnmarshalKRPCMsg(m Msg) error {
b, err := getResponseNodes(m)
if err != nil {
return err
}
2014-05-25 11:34:29 +00:00
for i := 0; i < len(b); i += 26 {
2014-05-27 06:28:56 +00:00
var n NodeInfo
err := n.UnmarshalCompact([]byte(b[i : i+26]))
2014-05-25 11:34:29 +00:00
if err != nil {
return err
}
me.Nodes = append(me.Nodes, n)
}
return nil
}
func (s *Server) FindNode(addr *net.UDPAddr, targetID string) (t *transaction, err error) {
t, err = s.query(addr, "find_node", map[string]string{"target": targetID})
if err != nil {
return
}
ch := make(chan Msg)
t.response = ch
go func() {
d, ok := <-t.response
if !ok {
close(t.Response)
return
}
if d["y"] == "r" {
var r findNodeResponse
err = r.UnmarshalKRPCMsg(d)
if err != nil {
log.Print(err)
} else {
2014-05-27 06:28:56 +00:00
s.mu.Lock()
2014-05-25 11:34:29 +00:00
for _, cni := range r.Nodes {
n := s.getNode(cni.Addr)
n.id = string(cni.ID[:])
}
2014-05-27 06:28:56 +00:00
s.mu.Unlock()
2014-05-25 11:34:29 +00:00
}
}
t.Response <- d
}()
return
}
2014-05-27 06:28:56 +00:00
func (s *Server) addRootNode() error {
addr, err := net.ResolveUDPAddr("udp4", "router.bittorrent.com:6881")
if err != nil {
return err
}
s.nodes[addr.String()] = &Node{
addr: addr,
}
return nil
}
// Populates the node table.
func (s *Server) Bootstrap() (err error) {
s.mu.Lock()
defer s.mu.Unlock()
2014-05-25 11:34:29 +00:00
if len(s.nodes) == 0 {
2014-05-27 06:28:56 +00:00
err = s.addRootNode()
2014-05-25 11:34:29 +00:00
if err != nil {
2014-05-27 06:28:56 +00:00
return
2014-05-25 11:34:29 +00:00
}
2014-05-27 06:28:56 +00:00
}
for _, node := range s.nodes {
var t *transaction
s.mu.Unlock()
t, err = s.FindNode(node.addr, s.ID)
s.mu.Lock()
if err != nil {
return
2014-05-25 11:34:29 +00:00
}
2014-05-27 06:28:56 +00:00
go func() {
<-t.Response
}()
2014-05-25 11:34:29 +00:00
}
2014-05-27 06:28:56 +00:00
return
}
func (s *Server) GoodNodes() (nis []NodeInfo) {
s.mu.Lock()
defer s.mu.Unlock()
for _, node := range s.nodes {
if !node.Good() {
continue
2014-05-25 11:34:29 +00:00
}
2014-05-27 06:28:56 +00:00
ni := NodeInfo{
Addr: node.addr,
}
if n := copy(ni.ID[:], node.id); n != 20 {
panic(n)
}
nis = append(nis, ni)
2014-05-25 11:34:29 +00:00
}
2014-05-27 06:28:56 +00:00
return
}
func (s *Server) StopServing() {
s.Socket.Close()
2014-05-25 11:34:29 +00:00
}