2021-10-19 09:43:41 -04:00

108 lines
2.9 KiB
Go

package csms
import (
"context"
"fmt"
"net"
"github.com/libp2p/go-libp2p-core/peer"
"github.com/libp2p/go-libp2p-core/sec"
mss "github.com/multiformats/go-multistream"
)
// SSMuxer is a multistream stream security transport multiplexer.
//
// SSMuxer is safe to use without initialization. However, it's not safe to move
// after use.
type SSMuxer struct {
mux mss.MultistreamMuxer
tpts map[string]sec.SecureTransport
OrderPreference []string
}
var _ sec.SecureMuxer = (*SSMuxer)(nil)
// AddTransport adds a stream security transport to this multistream muxer.
//
// This method is *not* thread-safe. It should be called only when initializing
// the SSMuxer.
func (sm *SSMuxer) AddTransport(path string, transport sec.SecureTransport) {
if sm.tpts == nil {
sm.tpts = make(map[string]sec.SecureTransport, 1)
}
sm.mux.AddHandler(path, nil)
sm.tpts[path] = transport
sm.OrderPreference = append(sm.OrderPreference, path)
}
// SecureInbound secures an inbound connection using this multistream
// multiplexed stream security transport.
func (sm *SSMuxer) SecureInbound(ctx context.Context, insecure net.Conn) (sec.SecureConn, bool, error) {
tpt, _, err := sm.selectProto(ctx, insecure, true)
if err != nil {
return nil, false, err
}
sconn, err := tpt.SecureInbound(ctx, insecure)
return sconn, true, err
}
// SecureOutbound secures an outbound connection using this multistream
// multiplexed stream security transport.
func (sm *SSMuxer) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, bool, error) {
tpt, server, err := sm.selectProto(ctx, insecure, false)
if err != nil {
return nil, false, err
}
var sconn sec.SecureConn
if server {
sconn, err = tpt.SecureInbound(ctx, insecure)
if err != nil {
return nil, false, fmt.Errorf("failed to secure inbound connection: %s", err)
}
// ensure the correct peer connected to us
if sconn.RemotePeer() != p {
sconn.Close()
return nil, false, fmt.Errorf("Unexpected peer")
}
} else {
sconn, err = tpt.SecureOutbound(ctx, insecure, p)
}
return sconn, server, err
}
func (sm *SSMuxer) selectProto(ctx context.Context, insecure net.Conn, server bool) (sec.SecureTransport, bool, error) {
var proto string
var err error
var iamserver bool
done := make(chan struct{})
go func() {
defer close(done)
if server {
iamserver = true
proto, _, err = sm.mux.Negotiate(insecure)
} else {
proto, iamserver, err = mss.SelectWithSimopenOrFail(sm.OrderPreference, insecure)
}
}()
select {
case <-done:
if err != nil {
return nil, false, err
}
if tpt, ok := sm.tpts[proto]; ok {
return tpt, iamserver, nil
}
return nil, false, fmt.Errorf("selected unknown security transport")
case <-ctx.Done():
// We *must* do this. We have outstanding work on the connection
// and it's no longer safe to use.
insecure.Close()
<-done // wait to stop using the connection.
return nil, false, ctx.Err()
}
}