111 lines
3.2 KiB
Go
Raw Normal View History

package tcpreuse
import (
"context"
"fmt"
"math/rand"
"net"
)
type multiDialer struct {
loopback []*net.TCPAddr
unspecified []*net.TCPAddr
global *net.TCPAddr
}
func (d *multiDialer) Dial(network, addr string) (net.Conn, error) {
return d.DialContext(context.Background(), network, addr)
}
func randAddr(addrs []*net.TCPAddr) *net.TCPAddr {
if len(addrs) > 0 {
return addrs[rand.Intn(len(addrs))]
}
return nil
}
func (d *multiDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
tcpAddr, err := net.ResolveTCPAddr(network, addr)
if err != nil {
return nil, err
}
// We pick the source *port* based on the following algorithm.
//
// 1. If we're dialing loopback, choose a source-port in order of
// preference:
// 1. A port in-use by an explicit loopback listener.
// 2. A port in-use by a listener on an unspecified address (must
// also be listening on localhost).
// 3. A port in-use by a listener on a global address. We don't have
// any other better options (other than picking a random port).
// 2. If we're dialing a global address, choose a source-port in order
// of preference:
// 1. A port in-use by a listener on an unspecified address (the most
// general case).
// 2. A port in-use by a listener on the global address.
// 3. Fail on link-local dials (go-ipfs currently forbids this and I
// figured we could try lifting this restriction later).
//
//
// Note: We *always* dial from the unspecified address (regardless of
// the port we pick). In the future, we could use netlink (on Linux) to
// figure out the right source address but we're going to punt on that.
ip := tcpAddr.IP
source := d.global
switch {
case ip.IsLoopback():
switch {
case len(d.loopback) > 0:
source = randAddr(d.loopback)
case len(d.unspecified) > 0:
source = randAddr(d.unspecified)
}
case ip.IsGlobalUnicast():
switch {
case len(d.unspecified) > 0:
source = randAddr(d.unspecified)
}
default:
return nil, fmt.Errorf("undialable IP: %s", tcpAddr.IP)
}
return reuseDial(ctx, source, network, addr)
}
func newMultiDialer(unspec net.IP, listeners map[*listener]struct{}) dialer {
m := new(multiDialer)
for l := range listeners {
laddr := l.Addr().(*net.TCPAddr)
switch {
case laddr.IP.IsLoopback():
m.loopback = append(m.loopback, laddr)
case laddr.IP.IsGlobalUnicast():
// Different global ports? Crap.
//
// The *proper* way to deal with this is to, e.g., use
// netlink to figure out which source address we would
// normally use to dial a destination address and then
// pick one of the ports we're listening on on that
// source address. However, this is a pain in the ass.
//
// Instead, we're just going to always dial from the
// unspecified address with the first global port we
// find.
//
// TODO: Port priority? Addr priority?
if m.global == nil {
m.global = &net.TCPAddr{
IP: unspec,
Port: laddr.Port,
}
} else {
log.Warning("listening on external interfaces on multiple ports, will dial from %d, not %s", m.global, laddr)
}
case laddr.IP.IsUnspecified():
m.unspecified = append(m.unspecified, laddr)
}
}
return m
}