301 lines
7.4 KiB
Go
301 lines
7.4 KiB
Go
/*
|
|
Package net provides utility functions for working with IPs (net.IP).
|
|
*/
|
|
package net
|
|
|
|
import (
|
|
"encoding/binary"
|
|
"fmt"
|
|
"math"
|
|
"net"
|
|
)
|
|
|
|
// IPVersion is version of IP address.
|
|
type IPVersion string
|
|
|
|
// Helper constants.
|
|
const (
|
|
IPv4Uint32Count = 1
|
|
IPv6Uint32Count = 4
|
|
|
|
BitsPerUint32 = 32
|
|
BytePerUint32 = 4
|
|
|
|
IPv4 IPVersion = "IPv4"
|
|
IPv6 IPVersion = "IPv6"
|
|
)
|
|
|
|
// ErrInvalidBitPosition is returned when bits requested is not valid.
|
|
var ErrInvalidBitPosition = fmt.Errorf("bit position not valid")
|
|
|
|
// ErrVersionMismatch is returned upon mismatch in network input versions.
|
|
var ErrVersionMismatch = fmt.Errorf("Network input version mismatch")
|
|
|
|
// ErrNoGreatestCommonBit is an error returned when no greatest common bit
|
|
// exists for the cidr ranges.
|
|
var ErrNoGreatestCommonBit = fmt.Errorf("No greatest common bit")
|
|
|
|
// NetworkNumber represents an IP address using uint32 as internal storage.
|
|
// IPv4 usings 1 uint32, while IPv6 uses 4 uint32.
|
|
type NetworkNumber []uint32
|
|
|
|
// NewNetworkNumber returns a equivalent NetworkNumber to given IP address,
|
|
// return nil if ip is neither IPv4 nor IPv6.
|
|
func NewNetworkNumber(ip net.IP) NetworkNumber {
|
|
if ip == nil {
|
|
return nil
|
|
}
|
|
coercedIP := ip.To4()
|
|
parts := 1
|
|
if coercedIP == nil {
|
|
coercedIP = ip.To16()
|
|
parts = 4
|
|
}
|
|
if coercedIP == nil {
|
|
return nil
|
|
}
|
|
nn := make(NetworkNumber, parts)
|
|
for i := 0; i < parts; i++ {
|
|
idx := i * net.IPv4len
|
|
nn[i] = binary.BigEndian.Uint32(coercedIP[idx : idx+net.IPv4len])
|
|
}
|
|
return nn
|
|
}
|
|
|
|
// ToV4 returns ip address if ip is IPv4, returns nil otherwise.
|
|
func (n NetworkNumber) ToV4() NetworkNumber {
|
|
if len(n) != IPv4Uint32Count {
|
|
return nil
|
|
}
|
|
return n
|
|
}
|
|
|
|
// ToV6 returns ip address if ip is IPv6, returns nil otherwise.
|
|
func (n NetworkNumber) ToV6() NetworkNumber {
|
|
if len(n) != IPv6Uint32Count {
|
|
return nil
|
|
}
|
|
return n
|
|
}
|
|
|
|
// ToIP returns equivalent net.IP.
|
|
func (n NetworkNumber) ToIP() net.IP {
|
|
ip := make(net.IP, len(n)*BytePerUint32)
|
|
for i := 0; i < len(n); i++ {
|
|
idx := i * net.IPv4len
|
|
binary.BigEndian.PutUint32(ip[idx:idx+net.IPv4len], n[i])
|
|
}
|
|
if len(ip) == net.IPv4len {
|
|
ip = net.IPv4(ip[0], ip[1], ip[2], ip[3])
|
|
}
|
|
return ip
|
|
}
|
|
|
|
// Equal is the equality test for 2 network numbers.
|
|
func (n NetworkNumber) Equal(n1 NetworkNumber) bool {
|
|
if len(n) != len(n1) {
|
|
return false
|
|
}
|
|
if n[0] != n1[0] {
|
|
return false
|
|
}
|
|
if len(n) == IPv6Uint32Count {
|
|
return n[1] == n1[1] && n[2] == n1[2] && n[3] == n1[3]
|
|
}
|
|
return true
|
|
}
|
|
|
|
// Next returns the next logical network number.
|
|
func (n NetworkNumber) Next() NetworkNumber {
|
|
newIP := make(NetworkNumber, len(n))
|
|
copy(newIP, n)
|
|
for i := len(newIP) - 1; i >= 0; i-- {
|
|
newIP[i]++
|
|
if newIP[i] > 0 {
|
|
break
|
|
}
|
|
}
|
|
return newIP
|
|
}
|
|
|
|
// Previous returns the previous logical network number.
|
|
func (n NetworkNumber) Previous() NetworkNumber {
|
|
newIP := make(NetworkNumber, len(n))
|
|
copy(newIP, n)
|
|
for i := len(newIP) - 1; i >= 0; i-- {
|
|
newIP[i]--
|
|
if newIP[i] < math.MaxUint32 {
|
|
break
|
|
}
|
|
}
|
|
return newIP
|
|
}
|
|
|
|
// Bit returns uint32 representing the bit value at given position, e.g.,
|
|
// "128.0.0.0" has bit value of 1 at position 31, and 0 for positions 30 to 0.
|
|
func (n NetworkNumber) Bit(position uint) (uint32, error) {
|
|
if int(position) > len(n)*BitsPerUint32-1 {
|
|
return 0, ErrInvalidBitPosition
|
|
}
|
|
idx := len(n) - 1 - int(position/BitsPerUint32)
|
|
// Mod 31 to get array index.
|
|
rShift := position & (BitsPerUint32 - 1)
|
|
return (n[idx] >> rShift) & 1, nil
|
|
}
|
|
|
|
// LeastCommonBitPosition returns the smallest position of the preceding common
|
|
// bits of the 2 network numbers, and returns an error ErrNoGreatestCommonBit
|
|
// if the two network number diverges from the first bit.
|
|
// e.g., if the network number diverges after the 1st bit, it returns 131 for
|
|
// IPv6 and 31 for IPv4 .
|
|
func (n NetworkNumber) LeastCommonBitPosition(n1 NetworkNumber) (uint, error) {
|
|
if len(n) != len(n1) {
|
|
return 0, ErrVersionMismatch
|
|
}
|
|
for i := 0; i < len(n); i++ {
|
|
mask := uint32(1) << 31
|
|
pos := uint(31)
|
|
for ; mask > 0; mask >>= 1 {
|
|
if n[i]&mask != n1[i]&mask {
|
|
if i == 0 && pos == 31 {
|
|
return 0, ErrNoGreatestCommonBit
|
|
}
|
|
return (pos + 1) + uint(BitsPerUint32)*uint(len(n)-i-1), nil
|
|
}
|
|
pos--
|
|
}
|
|
}
|
|
return 0, nil
|
|
}
|
|
|
|
// Network represents a block of network numbers, also known as CIDR.
|
|
type Network struct {
|
|
Number NetworkNumber
|
|
Mask NetworkNumberMask
|
|
}
|
|
|
|
// NewNetwork returns Network built using given net.IPNet.
|
|
func NewNetwork(ipNet net.IPNet) Network {
|
|
ones, _ := ipNet.Mask.Size()
|
|
return Network{
|
|
Number: NewNetworkNumber(ipNet.IP),
|
|
Mask: NetworkNumberMask(ones),
|
|
}
|
|
}
|
|
|
|
// Masked returns a new network conforming to new mask.
|
|
func (n Network) Masked(ones int) Network {
|
|
mask := NetworkNumberMask(ones)
|
|
return Network{
|
|
Number: mask.Mask(n.Number),
|
|
Mask: mask,
|
|
}
|
|
}
|
|
|
|
func sub(a, b uint8) uint8 {
|
|
res := a - b
|
|
if res > a {
|
|
res = 0
|
|
}
|
|
return res
|
|
}
|
|
|
|
func mask(m NetworkNumberMask) (mask1, mask2, mask3, mask4 uint32) {
|
|
// We're relying on overflow here.
|
|
const ones uint32 = 0xFFFFFFFF
|
|
mask1 = ones << sub(1*32, uint8(m))
|
|
mask2 = ones << sub(2*32, uint8(m))
|
|
mask3 = ones << sub(3*32, uint8(m))
|
|
mask4 = ones << sub(4*32, uint8(m))
|
|
return
|
|
}
|
|
|
|
// Contains returns true if NetworkNumber is in range of Network, false
|
|
// otherwise.
|
|
func (n Network) Contains(nn NetworkNumber) bool {
|
|
if len(n.Number) != len(nn) {
|
|
return false
|
|
}
|
|
const ones uint32 = 0xFFFFFFFF
|
|
|
|
mask1, mask2, mask3, mask4 := mask(n.Mask)
|
|
switch len(n.Number) {
|
|
case IPv4Uint32Count:
|
|
return nn[0]&mask1 == n.Number[0]
|
|
case IPv6Uint32Count:
|
|
return nn[0]&mask1 == n.Number[0] &&
|
|
nn[1]&mask2 == n.Number[1] &&
|
|
nn[2]&mask3 == n.Number[2] &&
|
|
nn[3]&mask4 == n.Number[3]
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
// Contains returns true if Network covers o, false otherwise
|
|
func (n Network) Covers(o Network) bool {
|
|
return n.Contains(o.Number) && n.Mask <= o.Mask
|
|
}
|
|
|
|
// LeastCommonBitPosition returns the smallest position of the preceding common
|
|
// bits of the 2 networks, and returns an error ErrNoGreatestCommonBit
|
|
// if the two network number diverges from the first bit.
|
|
func (n Network) LeastCommonBitPosition(n1 Network) (uint, error) {
|
|
maskSize := n.Mask
|
|
if n1.Mask < n.Mask {
|
|
maskSize = n1.Mask
|
|
}
|
|
maskPosition := len(n1.Number)*BitsPerUint32 - int(maskSize)
|
|
lcb, err := n.Number.LeastCommonBitPosition(n1.Number)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
return uint(math.Max(float64(maskPosition), float64(lcb))), nil
|
|
}
|
|
|
|
// Equal is the equality test for 2 networks.
|
|
func (n Network) Equal(n1 Network) bool {
|
|
return n.Number.Equal(n1.Number) && n.Mask == n1.Mask
|
|
}
|
|
|
|
func (n Network) String() string {
|
|
return fmt.Sprintf("%s/%d", n.Number.ToIP(), n.Mask)
|
|
}
|
|
|
|
func (n Network) IPNet() net.IPNet {
|
|
return net.IPNet{
|
|
IP: n.Number.ToIP(),
|
|
Mask: net.CIDRMask(int(n.Mask), len(n.Number)*32),
|
|
}
|
|
}
|
|
|
|
// NetworkNumberMask is an IP address.
|
|
type NetworkNumberMask int
|
|
|
|
// Mask returns a new masked NetworkNumber from given NetworkNumber.
|
|
func (m NetworkNumberMask) Mask(n NetworkNumber) NetworkNumber {
|
|
mask1, mask2, mask3, mask4 := mask(m)
|
|
|
|
result := make(NetworkNumber, len(n))
|
|
switch len(n) {
|
|
case IPv4Uint32Count:
|
|
result[0] = n[0] & mask1
|
|
case IPv6Uint32Count:
|
|
result[0] = n[0] & mask1
|
|
result[1] = n[1] & mask2
|
|
result[2] = n[2] & mask3
|
|
result[3] = n[3] & mask4
|
|
}
|
|
return result
|
|
}
|
|
|
|
// NextIP returns the next sequential ip.
|
|
func NextIP(ip net.IP) net.IP {
|
|
return NewNetworkNumber(ip).Next().ToIP()
|
|
}
|
|
|
|
// PreviousIP returns the previous sequential ip.
|
|
func PreviousIP(ip net.IP) net.IP {
|
|
return NewNetworkNumber(ip).Previous().ToIP()
|
|
}
|