diff --git a/p2p/host/basic/basic_host.go b/p2p/host/basic/basic_host.go index 2a697556..68b3e08e 100644 --- a/p2p/host/basic/basic_host.go +++ b/p2p/host/basic/basic_host.go @@ -156,6 +156,17 @@ func (h *BasicHost) SetStreamHandler(pid protocol.ID, handler inet.StreamHandler }) } +// SetStreamHandlerMatch sets the protocol handler on the Host's Mux +// using a matching function to do protocol comparisons +func (h *BasicHost) SetStreamHandlerMatch(pid protocol.ID, m func(string) bool, handler inet.StreamHandler) { + h.Mux().AddHandlerWithFunc(string(pid), m, func(p string, rwc io.ReadWriteCloser) error { + is := rwc.(inet.Stream) + is.SetProtocol(p) + handler(is) + return nil + }) +} + // RemoveStreamHandler returns .. func (h *BasicHost) RemoveStreamHandler(pid protocol.ID) { h.Mux().RemoveHandler(string(pid)) diff --git a/p2p/host/host.go b/p2p/host/host.go index 40bb6881..65810e03 100644 --- a/p2p/host/host.go +++ b/p2p/host/host.go @@ -48,6 +48,10 @@ type Host interface { // (Threadsafe) SetStreamHandler(pid protocol.ID, handler inet.StreamHandler) + // SetStreamHandlerMatch sets the protocol handler on the Host's Mux + // using a matching function for protocol selection. + SetStreamHandlerMatch(protocol.ID, func(string) bool, inet.StreamHandler) + // RemoveStreamHandler removes a handler on the mux that was set by // SetStreamHandler RemoveStreamHandler(pid protocol.ID) diff --git a/p2p/host/match.go b/p2p/host/match.go new file mode 100644 index 00000000..dfee37e2 --- /dev/null +++ b/p2p/host/match.go @@ -0,0 +1,35 @@ +package host + +import ( + "strings" + + semver "github.com/coreos/go-semver/semver" +) + +func MultistreamSemverMatcher(base string) (func(string) bool, error) { + parts := strings.Split(base, "/") + vers, err := semver.NewVersion(parts[len(parts)-1]) + if err != nil { + return nil, err + } + + return func(check string) bool { + chparts := strings.Split(check, "/") + if len(chparts) != len(parts) { + return false + } + + for i, v := range chparts[:len(chparts)-1] { + if parts[i] != v { + return false + } + } + + chvers, err := semver.NewVersion(chparts[len(chparts)-1]) + if err != nil { + return false + } + + return vers.Major == chvers.Major && vers.Minor >= chvers.Minor + }, nil +} diff --git a/p2p/host/match_test.go b/p2p/host/match_test.go new file mode 100644 index 00000000..14bb9553 --- /dev/null +++ b/p2p/host/match_test.go @@ -0,0 +1,33 @@ +package host + +import ( + "testing" +) + +func TestSemverMatching(t *testing.T) { + m, err := MultistreamSemverMatcher("/testing/4.3.5") + if err != nil { + t.Fatal(err) + } + + cases := map[string]bool{ + "/testing/4.3.0": true, + "/testing/4.3.7": true, + "/testing/4.3.5": true, + "/testing/4.2.7": true, + "/testing/4.0.0": true, + "/testing/5.0.0": false, + "/cars/dogs/4.3.5": false, + "/foo/1.0.0": false, + "": false, + "dogs": false, + "/foo": false, + "/foo/1.1.1.1": false, + } + + for p, ok := range cases { + if m(p) != ok { + t.Fatalf("expected %s to be %t", p, ok) + } + } +} diff --git a/p2p/host/routed/routed.go b/p2p/host/routed/routed.go index 3f04a072..4d25afdd 100644 --- a/p2p/host/routed/routed.go +++ b/p2p/host/routed/routed.go @@ -110,6 +110,10 @@ func (rh *RoutedHost) SetStreamHandler(pid protocol.ID, handler inet.StreamHandl rh.host.SetStreamHandler(pid, handler) } +func (rh *RoutedHost) SetStreamHandlerMatch(pid protocol.ID, m func(string) bool, handler inet.StreamHandler) { + rh.host.SetStreamHandlerMatch(pid, m, handler) +} + func (rh *RoutedHost) RemoveStreamHandler(pid protocol.ID) { rh.host.RemoveStreamHandler(pid) } @@ -125,3 +129,5 @@ func (rh *RoutedHost) Close() error { func (rh *RoutedHost) GetBandwidthReporter() metrics.Reporter { return rh.host.GetBandwidthReporter() } + +var _ (host.Host) = (*RoutedHost)(nil) diff --git a/p2p/protocol/identify/id.go b/p2p/protocol/identify/id.go index 3c44f24a..cba8b8ec 100644 --- a/p2p/protocol/identify/id.go +++ b/p2p/protocol/identify/id.go @@ -173,6 +173,7 @@ func (ids *IDService) consumeMessage(mes *pb.Identify, c inet.Conn) { p := c.RemotePeer() // mes.Protocols + ids.Host.Peerstore().SetProtocols(p, mes.Protocols) // mes.ObservedAddr ids.consumeObservedAddress(mes.GetObservedAddr(), c)