/* Package net provides utility functions for working with IPs (net.IP). */ package net import ( "encoding/binary" "fmt" "math" "net" ) // IPVersion is version of IP address. type IPVersion string // Helper constants. const ( IPv4Uint32Count = 1 IPv6Uint32Count = 4 BitsPerUint32 = 32 BytePerUint32 = 4 IPv4 IPVersion = "IPv4" IPv6 IPVersion = "IPv6" ) // ErrInvalidBitPosition is returned when bits requested is not valid. var ErrInvalidBitPosition = fmt.Errorf("bit position not valid") // ErrVersionMismatch is returned upon mismatch in network input versions. var ErrVersionMismatch = fmt.Errorf("Network input version mismatch") // ErrNoGreatestCommonBit is an error returned when no greatest common bit // exists for the cidr ranges. var ErrNoGreatestCommonBit = fmt.Errorf("No greatest common bit") // NetworkNumber represents an IP address using uint32 as internal storage. // IPv4 usings 1 uint32, while IPv6 uses 4 uint32. type NetworkNumber []uint32 // NewNetworkNumber returns a equivalent NetworkNumber to given IP address, // return nil if ip is neither IPv4 nor IPv6. func NewNetworkNumber(ip net.IP) NetworkNumber { if ip == nil { return nil } coercedIP := ip.To4() parts := 1 if coercedIP == nil { coercedIP = ip.To16() parts = 4 } if coercedIP == nil { return nil } nn := make(NetworkNumber, parts) for i := 0; i < parts; i++ { idx := i * net.IPv4len nn[i] = binary.BigEndian.Uint32(coercedIP[idx : idx+net.IPv4len]) } return nn } // ToV4 returns ip address if ip is IPv4, returns nil otherwise. func (n NetworkNumber) ToV4() NetworkNumber { if len(n) != IPv4Uint32Count { return nil } return n } // ToV6 returns ip address if ip is IPv6, returns nil otherwise. func (n NetworkNumber) ToV6() NetworkNumber { if len(n) != IPv6Uint32Count { return nil } return n } // ToIP returns equivalent net.IP. func (n NetworkNumber) ToIP() net.IP { ip := make(net.IP, len(n)*BytePerUint32) for i := 0; i < len(n); i++ { idx := i * net.IPv4len binary.BigEndian.PutUint32(ip[idx:idx+net.IPv4len], n[i]) } if len(ip) == net.IPv4len { ip = net.IPv4(ip[0], ip[1], ip[2], ip[3]) } return ip } // Equal is the equality test for 2 network numbers. func (n NetworkNumber) Equal(n1 NetworkNumber) bool { if len(n) != len(n1) { return false } if n[0] != n1[0] { return false } if len(n) == IPv6Uint32Count { return n[1] == n1[1] && n[2] == n1[2] && n[3] == n1[3] } return true } // Next returns the next logical network number. func (n NetworkNumber) Next() NetworkNumber { newIP := make(NetworkNumber, len(n)) copy(newIP, n) for i := len(newIP) - 1; i >= 0; i-- { newIP[i]++ if newIP[i] > 0 { break } } return newIP } // Previous returns the previous logical network number. func (n NetworkNumber) Previous() NetworkNumber { newIP := make(NetworkNumber, len(n)) copy(newIP, n) for i := len(newIP) - 1; i >= 0; i-- { newIP[i]-- if newIP[i] < math.MaxUint32 { break } } return newIP } // Bit returns uint32 representing the bit value at given position, e.g., // "128.0.0.0" has bit value of 1 at position 31, and 0 for positions 30 to 0. func (n NetworkNumber) Bit(position uint) (uint32, error) { if int(position) > len(n)*BitsPerUint32-1 { return 0, ErrInvalidBitPosition } idx := len(n) - 1 - int(position/BitsPerUint32) // Mod 31 to get array index. rShift := position & (BitsPerUint32 - 1) return (n[idx] >> rShift) & 1, nil } // LeastCommonBitPosition returns the smallest position of the preceding common // bits of the 2 network numbers, and returns an error ErrNoGreatestCommonBit // if the two network number diverges from the first bit. // e.g., if the network number diverges after the 1st bit, it returns 131 for // IPv6 and 31 for IPv4 . func (n NetworkNumber) LeastCommonBitPosition(n1 NetworkNumber) (uint, error) { if len(n) != len(n1) { return 0, ErrVersionMismatch } for i := 0; i < len(n); i++ { mask := uint32(1) << 31 pos := uint(31) for ; mask > 0; mask >>= 1 { if n[i]&mask != n1[i]&mask { if i == 0 && pos == 31 { return 0, ErrNoGreatestCommonBit } return (pos + 1) + uint(BitsPerUint32)*uint(len(n)-i-1), nil } pos-- } } return 0, nil } // Network represents a block of network numbers, also known as CIDR. type Network struct { Number NetworkNumber Mask NetworkNumberMask } // NewNetwork returns Network built using given net.IPNet. func NewNetwork(ipNet net.IPNet) Network { ones, _ := ipNet.Mask.Size() return Network{ Number: NewNetworkNumber(ipNet.IP), Mask: NetworkNumberMask(ones), } } // Masked returns a new network conforming to new mask. func (n Network) Masked(ones int) Network { mask := NetworkNumberMask(ones) return Network{ Number: mask.Mask(n.Number), Mask: mask, } } func sub(a, b uint8) uint8 { res := a - b if res > a { res = 0 } return res } func mask(m NetworkNumberMask) (mask1, mask2, mask3, mask4 uint32) { // We're relying on overflow here. const ones uint32 = 0xFFFFFFFF mask1 = ones << sub(1*32, uint8(m)) mask2 = ones << sub(2*32, uint8(m)) mask3 = ones << sub(3*32, uint8(m)) mask4 = ones << sub(4*32, uint8(m)) return } // Contains returns true if NetworkNumber is in range of Network, false // otherwise. func (n Network) Contains(nn NetworkNumber) bool { if len(n.Number) != len(nn) { return false } const ones uint32 = 0xFFFFFFFF mask1, mask2, mask3, mask4 := mask(n.Mask) switch len(n.Number) { case IPv4Uint32Count: return nn[0]&mask1 == n.Number[0] case IPv6Uint32Count: return nn[0]&mask1 == n.Number[0] && nn[1]&mask2 == n.Number[1] && nn[2]&mask3 == n.Number[2] && nn[3]&mask4 == n.Number[3] default: return false } } // Contains returns true if Network covers o, false otherwise func (n Network) Covers(o Network) bool { return n.Contains(o.Number) && n.Mask <= o.Mask } // LeastCommonBitPosition returns the smallest position of the preceding common // bits of the 2 networks, and returns an error ErrNoGreatestCommonBit // if the two network number diverges from the first bit. func (n Network) LeastCommonBitPosition(n1 Network) (uint, error) { maskSize := n.Mask if n1.Mask < n.Mask { maskSize = n1.Mask } maskPosition := len(n1.Number)*BitsPerUint32 - int(maskSize) lcb, err := n.Number.LeastCommonBitPosition(n1.Number) if err != nil { return 0, err } return uint(math.Max(float64(maskPosition), float64(lcb))), nil } // Equal is the equality test for 2 networks. func (n Network) Equal(n1 Network) bool { return n.Number.Equal(n1.Number) && n.Mask == n1.Mask } func (n Network) String() string { return fmt.Sprintf("%s/%d", n.Number.ToIP(), n.Mask) } func (n Network) IPNet() net.IPNet { return net.IPNet{ IP: n.Number.ToIP(), Mask: net.CIDRMask(int(n.Mask), len(n.Number)*32), } } // NetworkNumberMask is an IP address. type NetworkNumberMask int // Mask returns a new masked NetworkNumber from given NetworkNumber. func (m NetworkNumberMask) Mask(n NetworkNumber) NetworkNumber { mask1, mask2, mask3, mask4 := mask(m) result := make(NetworkNumber, len(n)) switch len(n) { case IPv4Uint32Count: result[0] = n[0] & mask1 case IPv6Uint32Count: result[0] = n[0] & mask1 result[1] = n[1] & mask2 result[2] = n[2] & mask3 result[3] = n[3] & mask4 } return result } // NextIP returns the next sequential ip. func NextIP(ip net.IP) net.IP { return NewNetworkNumber(ip).Next().ToIP() } // PreviousIP returns the previous sequential ip. func PreviousIP(ip net.IP) net.IP { return NewNetworkNumber(ip).Previous().ToIP() }