402 lines
10 KiB
Go
402 lines
10 KiB
Go
// Copyright (c) 2022 Uber Technologies, Inc.
|
|
//
|
|
// Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
// of this software and associated documentation files (the "Software"), to deal
|
|
// in the Software without restriction, including without limitation the rights
|
|
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
// copies of the Software, and to permit persons to whom the Software is
|
|
// furnished to do so, subject to the following conditions:
|
|
//
|
|
// The above copyright notice and this permission notice shall be included in
|
|
// all copies or substantial portions of the Software.
|
|
//
|
|
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
|
// THE SOFTWARE.
|
|
|
|
package lifecycle
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"reflect"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"go.uber.org/fx/fxevent"
|
|
"go.uber.org/fx/internal/fxclock"
|
|
"go.uber.org/fx/internal/fxreflect"
|
|
"go.uber.org/multierr"
|
|
)
|
|
|
|
// Reflection types for each of the supported hook function signatures. These
|
|
// are used in cases in which the Callable constraint matches a user-defined
|
|
// function type that cannot be converted to an underlying function type with
|
|
// a conventional conversion or type switch.
|
|
var (
|
|
_reflFunc = reflect.TypeOf(Func(nil))
|
|
_reflErrorFunc = reflect.TypeOf(ErrorFunc(nil))
|
|
_reflContextFunc = reflect.TypeOf(ContextFunc(nil))
|
|
_reflContextErrorFunc = reflect.TypeOf(ContextErrorFunc(nil))
|
|
)
|
|
|
|
// Discrete function signatures that are allowed as part of a [Callable].
|
|
type (
|
|
// A Func can be converted to a ContextErrorFunc.
|
|
Func = func()
|
|
// An ErrorFunc can be converted to a ContextErrorFunc.
|
|
ErrorFunc = func() error
|
|
// A ContextFunc can be converted to a ContextErrorFunc.
|
|
ContextFunc = func(context.Context)
|
|
// A ContextErrorFunc is used as a [Hook.OnStart] or [Hook.OnStop]
|
|
// function.
|
|
ContextErrorFunc = func(context.Context) error
|
|
)
|
|
|
|
// A Callable is a constraint that matches functions that are, or can be
|
|
// converted to, functions suitable for a Hook.
|
|
//
|
|
// Callable must be identical to [fx.HookFunc].
|
|
type Callable interface {
|
|
~Func | ~ErrorFunc | ~ContextFunc | ~ContextErrorFunc
|
|
}
|
|
|
|
// Wrap wraps x into a ContextErrorFunc suitable for a Hook.
|
|
func Wrap[T Callable](x T) (ContextErrorFunc, string) {
|
|
if x == nil {
|
|
return nil, ""
|
|
}
|
|
|
|
switch fn := any(x).(type) {
|
|
case Func:
|
|
return func(context.Context) error {
|
|
fn()
|
|
return nil
|
|
}, fxreflect.FuncName(x)
|
|
case ErrorFunc:
|
|
return func(context.Context) error {
|
|
return fn()
|
|
}, fxreflect.FuncName(x)
|
|
case ContextFunc:
|
|
return func(ctx context.Context) error {
|
|
fn(ctx)
|
|
return nil
|
|
}, fxreflect.FuncName(x)
|
|
case ContextErrorFunc:
|
|
return fn, fxreflect.FuncName(x)
|
|
}
|
|
|
|
// Since (1) we're already using reflect in Fx, (2) we're not particularly
|
|
// concerned with performance, and (3) unsafe would require discrete build
|
|
// targets for appengine (etc), just use reflect to convert user-defined
|
|
// function types to their underlying function types and then call Wrap
|
|
// again with the converted value.
|
|
reflVal := reflect.ValueOf(x)
|
|
switch {
|
|
case reflVal.CanConvert(_reflFunc):
|
|
return Wrap(reflVal.Convert(_reflFunc).Interface().(Func))
|
|
case reflVal.CanConvert(_reflErrorFunc):
|
|
return Wrap(reflVal.Convert(_reflErrorFunc).Interface().(ErrorFunc))
|
|
case reflVal.CanConvert(_reflContextFunc):
|
|
return Wrap(reflVal.Convert(_reflContextFunc).Interface().(ContextFunc))
|
|
default:
|
|
// Is already convertible to ContextErrorFunc.
|
|
return Wrap(reflVal.Convert(_reflContextErrorFunc).Interface().(ContextErrorFunc))
|
|
}
|
|
}
|
|
|
|
// A Hook is a pair of start and stop callbacks, either of which can be nil,
|
|
// plus a string identifying the supplier of the hook.
|
|
type Hook struct {
|
|
OnStart func(context.Context) error
|
|
OnStop func(context.Context) error
|
|
OnStartName string
|
|
OnStopName string
|
|
|
|
callerFrame fxreflect.Frame
|
|
}
|
|
|
|
type appState int
|
|
|
|
const (
|
|
stopped appState = iota
|
|
starting
|
|
incompleteStart
|
|
started
|
|
stopping
|
|
)
|
|
|
|
func (as appState) String() string {
|
|
switch as {
|
|
case stopped:
|
|
return "stopped"
|
|
case starting:
|
|
return "starting"
|
|
case incompleteStart:
|
|
return "incompleteStart"
|
|
case started:
|
|
return "started"
|
|
case stopping:
|
|
return "stopping"
|
|
default:
|
|
return "invalidState"
|
|
}
|
|
}
|
|
|
|
// Lifecycle coordinates application lifecycle hooks.
|
|
type Lifecycle struct {
|
|
clock fxclock.Clock
|
|
logger fxevent.Logger
|
|
state appState
|
|
hooks []Hook
|
|
numStarted int
|
|
startRecords HookRecords
|
|
stopRecords HookRecords
|
|
runningHook Hook
|
|
mu sync.Mutex
|
|
}
|
|
|
|
// New constructs a new Lifecycle.
|
|
func New(logger fxevent.Logger, clock fxclock.Clock) *Lifecycle {
|
|
return &Lifecycle{logger: logger, clock: clock}
|
|
}
|
|
|
|
// Append adds a Hook to the lifecycle.
|
|
func (l *Lifecycle) Append(hook Hook) {
|
|
// Save the caller's stack frame to report file/line number.
|
|
if f := fxreflect.CallerStack(2, 0); len(f) > 0 {
|
|
hook.callerFrame = f[0]
|
|
}
|
|
l.hooks = append(l.hooks, hook)
|
|
}
|
|
|
|
// Start runs all OnStart hooks, returning immediately if it encounters an
|
|
// error.
|
|
func (l *Lifecycle) Start(ctx context.Context) error {
|
|
if ctx == nil {
|
|
return errors.New("called OnStart with nil context")
|
|
}
|
|
|
|
l.mu.Lock()
|
|
if l.state != stopped {
|
|
defer l.mu.Unlock()
|
|
return fmt.Errorf("attempted to start lifecycle when in state: %v", l.state)
|
|
}
|
|
l.numStarted = 0
|
|
l.state = starting
|
|
|
|
l.startRecords = make(HookRecords, 0, len(l.hooks))
|
|
l.mu.Unlock()
|
|
|
|
var returnState appState = incompleteStart
|
|
defer func() {
|
|
l.mu.Lock()
|
|
l.state = returnState
|
|
l.mu.Unlock()
|
|
}()
|
|
|
|
for _, hook := range l.hooks {
|
|
// if ctx has cancelled, bail out of the loop.
|
|
if err := ctx.Err(); err != nil {
|
|
return err
|
|
}
|
|
|
|
if hook.OnStart != nil {
|
|
l.mu.Lock()
|
|
l.runningHook = hook
|
|
l.mu.Unlock()
|
|
|
|
runtime, err := l.runStartHook(ctx, hook)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
l.mu.Lock()
|
|
l.startRecords = append(l.startRecords, HookRecord{
|
|
CallerFrame: hook.callerFrame,
|
|
Func: hook.OnStart,
|
|
Runtime: runtime,
|
|
})
|
|
l.mu.Unlock()
|
|
}
|
|
l.numStarted++
|
|
}
|
|
|
|
returnState = started
|
|
return nil
|
|
}
|
|
|
|
func (l *Lifecycle) runStartHook(ctx context.Context, hook Hook) (runtime time.Duration, err error) {
|
|
funcName := hook.OnStartName
|
|
if len(funcName) == 0 {
|
|
funcName = fxreflect.FuncName(hook.OnStart)
|
|
}
|
|
|
|
l.logger.LogEvent(&fxevent.OnStartExecuting{
|
|
CallerName: hook.callerFrame.Function,
|
|
FunctionName: funcName,
|
|
})
|
|
defer func() {
|
|
l.logger.LogEvent(&fxevent.OnStartExecuted{
|
|
CallerName: hook.callerFrame.Function,
|
|
FunctionName: funcName,
|
|
Runtime: runtime,
|
|
Err: err,
|
|
})
|
|
}()
|
|
|
|
begin := l.clock.Now()
|
|
err = hook.OnStart(ctx)
|
|
return l.clock.Since(begin), err
|
|
}
|
|
|
|
// Stop runs any OnStop hooks whose OnStart counterpart succeeded. OnStop
|
|
// hooks run in reverse order.
|
|
func (l *Lifecycle) Stop(ctx context.Context) error {
|
|
if ctx == nil {
|
|
return errors.New("called OnStop with nil context")
|
|
}
|
|
|
|
l.mu.Lock()
|
|
if l.state != started && l.state != incompleteStart && l.state != starting {
|
|
defer l.mu.Unlock()
|
|
return nil
|
|
}
|
|
l.state = stopping
|
|
l.mu.Unlock()
|
|
|
|
defer func() {
|
|
l.mu.Lock()
|
|
l.state = stopped
|
|
l.mu.Unlock()
|
|
}()
|
|
|
|
l.mu.Lock()
|
|
l.stopRecords = make(HookRecords, 0, l.numStarted)
|
|
// Take a snapshot of hook state to avoid races.
|
|
allHooks := l.hooks[:]
|
|
numStarted := l.numStarted
|
|
l.mu.Unlock()
|
|
|
|
// Run backward from last successful OnStart.
|
|
var errs []error
|
|
for ; numStarted > 0; numStarted-- {
|
|
if err := ctx.Err(); err != nil {
|
|
return err
|
|
}
|
|
hook := allHooks[numStarted-1]
|
|
if hook.OnStop == nil {
|
|
continue
|
|
}
|
|
|
|
l.mu.Lock()
|
|
l.runningHook = hook
|
|
l.mu.Unlock()
|
|
|
|
runtime, err := l.runStopHook(ctx, hook)
|
|
if err != nil {
|
|
// For best-effort cleanup, keep going after errors.
|
|
errs = append(errs, err)
|
|
}
|
|
|
|
l.mu.Lock()
|
|
l.stopRecords = append(l.stopRecords, HookRecord{
|
|
CallerFrame: hook.callerFrame,
|
|
Func: hook.OnStop,
|
|
Runtime: runtime,
|
|
})
|
|
l.mu.Unlock()
|
|
}
|
|
|
|
return multierr.Combine(errs...)
|
|
}
|
|
|
|
func (l *Lifecycle) runStopHook(ctx context.Context, hook Hook) (runtime time.Duration, err error) {
|
|
funcName := hook.OnStopName
|
|
if len(funcName) == 0 {
|
|
funcName = fxreflect.FuncName(hook.OnStop)
|
|
}
|
|
|
|
l.logger.LogEvent(&fxevent.OnStopExecuting{
|
|
CallerName: hook.callerFrame.Function,
|
|
FunctionName: funcName,
|
|
})
|
|
defer func() {
|
|
l.logger.LogEvent(&fxevent.OnStopExecuted{
|
|
CallerName: hook.callerFrame.Function,
|
|
FunctionName: funcName,
|
|
Runtime: runtime,
|
|
Err: err,
|
|
})
|
|
}()
|
|
|
|
begin := l.clock.Now()
|
|
err = hook.OnStop(ctx)
|
|
return l.clock.Since(begin), err
|
|
}
|
|
|
|
// RunningHookCaller returns the name of the hook that was running when a Start/Stop
|
|
// hook timed out.
|
|
func (l *Lifecycle) RunningHookCaller() string {
|
|
l.mu.Lock()
|
|
defer l.mu.Unlock()
|
|
return l.runningHook.callerFrame.Function
|
|
}
|
|
|
|
// HookRecord keeps track of each Hook's execution time, the caller that appended the Hook, and function that ran as the Hook.
|
|
type HookRecord struct {
|
|
CallerFrame fxreflect.Frame // stack frame of the caller
|
|
Func func(context.Context) error // function that ran as sanitized name
|
|
Runtime time.Duration // how long the hook ran
|
|
}
|
|
|
|
// HookRecords is a Stringer wrapper of HookRecord slice.
|
|
type HookRecords []HookRecord
|
|
|
|
func (rs HookRecords) Len() int {
|
|
return len(rs)
|
|
}
|
|
|
|
func (rs HookRecords) Less(i, j int) bool {
|
|
// Sort by runtime, greater ones at top.
|
|
return rs[i].Runtime > rs[j].Runtime
|
|
}
|
|
|
|
func (rs HookRecords) Swap(i, j int) {
|
|
rs[i], rs[j] = rs[j], rs[i]
|
|
}
|
|
|
|
// Used for logging startup errors.
|
|
func (rs HookRecords) String() string {
|
|
var b strings.Builder
|
|
for _, r := range rs {
|
|
fmt.Fprintf(&b, "%s took %v from %s",
|
|
fxreflect.FuncName(r.Func), r.Runtime, r.CallerFrame)
|
|
}
|
|
return b.String()
|
|
}
|
|
|
|
// Format implements fmt.Formatter to handle "%+v".
|
|
func (rs HookRecords) Format(w fmt.State, c rune) {
|
|
if !w.Flag('+') {
|
|
// Without %+v, fall back to String().
|
|
io.WriteString(w, rs.String())
|
|
return
|
|
}
|
|
|
|
for _, r := range rs {
|
|
fmt.Fprintf(w, "\n%s took %v from:\n\t%+v",
|
|
fxreflect.FuncName(r.Func),
|
|
r.Runtime,
|
|
r.CallerFrame)
|
|
}
|
|
fmt.Fprintf(w, "\n")
|
|
}
|