Merge pull request #25 from whyrusleeping/cleanup-new-bytes

cleanup panics and make NewFromBytes faster
This commit is contained in:
Juan Benet 2016-05-09 21:22:45 -04:00
commit f3dff105e4
5 changed files with 187 additions and 63 deletions

View File

@ -1,9 +1,8 @@
language: go language: go
go: go:
- 1.3
- release
- tip - tip
- 1.6.1
script: script:
- go test -race -cpu=5 -v ./... - go test -race -cpu=5 -v ./...

126
codec.go
View File

@ -1,9 +1,9 @@
package multiaddr package multiaddr
import ( import (
"bytes"
"encoding/base32" "encoding/base32"
"encoding/binary" "encoding/binary"
"errors"
"fmt" "fmt"
"net" "net"
"strconv" "strconv"
@ -17,7 +17,7 @@ func stringToBytes(s string) ([]byte, error) {
// consume trailing slashes // consume trailing slashes
s = strings.TrimRight(s, "/") s = strings.TrimRight(s, "/")
b := []byte{} b := new(bytes.Buffer)
sp := strings.Split(s, "/") sp := strings.Split(s, "/")
if sp[0] != "" { if sp[0] != "" {
@ -32,7 +32,7 @@ func stringToBytes(s string) ([]byte, error) {
if p.Code == 0 { if p.Code == 0 {
return nil, fmt.Errorf("no protocol with name %s", sp[0]) return nil, fmt.Errorf("no protocol with name %s", sp[0])
} }
b = append(b, CodeToVarint(p.Code)...) b.Write(CodeToVarint(p.Code))
sp = sp[1:] sp = sp[1:]
if p.Size == 0 { // no length. if p.Size == 0 { // no length.
@ -42,37 +42,59 @@ func stringToBytes(s string) ([]byte, error) {
if len(sp) < 1 { if len(sp) < 1 {
return nil, fmt.Errorf("protocol requires address, none given: %s", p.Name) return nil, fmt.Errorf("protocol requires address, none given: %s", p.Name)
} }
a, err := addressStringToBytes(p, sp[0]) a, err := addressStringToBytes(p, sp[0])
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to parse %s: %s %s", p.Name, sp[0], err) return nil, fmt.Errorf("failed to parse %s: %s %s", p.Name, sp[0], err)
} }
b = append(b, a...) b.Write(a)
sp = sp[1:] sp = sp[1:]
} }
return b, nil
return b.Bytes(), nil
}
func validateBytes(b []byte) (err error) {
for len(b) > 0 {
code, n, err := ReadVarintCode(b)
if err != nil {
return err
}
b = b[n:]
p := ProtocolWithCode(code)
if p.Code == 0 {
return fmt.Errorf("no protocol with code %d", code)
}
if p.Size == 0 {
continue
}
size, err := sizeForAddr(p, b)
if err != nil {
return err
}
if len(b) < size || size < 0 {
return fmt.Errorf("invalid value for size")
}
b = b[size:]
}
return nil
} }
func bytesToString(b []byte) (ret string, err error) { func bytesToString(b []byte) (ret string, err error) {
// panic handler, in case we try accessing bytes incorrectly.
defer func() {
if e := recover(); e != nil {
ret = ""
switch e := e.(type) {
case error:
err = e
case string:
err = errors.New(e)
default:
err = fmt.Errorf("%v", e)
}
}
}()
s := "" s := ""
for len(b) > 0 { for len(b) > 0 {
code, n, err := ReadVarintCode(b)
if err != nil {
return "", err
}
code, n := ReadVarintCode(b)
b = b[n:] b = b[n:]
p := ProtocolWithCode(code) p := ProtocolWithCode(code)
if p.Code == 0 { if p.Code == 0 {
@ -84,7 +106,15 @@ func bytesToString(b []byte) (ret string, err error) {
continue continue
} }
size := sizeForAddr(p, b) size, err := sizeForAddr(p, b)
if err != nil {
return "", err
}
if len(b) < size || size < 0 {
return "", fmt.Errorf("invalid value for size")
}
a, err := addressBytesToString(p, b[:size]) a, err := addressBytesToString(p, b[:size])
if err != nil { if err != nil {
return "", err return "", err
@ -98,36 +128,40 @@ func bytesToString(b []byte) (ret string, err error) {
return s, nil return s, nil
} }
func sizeForAddr(p Protocol, b []byte) int { func sizeForAddr(p Protocol, b []byte) (int, error) {
switch { switch {
case p.Size > 0: case p.Size > 0:
return (p.Size / 8) return (p.Size / 8), nil
case p.Size == 0: case p.Size == 0:
return 0 return 0, nil
default: default:
size, n := ReadVarintCode(b) size, n, err := ReadVarintCode(b)
return size + n if err != nil {
return 0, err
}
return size + n, nil
} }
} }
func bytesSplit(b []byte) (ret [][]byte, err error) { func bytesSplit(b []byte) ([][]byte, error) {
// panic handler, in case we try accessing bytes incorrectly. var ret [][]byte
defer func() {
if e := recover(); e != nil {
ret = [][]byte{}
err = e.(error)
}
}()
ret = [][]byte{}
for len(b) > 0 { for len(b) > 0 {
code, n := ReadVarintCode(b) code, n, err := ReadVarintCode(b)
if err != nil {
return nil, err
}
p := ProtocolWithCode(code) p := ProtocolWithCode(code)
if p.Code == 0 { if p.Code == 0 {
return [][]byte{}, fmt.Errorf("no protocol with code %d", b[0]) return nil, fmt.Errorf("no protocol with code %d", b[0])
}
size, err := sizeForAddr(p, b[n:])
if err != nil {
return nil, err
} }
size := sizeForAddr(p, b[n:])
length := n + size length := n + size
ret = append(ret, b[:length]) ret = append(ret, b[:length])
b = b[length:] b = b[length:]
@ -228,17 +262,21 @@ func addressBytesToString(p Protocol, b []byte) (string, error) {
case P_IPFS: // ipfs case P_IPFS: // ipfs
// the address is a varint-prefixed multihash string representation // the address is a varint-prefixed multihash string representation
size, n := ReadVarintCode(b) size, n, err := ReadVarintCode(b)
if err != nil {
return "", err
}
b = b[n:] b = b[n:]
if len(b) != size { if len(b) != size {
panic("inconsistent lengths") return "", fmt.Errorf("inconsistent lengths")
} }
m, err := mh.Cast(b) m, err := mh.Cast(b)
if err != nil { if err != nil {
return "", err return "", err
} }
return m.B58String(), nil return m.B58String(), nil
default:
return "", fmt.Errorf("unknown protocol")
} }
return "", fmt.Errorf("unknown protocol")
} }

View File

@ -3,6 +3,7 @@ package multiaddr
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"log"
"strings" "strings"
) )
@ -12,7 +13,13 @@ type multiaddr struct {
} }
// NewMultiaddr parses and validates an input string, returning a *Multiaddr // NewMultiaddr parses and validates an input string, returning a *Multiaddr
func NewMultiaddr(s string) (Multiaddr, error) { func NewMultiaddr(s string) (a Multiaddr, err error) {
defer func() {
if e := recover(); e != nil {
log.Printf("Panic in NewMultiaddr on input %q: %s", s, e)
err = fmt.Errorf("%v", e)
}
}()
b, err := stringToBytes(s) b, err := stringToBytes(s)
if err != nil { if err != nil {
return nil, err return nil, err
@ -22,12 +29,19 @@ func NewMultiaddr(s string) (Multiaddr, error) {
// NewMultiaddrBytes initializes a Multiaddr from a byte representation. // NewMultiaddrBytes initializes a Multiaddr from a byte representation.
// It validates it as an input string. // It validates it as an input string.
func NewMultiaddrBytes(b []byte) (Multiaddr, error) { func NewMultiaddrBytes(b []byte) (a Multiaddr, err error) {
s, err := bytesToString(b) defer func() {
if err != nil { if e := recover(); e != nil {
log.Printf("Panic in NewMultiaddrBytes on input %q: %s", b, e)
err = fmt.Errorf("%v", e)
}
}()
if err := validateBytes(b); err != nil {
return nil, err return nil, err
} }
return NewMultiaddr(s)
return &multiaddr{bytes: b}, nil
} }
// Equal tests whether two multiaddrs are equal // Equal tests whether two multiaddrs are equal
@ -64,11 +78,14 @@ func (m *multiaddr) Protocols() []Protocol {
} }
}() }()
size := 0 var ps []Protocol
ps := []Protocol{} b := m.bytes
b := m.bytes[:]
for len(b) > 0 { for len(b) > 0 {
code, n := ReadVarintCode(b) code, n, err := ReadVarintCode(b)
if err != nil {
panic(err)
}
p := ProtocolWithCode(code) p := ProtocolWithCode(code)
if p.Code == 0 { if p.Code == 0 {
// this is a panic (and not returning err) because this should've been // this is a panic (and not returning err) because this should've been
@ -78,7 +95,11 @@ func (m *multiaddr) Protocols() []Protocol {
ps = append(ps, p) ps = append(ps, p)
b = b[n:] b = b[n:]
size = sizeForAddr(p, b) size, err := sizeForAddr(p, b)
if err != nil {
panic(err)
}
b = b[size:] b = b[size:]
} }
return ps return ps

View File

@ -3,7 +3,10 @@ package multiaddr
import ( import (
"bytes" "bytes"
"encoding/hex" "encoding/hex"
"math/rand"
"strings"
"testing" "testing"
"time"
) )
func newMultiaddr(t *testing.T, a string) Multiaddr { func newMultiaddr(t *testing.T, a string) Multiaddr {
@ -138,6 +141,10 @@ func TestStringToBytes(t *testing.T) {
if !bytes.Equal(b1, b2) { if !bytes.Equal(b1, b2) {
t.Error("failed to convert", s, "to", b1, "got", b2) t.Error("failed to convert", s, "to", b1, "got", b2)
} }
if err := validateBytes(b2); err != nil {
t.Error(err)
}
} }
testString("/ip4/127.0.0.1/udp/1234", "047f0000011104d2") testString("/ip4/127.0.0.1/udp/1234", "047f0000011104d2")
@ -153,6 +160,10 @@ func TestBytesToString(t *testing.T) {
t.Error("failed to decode hex", h) t.Error("failed to decode hex", h)
} }
if err := validateBytes(b); err != nil {
t.Error(err)
}
s2, err := bytesToString(b) s2, err := bytesToString(b)
if err != nil { if err != nil {
t.Error("failed to convert", b) t.Error("failed to convert", b)
@ -246,7 +257,7 @@ func TestProtocolsWithString(t *testing.T) {
for s, ps1 := range good { for s, ps1 := range good {
ps2, err := ProtocolsWithString(s) ps2, err := ProtocolsWithString(s)
if err != nil { if err != nil {
t.Error("ProtocolsWithString(%s) should have succeeded", s) t.Errorf("ProtocolsWithString(%s) should have succeeded", s)
} }
for i, ps1p := range ps1 { for i, ps1p := range ps1 {
@ -266,7 +277,7 @@ func TestProtocolsWithString(t *testing.T) {
for _, s := range bad { for _, s := range bad {
if _, err := ProtocolsWithString(s); err == nil { if _, err := ProtocolsWithString(s); err == nil {
t.Error("ProtocolsWithString(%s) should have failed", s) t.Errorf("ProtocolsWithString(%s) should have failed", s)
} }
} }
@ -309,7 +320,7 @@ func assertValueForProto(t *testing.T, a Multiaddr, p int, exp string) {
} }
if fv != exp { if fv != exp {
t.Fatalf("expected %q for %d in %d, but got %q instead", exp, p, a, fv) t.Fatalf("expected %q for %d in %s, but got %q instead", exp, p, a, fv)
} }
} }
@ -342,3 +353,55 @@ func TestGetValue(t *testing.T) {
assertValueForProto(t, a, P_UDP, "12345") assertValueForProto(t, a, P_UDP, "12345")
assertValueForProto(t, a, P_UTP, "") assertValueForProto(t, a, P_UTP, "")
} }
func TestFuzzBytes(t *testing.T) {
rand.Seed(time.Now().UnixNano())
// Bump up these numbers if you want to stress this
buf := make([]byte, 256)
for i := 0; i < 2000; i++ {
l := rand.Intn(len(buf))
rand.Read(buf[:l])
// just checking that it doesnt panic
ma, err := NewMultiaddrBytes(buf[:l])
if err == nil {
// for any valid multiaddrs, make sure these calls don't panic
_ = ma.String()
ma.Protocols()
}
}
}
func randMaddrString() string {
good_corpus := []string{"tcp", "ip", "udp", "ipfs", "0.0.0.0", "127.0.0.1", "12345", "QmbHVEEepCi7rn7VL7Exxpd2Ci9NNB6ifvqwhsrbRMgQFP"}
size := rand.Intn(256)
parts := make([]string, 0, size)
for i := 0; i < size; i++ {
switch rand.Intn(5) {
case 0, 1, 2:
parts = append(parts, good_corpus[rand.Intn(len(good_corpus))])
default:
badbuf := make([]byte, rand.Intn(256))
rand.Read(badbuf)
parts = append(parts, string(badbuf))
}
}
return "/" + strings.Join(parts, "/")
}
func TestFuzzString(t *testing.T) {
rand.Seed(time.Now().UnixNano())
// Bump up these numbers if you want to stress this
for i := 0; i < 2000; i++ {
// just checking that it doesnt panic
ma, err := NewMultiaddr(randMaddrString())
if err == nil {
// for any valid multiaddrs, make sure these calls don't panic
_ = ma.String()
ma.Protocols()
}
}
}

View File

@ -117,16 +117,19 @@ func CodeToVarint(num int) []byte {
// VarintToCode converts a varint-encoded []byte to an integer protocol code // VarintToCode converts a varint-encoded []byte to an integer protocol code
func VarintToCode(buf []byte) int { func VarintToCode(buf []byte) int {
num, _ := ReadVarintCode(buf) num, _, err := ReadVarintCode(buf)
if err != nil {
panic(err)
}
return num return num
} }
// ReadVarintCode reads a varint code from the beginning of buf. // ReadVarintCode reads a varint code from the beginning of buf.
// returns the code, and the number of bytes read. // returns the code, and the number of bytes read.
func ReadVarintCode(buf []byte) (int, int) { func ReadVarintCode(buf []byte) (int, int, error) {
num, n := binary.Uvarint(buf) num, n := binary.Uvarint(buf)
if n < 0 { if n < 0 {
panic("varints larger than uint64 not yet supported") return 0, 0, fmt.Errorf("varints larger than uint64 not yet supported")
} }
return int(num), n return int(num), n, nil
} }