253 lines
6.0 KiB
Go
Raw Normal View History

2017-06-06 00:01:05 +02:00
package middleware
import (
"fmt"
"io"
"math/rand"
"net"
"net/http"
"net/http/httputil"
"net/url"
2018-02-21 00:48:10 +01:00
"regexp"
"strings"
"sync"
2017-06-06 00:01:05 +02:00
"sync/atomic"
"time"
"github.com/labstack/echo"
)
// TODO: Handle TLS proxy
type (
// ProxyConfig defines the config for Proxy middleware.
ProxyConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// Balancer defines a load balancing technique.
// Required.
Balancer ProxyBalancer
2018-02-21 00:48:10 +01:00
// Rewrite defines URL path rewrite rules. The values captured in asterisk can be
// retrieved by index e.g. $1, $2 and so on.
// Examples:
// "/old": "/new",
// "/api/*": "/$1",
// "/js/*": "/public/javascripts/$1",
// "/users/*/orders/*": "/user/$1/order/$2",
Rewrite map[string]string
rewriteRegex map[*regexp.Regexp]string
2017-06-06 00:01:05 +02:00
}
// ProxyTarget defines the upstream target.
ProxyTarget struct {
2018-02-21 00:48:10 +01:00
Name string
URL *url.URL
2017-06-06 00:01:05 +02:00
}
2018-02-21 00:48:10 +01:00
// ProxyBalancer defines an interface to implement a load balancing technique.
ProxyBalancer interface {
AddTarget(*ProxyTarget) bool
RemoveTarget(string) bool
Next() *ProxyTarget
2017-06-06 00:01:05 +02:00
}
2018-02-21 00:48:10 +01:00
commonBalancer struct {
targets []*ProxyTarget
mutex sync.RWMutex
2017-06-06 00:01:05 +02:00
}
2018-02-21 00:48:10 +01:00
// RandomBalancer implements a random load balancing technique.
randomBalancer struct {
*commonBalancer
random *rand.Rand
}
// RoundRobinBalancer implements a round-robin load balancing technique.
roundRobinBalancer struct {
*commonBalancer
i uint32
2017-06-06 00:01:05 +02:00
}
)
2017-12-07 23:00:56 +01:00
var (
// DefaultProxyConfig is the default Proxy middleware config.
DefaultProxyConfig = ProxyConfig{
Skipper: DefaultSkipper,
}
)
2017-06-06 00:01:05 +02:00
func proxyHTTP(t *ProxyTarget) http.Handler {
return httputil.NewSingleHostReverseProxy(t.URL)
}
func proxyRaw(t *ProxyTarget, c echo.Context) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
2017-12-07 23:00:56 +01:00
in, _, err := c.Response().Hijack()
2017-06-06 00:01:05 +02:00
if err != nil {
2017-12-07 23:00:56 +01:00
c.Error(fmt.Errorf("proxy raw, hijack error=%v, url=%s", t.URL, err))
2017-06-06 00:01:05 +02:00
return
}
defer in.Close()
out, err := net.Dial("tcp", t.URL.Host)
if err != nil {
2017-12-07 23:00:56 +01:00
he := echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, dial error=%v, url=%s", t.URL, err))
2017-06-06 00:01:05 +02:00
c.Error(he)
return
}
defer out.Close()
2017-12-07 23:00:56 +01:00
// Write header
2017-06-06 00:01:05 +02:00
err = r.Write(out)
if err != nil {
2017-12-07 23:00:56 +01:00
he := echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, request header copy error=%v, url=%s", t.URL, err))
2017-06-06 00:01:05 +02:00
c.Error(he)
return
}
errc := make(chan error, 2)
cp := func(dst io.Writer, src io.Reader) {
_, err := io.Copy(dst, src)
errc <- err
}
go cp(out, in)
go cp(in, out)
err = <-errc
if err != nil && err != io.EOF {
2017-12-07 23:00:56 +01:00
c.Logger().Errorf("proxy raw, copy body error=%v, url=%s", t.URL, err)
2017-06-06 00:01:05 +02:00
}
})
}
2018-02-21 00:48:10 +01:00
// NewRandomBalancer returns a random proxy balancer.
func NewRandomBalancer(targets []*ProxyTarget) ProxyBalancer {
b := &randomBalancer{commonBalancer: new(commonBalancer)}
b.targets = targets
return b
}
// NewRoundRobinBalancer returns a round-robin proxy balancer.
func NewRoundRobinBalancer(targets []*ProxyTarget) ProxyBalancer {
b := &roundRobinBalancer{commonBalancer: new(commonBalancer)}
b.targets = targets
return b
}
// AddTarget adds an upstream target to the list.
func (b *commonBalancer) AddTarget(target *ProxyTarget) bool {
for _, t := range b.targets {
if t.Name == target.Name {
return false
}
}
b.mutex.Lock()
defer b.mutex.Unlock()
b.targets = append(b.targets, target)
return true
}
// RemoveTarget removes an upstream target from the list.
func (b *commonBalancer) RemoveTarget(name string) bool {
b.mutex.Lock()
defer b.mutex.Unlock()
for i, t := range b.targets {
if t.Name == name {
b.targets = append(b.targets[:i], b.targets[i+1:]...)
return true
}
}
return false
}
2017-06-06 00:01:05 +02:00
// Next randomly returns an upstream target.
2018-02-21 00:48:10 +01:00
func (b *randomBalancer) Next() *ProxyTarget {
if b.random == nil {
b.random = rand.New(rand.NewSource(int64(time.Now().Nanosecond())))
2017-06-06 00:01:05 +02:00
}
2018-02-21 00:48:10 +01:00
b.mutex.RLock()
defer b.mutex.RUnlock()
return b.targets[b.random.Intn(len(b.targets))]
2017-06-06 00:01:05 +02:00
}
// Next returns an upstream target using round-robin technique.
2018-02-21 00:48:10 +01:00
func (b *roundRobinBalancer) Next() *ProxyTarget {
b.i = b.i % uint32(len(b.targets))
t := b.targets[b.i]
atomic.AddUint32(&b.i, 1)
2017-06-06 00:01:05 +02:00
return t
}
2017-12-07 23:00:56 +01:00
// Proxy returns a Proxy middleware.
//
// Proxy middleware forwards the request to upstream server using a configured load balancing technique.
func Proxy(balancer ProxyBalancer) echo.MiddlewareFunc {
c := DefaultProxyConfig
c.Balancer = balancer
return ProxyWithConfig(c)
}
// ProxyWithConfig returns a Proxy middleware with config.
// See: `Proxy()`
func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
2017-06-06 00:01:05 +02:00
// Defaults
if config.Skipper == nil {
config.Skipper = DefaultLoggerConfig.Skipper
}
if config.Balancer == nil {
panic("echo: proxy middleware requires balancer")
}
2018-02-21 00:48:10 +01:00
config.rewriteRegex = map[*regexp.Regexp]string{}
// Initialize
for k, v := range config.Rewrite {
k = strings.Replace(k, "*", "(\\S*)", -1)
config.rewriteRegex[regexp.MustCompile(k)] = v
}
2017-06-06 00:01:05 +02:00
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) (err error) {
2017-12-07 23:00:56 +01:00
if config.Skipper(c) {
return next(c)
}
2017-06-06 00:01:05 +02:00
req := c.Request()
res := c.Response()
tgt := config.Balancer.Next()
2018-02-21 00:48:10 +01:00
// Rewrite
for k, v := range config.rewriteRegex {
replacer := captureTokens(k, req.URL.Path)
if replacer != nil {
req.URL.Path = replacer.Replace(v)
}
}
2017-06-06 00:01:05 +02:00
// Fix header
if req.Header.Get(echo.HeaderXRealIP) == "" {
req.Header.Set(echo.HeaderXRealIP, c.RealIP())
}
if req.Header.Get(echo.HeaderXForwardedProto) == "" {
req.Header.Set(echo.HeaderXForwardedProto, c.Scheme())
}
if c.IsWebSocket() && req.Header.Get(echo.HeaderXForwardedFor) == "" { // For HTTP, it is automatically set by Go HTTP reverse proxy.
req.Header.Set(echo.HeaderXForwardedFor, c.RealIP())
}
// Proxy
switch {
case c.IsWebSocket():
proxyRaw(tgt, c).ServeHTTP(res, req)
case req.Header.Get(echo.HeaderAccept) == "text/event-stream":
default:
proxyHTTP(tgt).ServeHTTP(res, req)
}
return
}
}
}