2022-04-06 11:48:16 +02:00

173 lines
3.2 KiB
Go

// Package connctx wraps net.Conn using context.Context.
package connctx
import (
"context"
"errors"
"io"
"net"
"sync"
"sync/atomic"
"time"
)
// ErrClosing is returned on Write to closed connection.
var ErrClosing = errors.New("use of closed network connection")
// Reader is an interface for context controlled reader.
type Reader interface {
ReadContext(context.Context, []byte) (int, error)
}
// Writer is an interface for context controlled writer.
type Writer interface {
WriteContext(context.Context, []byte) (int, error)
}
// ReadWriter is a composite of ReadWriter.
type ReadWriter interface {
Reader
Writer
}
// ConnCtx is a wrapper of net.Conn using context.Context.
type ConnCtx interface {
Reader
Writer
io.Closer
LocalAddr() net.Addr
RemoteAddr() net.Addr
Conn() net.Conn
}
type connCtx struct {
nextConn net.Conn
closed chan struct{}
closeOnce sync.Once
readMu sync.Mutex
writeMu sync.Mutex
}
var veryOld = time.Unix(0, 1) //nolint:gochecknoglobals
// New creates a new ConnCtx wrapping given net.Conn.
func New(conn net.Conn) ConnCtx {
c := &connCtx{
nextConn: conn,
closed: make(chan struct{}),
}
return c
}
func (c *connCtx) ReadContext(ctx context.Context, b []byte) (int, error) {
c.readMu.Lock()
defer c.readMu.Unlock()
select {
case <-c.closed:
return 0, io.EOF
default:
}
done := make(chan struct{})
var wg sync.WaitGroup
var errSetDeadline atomic.Value
wg.Add(1)
go func() {
defer wg.Done()
select {
case <-ctx.Done():
// context canceled
if err := c.nextConn.SetReadDeadline(veryOld); err != nil {
errSetDeadline.Store(err)
return
}
<-done
if err := c.nextConn.SetReadDeadline(time.Time{}); err != nil {
errSetDeadline.Store(err)
}
case <-done:
}
}()
n, err := c.nextConn.Read(b)
close(done)
wg.Wait()
if e := ctx.Err(); e != nil && n == 0 {
err = e
}
if err2 := errSetDeadline.Load(); err == nil && err2 != nil {
err = err2.(error)
}
return n, err
}
func (c *connCtx) WriteContext(ctx context.Context, b []byte) (int, error) {
c.writeMu.Lock()
defer c.writeMu.Unlock()
select {
case <-c.closed:
return 0, ErrClosing
default:
}
done := make(chan struct{})
var wg sync.WaitGroup
var errSetDeadline atomic.Value
wg.Add(1)
go func() {
defer wg.Done()
select {
case <-ctx.Done():
// context canceled
if err := c.nextConn.SetWriteDeadline(veryOld); err != nil {
errSetDeadline.Store(err)
return
}
<-done
if err := c.nextConn.SetWriteDeadline(time.Time{}); err != nil {
errSetDeadline.Store(err)
}
case <-done:
}
}()
n, err := c.nextConn.Write(b)
close(done)
wg.Wait()
if e := ctx.Err(); e != nil && n == 0 {
err = e
}
if err2 := errSetDeadline.Load(); err == nil && err2 != nil {
err = err2.(error)
}
return n, err
}
func (c *connCtx) Close() error {
err := c.nextConn.Close()
c.closeOnce.Do(func() {
c.writeMu.Lock()
c.readMu.Lock()
close(c.closed)
c.readMu.Unlock()
c.writeMu.Unlock()
})
return err
}
func (c *connCtx) LocalAddr() net.Addr {
return c.nextConn.LocalAddr()
}
func (c *connCtx) RemoteAddr() net.Addr {
return c.nextConn.RemoteAddr()
}
func (c *connCtx) Conn() net.Conn {
return c.nextConn
}