diff --git a/transcoders.go b/transcoders.go index 86ce200..b5408f0 100644 --- a/transcoders.go +++ b/transcoders.go @@ -73,6 +73,9 @@ func ip6zoneBtS(b []byte) (string, error) { } func ip6zoneVal(b []byte) error { + if len(b) == 0 { + return fmt.Errorf("invalid length (should be > 0)") + } // Not supported as this would break multiaddrs. if bytes.IndexByte(b, '/') >= 0 { return fmt.Errorf("IPv6 zone ID contains '/': %s", string(b)) diff --git a/util.go b/util.go index f08788b..d1044da 100644 --- a/util.go +++ b/util.go @@ -4,6 +4,9 @@ import "fmt" // Split returns the sub-address portions of a multiaddr. func Split(m Multiaddr) []Multiaddr { + if _, ok := m.(*Component); ok { + return []Multiaddr{m} + } var addrs []Multiaddr ForEach(m, func(c Component) bool { addrs = append(addrs, &c) @@ -58,6 +61,11 @@ func StringCast(s string) Multiaddr { // SplitFirst returns the first component and the rest of the multiaddr. func SplitFirst(m Multiaddr) (*Component, Multiaddr) { + // Shortcut if we already have a component + if c, ok := m.(*Component); ok { + return c, nil + } + b := m.Bytes() if len(b) == 0 { return nil, nil @@ -74,6 +82,11 @@ func SplitFirst(m Multiaddr) (*Component, Multiaddr) { // SplitLast returns the rest of the multiaddr and the last component. func SplitLast(m Multiaddr) (Multiaddr, *Component) { + // Shortcut if we already have a component + if c, ok := m.(*Component); ok { + return nil, c + } + b := m.Bytes() if len(b) == 0 { return nil, nil @@ -106,6 +119,13 @@ func SplitLast(m Multiaddr) (Multiaddr, *Component) { // component on which the callback first returns will be included in the // *second* multiaddr. func SplitFunc(m Multiaddr, cb func(Component) bool) (Multiaddr, Multiaddr) { + // Shortcut if we already have a component + if c, ok := m.(*Component); ok { + if cb(*c) { + return nil, m + } + return m, nil + } b := m.Bytes() if len(b) == 0 { return nil, nil @@ -140,6 +160,12 @@ func SplitFunc(m Multiaddr, cb func(Component) bool) (Multiaddr, Multiaddr) { // // This function iterates over components *by value* to avoid allocating. func ForEach(m Multiaddr, cb func(c Component) bool) { + // Shortcut if we already have a component + if c, ok := m.(*Component); ok { + cb(*c) + return + } + b := m.Bytes() for len(b) > 0 { n, c, err := readComponent(b) diff --git a/util_test.go b/util_test.go index 3210ca1..a0ee55c 100644 --- a/util_test.go +++ b/util_test.go @@ -60,6 +60,48 @@ func TestSplitFirstLast(t *testing.T) { t.Errorf("expected %s to be %s", rest, restExp) } } + + c, err := NewComponent("ip4", "127.0.0.1") + if err != nil { + t.Fatal(err) + } + + ci, m := SplitFirst(c) + if !ci.Equal(c) || m != nil { + t.Error("split first on component failed") + } + m, ci = SplitLast(c) + if !ci.Equal(c) || m != nil { + t.Error("split last on component failed") + } + cis := Split(c) + if len(cis) != 1 || !cis[0].Equal(c) { + t.Error("split on component failed") + } + m1, m2 := SplitFunc(c, func(c Component) bool { + return true + }) + if m1 != nil || !m2.Equal(c) { + t.Error("split func(true) on component failed") + } + m1, m2 = SplitFunc(c, func(c Component) bool { + return false + }) + if !m1.Equal(c) || m2 != nil { + t.Error("split func(false) on component failed") + } + + i := 0 + ForEach(c, func(ci Component) bool { + if i != 0 { + t.Error("expected exactly one component") + } + i++ + if !ci.Equal(c) { + t.Error("foreach on component failed") + } + return true + }) } func TestSplitFunc(t *testing.T) {