2016-09-04 20:04:43 +02:00

238 lines
6.5 KiB
Go

/*
Provides an HTTP Transport that implements the `RoundTripper` interface and
can be used as a built in replacement for the standard library's, providing:
* connection timeouts
* request timeouts
This is a thin wrapper around `http.Transport` that sets dial timeouts and uses
Go's internal timer scheduler to call the Go 1.1+ `CancelRequest()` API.
*/
package httpclient
import (
"crypto/tls"
"errors"
"io"
"net"
"net/http"
"net/url"
"sync"
"time"
)
// returns the current version of the package
func Version() string {
return "0.4.1"
}
// Transport implements the RoundTripper interface and can be used as a replacement
// for Go's built in http.Transport implementing end-to-end request timeouts.
//
// transport := &httpclient.Transport{
// ConnectTimeout: 1*time.Second,
// ResponseHeaderTimeout: 5*time.Second,
// RequestTimeout: 10*time.Second,
// }
// defer transport.Close()
//
// client := &http.Client{Transport: transport}
// req, _ := http.NewRequest("GET", "http://127.0.0.1/test", nil)
// resp, err := client.Do(req)
// if err != nil {
// return err
// }
// defer resp.Body.Close()
//
type Transport struct {
// Proxy specifies a function to return a proxy for a given
// *http.Request. If the function returns a non-nil error, the
// request is aborted with the provided error.
// If Proxy is nil or returns a nil *url.URL, no proxy is used.
Proxy func(*http.Request) (*url.URL, error)
// Dial specifies the dial function for creating TCP
// connections. This will override the Transport's ConnectTimeout and
// ReadWriteTimeout settings.
// If Dial is nil, a dialer is generated on demand matching the Transport's
// options.
Dial func(network, addr string) (net.Conn, error)
// TLSClientConfig specifies the TLS configuration to use with
// tls.Client. If nil, the default configuration is used.
TLSClientConfig *tls.Config
// DisableKeepAlives, if true, prevents re-use of TCP connections
// between different HTTP requests.
DisableKeepAlives bool
// DisableCompression, if true, prevents the Transport from
// requesting compression with an "Accept-Encoding: gzip"
// request header when the Request contains no existing
// Accept-Encoding value. If the Transport requests gzip on
// its own and gets a gzipped response, it's transparently
// decoded in the Response.Body. However, if the user
// explicitly requested gzip it is not automatically
// uncompressed.
DisableCompression bool
// MaxIdleConnsPerHost, if non-zero, controls the maximum idle
// (keep-alive) to keep per-host. If zero,
// http.DefaultMaxIdleConnsPerHost is used.
MaxIdleConnsPerHost int
// ConnectTimeout, if non-zero, is the maximum amount of time a dial will wait for
// a connect to complete.
ConnectTimeout time.Duration
// ResponseHeaderTimeout, if non-zero, specifies the amount of
// time to wait for a server's response headers after fully
// writing the request (including its body, if any). This
// time does not include the time to read the response body.
ResponseHeaderTimeout time.Duration
// RequestTimeout, if non-zero, specifies the amount of time for the entire
// request to complete (including all of the above timeouts + entire response body).
// This should never be less than the sum total of the above two timeouts.
RequestTimeout time.Duration
// ReadWriteTimeout, if non-zero, will set a deadline for every Read and
// Write operation on the request connection.
ReadWriteTimeout time.Duration
// TCPWriteBufferSize, the size of the operating system's write
// buffer associated with the connection.
TCPWriteBufferSize int
// TCPReadBuffserSize, the size of the operating system's read
// buffer associated with the connection.
TCPReadBufferSize int
starter sync.Once
transport *http.Transport
}
// Close cleans up the Transport, currently a no-op
func (t *Transport) Close() error {
return nil
}
func (t *Transport) lazyStart() {
if t.Dial == nil {
t.Dial = func(netw, addr string) (net.Conn, error) {
c, err := net.DialTimeout(netw, addr, t.ConnectTimeout)
if err != nil {
return nil, err
}
if t.TCPReadBufferSize != 0 || t.TCPWriteBufferSize != 0 {
if tcpCon, ok := c.(*net.TCPConn); ok {
if t.TCPWriteBufferSize != 0 {
if err = tcpCon.SetWriteBuffer(t.TCPWriteBufferSize); err != nil {
return nil, err
}
}
if t.TCPReadBufferSize != 0 {
if err = tcpCon.SetReadBuffer(t.TCPReadBufferSize); err != nil {
return nil, err
}
}
} else {
err = errors.New("Not Tcp Connection")
return nil, err
}
}
if t.ReadWriteTimeout > 0 {
timeoutConn := &rwTimeoutConn{
TCPConn: c.(*net.TCPConn),
rwTimeout: t.ReadWriteTimeout,
}
return timeoutConn, nil
}
return c, nil
}
}
t.transport = &http.Transport{
Dial: t.Dial,
Proxy: t.Proxy,
TLSClientConfig: t.TLSClientConfig,
DisableKeepAlives: t.DisableKeepAlives,
DisableCompression: t.DisableCompression,
MaxIdleConnsPerHost: t.MaxIdleConnsPerHost,
ResponseHeaderTimeout: t.ResponseHeaderTimeout,
}
}
func (t *Transport) CancelRequest(req *http.Request) {
t.starter.Do(t.lazyStart)
t.transport.CancelRequest(req)
}
func (t *Transport) CloseIdleConnections() {
t.starter.Do(t.lazyStart)
t.transport.CloseIdleConnections()
}
func (t *Transport) RegisterProtocol(scheme string, rt http.RoundTripper) {
t.starter.Do(t.lazyStart)
t.transport.RegisterProtocol(scheme, rt)
}
func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error) {
t.starter.Do(t.lazyStart)
if t.RequestTimeout > 0 {
timer := time.AfterFunc(t.RequestTimeout, func() {
t.transport.CancelRequest(req)
})
resp, err = t.transport.RoundTrip(req)
if err != nil {
timer.Stop()
} else {
resp.Body = &bodyCloseInterceptor{ReadCloser: resp.Body, timer: timer}
}
} else {
resp, err = t.transport.RoundTrip(req)
}
return
}
type bodyCloseInterceptor struct {
io.ReadCloser
timer *time.Timer
}
func (bci *bodyCloseInterceptor) Close() error {
bci.timer.Stop()
return bci.ReadCloser.Close()
}
// A net.Conn that sets a deadline for every Read or Write operation
type rwTimeoutConn struct {
*net.TCPConn
rwTimeout time.Duration
}
func (c *rwTimeoutConn) Read(b []byte) (int, error) {
err := c.TCPConn.SetDeadline(time.Now().Add(c.rwTimeout))
if err != nil {
return 0, err
}
return c.TCPConn.Read(b)
}
func (c *rwTimeoutConn) Write(b []byte) (int, error) {
err := c.TCPConn.SetDeadline(time.Now().Add(c.rwTimeout))
if err != nil {
return 0, err
}
return c.TCPConn.Write(b)
}