diff --git a/multiaddr_test.go b/multiaddr_test.go index f9f0a47..6545897 100644 --- a/multiaddr_test.go +++ b/multiaddr_test.go @@ -216,6 +216,46 @@ func TestProtocols(t *testing.T) { } +func TestProtocolsWithString(t *testing.T) { + pwn := ProtocolWithName + good := map[string][]Protocol{ + "/ip4": []Protocol{pwn("ip4")}, + "/ip4/tcp": []Protocol{pwn("ip4"), pwn("tcp")}, + "ip4/tcp/udp/ip6": []Protocol{pwn("ip4"), pwn("tcp"), pwn("udp"), pwn("ip6")}, + "////////ip4/tcp": []Protocol{pwn("ip4"), pwn("tcp")}, + "ip4/udp/////////": []Protocol{pwn("ip4"), pwn("udp")}, + "////////ip4/tcp////////": []Protocol{pwn("ip4"), pwn("tcp")}, + } + + for s, ps1 := range good { + ps2, err := ProtocolsWithString(s) + if err != nil { + t.Error("ProtocolsWithString(%s) should have succeeded", s) + } + + for i, ps1p := range ps1 { + ps2p := ps2[i] + if ps1p.Code != ps2p.Code { + t.Errorf("mismatch: %s != %s, %s", ps1p.Name, ps2p.Name, s) + } + } + } + + bad := []string{ + "dsijafd", // bogus proto + "/ip4/tcp/fidosafoidsa", // bogus proto + "////////ip4/tcp/21432141/////////", // bogus proto + "////////ip4///////tcp/////////", // empty protos in between + } + + for _, s := range bad { + if _, err := ProtocolsWithString(s); err == nil { + t.Error("ProtocolsWithString(%s) should have failed", s) + } + } + +} + func TestEncapsulate(t *testing.T) { m, err := NewMultiaddr("/ip4/127.0.0.1/udp/1234") if err != nil { diff --git a/protocols.go b/protocols.go index 3be0b30..eaddc61 100644 --- a/protocols.go +++ b/protocols.go @@ -2,6 +2,8 @@ package multiaddr import ( "encoding/binary" + "fmt" + "strings" ) // Protocol is a Multiaddr protocol description structure. @@ -62,6 +64,25 @@ func ProtocolWithCode(c int) Protocol { return Protocol{} } +// ProtocolsWithString returns a slice of protocols matching given string. +func ProtocolsWithString(s string) ([]Protocol, error) { + s = strings.Trim(s, "/") + sp := strings.Split(s, "/") + if len(sp) == 0 { + return nil, nil + } + + t := make([]Protocol, len(sp)) + for i, name := range sp { + p := ProtocolWithName(name) + if p.Code == 0 { + return nil, fmt.Errorf("no protocol with name: %s", name) + } + t[i] = p + } + return t, nil +} + // CodeToVarint converts an integer to a varint-encoded []byte func CodeToVarint(num int) []byte { buf := make([]byte, (num/7)+1) // varint package is uint64