211 lines
4.9 KiB
Go
211 lines
4.9 KiB
Go
package tcp
|
|
|
|
import (
|
|
"context"
|
|
"sync"
|
|
"sync/atomic"
|
|
"syscall"
|
|
"time"
|
|
|
|
"github.com/pkg/errors"
|
|
)
|
|
|
|
// Checker contains an epoll instance for TCP handshake checking
|
|
type Checker struct {
|
|
pipePool
|
|
resultPipes
|
|
pollerLock sync.Mutex
|
|
_pollerFd int32
|
|
zeroLinger bool
|
|
isReady chan struct{}
|
|
}
|
|
|
|
// NewChecker creates a Checker with linger set to zero.
|
|
func NewChecker() *Checker {
|
|
return NewCheckerZeroLinger(true)
|
|
}
|
|
|
|
// NewCheckerZeroLinger creates a Checker with zeroLinger set to given value.
|
|
func NewCheckerZeroLinger(zeroLinger bool) *Checker {
|
|
return &Checker{
|
|
pipePool: newPipePoolSyncPool(),
|
|
resultPipes: newResultPipesSyncMap(),
|
|
_pollerFd: -1,
|
|
zeroLinger: zeroLinger,
|
|
isReady: make(chan struct{}),
|
|
}
|
|
}
|
|
|
|
// CheckingLoop must be called before anything else.
|
|
// NOTE: this function blocks until ctx got canceled.
|
|
func (c *Checker) CheckingLoop(ctx context.Context) error {
|
|
pollerFd, err := c.createPoller()
|
|
if err != nil {
|
|
return errors.Wrap(err, "error creating poller")
|
|
}
|
|
defer c.closePoller()
|
|
|
|
c.setReady()
|
|
defer c.resetReady()
|
|
|
|
return c.pollingLoop(ctx, pollerFd)
|
|
}
|
|
|
|
func (c *Checker) createPoller() (int, error) {
|
|
c.pollerLock.Lock()
|
|
defer c.pollerLock.Unlock()
|
|
|
|
if c.pollerFD() > 0 {
|
|
// return if already initialized
|
|
return -1, ErrCheckerAlreadyStarted
|
|
}
|
|
|
|
pollerFd, err := createPoller()
|
|
if err != nil {
|
|
return -1, err
|
|
}
|
|
c.setPollerFD(pollerFd)
|
|
|
|
return pollerFd, nil
|
|
}
|
|
|
|
func (c *Checker) closePoller() error {
|
|
c.pollerLock.Lock()
|
|
defer c.pollerLock.Unlock()
|
|
var err error
|
|
if c.pollerFD() > 0 {
|
|
err = syscall.Close(c.pollerFD())
|
|
}
|
|
c.setPollerFD(-1)
|
|
return err
|
|
}
|
|
|
|
func (c *Checker) setReady() {
|
|
close(c.isReady)
|
|
}
|
|
|
|
func (c *Checker) resetReady() {
|
|
c.isReady = make(chan struct{})
|
|
}
|
|
|
|
const pollerTimeout = time.Second
|
|
|
|
func (c *Checker) pollingLoop(ctx context.Context, pollerFd int) error {
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil
|
|
default:
|
|
evts, err := pollEvents(pollerFd, pollerTimeout)
|
|
if err != nil {
|
|
// fatal error
|
|
return errors.Wrap(err, "error during polling loop")
|
|
}
|
|
|
|
c.handlePollerEvents(evts)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *Checker) handlePollerEvents(evts []event) {
|
|
for _, e := range evts {
|
|
if pipe, exists := c.resultPipes.popResultPipe(e.Fd); exists {
|
|
pipe <- e.Err
|
|
}
|
|
// error pipe not found
|
|
// in this case, e.Fd should have been handled in the previous event.
|
|
}
|
|
}
|
|
|
|
func (c *Checker) pollerFD() int {
|
|
return int(atomic.LoadInt32(&c._pollerFd))
|
|
}
|
|
|
|
func (c *Checker) setPollerFD(fd int) {
|
|
atomic.StoreInt32(&c._pollerFd, int32(fd))
|
|
}
|
|
|
|
// CheckAddr performs a TCP check with given TCP address and timeout
|
|
// A successful check will result in nil error
|
|
// ErrTimeout is returned if timeout
|
|
// zeroLinger is an optional parameter indicating if linger should be set to zero
|
|
// for this particular connection
|
|
// Note: timeout includes domain resolving
|
|
func (c *Checker) CheckAddr(addr string, timeout time.Duration) (err error) {
|
|
return c.CheckAddrZeroLinger(addr, timeout, c.zeroLinger)
|
|
}
|
|
|
|
// CheckAddrZeroLinger is like CheckAddr with an extra parameter indicating whether to enable zero linger.
|
|
func (c *Checker) CheckAddrZeroLinger(addr string, timeout time.Duration, zeroLinger bool) error {
|
|
// Set deadline
|
|
deadline := time.Now().Add(timeout)
|
|
|
|
// Parse address
|
|
rAddr, err := parseSockAddr(addr)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
// Create socket with options set
|
|
fd, err := createSocketZeroLinger(zeroLinger)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
// Socket should be closed anyway
|
|
defer syscall.Close(fd)
|
|
|
|
// Connect to the address
|
|
if success, cErr := connect(fd, rAddr); cErr != nil {
|
|
// If there was an error, return it.
|
|
return &ErrConnect{cErr}
|
|
} else if success {
|
|
// If the connect was successful, we are done.
|
|
return nil
|
|
}
|
|
// Otherwise wait for the result of connect.
|
|
return c.waitConnectResult(fd, deadline.Sub(time.Now()))
|
|
}
|
|
|
|
func (c *Checker) waitConnectResult(fd int, timeout time.Duration) error {
|
|
// get a pipe of connect result
|
|
resultPipe := c.getPipe()
|
|
defer func() {
|
|
c.resultPipes.deregisterResultPipe(fd)
|
|
c.putBackPipe(resultPipe)
|
|
}()
|
|
|
|
// this must be done before registerEvents
|
|
c.resultPipes.registerResultPipe(fd, resultPipe)
|
|
// Register to epoll for later error checking
|
|
if err := registerEvents(c.pollerFD(), fd); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Wait for connect result
|
|
return c.waitPipeTimeout(resultPipe, timeout)
|
|
}
|
|
|
|
func (c *Checker) waitPipeTimeout(pipe chan error, timeout time.Duration) error {
|
|
select {
|
|
case ret := <-pipe:
|
|
return ret
|
|
case <-time.After(timeout):
|
|
return ErrTimeout
|
|
}
|
|
}
|
|
|
|
// WaitReady returns a chan which is closed when the Checker is ready for use.
|
|
func (c *Checker) WaitReady() <-chan struct{} {
|
|
return c.isReady
|
|
}
|
|
|
|
// IsReady returns a bool indicates whether the Checker is ready for use
|
|
func (c *Checker) IsReady() bool {
|
|
return c.pollerFD() > 0
|
|
}
|
|
|
|
// PollerFd returns the inner fd of poller instance.
|
|
// NOTE: Use this only when you really know what you are doing.
|
|
func (c *Checker) PollerFd() int {
|
|
return c.pollerFD()
|
|
}
|