diff --git a/convert.go b/convert.go index 50c6fff..8e05842 100644 --- a/convert.go +++ b/convert.go @@ -4,6 +4,7 @@ import ( "fmt" "net" "strings" + "sync" ma "github.com/jbenet/go-multiaddr" utp "github.com/jbenet/go-multiaddr-net/utp" @@ -11,105 +12,190 @@ import ( var errIncorrectNetAddr = fmt.Errorf("incorrect network addr conversion") +type AddrParser func(a net.Addr) (ma.Multiaddr, error) +type MaddrParser func(ma ma.Multiaddr) (net.Addr, error) + +var maddrParsers map[string]MaddrParser +var addrParsers map[string]AddrParser +var addrParsersLock sync.Mutex + +func init() { + addrParsers = make(map[string]AddrParser) + maddrParsers = make(map[string]MaddrParser) + + registerDefaultAddrParsers() + registerDefaultMaddrParsers() +} + +func registerDefaultAddrParsers() { + funcs := map[string]AddrParser{ + "tcp": ParseTcpNetAddr, + "udp": ParseUdpNetAddr, + "utp": ParseUtpNetAddr, + "ip": ParseIpNetAddr, + } + + for k, v := range funcs { + addrParsers[k] = v + addrParsers[k+"4"] = v + addrParsers[k+"6"] = v + } + + addrParsers["ip+net"] = ParseIpPlusNetAddr +} + +func RegisterAddressType(netname, maname string, ap AddrParser, mp MaddrParser) { + addrParsersLock.Lock() + defer addrParsersLock.Unlock() + addrParsers[netname] = ap + maddrParsers[maname] = mp +} + +func ParseTcpNetAddr(a net.Addr) (ma.Multiaddr, error) { + ac, ok := a.(*net.TCPAddr) + if !ok { + return nil, errIncorrectNetAddr + } + + // Get IP Addr + ipm, err := FromIP(ac.IP) + if err != nil { + return nil, errIncorrectNetAddr + } + + // Get TCP Addr + tcpm, err := ma.NewMultiaddr(fmt.Sprintf("/tcp/%d", ac.Port)) + if err != nil { + return nil, errIncorrectNetAddr + } + + // Encapsulate + return ipm.Encapsulate(tcpm), nil +} + +func ParseUdpNetAddr(a net.Addr) (ma.Multiaddr, error) { + ac, ok := a.(*net.UDPAddr) + if !ok { + return nil, errIncorrectNetAddr + } + + // Get IP Addr + ipm, err := FromIP(ac.IP) + if err != nil { + return nil, errIncorrectNetAddr + } + + // Get UDP Addr + udpm, err := ma.NewMultiaddr(fmt.Sprintf("/udp/%d", ac.Port)) + if err != nil { + return nil, errIncorrectNetAddr + } + + // Encapsulate + return ipm.Encapsulate(udpm), nil +} + +func ParseUtpNetAddr(a net.Addr) (ma.Multiaddr, error) { + acc, ok := a.(*utp.Addr) + if !ok { + return nil, errIncorrectNetAddr + } + + // Get UDP Addr + ac, ok := acc.Child().(*net.UDPAddr) + if !ok { + return nil, errIncorrectNetAddr + } + + // Get IP Addr + ipm, err := FromIP(ac.IP) + if err != nil { + return nil, errIncorrectNetAddr + } + + // Get UDP Addr + utpm, err := ma.NewMultiaddr(fmt.Sprintf("/udp/%d/utp", ac.Port)) + if err != nil { + return nil, errIncorrectNetAddr + } + + // Encapsulate + return ipm.Encapsulate(utpm), nil +} + +func ParseIpNetAddr(a net.Addr) (ma.Multiaddr, error) { + ac, ok := a.(*net.IPAddr) + if !ok { + return nil, errIncorrectNetAddr + } + return FromIP(ac.IP) +} + +func ParseIpPlusNetAddr(a net.Addr) (ma.Multiaddr, error) { + ac, ok := a.(*net.IPNet) + if !ok { + return nil, errIncorrectNetAddr + } + return FromIP(ac.IP) +} + +func getAddrParser(net string) (AddrParser, error) { + addrParsersLock.Lock() + defer addrParsersLock.Unlock() + + parser, ok := addrParsers[net] + if !ok { + return nil, fmt.Errorf("unknown network %v", net) + } + return parser, nil +} + // FromNetAddr converts a net.Addr type to a Multiaddr. func FromNetAddr(a net.Addr) (ma.Multiaddr, error) { if a == nil { return nil, fmt.Errorf("nil multiaddr") } - - switch a.Network() { - case "tcp", "tcp4", "tcp6": - ac, ok := a.(*net.TCPAddr) - if !ok { - return nil, errIncorrectNetAddr - } - - // Get IP Addr - ipm, err := FromIP(ac.IP) - if err != nil { - return nil, errIncorrectNetAddr - } - - // Get TCP Addr - tcpm, err := ma.NewMultiaddr(fmt.Sprintf("/tcp/%d", ac.Port)) - if err != nil { - return nil, errIncorrectNetAddr - } - - // Encapsulate - return ipm.Encapsulate(tcpm), nil - - case "udp", "upd4", "udp6": - ac, ok := a.(*net.UDPAddr) - if !ok { - return nil, errIncorrectNetAddr - } - - // Get IP Addr - ipm, err := FromIP(ac.IP) - if err != nil { - return nil, errIncorrectNetAddr - } - - // Get UDP Addr - udpm, err := ma.NewMultiaddr(fmt.Sprintf("/udp/%d", ac.Port)) - if err != nil { - return nil, errIncorrectNetAddr - } - - // Encapsulate - return ipm.Encapsulate(udpm), nil - - case "utp", "utp4", "utp6": - acc, ok := a.(*utp.Addr) - if !ok { - return nil, errIncorrectNetAddr - } - - // Get UDP Addr - ac, ok := acc.Child().(*net.UDPAddr) - if !ok { - return nil, errIncorrectNetAddr - } - - // Get IP Addr - ipm, err := FromIP(ac.IP) - if err != nil { - return nil, errIncorrectNetAddr - } - - // Get UDP Addr - utpm, err := ma.NewMultiaddr(fmt.Sprintf("/udp/%d/utp", ac.Port)) - if err != nil { - return nil, errIncorrectNetAddr - } - - // Encapsulate - return ipm.Encapsulate(utpm), nil - - case "ip", "ip4", "ip6": - ac, ok := a.(*net.IPAddr) - if !ok { - return nil, errIncorrectNetAddr - } - return FromIP(ac.IP) - - case "ip+net": - ac, ok := a.(*net.IPNet) - if !ok { - return nil, errIncorrectNetAddr - } - return FromIP(ac.IP) - - default: - return nil, fmt.Errorf("unknown network %v", a.Network()) + p, err := getAddrParser(a.Network()) + if err != nil { + return nil, err } + + return p(a) +} + +func getMaddrParser(name string) (MaddrParser, error) { + addrParsersLock.Lock() + defer addrParsersLock.Unlock() + p, ok := maddrParsers[name] + if !ok { + return nil, fmt.Errorf("network not supported: %s", name) + } + + return p, nil } // ToNetAddr converts a Multiaddr to a net.Addr // Must be ThinWaist. acceptable protocol stacks are: // /ip{4,6}/{tcp, udp} func ToNetAddr(maddr ma.Multiaddr) (net.Addr, error) { + protos := maddr.Protocols() + final := protos[len(protos)-1] + + p, err := getMaddrParser(final.Name) + if err != nil { + return nil, err + } + + return p(maddr) +} + +func registerDefaultMaddrParsers() { + for _, net := range []string{"tcp", "udp", "utp", "ip", "ip4", "ip6"} { + maddrParsers[net] = parseBasicNetAddr + } +} + +func parseBasicNetAddr(maddr ma.Multiaddr) (net.Addr, error) { network, host, err := DialArgs(maddr) if err != nil { return nil, err