334 lines
7.2 KiB
Go
334 lines
7.2 KiB
Go
package multistream
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/rand"
|
|
"encoding/binary"
|
|
"errors"
|
|
"io"
|
|
"strconv"
|
|
"strings"
|
|
)
|
|
|
|
// ErrNotSupported is the error returned when the muxer does not support
|
|
// the protocol specified for the handshake.
|
|
var ErrNotSupported = errors.New("protocol not supported")
|
|
|
|
// ErrNoProtocols is the error returned when the no protocols have been
|
|
// specified.
|
|
var ErrNoProtocols = errors.New("no protocols specified")
|
|
|
|
const (
|
|
tieBreakerPrefix = "select:"
|
|
initiator = "initiator"
|
|
responder = "responder"
|
|
)
|
|
|
|
// SelectProtoOrFail performs the initial multistream handshake
|
|
// to inform the muxer of the protocol that will be used to communicate
|
|
// on this ReadWriteCloser. It returns an error if, for example,
|
|
// the muxer does not know how to handle this protocol.
|
|
func SelectProtoOrFail(proto string, rwc io.ReadWriteCloser) error {
|
|
errCh := make(chan error, 1)
|
|
go func() {
|
|
var buf bytes.Buffer
|
|
if err := delitmWriteAll(&buf, []byte(ProtocolID), []byte(proto)); err != nil {
|
|
errCh <- err
|
|
return
|
|
}
|
|
_, err := io.Copy(rwc, &buf)
|
|
errCh <- err
|
|
}()
|
|
// We have to read *both* errors.
|
|
err1 := readMultistreamHeader(rwc)
|
|
err2 := readProto(proto, rwc)
|
|
if werr := <-errCh; werr != nil {
|
|
return werr
|
|
}
|
|
if err1 != nil {
|
|
return err1
|
|
}
|
|
if err2 != nil {
|
|
return err2
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// SelectOneOf will perform handshakes with the protocols on the given slice
|
|
// until it finds one which is supported by the muxer.
|
|
func SelectOneOf(protos []string, rwc io.ReadWriteCloser) (string, error) {
|
|
if len(protos) == 0 {
|
|
return "", ErrNoProtocols
|
|
}
|
|
|
|
// Use SelectProtoOrFail to pipeline the /multistream/1.0.0 handshake
|
|
// with an attempt to negotiate the first protocol. If that fails, we
|
|
// can continue negotiating the rest of the protocols normally.
|
|
//
|
|
// This saves us a round trip.
|
|
switch err := SelectProtoOrFail(protos[0], rwc); err {
|
|
case nil:
|
|
return protos[0], nil
|
|
case ErrNotSupported: // try others
|
|
default:
|
|
return "", err
|
|
}
|
|
return selectProtosOrFail(protos[1:], rwc)
|
|
}
|
|
|
|
const simOpenProtocol = "/libp2p/simultaneous-connect"
|
|
|
|
// Performs protocol negotiation with the simultaneous open extension; the returned boolean
|
|
// indicator will be true if we should act as a server.
|
|
func SelectWithSimopenOrFail(protos []string, rwc io.ReadWriteCloser) (string, bool, error) {
|
|
if len(protos) == 0 {
|
|
return "", false, ErrNoProtocols
|
|
}
|
|
|
|
werrCh := make(chan error, 1)
|
|
go func() {
|
|
var buf bytes.Buffer
|
|
if err := delitmWriteAll(&buf, []byte(ProtocolID), []byte(simOpenProtocol), []byte(protos[0])); err != nil {
|
|
werrCh <- err
|
|
return
|
|
}
|
|
|
|
_, err := io.Copy(rwc, &buf)
|
|
werrCh <- err
|
|
}()
|
|
|
|
err := readMultistreamHeader(rwc)
|
|
if err != nil {
|
|
return "", false, err
|
|
}
|
|
|
|
tok, err := ReadNextToken(rwc)
|
|
if err != nil {
|
|
return "", false, err
|
|
}
|
|
|
|
if err = <-werrCh; err != nil {
|
|
return "", false, err
|
|
}
|
|
|
|
switch tok {
|
|
case simOpenProtocol:
|
|
// simultaneous open
|
|
return simOpen(protos, rwc)
|
|
case "na":
|
|
// client open
|
|
proto, err := clientOpen(protos, rwc)
|
|
if err != nil {
|
|
return "", false, err
|
|
}
|
|
return proto, false, nil
|
|
default:
|
|
return "", false, errors.New("unexpected response: " + tok)
|
|
}
|
|
}
|
|
|
|
func clientOpen(protos []string, rwc io.ReadWriteCloser) (string, error) {
|
|
// check to see if we selected the pipelined protocol
|
|
tok, err := ReadNextToken(rwc)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
switch tok {
|
|
case protos[0]:
|
|
return tok, nil
|
|
case "na":
|
|
return selectProtosOrFail(protos[1:], rwc)
|
|
default:
|
|
return "", errors.New("unexpected response: " + tok)
|
|
}
|
|
}
|
|
|
|
func selectProtosOrFail(protos []string, rwc io.ReadWriteCloser) (string, error) {
|
|
for _, p := range protos {
|
|
err := trySelect(p, rwc)
|
|
switch err {
|
|
case nil:
|
|
return p, nil
|
|
case ErrNotSupported:
|
|
default:
|
|
return "", err
|
|
}
|
|
}
|
|
return "", ErrNotSupported
|
|
}
|
|
|
|
func simOpen(protos []string, rwc io.ReadWriteCloser) (string, bool, error) {
|
|
randBytes := make([]byte, 8)
|
|
_, err := rand.Read(randBytes)
|
|
if err != nil {
|
|
return "", false, err
|
|
}
|
|
myNonce := binary.LittleEndian.Uint64(randBytes)
|
|
|
|
werrCh := make(chan error, 1)
|
|
go func() {
|
|
myselect := []byte(tieBreakerPrefix + strconv.FormatUint(myNonce, 10))
|
|
err := delimWriteBuffered(rwc, myselect)
|
|
werrCh <- err
|
|
}()
|
|
|
|
// skip exactly one protocol
|
|
// see https://github.com/multiformats/go-multistream/pull/42#discussion_r558757135
|
|
_, err = ReadNextToken(rwc)
|
|
if err != nil {
|
|
return "", false, err
|
|
}
|
|
|
|
// read the tie breaker nonce
|
|
tok, err := ReadNextToken(rwc)
|
|
if err != nil {
|
|
return "", false, err
|
|
}
|
|
if !strings.HasPrefix(tok, tieBreakerPrefix) {
|
|
return "", false, errors.New("tie breaker nonce not sent with the correct prefix")
|
|
}
|
|
|
|
if err = <-werrCh; err != nil {
|
|
return "", false, err
|
|
}
|
|
|
|
peerNonce, err := strconv.ParseUint(tok[len(tieBreakerPrefix):], 10, 64)
|
|
if err != nil {
|
|
return "", false, err
|
|
}
|
|
|
|
var iamserver bool
|
|
|
|
if peerNonce == myNonce {
|
|
return "", false, errors.New("failed client selection; identical nonces")
|
|
}
|
|
iamserver = peerNonce > myNonce
|
|
|
|
var proto string
|
|
if iamserver {
|
|
proto, err = simOpenSelectServer(protos, rwc)
|
|
} else {
|
|
proto, err = simOpenSelectClient(protos, rwc)
|
|
}
|
|
|
|
return proto, iamserver, err
|
|
}
|
|
|
|
func simOpenSelectServer(protos []string, rwc io.ReadWriteCloser) (string, error) {
|
|
werrCh := make(chan error, 1)
|
|
go func() {
|
|
err := delimWriteBuffered(rwc, []byte(responder))
|
|
werrCh <- err
|
|
}()
|
|
|
|
tok, err := ReadNextToken(rwc)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
if tok != initiator {
|
|
return "", errors.New("unexpected response: " + tok)
|
|
}
|
|
if err = <-werrCh; err != nil {
|
|
return "", err
|
|
}
|
|
|
|
for {
|
|
tok, err = ReadNextToken(rwc)
|
|
|
|
if err == io.EOF {
|
|
return "", ErrNotSupported
|
|
}
|
|
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
for _, p := range protos {
|
|
if tok == p {
|
|
err = delimWriteBuffered(rwc, []byte(p))
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return p, nil
|
|
}
|
|
}
|
|
|
|
err = delimWriteBuffered(rwc, []byte("na"))
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
}
|
|
|
|
}
|
|
|
|
func simOpenSelectClient(protos []string, rwc io.ReadWriteCloser) (string, error) {
|
|
werrCh := make(chan error, 1)
|
|
go func() {
|
|
err := delimWriteBuffered(rwc, []byte(initiator))
|
|
werrCh <- err
|
|
}()
|
|
|
|
tok, err := ReadNextToken(rwc)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
if tok != responder {
|
|
return "", errors.New("unexpected response: " + tok)
|
|
}
|
|
if err = <-werrCh; err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return selectProtosOrFail(protos, rwc)
|
|
}
|
|
|
|
func handshake(rw io.ReadWriter) error {
|
|
errCh := make(chan error, 1)
|
|
go func() {
|
|
errCh <- delimWriteBuffered(rw, []byte(ProtocolID))
|
|
}()
|
|
|
|
if err := readMultistreamHeader(rw); err != nil {
|
|
return err
|
|
}
|
|
return <-errCh
|
|
}
|
|
|
|
func readMultistreamHeader(r io.Reader) error {
|
|
tok, err := ReadNextToken(r)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if tok != ProtocolID {
|
|
return errors.New("received mismatch in protocol id")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func trySelect(proto string, rwc io.ReadWriteCloser) error {
|
|
err := delimWriteBuffered(rwc, []byte(proto))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return readProto(proto, rwc)
|
|
}
|
|
|
|
func readProto(proto string, r io.Reader) error {
|
|
tok, err := ReadNextToken(r)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
switch tok {
|
|
case proto:
|
|
return nil
|
|
case "na":
|
|
return ErrNotSupported
|
|
default:
|
|
return errors.New("unrecognized response: " + tok)
|
|
}
|
|
}
|