swarm: make stream.Protocol() return type protocol.ID
This commit is contained in:
parent
36c66c0e93
commit
f74add8a19
|
@ -116,10 +116,10 @@ func (h *BasicHost) newStreamHandler(s inet.Stream) {
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
s.SetProtocol(protocol.ID(protoID))
|
||||||
|
|
||||||
logStream := mstream.WrapStream(s, protocol.ID(protoID), h.bwc)
|
logStream := mstream.WrapStream(s, h.bwc)
|
||||||
|
|
||||||
s.SetProtocol(protoID)
|
|
||||||
go handle(protoID, logStream)
|
go handle(protoID, logStream)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -155,7 +155,7 @@ func (h *BasicHost) IDService() *identify.IDService {
|
||||||
func (h *BasicHost) SetStreamHandler(pid protocol.ID, handler inet.StreamHandler) {
|
func (h *BasicHost) SetStreamHandler(pid protocol.ID, handler inet.StreamHandler) {
|
||||||
h.Mux().AddHandler(string(pid), func(p string, rwc io.ReadWriteCloser) error {
|
h.Mux().AddHandler(string(pid), func(p string, rwc io.ReadWriteCloser) error {
|
||||||
is := rwc.(inet.Stream)
|
is := rwc.(inet.Stream)
|
||||||
is.SetProtocol(p)
|
is.SetProtocol(protocol.ID(p))
|
||||||
handler(is)
|
handler(is)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
@ -166,7 +166,7 @@ func (h *BasicHost) SetStreamHandler(pid protocol.ID, handler inet.StreamHandler
|
||||||
func (h *BasicHost) SetStreamHandlerMatch(pid protocol.ID, m func(string) bool, handler inet.StreamHandler) {
|
func (h *BasicHost) SetStreamHandlerMatch(pid protocol.ID, m func(string) bool, handler inet.StreamHandler) {
|
||||||
h.Mux().AddHandlerWithFunc(string(pid), m, func(p string, rwc io.ReadWriteCloser) error {
|
h.Mux().AddHandlerWithFunc(string(pid), m, func(p string, rwc io.ReadWriteCloser) error {
|
||||||
is := rwc.(inet.Stream)
|
is := rwc.(inet.Stream)
|
||||||
is.SetProtocol(p)
|
is.SetProtocol(protocol.ID(p))
|
||||||
handler(is)
|
handler(is)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
@ -187,27 +187,26 @@ func (h *BasicHost) NewStream(ctx context.Context, p peer.ID, pids ...protocol.I
|
||||||
return h.newStream(ctx, p, pref)
|
return h.newStream(ctx, p, pref)
|
||||||
}
|
}
|
||||||
|
|
||||||
var lastErr error
|
var protoStrs []string
|
||||||
for _, pid := range pids {
|
for _, pid := range pids {
|
||||||
s, err := h.newStream(ctx, p, pid)
|
protoStrs = append(protoStrs, string(pid))
|
||||||
if err != nil {
|
|
||||||
lastErr = err
|
|
||||||
log.Infof("NewStream to %s for %s failed: %s", p, pid, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = s.Read(nil)
|
|
||||||
if err != nil {
|
|
||||||
lastErr = err
|
|
||||||
log.Infof("NewStream to %s for %s failed (on read): %s", p, pid, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
h.setPreferredProtocol(p, pid)
|
|
||||||
return s, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, lastErr
|
s, err := h.Network().NewStream(ctx, p)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
selected, err := msmux.SelectOneOf(protoStrs, s)
|
||||||
|
if err != nil {
|
||||||
|
s.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
selpid := protocol.ID(selected)
|
||||||
|
s.SetProtocol(selpid)
|
||||||
|
h.setPreferredProtocol(p, selpid)
|
||||||
|
|
||||||
|
return mstream.WrapStream(s, h.bwc), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *BasicHost) preferredProtocol(p peer.ID, pids []protocol.ID) protocol.ID {
|
func (h *BasicHost) preferredProtocol(p peer.ID, pids []protocol.ID) protocol.ID {
|
||||||
|
@ -257,9 +256,9 @@ func (h *BasicHost) newStream(ctx context.Context, p peer.ID, pid protocol.ID) (
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
s.SetProtocol(string(pid))
|
s.SetProtocol(pid)
|
||||||
|
|
||||||
logStream := mstream.WrapStream(s, pid, h.bwc)
|
logStream := mstream.WrapStream(s, h.bwc)
|
||||||
|
|
||||||
lzcon := msmux.NewMSSelect(logStream, string(pid))
|
lzcon := msmux.NewMSSelect(logStream, string(pid))
|
||||||
return &streamWrapper{
|
return &streamWrapper{
|
||||||
|
|
|
@ -76,7 +76,7 @@ func getHostPair(ctx context.Context, t *testing.T) (host.Host, host.Host) {
|
||||||
return h1, h2
|
return h1, h2
|
||||||
}
|
}
|
||||||
|
|
||||||
func assertWait(t *testing.T, c chan string, exp string) {
|
func assertWait(t *testing.T, c chan protocol.ID, exp protocol.ID) {
|
||||||
select {
|
select {
|
||||||
case proto := <-c:
|
case proto := <-c:
|
||||||
if proto != exp {
|
if proto != exp {
|
||||||
|
@ -99,7 +99,7 @@ func TestHostProtoPreference(t *testing.T) {
|
||||||
protoNew := protocol.ID("/testing/1.1.0")
|
protoNew := protocol.ID("/testing/1.1.0")
|
||||||
protoMinor := protocol.ID("/testing/1.2.0")
|
protoMinor := protocol.ID("/testing/1.2.0")
|
||||||
|
|
||||||
connectedOn := make(chan string, 16)
|
connectedOn := make(chan protocol.ID, 16)
|
||||||
|
|
||||||
handler := func(s inet.Stream) {
|
handler := func(s inet.Stream) {
|
||||||
connectedOn <- s.Protocol()
|
connectedOn <- s.Protocol()
|
||||||
|
@ -113,10 +113,10 @@ func TestHostProtoPreference(t *testing.T) {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
assertWait(t, connectedOn, string(protoOld))
|
assertWait(t, connectedOn, protoOld)
|
||||||
s.Close()
|
s.Close()
|
||||||
|
|
||||||
mfunc, err := host.MultistreamSemverMatcher(string(protoMinor))
|
mfunc, err := host.MultistreamSemverMatcher(protoMinor)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -135,7 +135,7 @@ func TestHostProtoPreference(t *testing.T) {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
assertWait(t, connectedOn, string(protoOld))
|
assertWait(t, connectedOn, protoOld)
|
||||||
|
|
||||||
s2.Close()
|
s2.Close()
|
||||||
|
|
||||||
|
@ -144,12 +144,7 @@ func TestHostProtoPreference(t *testing.T) {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = s3.Read(nil)
|
assertWait(t, connectedOn, protoMinor)
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
assertWait(t, connectedOn, string(protoMinor))
|
|
||||||
s3.Close()
|
s3.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -179,7 +174,7 @@ func TestHostProtoPreknowledge(t *testing.T) {
|
||||||
h1 := testutil.GenHostSwarm(t, ctx)
|
h1 := testutil.GenHostSwarm(t, ctx)
|
||||||
h2 := testutil.GenHostSwarm(t, ctx)
|
h2 := testutil.GenHostSwarm(t, ctx)
|
||||||
|
|
||||||
conn := make(chan string, 16)
|
conn := make(chan protocol.ID, 16)
|
||||||
handler := func(s inet.Stream) {
|
handler := func(s inet.Stream) {
|
||||||
conn <- s.Protocol()
|
conn <- s.Protocol()
|
||||||
s.Close()
|
s.Close()
|
||||||
|
|
|
@ -1,13 +1,14 @@
|
||||||
package host
|
package host
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"github.com/libp2p/go-libp2p/p2p/protocol"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
semver "github.com/coreos/go-semver/semver"
|
semver "github.com/coreos/go-semver/semver"
|
||||||
)
|
)
|
||||||
|
|
||||||
func MultistreamSemverMatcher(base string) (func(string) bool, error) {
|
func MultistreamSemverMatcher(base protocol.ID) (func(string) bool, error) {
|
||||||
parts := strings.Split(base, "/")
|
parts := strings.Split(string(base), "/")
|
||||||
vers, err := semver.NewVersion(parts[len(parts)-1])
|
vers, err := semver.NewVersion(parts[len(parts)-1])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
|
@ -19,18 +19,18 @@ type meteredStream struct {
|
||||||
mesRecv metrics.StreamMeterCallback
|
mesRecv metrics.StreamMeterCallback
|
||||||
}
|
}
|
||||||
|
|
||||||
func newMeteredStream(base inet.Stream, pid protocol.ID, p peer.ID, recvCB, sentCB metrics.StreamMeterCallback) inet.Stream {
|
func newMeteredStream(base inet.Stream, p peer.ID, recvCB, sentCB metrics.StreamMeterCallback) inet.Stream {
|
||||||
return &meteredStream{
|
return &meteredStream{
|
||||||
Stream: base,
|
Stream: base,
|
||||||
mesSent: sentCB,
|
mesSent: sentCB,
|
||||||
mesRecv: recvCB,
|
mesRecv: recvCB,
|
||||||
protoKey: pid,
|
protoKey: base.Protocol(),
|
||||||
peerKey: p,
|
peerKey: p,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WrapStream(base inet.Stream, pid protocol.ID, bwc metrics.Reporter) inet.Stream {
|
func WrapStream(base inet.Stream, bwc metrics.Reporter) inet.Stream {
|
||||||
return newMeteredStream(base, pid, base.Conn().RemotePeer(), bwc.LogRecvMessageStream, bwc.LogSentMessageStream)
|
return newMeteredStream(base, base.Conn().RemotePeer(), bwc.LogRecvMessageStream, bwc.LogSentMessageStream)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *meteredStream) Read(b []byte) (int, error) {
|
func (s *meteredStream) Read(b []byte) (int, error) {
|
||||||
|
|
|
@ -24,6 +24,10 @@ func (fs *FakeStream) Write(b []byte) (int, error) {
|
||||||
return len(b), nil
|
return len(b), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (fs *FakeStream) Protocol() protocol.ID {
|
||||||
|
return "TEST"
|
||||||
|
}
|
||||||
|
|
||||||
func TestCallbacksWork(t *testing.T) {
|
func TestCallbacksWork(t *testing.T) {
|
||||||
fake := new(FakeStream)
|
fake := new(FakeStream)
|
||||||
|
|
||||||
|
@ -38,7 +42,7 @@ func TestCallbacksWork(t *testing.T) {
|
||||||
recv += n
|
recv += n
|
||||||
}
|
}
|
||||||
|
|
||||||
ms := newMeteredStream(fake, protocol.ID("TEST"), peer.ID("PEER"), recvCB, sentCB)
|
ms := newMeteredStream(fake, peer.ID("PEER"), recvCB, sentCB)
|
||||||
|
|
||||||
toWrite := int64(100000)
|
toWrite := int64(100000)
|
||||||
toRead := int64(100000)
|
toRead := int64(100000)
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
ma "github.com/jbenet/go-multiaddr"
|
ma "github.com/jbenet/go-multiaddr"
|
||||||
"github.com/jbenet/goprocess"
|
"github.com/jbenet/goprocess"
|
||||||
conn "github.com/libp2p/go-libp2p/p2p/net/conn"
|
conn "github.com/libp2p/go-libp2p/p2p/net/conn"
|
||||||
|
protocol "github.com/libp2p/go-libp2p/p2p/protocol"
|
||||||
context "golang.org/x/net/context"
|
context "golang.org/x/net/context"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -26,8 +27,8 @@ type Stream interface {
|
||||||
io.Writer
|
io.Writer
|
||||||
io.Closer
|
io.Closer
|
||||||
|
|
||||||
Protocol() string
|
Protocol() protocol.ID
|
||||||
SetProtocol(string)
|
SetProtocol(protocol.ID)
|
||||||
|
|
||||||
// Conn returns the connection this stream is part of.
|
// Conn returns the connection this stream is part of.
|
||||||
Conn() Conn
|
Conn() Conn
|
||||||
|
|
|
@ -7,6 +7,7 @@ import (
|
||||||
|
|
||||||
process "github.com/jbenet/goprocess"
|
process "github.com/jbenet/goprocess"
|
||||||
inet "github.com/libp2p/go-libp2p/p2p/net"
|
inet "github.com/libp2p/go-libp2p/p2p/net"
|
||||||
|
protocol "github.com/libp2p/go-libp2p/p2p/protocol"
|
||||||
)
|
)
|
||||||
|
|
||||||
// stream implements inet.Stream
|
// stream implements inet.Stream
|
||||||
|
@ -17,7 +18,7 @@ type stream struct {
|
||||||
toDeliver chan *transportObject
|
toDeliver chan *transportObject
|
||||||
proc process.Process
|
proc process.Process
|
||||||
|
|
||||||
protocol string
|
protocol protocol.ID
|
||||||
}
|
}
|
||||||
|
|
||||||
type transportObject struct {
|
type transportObject struct {
|
||||||
|
@ -50,11 +51,11 @@ func (s *stream) Write(p []byte) (n int, err error) {
|
||||||
return len(p), nil
|
return len(p), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *stream) Protocol() string {
|
func (s *stream) Protocol() protocol.ID {
|
||||||
return s.protocol
|
return s.protocol
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *stream) SetProtocol(proto string) {
|
func (s *stream) SetProtocol(proto protocol.ID) {
|
||||||
s.protocol = proto
|
s.protocol = proto
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -2,6 +2,7 @@ package swarm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
inet "github.com/libp2p/go-libp2p/p2p/net"
|
inet "github.com/libp2p/go-libp2p/p2p/net"
|
||||||
|
protocol "github.com/libp2p/go-libp2p/p2p/protocol"
|
||||||
|
|
||||||
ps "github.com/jbenet/go-peerstream"
|
ps "github.com/jbenet/go-peerstream"
|
||||||
)
|
)
|
||||||
|
@ -10,7 +11,7 @@ import (
|
||||||
// our Conn and Swarm (instead of just the ps.Conn and ps.Swarm)
|
// our Conn and Swarm (instead of just the ps.Conn and ps.Swarm)
|
||||||
type Stream struct {
|
type Stream struct {
|
||||||
stream *ps.Stream
|
stream *ps.Stream
|
||||||
protocol string
|
protocol protocol.ID
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stream returns the underlying peerstream.Stream
|
// Stream returns the underlying peerstream.Stream
|
||||||
|
@ -44,11 +45,11 @@ func (s *Stream) Close() error {
|
||||||
return s.stream.Close()
|
return s.stream.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Stream) Protocol() string {
|
func (s *Stream) Protocol() protocol.ID {
|
||||||
return s.protocol
|
return s.protocol
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Stream) SetProtocol(p string) {
|
func (s *Stream) SetProtocol(p protocol.ID) {
|
||||||
s.protocol = p
|
s.protocol = p
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -86,8 +86,10 @@ func (ids *IDService) IdentifyConn(c inet.Conn) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.SetProtocol(ID)
|
||||||
|
|
||||||
bwc := ids.Host.GetBandwidthReporter()
|
bwc := ids.Host.GetBandwidthReporter()
|
||||||
s = mstream.WrapStream(s, ID, bwc)
|
s = mstream.WrapStream(s, bwc)
|
||||||
|
|
||||||
// ok give the response to our handler.
|
// ok give the response to our handler.
|
||||||
if err := msmux.SelectProtoOrFail(ID, s); err != nil {
|
if err := msmux.SelectProtoOrFail(ID, s); err != nil {
|
||||||
|
@ -115,7 +117,7 @@ func (ids *IDService) RequestHandler(s inet.Stream) {
|
||||||
c := s.Conn()
|
c := s.Conn()
|
||||||
|
|
||||||
bwc := ids.Host.GetBandwidthReporter()
|
bwc := ids.Host.GetBandwidthReporter()
|
||||||
s = mstream.WrapStream(s, ID, bwc)
|
s = mstream.WrapStream(s, bwc)
|
||||||
|
|
||||||
w := ggio.NewDelimitedWriter(s)
|
w := ggio.NewDelimitedWriter(s)
|
||||||
mes := pb.Identify{}
|
mes := pb.Identify{}
|
||||||
|
|
Loading…
Reference in New Issue