diff --git a/dht/dht.go b/dht/dht.go index 6df2c225..499d59bf 100644 --- a/dht/dht.go +++ b/dht/dht.go @@ -139,8 +139,6 @@ func (n *node) IsSecure() bool { if n.id.IsUnset() { return false } - // TODO (@onetruecathal): Exempt local peers from security - // check as per security extension recommendations return NodeIdSecure(n.id.ByteString(), n.addr.IP()) } diff --git a/dht/security.go b/dht/security.go index e91d1607..9355b7b5 100644 --- a/dht/security.go +++ b/dht/security.go @@ -43,6 +43,9 @@ func SecureNodeId(id []byte, ip net.IP) { // Returns whether the node ID is considered secure. The id is the 20 raw // bytes. http://www.libtorrent.org/dht_sec.html func NodeIdSecure(id string, ip net.IP) bool { + if isLocalNetwork(ip) { + return true + } if len(id) != 20 { panic(fmt.Sprintf("%q", id)) } @@ -61,3 +64,43 @@ func NodeIdSecure(id string, ip net.IP) bool { } return true } + +var ( + classA, classB, classC *net.IPNet +) + +func mustParseCIDRIPNet(s string) *net.IPNet { + _, ret, err := net.ParseCIDR(s) + if err != nil { + panic(err) + } + return ret +} + +func init() { + classA = mustParseCIDRIPNet("10.0.0.0/8") + classB = mustParseCIDRIPNet("172.16.0.0/12") + classC = mustParseCIDRIPNet("192.168.0.0/16") +} + +// Per http://www.libtorrent.org/dht_sec.html#enforcement, the IP is +// considered a local network address and should be exempted from node ID +// verification. +func isLocalNetwork(ip net.IP) bool { + if classA.Contains(ip) { + return true + } + if classB.Contains(ip) { + return true + } + if classC.Contains(ip) { + return true + } + if ip.IsLinkLocalUnicast() { + return true + } + if ip.IsLoopback() { + return true + } + return false +} diff --git a/dht/security_test.go b/dht/security_test.go index 81ce0be2..9cd84470 100644 --- a/dht/security_test.go +++ b/dht/security_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/anacrolix/missinggo" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -31,19 +32,22 @@ func TestDHTSec(t *testing.T) { {"84.124.73.14", "1b0321dd1bb1fe518101ceef99462b947a01fe01", true}, // spec[4] with the 3rd last bit changed. Not valid. {"43.213.53.83", "e56f6cbf5b7c4be0237986d5243b87aa6d51303e", false}, + // Because class A network. + {"10.213.53.83", "e56f6cbf5b7c4be0237986d5243b87aa6d51305a", true}, + // Because not class A, and id[0]&3 does not match. + {"12.213.53.83", "e56f6cbf5b7c4be0237986d5243b87aa6d51305a", false}, + // Because class C. + {"192.168.53.83", "e56f6cbf5b7c4be0237986d5243b87aa6d51305a", true}, } { ip := net.ParseIP(case_.ipStr) id, err := hex.DecodeString(case_.nodeIDHex) require.NoError(t, err) secure := NodeIdSecure(string(id), ip) - if secure != case_.valid { - t.Fatalf("case failed: %v", case_) - } + assert.Equal(t, case_.valid, secure, "%v", case_) if !secure { + // It's not secure, so secure it in place and then check it again. SecureNodeId(id, ip) - if !NodeIdSecure(string(id), ip) { - t.Fatal("failed to secure node id") - } + assert.True(t, NodeIdSecure(string(id), ip), "%v", case_) } } }