diff --git a/dht/dht.go b/dht/dht.go index ec00c1fa..4fef734d 100644 --- a/dht/dht.go +++ b/dht/dht.go @@ -65,6 +65,8 @@ type ServerConfig struct { IPBlocklist iplist.Ranger // Used to secure the server's ID. Defaults to the Conn's LocalAddr(). PublicIP net.IP + + OnQuery func(*Msg, net.Addr) bool } // ServerStats instance is returned by Server.Stats() and stores Server metrics @@ -139,6 +141,8 @@ func (n *node) IsSecure() bool { if n.id.IsUnset() { return false } + // TODO (@onetruecathal): Exempt local peers from security + // check as per security extension recommendations return NodeIdSecure(n.id.ByteString(), n.addr.IP()) } diff --git a/dht/dht_test.go b/dht/dht_test.go index 082b15b3..affeba83 100644 --- a/dht/dht_test.go +++ b/dht/dht_test.go @@ -6,6 +6,7 @@ import ( "math/rand" "net" "testing" + "time" "github.com/anacrolix/missinggo" "github.com/stretchr/testify/assert" @@ -217,3 +218,53 @@ func TestAnnounceTimeout(t *testing.T) { func TestEqualPointers(t *testing.T) { assert.EqualValues(t, &Msg{R: &Return{}}, &Msg{R: &Return{}}) } + +func TestHook(t *testing.T) { + t.Log("TestHook: Starting with Ping intercept/passthrough") + srv, err := NewServer(&ServerConfig{ + Addr: "127.0.0.1:5678", + NoDefaultBootstrap: true, + }) + require.NoError(t, err) + defer srv.Close() + // Establish server with a hook attached to "ping" + hookCalled := make(chan bool) + srv0, err := NewServer(&ServerConfig{ + Addr: "127.0.0.1:5679", + BootstrapNodes: []string{"127.0.0.1:5678"}, + OnQuery: func(m *Msg, addr net.Addr) bool { + if m.Q == "ping" { + hookCalled <- true + } + return true + }, + }) + require.NoError(t, err) + defer srv0.Close() + // Ping srv0 from srv to trigger hook. Should also receive a response. + t.Log("TestHook: Servers created, hook for ping established. Calling Ping.") + tn, err := srv.Ping(&net.UDPAddr{ + IP: []byte{127, 0, 0, 1}, + Port: srv0.Addr().(*net.UDPAddr).Port, + }) + assert.NoError(t, err) + defer tn.Close() + // Await response from hooked server + tn.SetResponseHandler(func(msg Msg, b bool) { + t.Log("TestHook: Sender received response from pinged hook server, so normal execution resumed.") + }) + // Await signal that hook has been called. + select { + case <-hookCalled: + { + // Success, hook was triggered. Todo: Ensure that "ok" channel + // receives, also, indicating normal handling proceeded also. + t.Log("TestHook: Received ping, hook called and returned to normal execution!") + return + } + case <-time.After(time.Second * 1): + { + t.Error("Failed to see evidence of ping hook being called after 2 seconds.") + } + } +} diff --git a/dht/server.go b/dht/server.go index b25579c1..73a4cc3c 100644 --- a/dht/server.go +++ b/dht/server.go @@ -251,6 +251,12 @@ func (s *Server) nodeByID(id string) *node { func (s *Server) handleQuery(source dHTAddr, m Msg) { node := s.getNode(source, m.SenderID()) node.lastGotQuery = time.Now() + if s.config.OnQuery != nil { + propagate := s.config.OnQuery(&m, source.UDPAddr()) + if !propagate { + return + } + } // Don't respond. if s.config.Passive { return @@ -340,6 +346,7 @@ func (s *Server) getNode(addr dHTAddr, id string) (n *node) { if len(s.nodes) >= maxNodes { return } + // Exclude insecure nodes from the node table. if !s.config.NoSecurity && !n.IsSecure() { return }