2022-04-06 10:36:06 -04:00

301 lines
7.4 KiB
Go

/*
Package net provides utility functions for working with IPs (net.IP).
*/
package net
import (
"encoding/binary"
"fmt"
"math"
"net"
)
// IPVersion is version of IP address.
type IPVersion string
// Helper constants.
const (
IPv4Uint32Count = 1
IPv6Uint32Count = 4
BitsPerUint32 = 32
BytePerUint32 = 4
IPv4 IPVersion = "IPv4"
IPv6 IPVersion = "IPv6"
)
// ErrInvalidBitPosition is returned when bits requested is not valid.
var ErrInvalidBitPosition = fmt.Errorf("bit position not valid")
// ErrVersionMismatch is returned upon mismatch in network input versions.
var ErrVersionMismatch = fmt.Errorf("Network input version mismatch")
// ErrNoGreatestCommonBit is an error returned when no greatest common bit
// exists for the cidr ranges.
var ErrNoGreatestCommonBit = fmt.Errorf("No greatest common bit")
// NetworkNumber represents an IP address using uint32 as internal storage.
// IPv4 usings 1 uint32, while IPv6 uses 4 uint32.
type NetworkNumber []uint32
// NewNetworkNumber returns a equivalent NetworkNumber to given IP address,
// return nil if ip is neither IPv4 nor IPv6.
func NewNetworkNumber(ip net.IP) NetworkNumber {
if ip == nil {
return nil
}
coercedIP := ip.To4()
parts := 1
if coercedIP == nil {
coercedIP = ip.To16()
parts = 4
}
if coercedIP == nil {
return nil
}
nn := make(NetworkNumber, parts)
for i := 0; i < parts; i++ {
idx := i * net.IPv4len
nn[i] = binary.BigEndian.Uint32(coercedIP[idx : idx+net.IPv4len])
}
return nn
}
// ToV4 returns ip address if ip is IPv4, returns nil otherwise.
func (n NetworkNumber) ToV4() NetworkNumber {
if len(n) != IPv4Uint32Count {
return nil
}
return n
}
// ToV6 returns ip address if ip is IPv6, returns nil otherwise.
func (n NetworkNumber) ToV6() NetworkNumber {
if len(n) != IPv6Uint32Count {
return nil
}
return n
}
// ToIP returns equivalent net.IP.
func (n NetworkNumber) ToIP() net.IP {
ip := make(net.IP, len(n)*BytePerUint32)
for i := 0; i < len(n); i++ {
idx := i * net.IPv4len
binary.BigEndian.PutUint32(ip[idx:idx+net.IPv4len], n[i])
}
if len(ip) == net.IPv4len {
ip = net.IPv4(ip[0], ip[1], ip[2], ip[3])
}
return ip
}
// Equal is the equality test for 2 network numbers.
func (n NetworkNumber) Equal(n1 NetworkNumber) bool {
if len(n) != len(n1) {
return false
}
if n[0] != n1[0] {
return false
}
if len(n) == IPv6Uint32Count {
return n[1] == n1[1] && n[2] == n1[2] && n[3] == n1[3]
}
return true
}
// Next returns the next logical network number.
func (n NetworkNumber) Next() NetworkNumber {
newIP := make(NetworkNumber, len(n))
copy(newIP, n)
for i := len(newIP) - 1; i >= 0; i-- {
newIP[i]++
if newIP[i] > 0 {
break
}
}
return newIP
}
// Previous returns the previous logical network number.
func (n NetworkNumber) Previous() NetworkNumber {
newIP := make(NetworkNumber, len(n))
copy(newIP, n)
for i := len(newIP) - 1; i >= 0; i-- {
newIP[i]--
if newIP[i] < math.MaxUint32 {
break
}
}
return newIP
}
// Bit returns uint32 representing the bit value at given position, e.g.,
// "128.0.0.0" has bit value of 1 at position 31, and 0 for positions 30 to 0.
func (n NetworkNumber) Bit(position uint) (uint32, error) {
if int(position) > len(n)*BitsPerUint32-1 {
return 0, ErrInvalidBitPosition
}
idx := len(n) - 1 - int(position/BitsPerUint32)
// Mod 31 to get array index.
rShift := position & (BitsPerUint32 - 1)
return (n[idx] >> rShift) & 1, nil
}
// LeastCommonBitPosition returns the smallest position of the preceding common
// bits of the 2 network numbers, and returns an error ErrNoGreatestCommonBit
// if the two network number diverges from the first bit.
// e.g., if the network number diverges after the 1st bit, it returns 131 for
// IPv6 and 31 for IPv4 .
func (n NetworkNumber) LeastCommonBitPosition(n1 NetworkNumber) (uint, error) {
if len(n) != len(n1) {
return 0, ErrVersionMismatch
}
for i := 0; i < len(n); i++ {
mask := uint32(1) << 31
pos := uint(31)
for ; mask > 0; mask >>= 1 {
if n[i]&mask != n1[i]&mask {
if i == 0 && pos == 31 {
return 0, ErrNoGreatestCommonBit
}
return (pos + 1) + uint(BitsPerUint32)*uint(len(n)-i-1), nil
}
pos--
}
}
return 0, nil
}
// Network represents a block of network numbers, also known as CIDR.
type Network struct {
Number NetworkNumber
Mask NetworkNumberMask
}
// NewNetwork returns Network built using given net.IPNet.
func NewNetwork(ipNet net.IPNet) Network {
ones, _ := ipNet.Mask.Size()
return Network{
Number: NewNetworkNumber(ipNet.IP),
Mask: NetworkNumberMask(ones),
}
}
// Masked returns a new network conforming to new mask.
func (n Network) Masked(ones int) Network {
mask := NetworkNumberMask(ones)
return Network{
Number: mask.Mask(n.Number),
Mask: mask,
}
}
func sub(a, b uint8) uint8 {
res := a - b
if res > a {
res = 0
}
return res
}
func mask(m NetworkNumberMask) (mask1, mask2, mask3, mask4 uint32) {
// We're relying on overflow here.
const ones uint32 = 0xFFFFFFFF
mask1 = ones << sub(1*32, uint8(m))
mask2 = ones << sub(2*32, uint8(m))
mask3 = ones << sub(3*32, uint8(m))
mask4 = ones << sub(4*32, uint8(m))
return
}
// Contains returns true if NetworkNumber is in range of Network, false
// otherwise.
func (n Network) Contains(nn NetworkNumber) bool {
if len(n.Number) != len(nn) {
return false
}
const ones uint32 = 0xFFFFFFFF
mask1, mask2, mask3, mask4 := mask(n.Mask)
switch len(n.Number) {
case IPv4Uint32Count:
return nn[0]&mask1 == n.Number[0]
case IPv6Uint32Count:
return nn[0]&mask1 == n.Number[0] &&
nn[1]&mask2 == n.Number[1] &&
nn[2]&mask3 == n.Number[2] &&
nn[3]&mask4 == n.Number[3]
default:
return false
}
}
// Contains returns true if Network covers o, false otherwise
func (n Network) Covers(o Network) bool {
return n.Contains(o.Number) && n.Mask <= o.Mask
}
// LeastCommonBitPosition returns the smallest position of the preceding common
// bits of the 2 networks, and returns an error ErrNoGreatestCommonBit
// if the two network number diverges from the first bit.
func (n Network) LeastCommonBitPosition(n1 Network) (uint, error) {
maskSize := n.Mask
if n1.Mask < n.Mask {
maskSize = n1.Mask
}
maskPosition := len(n1.Number)*BitsPerUint32 - int(maskSize)
lcb, err := n.Number.LeastCommonBitPosition(n1.Number)
if err != nil {
return 0, err
}
return uint(math.Max(float64(maskPosition), float64(lcb))), nil
}
// Equal is the equality test for 2 networks.
func (n Network) Equal(n1 Network) bool {
return n.Number.Equal(n1.Number) && n.Mask == n1.Mask
}
func (n Network) String() string {
return fmt.Sprintf("%s/%d", n.Number.ToIP(), n.Mask)
}
func (n Network) IPNet() net.IPNet {
return net.IPNet{
IP: n.Number.ToIP(),
Mask: net.CIDRMask(int(n.Mask), len(n.Number)*32),
}
}
// NetworkNumberMask is an IP address.
type NetworkNumberMask int
// Mask returns a new masked NetworkNumber from given NetworkNumber.
func (m NetworkNumberMask) Mask(n NetworkNumber) NetworkNumber {
mask1, mask2, mask3, mask4 := mask(m)
result := make(NetworkNumber, len(n))
switch len(n) {
case IPv4Uint32Count:
result[0] = n[0] & mask1
case IPv6Uint32Count:
result[0] = n[0] & mask1
result[1] = n[1] & mask2
result[2] = n[2] & mask3
result[3] = n[3] & mask4
}
return result
}
// NextIP returns the next sequential ip.
func NextIP(ip net.IP) net.IP {
return NewNetworkNumber(ip).Next().ToIP()
}
// PreviousIP returns the previous sequential ip.
func PreviousIP(ip net.IP) net.IP {
return NewNetworkNumber(ip).Previous().ToIP()
}