161 lines
4.1 KiB
Go
161 lines
4.1 KiB
Go
package multistream
|
|
|
|
import (
|
|
"fmt"
|
|
"io"
|
|
"sync"
|
|
)
|
|
|
|
// NewMSSelect returns a new Multistream which is able to perform
|
|
// protocol selection with a MultistreamMuxer.
|
|
func NewMSSelect[T StringLike](c io.ReadWriteCloser, proto T) LazyConn {
|
|
return &lazyClientConn[T]{
|
|
protos: []T{ProtocolID, proto},
|
|
con: c,
|
|
}
|
|
}
|
|
|
|
// NewMultistream returns a multistream for the given protocol. This will not
|
|
// perform any protocol selection. If you are using a MultistreamMuxer, use
|
|
// NewMSSelect.
|
|
func NewMultistream[T StringLike](c io.ReadWriteCloser, proto T) LazyConn {
|
|
return &lazyClientConn[T]{
|
|
protos: []T{proto},
|
|
con: c,
|
|
}
|
|
}
|
|
|
|
// lazyClientConn is a ReadWriteCloser adapter that lazily negotiates a protocol
|
|
// using multistream-select on first use.
|
|
//
|
|
// It *does not* block writes waiting for the other end to respond. Instead, it
|
|
// simply assumes the negotiation went successfully and starts writing data.
|
|
// See: https://github.com/multiformats/go-multistream/issues/20
|
|
type lazyClientConn[T StringLike] struct {
|
|
// Used to ensure we only trigger the write half of the handshake once.
|
|
rhandshakeOnce sync.Once
|
|
rerr error
|
|
|
|
// Used to ensure we only trigger the read half of the handshake once.
|
|
whandshakeOnce sync.Once
|
|
werr error
|
|
|
|
// The sequence of protocols to negotiate.
|
|
protos []T
|
|
|
|
// The inner connection.
|
|
con io.ReadWriteCloser
|
|
}
|
|
|
|
// Read reads data from the io.ReadWriteCloser.
|
|
//
|
|
// If the protocol hasn't yet been negotiated, this method triggers the write
|
|
// half of the handshake and then waits for the read half to complete.
|
|
//
|
|
// It returns an error if the read half of the handshake fails.
|
|
func (l *lazyClientConn[T]) Read(b []byte) (int, error) {
|
|
l.rhandshakeOnce.Do(func() {
|
|
go l.whandshakeOnce.Do(l.doWriteHandshake)
|
|
l.doReadHandshake()
|
|
})
|
|
if l.rerr != nil {
|
|
return 0, l.rerr
|
|
}
|
|
if len(b) == 0 {
|
|
return 0, nil
|
|
}
|
|
|
|
return l.con.Read(b)
|
|
}
|
|
|
|
func (l *lazyClientConn[T]) doReadHandshake() {
|
|
for _, proto := range l.protos {
|
|
// read protocol
|
|
tok, err := ReadNextToken[T](l.con)
|
|
if err != nil {
|
|
l.rerr = err
|
|
return
|
|
}
|
|
|
|
if tok == "na" {
|
|
l.rerr = ErrNotSupported[T]{[]T{proto}}
|
|
return
|
|
}
|
|
if tok != proto {
|
|
l.rerr = fmt.Errorf("protocol mismatch in lazy handshake ( %s != %s )", tok, proto)
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (l *lazyClientConn[T]) doWriteHandshake() {
|
|
l.doWriteHandshakeWithData(nil)
|
|
}
|
|
|
|
// Perform the write handshake but *also* write some extra data.
|
|
func (l *lazyClientConn[T]) doWriteHandshakeWithData(extra []byte) int {
|
|
buf := getWriter(l.con)
|
|
defer putWriter(buf)
|
|
|
|
for _, proto := range l.protos {
|
|
l.werr = delimWrite(buf, []byte(proto))
|
|
if l.werr != nil {
|
|
return 0
|
|
}
|
|
}
|
|
|
|
n := 0
|
|
if len(extra) > 0 {
|
|
n, l.werr = buf.Write(extra)
|
|
if l.werr != nil {
|
|
return n
|
|
}
|
|
}
|
|
l.werr = buf.Flush()
|
|
return n
|
|
}
|
|
|
|
// Write writes the given buffer to the underlying connection.
|
|
//
|
|
// If the protocol has not yet been negotiated, write waits for the write half
|
|
// of the handshake to complete triggers (but does not wait for) the read half.
|
|
//
|
|
// Write *also* ignores errors from the read half of the handshake (in case the
|
|
// stream is actually write only).
|
|
func (l *lazyClientConn[T]) Write(b []byte) (int, error) {
|
|
n := 0
|
|
l.whandshakeOnce.Do(func() {
|
|
go l.rhandshakeOnce.Do(l.doReadHandshake)
|
|
n = l.doWriteHandshakeWithData(b)
|
|
})
|
|
if l.werr != nil || n > 0 {
|
|
return n, l.werr
|
|
}
|
|
return l.con.Write(b)
|
|
}
|
|
|
|
// Close closes the underlying io.ReadWriteCloser
|
|
//
|
|
// This does not flush anything.
|
|
func (l *lazyClientConn[T]) Close() error {
|
|
// As the client, we flush the handshake on close to cover an
|
|
// interesting edge-case where the server only speaks a single protocol
|
|
// and responds eagerly with that protocol before waiting for out
|
|
// handshake.
|
|
//
|
|
// Again, we must not read the error because the other end may have
|
|
// closed the stream for reading. I mean, we're the initiator so that's
|
|
// strange... but it's still allowed
|
|
_ = l.Flush()
|
|
return l.con.Close()
|
|
}
|
|
|
|
// Flush sends the handshake.
|
|
func (l *lazyClientConn[T]) Flush() error {
|
|
l.whandshakeOnce.Do(func() {
|
|
go l.rhandshakeOnce.Do(l.doReadHandshake)
|
|
l.doWriteHandshake()
|
|
})
|
|
return l.werr
|
|
}
|