160 lines
3.8 KiB
Go
160 lines
3.8 KiB
Go
// Copyright 2012 Google, Inc. All rights reserved.
|
|
//
|
|
// Use of this source code is governed by a BSD-style license
|
|
// that can be found in the LICENSE file in the root of the source
|
|
// tree.
|
|
|
|
// Originally found in
|
|
// https://github.com/google/gopacket/blob/master/routing/routing.go
|
|
// * Route selection modified to choose most selective route
|
|
// to break ties when route priority is insufficient.
|
|
package netroute
|
|
|
|
import (
|
|
"bytes"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"strings"
|
|
)
|
|
|
|
// Pulled from http://man7.org/linux/man-pages/man7/rtnetlink.7.html
|
|
// See the section on RTM_NEWROUTE, specifically 'struct rtmsg'.
|
|
type routeInfoInMemory struct {
|
|
Family byte
|
|
DstLen byte
|
|
SrcLen byte
|
|
TOS byte
|
|
|
|
Table byte
|
|
Protocol byte
|
|
Scope byte
|
|
Type byte
|
|
|
|
Flags uint32
|
|
}
|
|
|
|
// rtInfo contains information on a single route.
|
|
type rtInfo struct {
|
|
Src, Dst *net.IPNet
|
|
Gateway, PrefSrc net.IP
|
|
// We currently ignore the InputIface.
|
|
InputIface, OutputIface uint32
|
|
Priority uint32
|
|
}
|
|
|
|
// routeSlice implements sort.Interface to sort routes by Priority.
|
|
type routeSlice []*rtInfo
|
|
|
|
func (r routeSlice) Len() int {
|
|
return len(r)
|
|
}
|
|
func (r routeSlice) Less(i, j int) bool {
|
|
return r[i].Priority < r[j].Priority
|
|
}
|
|
func (r routeSlice) Swap(i, j int) {
|
|
r[i], r[j] = r[j], r[i]
|
|
}
|
|
|
|
type router struct {
|
|
ifaces map[int]net.Interface
|
|
addrs map[int]ipAddrs
|
|
v4, v6 routeSlice
|
|
}
|
|
|
|
func (r *router) String() string {
|
|
strs := []string{"ROUTER", "--- V4 ---"}
|
|
for _, route := range r.v4 {
|
|
strs = append(strs, fmt.Sprintf("%+v", *route))
|
|
}
|
|
strs = append(strs, "--- V6 ---")
|
|
for _, route := range r.v6 {
|
|
strs = append(strs, fmt.Sprintf("%+v", *route))
|
|
}
|
|
return strings.Join(strs, "\n")
|
|
}
|
|
|
|
type ipAddrs struct {
|
|
v4, v6 net.IP
|
|
}
|
|
|
|
func (r *router) Route(dst net.IP) (iface *net.Interface, gateway, preferredSrc net.IP, err error) {
|
|
return r.RouteWithSrc(nil, nil, dst)
|
|
}
|
|
|
|
func (r *router) RouteWithSrc(input net.HardwareAddr, src, dst net.IP) (iface *net.Interface, gateway, preferredSrc net.IP, err error) {
|
|
var ifaceIndex int
|
|
switch {
|
|
case dst.To4() != nil:
|
|
ifaceIndex, gateway, preferredSrc, err = r.route(r.v4, input, src, dst)
|
|
case dst.To16() != nil:
|
|
ifaceIndex, gateway, preferredSrc, err = r.route(r.v6, input, src, dst)
|
|
default:
|
|
err = errors.New("IP is not valid as IPv4 or IPv6")
|
|
return
|
|
}
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
// Interfaces are 1-indexed, but we store them in a 0-indexed array.
|
|
correspondingIface, ok := r.ifaces[ifaceIndex]
|
|
if !ok {
|
|
err = errors.New("Route refereced unknown interface")
|
|
}
|
|
iface = &correspondingIface
|
|
|
|
if preferredSrc == nil {
|
|
switch {
|
|
case dst.To4() != nil:
|
|
preferredSrc = r.addrs[ifaceIndex].v4
|
|
case dst.To16() != nil:
|
|
preferredSrc = r.addrs[ifaceIndex].v6
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
func (r *router) route(routes routeSlice, input net.HardwareAddr, src, dst net.IP) (iface int, gateway, preferredSrc net.IP, err error) {
|
|
var inputIndex uint32
|
|
if input != nil {
|
|
for i, iface := range r.ifaces {
|
|
if bytes.Equal(input, iface.HardwareAddr) {
|
|
// Convert from zero- to one-indexed.
|
|
inputIndex = uint32(i + 1)
|
|
break
|
|
}
|
|
}
|
|
}
|
|
var mostSpecificRt *rtInfo
|
|
for _, rt := range routes {
|
|
if rt.InputIface != 0 && rt.InputIface != inputIndex {
|
|
continue
|
|
}
|
|
if src != nil && rt.Src != nil && !rt.Src.Contains(src) {
|
|
continue
|
|
}
|
|
if rt.Dst != nil && !rt.Dst.Contains(dst) {
|
|
continue
|
|
}
|
|
if mostSpecificRt != nil {
|
|
var candSpec, curSpec int
|
|
if rt.Dst != nil {
|
|
candSpec, _ = rt.Dst.Mask.Size()
|
|
}
|
|
if mostSpecificRt.Dst != nil {
|
|
curSpec, _ = mostSpecificRt.Dst.Mask.Size()
|
|
}
|
|
if candSpec < curSpec {
|
|
continue
|
|
}
|
|
}
|
|
mostSpecificRt = rt
|
|
}
|
|
if mostSpecificRt != nil {
|
|
return int(mostSpecificRt.OutputIface), mostSpecificRt.Gateway, mostSpecificRt.PrefSrc, nil
|
|
}
|
|
err = fmt.Errorf("no route found for %v", dst)
|
|
return
|
|
}
|