diff --git a/consul/pool.go b/consul/pool.go index 7fc93950d2..2323b36b3a 100644 --- a/consul/pool.go +++ b/consul/pool.go @@ -41,7 +41,9 @@ type StreamClient struct { // Conn is a pooled connection to a Consul server type Conn struct { - refCount int32 + refCount int32 + shouldClose int32 + addr net.Addr session muxSession lastUsed time.Time @@ -93,7 +95,7 @@ func (c *Conn) getClient() (*StreamClient, error) { func (c *Conn) returnClient(client *StreamClient) { didSave := false c.clientLock.Lock() - if c.clients.Len() < c.pool.maxStreams { + if c.clients.Len() < c.pool.maxStreams && atomic.LoadInt32(&c.shouldClose) == 0 { c.clients.PushFront(client) didSave = true } @@ -184,14 +186,12 @@ func (p *ConnPool) acquire(addr net.Addr, version int) (*Conn, error) { // getPooled is used to return a pooled connection func (p *ConnPool) getPooled(addr net.Addr, version int) *Conn { p.Lock() - defer p.Unlock() - - // Look for an existing connection c := p.pool[addr.String()] if c != nil { c.lastUsed = time.Now() atomic.AddInt32(&c.refCount, 1) } + p.Unlock() return c } @@ -261,29 +261,41 @@ func (p *ConnPool) getNewConn(addr net.Addr, version int) (*Conn, error) { // Track this connection, handle potential race condition p.Lock() - defer p.Unlock() if existing := p.pool[addr.String()]; existing != nil { - session.Close() + c.Close() + p.Unlock() return existing, nil } else { p.pool[addr.String()] = c + p.Unlock() return c, nil } } // clearConn is used to clear any cached connection, potentially in response to an erro -func (p *ConnPool) clearConn(addr net.Addr) { +func (p *ConnPool) clearConn(conn *Conn) { + // Ensure returned streams are closed + atomic.StoreInt32(&conn.shouldClose, 1) + + // Clear from the cache p.Lock() - defer p.Unlock() - if conn, ok := p.pool[addr.String()]; ok { + if c, ok := p.pool[conn.addr.String()]; ok && c == conn { + delete(p.pool, conn.addr.String()) + } + p.Unlock() + + // Close down immediately if idle + if refCount := atomic.LoadInt32(&conn.shouldClose); refCount == 0 { conn.Close() - delete(p.pool, addr.String()) } } // releaseConn is invoked when we are done with a conn to reduce the ref count func (p *ConnPool) releaseConn(conn *Conn) { - atomic.AddInt32(&conn.refCount, -1) + refCount := atomic.AddInt32(&conn.refCount, -1) + if refCount == 0 && atomic.LoadInt32(&conn.shouldClose) == 1 { + conn.Close() + } } // getClient is used to get a usable client for an address and protocol version @@ -299,7 +311,8 @@ START: // Get a client client, err := conn.getClient() if err != nil { - p.clearConn(addr) + p.clearConn(conn) + p.releaseConn(conn) // Try to redial, possible that the TCP session closed due to timeout if retries == 0 { @@ -313,23 +326,24 @@ START: // RPC is used to make an RPC call to a remote host func (p *ConnPool) RPC(addr net.Addr, version int, method string, args interface{}, reply interface{}) error { + // Get a usable client conn, sc, err := p.getClient(addr, version) - defer func() { - conn.returnClient(sc) - p.releaseConn(conn) - }() + if err != nil { + return fmt.Errorf("rpc error: %v", err) + } // Make the RPC call err = sc.client.Call(method, args, reply) - - // Fast path the non-error case - if err == nil { - return nil + if err != nil { + p.clearConn(conn) + p.releaseConn(conn) + return fmt.Errorf("rpc error: %v", err) } - // Do-not re-use as a pre-caution - p.clearConn(addr) - return fmt.Errorf("rpc error: %v", err) + // Done with the connection + conn.returnClient(sc) + p.releaseConn(conn) + return nil } // Reap is used to close conns open over maxTime