124 lines
2.8 KiB
Go
Raw Normal View History

package multistream
import (
2019-06-09 09:24:20 +02:00
"bytes"
"errors"
"io"
)
// ErrNotSupported is the error returned when the muxer does not support
// the protocol specified for the handshake.
var ErrNotSupported = errors.New("protocol not supported")
2019-06-09 09:24:20 +02:00
// ErrNoProtocols is the error returned when the no protocols have been
// specified.
var ErrNoProtocols = errors.New("no protocols specified")
// SelectProtoOrFail performs the initial multistream handshake
// to inform the muxer of the protocol that will be used to communicate
// on this ReadWriteCloser. It returns an error if, for example,
// the muxer does not know how to handle this protocol.
func SelectProtoOrFail(proto string, rwc io.ReadWriteCloser) error {
2019-06-09 09:24:20 +02:00
errCh := make(chan error, 1)
go func() {
var buf bytes.Buffer
delimWrite(&buf, []byte(ProtocolID))
delimWrite(&buf, []byte(proto))
_, err := io.Copy(rwc, &buf)
errCh <- err
}()
// We have to read *both* errors.
err1 := readMultistreamHeader(rwc)
err2 := readProto(proto, rwc)
if werr := <-errCh; werr != nil {
return werr
}
2019-06-09 09:24:20 +02:00
if err1 != nil {
return err1
}
if err2 != nil {
return err2
}
return nil
}
// SelectOneOf will perform handshakes with the protocols on the given slice
// until it finds one which is supported by the muxer.
func SelectOneOf(protos []string, rwc io.ReadWriteCloser) (string, error) {
2019-06-09 09:24:20 +02:00
if len(protos) == 0 {
return "", ErrNoProtocols
}
2019-06-09 09:24:20 +02:00
// Use SelectProtoOrFail to pipeline the /multistream/1.0.0 handshake
// with an attempt to negotiate the first protocol. If that fails, we
// can continue negotiating the rest of the protocols normally.
//
// This saves us a round trip.
switch err := SelectProtoOrFail(protos[0], rwc); err {
case nil:
return protos[0], nil
case ErrNotSupported: // try others
default:
return "", err
}
for _, p := range protos[1:] {
err := trySelect(p, rwc)
switch err {
case nil:
return p, nil
case ErrNotSupported:
default:
return "", err
}
}
return "", ErrNotSupported
}
func handshake(rwc io.ReadWriteCloser) error {
2019-06-09 09:24:20 +02:00
errCh := make(chan error, 1)
go func() {
errCh <- delimWriteBuffered(rwc, []byte(ProtocolID))
}()
2019-06-09 09:24:20 +02:00
if err := readMultistreamHeader(rwc); err != nil {
return err
}
2019-06-09 09:24:20 +02:00
return <-errCh
}
2019-06-09 09:24:20 +02:00
func readMultistreamHeader(r io.ReadWriter) error {
tok, err := ReadNextToken(r)
if err != nil {
return err
}
2019-06-09 09:24:20 +02:00
if tok != ProtocolID {
return errors.New("received mismatch in protocol id")
}
return nil
}
func trySelect(proto string, rwc io.ReadWriteCloser) error {
2019-06-09 09:24:20 +02:00
err := delimWriteBuffered(rwc, []byte(proto))
if err != nil {
return err
}
2019-06-09 09:24:20 +02:00
return readProto(proto, rwc)
}
2019-06-09 09:24:20 +02:00
func readProto(proto string, rw io.ReadWriter) error {
tok, err := ReadNextToken(rw)
if err != nil {
return err
}
switch tok {
case proto:
return nil
case "na":
return ErrNotSupported
default:
return errors.New("unrecognized response: " + tok)
}
}