Add OnQuery hook, thanks to Cathal Garvey
This commit is contained in:
parent
6d8310d2df
commit
565fb20139
@ -65,6 +65,8 @@ type ServerConfig struct {
|
|||||||
IPBlocklist iplist.Ranger
|
IPBlocklist iplist.Ranger
|
||||||
// Used to secure the server's ID. Defaults to the Conn's LocalAddr().
|
// Used to secure the server's ID. Defaults to the Conn's LocalAddr().
|
||||||
PublicIP net.IP
|
PublicIP net.IP
|
||||||
|
|
||||||
|
OnQuery func(*Msg, net.Addr) bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// ServerStats instance is returned by Server.Stats() and stores Server metrics
|
// ServerStats instance is returned by Server.Stats() and stores Server metrics
|
||||||
@ -139,6 +141,8 @@ func (n *node) IsSecure() bool {
|
|||||||
if n.id.IsUnset() {
|
if n.id.IsUnset() {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
// TODO (@onetruecathal): Exempt local peers from security
|
||||||
|
// check as per security extension recommendations
|
||||||
return NodeIdSecure(n.id.ByteString(), n.addr.IP())
|
return NodeIdSecure(n.id.ByteString(), n.addr.IP())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/anacrolix/missinggo"
|
"github.com/anacrolix/missinggo"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
@ -217,3 +218,53 @@ func TestAnnounceTimeout(t *testing.T) {
|
|||||||
func TestEqualPointers(t *testing.T) {
|
func TestEqualPointers(t *testing.T) {
|
||||||
assert.EqualValues(t, &Msg{R: &Return{}}, &Msg{R: &Return{}})
|
assert.EqualValues(t, &Msg{R: &Return{}}, &Msg{R: &Return{}})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHook(t *testing.T) {
|
||||||
|
t.Log("TestHook: Starting with Ping intercept/passthrough")
|
||||||
|
srv, err := NewServer(&ServerConfig{
|
||||||
|
Addr: "127.0.0.1:5678",
|
||||||
|
NoDefaultBootstrap: true,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer srv.Close()
|
||||||
|
// Establish server with a hook attached to "ping"
|
||||||
|
hookCalled := make(chan bool)
|
||||||
|
srv0, err := NewServer(&ServerConfig{
|
||||||
|
Addr: "127.0.0.1:5679",
|
||||||
|
BootstrapNodes: []string{"127.0.0.1:5678"},
|
||||||
|
OnQuery: func(m *Msg, addr net.Addr) bool {
|
||||||
|
if m.Q == "ping" {
|
||||||
|
hookCalled <- true
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer srv0.Close()
|
||||||
|
// Ping srv0 from srv to trigger hook. Should also receive a response.
|
||||||
|
t.Log("TestHook: Servers created, hook for ping established. Calling Ping.")
|
||||||
|
tn, err := srv.Ping(&net.UDPAddr{
|
||||||
|
IP: []byte{127, 0, 0, 1},
|
||||||
|
Port: srv0.Addr().(*net.UDPAddr).Port,
|
||||||
|
})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
defer tn.Close()
|
||||||
|
// Await response from hooked server
|
||||||
|
tn.SetResponseHandler(func(msg Msg, b bool) {
|
||||||
|
t.Log("TestHook: Sender received response from pinged hook server, so normal execution resumed.")
|
||||||
|
})
|
||||||
|
// Await signal that hook has been called.
|
||||||
|
select {
|
||||||
|
case <-hookCalled:
|
||||||
|
{
|
||||||
|
// Success, hook was triggered. Todo: Ensure that "ok" channel
|
||||||
|
// receives, also, indicating normal handling proceeded also.
|
||||||
|
t.Log("TestHook: Received ping, hook called and returned to normal execution!")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case <-time.After(time.Second * 1):
|
||||||
|
{
|
||||||
|
t.Error("Failed to see evidence of ping hook being called after 2 seconds.")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -251,6 +251,12 @@ func (s *Server) nodeByID(id string) *node {
|
|||||||
func (s *Server) handleQuery(source dHTAddr, m Msg) {
|
func (s *Server) handleQuery(source dHTAddr, m Msg) {
|
||||||
node := s.getNode(source, m.SenderID())
|
node := s.getNode(source, m.SenderID())
|
||||||
node.lastGotQuery = time.Now()
|
node.lastGotQuery = time.Now()
|
||||||
|
if s.config.OnQuery != nil {
|
||||||
|
propagate := s.config.OnQuery(&m, source.UDPAddr())
|
||||||
|
if !propagate {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
// Don't respond.
|
// Don't respond.
|
||||||
if s.config.Passive {
|
if s.config.Passive {
|
||||||
return
|
return
|
||||||
@ -340,6 +346,7 @@ func (s *Server) getNode(addr dHTAddr, id string) (n *node) {
|
|||||||
if len(s.nodes) >= maxNodes {
|
if len(s.nodes) >= maxNodes {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
// Exclude insecure nodes from the node table.
|
||||||
if !s.config.NoSecurity && !n.IsSecure() {
|
if !s.config.NoSecurity && !n.IsSecure() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user