diff --git a/component.go b/component.go index e6cc35a..e31fe5d 100644 --- a/component.go +++ b/component.go @@ -130,20 +130,3 @@ func newComponent(protocol Protocol, bvalue []byte) *Component { offset: offset, } } - -// ForEach walks over the multiaddr, component by component. -// -// This function iterates over components *by value* to avoid allocating. -func ForEach(m Multiaddr, cb func(c Component) bool) { - b := m.Bytes() - for len(b) > 0 { - n, c, err := readComponent(b) - if err != nil { - panic(err) - } - if !cb(c) { - return - } - b = b[n:] - } -} diff --git a/util.go b/util.go index 7a2c414..f08788b 100644 --- a/util.go +++ b/util.go @@ -55,3 +55,100 @@ func StringCast(s string) Multiaddr { } return m } + +// SplitFirst returns the first component and the rest of the multiaddr. +func SplitFirst(m Multiaddr) (*Component, Multiaddr) { + b := m.Bytes() + if len(b) == 0 { + return nil, nil + } + n, c, err := readComponent(b) + if err != nil { + panic(err) + } + if len(b) == n { + return &c, nil + } + return &c, multiaddr{b[n:]} +} + +// SplitLast returns the rest of the multiaddr and the last component. +func SplitLast(m Multiaddr) (Multiaddr, *Component) { + b := m.Bytes() + if len(b) == 0 { + return nil, nil + } + + var ( + c Component + err error + offset int + ) + for { + var n int + n, c, err = readComponent(b[offset:]) + if err != nil { + panic(err) + } + if len(b) == n+offset { + // Reached end + if offset == 0 { + // Only one component + return nil, &c + } + return multiaddr{b[:offset]}, &c + } + offset += n + } +} + +// SplitFunc splits the multiaddr when the callback first returns true. The +// component on which the callback first returns will be included in the +// *second* multiaddr. +func SplitFunc(m Multiaddr, cb func(Component) bool) (Multiaddr, Multiaddr) { + b := m.Bytes() + if len(b) == 0 { + return nil, nil + } + var ( + c Component + err error + offset int + ) + for offset < len(b) { + var n int + n, c, err = readComponent(b[offset:]) + if err != nil { + panic(err) + } + if cb(c) { + break + } + offset += n + } + switch offset { + case 0: + return nil, m + case len(b): + return m, nil + default: + return multiaddr{b[:offset]}, multiaddr{b[offset:]} + } +} + +// ForEach walks over the multiaddr, component by component. +// +// This function iterates over components *by value* to avoid allocating. +func ForEach(m Multiaddr, cb func(c Component) bool) { + b := m.Bytes() + for len(b) > 0 { + n, c, err := readComponent(b) + if err != nil { + panic(err) + } + if !cb(c) { + return + } + b = b[n:] + } +} diff --git a/util_test.go b/util_test.go new file mode 100644 index 0000000..3210ca1 --- /dev/null +++ b/util_test.go @@ -0,0 +1,101 @@ +package multiaddr + +import ( + "strings" + "testing" +) + +func TestSplitFirstLast(t *testing.T) { + ipStr := "/ip4/0.0.0.0" + tcpStr := "/tcp/123" + quicStr := "/quic" + ipfsStr := "/ipfs/QmPSQnBKM9g7BaUcZCvswUJVscQ1ipjmwxN5PXCjkp9EQ7" + + for _, x := range [][]string{ + []string{ipStr, tcpStr, quicStr, ipfsStr}, + []string{ipStr, tcpStr, ipfsStr}, + []string{ipStr, tcpStr}, + []string{ipStr}, + []string{}, + } { + addr := StringCast(strings.Join(x, "")) + head, tail := SplitFirst(addr) + rest, last := SplitLast(addr) + if len(x) == 0 { + if head != nil { + t.Error("expected head to be nil") + } + if tail != nil { + t.Error("expected tail to be nil") + } + if rest != nil { + t.Error("expected rest to be nil") + } + if last != nil { + t.Error("expected last to be nil") + } + continue + } + if !head.Equal(StringCast(x[0])) { + t.Errorf("expected %s to be %s", head, x[0]) + } + if !last.Equal(StringCast(x[len(x)-1])) { + t.Errorf("expected %s to be %s", head, x[len(x)-1]) + } + if len(x) == 1 { + if tail != nil { + t.Error("expected tail to be nil") + } + if rest != nil { + t.Error("expected rest to be nil") + } + continue + } + tailExp := strings.Join(x[1:], "") + if !tail.Equal(StringCast(tailExp)) { + t.Errorf("expected %s to be %s", tail, tailExp) + } + restExp := strings.Join(x[:len(x)-1], "") + if !rest.Equal(StringCast(restExp)) { + t.Errorf("expected %s to be %s", rest, restExp) + } + } +} + +func TestSplitFunc(t *testing.T) { + ipStr := "/ip4/0.0.0.0" + tcpStr := "/tcp/123" + quicStr := "/quic" + ipfsStr := "/ipfs/QmPSQnBKM9g7BaUcZCvswUJVscQ1ipjmwxN5PXCjkp9EQ7" + + for _, x := range [][]string{ + []string{ipStr, tcpStr, quicStr, ipfsStr}, + []string{ipStr, tcpStr, ipfsStr}, + []string{ipStr, tcpStr}, + []string{ipStr}, + } { + addr := StringCast(strings.Join(x, "")) + for i, cs := range x { + target := StringCast(cs) + a, b := SplitFunc(addr, func(c Component) bool { + return c.Equal(target) + }) + if i == 0 { + if a != nil { + t.Error("expected nil addr") + } + } else { + if !a.Equal(StringCast(strings.Join(x[:i], ""))) { + t.Error("split failed") + } + if !b.Equal(StringCast(strings.Join(x[i:], ""))) { + t.Error("split failed") + } + } + } + a, b := SplitFunc(addr, func(_ Component) bool { return false }) + if !a.Equal(addr) || b != nil { + t.Error("should not have split") + } + } +}