110 lines
3.1 KiB
Go
110 lines
3.1 KiB
Go
package csms
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"log"
|
|
"net"
|
|
|
|
"github.com/libp2p/go-libp2p-core/peer"
|
|
"github.com/libp2p/go-libp2p-core/sec"
|
|
mss "github.com/multiformats/go-multistream"
|
|
)
|
|
|
|
// SSMuxer is a multistream stream security transport multiplexer.
|
|
//
|
|
// SSMuxer is safe to use without initialization. However, it's not safe to move
|
|
// after use.
|
|
type SSMuxer struct {
|
|
mux mss.MultistreamMuxer
|
|
tpts map[string]sec.SecureTransport
|
|
OrderPreference []string
|
|
}
|
|
|
|
var _ sec.SecureMuxer = (*SSMuxer)(nil)
|
|
|
|
// AddTransport adds a stream security transport to this multistream muxer.
|
|
//
|
|
// This method is *not* thread-safe. It should be called only when initializing
|
|
// the SSMuxer.
|
|
func (sm *SSMuxer) AddTransport(path string, transport sec.SecureTransport) {
|
|
if sm.tpts == nil {
|
|
sm.tpts = make(map[string]sec.SecureTransport, 1)
|
|
}
|
|
|
|
sm.mux.AddHandler(path, nil)
|
|
sm.tpts[path] = transport
|
|
sm.OrderPreference = append(sm.OrderPreference, path)
|
|
}
|
|
|
|
// SecureInbound secures an inbound connection using this multistream
|
|
// multiplexed stream security transport.
|
|
func (sm *SSMuxer) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, bool, error) {
|
|
tpt, _, err := sm.selectProto(ctx, insecure, true)
|
|
if err != nil {
|
|
return nil, false, err
|
|
}
|
|
sconn, err := tpt.SecureInbound(ctx, insecure, p)
|
|
return sconn, true, err
|
|
}
|
|
|
|
// SecureOutbound secures an outbound connection using this multistream
|
|
// multiplexed stream security transport.
|
|
func (sm *SSMuxer) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, bool, error) {
|
|
tpt, server, err := sm.selectProto(ctx, insecure, false)
|
|
if err != nil {
|
|
return nil, false, err
|
|
}
|
|
|
|
var sconn sec.SecureConn
|
|
if server {
|
|
sconn, err = tpt.SecureInbound(ctx, insecure, p)
|
|
if err != nil {
|
|
return nil, false, fmt.Errorf("failed to secure inbound connection: %s", err)
|
|
}
|
|
// ensure the correct peer connected to us
|
|
if sconn.RemotePeer() != p {
|
|
sconn.Close()
|
|
log.Printf("Handshake failed to properly authenticate peer. Authenticated %s, expected %s.", sconn.RemotePeer(), p)
|
|
return nil, false, fmt.Errorf("unexpected peer")
|
|
}
|
|
} else {
|
|
sconn, err = tpt.SecureOutbound(ctx, insecure, p)
|
|
}
|
|
|
|
return sconn, server, err
|
|
}
|
|
|
|
func (sm *SSMuxer) selectProto(ctx context.Context, insecure net.Conn, server bool) (sec.SecureTransport, bool, error) {
|
|
var proto string
|
|
var err error
|
|
var iamserver bool
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer close(done)
|
|
if server {
|
|
iamserver = true
|
|
proto, _, err = sm.mux.Negotiate(insecure)
|
|
} else {
|
|
proto, iamserver, err = mss.SelectWithSimopenOrFail(sm.OrderPreference, insecure)
|
|
}
|
|
}()
|
|
|
|
select {
|
|
case <-done:
|
|
if err != nil {
|
|
return nil, false, err
|
|
}
|
|
if tpt, ok := sm.tpts[proto]; ok {
|
|
return tpt, iamserver, nil
|
|
}
|
|
return nil, false, fmt.Errorf("selected unknown security transport")
|
|
case <-ctx.Done():
|
|
// We *must* do this. We have outstanding work on the connection
|
|
// and it's no longer safe to use.
|
|
insecure.Close()
|
|
<-done // wait to stop using the connection.
|
|
return nil, false, ctx.Err()
|
|
}
|
|
}
|