2022-04-06 10:36:06 -04:00

405 lines
10 KiB
Go

package cidranger
import (
"fmt"
"net"
"strings"
rnet "github.com/libp2p/go-cidranger/net"
)
// prefixTrie is a path-compressed (PC) trie implementation of the
// ranger interface inspired by this blog post:
// https://vincent.bernat.im/en/blog/2017-ipv4-route-lookup-linux
//
// CIDR blocks are stored using a prefix tree structure where each node has its
// parent as prefix, and the path from the root node represents current CIDR
// block.
//
// For IPv4, the trie structure guarantees max depth of 32 as IPv4 addresses are
// 32 bits long and each bit represents a prefix tree starting at that bit. This
// property also guarantees constant lookup time in Big-O notation.
//
// Path compression compresses a string of node with only 1 child into a single
// node, decrease the amount of lookups necessary during containment tests.
//
// Level compression dictates the amount of direct children of a node by
// allowing it to handle multiple bits in the path. The heuristic (based on
// children population) to decide when the compression and decompression happens
// is outlined in the prior linked blog, and will be experimented with in more
// depth in this project in the future.
//
// Note: Can not insert both IPv4 and IPv6 network addresses into the same
// prefix trie, use versionedRanger wrapper instead.
//
// TODO: Implement level-compressed component of the LPC trie.
type prefixTrie struct {
parent *prefixTrie
children [2]*prefixTrie
numBitsSkipped uint
numBitsHandled uint
network rnet.Network
entry RangerEntry
size int // This is only maintained in the root trie.
}
var ip4ZeroCIDR, ip6ZeroCIDR net.IPNet
func init() {
_, v4, _ := net.ParseCIDR("0.0.0.0/0")
_, v6, _ := net.ParseCIDR("0::0/0")
ip4ZeroCIDR = *v4
ip6ZeroCIDR = *v6
}
func newRanger(version rnet.IPVersion) Ranger {
return newPrefixTree(version)
}
// newPrefixTree creates a new prefixTrie.
func newPrefixTree(version rnet.IPVersion) *prefixTrie {
rootNet := ip4ZeroCIDR
if version == rnet.IPv6 {
rootNet = ip6ZeroCIDR
}
return &prefixTrie{
numBitsSkipped: 0,
numBitsHandled: 1,
network: rnet.NewNetwork(rootNet),
}
}
func newPathprefixTrie(network rnet.Network, numBitsSkipped uint) *prefixTrie {
version := rnet.IPv4
if len(network.Number) == rnet.IPv6Uint32Count {
version = rnet.IPv6
}
path := newPrefixTree(version)
path.numBitsSkipped = numBitsSkipped
path.network = network.Masked(int(numBitsSkipped))
return path
}
func newEntryTrie(network rnet.Network, entry RangerEntry) *prefixTrie {
leaf := newPathprefixTrie(network, uint(network.Mask))
leaf.entry = entry
return leaf
}
// Insert inserts a RangerEntry into prefix trie.
func (p *prefixTrie) Insert(entry RangerEntry) error {
network := entry.Network()
sizeIncreased, err := p.insert(rnet.NewNetwork(network), entry)
if sizeIncreased {
p.size++
}
return err
}
// Remove removes RangerEntry identified by given network from trie.
func (p *prefixTrie) Remove(network net.IPNet) (RangerEntry, error) {
entry, err := p.remove(rnet.NewNetwork(network))
if entry != nil {
p.size--
}
return entry, err
}
// Contains returns boolean indicating whether given ip is contained in any
// of the inserted networks.
func (p *prefixTrie) Contains(ip net.IP) (bool, error) {
nn := rnet.NewNetworkNumber(ip)
if nn == nil {
return false, ErrInvalidNetworkNumberInput
}
return p.contains(nn)
}
// ContainingNetworks returns the list of RangerEntry(s) the given ip is
// contained in in ascending prefix order.
func (p *prefixTrie) ContainingNetworks(ip net.IP) ([]RangerEntry, error) {
nn := rnet.NewNetworkNumber(ip)
if nn == nil {
return nil, ErrInvalidNetworkNumberInput
}
return p.containingNetworks(nn)
}
// CoveredNetworks returns the list of RangerEntry(s) the given ipnet
// covers. That is, the networks that are completely subsumed by the
// specified network.
func (p *prefixTrie) CoveredNetworks(network net.IPNet) ([]RangerEntry, error) {
net := rnet.NewNetwork(network)
return p.coveredNetworks(net)
}
// Len returns number of networks in ranger.
func (p *prefixTrie) Len() int {
return p.size
}
// String returns string representation of trie, mainly for visualization and
// debugging.
func (p *prefixTrie) String() string {
children := []string{}
padding := strings.Repeat("| ", p.level()+1)
for bits, child := range p.children {
if child == nil {
continue
}
childStr := fmt.Sprintf("\n%s%d--> %s", padding, bits, child.String())
children = append(children, childStr)
}
return fmt.Sprintf("%s (target_pos:%d:has_entry:%t)%s", p.network,
p.targetBitPosition(), p.hasEntry(), strings.Join(children, ""))
}
func (p *prefixTrie) contains(number rnet.NetworkNumber) (bool, error) {
if !p.network.Contains(number) {
return false, nil
}
if p.hasEntry() {
return true, nil
}
if p.targetBitPosition() < 0 {
return false, nil
}
bit, err := p.targetBitFromIP(number)
if err != nil {
return false, err
}
child := p.children[bit]
if child != nil {
return child.contains(number)
}
return false, nil
}
func (p *prefixTrie) containingNetworks(number rnet.NetworkNumber) ([]RangerEntry, error) {
results := []RangerEntry{}
if !p.network.Contains(number) {
return results, nil
}
if p.hasEntry() {
results = []RangerEntry{p.entry}
}
if p.targetBitPosition() < 0 {
return results, nil
}
bit, err := p.targetBitFromIP(number)
if err != nil {
return nil, err
}
child := p.children[bit]
if child != nil {
ranges, err := child.containingNetworks(number)
if err != nil {
return nil, err
}
if len(ranges) > 0 {
if len(results) > 0 {
results = append(results, ranges...)
} else {
results = ranges
}
}
}
return results, nil
}
func (p *prefixTrie) coveredNetworks(network rnet.Network) ([]RangerEntry, error) {
var results []RangerEntry
if network.Covers(p.network) {
for entry := range p.walkDepth() {
results = append(results, entry)
}
} else if p.targetBitPosition() >= 0 {
bit, err := p.targetBitFromIP(network.Number)
if err != nil {
return results, err
}
child := p.children[bit]
if child != nil {
return child.coveredNetworks(network)
}
}
return results, nil
}
func (p *prefixTrie) insert(network rnet.Network, entry RangerEntry) (bool, error) {
if p.network.Equal(network) {
sizeIncreased := p.entry == nil
p.entry = entry
return sizeIncreased, nil
}
bit, err := p.targetBitFromIP(network.Number)
if err != nil {
return false, err
}
existingChild := p.children[bit]
// No existing child, insert new leaf trie.
if existingChild == nil {
p.appendTrie(bit, newEntryTrie(network, entry))
return true, nil
}
// Check whether it is necessary to insert additional path prefix between current trie and existing child,
// in the case that inserted network diverges on its path to existing child.
lcb, err := network.LeastCommonBitPosition(existingChild.network)
divergingBitPos := int(lcb) - 1
if divergingBitPos > existingChild.targetBitPosition() {
pathPrefix := newPathprefixTrie(network, p.totalNumberOfBits()-lcb)
err := p.insertPrefix(bit, pathPrefix, existingChild)
if err != nil {
return false, err
}
// Update new child
existingChild = pathPrefix
}
return existingChild.insert(network, entry)
}
func (p *prefixTrie) appendTrie(bit uint32, prefix *prefixTrie) {
p.children[bit] = prefix
prefix.parent = p
}
func (p *prefixTrie) insertPrefix(bit uint32, pathPrefix, child *prefixTrie) error {
// Set parent/child relationship between current trie and inserted pathPrefix
p.children[bit] = pathPrefix
pathPrefix.parent = p
// Set parent/child relationship between inserted pathPrefix and original child
pathPrefixBit, err := pathPrefix.targetBitFromIP(child.network.Number)
if err != nil {
return err
}
pathPrefix.children[pathPrefixBit] = child
child.parent = pathPrefix
return nil
}
func (p *prefixTrie) remove(network rnet.Network) (RangerEntry, error) {
if p.hasEntry() && p.network.Equal(network) {
entry := p.entry
p.entry = nil
err := p.compressPathIfPossible()
if err != nil {
return nil, err
}
return entry, nil
}
bit, err := p.targetBitFromIP(network.Number)
if err != nil {
return nil, err
}
child := p.children[bit]
if child != nil {
return child.remove(network)
}
return nil, nil
}
func (p *prefixTrie) qualifiesForPathCompression() bool {
// Current prefix trie can be path compressed if it meets all following.
// 1. records no CIDR entry
// 2. has single or no child
// 3. is not root trie
return !p.hasEntry() && p.childrenCount() <= 1 && p.parent != nil
}
func (p *prefixTrie) compressPathIfPossible() error {
if !p.qualifiesForPathCompression() {
// Does not qualify to be compressed
return nil
}
// Find lone child.
var loneChild *prefixTrie
for _, child := range p.children {
if child != nil {
loneChild = child
break
}
}
// Find root of currnt single child lineage.
parent := p.parent
for ; parent.qualifiesForPathCompression(); parent = parent.parent {
}
parentBit, err := parent.targetBitFromIP(p.network.Number)
if err != nil {
return err
}
parent.children[parentBit] = loneChild
// Attempts to furthur apply path compression at current lineage parent, in case current lineage
// compressed into parent.
return parent.compressPathIfPossible()
}
func (p *prefixTrie) childrenCount() int {
count := 0
for _, child := range p.children {
if child != nil {
count++
}
}
return count
}
func (p *prefixTrie) totalNumberOfBits() uint {
return rnet.BitsPerUint32 * uint(len(p.network.Number))
}
func (p *prefixTrie) targetBitPosition() int {
return int(p.totalNumberOfBits()-p.numBitsSkipped) - 1
}
func (p *prefixTrie) targetBitFromIP(n rnet.NetworkNumber) (uint32, error) {
// This is a safe uint boxing of int since we should never attempt to get
// target bit at a negative position.
return n.Bit(uint(p.targetBitPosition()))
}
func (p *prefixTrie) hasEntry() bool {
return p.entry != nil
}
func (p *prefixTrie) level() int {
if p.parent == nil {
return 0
}
return p.parent.level() + 1
}
// walkDepth walks the trie in depth order, for unit testing.
func (p *prefixTrie) walkDepth() <-chan RangerEntry {
entries := make(chan RangerEntry)
go func() {
if p.hasEntry() {
entries <- p.entry
}
childEntriesList := []<-chan RangerEntry{}
for _, trie := range p.children {
if trie == nil {
continue
}
childEntriesList = append(childEntriesList, trie.walkDepth())
}
for _, childEntries := range childEntriesList {
for entry := range childEntries {
entries <- entry
}
}
close(entries)
}()
return entries
}