402 lines
10 KiB
Go
Raw Normal View History

2023-05-19 16:23:55 -04:00
// Copyright (c) 2022 Uber Technologies, Inc.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
package lifecycle
import (
"context"
"errors"
"fmt"
"io"
2023-05-19 16:23:55 -04:00
"reflect"
"strings"
"sync"
"time"
"go.uber.org/fx/fxevent"
"go.uber.org/fx/internal/fxclock"
"go.uber.org/fx/internal/fxreflect"
"go.uber.org/multierr"
)
2023-05-19 16:23:55 -04:00
// Reflection types for each of the supported hook function signatures. These
// are used in cases in which the Callable constraint matches a user-defined
// function type that cannot be converted to an underlying function type with
// a conventional conversion or type switch.
var (
_reflFunc = reflect.TypeOf(Func(nil))
_reflErrorFunc = reflect.TypeOf(ErrorFunc(nil))
_reflContextFunc = reflect.TypeOf(ContextFunc(nil))
_reflContextErrorFunc = reflect.TypeOf(ContextErrorFunc(nil))
)
// Discrete function signatures that are allowed as part of a [Callable].
type (
// A Func can be converted to a ContextErrorFunc.
Func = func()
// An ErrorFunc can be converted to a ContextErrorFunc.
ErrorFunc = func() error
// A ContextFunc can be converted to a ContextErrorFunc.
ContextFunc = func(context.Context)
// A ContextErrorFunc is used as a [Hook.OnStart] or [Hook.OnStop]
// function.
ContextErrorFunc = func(context.Context) error
)
// A Callable is a constraint that matches functions that are, or can be
// converted to, functions suitable for a Hook.
//
// Callable must be identical to [fx.HookFunc].
type Callable interface {
~Func | ~ErrorFunc | ~ContextFunc | ~ContextErrorFunc
}
// Wrap wraps x into a ContextErrorFunc suitable for a Hook.
func Wrap[T Callable](x T) (ContextErrorFunc, string) {
if x == nil {
return nil, ""
}
switch fn := any(x).(type) {
case Func:
return func(context.Context) error {
fn()
return nil
}, fxreflect.FuncName(x)
case ErrorFunc:
return func(context.Context) error {
return fn()
}, fxreflect.FuncName(x)
case ContextFunc:
return func(ctx context.Context) error {
fn(ctx)
return nil
}, fxreflect.FuncName(x)
case ContextErrorFunc:
return fn, fxreflect.FuncName(x)
}
// Since (1) we're already using reflect in Fx, (2) we're not particularly
// concerned with performance, and (3) unsafe would require discrete build
// targets for appengine (etc), just use reflect to convert user-defined
// function types to their underlying function types and then call Wrap
// again with the converted value.
reflVal := reflect.ValueOf(x)
switch {
case reflVal.CanConvert(_reflFunc):
return Wrap(reflVal.Convert(_reflFunc).Interface().(Func))
case reflVal.CanConvert(_reflErrorFunc):
return Wrap(reflVal.Convert(_reflErrorFunc).Interface().(ErrorFunc))
case reflVal.CanConvert(_reflContextFunc):
return Wrap(reflVal.Convert(_reflContextFunc).Interface().(ContextFunc))
default:
// Is already convertible to ContextErrorFunc.
return Wrap(reflVal.Convert(_reflContextErrorFunc).Interface().(ContextErrorFunc))
}
}
// A Hook is a pair of start and stop callbacks, either of which can be nil,
// plus a string identifying the supplier of the hook.
type Hook struct {
2023-05-19 16:23:55 -04:00
OnStart func(context.Context) error
OnStop func(context.Context) error
OnStartName string
OnStopName string
callerFrame fxreflect.Frame
}
2023-05-19 16:23:55 -04:00
type appState int
const (
stopped appState = iota
starting
incompleteStart
started
stopping
)
func (as appState) String() string {
switch as {
case stopped:
return "stopped"
case starting:
return "starting"
case incompleteStart:
return "incompleteStart"
case started:
return "started"
case stopping:
return "stopping"
default:
return "invalidState"
}
}
// Lifecycle coordinates application lifecycle hooks.
type Lifecycle struct {
clock fxclock.Clock
logger fxevent.Logger
2023-05-19 16:23:55 -04:00
state appState
hooks []Hook
numStarted int
startRecords HookRecords
stopRecords HookRecords
runningHook Hook
mu sync.Mutex
}
// New constructs a new Lifecycle.
func New(logger fxevent.Logger, clock fxclock.Clock) *Lifecycle {
return &Lifecycle{logger: logger, clock: clock}
}
// Append adds a Hook to the lifecycle.
func (l *Lifecycle) Append(hook Hook) {
// Save the caller's stack frame to report file/line number.
if f := fxreflect.CallerStack(2, 0); len(f) > 0 {
hook.callerFrame = f[0]
}
l.hooks = append(l.hooks, hook)
}
// Start runs all OnStart hooks, returning immediately if it encounters an
// error.
func (l *Lifecycle) Start(ctx context.Context) error {
if ctx == nil {
return errors.New("called OnStart with nil context")
}
l.mu.Lock()
2023-05-19 16:23:55 -04:00
if l.state != stopped {
defer l.mu.Unlock()
return fmt.Errorf("attempted to start lifecycle when in state: %v", l.state)
}
l.numStarted = 0
l.state = starting
l.startRecords = make(HookRecords, 0, len(l.hooks))
l.mu.Unlock()
2024-06-05 16:10:03 -04:00
returnState := incompleteStart
2023-05-19 16:23:55 -04:00
defer func() {
l.mu.Lock()
l.state = returnState
l.mu.Unlock()
}()
for _, hook := range l.hooks {
// if ctx has cancelled, bail out of the loop.
if err := ctx.Err(); err != nil {
return err
}
if hook.OnStart != nil {
l.mu.Lock()
l.runningHook = hook
l.mu.Unlock()
runtime, err := l.runStartHook(ctx, hook)
if err != nil {
return err
}
l.mu.Lock()
l.startRecords = append(l.startRecords, HookRecord{
CallerFrame: hook.callerFrame,
Func: hook.OnStart,
Runtime: runtime,
})
l.mu.Unlock()
}
l.numStarted++
}
2023-05-19 16:23:55 -04:00
returnState = started
return nil
}
func (l *Lifecycle) runStartHook(ctx context.Context, hook Hook) (runtime time.Duration, err error) {
2023-05-19 16:23:55 -04:00
funcName := hook.OnStartName
if len(funcName) == 0 {
funcName = fxreflect.FuncName(hook.OnStart)
}
l.logger.LogEvent(&fxevent.OnStartExecuting{
CallerName: hook.callerFrame.Function,
FunctionName: funcName,
})
defer func() {
l.logger.LogEvent(&fxevent.OnStartExecuted{
CallerName: hook.callerFrame.Function,
FunctionName: funcName,
Runtime: runtime,
Err: err,
})
}()
begin := l.clock.Now()
err = hook.OnStart(ctx)
return l.clock.Since(begin), err
}
// Stop runs any OnStop hooks whose OnStart counterpart succeeded. OnStop
// hooks run in reverse order.
func (l *Lifecycle) Stop(ctx context.Context) error {
if ctx == nil {
return errors.New("called OnStop with nil context")
}
2023-05-19 16:23:55 -04:00
l.mu.Lock()
if l.state != started && l.state != incompleteStart && l.state != starting {
2023-05-19 16:23:55 -04:00
defer l.mu.Unlock()
return nil
}
l.state = stopping
l.mu.Unlock()
defer func() {
l.mu.Lock()
l.state = stopped
l.mu.Unlock()
}()
l.mu.Lock()
l.stopRecords = make(HookRecords, 0, l.numStarted)
// Take a snapshot of hook state to avoid races.
allHooks := l.hooks[:]
numStarted := l.numStarted
l.mu.Unlock()
// Run backward from last successful OnStart.
var errs []error
for ; numStarted > 0; numStarted-- {
if err := ctx.Err(); err != nil {
return err
}
hook := allHooks[numStarted-1]
if hook.OnStop == nil {
continue
}
l.mu.Lock()
l.runningHook = hook
l.mu.Unlock()
runtime, err := l.runStopHook(ctx, hook)
if err != nil {
// For best-effort cleanup, keep going after errors.
errs = append(errs, err)
}
l.mu.Lock()
l.stopRecords = append(l.stopRecords, HookRecord{
CallerFrame: hook.callerFrame,
Func: hook.OnStop,
Runtime: runtime,
})
l.mu.Unlock()
}
return multierr.Combine(errs...)
}
func (l *Lifecycle) runStopHook(ctx context.Context, hook Hook) (runtime time.Duration, err error) {
2023-05-19 16:23:55 -04:00
funcName := hook.OnStopName
if len(funcName) == 0 {
funcName = fxreflect.FuncName(hook.OnStop)
}
l.logger.LogEvent(&fxevent.OnStopExecuting{
CallerName: hook.callerFrame.Function,
FunctionName: funcName,
})
defer func() {
l.logger.LogEvent(&fxevent.OnStopExecuted{
CallerName: hook.callerFrame.Function,
FunctionName: funcName,
Runtime: runtime,
Err: err,
})
}()
begin := l.clock.Now()
err = hook.OnStop(ctx)
return l.clock.Since(begin), err
}
// RunningHookCaller returns the name of the hook that was running when a Start/Stop
// hook timed out.
func (l *Lifecycle) RunningHookCaller() string {
l.mu.Lock()
defer l.mu.Unlock()
return l.runningHook.callerFrame.Function
}
// HookRecord keeps track of each Hook's execution time, the caller that appended the Hook, and function that ran as the Hook.
type HookRecord struct {
CallerFrame fxreflect.Frame // stack frame of the caller
Func func(context.Context) error // function that ran as sanitized name
Runtime time.Duration // how long the hook ran
}
// HookRecords is a Stringer wrapper of HookRecord slice.
type HookRecords []HookRecord
func (rs HookRecords) Len() int {
return len(rs)
}
func (rs HookRecords) Less(i, j int) bool {
// Sort by runtime, greater ones at top.
return rs[i].Runtime > rs[j].Runtime
}
func (rs HookRecords) Swap(i, j int) {
rs[i], rs[j] = rs[j], rs[i]
}
// Used for logging startup errors.
func (rs HookRecords) String() string {
var b strings.Builder
for _, r := range rs {
fmt.Fprintf(&b, "%s took %v from %s",
fxreflect.FuncName(r.Func), r.Runtime, r.CallerFrame)
}
return b.String()
}
// Format implements fmt.Formatter to handle "%+v".
func (rs HookRecords) Format(w fmt.State, c rune) {
if !w.Flag('+') {
// Without %+v, fall back to String().
io.WriteString(w, rs.String())
return
}
for _, r := range rs {
fmt.Fprintf(w, "\n%s took %v from:\n\t%+v",
fxreflect.FuncName(r.Func),
r.Runtime,
r.CallerFrame)
}
fmt.Fprintf(w, "\n")
}