2024-01-18 14:28:06 +00:00

152 lines
3.6 KiB
Go

package multistream
import (
"bytes"
"errors"
"fmt"
"io"
"os"
"runtime/debug"
)
// ErrNotSupported is the error returned when the muxer doesn't support
// the protocols tried for the handshake.
type ErrNotSupported[T StringLike] struct {
// Slice of protocols that were not supported by the muxer
Protos []T
}
func (e ErrNotSupported[T]) Error() string {
return fmt.Sprintf("protocols not supported: %v", e.Protos)
}
func (e ErrNotSupported[T]) Is(target error) bool {
_, ok := target.(ErrNotSupported[T])
return ok
}
// ErrNoProtocols is the error returned when the no protocols have been
// specified.
var ErrNoProtocols = errors.New("no protocols specified")
// SelectProtoOrFail performs the initial multistream handshake
// to inform the muxer of the protocol that will be used to communicate
// on this ReadWriteCloser. It returns an error if, for example,
// the muxer does not know how to handle this protocol.
func SelectProtoOrFail[T StringLike](proto T, rwc io.ReadWriteCloser) (err error) {
defer func() {
if rerr := recover(); rerr != nil {
fmt.Fprintf(os.Stderr, "caught panic: %s\n%s\n", rerr, debug.Stack())
err = fmt.Errorf("panic selecting protocol: %s", rerr)
}
}()
errCh := make(chan error, 1)
go func() {
var buf bytes.Buffer
if err := delitmWriteAll(&buf, []byte(ProtocolID), []byte(proto)); err != nil {
errCh <- err
return
}
_, err := io.Copy(rwc, &buf)
errCh <- err
}()
// We have to read *both* errors.
err1 := readMultistreamHeader(rwc)
err2 := readProto(proto, rwc)
if werr := <-errCh; werr != nil {
return werr
}
if err1 != nil {
return err1
}
if err2 != nil {
return err2
}
return nil
}
// SelectOneOf will perform handshakes with the protocols on the given slice
// until it finds one which is supported by the muxer.
func SelectOneOf[T StringLike](protos []T, rwc io.ReadWriteCloser) (proto T, err error) {
defer func() {
if rerr := recover(); rerr != nil {
fmt.Fprintf(os.Stderr, "caught panic: %s\n%s\n", rerr, debug.Stack())
err = fmt.Errorf("panic selecting one of protocols: %s", rerr)
}
}()
if len(protos) == 0 {
return "", ErrNoProtocols
}
// Use SelectProtoOrFail to pipeline the /multistream/1.0.0 handshake
// with an attempt to negotiate the first protocol. If that fails, we
// can continue negotiating the rest of the protocols normally.
//
// This saves us a round trip.
switch err := SelectProtoOrFail(protos[0], rwc); err.(type) {
case nil:
return protos[0], nil
case ErrNotSupported[T]: // try others
default:
return "", err
}
proto, err = selectProtosOrFail(protos[1:], rwc)
if _, ok := err.(ErrNotSupported[T]); ok {
return "", ErrNotSupported[T]{protos}
}
return proto, err
}
func selectProtosOrFail[T StringLike](protos []T, rwc io.ReadWriteCloser) (T, error) {
for _, p := range protos {
err := trySelect(p, rwc)
switch err := err.(type) {
case nil:
return p, nil
case ErrNotSupported[T]:
default:
return "", err
}
}
return "", ErrNotSupported[T]{protos}
}
func readMultistreamHeader(r io.Reader) error {
tok, err := ReadNextToken[string](r)
if err != nil {
return err
}
if tok != ProtocolID {
return errors.New("received mismatch in protocol id")
}
return nil
}
func trySelect[T StringLike](proto T, rwc io.ReadWriteCloser) error {
err := delimWriteBuffered(rwc, []byte(proto))
if err != nil {
return err
}
return readProto(proto, rwc)
}
func readProto[T StringLike](proto T, r io.Reader) error {
tok, err := ReadNextToken[T](r)
if err != nil {
return err
}
switch tok {
case proto:
return nil
case "na":
return ErrNotSupported[T]{[]T{proto}}
default:
return fmt.Errorf("unrecognized response: %s", tok)
}
}