91 lines
2.5 KiB
Go
91 lines
2.5 KiB
Go
package tcpreuse
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"math/rand"
|
|
"net"
|
|
|
|
"github.com/libp2p/go-netroute"
|
|
)
|
|
|
|
type multiDialer struct {
|
|
listeningAddresses []*net.TCPAddr
|
|
loopback []*net.TCPAddr
|
|
unspecified []*net.TCPAddr
|
|
fallback net.TCPAddr
|
|
}
|
|
|
|
func (d *multiDialer) Dial(network, addr string) (net.Conn, error) {
|
|
return d.DialContext(context.Background(), network, addr)
|
|
}
|
|
|
|
func randAddr(addrs []*net.TCPAddr) *net.TCPAddr {
|
|
if len(addrs) > 0 {
|
|
return addrs[rand.Intn(len(addrs))]
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// DialContext dials a target addr.
|
|
// Dialing preference is
|
|
// * If there is a listener on the local interface the OS expects to use to route towards addr, use that.
|
|
// * If there is a listener on a loopback address, addr is loopback, use that.
|
|
// * If there is a listener on an undefined address (0.0.0.0 or ::), use that.
|
|
// * Use the fallback IP specified during construction, with a port that's already being listened on, if one exists.
|
|
func (d *multiDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
tcpAddr, err := net.ResolveTCPAddr(network, addr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
ip := tcpAddr.IP
|
|
if !ip.IsLoopback() && !ip.IsGlobalUnicast() {
|
|
return nil, fmt.Errorf("undialable IP: %s", ip)
|
|
}
|
|
|
|
if router, err := netroute.New(); err == nil {
|
|
if _, _, preferredSrc, err := router.Route(ip); err == nil {
|
|
for _, optAddr := range d.listeningAddresses {
|
|
if optAddr.IP.Equal(preferredSrc) {
|
|
return reuseDial(ctx, optAddr, network, addr)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if ip.IsLoopback() && len(d.loopback) > 0 {
|
|
return reuseDial(ctx, randAddr(d.loopback), network, addr)
|
|
}
|
|
if len(d.unspecified) == 0 {
|
|
return reuseDial(ctx, &d.fallback, network, addr)
|
|
}
|
|
|
|
return reuseDial(ctx, randAddr(d.unspecified), network, addr)
|
|
}
|
|
|
|
func newMultiDialer(unspec net.IP, listeners map[*listener]struct{}) (m dialer) {
|
|
addrs := make([]*net.TCPAddr, 0)
|
|
loopback := make([]*net.TCPAddr, 0)
|
|
unspecified := make([]*net.TCPAddr, 0)
|
|
existingPort := 0
|
|
|
|
for l := range listeners {
|
|
addr := l.Addr().(*net.TCPAddr)
|
|
addrs = append(addrs, addr)
|
|
if addr.IP.IsLoopback() {
|
|
loopback = append(loopback, addr)
|
|
} else if addr.IP.IsGlobalUnicast() && existingPort == 0 {
|
|
existingPort = addr.Port
|
|
} else if addr.IP.IsUnspecified() {
|
|
unspecified = append(unspecified, addr)
|
|
}
|
|
}
|
|
m = &multiDialer{
|
|
listeningAddresses: addrs,
|
|
loopback: loopback,
|
|
unspecified: unspecified,
|
|
fallback: net.TCPAddr{IP: unspec, Port: existingPort},
|
|
}
|
|
return
|
|
}
|