cleanup panics and make NewFromBytes faster
This commit is contained in:
parent
41d1117052
commit
9c4a0baf6d
98
codec.go
98
codec.go
|
@ -3,7 +3,6 @@ package multiaddr
|
|||
import (
|
||||
"encoding/base32"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
|
@ -52,27 +51,42 @@ func stringToBytes(s string) ([]byte, error) {
|
|||
return b, nil
|
||||
}
|
||||
|
||||
func bytesToString(b []byte) (ret string, err error) {
|
||||
func validateBytes(b []byte) (err error) {
|
||||
// panic handler, in case we try accessing bytes incorrectly.
|
||||
defer func() {
|
||||
if e := recover(); e != nil {
|
||||
ret = ""
|
||||
switch e := e.(type) {
|
||||
case error:
|
||||
err = e
|
||||
case string:
|
||||
err = errors.New(e)
|
||||
default:
|
||||
err = fmt.Errorf("%v", e)
|
||||
}
|
||||
for len(b) > 0 {
|
||||
code, n, err := ReadVarintCode(b)
|
||||
b = b[n:]
|
||||
p := ProtocolWithCode(code)
|
||||
if p.Code == 0 {
|
||||
return fmt.Errorf("no protocol with code %d", code)
|
||||
}
|
||||
}()
|
||||
|
||||
if p.Size == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
size, err := sizeForAddr(p, b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(b) < size {
|
||||
return fmt.Errorf("invalid value for size")
|
||||
}
|
||||
b = b[size:]
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
func bytesToString(b []byte) (ret string, err error) {
|
||||
s := ""
|
||||
|
||||
for len(b) > 0 {
|
||||
code, n, err := ReadVarintCode(b)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
code, n := ReadVarintCode(b)
|
||||
b = b[n:]
|
||||
p := ProtocolWithCode(code)
|
||||
if p.Code == 0 {
|
||||
|
@ -84,7 +98,11 @@ func bytesToString(b []byte) (ret string, err error) {
|
|||
continue
|
||||
}
|
||||
|
||||
size := sizeForAddr(p, b)
|
||||
size, err := sizeForAddr(p, b)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
a, err := addressBytesToString(p, b[:size])
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
@ -98,36 +116,40 @@ func bytesToString(b []byte) (ret string, err error) {
|
|||
return s, nil
|
||||
}
|
||||
|
||||
func sizeForAddr(p Protocol, b []byte) int {
|
||||
func sizeForAddr(p Protocol, b []byte) (int, error) {
|
||||
switch {
|
||||
case p.Size > 0:
|
||||
return (p.Size / 8)
|
||||
return (p.Size / 8), nil
|
||||
case p.Size == 0:
|
||||
return 0
|
||||
return 0, nil
|
||||
default:
|
||||
size, n := ReadVarintCode(b)
|
||||
return size + n
|
||||
size, n, err := ReadVarintCode(b)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return size + n, nil
|
||||
}
|
||||
}
|
||||
|
||||
func bytesSplit(b []byte) (ret [][]byte, err error) {
|
||||
// panic handler, in case we try accessing bytes incorrectly.
|
||||
defer func() {
|
||||
if e := recover(); e != nil {
|
||||
ret = [][]byte{}
|
||||
err = e.(error)
|
||||
}
|
||||
}()
|
||||
|
||||
ret = [][]byte{}
|
||||
func bytesSplit(b []byte) ([][]byte, error) {
|
||||
var ret [][]byte
|
||||
for len(b) > 0 {
|
||||
code, n := ReadVarintCode(b)
|
||||
code, n, err := ReadVarintCode(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
p := ProtocolWithCode(code)
|
||||
if p.Code == 0 {
|
||||
return [][]byte{}, fmt.Errorf("no protocol with code %d", b[0])
|
||||
return nil, fmt.Errorf("no protocol with code %d", b[0])
|
||||
}
|
||||
|
||||
size, err := sizeForAddr(p, b[n:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
size := sizeForAddr(p, b[n:])
|
||||
length := n + size
|
||||
ret = append(ret, b[:length])
|
||||
b = b[length:]
|
||||
|
@ -228,10 +250,14 @@ func addressBytesToString(p Protocol, b []byte) (string, error) {
|
|||
|
||||
case P_IPFS: // ipfs
|
||||
// the address is a varint-prefixed multihash string representation
|
||||
size, n := ReadVarintCode(b)
|
||||
size, n, err := ReadVarintCode(b)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
b = b[n:]
|
||||
if len(b) != size {
|
||||
panic("inconsistent lengths")
|
||||
return "", fmt.Errorf("inconsistent lengths")
|
||||
}
|
||||
m, err := mh.Cast(b)
|
||||
if err != nil {
|
||||
|
|
19
multiaddr.go
19
multiaddr.go
|
@ -23,11 +23,11 @@ func NewMultiaddr(s string) (Multiaddr, error) {
|
|||
// NewMultiaddrBytes initializes a Multiaddr from a byte representation.
|
||||
// It validates it as an input string.
|
||||
func NewMultiaddrBytes(b []byte) (Multiaddr, error) {
|
||||
s, err := bytesToString(b)
|
||||
if err != nil {
|
||||
if err := validateBytes(b); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewMultiaddr(s)
|
||||
|
||||
return &multiaddr{bytes: b}, nil
|
||||
}
|
||||
|
||||
// Equal tests whether two multiaddrs are equal
|
||||
|
@ -64,11 +64,14 @@ func (m *multiaddr) Protocols() []Protocol {
|
|||
}
|
||||
}()
|
||||
|
||||
size := 0
|
||||
ps := []Protocol{}
|
||||
b := m.bytes[:]
|
||||
var ps []Protocol
|
||||
b := m.bytes
|
||||
for len(b) > 0 {
|
||||
code, n := ReadVarintCode(b)
|
||||
code, n, err := ReadVarintCode(b)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
p := ProtocolWithCode(code)
|
||||
if p.Code == 0 {
|
||||
// this is a panic (and not returning err) because this should've been
|
||||
|
@ -78,7 +81,7 @@ func (m *multiaddr) Protocols() []Protocol {
|
|||
ps = append(ps, p)
|
||||
b = b[n:]
|
||||
|
||||
size = sizeForAddr(p, b)
|
||||
size, err := sizeForAddr(p, b)
|
||||
b = b[size:]
|
||||
}
|
||||
return ps
|
||||
|
|
|
@ -3,7 +3,9 @@ package multiaddr
|
|||
import (
|
||||
"bytes"
|
||||
"encoding/hex"
|
||||
"math/rand"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func newMultiaddr(t *testing.T, a string) Multiaddr {
|
||||
|
@ -342,3 +344,21 @@ func TestGetValue(t *testing.T) {
|
|||
assertValueForProto(t, a, P_UDP, "12345")
|
||||
assertValueForProto(t, a, P_UTP, "")
|
||||
}
|
||||
|
||||
func TestFuzzBytes(t *testing.T) {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
// Bump up these numbers if you want to stress this
|
||||
buf := make([]byte, 256)
|
||||
for i := 0; i < 2000; i++ {
|
||||
l := rand.Intn(len(buf))
|
||||
rand.Read(buf[:l])
|
||||
|
||||
// just checking that it doesnt panic
|
||||
ma, err := NewMultiaddrBytes(buf[:l])
|
||||
if err == nil {
|
||||
// for any valid multiaddrs, make sure these calls don't panic
|
||||
ma.String()
|
||||
ma.Protocols()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
11
protocols.go
11
protocols.go
|
@ -117,16 +117,19 @@ func CodeToVarint(num int) []byte {
|
|||
|
||||
// VarintToCode converts a varint-encoded []byte to an integer protocol code
|
||||
func VarintToCode(buf []byte) int {
|
||||
num, _ := ReadVarintCode(buf)
|
||||
num, _, err := ReadVarintCode(buf)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return num
|
||||
}
|
||||
|
||||
// ReadVarintCode reads a varint code from the beginning of buf.
|
||||
// returns the code, and the number of bytes read.
|
||||
func ReadVarintCode(buf []byte) (int, int) {
|
||||
func ReadVarintCode(buf []byte) (int, int, error) {
|
||||
num, n := binary.Uvarint(buf)
|
||||
if n < 0 {
|
||||
panic("varints larger than uint64 not yet supported")
|
||||
return 0, 0, fmt.Errorf("varints larger than uint64 not yet supported")
|
||||
}
|
||||
return int(num), n
|
||||
return int(num), n, nil
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue