cleanup addrToIP logic

This commit is contained in:
Will Scott 2020-02-25 10:24:47 -08:00
parent 10b0a942b8
commit 426d729c51
2 changed files with 16 additions and 22 deletions

View File

@ -3,7 +3,7 @@ package autonat
import ( import (
"bytes" "bytes"
"context" "context"
"fmt" "errors"
"math/rand" "math/rand"
"net" "net"
"strings" "strings"
@ -106,31 +106,25 @@ func (as *AutoNATService) handleStream(s network.Stream) {
} }
// Optimistically extract the net.IP host from a multiaddress. // Optimistically extract the net.IP host from a multiaddress.
func addrToIP(addr ma.Multiaddr) net.IP { func addrToIP(addr ma.Multiaddr) (net.IP, error) {
n, ip, err := manet.DialArgs(addr) n, ip, err := manet.DialArgs(addr)
if err != nil { if err != nil {
return nil return nil, err
} }
// if no port: if strings.HasPrefix(n, "tcp") || strings.HasPrefix(n, "udp") {
if n == "ip" || n == "ip4" || n == "ip6" { ip, _, err = net.SplitHostPort(ip)
// Strip v6 zone if it's there. } else if !strings.HasPrefix(n, "ip") {
if strings.Contains(ip, "%") { return nil, errors.New("non-ip multiaddr")
ip = ip[:strings.Index(ip, "%")]
}
return net.ParseIP(ip)
} }
ip, _, err = net.SplitHostPort(ip)
if err != nil { if err != nil {
fmt.Printf("failed to split: %v", err) return nil, err
return nil
} }
// Strip v6 zone if it's there. // Strip v6 zone if it's there.
if strings.Contains(ip, "%") { if strings.Contains(ip, "%") {
ip = ip[:strings.Index(ip, "%")] ip = ip[:strings.Index(ip, "%")]
} }
return net.ParseIP(ip) return net.ParseIP(ip), nil
} }
func (as *AutoNATService) handleDial(p peer.ID, obsaddr ma.Multiaddr, mpi *pb.Message_PeerInfo) *pb.Message_DialResponse { func (as *AutoNATService) handleDial(p peer.ID, obsaddr ma.Multiaddr, mpi *pb.Message_PeerInfo) *pb.Message_DialResponse {
@ -158,7 +152,7 @@ func (as *AutoNATService) handleDial(p peer.ID, obsaddr ma.Multiaddr, mpi *pb.Me
if !as.skipDial(obsaddr) { if !as.skipDial(obsaddr) {
addrs = append(addrs, obsaddr) addrs = append(addrs, obsaddr)
seen[obsaddr.String()] = struct{}{} seen[obsaddr.String()] = struct{}{}
obsHost = addrToIP(obsaddr) obsHost, _ = addrToIP(obsaddr)
} }
for _, maddr := range mpi.GetAddrs() { for _, maddr := range mpi.GetAddrs() {
@ -172,7 +166,7 @@ func (as *AutoNATService) handleDial(p peer.ID, obsaddr ma.Multiaddr, mpi *pb.Me
continue continue
} }
if !bytes.Equal(obsHost, addrToIP(addr)) { if ip, err := addrToIP(addr); err != nil || !bytes.Equal(obsHost, ip) {
continue continue
} }

View File

@ -139,26 +139,26 @@ func TestAutoNATServiceDialRateLimiter(t *testing.T) {
func TestAddrToIP(t *testing.T) { func TestAddrToIP(t *testing.T) {
addr, _ := ma.NewMultiaddr("/ip4/127.0.0.1/tcp/0") addr, _ := ma.NewMultiaddr("/ip4/127.0.0.1/tcp/0")
if !addrToIP(addr).Equal(net.IPv4(127, 0, 0, 1)) { if ip, err := addrToIP(addr); err != nil || !ip.Equal(net.IPv4(127, 0, 0, 1)) {
t.Fatal("addrToIP of ipv4 localhost incorrect!") t.Fatal("addrToIP of ipv4 localhost incorrect!")
} }
addr, _ = ma.NewMultiaddr("/ip4/192.168.0.1/tcp/6") addr, _ = ma.NewMultiaddr("/ip4/192.168.0.1/tcp/6")
if !addrToIP(addr).Equal(net.IPv4(192, 168, 0, 1)) { if ip, err := addrToIP(addr); err != nil || !ip.Equal(net.IPv4(192, 168, 0, 1)) {
t.Fatal("addrToIP of ipv4 incorrect!") t.Fatal("addrToIP of ipv4 incorrect!")
} }
addr, _ = ma.NewMultiaddr("/ip6/::ffff:127.0.0.1/tcp/111") addr, _ = ma.NewMultiaddr("/ip6/::ffff:127.0.0.1/tcp/111")
if !addrToIP(addr).Equal(net.ParseIP("::ffff:127.0.0.1")) { if ip, err := addrToIP(addr); err != nil || !ip.Equal(net.ParseIP("::ffff:127.0.0.1")) {
t.Fatal("addrToIP of ipv6 incorrect!") t.Fatal("addrToIP of ipv6 incorrect!")
} }
addr, _ = ma.NewMultiaddr("/ip6zone/eth0/ip6/fe80::1") addr, _ = ma.NewMultiaddr("/ip6zone/eth0/ip6/fe80::1")
if !addrToIP(addr).Equal(net.ParseIP("fe80::1")) { if ip, err := addrToIP(addr); err != nil || !ip.Equal(net.ParseIP("fe80::1")) {
t.Fatal("addrToIP of ip6zone incorrect!") t.Fatal("addrToIP of ip6zone incorrect!")
} }
addr, _ = ma.NewMultiaddr("/unix/a/b/c/d") addr, _ = ma.NewMultiaddr("/unix/a/b/c/d")
if addrToIP(addr) != nil { if _, err := addrToIP(addr); err == nil {
t.Fatal("invalid addrToIP populates") t.Fatal("invalid addrToIP populates")
} }
} }