diff --git a/codec.go b/codec.go index fe672ac..6234f6e 100644 --- a/codec.go +++ b/codec.go @@ -96,54 +96,58 @@ func validateBytes(b []byte) (err error) { return nil } -func bytesToString(b []byte) (ret string, err error) { - s := "" +func readComponent(b []byte) (int, Component, error) { + var offset int + code, n, err := ReadVarintCode(b) + if err != nil { + return 0, Component{}, err + } + offset += n - for len(b) > 0 { - code, n, err := ReadVarintCode(b) - if err != nil { - return "", err - } - - b = b[n:] - p := ProtocolWithCode(code) - if p.Code == 0 { - return "", fmt.Errorf("no protocol with code %d", code) - } - s += "/" + p.Name - - if p.Size == 0 { - continue - } - - n, size, err := sizeForAddr(p, b) - if err != nil { - return "", err - } - - b = b[n:] - - if len(b) < size || size < 0 { - return "", fmt.Errorf("invalid value for size") - } - - if p.Transcoder == nil { - return "", fmt.Errorf("no transcoder for %s protocol", p.Name) - } - a, err := p.Transcoder.BytesToString(b[:size]) - if err != nil { - return "", err - } - if p.Path && len(a) > 0 && a[0] == '/' { - a = a[1:] - } - if len(a) > 0 { - s += "/" + a - } - b = b[size:] + p := ProtocolWithCode(code) + if p.Code == 0 { + return 0, Component{}, fmt.Errorf("no protocol with code %d", code) } - return s, nil + if p.Size == 0 { + return offset, Component{ + bytes: b[:offset], + offset: offset, + protocol: p, + }, nil + } + + n, size, err := sizeForAddr(p, b[offset:]) + if err != nil { + return 0, Component{}, err + } + + offset += n + + if len(b[offset:]) < size || size < 0 { + return 0, Component{}, fmt.Errorf("invalid value for size") + } + + return offset + size, Component{ + bytes: b[:offset+size], + protocol: p, + offset: offset, + }, nil +} + +func bytesToString(b []byte) (ret string, err error) { + var buf strings.Builder + + for len(b) > 0 { + n, c, err := readComponent(b) + if err != nil { + return "", err + } + b = b[n:] + c.writeTo(&buf) + } + + return buf.String(), nil } func sizeForAddr(p Protocol, b []byte) (skip, size int, err error) { diff --git a/component.go b/component.go new file mode 100644 index 0000000..e6cc35a --- /dev/null +++ b/component.go @@ -0,0 +1,149 @@ +package multiaddr + +import ( + "bytes" + "encoding/binary" + "fmt" + "strings" +) + +// Component is a single multiaddr Component. +type Component struct { + bytes []byte + protocol Protocol + offset int +} + +func (c *Component) Bytes() []byte { + return c.bytes +} + +func (c *Component) Equal(o Multiaddr) bool { + return bytes.Equal(c.bytes, o.Bytes()) +} + +func (c *Component) Protocols() []Protocol { + return []Protocol{c.protocol} +} + +func (c *Component) Decapsulate(o Multiaddr) Multiaddr { + if c.Equal(o) { + return nil + } + return c +} + +func (c *Component) Encapsulate(o Multiaddr) Multiaddr { + m := multiaddr{bytes: c.bytes} + return m.Encapsulate(o) +} + +func (c *Component) ValueForProtocol(code int) (string, error) { + if c.protocol.Code != code { + return "", ErrProtocolNotFound + } + return c.Value(), nil +} + +func (c *Component) Protocol() Protocol { + return c.protocol +} + +func (c *Component) RawValue() []byte { + return c.bytes[c.offset:] +} + +func (c *Component) Value() string { + if c.protocol.Transcoder == nil { + return "" + } + value, err := c.protocol.Transcoder.BytesToString(c.bytes[c.offset:]) + if err != nil { + // This Component must have been checked. + panic(err) + } + return value +} + +func (c *Component) String() string { + var b strings.Builder + c.writeTo(&b) + return b.String() +} + +// writeTo is an efficient, private function for string-formatting a multiaddr. +// Trust me, we tend to allocate a lot when doing this. +func (c *Component) writeTo(b *strings.Builder) { + b.WriteByte('/') + b.WriteString(c.protocol.Name) + value := c.Value() + if len(value) == 0 { + return + } + if !(c.protocol.Path && value[0] == '/') { + b.WriteByte('/') + } + b.WriteString(value) +} + +// NewComponent constructs a new multiaddr component +func NewComponent(protocol, value string) (*Component, error) { + p := ProtocolWithName(protocol) + if p.Code == 0 { + return nil, fmt.Errorf("unsupported protocol: %s", protocol) + } + if p.Transcoder != nil { + bts, err := p.Transcoder.StringToBytes(value) + if err != nil { + return nil, err + } + return newComponent(p, bts), nil + } else if value != "" { + return nil, fmt.Errorf("protocol %s doesn't take a value", p.Name) + } + return newComponent(p, nil), nil + // TODO: handle path /? +} + +func newComponent(protocol Protocol, bvalue []byte) *Component { + size := len(bvalue) + size += len(protocol.VCode) + if protocol.Size < 0 { + size += VarintSize(len(bvalue)) + } + maddr := make([]byte, size) + var offset int + offset += copy(maddr[offset:], protocol.VCode) + if protocol.Size < 0 { + offset += binary.PutUvarint(maddr[offset:], uint64(len(bvalue))) + } + copy(maddr[offset:], bvalue) + + // For debugging + if len(maddr) != offset+len(bvalue) { + panic("incorrect length") + } + + return &Component{ + bytes: maddr, + protocol: protocol, + offset: offset, + } +} + +// ForEach walks over the multiaddr, component by component. +// +// This function iterates over components *by value* to avoid allocating. +func ForEach(m Multiaddr, cb func(c Component) bool) { + b := m.Bytes() + for len(b) > 0 { + n, c, err := readComponent(b) + if err != nil { + panic(err) + } + if !cb(c) { + return + } + b = b[n:] + } +} diff --git a/interface.go b/interface.go index 1f46184..34bffd9 100644 --- a/interface.go +++ b/interface.go @@ -43,5 +43,8 @@ type Multiaddr interface { Decapsulate(Multiaddr) Multiaddr // ValueForProtocol returns the value (if any) following the specified protocol + // + // Note: protocols can appear multiple times in a single multiaddr. + // Consider using `ForEach` to walk over the addr manually. ValueForProtocol(code int) (string, error) } diff --git a/multiaddr.go b/multiaddr.go index 9b5c251..2c07dd3 100644 --- a/multiaddr.go +++ b/multiaddr.go @@ -127,16 +127,15 @@ func (m multiaddr) Decapsulate(o Multiaddr) Multiaddr { var ErrProtocolNotFound = fmt.Errorf("protocol not found in multiaddr") -func (m multiaddr) ValueForProtocol(code int) (string, error) { - for _, sub := range Split(m) { - p := sub.Protocols()[0] - if p.Code == code { - if p.Size == 0 { - return "", nil - } - return strings.SplitN(sub.String(), "/", 3)[2], nil +func (m multiaddr) ValueForProtocol(code int) (value string, err error) { + err = ErrProtocolNotFound + ForEach(m, func(c Component) bool { + if c.Protocol().Code == code { + value = c.Value() + err = nil + return false } - } - - return "", ErrProtocolNotFound + return true + }) + return } diff --git a/util.go b/util.go index 49eff9d..7a2c414 100644 --- a/util.go +++ b/util.go @@ -4,15 +4,11 @@ import "fmt" // Split returns the sub-address portions of a multiaddr. func Split(m Multiaddr) []Multiaddr { - split, err := bytesSplit(m.Bytes()) - if err != nil { - panic(fmt.Errorf("invalid multiaddr %s", m.String())) - } - - addrs := make([]Multiaddr, len(split)) - for i, addr := range split { - addrs[i] = multiaddr{bytes: addr} - } + var addrs []Multiaddr + ForEach(m, func(c Component) bool { + addrs = append(addrs, &c) + return true + }) return addrs }