167 lines
4.4 KiB
Go
Raw Normal View History

2017-02-18 23:00:46 +01:00
package middleware
import (
"errors"
"net/http"
"strings"
"github.com/labstack/echo/v4"
2017-02-18 23:00:46 +01:00
)
type (
// KeyAuthConfig defines the config for KeyAuth middleware.
KeyAuthConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// KeyLookup is a string in the form of "<source>:<name>" that is used
// to extract key from the request.
// Optional. Default value "header:Authorization".
// Possible values:
// - "header:<name>"
// - "query:<name>"
2018-02-21 00:48:10 +01:00
// - "form:<name>"
KeyLookup string `yaml:"key_lookup"`
2017-02-18 23:00:46 +01:00
// AuthScheme to be used in the Authorization header.
// Optional. Default value "Bearer".
AuthScheme string
// Validator is a function to validate key.
// Required.
Validator KeyAuthValidator
2021-05-30 00:25:30 +02:00
// ErrorHandler defines a function which is executed for an invalid key.
// It may be used to define a custom error.
ErrorHandler KeyAuthErrorHandler
2017-02-18 23:00:46 +01:00
}
// KeyAuthValidator defines a function to validate KeyAuth credentials.
2017-06-06 00:01:05 +02:00
KeyAuthValidator func(string, echo.Context) (bool, error)
2017-02-18 23:00:46 +01:00
keyExtractor func(echo.Context) (string, error)
2021-05-30 00:25:30 +02:00
// KeyAuthErrorHandler defines a function which is executed for an invalid key.
KeyAuthErrorHandler func(error, echo.Context) error
2017-02-18 23:00:46 +01:00
)
var (
// DefaultKeyAuthConfig is the default KeyAuth middleware config.
DefaultKeyAuthConfig = KeyAuthConfig{
Skipper: DefaultSkipper,
KeyLookup: "header:" + echo.HeaderAuthorization,
AuthScheme: "Bearer",
}
)
// KeyAuth returns an KeyAuth middleware.
//
// For valid key it calls the next handler.
// For invalid key, it sends "401 - Unauthorized" response.
// For missing key, it sends "400 - Bad Request" response.
func KeyAuth(fn KeyAuthValidator) echo.MiddlewareFunc {
c := DefaultKeyAuthConfig
c.Validator = fn
return KeyAuthWithConfig(c)
}
// KeyAuthWithConfig returns an KeyAuth middleware with config.
// See `KeyAuth()`.
func KeyAuthWithConfig(config KeyAuthConfig) echo.MiddlewareFunc {
// Defaults
if config.Skipper == nil {
config.Skipper = DefaultKeyAuthConfig.Skipper
}
// Defaults
if config.AuthScheme == "" {
config.AuthScheme = DefaultKeyAuthConfig.AuthScheme
}
if config.KeyLookup == "" {
config.KeyLookup = DefaultKeyAuthConfig.KeyLookup
}
if config.Validator == nil {
2017-12-07 23:00:56 +01:00
panic("echo: key-auth middleware requires a validator function")
2017-02-18 23:00:46 +01:00
}
// Initialize
parts := strings.Split(config.KeyLookup, ":")
extractor := keyFromHeader(parts[1], config.AuthScheme)
switch parts[0] {
case "query":
extractor = keyFromQuery(parts[1])
2018-02-21 00:48:10 +01:00
case "form":
extractor = keyFromForm(parts[1])
2017-02-18 23:00:46 +01:00
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}
// Extract and verify key
key, err := extractor(c)
if err != nil {
2021-05-30 00:25:30 +02:00
if config.ErrorHandler != nil {
return config.ErrorHandler(err, c)
}
2017-02-18 23:00:46 +01:00
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
2017-06-06 00:01:05 +02:00
valid, err := config.Validator(key, c)
if err != nil {
2021-05-30 00:25:30 +02:00
if config.ErrorHandler != nil {
return config.ErrorHandler(err, c)
}
2019-06-16 23:33:25 +02:00
return &echo.HTTPError{
Code: http.StatusUnauthorized,
Message: "invalid key",
Internal: err,
}
2017-06-06 00:01:05 +02:00
} else if valid {
2017-02-18 23:00:46 +01:00
return next(c)
}
return echo.ErrUnauthorized
}
}
}
// keyFromHeader returns a `keyExtractor` that extracts key from the request header.
func keyFromHeader(header string, authScheme string) keyExtractor {
return func(c echo.Context) (string, error) {
auth := c.Request().Header.Get(header)
if auth == "" {
return "", errors.New("missing key in request header")
2017-02-18 23:00:46 +01:00
}
if header == echo.HeaderAuthorization {
l := len(authScheme)
if len(auth) > l+1 && auth[:l] == authScheme {
return auth[l+1:], nil
}
return "", errors.New("invalid key in the request header")
2017-02-18 23:00:46 +01:00
}
return auth, nil
}
}
// keyFromQuery returns a `keyExtractor` that extracts key from the query string.
func keyFromQuery(param string) keyExtractor {
return func(c echo.Context) (string, error) {
key := c.QueryParam(param)
if key == "" {
return "", errors.New("missing key in the query string")
2017-02-18 23:00:46 +01:00
}
return key, nil
}
}
2018-02-21 00:48:10 +01:00
// keyFromForm returns a `keyExtractor` that extracts key from the form.
func keyFromForm(param string) keyExtractor {
return func(c echo.Context) (string, error) {
key := c.FormValue(param)
if key == "" {
return "", errors.New("missing key in the form")
2018-02-21 00:48:10 +01:00
}
return key, nil
}
}