diff --git a/discover/table.go b/discover/table.go index 2f61d99..f721bcb 100644 --- a/discover/table.go +++ b/discover/table.go @@ -79,6 +79,8 @@ type Table struct { closeReq chan struct{} closed chan struct{} + nodeIsValidFn func(enode.Node) bool + nodeAddedHook func(*node) // for testing } @@ -99,17 +101,18 @@ type bucket struct { ips netutil.DistinctNetSet } -func newTable(t transport, db *enode.DB, bootnodes []*enode.Node, log log.Logger) (*Table, error) { +func newTable(t transport, db *enode.DB, bootnodes []*enode.Node, nodeIsValidFn func(enode.Node) bool, log log.Logger) (*Table, error) { tab := &Table{ - net: t, - db: db, - refreshReq: make(chan chan struct{}), - initDone: make(chan struct{}), - closeReq: make(chan struct{}), - closed: make(chan struct{}), - rand: mrand.New(mrand.NewSource(0)), - ips: netutil.DistinctNetSet{Subnet: tableSubnet, Limit: tableIPLimit}, - log: log, + net: t, + db: db, + refreshReq: make(chan chan struct{}), + initDone: make(chan struct{}), + closeReq: make(chan struct{}), + closed: make(chan struct{}), + rand: mrand.New(mrand.NewSource(0)), + ips: netutil.DistinctNetSet{Subnet: tableSubnet, Limit: tableIPLimit}, + nodeIsValidFn: nodeIsValidFn, + log: log, } if err := tab.setFallbackNodes(bootnodes); err != nil { return nil, err @@ -513,6 +516,10 @@ func (tab *Table) addVerifiedNode(n *node) { return } + if tab.nodeIsValidFn != nil && !tab.nodeIsValidFn(n.Node) { + return + } + tab.mutex.Lock() defer tab.mutex.Unlock() b := tab.bucket(n.ID()) diff --git a/discover/table_test.go b/discover/table_test.go index 5f40c96..cbffdef 100644 --- a/discover/table_test.go +++ b/discover/table_test.go @@ -314,8 +314,18 @@ func TestTable_addVerifiedNode(t *testing.T) { // Insert two nodes. n1 := nodeAtDistance(tab.self().ID(), 256, net.IP{88, 77, 66, 1}) n2 := nodeAtDistance(tab.self().ID(), 256, net.IP{88, 77, 66, 2}) + n3 := nodeAtDistance(tab.self().ID(), 256, net.IP{66, 77, 88, 3}) + + // Check if node is valid before adding it + validFN := func(node enode.Node) bool { + return !node.IP().Equal(n3.IP()) // Node 3 is invalid + } + + tab.nodeIsValidFn = validFN + tab.addSeenNode(n1) tab.addSeenNode(n2) + tab.addVerifiedNode(n3) // Verify bucket content: bcontent := []*node{n1, n2} @@ -343,14 +353,24 @@ func TestTable_addSeenNode(t *testing.T) { defer db.Close() defer tab.close() - // Insert two nodes. + // Insert three nodes. n1 := nodeAtDistance(tab.self().ID(), 256, net.IP{88, 77, 66, 1}) n2 := nodeAtDistance(tab.self().ID(), 256, net.IP{88, 77, 66, 2}) + n3 := nodeAtDistance(tab.self().ID(), 256, net.IP{66, 77, 88, 3}) + + // Check if node is valid before adding it + validFN := func(node enode.Node) bool { + return !node.IP().Equal(n3.IP()) // Node 3 is invalid + } + + tab.nodeIsValidFn = validFN + tab.addSeenNode(n1) tab.addSeenNode(n2) + tab.addSeenNode(n3) // Verify bucket content: - bcontent := []*node{n1, n2} + bcontent := []*node{n1, n2} // n3 shouldnt have been added if !reflect.DeepEqual(tab.bucket(n1.ID()).entries, bcontent) { t.Fatalf("wrong bucket content: %v", tab.bucket(n1.ID()).entries) } diff --git a/discover/v4_udp.go b/discover/v4_udp.go index 238fa09..095006c 100644 --- a/discover/v4_udp.go +++ b/discover/v4_udp.go @@ -142,7 +142,7 @@ func ListenV4(c UDPConn, ln *enode.LocalNode, cfg Config) (*UDPv4, error) { log: cfg.Log, } - tab, err := newTable(t, ln.Database(), cfg.Bootnodes, t.log) + tab, err := newTable(t, ln.Database(), cfg.Bootnodes, cfg.ValidNodeFn, t.log) if err != nil { return nil, err } diff --git a/discover/v5_udp.go b/discover/v5_udp.go index 1e3bf4f..4dadc41 100644 --- a/discover/v5_udp.go +++ b/discover/v5_udp.go @@ -164,7 +164,7 @@ func newUDPv5(conn UDPConn, ln *enode.LocalNode, cfg Config) (*UDPv5, error) { closeCtx: closeCtx, cancelCloseCtx: cancelCloseCtx, } - tab, err := newTable(t, t.db, cfg.Bootnodes, cfg.Log) + tab, err := newTable(t, t.db, cfg.Bootnodes, cfg.ValidNodeFn, cfg.Log) if err != nil { return nil, err }