mirror of
https://github.com/logos-messaging/go-multiaddr.git
synced 2026-01-05 22:43:10 +00:00
Merge pull request #25 from whyrusleeping/cleanup-new-bytes
cleanup panics and make NewFromBytes faster
This commit is contained in:
commit
f3dff105e4
@ -1,9 +1,8 @@
|
|||||||
language: go
|
language: go
|
||||||
|
|
||||||
go:
|
go:
|
||||||
- 1.3
|
|
||||||
- release
|
|
||||||
- tip
|
- tip
|
||||||
|
- 1.6.1
|
||||||
|
|
||||||
script:
|
script:
|
||||||
- go test -race -cpu=5 -v ./...
|
- go test -race -cpu=5 -v ./...
|
||||||
|
|||||||
126
codec.go
126
codec.go
@ -1,9 +1,9 @@
|
|||||||
package multiaddr
|
package multiaddr
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"encoding/base32"
|
"encoding/base32"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"strconv"
|
"strconv"
|
||||||
@ -17,7 +17,7 @@ func stringToBytes(s string) ([]byte, error) {
|
|||||||
// consume trailing slashes
|
// consume trailing slashes
|
||||||
s = strings.TrimRight(s, "/")
|
s = strings.TrimRight(s, "/")
|
||||||
|
|
||||||
b := []byte{}
|
b := new(bytes.Buffer)
|
||||||
sp := strings.Split(s, "/")
|
sp := strings.Split(s, "/")
|
||||||
|
|
||||||
if sp[0] != "" {
|
if sp[0] != "" {
|
||||||
@ -32,7 +32,7 @@ func stringToBytes(s string) ([]byte, error) {
|
|||||||
if p.Code == 0 {
|
if p.Code == 0 {
|
||||||
return nil, fmt.Errorf("no protocol with name %s", sp[0])
|
return nil, fmt.Errorf("no protocol with name %s", sp[0])
|
||||||
}
|
}
|
||||||
b = append(b, CodeToVarint(p.Code)...)
|
b.Write(CodeToVarint(p.Code))
|
||||||
sp = sp[1:]
|
sp = sp[1:]
|
||||||
|
|
||||||
if p.Size == 0 { // no length.
|
if p.Size == 0 { // no length.
|
||||||
@ -42,37 +42,59 @@ func stringToBytes(s string) ([]byte, error) {
|
|||||||
if len(sp) < 1 {
|
if len(sp) < 1 {
|
||||||
return nil, fmt.Errorf("protocol requires address, none given: %s", p.Name)
|
return nil, fmt.Errorf("protocol requires address, none given: %s", p.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
a, err := addressStringToBytes(p, sp[0])
|
a, err := addressStringToBytes(p, sp[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to parse %s: %s %s", p.Name, sp[0], err)
|
return nil, fmt.Errorf("failed to parse %s: %s %s", p.Name, sp[0], err)
|
||||||
}
|
}
|
||||||
b = append(b, a...)
|
b.Write(a)
|
||||||
sp = sp[1:]
|
sp = sp[1:]
|
||||||
}
|
}
|
||||||
return b, nil
|
|
||||||
|
return b.Bytes(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateBytes(b []byte) (err error) {
|
||||||
|
for len(b) > 0 {
|
||||||
|
code, n, err := ReadVarintCode(b)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
b = b[n:]
|
||||||
|
p := ProtocolWithCode(code)
|
||||||
|
if p.Code == 0 {
|
||||||
|
return fmt.Errorf("no protocol with code %d", code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.Size == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
size, err := sizeForAddr(p, b)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(b) < size || size < 0 {
|
||||||
|
return fmt.Errorf("invalid value for size")
|
||||||
|
}
|
||||||
|
|
||||||
|
b = b[size:]
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func bytesToString(b []byte) (ret string, err error) {
|
func bytesToString(b []byte) (ret string, err error) {
|
||||||
// panic handler, in case we try accessing bytes incorrectly.
|
|
||||||
defer func() {
|
|
||||||
if e := recover(); e != nil {
|
|
||||||
ret = ""
|
|
||||||
switch e := e.(type) {
|
|
||||||
case error:
|
|
||||||
err = e
|
|
||||||
case string:
|
|
||||||
err = errors.New(e)
|
|
||||||
default:
|
|
||||||
err = fmt.Errorf("%v", e)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
s := ""
|
s := ""
|
||||||
|
|
||||||
for len(b) > 0 {
|
for len(b) > 0 {
|
||||||
|
code, n, err := ReadVarintCode(b)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
code, n := ReadVarintCode(b)
|
|
||||||
b = b[n:]
|
b = b[n:]
|
||||||
p := ProtocolWithCode(code)
|
p := ProtocolWithCode(code)
|
||||||
if p.Code == 0 {
|
if p.Code == 0 {
|
||||||
@ -84,7 +106,15 @@ func bytesToString(b []byte) (ret string, err error) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
size := sizeForAddr(p, b)
|
size, err := sizeForAddr(p, b)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(b) < size || size < 0 {
|
||||||
|
return "", fmt.Errorf("invalid value for size")
|
||||||
|
}
|
||||||
|
|
||||||
a, err := addressBytesToString(p, b[:size])
|
a, err := addressBytesToString(p, b[:size])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
@ -98,36 +128,40 @@ func bytesToString(b []byte) (ret string, err error) {
|
|||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func sizeForAddr(p Protocol, b []byte) int {
|
func sizeForAddr(p Protocol, b []byte) (int, error) {
|
||||||
switch {
|
switch {
|
||||||
case p.Size > 0:
|
case p.Size > 0:
|
||||||
return (p.Size / 8)
|
return (p.Size / 8), nil
|
||||||
case p.Size == 0:
|
case p.Size == 0:
|
||||||
return 0
|
return 0, nil
|
||||||
default:
|
default:
|
||||||
size, n := ReadVarintCode(b)
|
size, n, err := ReadVarintCode(b)
|
||||||
return size + n
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return size + n, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func bytesSplit(b []byte) (ret [][]byte, err error) {
|
func bytesSplit(b []byte) ([][]byte, error) {
|
||||||
// panic handler, in case we try accessing bytes incorrectly.
|
var ret [][]byte
|
||||||
defer func() {
|
|
||||||
if e := recover(); e != nil {
|
|
||||||
ret = [][]byte{}
|
|
||||||
err = e.(error)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
ret = [][]byte{}
|
|
||||||
for len(b) > 0 {
|
for len(b) > 0 {
|
||||||
code, n := ReadVarintCode(b)
|
code, n, err := ReadVarintCode(b)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
p := ProtocolWithCode(code)
|
p := ProtocolWithCode(code)
|
||||||
if p.Code == 0 {
|
if p.Code == 0 {
|
||||||
return [][]byte{}, fmt.Errorf("no protocol with code %d", b[0])
|
return nil, fmt.Errorf("no protocol with code %d", b[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
size, err := sizeForAddr(p, b[n:])
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
size := sizeForAddr(p, b[n:])
|
|
||||||
length := n + size
|
length := n + size
|
||||||
ret = append(ret, b[:length])
|
ret = append(ret, b[:length])
|
||||||
b = b[length:]
|
b = b[length:]
|
||||||
@ -228,17 +262,21 @@ func addressBytesToString(p Protocol, b []byte) (string, error) {
|
|||||||
|
|
||||||
case P_IPFS: // ipfs
|
case P_IPFS: // ipfs
|
||||||
// the address is a varint-prefixed multihash string representation
|
// the address is a varint-prefixed multihash string representation
|
||||||
size, n := ReadVarintCode(b)
|
size, n, err := ReadVarintCode(b)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
b = b[n:]
|
b = b[n:]
|
||||||
if len(b) != size {
|
if len(b) != size {
|
||||||
panic("inconsistent lengths")
|
return "", fmt.Errorf("inconsistent lengths")
|
||||||
}
|
}
|
||||||
m, err := mh.Cast(b)
|
m, err := mh.Cast(b)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
return m.B58String(), nil
|
return m.B58String(), nil
|
||||||
|
default:
|
||||||
|
return "", fmt.Errorf("unknown protocol")
|
||||||
}
|
}
|
||||||
|
|
||||||
return "", fmt.Errorf("unknown protocol")
|
|
||||||
}
|
}
|
||||||
|
|||||||
41
multiaddr.go
41
multiaddr.go
@ -3,6 +3,7 @@ package multiaddr
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -12,7 +13,13 @@ type multiaddr struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewMultiaddr parses and validates an input string, returning a *Multiaddr
|
// NewMultiaddr parses and validates an input string, returning a *Multiaddr
|
||||||
func NewMultiaddr(s string) (Multiaddr, error) {
|
func NewMultiaddr(s string) (a Multiaddr, err error) {
|
||||||
|
defer func() {
|
||||||
|
if e := recover(); e != nil {
|
||||||
|
log.Printf("Panic in NewMultiaddr on input %q: %s", s, e)
|
||||||
|
err = fmt.Errorf("%v", e)
|
||||||
|
}
|
||||||
|
}()
|
||||||
b, err := stringToBytes(s)
|
b, err := stringToBytes(s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -22,12 +29,19 @@ func NewMultiaddr(s string) (Multiaddr, error) {
|
|||||||
|
|
||||||
// NewMultiaddrBytes initializes a Multiaddr from a byte representation.
|
// NewMultiaddrBytes initializes a Multiaddr from a byte representation.
|
||||||
// It validates it as an input string.
|
// It validates it as an input string.
|
||||||
func NewMultiaddrBytes(b []byte) (Multiaddr, error) {
|
func NewMultiaddrBytes(b []byte) (a Multiaddr, err error) {
|
||||||
s, err := bytesToString(b)
|
defer func() {
|
||||||
if err != nil {
|
if e := recover(); e != nil {
|
||||||
|
log.Printf("Panic in NewMultiaddrBytes on input %q: %s", b, e)
|
||||||
|
err = fmt.Errorf("%v", e)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err := validateBytes(b); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return NewMultiaddr(s)
|
|
||||||
|
return &multiaddr{bytes: b}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Equal tests whether two multiaddrs are equal
|
// Equal tests whether two multiaddrs are equal
|
||||||
@ -64,11 +78,14 @@ func (m *multiaddr) Protocols() []Protocol {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
size := 0
|
var ps []Protocol
|
||||||
ps := []Protocol{}
|
b := m.bytes
|
||||||
b := m.bytes[:]
|
|
||||||
for len(b) > 0 {
|
for len(b) > 0 {
|
||||||
code, n := ReadVarintCode(b)
|
code, n, err := ReadVarintCode(b)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
p := ProtocolWithCode(code)
|
p := ProtocolWithCode(code)
|
||||||
if p.Code == 0 {
|
if p.Code == 0 {
|
||||||
// this is a panic (and not returning err) because this should've been
|
// this is a panic (and not returning err) because this should've been
|
||||||
@ -78,7 +95,11 @@ func (m *multiaddr) Protocols() []Protocol {
|
|||||||
ps = append(ps, p)
|
ps = append(ps, p)
|
||||||
b = b[n:]
|
b = b[n:]
|
||||||
|
|
||||||
size = sizeForAddr(p, b)
|
size, err := sizeForAddr(p, b)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
b = b[size:]
|
b = b[size:]
|
||||||
}
|
}
|
||||||
return ps
|
return ps
|
||||||
|
|||||||
@ -3,7 +3,10 @@ package multiaddr
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
|
"math/rand"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func newMultiaddr(t *testing.T, a string) Multiaddr {
|
func newMultiaddr(t *testing.T, a string) Multiaddr {
|
||||||
@ -138,6 +141,10 @@ func TestStringToBytes(t *testing.T) {
|
|||||||
if !bytes.Equal(b1, b2) {
|
if !bytes.Equal(b1, b2) {
|
||||||
t.Error("failed to convert", s, "to", b1, "got", b2)
|
t.Error("failed to convert", s, "to", b1, "got", b2)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := validateBytes(b2); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
testString("/ip4/127.0.0.1/udp/1234", "047f0000011104d2")
|
testString("/ip4/127.0.0.1/udp/1234", "047f0000011104d2")
|
||||||
@ -153,6 +160,10 @@ func TestBytesToString(t *testing.T) {
|
|||||||
t.Error("failed to decode hex", h)
|
t.Error("failed to decode hex", h)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := validateBytes(b); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
s2, err := bytesToString(b)
|
s2, err := bytesToString(b)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error("failed to convert", b)
|
t.Error("failed to convert", b)
|
||||||
@ -246,7 +257,7 @@ func TestProtocolsWithString(t *testing.T) {
|
|||||||
for s, ps1 := range good {
|
for s, ps1 := range good {
|
||||||
ps2, err := ProtocolsWithString(s)
|
ps2, err := ProtocolsWithString(s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error("ProtocolsWithString(%s) should have succeeded", s)
|
t.Errorf("ProtocolsWithString(%s) should have succeeded", s)
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, ps1p := range ps1 {
|
for i, ps1p := range ps1 {
|
||||||
@ -266,7 +277,7 @@ func TestProtocolsWithString(t *testing.T) {
|
|||||||
|
|
||||||
for _, s := range bad {
|
for _, s := range bad {
|
||||||
if _, err := ProtocolsWithString(s); err == nil {
|
if _, err := ProtocolsWithString(s); err == nil {
|
||||||
t.Error("ProtocolsWithString(%s) should have failed", s)
|
t.Errorf("ProtocolsWithString(%s) should have failed", s)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -309,7 +320,7 @@ func assertValueForProto(t *testing.T, a Multiaddr, p int, exp string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if fv != exp {
|
if fv != exp {
|
||||||
t.Fatalf("expected %q for %d in %d, but got %q instead", exp, p, a, fv)
|
t.Fatalf("expected %q for %d in %s, but got %q instead", exp, p, a, fv)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -342,3 +353,55 @@ func TestGetValue(t *testing.T) {
|
|||||||
assertValueForProto(t, a, P_UDP, "12345")
|
assertValueForProto(t, a, P_UDP, "12345")
|
||||||
assertValueForProto(t, a, P_UTP, "")
|
assertValueForProto(t, a, P_UTP, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestFuzzBytes(t *testing.T) {
|
||||||
|
rand.Seed(time.Now().UnixNano())
|
||||||
|
// Bump up these numbers if you want to stress this
|
||||||
|
buf := make([]byte, 256)
|
||||||
|
for i := 0; i < 2000; i++ {
|
||||||
|
l := rand.Intn(len(buf))
|
||||||
|
rand.Read(buf[:l])
|
||||||
|
|
||||||
|
// just checking that it doesnt panic
|
||||||
|
ma, err := NewMultiaddrBytes(buf[:l])
|
||||||
|
if err == nil {
|
||||||
|
// for any valid multiaddrs, make sure these calls don't panic
|
||||||
|
_ = ma.String()
|
||||||
|
ma.Protocols()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func randMaddrString() string {
|
||||||
|
good_corpus := []string{"tcp", "ip", "udp", "ipfs", "0.0.0.0", "127.0.0.1", "12345", "QmbHVEEepCi7rn7VL7Exxpd2Ci9NNB6ifvqwhsrbRMgQFP"}
|
||||||
|
|
||||||
|
size := rand.Intn(256)
|
||||||
|
parts := make([]string, 0, size)
|
||||||
|
for i := 0; i < size; i++ {
|
||||||
|
switch rand.Intn(5) {
|
||||||
|
case 0, 1, 2:
|
||||||
|
parts = append(parts, good_corpus[rand.Intn(len(good_corpus))])
|
||||||
|
default:
|
||||||
|
badbuf := make([]byte, rand.Intn(256))
|
||||||
|
rand.Read(badbuf)
|
||||||
|
parts = append(parts, string(badbuf))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return "/" + strings.Join(parts, "/")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFuzzString(t *testing.T) {
|
||||||
|
rand.Seed(time.Now().UnixNano())
|
||||||
|
// Bump up these numbers if you want to stress this
|
||||||
|
for i := 0; i < 2000; i++ {
|
||||||
|
|
||||||
|
// just checking that it doesnt panic
|
||||||
|
ma, err := NewMultiaddr(randMaddrString())
|
||||||
|
if err == nil {
|
||||||
|
// for any valid multiaddrs, make sure these calls don't panic
|
||||||
|
_ = ma.String()
|
||||||
|
ma.Protocols()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
11
protocols.go
11
protocols.go
@ -117,16 +117,19 @@ func CodeToVarint(num int) []byte {
|
|||||||
|
|
||||||
// VarintToCode converts a varint-encoded []byte to an integer protocol code
|
// VarintToCode converts a varint-encoded []byte to an integer protocol code
|
||||||
func VarintToCode(buf []byte) int {
|
func VarintToCode(buf []byte) int {
|
||||||
num, _ := ReadVarintCode(buf)
|
num, _, err := ReadVarintCode(buf)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
return num
|
return num
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReadVarintCode reads a varint code from the beginning of buf.
|
// ReadVarintCode reads a varint code from the beginning of buf.
|
||||||
// returns the code, and the number of bytes read.
|
// returns the code, and the number of bytes read.
|
||||||
func ReadVarintCode(buf []byte) (int, int) {
|
func ReadVarintCode(buf []byte) (int, int, error) {
|
||||||
num, n := binary.Uvarint(buf)
|
num, n := binary.Uvarint(buf)
|
||||||
if n < 0 {
|
if n < 0 {
|
||||||
panic("varints larger than uint64 not yet supported")
|
return 0, 0, fmt.Errorf("varints larger than uint64 not yet supported")
|
||||||
}
|
}
|
||||||
return int(num), n
|
return int(num), n, nil
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user