2018-07-04 10:51:47 +00:00
|
|
|
package multistream
|
|
|
|
|
|
|
|
import (
|
2019-06-09 07:24:20 +00:00
|
|
|
"bytes"
|
2018-07-04 10:51:47 +00:00
|
|
|
"errors"
|
2022-08-19 16:34:07 +00:00
|
|
|
"fmt"
|
2018-07-04 10:51:47 +00:00
|
|
|
"io"
|
2022-08-19 16:34:07 +00:00
|
|
|
"os"
|
|
|
|
"runtime/debug"
|
2018-07-04 10:51:47 +00:00
|
|
|
)
|
|
|
|
|
2023-02-22 21:58:17 +00:00
|
|
|
// ErrNotSupported is the error returned when the muxer doesn't support
|
|
|
|
// the protocols tried for the handshake.
|
|
|
|
type ErrNotSupported[T StringLike] struct {
|
|
|
|
|
|
|
|
// Slice of protocols that were not supported by the muxer
|
|
|
|
Protos []T
|
|
|
|
}
|
|
|
|
|
|
|
|
func (e ErrNotSupported[T]) Error() string {
|
|
|
|
return fmt.Sprintf("protocols not supported: %v", e.Protos)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (e ErrNotSupported[T]) Is(target error) bool {
|
|
|
|
_, ok := target.(ErrNotSupported[T])
|
|
|
|
return ok
|
|
|
|
}
|
2018-07-04 10:51:47 +00:00
|
|
|
|
2019-06-09 07:24:20 +00:00
|
|
|
// ErrNoProtocols is the error returned when the no protocols have been
|
|
|
|
// specified.
|
|
|
|
var ErrNoProtocols = errors.New("no protocols specified")
|
|
|
|
|
2018-07-04 10:51:47 +00:00
|
|
|
// SelectProtoOrFail performs the initial multistream handshake
|
|
|
|
// to inform the muxer of the protocol that will be used to communicate
|
|
|
|
// on this ReadWriteCloser. It returns an error if, for example,
|
|
|
|
// the muxer does not know how to handle this protocol.
|
2023-02-22 21:58:17 +00:00
|
|
|
func SelectProtoOrFail[T StringLike](proto T, rwc io.ReadWriteCloser) (err error) {
|
2022-08-19 16:34:07 +00:00
|
|
|
defer func() {
|
|
|
|
if rerr := recover(); rerr != nil {
|
|
|
|
fmt.Fprintf(os.Stderr, "caught panic: %s\n%s\n", rerr, debug.Stack())
|
|
|
|
err = fmt.Errorf("panic selecting protocol: %s", rerr)
|
|
|
|
}
|
|
|
|
}()
|
|
|
|
|
2019-06-09 07:24:20 +00:00
|
|
|
errCh := make(chan error, 1)
|
|
|
|
go func() {
|
|
|
|
var buf bytes.Buffer
|
2021-10-19 13:43:41 +00:00
|
|
|
if err := delitmWriteAll(&buf, []byte(ProtocolID), []byte(proto)); err != nil {
|
|
|
|
errCh <- err
|
|
|
|
return
|
|
|
|
}
|
2019-06-09 07:24:20 +00:00
|
|
|
_, err := io.Copy(rwc, &buf)
|
|
|
|
errCh <- err
|
|
|
|
}()
|
|
|
|
// We have to read *both* errors.
|
|
|
|
err1 := readMultistreamHeader(rwc)
|
|
|
|
err2 := readProto(proto, rwc)
|
|
|
|
if werr := <-errCh; werr != nil {
|
|
|
|
return werr
|
2018-07-04 10:51:47 +00:00
|
|
|
}
|
2019-06-09 07:24:20 +00:00
|
|
|
if err1 != nil {
|
|
|
|
return err1
|
|
|
|
}
|
|
|
|
if err2 != nil {
|
|
|
|
return err2
|
|
|
|
}
|
|
|
|
return nil
|
2018-07-04 10:51:47 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
// SelectOneOf will perform handshakes with the protocols on the given slice
|
|
|
|
// until it finds one which is supported by the muxer.
|
2023-02-22 21:58:17 +00:00
|
|
|
func SelectOneOf[T StringLike](protos []T, rwc io.ReadWriteCloser) (proto T, err error) {
|
2022-08-19 16:34:07 +00:00
|
|
|
defer func() {
|
|
|
|
if rerr := recover(); rerr != nil {
|
|
|
|
fmt.Fprintf(os.Stderr, "caught panic: %s\n%s\n", rerr, debug.Stack())
|
|
|
|
err = fmt.Errorf("panic selecting one of protocols: %s", rerr)
|
|
|
|
}
|
|
|
|
}()
|
|
|
|
|
2019-06-09 07:24:20 +00:00
|
|
|
if len(protos) == 0 {
|
|
|
|
return "", ErrNoProtocols
|
2018-07-04 10:51:47 +00:00
|
|
|
}
|
|
|
|
|
2019-06-09 07:24:20 +00:00
|
|
|
// Use SelectProtoOrFail to pipeline the /multistream/1.0.0 handshake
|
|
|
|
// with an attempt to negotiate the first protocol. If that fails, we
|
|
|
|
// can continue negotiating the rest of the protocols normally.
|
|
|
|
//
|
|
|
|
// This saves us a round trip.
|
2023-02-22 21:58:17 +00:00
|
|
|
switch err := SelectProtoOrFail(protos[0], rwc); err.(type) {
|
2019-06-09 07:24:20 +00:00
|
|
|
case nil:
|
|
|
|
return protos[0], nil
|
2023-02-22 21:58:17 +00:00
|
|
|
case ErrNotSupported[T]: // try others
|
2019-06-09 07:24:20 +00:00
|
|
|
default:
|
|
|
|
return "", err
|
|
|
|
}
|
2023-02-22 21:58:17 +00:00
|
|
|
proto, err = selectProtosOrFail(protos[1:], rwc)
|
|
|
|
if _, ok := err.(ErrNotSupported[T]); ok {
|
|
|
|
return "", ErrNotSupported[T]{protos}
|
|
|
|
}
|
|
|
|
return proto, err
|
2021-10-19 13:43:41 +00:00
|
|
|
}
|
|
|
|
|
2023-02-22 21:58:17 +00:00
|
|
|
func selectProtosOrFail[T StringLike](protos []T, rwc io.ReadWriteCloser) (T, error) {
|
2021-10-19 13:43:41 +00:00
|
|
|
for _, p := range protos {
|
2018-07-04 10:51:47 +00:00
|
|
|
err := trySelect(p, rwc)
|
2023-02-22 21:58:17 +00:00
|
|
|
switch err := err.(type) {
|
2018-07-04 10:51:47 +00:00
|
|
|
case nil:
|
|
|
|
return p, nil
|
2023-02-22 21:58:17 +00:00
|
|
|
case ErrNotSupported[T]:
|
2018-07-04 10:51:47 +00:00
|
|
|
default:
|
|
|
|
return "", err
|
|
|
|
}
|
|
|
|
}
|
2023-02-22 21:58:17 +00:00
|
|
|
return "", ErrNotSupported[T]{protos}
|
2018-07-04 10:51:47 +00:00
|
|
|
}
|
|
|
|
|
2021-06-16 20:19:45 +00:00
|
|
|
func readMultistreamHeader(r io.Reader) error {
|
2023-02-22 21:58:17 +00:00
|
|
|
tok, err := ReadNextToken[string](r)
|
2018-07-04 10:51:47 +00:00
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
2019-06-09 07:24:20 +00:00
|
|
|
if tok != ProtocolID {
|
|
|
|
return errors.New("received mismatch in protocol id")
|
|
|
|
}
|
2018-07-04 10:51:47 +00:00
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
2023-02-22 21:58:17 +00:00
|
|
|
func trySelect[T StringLike](proto T, rwc io.ReadWriteCloser) error {
|
2019-06-09 07:24:20 +00:00
|
|
|
err := delimWriteBuffered(rwc, []byte(proto))
|
2018-07-04 10:51:47 +00:00
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
2019-06-09 07:24:20 +00:00
|
|
|
return readProto(proto, rwc)
|
|
|
|
}
|
2018-07-04 10:51:47 +00:00
|
|
|
|
2023-02-22 21:58:17 +00:00
|
|
|
func readProto[T StringLike](proto T, r io.Reader) error {
|
|
|
|
tok, err := ReadNextToken[T](r)
|
2018-07-04 10:51:47 +00:00
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
switch tok {
|
|
|
|
case proto:
|
|
|
|
return nil
|
|
|
|
case "na":
|
2023-02-22 21:58:17 +00:00
|
|
|
return ErrNotSupported[T]{[]T{proto}}
|
2018-07-04 10:51:47 +00:00
|
|
|
default:
|
2023-02-22 21:58:17 +00:00
|
|
|
return fmt.Errorf("unrecognized response: %s", tok)
|
2018-07-04 10:51:47 +00:00
|
|
|
}
|
|
|
|
}
|