diff --git a/codec.go b/codec.go index 64e7d04..e05bbb8 100644 --- a/codec.go +++ b/codec.go @@ -3,7 +3,6 @@ package multiaddr import ( "encoding/base32" "encoding/binary" - "errors" "fmt" "net" "strconv" @@ -52,27 +51,42 @@ func stringToBytes(s string) ([]byte, error) { return b, nil } -func bytesToString(b []byte) (ret string, err error) { +func validateBytes(b []byte) (err error) { // panic handler, in case we try accessing bytes incorrectly. - defer func() { - if e := recover(); e != nil { - ret = "" - switch e := e.(type) { - case error: - err = e - case string: - err = errors.New(e) - default: - err = fmt.Errorf("%v", e) - } + for len(b) > 0 { + code, n, err := ReadVarintCode(b) + b = b[n:] + p := ProtocolWithCode(code) + if p.Code == 0 { + return fmt.Errorf("no protocol with code %d", code) } - }() + if p.Size == 0 { + continue + } + + size, err := sizeForAddr(p, b) + if err != nil { + return err + } + + if len(b) < size { + return fmt.Errorf("invalid value for size") + } + b = b[size:] + } + + return nil +} +func bytesToString(b []byte) (ret string, err error) { s := "" for len(b) > 0 { + code, n, err := ReadVarintCode(b) + if err != nil { + return "", err + } - code, n := ReadVarintCode(b) b = b[n:] p := ProtocolWithCode(code) if p.Code == 0 { @@ -84,7 +98,11 @@ func bytesToString(b []byte) (ret string, err error) { continue } - size := sizeForAddr(p, b) + size, err := sizeForAddr(p, b) + if err != nil { + return "", err + } + a, err := addressBytesToString(p, b[:size]) if err != nil { return "", err @@ -98,36 +116,40 @@ func bytesToString(b []byte) (ret string, err error) { return s, nil } -func sizeForAddr(p Protocol, b []byte) int { +func sizeForAddr(p Protocol, b []byte) (int, error) { switch { case p.Size > 0: - return (p.Size / 8) + return (p.Size / 8), nil case p.Size == 0: - return 0 + return 0, nil default: - size, n := ReadVarintCode(b) - return size + n + size, n, err := ReadVarintCode(b) + if err != nil { + return 0, err + } + + return size + n, nil } } -func bytesSplit(b []byte) (ret [][]byte, err error) { - // panic handler, in case we try accessing bytes incorrectly. - defer func() { - if e := recover(); e != nil { - ret = [][]byte{} - err = e.(error) - } - }() - - ret = [][]byte{} +func bytesSplit(b []byte) ([][]byte, error) { + var ret [][]byte for len(b) > 0 { - code, n := ReadVarintCode(b) + code, n, err := ReadVarintCode(b) + if err != nil { + return nil, err + } + p := ProtocolWithCode(code) if p.Code == 0 { - return [][]byte{}, fmt.Errorf("no protocol with code %d", b[0]) + return nil, fmt.Errorf("no protocol with code %d", b[0]) + } + + size, err := sizeForAddr(p, b[n:]) + if err != nil { + return nil, err } - size := sizeForAddr(p, b[n:]) length := n + size ret = append(ret, b[:length]) b = b[length:] @@ -228,10 +250,14 @@ func addressBytesToString(p Protocol, b []byte) (string, error) { case P_IPFS: // ipfs // the address is a varint-prefixed multihash string representation - size, n := ReadVarintCode(b) + size, n, err := ReadVarintCode(b) + if err != nil { + return "", err + } + b = b[n:] if len(b) != size { - panic("inconsistent lengths") + return "", fmt.Errorf("inconsistent lengths") } m, err := mh.Cast(b) if err != nil { diff --git a/multiaddr.go b/multiaddr.go index 0dbcede..be04a7f 100644 --- a/multiaddr.go +++ b/multiaddr.go @@ -23,11 +23,11 @@ func NewMultiaddr(s string) (Multiaddr, error) { // NewMultiaddrBytes initializes a Multiaddr from a byte representation. // It validates it as an input string. func NewMultiaddrBytes(b []byte) (Multiaddr, error) { - s, err := bytesToString(b) - if err != nil { + if err := validateBytes(b); err != nil { return nil, err } - return NewMultiaddr(s) + + return &multiaddr{bytes: b}, nil } // Equal tests whether two multiaddrs are equal @@ -64,11 +64,14 @@ func (m *multiaddr) Protocols() []Protocol { } }() - size := 0 - ps := []Protocol{} - b := m.bytes[:] + var ps []Protocol + b := m.bytes for len(b) > 0 { - code, n := ReadVarintCode(b) + code, n, err := ReadVarintCode(b) + if err != nil { + panic(err) + } + p := ProtocolWithCode(code) if p.Code == 0 { // this is a panic (and not returning err) because this should've been @@ -78,7 +81,7 @@ func (m *multiaddr) Protocols() []Protocol { ps = append(ps, p) b = b[n:] - size = sizeForAddr(p, b) + size, err := sizeForAddr(p, b) b = b[size:] } return ps diff --git a/multiaddr_test.go b/multiaddr_test.go index 7c75274..ac81ac2 100644 --- a/multiaddr_test.go +++ b/multiaddr_test.go @@ -3,7 +3,9 @@ package multiaddr import ( "bytes" "encoding/hex" + "math/rand" "testing" + "time" ) func newMultiaddr(t *testing.T, a string) Multiaddr { @@ -342,3 +344,21 @@ func TestGetValue(t *testing.T) { assertValueForProto(t, a, P_UDP, "12345") assertValueForProto(t, a, P_UTP, "") } + +func TestFuzzBytes(t *testing.T) { + rand.Seed(time.Now().UnixNano()) + // Bump up these numbers if you want to stress this + buf := make([]byte, 256) + for i := 0; i < 2000; i++ { + l := rand.Intn(len(buf)) + rand.Read(buf[:l]) + + // just checking that it doesnt panic + ma, err := NewMultiaddrBytes(buf[:l]) + if err == nil { + // for any valid multiaddrs, make sure these calls don't panic + ma.String() + ma.Protocols() + } + } +} diff --git a/protocols.go b/protocols.go index 8364d4c..7454590 100644 --- a/protocols.go +++ b/protocols.go @@ -117,16 +117,19 @@ func CodeToVarint(num int) []byte { // VarintToCode converts a varint-encoded []byte to an integer protocol code func VarintToCode(buf []byte) int { - num, _ := ReadVarintCode(buf) + num, _, err := ReadVarintCode(buf) + if err != nil { + panic(err) + } return num } // ReadVarintCode reads a varint code from the beginning of buf. // returns the code, and the number of bytes read. -func ReadVarintCode(buf []byte) (int, int) { +func ReadVarintCode(buf []byte) (int, int, error) { num, n := binary.Uvarint(buf) if n < 0 { - panic("varints larger than uint64 not yet supported") + return 0, 0, fmt.Errorf("varints larger than uint64 not yet supported") } - return int(num), n + return int(num), n, nil }