consistently use protocol.ID instead of strings (#2004)

* Change PeerStore interface to use protocol.ID

This reduces the string to protocol.ID translations happening
at various places in the code

* Fix misc cases of protocol.ID conversion

* Merge multistream changes

* Use protocol.ID in network.ConnectionState

* don't update examples

* fix error message tests

* merge new go-multistream changes

* update test-plans go mod

* change transport back to string
This commit is contained in:
Sukun 2023-01-27 15:09:59 +05:30 committed by GitHub
parent 3919359872
commit 6b9c11680e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
35 changed files with 204 additions and 198 deletions

View File

@ -52,7 +52,7 @@ type Host interface {
// SetStreamHandlerMatch sets the protocol handler on the Host's Mux
// using a matching function for protocol selection.
SetStreamHandlerMatch(protocol.ID, func(string) bool, network.StreamHandler)
SetStreamHandlerMatch(protocol.ID, func(protocol.ID) bool, network.StreamHandler)
// RemoveStreamHandler removes a handler on the mux that was set by
// SetStreamHandler

View File

@ -6,6 +6,7 @@ import (
ic "github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/protocol"
ma "github.com/multiformats/go-multiaddr"
)
@ -37,9 +38,9 @@ type Conn interface {
// ConnectionState holds information about the connection.
type ConnectionState struct {
// The stream multiplexer used on this connection (if any). For example: /yamux/1.0.0
StreamMultiplexer string
StreamMultiplexer protocol.ID
// The security protocol used on this connection (if any). For example: /tls/1.0.0
Security string
Security protocol.ID
// the transport used on this connection. For example: tcp
Transport string
}

View File

@ -11,6 +11,7 @@ import (
ic "github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/protocol"
"github.com/libp2p/go-libp2p/core/record"
ma "github.com/multiformats/go-multiaddr"
@ -230,19 +231,19 @@ type Metrics interface {
// ProtoBook tracks the protocols supported by peers.
type ProtoBook interface {
GetProtocols(peer.ID) ([]string, error)
AddProtocols(peer.ID, ...string) error
SetProtocols(peer.ID, ...string) error
RemoveProtocols(peer.ID, ...string) error
GetProtocols(peer.ID) ([]protocol.ID, error)
AddProtocols(peer.ID, ...protocol.ID) error
SetProtocols(peer.ID, ...protocol.ID) error
RemoveProtocols(peer.ID, ...protocol.ID) error
// SupportsProtocols returns the set of protocols the peer supports from among the given protocols.
// If the returned error is not nil, the result is indeterminate.
SupportsProtocols(peer.ID, ...string) ([]string, error)
SupportsProtocols(peer.ID, ...protocol.ID) ([]protocol.ID, error)
// FirstSupportedProtocol returns the first protocol that the peer supports among the given protocols.
// If the peer does not support any of the given protocols, this function will return an empty string and a nil error.
// If the peer does not support any of the given protocols, this function will return an empty protocol.ID and a nil error.
// If the returned error is not nil, the result is indeterminate.
FirstSupportedProtocol(peer.ID, ...string) (string, error)
FirstSupportedProtocol(peer.ID, ...protocol.ID) (protocol.ID, error)
// RemovePeer removes all protocols associated with a peer.
RemovePeer(peer.ID)

View File

@ -3,6 +3,8 @@ package protocol
import (
"io"
"github.com/multiformats/go-multistream"
)
// HandlerFunc is a user-provided function used by the Router to
@ -11,7 +13,7 @@ import (
// Will be invoked with the protocol ID string as the first argument,
// which may differ from the ID used for registration if the handler
// was registered using a match function.
type HandlerFunc = func(protocol string, rwc io.ReadWriteCloser) error
type HandlerFunc = multistream.HandlerFunc[ID]
// Router is an interface that allows users to add and remove protocol handlers,
// which will be invoked when incoming stream requests for registered protocols
@ -25,7 +27,7 @@ type Router interface {
// AddHandler registers the given handler to be invoked for
// an exact literal match of the given protocol ID string.
AddHandler(protocol string, handler HandlerFunc)
AddHandler(protocol ID, handler HandlerFunc)
// AddHandlerWithFunc registers the given handler to be invoked
// when the provided match function returns true.
@ -35,17 +37,17 @@ type Router interface {
// the protocol. Note that the protocol ID argument is not
// used for matching; if you want to match the protocol ID
// string exactly, you must check for it in your match function.
AddHandlerWithFunc(protocol string, match func(string) bool, handler HandlerFunc)
AddHandlerWithFunc(protocol ID, match func(ID) bool, handler HandlerFunc)
// RemoveHandler removes the registered handler (if any) for the
// given protocol ID string.
RemoveHandler(protocol string)
RemoveHandler(protocol ID)
// Protocols returns a list of all registered protocol ID strings.
// Note that the Router may be able to handle protocol IDs not
// included in this list if handlers were added with match functions
// using AddHandlerWithFunc.
Protocols() []string
Protocols() []ID
}
// Negotiator is a component capable of reaching agreement over what protocols
@ -55,7 +57,7 @@ type Negotiator interface {
// inbound stream, returning after the protocol has been determined and the
// Negotiator has finished using the stream for negotiation. Returns an
// error if negotiation fails.
Negotiate(rwc io.ReadWriteCloser) (string, HandlerFunc, error)
Negotiate(rwc io.ReadWriteCloser) (ID, HandlerFunc, error)
// Handle calls Negotiate to determine which protocol handler to use for an
// inbound stream, then invokes the protocol handler function, passing it

2
go.mod
View File

@ -42,7 +42,7 @@ require (
github.com/multiformats/go-multibase v0.1.1
github.com/multiformats/go-multicodec v0.7.0
github.com/multiformats/go-multihash v0.2.1
github.com/multiformats/go-multistream v0.3.3
github.com/multiformats/go-multistream v0.4.0
github.com/multiformats/go-varint v0.0.7
github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58
github.com/prometheus/client_golang v1.14.0

4
go.sum
View File

@ -384,8 +384,8 @@ github.com/multiformats/go-multicodec v0.7.0/go.mod h1:GUC8upxSBE4oG+q3kWZRw/+6y
github.com/multiformats/go-multihash v0.0.8/go.mod h1:YSLudS+Pi8NHE7o6tb3D8vrpKa63epEDmG8nTduyAew=
github.com/multiformats/go-multihash v0.2.1 h1:aem8ZT0VA2nCHHk7bPJ1BjUbHNciqZC/d16Vve9l108=
github.com/multiformats/go-multihash v0.2.1/go.mod h1:WxoMcYG85AZVQUyRyo9s4wULvW5qrI9vb2Lt6evduFc=
github.com/multiformats/go-multistream v0.3.3 h1:d5PZpjwRgVlbwfdTDjife7XszfZd8KYWfROYFlGcR8o=
github.com/multiformats/go-multistream v0.3.3/go.mod h1:ODRoqamLUsETKS9BNcII4gcRsJBU5VAwRIv7O39cEXg=
github.com/multiformats/go-multistream v0.4.0 h1:5i4JbawClkbuaX+mIVXiHQYVPxUW+zjv6w7jtSRukxc=
github.com/multiformats/go-multistream v0.4.0/go.mod h1:BS6ZSYcA4NwYEaIMeCtpJydp2Dc+fNRA6uJMSu/m8+4=
github.com/multiformats/go-varint v0.0.1/go.mod h1:3Ls8CIEsrijN6+B7PbrXRPxHRPuXSrVKRY101jdMZYE=
github.com/multiformats/go-varint v0.0.7 h1:sWSGR+f/eu5ABZA2ZpYKBILXTTs9JWpdEM/nEGOHFS8=
github.com/multiformats/go-varint v0.0.7/go.mod h1:r8PUYw/fD/SjBCiKOoDlGF6QawOELpZAu9eioSos/OU=

View File

@ -23,8 +23,8 @@ import (
)
const (
protoIDv1 = string(relayv1.ProtoID)
protoIDv2 = string(circuitv2_proto.ProtoIDv2Hop)
protoIDv1 = relayv1.ProtoID
protoIDv2 = circuitv2_proto.ProtoIDv2Hop
)
// Terminology:

View File

@ -70,7 +70,7 @@ type BasicHost struct {
network network.Network
psManager *pstoremanager.PeerstoreManager
mux *msmux.MultistreamMuxer
mux *msmux.MultistreamMuxer[protocol.ID]
ids identify.IDService
hps *holepunch.Service
pings *ping.PingService
@ -108,7 +108,7 @@ var _ host.Host = (*BasicHost)(nil)
// customize construction of the *BasicHost.
type HostOpts struct {
// MultistreamMuxer is essential for the *BasicHost and will use a sensible default value if omitted.
MultistreamMuxer *msmux.MultistreamMuxer
MultistreamMuxer *msmux.MultistreamMuxer[protocol.ID]
// NegotiationTimeout determines the read and write timeouts on streams.
// If 0 or omitted, it will use DefaultNegotiationTimeout.
@ -168,7 +168,7 @@ func NewHost(n network.Network, opts *HostOpts) (*BasicHost, error) {
h := &BasicHost{
network: n,
psManager: psManager,
mux: msmux.NewMultistreamMuxer(),
mux: msmux.NewMultistreamMuxer[protocol.ID](),
negtimeout: DefaultNegotiationTimeout,
AddrsFactory: DefaultAddrsFactory,
maResolver: madns.DefaultResolver,
@ -407,7 +407,7 @@ func (h *BasicHost) newStreamHandler(s network.Stream) {
}
}
if err := s.SetProtocol(protocol.ID(protoID)); err != nil {
if err := s.SetProtocol(protoID); err != nil {
log.Debugf("error setting stream protocol: %s", err)
s.Reset()
return
@ -571,9 +571,9 @@ func (h *BasicHost) EventBus() event.Bus {
//
// (Threadsafe)
func (h *BasicHost) SetStreamHandler(pid protocol.ID, handler network.StreamHandler) {
h.Mux().AddHandler(string(pid), func(p string, rwc io.ReadWriteCloser) error {
h.Mux().AddHandler(pid, func(p protocol.ID, rwc io.ReadWriteCloser) error {
is := rwc.(network.Stream)
is.SetProtocol(protocol.ID(p))
is.SetProtocol(p)
handler(is)
return nil
})
@ -584,10 +584,10 @@ func (h *BasicHost) SetStreamHandler(pid protocol.ID, handler network.StreamHand
// SetStreamHandlerMatch sets the protocol handler on the Host's Mux
// using a matching function to do protocol comparisons
func (h *BasicHost) SetStreamHandlerMatch(pid protocol.ID, m func(string) bool, handler network.StreamHandler) {
h.Mux().AddHandlerWithFunc(string(pid), m, func(p string, rwc io.ReadWriteCloser) error {
func (h *BasicHost) SetStreamHandlerMatch(pid protocol.ID, m func(protocol.ID) bool, handler network.StreamHandler) {
h.Mux().AddHandlerWithFunc(pid, m, func(p protocol.ID, rwc io.ReadWriteCloser) error {
is := rwc.(network.Stream)
is.SetProtocol(protocol.ID(p))
is.SetProtocol(p)
handler(is)
return nil
})
@ -598,7 +598,7 @@ func (h *BasicHost) SetStreamHandlerMatch(pid protocol.ID, m func(string) bool,
// RemoveStreamHandler returns ..
func (h *BasicHost) RemoveStreamHandler(pid protocol.ID) {
h.Mux().RemoveHandler(string(pid))
h.Mux().RemoveHandler(pid)
h.emitters.evtLocalProtocolsUpdated.Emit(event.EvtLocalProtocolsUpdated{
Removed: []protocol.ID{pid},
})
@ -637,9 +637,7 @@ func (h *BasicHost) NewStream(ctx context.Context, p peer.ID, pids ...protocol.I
return nil, ctx.Err()
}
pidStrings := protocol.ConvertToStrings(pids)
pref, err := h.preferredProtocol(p, pidStrings)
pref, err := h.preferredProtocol(p, pids)
if err != nil {
_ = s.Reset()
return nil, err
@ -647,7 +645,7 @@ func (h *BasicHost) NewStream(ctx context.Context, p peer.ID, pids ...protocol.I
if pref != "" {
s.SetProtocol(pref)
lzcon := msmux.NewMSSelect(s, string(pref))
lzcon := msmux.NewMSSelect(s, pref)
return &streamWrapper{
Stream: s,
rw: lzcon,
@ -655,10 +653,10 @@ func (h *BasicHost) NewStream(ctx context.Context, p peer.ID, pids ...protocol.I
}
// Negotiate the protocol in the background, obeying the context.
var selected string
var selected protocol.ID
errCh := make(chan error, 1)
go func() {
selected, err = msmux.SelectOneOf(pidStrings, s)
selected, err = msmux.SelectOneOf(pids, s)
errCh <- err
}()
select {
@ -674,13 +672,12 @@ func (h *BasicHost) NewStream(ctx context.Context, p peer.ID, pids ...protocol.I
return nil, ctx.Err()
}
selpid := protocol.ID(selected)
s.SetProtocol(selpid)
s.SetProtocol(selected)
h.Peerstore().AddProtocols(p, selected)
return s, nil
}
func (h *BasicHost) preferredProtocol(p peer.ID, pids []string) (protocol.ID, error) {
func (h *BasicHost) preferredProtocol(p peer.ID, pids []protocol.ID) (protocol.ID, error) {
supported, err := h.Peerstore().SupportsProtocols(p, pids...)
if err != nil {
return "", err
@ -688,7 +685,7 @@ func (h *BasicHost) preferredProtocol(p peer.ID, pids []string) (protocol.ID, er
var out protocol.ID
if len(supported) > 0 {
out = protocol.ID(supported[0])
out = supported[0]
}
return out, nil
}

View File

@ -158,8 +158,8 @@ func TestProtocolHandlerEvents(t *testing.T) {
h.SetStreamHandler(protocol.TestingID, func(s network.Stream) {})
assert([]protocol.ID{protocol.TestingID}, nil)
h.SetStreamHandler(protocol.ID("foo"), func(s network.Stream) {})
assert([]protocol.ID{protocol.ID("foo")}, nil)
h.SetStreamHandler("foo", func(s network.Stream) {})
assert([]protocol.ID{"foo"}, nil)
h.RemoveStreamHandler(protocol.TestingID)
assert(nil, []protocol.ID{protocol.TestingID})
}
@ -273,9 +273,9 @@ func TestHostProtoPreference(t *testing.T) {
defer h2.Close()
const (
protoOld = protocol.ID("/testing")
protoNew = protocol.ID("/testing/1.1.0")
protoMinor = protocol.ID("/testing/1.2.0")
protoOld = "/testing"
protoNew = "/testing/1.1.0"
protoMinor = "/testing/1.2.0"
)
connectedOn := make(chan protocol.ID)
@ -299,7 +299,7 @@ func TestHostProtoPreference(t *testing.T) {
assertWait(t, connectedOn, protoOld)
s.Close()
h2.SetStreamHandlerMatch(protoMinor, func(string) bool { return true }, handler)
h2.SetStreamHandlerMatch(protoMinor, func(protocol.ID) bool { return true }, handler)
// remembered preference will be chosen first, even when the other side newly supports it
s2, err := h1.NewStream(context.Background(), h2.ID(), protoMinor, protoNew, protoOld)
require.NoError(t, err)

View File

@ -27,7 +27,7 @@ var log = logging.Logger("blankhost")
// BlankHost is the thinnest implementation of the host.Host interface
type BlankHost struct {
n network.Network
mux *mstream.MultistreamMuxer
mux *mstream.MultistreamMuxer[protocol.ID]
cmgr connmgr.ConnManager
eventbus event.Bus
emitters struct {
@ -65,7 +65,7 @@ func NewBlankHost(n network.Network, options ...Option) *BlankHost {
bh := &BlankHost{
n: n,
cmgr: cfg.cmgr,
mux: mstream.NewMultistreamMuxer(),
mux: mstream.NewMultistreamMuxer[protocol.ID](),
}
if bh.eventbus == nil {
bh.eventbus = eventbus.NewBus()
@ -158,35 +158,29 @@ func (bh *BlankHost) NewStream(ctx context.Context, p peer.ID, protos ...protoco
return nil, err
}
protoStrs := make([]string, len(protos))
for i, pid := range protos {
protoStrs[i] = string(pid)
}
selected, err := mstream.SelectOneOf(protoStrs, s)
selected, err := mstream.SelectOneOf(protos, s)
if err != nil {
s.Reset()
return nil, err
}
selpid := protocol.ID(selected)
s.SetProtocol(selpid)
s.SetProtocol(selected)
bh.Peerstore().AddProtocols(p, selected)
return s, nil
}
func (bh *BlankHost) RemoveStreamHandler(pid protocol.ID) {
bh.Mux().RemoveHandler(string(pid))
bh.Mux().RemoveHandler(pid)
bh.emitters.evtLocalProtocolsUpdated.Emit(event.EvtLocalProtocolsUpdated{
Removed: []protocol.ID{pid},
})
}
func (bh *BlankHost) SetStreamHandler(pid protocol.ID, handler network.StreamHandler) {
bh.Mux().AddHandler(string(pid), func(p string, rwc io.ReadWriteCloser) error {
bh.Mux().AddHandler(pid, func(p protocol.ID, rwc io.ReadWriteCloser) error {
is := rwc.(network.Stream)
is.SetProtocol(protocol.ID(p))
is.SetProtocol(p)
handler(is)
return nil
})
@ -195,10 +189,10 @@ func (bh *BlankHost) SetStreamHandler(pid protocol.ID, handler network.StreamHan
})
}
func (bh *BlankHost) SetStreamHandlerMatch(pid protocol.ID, m func(string) bool, handler network.StreamHandler) {
bh.Mux().AddHandlerWithFunc(string(pid), m, func(p string, rwc io.ReadWriteCloser) error {
func (bh *BlankHost) SetStreamHandlerMatch(pid protocol.ID, m func(protocol.ID) bool, handler network.StreamHandler) {
bh.Mux().AddHandlerWithFunc(pid, m, func(p protocol.ID, rwc io.ReadWriteCloser) error {
is := rwc.(network.Stream)
is.SetProtocol(protocol.ID(p))
is.SetProtocol(p)
handler(is)
return nil
})
@ -216,7 +210,7 @@ func (bh *BlankHost) newStreamHandler(s network.Stream) {
return
}
s.SetProtocol(protocol.ID(protoID))
s.SetProtocol(protoID)
go handle(protoID, s)
}

View File

@ -8,6 +8,7 @@ import (
pool "github.com/libp2p/go-buffer-pool"
"github.com/libp2p/go-libp2p/core/peer"
pstore "github.com/libp2p/go-libp2p/core/peerstore"
"github.com/libp2p/go-libp2p/core/protocol"
ds "github.com/ipfs/go-datastore"
"github.com/ipfs/go-datastore/query"
@ -28,7 +29,7 @@ func init() {
// Gob registers basic types by default.
//
// Register complex types used by the peerstore itself.
gob.Register(make(map[string]struct{}))
gob.Register(make(map[protocol.ID]struct{}))
}
// NewPeerMetadata creates a metadata store backed by a persistent db. It uses gob for serialisation.

View File

@ -7,6 +7,7 @@ import (
"github.com/libp2p/go-libp2p/core/peer"
pstore "github.com/libp2p/go-libp2p/core/peerstore"
"github.com/libp2p/go-libp2p/core/protocol"
)
type protoSegment struct {
@ -58,12 +59,12 @@ func NewProtoBook(meta pstore.PeerMetadata, opts ...ProtoBookOption) (*dsProtoBo
return pb, nil
}
func (pb *dsProtoBook) SetProtocols(p peer.ID, protos ...string) error {
func (pb *dsProtoBook) SetProtocols(p peer.ID, protos ...protocol.ID) error {
if len(protos) > pb.maxProtos {
return errTooManyProtocols
}
protomap := make(map[string]struct{}, len(protos))
protomap := make(map[protocol.ID]struct{}, len(protos))
for _, proto := range protos {
protomap[proto] = struct{}{}
}
@ -75,7 +76,7 @@ func (pb *dsProtoBook) SetProtocols(p peer.ID, protos ...string) error {
return pb.meta.Put(p, "protocols", protomap)
}
func (pb *dsProtoBook) AddProtocols(p peer.ID, protos ...string) error {
func (pb *dsProtoBook) AddProtocols(p peer.ID, protos ...protocol.ID) error {
s := pb.segments.get(p)
s.Lock()
defer s.Unlock()
@ -95,7 +96,7 @@ func (pb *dsProtoBook) AddProtocols(p peer.ID, protos ...string) error {
return pb.meta.Put(p, "protocols", pmap)
}
func (pb *dsProtoBook) GetProtocols(p peer.ID) ([]string, error) {
func (pb *dsProtoBook) GetProtocols(p peer.ID) ([]protocol.ID, error) {
s := pb.segments.get(p)
s.RLock()
defer s.RUnlock()
@ -105,7 +106,7 @@ func (pb *dsProtoBook) GetProtocols(p peer.ID) ([]string, error) {
return nil, err
}
res := make([]string, 0, len(pmap))
res := make([]protocol.ID, 0, len(pmap))
for proto := range pmap {
res = append(res, proto)
}
@ -113,7 +114,7 @@ func (pb *dsProtoBook) GetProtocols(p peer.ID) ([]string, error) {
return res, nil
}
func (pb *dsProtoBook) SupportsProtocols(p peer.ID, protos ...string) ([]string, error) {
func (pb *dsProtoBook) SupportsProtocols(p peer.ID, protos ...protocol.ID) ([]protocol.ID, error) {
s := pb.segments.get(p)
s.RLock()
defer s.RUnlock()
@ -123,7 +124,7 @@ func (pb *dsProtoBook) SupportsProtocols(p peer.ID, protos ...string) ([]string,
return nil, err
}
res := make([]string, 0, len(protos))
res := make([]protocol.ID, 0, len(protos))
for _, proto := range protos {
if _, ok := pmap[proto]; ok {
res = append(res, proto)
@ -133,7 +134,7 @@ func (pb *dsProtoBook) SupportsProtocols(p peer.ID, protos ...string) ([]string,
return res, nil
}
func (pb *dsProtoBook) FirstSupportedProtocol(p peer.ID, protos ...string) (string, error) {
func (pb *dsProtoBook) FirstSupportedProtocol(p peer.ID, protos ...protocol.ID) (protocol.ID, error) {
s := pb.segments.get(p)
s.RLock()
defer s.RUnlock()
@ -151,7 +152,7 @@ func (pb *dsProtoBook) FirstSupportedProtocol(p peer.ID, protos ...string) (stri
return "", nil
}
func (pb *dsProtoBook) RemoveProtocols(p peer.ID, protos ...string) error {
func (pb *dsProtoBook) RemoveProtocols(p peer.ID, protos ...protocol.ID) error {
s := pb.segments.get(p)
s.Lock()
defer s.Unlock()
@ -173,15 +174,15 @@ func (pb *dsProtoBook) RemoveProtocols(p peer.ID, protos ...string) error {
return pb.meta.Put(p, "protocols", pmap)
}
func (pb *dsProtoBook) getProtocolMap(p peer.ID) (map[string]struct{}, error) {
func (pb *dsProtoBook) getProtocolMap(p peer.ID) (map[protocol.ID]struct{}, error) {
iprotomap, err := pb.meta.Get(p, "protocols")
switch err {
default:
return nil, err
case pstore.ErrNotFound:
return make(map[string]struct{}), nil
return make(map[protocol.ID]struct{}), nil
case nil:
cast, ok := iprotomap.(map[string]struct{})
cast, ok := iprotomap.(map[protocol.ID]struct{})
if !ok {
return nil, fmt.Errorf("stored protocol set was not a map")
}

View File

@ -6,11 +6,12 @@ import (
"github.com/libp2p/go-libp2p/core/peer"
pstore "github.com/libp2p/go-libp2p/core/peerstore"
"github.com/libp2p/go-libp2p/core/protocol"
)
type protoSegment struct {
sync.RWMutex
protocols map[peer.ID]map[string]struct{}
protocols map[peer.ID]map[protocol.ID]struct{}
}
type protoSegments [256]*protoSegment
@ -27,7 +28,7 @@ type memoryProtoBook struct {
maxProtos int
lk sync.RWMutex
interned map[string]string
interned map[protocol.ID]protocol.ID
}
var _ pstore.ProtoBook = (*memoryProtoBook)(nil)
@ -43,11 +44,11 @@ func WithMaxProtocols(num int) ProtoBookOption {
func NewProtoBook(opts ...ProtoBookOption) (*memoryProtoBook, error) {
pb := &memoryProtoBook{
interned: make(map[string]string, 256),
interned: make(map[protocol.ID]protocol.ID, 256),
segments: func() (ret protoSegments) {
for i := range ret {
ret[i] = &protoSegment{
protocols: make(map[peer.ID]map[string]struct{}),
protocols: make(map[peer.ID]map[protocol.ID]struct{}),
}
}
return ret
@ -63,7 +64,7 @@ func NewProtoBook(opts ...ProtoBookOption) (*memoryProtoBook, error) {
return pb, nil
}
func (pb *memoryProtoBook) internProtocol(proto string) string {
func (pb *memoryProtoBook) internProtocol(proto protocol.ID) protocol.ID {
// check if it is interned with the read lock
pb.lk.RLock()
interned, ok := pb.interned[proto]
@ -87,12 +88,12 @@ func (pb *memoryProtoBook) internProtocol(proto string) string {
return proto
}
func (pb *memoryProtoBook) SetProtocols(p peer.ID, protos ...string) error {
func (pb *memoryProtoBook) SetProtocols(p peer.ID, protos ...protocol.ID) error {
if len(protos) > pb.maxProtos {
return errTooManyProtocols
}
newprotos := make(map[string]struct{}, len(protos))
newprotos := make(map[protocol.ID]struct{}, len(protos))
for _, proto := range protos {
newprotos[pb.internProtocol(proto)] = struct{}{}
}
@ -105,14 +106,14 @@ func (pb *memoryProtoBook) SetProtocols(p peer.ID, protos ...string) error {
return nil
}
func (pb *memoryProtoBook) AddProtocols(p peer.ID, protos ...string) error {
func (pb *memoryProtoBook) AddProtocols(p peer.ID, protos ...protocol.ID) error {
s := pb.segments.get(p)
s.Lock()
defer s.Unlock()
protomap, ok := s.protocols[p]
if !ok {
protomap = make(map[string]struct{})
protomap = make(map[protocol.ID]struct{})
s.protocols[p] = protomap
}
if len(protomap)+len(protos) > pb.maxProtos {
@ -125,12 +126,12 @@ func (pb *memoryProtoBook) AddProtocols(p peer.ID, protos ...string) error {
return nil
}
func (pb *memoryProtoBook) GetProtocols(p peer.ID) ([]string, error) {
func (pb *memoryProtoBook) GetProtocols(p peer.ID) ([]protocol.ID, error) {
s := pb.segments.get(p)
s.RLock()
defer s.RUnlock()
out := make([]string, 0, len(s.protocols[p]))
out := make([]protocol.ID, 0, len(s.protocols[p]))
for k := range s.protocols[p] {
out = append(out, k)
}
@ -138,7 +139,7 @@ func (pb *memoryProtoBook) GetProtocols(p peer.ID) ([]string, error) {
return out, nil
}
func (pb *memoryProtoBook) RemoveProtocols(p peer.ID, protos ...string) error {
func (pb *memoryProtoBook) RemoveProtocols(p peer.ID, protos ...protocol.ID) error {
s := pb.segments.get(p)
s.Lock()
defer s.Unlock()
@ -155,12 +156,12 @@ func (pb *memoryProtoBook) RemoveProtocols(p peer.ID, protos ...string) error {
return nil
}
func (pb *memoryProtoBook) SupportsProtocols(p peer.ID, protos ...string) ([]string, error) {
func (pb *memoryProtoBook) SupportsProtocols(p peer.ID, protos ...protocol.ID) ([]protocol.ID, error) {
s := pb.segments.get(p)
s.RLock()
defer s.RUnlock()
out := make([]string, 0, len(protos))
out := make([]protocol.ID, 0, len(protos))
for _, proto := range protos {
if _, ok := s.protocols[p][proto]; ok {
out = append(out, proto)
@ -170,7 +171,7 @@ func (pb *memoryProtoBook) SupportsProtocols(p peer.ID, protos ...string) ([]str
return out, nil
}
func (pb *memoryProtoBook) FirstSupportedProtocol(p peer.ID, protos ...string) (string, error) {
func (pb *memoryProtoBook) FirstSupportedProtocol(p peer.ID, protos ...protocol.ID) (protocol.ID, error) {
s := pb.segments.get(p)
s.RLock()
defer s.RUnlock()

View File

@ -12,6 +12,7 @@ import (
"github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/peer"
pstore "github.com/libp2p/go-libp2p/core/peerstore"
"github.com/libp2p/go-libp2p/core/protocol"
ma "github.com/multiformats/go-multiaddr"
"github.com/stretchr/testify/require"
@ -44,6 +45,10 @@ func TestPeerstore(t *testing.T, factory PeerstoreFactory) {
}
}
func sortProtos(protos []protocol.ID) {
sort.Slice(protos, func(i, j int) bool { return protos[i] < protos[j] })
}
func testAddrStream(ps pstore.Peerstore) func(t *testing.T) {
return func(t *testing.T) {
addrs, pid := getAddrs(t, 100), peer.ID("testpeer")
@ -209,14 +214,14 @@ func testPeerstoreProtoStore(ps pstore.Peerstore) func(t *testing.T) {
return func(t *testing.T) {
t.Run("adding and removing protocols", func(t *testing.T) {
p1 := peer.ID("TESTPEER")
protos := []string{"a", "b", "c", "d"}
protos := []protocol.ID{"a", "b", "c", "d"}
require.NoError(t, ps.AddProtocols(p1, protos...))
out, err := ps.GetProtocols(p1)
require.NoError(t, err)
require.Len(t, out, len(protos), "got wrong number of protocols back")
sort.Strings(out)
sortProtos(out)
for i, p := range protos {
if out[i] != p {
t.Fatal("got wrong protocol")
@ -233,7 +238,7 @@ func testPeerstoreProtoStore(ps pstore.Peerstore) func(t *testing.T) {
b, err := ps.FirstSupportedProtocol(p1, "q", "w", "a", "y", "b")
require.NoError(t, err)
require.Equal(t, "a", b)
require.Equal(t, protocol.ID("a"), b)
b, err = ps.FirstSupportedProtocol(p1, "q", "x", "z")
require.NoError(t, err)
@ -241,9 +246,9 @@ func testPeerstoreProtoStore(ps pstore.Peerstore) func(t *testing.T) {
b, err = ps.FirstSupportedProtocol(p1, "a")
require.NoError(t, err)
require.Equal(t, "a", b)
require.Equal(t, protocol.ID("a"), b)
protos = []string{"other", "yet another", "one more"}
protos = []protocol.ID{"other", "yet another", "one more"}
require.NoError(t, ps.SetProtocols(p1, protos...))
supported, err = ps.SupportsProtocols(p1, "q", "w", "a", "y", "b")
@ -253,8 +258,8 @@ func testPeerstoreProtoStore(ps pstore.Peerstore) func(t *testing.T) {
supported, err = ps.GetProtocols(p1)
require.NoError(t, err)
sort.Strings(supported)
sort.Strings(protos)
sortProtos(supported)
sortProtos(protos)
if !reflect.DeepEqual(supported, protos) {
t.Fatalf("expected previously set protos; expected: %v, have: %v", protos, supported)
}
@ -270,7 +275,7 @@ func testPeerstoreProtoStore(ps pstore.Peerstore) func(t *testing.T) {
t.Run("removing peer", func(t *testing.T) {
p := peer.ID("foobar")
protos := []string{"a", "b"}
protos := []protocol.ID{"a", "b"}
require.NoError(t, ps.SetProtocols(p, protos...))
out, err := ps.GetProtocols(p)
@ -383,9 +388,9 @@ func getAddrs(t *testing.T, n int) []ma.Multiaddr {
func TestPeerstoreProtoStoreLimits(t *testing.T, ps pstore.Peerstore, limit int) {
p := peer.ID("foobar")
protocols := make([]string, limit)
protocols := make([]protocol.ID, limit)
for i := 0; i < limit; i++ {
protocols[i] = fmt.Sprintf("protocol %d", i)
protocols[i] = protocol.ID(fmt.Sprintf("protocol %d", i))
}
t.Run("setting protocols", func(t *testing.T) {

View File

@ -12,6 +12,7 @@ import (
gomock "github.com/golang/mock/gomock"
crypto "github.com/libp2p/go-libp2p/core/crypto"
peer "github.com/libp2p/go-libp2p/core/peer"
protocol "github.com/libp2p/go-libp2p/core/protocol"
multiaddr "github.com/multiformats/go-multiaddr"
)
@ -77,7 +78,7 @@ func (mr *MockPeerstoreMockRecorder) AddPrivKey(arg0, arg1 interface{}) *gomock.
}
// AddProtocols mocks base method.
func (m *MockPeerstore) AddProtocols(arg0 peer.ID, arg1 ...string) error {
func (m *MockPeerstore) AddProtocols(arg0 peer.ID, arg1 ...protocol.ID) error {
m.ctrl.T.Helper()
varargs := []interface{}{arg0}
for _, a := range arg1 {
@ -164,14 +165,14 @@ func (mr *MockPeerstoreMockRecorder) Close() *gomock.Call {
}
// FirstSupportedProtocol mocks base method.
func (m *MockPeerstore) FirstSupportedProtocol(arg0 peer.ID, arg1 ...string) (string, error) {
func (m *MockPeerstore) FirstSupportedProtocol(arg0 peer.ID, arg1 ...protocol.ID) (protocol.ID, error) {
m.ctrl.T.Helper()
varargs := []interface{}{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "FirstSupportedProtocol", varargs...)
ret0, _ := ret[0].(string)
ret0, _ := ret[0].(protocol.ID)
ret1, _ := ret[1].(error)
return ret0, ret1
}
@ -199,10 +200,10 @@ func (mr *MockPeerstoreMockRecorder) Get(arg0, arg1 interface{}) *gomock.Call {
}
// GetProtocols mocks base method.
func (m *MockPeerstore) GetProtocols(arg0 peer.ID) ([]string, error) {
func (m *MockPeerstore) GetProtocols(arg0 peer.ID) ([]protocol.ID, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetProtocols", arg0)
ret0, _ := ret[0].([]string)
ret0, _ := ret[0].([]protocol.ID)
ret1, _ := ret[1].(error)
return ret0, ret1
}
@ -350,7 +351,7 @@ func (mr *MockPeerstoreMockRecorder) RemovePeer(arg0 interface{}) *gomock.Call {
}
// RemoveProtocols mocks base method.
func (m *MockPeerstore) RemoveProtocols(arg0 peer.ID, arg1 ...string) error {
func (m *MockPeerstore) RemoveProtocols(arg0 peer.ID, arg1 ...protocol.ID) error {
m.ctrl.T.Helper()
varargs := []interface{}{arg0}
for _, a := range arg1 {
@ -393,7 +394,7 @@ func (mr *MockPeerstoreMockRecorder) SetAddrs(arg0, arg1, arg2 interface{}) *gom
}
// SetProtocols mocks base method.
func (m *MockPeerstore) SetProtocols(arg0 peer.ID, arg1 ...string) error {
func (m *MockPeerstore) SetProtocols(arg0 peer.ID, arg1 ...protocol.ID) error {
m.ctrl.T.Helper()
varargs := []interface{}{arg0}
for _, a := range arg1 {
@ -412,14 +413,14 @@ func (mr *MockPeerstoreMockRecorder) SetProtocols(arg0 interface{}, arg1 ...inte
}
// SupportsProtocols mocks base method.
func (m *MockPeerstore) SupportsProtocols(arg0 peer.ID, arg1 ...string) ([]string, error) {
func (m *MockPeerstore) SupportsProtocols(arg0 peer.ID, arg1 ...protocol.ID) ([]protocol.ID, error) {
m.ctrl.T.Helper()
varargs := []interface{}{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "SupportsProtocols", varargs...)
ret0, _ := ret[0].([]string)
ret0, _ := ret[0].([]protocol.ID)
ret1, _ := ret[1].(error)
return ret0, ret1
}

View File

@ -87,7 +87,7 @@ func (r *resourceManager) ListProtocols() []protocol.ID {
}
sort.Slice(result, func(i, j int) bool {
return strings.Compare(string(result[i]), string(result[j])) < 0
return result[i] < result[j]
})
return result

View File

@ -188,7 +188,7 @@ func (rh *RoutedHost) SetStreamHandler(pid protocol.ID, handler network.StreamHa
rh.host.SetStreamHandler(pid, handler)
}
func (rh *RoutedHost) SetStreamHandlerMatch(pid protocol.ID, m func(string) bool, handler network.StreamHandler) {
func (rh *RoutedHost) SetStreamHandlerMatch(pid protocol.ID, m func(protocol.ID) bool, handler network.StreamHandler) {
rh.host.SetStreamHandlerMatch(pid, m, handler)
}

View File

@ -122,12 +122,12 @@ func appendConnectionState(tags []string, cs network.ConnectionState) []string {
// This shouldn't happen, unless the transport doesn't properly set the Transport field in the ConnectionState.
tags = append(tags, "unknown")
} else {
tags = append(tags, cs.Transport)
tags = append(tags, string(cs.Transport))
}
// These might be empty, depending on the transport.
// For example, QUIC doesn't set security nor muxer.
tags = append(tags, cs.Security)
tags = append(tags, cs.StreamMultiplexer)
tags = append(tags, string(cs.Security))
tags = append(tags, string(cs.StreamMultiplexer))
return tags
}

View File

@ -56,8 +56,8 @@ func (t *transportConn) Close() error {
func (t *transportConn) ConnState() network.ConnectionState {
return network.ConnectionState{
StreamMultiplexer: string(t.muxer),
Security: string(t.security),
StreamMultiplexer: t.muxer,
Security: t.security,
Transport: "tcp",
}
}

View File

@ -405,7 +405,7 @@ func TestNoCommonSecurityProto(t *testing.T) {
}()
_, err = dial(t, ub, ln.Multiaddr(), idA, &network.NullScope{})
require.EqualError(t, err, "failed to negotiate security protocol: protocol not supported")
require.ErrorContains(t, err, "failed to negotiate security protocol: protocols not supported")
select {
case <-done:
t.Fatal("didn't expect to accept a connection")

View File

@ -53,13 +53,13 @@ type upgrader struct {
connGater connmgr.ConnectionGater
rcmgr network.ResourceManager
muxerMuxer *mss.MultistreamMuxer
muxerMuxer *mss.MultistreamMuxer[protocol.ID]
muxers []StreamMuxer
muxerIDs []string
muxerIDs []protocol.ID
security []sec.SecureTransport
securityMuxer *mss.MultistreamMuxer
securityIDs []string
securityMuxer *mss.MultistreamMuxer[protocol.ID]
securityIDs []protocol.ID
// AcceptTimeout is the maximum duration an Accept is allowed to take.
// This includes the time between accepting the raw network connection,
@ -77,10 +77,10 @@ func New(security []sec.SecureTransport, muxers []StreamMuxer, psk ipnet.PSK, rc
rcmgr: rcmgr,
connGater: connGater,
psk: psk,
muxerMuxer: mss.NewMultistreamMuxer(),
muxerMuxer: mss.NewMultistreamMuxer[protocol.ID](),
muxers: muxers,
security: security,
securityMuxer: mss.NewMultistreamMuxer(),
securityMuxer: mss.NewMultistreamMuxer[protocol.ID](),
}
for _, opt := range opts {
if err := opt(u); err != nil {
@ -90,15 +90,15 @@ func New(security []sec.SecureTransport, muxers []StreamMuxer, psk ipnet.PSK, rc
if u.rcmgr == nil {
u.rcmgr = &network.NullResourceManager{}
}
u.muxerIDs = make([]string, 0, len(muxers))
u.muxerIDs = make([]protocol.ID, 0, len(muxers))
for _, m := range muxers {
u.muxerMuxer.AddHandler(string(m.ID), nil)
u.muxerIDs = append(u.muxerIDs, string(m.ID))
u.muxerMuxer.AddHandler(m.ID, nil)
u.muxerIDs = append(u.muxerIDs, m.ID)
}
u.securityIDs = make([]string, 0, len(security))
u.securityIDs = make([]protocol.ID, 0, len(security))
for _, s := range security {
u.securityMuxer.AddHandler(string(s.ID()), nil)
u.securityIDs = append(u.securityIDs, string(s.ID()))
u.securityMuxer.AddHandler(s.ID(), nil)
u.securityIDs = append(u.securityIDs, s.ID())
}
return u, nil
}
@ -219,7 +219,7 @@ func (u *upgrader) negotiateMuxer(nc net.Conn, isServer bool) (*StreamMuxer, err
return nil, err
}
var proto string
var proto protocol.ID
if isServer {
selected, _, err := u.muxerMuxer.Negotiate(nc)
if err != nil {
@ -244,9 +244,9 @@ func (u *upgrader) negotiateMuxer(nc net.Conn, isServer bool) (*StreamMuxer, err
return nil, fmt.Errorf("selected protocol we don't have a transport for")
}
func (u *upgrader) getMuxerByID(id string) *StreamMuxer {
func (u *upgrader) getMuxerByID(id protocol.ID) *StreamMuxer {
for _, m := range u.muxers {
if string(m.ID) == id {
if m.ID == id {
return &m
}
}
@ -265,7 +265,7 @@ func (u *upgrader) setupMuxer(ctx context.Context, conn sec.SecureConn, server b
if err != nil {
return "", nil, err
}
return protocol.ID(muxerSelected), c, nil
return muxerSelected, c, nil
}
type result struct {
@ -298,9 +298,9 @@ func (u *upgrader) setupMuxer(ctx context.Context, conn sec.SecureConn, server b
}
}
func (u *upgrader) getSecurityByID(id string) sec.SecureTransport {
func (u *upgrader) getSecurityByID(id protocol.ID) sec.SecureTransport {
for _, s := range u.security {
if string(s.ID()) == id {
if s.ID() == id {
return s
}
}
@ -309,7 +309,7 @@ func (u *upgrader) getSecurityByID(id string) sec.SecureTransport {
func (u *upgrader) negotiateSecurity(ctx context.Context, insecure net.Conn, server bool) (sec.SecureTransport, bool, error) {
type result struct {
proto string
proto protocol.ID
iamserver bool
err error
}

View File

@ -27,7 +27,7 @@ func TestReservationFailures(t *testing.T) {
{
name: "unsupported protocol",
streamHandler: nil,
err: "protocol not supported",
err: "protocols not supported",
},
{
name: "wrong message type",

View File

@ -224,7 +224,7 @@ func TestFailuresOnInitiator(t *testing.T) {
hps := addHolePunchService(t, h2, opts...)
// wait until the hole punching protocol has actually started
require.Eventually(t, func() bool {
protos, _ := h2.Peerstore().SupportsProtocols(h1.ID(), string(holepunch.Protocol))
protos, _ := h2.Peerstore().SupportsProtocols(h1.ID(), holepunch.Protocol)
return len(protos) > 0
}, 200*time.Millisecond, 10*time.Millisecond)

View File

@ -500,7 +500,7 @@ func (ids *idService) createBaseIdentifyResponse(
localAddr := conn.LocalMultiaddr()
// set protocols this node is currently handling
mes.Protocols = snapshot.protocols
mes.Protocols = protocol.ConvertToStrings(snapshot.protocols)
// observed address so other side is informed of their
// "public" address, at least in relation to us.
@ -560,7 +560,7 @@ func (ids *idService) getSignedRecord(snapshot *identifySnapshot) []byte {
}
// diff takes two slices of strings (a and b) and computes which elements were added and removed in b
func diff(a, b []string) (added, removed []string) {
func diff(a, b []protocol.ID) (added, removed []protocol.ID) {
// This is O(n^2), but it's fine because the slices are small.
for _, x := range b {
var found bool
@ -593,13 +593,14 @@ func (ids *idService) consumeMessage(mes *pb.Identify, c network.Conn, isPush bo
p := c.RemotePeer()
supported, _ := ids.Host.Peerstore().GetProtocols(p)
added, removed := diff(supported, mes.Protocols)
ids.Host.Peerstore().SetProtocols(p, mes.Protocols...)
mesProtocols := protocol.ConvertFromStrings(mes.Protocols)
added, removed := diff(supported, mesProtocols)
ids.Host.Peerstore().SetProtocols(p, mesProtocols...)
if isPush {
ids.emitters.evtPeerProtocolsUpdated.Emit(event.EvtPeerProtocolsUpdated{
Peer: p,
Added: protocol.ConvertFromStrings(added),
Removed: protocol.ConvertFromStrings(removed),
Added: added,
Removed: removed,
})
}

View File

@ -383,7 +383,7 @@ func TestIdentifyPushWhileIdentifyingConn(t *testing.T) {
handler := func(s network.Stream) {
<-block
w := pbio.NewDelimitedWriter(s)
w.WriteMsg(&pb.Identify{Protocols: h1.Mux().Protocols()})
w.WriteMsg(&pb.Identify{Protocols: protocol.ConvertToStrings(h1.Mux().Protocols())})
s.Close()
}
h1.RemoveStreamHandler(identify.ID)
@ -587,14 +587,14 @@ func TestSendPush(t *testing.T) {
// h1 starts listening on a new protocol and h2 finds out about that through a push
h1.SetStreamHandler("rand", func(network.Stream) {})
require.Eventually(t, func() bool {
sup, err := h2.Peerstore().SupportsProtocols(h1.ID(), []string{"rand"}...)
sup, err := h2.Peerstore().SupportsProtocols(h1.ID(), []protocol.ID{"rand"}...)
return err == nil && len(sup) == 1 && sup[0] == "rand"
}, time.Second, 10*time.Millisecond)
// h1 stops listening on a protocol and h2 finds out about it via a push
h1.RemoveStreamHandler("rand")
require.Eventually(t, func() bool {
sup, err := h2.Peerstore().SupportsProtocols(h1.ID(), []string{"rand"}...)
sup, err := h2.Peerstore().SupportsProtocols(h1.ID(), []protocol.ID{"rand"}...)
return err == nil && len(sup) == 0
}, time.Second, 10*time.Millisecond)
}
@ -613,9 +613,9 @@ func TestLargeIdentifyMessage(t *testing.T) {
// add protocol strings to make the message larger
// about 2K of protocol strings
for i := 0; i < 500; i++ {
r := fmt.Sprintf("rand%d", i)
h1.SetStreamHandler(protocol.ID(r), func(network.Stream) {})
h2.SetStreamHandler(protocol.ID(r), func(network.Stream) {})
r := protocol.ID(fmt.Sprintf("rand%d", i))
h1.SetStreamHandler(r, func(network.Stream) {})
h2.SetStreamHandler(r, func(network.Stream) {})
}
h1p := h1.ID()
@ -719,9 +719,9 @@ func TestLargePushMessage(t *testing.T) {
// add protocol strings to make the message larger
// about 2K of protocol strings
for i := 0; i < 500; i++ {
r := fmt.Sprintf("rand%d", i)
h1.SetStreamHandler(protocol.ID(r), func(network.Stream) {})
h2.SetStreamHandler(protocol.ID(r), func(network.Stream) {})
r := protocol.ID(fmt.Sprintf("rand%d", i))
h1.SetStreamHandler(r, func(network.Stream) {})
h2.SetStreamHandler(r, func(network.Stream) {})
}
h1p := h1.ID()

View File

@ -16,7 +16,7 @@ import (
var errProtocolNotSupported = errors.New("protocol not supported")
type identifySnapshot struct {
protocols []string
protocols []protocol.ID
addrs []ma.Multiaddr
record *record.Envelope
}
@ -103,7 +103,7 @@ func (ph *peerHandler) sendPush(ctx context.Context) error {
return nil
}
func (ph *peerHandler) openStream(ctx context.Context, proto string) (network.Stream, error) {
func (ph *peerHandler) openStream(ctx context.Context, proto protocol.ID) (network.Stream, error) {
// wait for the other peer to send us an Identify response on "all" connections we have with it
// so we can look at it's supported protocols and avoid a multistream-select roundtrip to negotiate the protocol
// if we know for a fact that it doesn't support the protocol.
@ -127,5 +127,5 @@ func (ph *peerHandler) openStream(ctx context.Context, proto string) (network.St
// negotiate a stream without opening a new connection as we "should" already have a connection.
ctx, cancel := context.WithTimeout(network.WithNoDial(ctx, "should already have connection"), 30*time.Second)
defer cancel()
return ph.ids.Host.NewStream(ctx, ph.pid, protocol.ID(proto))
return ph.ids.Host.NewStream(ctx, ph.pid, proto)
}

View File

@ -12,6 +12,7 @@ import (
"github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/protocol"
)
type secureSession struct {
@ -134,7 +135,7 @@ func (s *secureSession) Close() error {
return s.insecureConn.Close()
}
func SessionWithConnState(s *secureSession, muxer string) *secureSession {
func SessionWithConnState(s *secureSession, muxer protocol.ID) *secureSession {
if s != nil {
s.connectionState.StreamMultiplexer = muxer
}

View File

@ -23,7 +23,7 @@ type Transport struct {
protocolID protocol.ID
localID peer.ID
privateKey crypto.PrivKey
muxers []string
muxers []protocol.ID
}
var _ sec.SecureTransport = &Transport{}
@ -36,16 +36,16 @@ func New(id protocol.ID, privkey crypto.PrivKey, muxers []tptu.StreamMuxer) (*Tr
return nil, err
}
smuxers := make([]string, 0, len(muxers))
muxerIDs := make([]protocol.ID, 0, len(muxers))
for _, m := range muxers {
smuxers = append(smuxers, string(m.ID))
muxerIDs = append(muxerIDs, m.ID)
}
return &Transport{
protocolID: id,
localID: localID,
privateKey: privkey,
muxers: smuxers,
muxers: muxerIDs,
}, nil
}
@ -87,7 +87,7 @@ func (t *Transport) ID() protocol.ID {
return t.protocolID
}
func matchMuxers(initiatorMuxers, responderMuxers []string) string {
func matchMuxers(initiatorMuxers, responderMuxers []protocol.ID) protocol.ID {
for _, initMuxer := range initiatorMuxers {
for _, respMuxer := range responderMuxers {
if initMuxer == respMuxer {
@ -100,7 +100,7 @@ func matchMuxers(initiatorMuxers, responderMuxers []string) string {
type transportEarlyDataHandler struct {
transport *Transport
receivedMuxers []string
receivedMuxers []protocol.ID
}
var _ EarlyDataHandler = &transportEarlyDataHandler{}
@ -111,19 +111,19 @@ func newTransportEDH(t *Transport) *transportEarlyDataHandler {
func (i *transportEarlyDataHandler) Send(context.Context, net.Conn, peer.ID) *pb.NoiseExtensions {
return &pb.NoiseExtensions{
StreamMuxers: i.transport.muxers,
StreamMuxers: protocol.ConvertToStrings(i.transport.muxers),
}
}
func (i *transportEarlyDataHandler) Received(_ context.Context, _ net.Conn, extension *pb.NoiseExtensions) error {
// Discard messages with size or the number of protocols exceeding extension limit for security.
if extension != nil && len(extension.StreamMuxers) <= maxProtoNum {
i.receivedMuxers = extension.GetStreamMuxers()
i.receivedMuxers = protocol.ConvertFromStrings(extension.GetStreamMuxers())
}
return nil
}
func (i *transportEarlyDataHandler) MatchMuxers(isInitiator bool) string {
func (i *transportEarlyDataHandler) MatchMuxers(isInitiator bool) protocol.ID {
if isInitiator {
return matchMuxers(i.transport.muxers, i.receivedMuxers)
}

View File

@ -15,6 +15,7 @@ import (
"github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/protocol"
"github.com/libp2p/go-libp2p/core/sec"
"github.com/libp2p/go-libp2p/p2p/security/noise/pb"
@ -37,7 +38,7 @@ func newTestTransport(t *testing.T, typ, bits int) *Transport {
}
}
func newTestTransportWithMuxers(t *testing.T, typ, bits int, muxers []string) *Transport {
func newTestTransportWithMuxers(t *testing.T, typ, bits int, muxers []protocol.ID) *Transport {
transport := newTestTransport(t, typ, bits)
transport.muxers = muxers
return transport
@ -632,9 +633,9 @@ func TestEarlyfffDataAcceptedWithNoHandler(t *testing.T) {
}
type noiseEarlyDataTestCase struct {
clientProtos []string
serverProtos []string
expectedResult string
clientProtos []protocol.ID
serverProtos []protocol.ID
expectedResult protocol.ID
}
func TestHandshakeWithTransportEarlyData(t *testing.T) {
@ -645,43 +646,43 @@ func TestHandshakeWithTransportEarlyData(t *testing.T) {
expectedResult: "",
},
{
clientProtos: []string{"muxer1"},
serverProtos: []string{"muxer1"},
clientProtos: []protocol.ID{"muxer1"},
serverProtos: []protocol.ID{"muxer1"},
expectedResult: "muxer1",
},
{
clientProtos: []string{"muxer1"},
serverProtos: []string{},
clientProtos: []protocol.ID{"muxer1"},
serverProtos: []protocol.ID{},
expectedResult: "",
},
{
clientProtos: []string{},
serverProtos: []string{"muxer2"},
clientProtos: []protocol.ID{},
serverProtos: []protocol.ID{"muxer2"},
expectedResult: "",
},
{
clientProtos: []string{"muxer2"},
serverProtos: []string{"muxer1"},
clientProtos: []protocol.ID{"muxer2"},
serverProtos: []protocol.ID{"muxer1"},
expectedResult: "",
},
{
clientProtos: []string{"muxer1", "muxer2"},
serverProtos: []string{"muxer2", "muxer1"},
clientProtos: []protocol.ID{"muxer1", "muxer2"},
serverProtos: []protocol.ID{"muxer2", "muxer1"},
expectedResult: "muxer1",
},
{
clientProtos: []string{"muxer3", "muxer2", "muxer1"},
serverProtos: []string{"muxer2", "muxer1"},
clientProtos: []protocol.ID{"muxer3", "muxer2", "muxer1"},
serverProtos: []protocol.ID{"muxer2", "muxer1"},
expectedResult: "muxer2",
},
{
clientProtos: []string{"muxer1", "muxer2"},
serverProtos: []string{"muxer3"},
clientProtos: []protocol.ID{"muxer1", "muxer2"},
serverProtos: []protocol.ID{"muxer3"},
expectedResult: "",
},
}
noiseHandshake := func(t *testing.T, initProtos, respProtos []string, expectedProto string) {
noiseHandshake := func(t *testing.T, initProtos, respProtos []protocol.ID, expectedProto protocol.ID) {
initTransport := newTestTransportWithMuxers(t, crypto.Ed25519, 2048, initProtos)
respTransport := newTestTransportWithMuxers(t, crypto.Ed25519, 2048, respProtos)

View File

@ -171,7 +171,7 @@ func (t *Transport) setupConn(tlsConn *tls.Conn, remotePubKey ci.PubKey) (sec.Se
privKey: t.privKey,
remotePeer: remotePeerID,
remotePubKey: remotePubKey,
connectionState: network.ConnectionState{StreamMultiplexer: nextProto},
connectionState: network.ConnectionState{StreamMultiplexer: protocol.ID(nextProto)},
}, nil
}

View File

@ -180,7 +180,7 @@ func TestHandshakeSucceeds(t *testing.T) {
type testcase struct {
clientProtos []protocol.ID
serverProtos []protocol.ID
expectedResult string
expectedResult protocol.ID
}
func TestHandshakeWithNextProtoSucceeds(t *testing.T) {
@ -225,7 +225,7 @@ func TestHandshakeWithNextProtoSucceeds(t *testing.T) {
clientID, clientKey := createPeer(t)
serverID, serverKey := createPeer(t)
handshake := func(t *testing.T, clientTransport *Transport, serverTransport *Transport, expectedMuxer string) {
handshake := func(t *testing.T, clientTransport *Transport, serverTransport *Transport, expectedMuxer protocol.ID) {
clientInsecureConn, serverInsecureConn := connect(t)
serverConnChan := make(chan sec.SecureConn)

View File

@ -69,7 +69,7 @@ func TestMuxerNegotiation(t *testing.T) {
Name: "no preference overlap",
ServerPreference: []libp2p.Option{yamuxOpt},
ClientPreference: []libp2p.Option{mplexOpt},
Error: "failed to negotiate stream multiplexer: protocol not supported",
Error: "failed to negotiate stream multiplexer: protocols not supported",
},
}
@ -119,7 +119,7 @@ func TestMuxerNegotiation(t *testing.T) {
require.NoError(t, err)
conns := client.Network().ConnsToPeer(server.ID())
require.Len(t, conns, 1, "expected exactly one connection")
require.Equal(t, tc.Expected, protocol.ID(conns[0].ConnState().StreamMultiplexer))
require.Equal(t, tc.Expected, conns[0].ConnState().StreamMultiplexer)
})
}
}

View File

@ -8,7 +8,6 @@ import (
"github.com/libp2p/go-libp2p"
"github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/protocol"
"github.com/libp2p/go-libp2p/p2p/security/noise"
tls "github.com/libp2p/go-libp2p/p2p/security/tls"
"github.com/libp2p/go-libp2p/p2p/transport/tcp"
@ -45,7 +44,7 @@ func TestSecurityNegotiation(t *testing.T) {
Name: "no overlap",
ServerPreference: []libp2p.Option{noiseOpt},
ClientPreference: []libp2p.Option{tlsOpt},
Error: "failed to negotiate security protocol: protocol not supported",
Error: "failed to negotiate security protocol: protocols not supported",
},
}
@ -84,7 +83,7 @@ func TestSecurityNegotiation(t *testing.T) {
require.NoError(t, err)
conns := client.Network().ConnsToPeer(server.ID())
require.Len(t, conns, 1, "expected exactly one connection")
require.Equal(t, tc.Expected, protocol.ID(conns[0].ConnState().Security))
require.Equal(t, tc.Expected, conns[0].ConnState().Security)
})
}
}

View File

@ -67,7 +67,7 @@ require (
github.com/multiformats/go-multibase v0.1.1 // indirect
github.com/multiformats/go-multicodec v0.7.0 // indirect
github.com/multiformats/go-multihash v0.2.1 // indirect
github.com/multiformats/go-multistream v0.3.3 // indirect
github.com/multiformats/go-multistream v0.4.0 // indirect
github.com/multiformats/go-varint v0.0.7 // indirect
github.com/onsi/ginkgo/v2 v2.5.1 // indirect
github.com/opencontainers/runtime-spec v1.0.2 // indirect

View File

@ -335,8 +335,8 @@ github.com/multiformats/go-multicodec v0.7.0/go.mod h1:GUC8upxSBE4oG+q3kWZRw/+6y
github.com/multiformats/go-multihash v0.0.8/go.mod h1:YSLudS+Pi8NHE7o6tb3D8vrpKa63epEDmG8nTduyAew=
github.com/multiformats/go-multihash v0.2.1 h1:aem8ZT0VA2nCHHk7bPJ1BjUbHNciqZC/d16Vve9l108=
github.com/multiformats/go-multihash v0.2.1/go.mod h1:WxoMcYG85AZVQUyRyo9s4wULvW5qrI9vb2Lt6evduFc=
github.com/multiformats/go-multistream v0.3.3 h1:d5PZpjwRgVlbwfdTDjife7XszfZd8KYWfROYFlGcR8o=
github.com/multiformats/go-multistream v0.3.3/go.mod h1:ODRoqamLUsETKS9BNcII4gcRsJBU5VAwRIv7O39cEXg=
github.com/multiformats/go-multistream v0.4.0 h1:5i4JbawClkbuaX+mIVXiHQYVPxUW+zjv6w7jtSRukxc=
github.com/multiformats/go-multistream v0.4.0/go.mod h1:BS6ZSYcA4NwYEaIMeCtpJydp2Dc+fNRA6uJMSu/m8+4=
github.com/multiformats/go-varint v0.0.1/go.mod h1:3Ls8CIEsrijN6+B7PbrXRPxHRPuXSrVKRY101jdMZYE=
github.com/multiformats/go-varint v0.0.7 h1:sWSGR+f/eu5ABZA2ZpYKBILXTTs9JWpdEM/nEGOHFS8=
github.com/multiformats/go-varint v0.0.7/go.mod h1:r8PUYw/fD/SjBCiKOoDlGF6QawOELpZAu9eioSos/OU=