swarm: make stream.Protocol() return type protocol.ID
This commit is contained in:
parent
36c66c0e93
commit
f74add8a19
|
@ -116,10 +116,10 @@ func (h *BasicHost) newStreamHandler(s inet.Stream) {
|
|||
}
|
||||
return
|
||||
}
|
||||
s.SetProtocol(protocol.ID(protoID))
|
||||
|
||||
logStream := mstream.WrapStream(s, protocol.ID(protoID), h.bwc)
|
||||
logStream := mstream.WrapStream(s, h.bwc)
|
||||
|
||||
s.SetProtocol(protoID)
|
||||
go handle(protoID, logStream)
|
||||
}
|
||||
|
||||
|
@ -155,7 +155,7 @@ func (h *BasicHost) IDService() *identify.IDService {
|
|||
func (h *BasicHost) SetStreamHandler(pid protocol.ID, handler inet.StreamHandler) {
|
||||
h.Mux().AddHandler(string(pid), func(p string, rwc io.ReadWriteCloser) error {
|
||||
is := rwc.(inet.Stream)
|
||||
is.SetProtocol(p)
|
||||
is.SetProtocol(protocol.ID(p))
|
||||
handler(is)
|
||||
return nil
|
||||
})
|
||||
|
@ -166,7 +166,7 @@ func (h *BasicHost) SetStreamHandler(pid protocol.ID, handler inet.StreamHandler
|
|||
func (h *BasicHost) SetStreamHandlerMatch(pid protocol.ID, m func(string) bool, handler inet.StreamHandler) {
|
||||
h.Mux().AddHandlerWithFunc(string(pid), m, func(p string, rwc io.ReadWriteCloser) error {
|
||||
is := rwc.(inet.Stream)
|
||||
is.SetProtocol(p)
|
||||
is.SetProtocol(protocol.ID(p))
|
||||
handler(is)
|
||||
return nil
|
||||
})
|
||||
|
@ -187,27 +187,26 @@ func (h *BasicHost) NewStream(ctx context.Context, p peer.ID, pids ...protocol.I
|
|||
return h.newStream(ctx, p, pref)
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
var protoStrs []string
|
||||
for _, pid := range pids {
|
||||
s, err := h.newStream(ctx, p, pid)
|
||||
protoStrs = append(protoStrs, string(pid))
|
||||
}
|
||||
|
||||
s, err := h.Network().NewStream(ctx, p)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
log.Infof("NewStream to %s for %s failed: %s", p, pid, err)
|
||||
continue
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = s.Read(nil)
|
||||
selected, err := msmux.SelectOneOf(protoStrs, s)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
log.Infof("NewStream to %s for %s failed (on read): %s", p, pid, err)
|
||||
continue
|
||||
s.Close()
|
||||
return nil, err
|
||||
}
|
||||
selpid := protocol.ID(selected)
|
||||
s.SetProtocol(selpid)
|
||||
h.setPreferredProtocol(p, selpid)
|
||||
|
||||
h.setPreferredProtocol(p, pid)
|
||||
return s, nil
|
||||
}
|
||||
|
||||
return nil, lastErr
|
||||
return mstream.WrapStream(s, h.bwc), nil
|
||||
}
|
||||
|
||||
func (h *BasicHost) preferredProtocol(p peer.ID, pids []protocol.ID) protocol.ID {
|
||||
|
@ -257,9 +256,9 @@ func (h *BasicHost) newStream(ctx context.Context, p peer.ID, pid protocol.ID) (
|
|||
return nil, err
|
||||
}
|
||||
|
||||
s.SetProtocol(string(pid))
|
||||
s.SetProtocol(pid)
|
||||
|
||||
logStream := mstream.WrapStream(s, pid, h.bwc)
|
||||
logStream := mstream.WrapStream(s, h.bwc)
|
||||
|
||||
lzcon := msmux.NewMSSelect(logStream, string(pid))
|
||||
return &streamWrapper{
|
||||
|
|
|
@ -76,7 +76,7 @@ func getHostPair(ctx context.Context, t *testing.T) (host.Host, host.Host) {
|
|||
return h1, h2
|
||||
}
|
||||
|
||||
func assertWait(t *testing.T, c chan string, exp string) {
|
||||
func assertWait(t *testing.T, c chan protocol.ID, exp protocol.ID) {
|
||||
select {
|
||||
case proto := <-c:
|
||||
if proto != exp {
|
||||
|
@ -99,7 +99,7 @@ func TestHostProtoPreference(t *testing.T) {
|
|||
protoNew := protocol.ID("/testing/1.1.0")
|
||||
protoMinor := protocol.ID("/testing/1.2.0")
|
||||
|
||||
connectedOn := make(chan string, 16)
|
||||
connectedOn := make(chan protocol.ID, 16)
|
||||
|
||||
handler := func(s inet.Stream) {
|
||||
connectedOn <- s.Protocol()
|
||||
|
@ -113,10 +113,10 @@ func TestHostProtoPreference(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
|
||||
assertWait(t, connectedOn, string(protoOld))
|
||||
assertWait(t, connectedOn, protoOld)
|
||||
s.Close()
|
||||
|
||||
mfunc, err := host.MultistreamSemverMatcher(string(protoMinor))
|
||||
mfunc, err := host.MultistreamSemverMatcher(protoMinor)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -135,7 +135,7 @@ func TestHostProtoPreference(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
|
||||
assertWait(t, connectedOn, string(protoOld))
|
||||
assertWait(t, connectedOn, protoOld)
|
||||
|
||||
s2.Close()
|
||||
|
||||
|
@ -144,12 +144,7 @@ func TestHostProtoPreference(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = s3.Read(nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
assertWait(t, connectedOn, string(protoMinor))
|
||||
assertWait(t, connectedOn, protoMinor)
|
||||
s3.Close()
|
||||
}
|
||||
|
||||
|
@ -179,7 +174,7 @@ func TestHostProtoPreknowledge(t *testing.T) {
|
|||
h1 := testutil.GenHostSwarm(t, ctx)
|
||||
h2 := testutil.GenHostSwarm(t, ctx)
|
||||
|
||||
conn := make(chan string, 16)
|
||||
conn := make(chan protocol.ID, 16)
|
||||
handler := func(s inet.Stream) {
|
||||
conn <- s.Protocol()
|
||||
s.Close()
|
||||
|
|
|
@ -1,13 +1,14 @@
|
|||
package host
|
||||
|
||||
import (
|
||||
"github.com/libp2p/go-libp2p/p2p/protocol"
|
||||
"strings"
|
||||
|
||||
semver "github.com/coreos/go-semver/semver"
|
||||
)
|
||||
|
||||
func MultistreamSemverMatcher(base string) (func(string) bool, error) {
|
||||
parts := strings.Split(base, "/")
|
||||
func MultistreamSemverMatcher(base protocol.ID) (func(string) bool, error) {
|
||||
parts := strings.Split(string(base), "/")
|
||||
vers, err := semver.NewVersion(parts[len(parts)-1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -19,18 +19,18 @@ type meteredStream struct {
|
|||
mesRecv metrics.StreamMeterCallback
|
||||
}
|
||||
|
||||
func newMeteredStream(base inet.Stream, pid protocol.ID, p peer.ID, recvCB, sentCB metrics.StreamMeterCallback) inet.Stream {
|
||||
func newMeteredStream(base inet.Stream, p peer.ID, recvCB, sentCB metrics.StreamMeterCallback) inet.Stream {
|
||||
return &meteredStream{
|
||||
Stream: base,
|
||||
mesSent: sentCB,
|
||||
mesRecv: recvCB,
|
||||
protoKey: pid,
|
||||
protoKey: base.Protocol(),
|
||||
peerKey: p,
|
||||
}
|
||||
}
|
||||
|
||||
func WrapStream(base inet.Stream, pid protocol.ID, bwc metrics.Reporter) inet.Stream {
|
||||
return newMeteredStream(base, pid, base.Conn().RemotePeer(), bwc.LogRecvMessageStream, bwc.LogSentMessageStream)
|
||||
func WrapStream(base inet.Stream, bwc metrics.Reporter) inet.Stream {
|
||||
return newMeteredStream(base, base.Conn().RemotePeer(), bwc.LogRecvMessageStream, bwc.LogSentMessageStream)
|
||||
}
|
||||
|
||||
func (s *meteredStream) Read(b []byte) (int, error) {
|
||||
|
|
|
@ -24,6 +24,10 @@ func (fs *FakeStream) Write(b []byte) (int, error) {
|
|||
return len(b), nil
|
||||
}
|
||||
|
||||
func (fs *FakeStream) Protocol() protocol.ID {
|
||||
return "TEST"
|
||||
}
|
||||
|
||||
func TestCallbacksWork(t *testing.T) {
|
||||
fake := new(FakeStream)
|
||||
|
||||
|
@ -38,7 +42,7 @@ func TestCallbacksWork(t *testing.T) {
|
|||
recv += n
|
||||
}
|
||||
|
||||
ms := newMeteredStream(fake, protocol.ID("TEST"), peer.ID("PEER"), recvCB, sentCB)
|
||||
ms := newMeteredStream(fake, peer.ID("PEER"), recvCB, sentCB)
|
||||
|
||||
toWrite := int64(100000)
|
||||
toRead := int64(100000)
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
ma "github.com/jbenet/go-multiaddr"
|
||||
"github.com/jbenet/goprocess"
|
||||
conn "github.com/libp2p/go-libp2p/p2p/net/conn"
|
||||
protocol "github.com/libp2p/go-libp2p/p2p/protocol"
|
||||
context "golang.org/x/net/context"
|
||||
)
|
||||
|
||||
|
@ -26,8 +27,8 @@ type Stream interface {
|
|||
io.Writer
|
||||
io.Closer
|
||||
|
||||
Protocol() string
|
||||
SetProtocol(string)
|
||||
Protocol() protocol.ID
|
||||
SetProtocol(protocol.ID)
|
||||
|
||||
// Conn returns the connection this stream is part of.
|
||||
Conn() Conn
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
|
||||
process "github.com/jbenet/goprocess"
|
||||
inet "github.com/libp2p/go-libp2p/p2p/net"
|
||||
protocol "github.com/libp2p/go-libp2p/p2p/protocol"
|
||||
)
|
||||
|
||||
// stream implements inet.Stream
|
||||
|
@ -17,7 +18,7 @@ type stream struct {
|
|||
toDeliver chan *transportObject
|
||||
proc process.Process
|
||||
|
||||
protocol string
|
||||
protocol protocol.ID
|
||||
}
|
||||
|
||||
type transportObject struct {
|
||||
|
@ -50,11 +51,11 @@ func (s *stream) Write(p []byte) (n int, err error) {
|
|||
return len(p), nil
|
||||
}
|
||||
|
||||
func (s *stream) Protocol() string {
|
||||
func (s *stream) Protocol() protocol.ID {
|
||||
return s.protocol
|
||||
}
|
||||
|
||||
func (s *stream) SetProtocol(proto string) {
|
||||
func (s *stream) SetProtocol(proto protocol.ID) {
|
||||
s.protocol = proto
|
||||
}
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@ package swarm
|
|||
|
||||
import (
|
||||
inet "github.com/libp2p/go-libp2p/p2p/net"
|
||||
protocol "github.com/libp2p/go-libp2p/p2p/protocol"
|
||||
|
||||
ps "github.com/jbenet/go-peerstream"
|
||||
)
|
||||
|
@ -10,7 +11,7 @@ import (
|
|||
// our Conn and Swarm (instead of just the ps.Conn and ps.Swarm)
|
||||
type Stream struct {
|
||||
stream *ps.Stream
|
||||
protocol string
|
||||
protocol protocol.ID
|
||||
}
|
||||
|
||||
// Stream returns the underlying peerstream.Stream
|
||||
|
@ -44,11 +45,11 @@ func (s *Stream) Close() error {
|
|||
return s.stream.Close()
|
||||
}
|
||||
|
||||
func (s *Stream) Protocol() string {
|
||||
func (s *Stream) Protocol() protocol.ID {
|
||||
return s.protocol
|
||||
}
|
||||
|
||||
func (s *Stream) SetProtocol(p string) {
|
||||
func (s *Stream) SetProtocol(p protocol.ID) {
|
||||
s.protocol = p
|
||||
}
|
||||
|
||||
|
|
|
@ -86,8 +86,10 @@ func (ids *IDService) IdentifyConn(c inet.Conn) {
|
|||
return
|
||||
}
|
||||
|
||||
s.SetProtocol(ID)
|
||||
|
||||
bwc := ids.Host.GetBandwidthReporter()
|
||||
s = mstream.WrapStream(s, ID, bwc)
|
||||
s = mstream.WrapStream(s, bwc)
|
||||
|
||||
// ok give the response to our handler.
|
||||
if err := msmux.SelectProtoOrFail(ID, s); err != nil {
|
||||
|
@ -115,7 +117,7 @@ func (ids *IDService) RequestHandler(s inet.Stream) {
|
|||
c := s.Conn()
|
||||
|
||||
bwc := ids.Host.GetBandwidthReporter()
|
||||
s = mstream.WrapStream(s, ID, bwc)
|
||||
s = mstream.WrapStream(s, bwc)
|
||||
|
||||
w := ggio.NewDelimitedWriter(s)
|
||||
mes := pb.Identify{}
|
||||
|
|
Loading…
Reference in New Issue