agent: refactor DNS and HTTP server

* refactor DNS server to be ready for multiple bind addresses
* drop tcpKeepAliveListener since it is default for the HTTP servers
* add startup timeout watcher for HTTP servers identical to DNS server
This commit is contained in:
Frank Schroeder 2017-05-24 15:22:56 +02:00
parent f4aa2ada4f
commit b6c69ebf5d
No known key found for this signature in database
GPG Key ID: 4D65C6EAEC87DECD
7 changed files with 241 additions and 281 deletions

View File

@ -3,6 +3,7 @@ package agent
import (
"context"
"crypto/sha512"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
@ -150,13 +151,13 @@ type Agent struct {
endpointsLock sync.RWMutex
// dnsAddr is the address the DNS server binds to
dnsAddr net.Addr
dnsAddrs []ProtoAddr
// dnsServer provides the DNS API
dnsServers []*DNSServer
// httpAddrs are the addresses per protocol the HTTP server binds to
httpAddrs map[string][]net.Addr
httpAddrs []ProtoAddr
// httpServers provides the HTTP API on various endpoints
httpServers []*HTTPServer
@ -172,7 +173,7 @@ func NewAgent(c *Config) (*Agent, error) {
if c.DataDir == "" && !c.DevMode {
return nil, fmt.Errorf("Must configure a DataDir")
}
dnsAddr, err := c.ClientListener(c.Addresses.DNS, c.Ports.DNS)
dnsAddrs, err := c.DNSAddrs()
if err != nil {
return nil, fmt.Errorf("Invalid DNS bind address: %s", err)
}
@ -199,7 +200,7 @@ func NewAgent(c *Config) (*Agent, error) {
reloadCh: make(chan chan error),
shutdownCh: make(chan struct{}),
endpoints: make(map[string]string),
dnsAddr: dnsAddr,
dnsAddrs: dnsAddrs,
httpAddrs: httpAddrs,
}
if err := a.resolveTmplAddrs(); err != nil {
@ -285,122 +286,169 @@ func (a *Agent) Start() error {
return err
}
// start dns server
if c.Ports.DNS > 0 {
srv, err := NewDNSServer(a, &c.DNSConfig, logOutput, a.logger, c.Domain, a.dnsAddr.String(), c.DNSRecursors)
if err != nil {
return fmt.Errorf("error starting DNS server: %s", err)
// start DNS servers
if err := a.listenAndServeDNS(); err != nil {
return err
}
a.dnsServers = []*DNSServer{srv}
// create listeners and unstarted servers
// see comment on listenHTTP why we are doing this
httpln, err := a.listenHTTP(a.httpAddrs)
if err != nil {
return err
}
// start HTTP servers
return a.startHTTP(a.httpAddrs)
}
func (a *Agent) startHTTP(httpAddrs map[string][]net.Addr) error {
// ln contains the list of pending listeners until the
// actual server is created and the listeners are used.
var ln []net.Listener
// cleanup the listeners on error. ln should be empty on success.
defer func() {
for _, l := range ln {
l.Close()
}
}()
// bind to the listeners for all addresses and protocols
// before we start the servers so that we can fail early
// if we can't bind to one of the addresses.
for proto, addrs := range httpAddrs {
for _, addr := range addrs {
switch addr.(type) {
case *net.UnixAddr:
switch proto {
case "http":
if _, err := os.Stat(addr.String()); !os.IsNotExist(err) {
a.logger.Printf("[WARN] agent: Replacing socket %q", addr.String())
}
l, err := ListenUnix(addr.String(), a.config.UnixSockets)
if err != nil {
for _, l := range httpln {
srv := NewHTTPServer(l.Addr().String(), a)
if err := a.serveHTTP(l, srv); err != nil {
return err
}
ln = append(ln, l)
default:
return fmt.Errorf("invalid protocol: %q", proto)
}
case *net.TCPAddr:
switch proto {
case "http":
l, err := ListenTCP(addr.String())
if err != nil {
return err
}
ln = append(ln, l)
case "https":
tlscfg, err := a.config.IncomingTLSConfig()
if err != nil {
return fmt.Errorf("invalid TLS configuration: %s", err)
}
l, err := ListenTLS(addr.String(), tlscfg)
if err != nil {
return err
}
ln = append(ln, l)
default:
return fmt.Errorf("invalid protocol: %q", proto)
}
default:
return fmt.Errorf("invalid address type: %T", addr)
}
}
}
// https://github.com/golang/go/issues/20239
//
// In go1.8.1 there is a race between Serve and Shutdown. If
// Shutdown is called before the Serve go routine was scheduled then
// the Serve go routine never returns. This deadlocks the agent
// shutdown for some tests since it will wait forever.
//
// We solve this with another WaitGroup which checks that the Serve
// go routine was called and after that it should be safe to call
// Shutdown on that server.
var up sync.WaitGroup
for _, l := range ln {
l := l // capture loop var
// create a server per listener instead of a single
// server with multiple listeners to take advantage
// of the Addr field for logging. Since the server
// does not keep state and they all share the same
// agent there is no overhead.
addr := l.Addr().String()
srv := NewHTTPServer(addr, a)
a.httpServers = append(a.httpServers, srv)
}
return nil
}
up.Add(1)
func (a *Agent) listenAndServeDNS() error {
notif := make(chan ProtoAddr, len(a.dnsAddrs))
for _, p := range a.dnsAddrs {
p := p // capture loop var
// create server
s, err := NewDNSServer(a)
if err != nil {
return err
}
a.dnsServers = append(a.dnsServers, s)
// start server
a.wgServers.Add(1)
go func() {
defer a.wgServers.Done()
up.Done()
a.logger.Printf("[INFO] agent: Starting HTTP server on %s", addr)
if err := srv.Serve(l); err != nil && err != http.ErrServerClosed {
a.logger.Print(err)
err := s.ListenAndServe(p.Net, p.Addr, func() { notif <- p })
if err != nil && !strings.Contains(err.Error(), "accept") {
a.logger.Printf("[ERR] agent: Error starting DNS server %s (%s): ", p.Addr, p.Net, err)
}
}()
}
up.Wait()
ln = nil
// wait for servers to be up
// todo(fs): not sure whether this is the right approach.
// todo(fs): maybe a failing server should trigger an agent shutdown.
timeout := time.After(time.Second)
for range a.dnsAddrs {
select {
case p := <-notif:
a.logger.Printf("[INFO] agent: Started DNS server %s (%s)", p.Addr, p.Net)
continue
case <-timeout:
return fmt.Errorf("agent: timeout starting DNS servers")
}
}
return nil
}
// listenHTTP binds listeners to the provided addresses and also returns
// pre-configured HTTP servers which are not yet started. The motivation is
// that in the current startup/shutdown setup we de-couple the listener
// creation from the server startup assuming that if any of the listeners
// cannot be bound we fail immediately and later failures do not occur.
// Therefore, starting a server with a running listener is assumed to not
// produce an error.
//
// The second motivation is that an HTTPS server needs to use the same TLSConfig
// on both the listener and the HTTP server. When listeners and servers are
// created at different times this becomes difficult to handle without keeping
// the TLS configuration somewhere or recreating it.
//
// This approach should ultimately be refactored to the point where we just
// start the server and any error should trigger a proper shutdown of the agent.
func (a *Agent) listenHTTP(addrs []ProtoAddr) ([]net.Listener, error) {
var ln []net.Listener
for _, p := range addrs {
var l net.Listener
var err error
switch {
case p.Net == "unix":
l, err = a.listenSocket(p.Addr, a.config.UnixSockets)
case p.Net == "tcp" && p.Proto == "http":
l, err = net.Listen("tcp", p.Addr)
case p.Net == "tcp" && p.Proto == "https":
var tlscfg *tls.Config
tlscfg, err = a.config.IncomingTLSConfig()
if err != nil {
break
}
l, err = tls.Listen("tcp", p.Addr, tlscfg)
}
if err != nil {
for _, l := range ln {
l.Close()
}
return nil, err
}
ln = append(ln, l)
}
return ln, nil
}
func (a *Agent) listenSocket(path string, perm FilePermissions) (net.Listener, error) {
if _, err := os.Stat(path); !os.IsNotExist(err) {
a.logger.Printf("[WARN] agent: Replacing socket %q", path)
}
if err := os.Remove(path); err != nil && !os.IsNotExist(err) {
return nil, fmt.Errorf("error removing socket file: %s", err)
}
l, err := net.Listen("unix", path)
if err != nil {
return nil, err
}
if err := setFilePermissions(path, perm); err != nil {
return nil, fmt.Errorf("Failed setting up HTTP socket: %s", err)
}
return l, nil
}
func (a *Agent) serveHTTP(l net.Listener, srv *HTTPServer) error {
// https://github.com/golang/go/issues/20239
//
// In go.8.1 there is a race between Serve and Shutdown. If
// Shutdown is called before the Serve go routine was scheduled then
// the Serve go routine never returns. This deadlocks the agent
// shutdown for some tests since it will wait forever.
if strings.Contains("*tls.listener", fmt.Sprintf("%T", l)) {
srv.proto = "https"
}
notif := make(chan string)
a.wgServers.Add(1)
go func() {
defer a.wgServers.Done()
notif <- srv.Addr
err := srv.Serve(l)
if err != nil && err != http.ErrServerClosed {
a.logger.Print(err)
}
}()
select {
case addr := <-notif:
if srv.proto == "https" {
a.logger.Printf("[INFO] agent: Started HTTPS server on %s", addr)
} else {
a.logger.Printf("[INFO] agent: Started HTTP server on %s", addr)
}
return nil
case <-time.After(time.Second):
return fmt.Errorf("agent: timeout starting HTTP servers")
}
}
// consulConfig is used to return a consul configuration
func (a *Agent) consulConfig() (*consul.Config, error) {
// Start with the provided config or default config
@ -982,12 +1030,19 @@ func (a *Agent) Shutdown() error {
a.logger.Println("[INFO] agent: Requesting shutdown")
// Stop all API endpoints
a.logger.Println("[INFO] agent: Stopping DNS endpoints")
for _, srv := range a.dnsServers {
a.logger.Printf("[INFO] agent: Stopping DNS server %s (%s)", srv.Server.Addr, srv.Server.Net)
srv.Shutdown()
}
for _, srv := range a.httpServers {
a.logger.Println("[INFO] agent: Stopping HTTP endpoint", srv.Addr)
// http server is HTTPS if TLSConfig is not nil and NextProtos does not only contain "h2"
// the latter seems to be a side effect of HTTP/2 support in go 1.8. TLSConfig != nil is
// no longer sufficient to check for an HTTPS server.
if srv.proto == "https" {
a.logger.Println("[INFO] agent: Stopping HTTPS server", srv.Addr)
} else {
a.logger.Println("[INFO] agent: Stopping HTTP server", srv.Addr)
}
// old behavior: just die
// srv.Close()

View File

@ -1,6 +1,7 @@
package agent
import (
"encoding/json"
"fmt"
"io"
"net"
@ -1034,3 +1035,13 @@ Usage: consul agent [options]
return strings.TrimSpace(helpText)
}
func printJSON(name string, v interface{}) {
fmt.Println(name)
b, err := json.MarshalIndent(v, "", " ")
if err != nil {
fmt.Printf("%#v\n", v)
return
}
fmt.Println(string(b))
}

View File

@ -780,26 +780,49 @@ func (c *Config) IncomingTLSConfig() (*tls.Config, error) {
return tc.IncomingTLSConfig()
}
type ProtoAddr struct {
Proto, Net, Addr string
}
func (p ProtoAddr) String() string {
return p.Proto + "+" + p.Net + "://" + p.Addr
}
func (c *Config) DNSAddrs() ([]ProtoAddr, error) {
if c.Ports.DNS == 0 {
return nil, nil
}
a, err := c.ClientListener(c.Addresses.DNS, c.Ports.DNS)
if err != nil {
return nil, err
}
addrs := []ProtoAddr{
{"dns", "tcp", a.String()},
{"dns", "udp", a.String()},
}
return addrs, nil
}
// HTTPAddrs returns the bind addresses for the HTTP server and
// the application protocol which should be served, e.g. 'http'
// or 'https'.
func (c *Config) HTTPAddrs() (map[string][]net.Addr, error) {
m := map[string][]net.Addr{}
func (c *Config) HTTPAddrs() ([]ProtoAddr, error) {
var addrs []ProtoAddr
if c.Ports.HTTP > 0 {
a, err := c.ClientListener(c.Addresses.HTTP, c.Ports.HTTP)
if err != nil {
return nil, err
}
m["http"] = []net.Addr{a}
addrs = append(addrs, ProtoAddr{"http", a.Network(), a.String()})
}
if c.Ports.HTTPS > 0 {
a, err := c.ClientListener(c.Addresses.HTTPS, c.Ports.HTTPS)
if err != nil {
return nil, err
}
m["https"] = []net.Addr{a}
addrs = append(addrs, ProtoAddr{"https", a.Network(), a.String()})
}
return m, nil
return addrs, nil
}
// Bool is used to initialize bool pointers in struct literals.

View File

@ -3,12 +3,9 @@ package agent
import (
"encoding/hex"
"fmt"
"io"
"log"
"net"
"os"
"strings"
"sync"
"time"
"github.com/armon/go-metrics"
@ -33,127 +30,56 @@ const (
// DNSServer is used to wrap an Agent and expose various
// service discovery endpoints using a DNS interface.
type DNSServer struct {
*dns.Server
agent *Agent
config *DNSConfig
dnsHandler *dns.ServeMux
dnsServer *dns.Server
dnsServerTCP *dns.Server
domain string
recursors []string
logger *log.Logger
}
// Shutdown stops the DNS Servers
func (d *DNSServer) Shutdown() {
if err := d.dnsServer.Shutdown(); err != nil {
d.logger.Printf("[ERR] dns: error stopping udp server: %v", err)
}
if err := d.dnsServerTCP.Shutdown(); err != nil {
d.logger.Printf("[ERR] dns: error stopping tcp server: %v", err)
}
}
// NewDNSServer starts a new DNS server to provide an agent interface
func NewDNSServer(agent *Agent, config *DNSConfig, logOutput io.Writer, logger *log.Logger, domain string, bind string, recursors []string) (*DNSServer, error) {
if logger == nil {
if logOutput == nil {
logOutput = os.Stderr
}
logger = log.New(logOutput, "", log.LstdFlags)
}
// Make sure domain is FQDN, make it case insensitive for ServeMux
domain = dns.Fqdn(strings.ToLower(domain))
// Construct the DNS components
mux := dns.NewServeMux()
var wg sync.WaitGroup
// Setup the servers
server := &dns.Server{
Addr: bind,
Net: "udp",
Handler: mux,
UDPSize: 65535,
NotifyStartedFunc: wg.Done,
}
serverTCP := &dns.Server{
Addr: bind,
Net: "tcp",
Handler: mux,
NotifyStartedFunc: wg.Done,
}
// Create the server
srv := &DNSServer{
agent: agent,
config: config,
dnsHandler: mux,
dnsServer: server,
dnsServerTCP: serverTCP,
domain: domain,
recursors: recursors,
logger: logger,
}
// Register mux handler, for reverse lookup
mux.HandleFunc("arpa.", srv.handlePtr)
// Register mux handlers
mux.HandleFunc(domain, srv.handleQuery)
if len(recursors) > 0 {
validatedRecursors := make([]string, len(recursors))
for idx, recursor := range recursors {
recursor, err := recursorAddr(recursor)
func NewDNSServer(a *Agent) (*DNSServer, error) {
var recursors []string
for _, r := range a.config.DNSRecursors {
ra, err := recursorAddr(r)
if err != nil {
return nil, fmt.Errorf("Invalid recursor address: %v", err)
}
validatedRecursors[idx] = recursor
recursors = append(recursors, ra)
}
srv.recursors = validatedRecursors
mux.HandleFunc(".", srv.handleRecurse)
// Make sure domain is FQDN, make it case insensitive for ServeMux
domain := dns.Fqdn(strings.ToLower(a.config.Domain))
srv := &DNSServer{
agent: a,
config: &a.config.DNSConfig,
domain: domain,
logger: a.logger,
recursors: recursors,
}
wg.Add(2)
// Async start the DNS Servers, handle a potential error
errCh := make(chan error, 2)
go func() {
if err := server.ListenAndServe(); err != nil && !strings.Contains(err.Error(), "accept") {
srv.logger.Printf("[ERR] dns: error starting udp server: %v", err)
errCh <- fmt.Errorf("dns udp setup failed: %v", err)
}
}()
go func() {
if err := serverTCP.ListenAndServe(); err != nil && !strings.Contains(err.Error(), "accept") {
srv.logger.Printf("[ERR] dns: error starting tcp server: %v", err)
errCh <- fmt.Errorf("dns tcp setup failed: %v", err)
}
}()
// Wait for NotifyStartedFunc callbacks indicating server has started
startCh := make(chan struct{})
go func() {
wg.Wait()
close(startCh)
}()
// Wait for either the check, listen error, or timeout
select {
case <-startCh:
return srv, nil
case e := <-errCh:
server.Shutdown()
serverTCP.Shutdown()
return nil, e
case <-time.After(time.Second):
server.Shutdown()
serverTCP.Shutdown()
return nil, fmt.Errorf("timeout setting up DNS server")
}
func (s *DNSServer) ListenAndServe(network, addr string, notif func()) error {
mux := dns.NewServeMux()
mux.HandleFunc("arpa.", s.handlePtr)
mux.HandleFunc(s.domain, s.handleQuery)
if len(s.recursors) > 0 {
mux.HandleFunc(".", s.handleRecurse)
}
s.Server = &dns.Server{
Addr: addr,
Net: network,
Handler: mux,
NotifyStartedFunc: notif,
}
if network == "udp" {
s.UDPSize = 65535
}
return s.Server.ListenAndServe()
}
// recursorAddr is used to add a port to the recursor if omitted.

View File

@ -408,9 +408,10 @@ func TestDNS_ReverseLookup(t *testing.T) {
func TestDNS_ReverseLookup_CustomDomain(t *testing.T) {
t.Parallel()
a := NewTestAgent(t.Name(), nil)
cfg := TestConfig()
cfg.Domain = "custom"
a := NewTestAgent(t.Name(), cfg)
defer a.Shutdown()
a.dns.domain = dns.Fqdn("custom")
// Register node
args := &structs.RegisterRequest{

View File

@ -19,10 +19,11 @@ import (
type HTTPServer struct {
*http.Server
agent *Agent
proto string
}
func NewHTTPServer(addr string, a *Agent) *HTTPServer {
s := &HTTPServer{&http.Server{Addr: addr}, a}
s := &HTTPServer{Server: &http.Server{Addr: addr}, agent: a}
s.Server.Handler = s.handler(s.agent.config.EnableDebug)
return s
}

View File

@ -1,57 +0,0 @@
package agent
import (
"crypto/tls"
"fmt"
"net"
"os"
"time"
)
func ListenTCP(addr string) (net.Listener, error) {
l, err := net.Listen("tcp", addr)
if err != nil {
return nil, err
}
l = tcpKeepAliveListener{l.(*net.TCPListener)}
return l, nil
}
func ListenTLS(addr string, cfg *tls.Config) (net.Listener, error) {
l, err := ListenTCP(addr)
if err != nil {
return nil, err
}
return tls.NewListener(l, cfg), nil
}
func ListenUnix(addr string, perm FilePermissions) (net.Listener, error) {
if err := os.Remove(addr); err != nil && !os.IsNotExist(err) {
return nil, fmt.Errorf("error removing socket file: %s", err)
}
l, err := net.Listen("unix", addr)
if err != nil {
return nil, err
}
if err := setFilePermissions(addr, perm); err != nil {
return nil, fmt.Errorf("Failed setting up HTTP socket: %s", err)
}
return l, nil
}
// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted
// connections. It's used by NewHttpServer so
// dead TCP connections eventually go away.
type tcpKeepAliveListener struct {
*net.TCPListener
}
func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) {
tc, err := ln.AcceptTCP()
if err != nil {
return
}
tc.SetKeepAlive(true)
tc.SetKeepAlivePeriod(30 * time.Second)
return tc, nil
}