From f6beed899b4bcd290249482bfbbf93ee87096ce0 Mon Sep 17 00:00:00 2001 From: Tevin Zhang Date: Wed, 1 Jun 2016 17:45:24 +0800 Subject: [PATCH] Make Shaker's methods goroutine safe --- err.go | 7 +++++++ shaker.go | 29 +++++++++++++++++++++++++++-- 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/err.go b/err.go index 3f21730..36803ae 100644 --- a/err.go +++ b/err.go @@ -1,8 +1,15 @@ package tcp +import ( + "errors" +) + // ErrTimeout indicates I/O timeout var ErrTimeout = &timeoutError{} +// ErrNotInitialized occurs while the Shaker is not initialized +var ErrNotInitialized = errors.New("not initialized") + type timeoutError struct{} func (e *timeoutError) Error() string { return "I/O timeout" } diff --git a/shaker.go b/shaker.go index b5f5f5a..36576e9 100644 --- a/shaker.go +++ b/shaker.go @@ -26,12 +26,15 @@ // Usually this means the server will not notice the health checking // traffic at all, thus the act of health chekcing will not be // considered as some misbehaviour of client. +// +// Shaker's methods may be called by multiple goroutines simultaneously. package tcp import ( "fmt" "os" "runtime" + "sync" "syscall" "time" ) @@ -40,12 +43,19 @@ const maxEpollEvents = 32 // Shaker contains an epoll instance for TCP handshake checking type Shaker struct { + sync.RWMutex epollFd int } // Init creates inner epoll instance, call this before anything else func (s *Shaker) Init() error { var err error + s.Lock() + defer s.Unlock() + if s.epollFd > 0 { + return nil + } + s.epollFd, err = syscall.EpollCreate1(0) if err != nil { return os.NewSyscallError("epoll_create1", err) @@ -101,13 +111,22 @@ func (s *Shaker) Test(addr string, timeout time.Duration) error { // Close closes the inner epoll fd func (s *Shaker) Close() error { - return syscall.Close(s.epollFd) + s.Lock() + defer s.Unlock() + if s.epollFd > 0 { + err := syscall.Close(s.epollFd) + s.epollFd = 0 + return err + } + return nil } func (s *Shaker) addEpoll(fd int) error { var event syscall.EpollEvent event.Events = syscall.EPOLLOUT event.Fd = int32(fd) + s.RLock() + defer s.RUnlock() if err := syscall.EpollCtl(s.epollFd, syscall.EPOLL_CTL_ADD, fd, &event); err != nil { return os.NewSyscallError("epoll_ctl", err) } @@ -118,7 +137,13 @@ func (s *Shaker) addEpoll(fd int) error { // The boolean returned indicates whether the connect is successful func (s *Shaker) wait(fd int, timeoutMS int) (bool, error) { var events [maxEpollEvents]syscall.EpollEvent - nevents, err := syscall.EpollWait(s.epollFd, events[:], timeoutMS) + s.RLock() + epollFd := s.epollFd + if epollFd <= 0 { + return false, ErrNotInitialized + } + s.RUnlock() + nevents, err := syscall.EpollWait(epollFd, events[:], timeoutMS) if err != nil { return false, os.NewSyscallError("epoll_wait", err) }