status-go/vendor/github.com/multiformats/go-multistream/lazyClient.go

161 lines
4.1 KiB
Go

package multistream
import (
"fmt"
"io"
"sync"
)
// NewMSSelect returns a new Multistream which is able to perform
// protocol selection with a MultistreamMuxer.
func NewMSSelect[T StringLike](c io.ReadWriteCloser, proto T) LazyConn {
return &lazyClientConn[T]{
protos: []T{ProtocolID, proto},
con: c,
}
}
// NewMultistream returns a multistream for the given protocol. This will not
// perform any protocol selection. If you are using a MultistreamMuxer, use
// NewMSSelect.
func NewMultistream[T StringLike](c io.ReadWriteCloser, proto T) LazyConn {
return &lazyClientConn[T]{
protos: []T{proto},
con: c,
}
}
// lazyClientConn is a ReadWriteCloser adapter that lazily negotiates a protocol
// using multistream-select on first use.
//
// It *does not* block writes waiting for the other end to respond. Instead, it
// simply assumes the negotiation went successfully and starts writing data.
// See: https://github.com/multiformats/go-multistream/issues/20
type lazyClientConn[T StringLike] struct {
// Used to ensure we only trigger the write half of the handshake once.
rhandshakeOnce sync.Once
rerr error
// Used to ensure we only trigger the read half of the handshake once.
whandshakeOnce sync.Once
werr error
// The sequence of protocols to negotiate.
protos []T
// The inner connection.
con io.ReadWriteCloser
}
// Read reads data from the io.ReadWriteCloser.
//
// If the protocol hasn't yet been negotiated, this method triggers the write
// half of the handshake and then waits for the read half to complete.
//
// It returns an error if the read half of the handshake fails.
func (l *lazyClientConn[T]) Read(b []byte) (int, error) {
l.rhandshakeOnce.Do(func() {
go l.whandshakeOnce.Do(l.doWriteHandshake)
l.doReadHandshake()
})
if l.rerr != nil {
return 0, l.rerr
}
if len(b) == 0 {
return 0, nil
}
return l.con.Read(b)
}
func (l *lazyClientConn[T]) doReadHandshake() {
for _, proto := range l.protos {
// read protocol
tok, err := ReadNextToken[T](l.con)
if err != nil {
l.rerr = err
return
}
if tok == "na" {
l.rerr = ErrNotSupported[T]{[]T{proto}}
return
}
if tok != proto {
l.rerr = fmt.Errorf("protocol mismatch in lazy handshake ( %s != %s )", tok, proto)
return
}
}
}
func (l *lazyClientConn[T]) doWriteHandshake() {
l.doWriteHandshakeWithData(nil)
}
// Perform the write handshake but *also* write some extra data.
func (l *lazyClientConn[T]) doWriteHandshakeWithData(extra []byte) int {
buf := getWriter(l.con)
defer putWriter(buf)
for _, proto := range l.protos {
l.werr = delimWrite(buf, []byte(proto))
if l.werr != nil {
return 0
}
}
n := 0
if len(extra) > 0 {
n, l.werr = buf.Write(extra)
if l.werr != nil {
return n
}
}
l.werr = buf.Flush()
return n
}
// Write writes the given buffer to the underlying connection.
//
// If the protocol has not yet been negotiated, write waits for the write half
// of the handshake to complete triggers (but does not wait for) the read half.
//
// Write *also* ignores errors from the read half of the handshake (in case the
// stream is actually write only).
func (l *lazyClientConn[T]) Write(b []byte) (int, error) {
n := 0
l.whandshakeOnce.Do(func() {
go l.rhandshakeOnce.Do(l.doReadHandshake)
n = l.doWriteHandshakeWithData(b)
})
if l.werr != nil || n > 0 {
return n, l.werr
}
return l.con.Write(b)
}
// Close closes the underlying io.ReadWriteCloser
//
// This does not flush anything.
func (l *lazyClientConn[T]) Close() error {
// As the client, we flush the handshake on close to cover an
// interesting edge-case where the server only speaks a single protocol
// and responds eagerly with that protocol before waiting for out
// handshake.
//
// Again, we must not read the error because the other end may have
// closed the stream for reading. I mean, we're the initiator so that's
// strange... but it's still allowed
_ = l.Flush()
return l.con.Close()
}
// Flush sends the handshake.
func (l *lazyClientConn[T]) Flush() error {
l.whandshakeOnce.Do(func() {
go l.rhandshakeOnce.Do(l.doReadHandshake)
l.doWriteHandshake()
})
return l.werr
}