From af083cc5ba2d38a558d3f5422f43fb0c49ccfc54 Mon Sep 17 00:00:00 2001 From: Alessandro De Blasis Date: Thu, 9 Jun 2022 15:48:34 +0100 Subject: [PATCH] tests: added syscall mocking and tests for Check_OSService --- agent/checks/check_windows_test.go | 430 +++++++++++++++++++++++++++++ agent/checks/os_service_windows.go | 82 +++++- 2 files changed, 501 insertions(+), 11 deletions(-) create mode 100644 agent/checks/check_windows_test.go diff --git a/agent/checks/check_windows_test.go b/agent/checks/check_windows_test.go new file mode 100644 index 0000000000..ad63ada6d5 --- /dev/null +++ b/agent/checks/check_windows_test.go @@ -0,0 +1,430 @@ +//go:build windows +// +build windows + +package checks + +import ( + "errors" + "testing" + "time" + + "github.com/hashicorp/consul/agent/mock" + "github.com/hashicorp/consul/agent/structs" + "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/sdk/testutil" + "golang.org/x/sys/windows" + "golang.org/x/sys/windows/svc" +) + +func TestOSServiceClient(t *testing.T) { + type args struct { + returnsOpenSCManagerError error + returnsOpenServiceError error + returnsServiceQueryError error + returnsServiceCloseError error + returnsSCMgrDisconnectError error + returnsServiceState svc.State + } + tests := []struct { + name string + args args + maybeHealthy *bool + }{ + // healthy + {"should pass for healthy service", args{ + returnsOpenSCManagerError: nil, + returnsOpenServiceError: nil, + returnsServiceQueryError: nil, + returnsServiceCloseError: nil, + returnsSCMgrDisconnectError: nil, + returnsServiceState: svc.Running, + }, boolPointer(true)}, + {"should pass for healthy service even when there's an error closing the service handle", args{ + returnsOpenSCManagerError: nil, + returnsOpenServiceError: nil, + returnsServiceQueryError: nil, + returnsServiceCloseError: errors.New("error while closing the service handle"), + returnsSCMgrDisconnectError: nil, + returnsServiceState: svc.Running, + }, boolPointer(true)}, + {"should pass for healthy service even when there's an error disconnecting from SCManager", args{ + returnsOpenSCManagerError: nil, + returnsOpenServiceError: nil, + returnsServiceQueryError: nil, + returnsServiceCloseError: nil, + returnsSCMgrDisconnectError: errors.New("error while disconnecting from service manager"), + returnsServiceState: svc.Running, + }, boolPointer(true)}, + + // warning + {"should be in warning state for any state that's not Running, Paused or Stopped", args{ + returnsOpenSCManagerError: nil, + returnsOpenServiceError: nil, + returnsServiceQueryError: nil, + returnsServiceCloseError: nil, + returnsSCMgrDisconnectError: nil, + returnsServiceState: svc.StartPending, + }, nil}, + {"should be in warning state when we cannot connect to the service manager", args{ + returnsOpenSCManagerError: errors.New("cannot connect to service manager"), + returnsOpenServiceError: nil, + returnsServiceQueryError: nil, + returnsServiceCloseError: nil, + returnsSCMgrDisconnectError: nil, + returnsServiceState: svc.Running, + }, nil}, + {"should be in warning state when we cannot open the service", args{ + returnsOpenSCManagerError: nil, + returnsOpenServiceError: errors.New("service testService does not exist"), + returnsServiceQueryError: nil, + returnsServiceCloseError: nil, + returnsSCMgrDisconnectError: nil, + returnsServiceState: svc.Running, + }, nil}, + {"should be in warning state when we cannot query the service state", args{ + returnsOpenSCManagerError: nil, + returnsOpenServiceError: nil, + returnsServiceQueryError: errors.New("cannot query testService state"), + returnsServiceCloseError: nil, + returnsSCMgrDisconnectError: nil, + returnsServiceState: svc.Running, + }, nil}, + {"should be in warning state for for any state that's not Running, Paused or Stopped when there's an error closing the service handle", args{ + returnsOpenSCManagerError: nil, + returnsOpenServiceError: nil, + returnsServiceQueryError: nil, + returnsServiceCloseError: errors.New("error while closing the service handle"), + returnsSCMgrDisconnectError: nil, + returnsServiceState: svc.StartPending, + }, nil}, + {"should be in warning state for for any state that's not Running, Paused or Stopped when there's an error disconnecting from SCManager", args{ + returnsOpenSCManagerError: nil, + returnsOpenServiceError: nil, + returnsServiceQueryError: nil, + returnsServiceCloseError: nil, + returnsSCMgrDisconnectError: errors.New("error while disconnecting from service manager"), + returnsServiceState: svc.StartPending, + }, nil}, + + // critical + {"should fail for paused service", args{ + returnsOpenSCManagerError: nil, + returnsOpenServiceError: nil, + returnsServiceQueryError: nil, + returnsServiceCloseError: nil, + returnsSCMgrDisconnectError: nil, + returnsServiceState: svc.Paused, + }, boolPointer(false)}, + {"should fail for stopped service", args{ + returnsOpenSCManagerError: nil, + returnsOpenServiceError: nil, + returnsServiceQueryError: nil, + returnsServiceCloseError: nil, + returnsSCMgrDisconnectError: nil, + returnsServiceState: svc.Stopped, + }, boolPointer(false)}, + {"should fail for stopped service even when there's an error closing the service handle", args{ + returnsOpenSCManagerError: nil, + returnsOpenServiceError: nil, + returnsServiceQueryError: nil, + returnsServiceCloseError: errors.New("error while closing the service handle"), + returnsSCMgrDisconnectError: nil, + returnsServiceState: svc.Stopped, + }, boolPointer(false)}, + {"should fail for stopped service even when there's an error disconnecting from SCManager", args{ + returnsOpenSCManagerError: nil, + returnsOpenServiceError: nil, + returnsServiceQueryError: nil, + returnsServiceCloseError: nil, + returnsSCMgrDisconnectError: errors.New("error while disconnecting from service manager"), + returnsServiceState: svc.Stopped, + }, boolPointer(false)}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + old := win + defer func() { win = old }() + win = fakeWindowsOS{ + returnsOpenSCManagerError: tt.args.returnsOpenSCManagerError, + returnsOpenServiceError: tt.args.returnsOpenServiceError, + returnsServiceQueryError: tt.args.returnsServiceQueryError, + returnsServiceCloseError: tt.args.returnsServiceCloseError, + returnsSCMgrDisconnectError: tt.args.returnsSCMgrDisconnectError, + returnsServiceState: tt.args.returnsServiceState, + } + probe, err := NewOSServiceClient() + if (tt.args.returnsOpenSCManagerError != nil && err == nil) || (tt.args.returnsOpenSCManagerError == nil && err != nil) { + t.Errorf("FAIL: %s. Expected error on OpenSCManager %v , but err == %v", tt.name, tt.args.returnsOpenSCManagerError, err) + } + if err != nil { + return + } + actualError := probe.Check("testService") + actuallyHealthy := actualError == nil + actualErrorIsCritical := errors.Is(actualError, ErrOSServiceStatusCritical) + actualWarning := !actuallyHealthy && !actualErrorIsCritical + expectedHealthy := tt.maybeHealthy != nil && *tt.maybeHealthy + expectedWarning := tt.maybeHealthy == nil + if expectedHealthy && !actuallyHealthy { + t.Errorf("FAIL: %s. Expected healthy %t, but err == %v", tt.name, boolVal(tt.maybeHealthy), actualError) + } + + if expectedWarning && !actualWarning { + t.Errorf("FAIL: %s. Expected non critical error, but err == %v", tt.name, actualError) + } + }) + } +} + +func TestCheck_OSService(t *testing.T) { + type args struct { + returnsOpenSCManagerError error + returnsOpenServiceError error + returnsServiceQueryError error + returnsServiceCloseError error + returnsSCMgrDisconnectError error + returnsServiceState svc.State + } + tests := []struct { + desc string + args args + state string + }{ + //healthy + {"should pass for healthy service", args{ + returnsOpenSCManagerError: nil, + returnsOpenServiceError: nil, + returnsServiceQueryError: nil, + returnsServiceCloseError: nil, + returnsSCMgrDisconnectError: nil, + returnsServiceState: svc.Running, + }, api.HealthPassing}, + {"should pass for healthy service even when there's an error closing the service handle", args{ + returnsOpenSCManagerError: nil, + returnsOpenServiceError: nil, + returnsServiceQueryError: nil, + returnsServiceCloseError: errors.New("error while closing the service handle"), + returnsSCMgrDisconnectError: nil, + returnsServiceState: svc.Running, + }, api.HealthPassing}, + {"should pass for healthy service even when there's an error disconnecting from SCManager", args{ + returnsOpenSCManagerError: nil, + returnsOpenServiceError: nil, + returnsServiceQueryError: nil, + returnsServiceCloseError: nil, + returnsSCMgrDisconnectError: errors.New("error while disconnecting from service manager"), + returnsServiceState: svc.Running, + }, api.HealthPassing}, + + // // warning + {"should be in warning state for any state that's not Running, Paused or Stopped", args{ + returnsOpenSCManagerError: nil, + returnsOpenServiceError: nil, + returnsServiceQueryError: nil, + returnsServiceCloseError: nil, + returnsSCMgrDisconnectError: nil, + returnsServiceState: svc.StartPending, + }, api.HealthWarning}, + {"should be in warning state when we cannot connect to the service manager", args{ + returnsOpenSCManagerError: errors.New("cannot connect to service manager"), + returnsOpenServiceError: nil, + returnsServiceQueryError: nil, + returnsServiceCloseError: nil, + returnsSCMgrDisconnectError: nil, + returnsServiceState: svc.Running, + }, api.HealthWarning}, + {"should be in warning state when we cannot open the service", args{ + returnsOpenSCManagerError: nil, + returnsOpenServiceError: errors.New("service testService does not exist"), + returnsServiceQueryError: nil, + returnsServiceCloseError: nil, + returnsSCMgrDisconnectError: nil, + returnsServiceState: svc.Running, + }, api.HealthWarning}, + {"should be in warning state when we cannot query the service state", args{ + returnsOpenSCManagerError: nil, + returnsOpenServiceError: nil, + returnsServiceQueryError: errors.New("cannot query testService state"), + returnsServiceCloseError: nil, + returnsSCMgrDisconnectError: nil, + returnsServiceState: svc.Running, + }, api.HealthWarning}, + {"should be in warning state for for any state that's not Running, Paused or Stopped when there's an error closing the service handle", args{ + returnsOpenSCManagerError: nil, + returnsOpenServiceError: nil, + returnsServiceQueryError: nil, + returnsServiceCloseError: errors.New("error while closing the service handle"), + returnsSCMgrDisconnectError: nil, + returnsServiceState: svc.StartPending, + }, api.HealthWarning}, + {"should be in warning state for for any state that's not Running, Paused or Stopped when there's an error disconnecting from SCManager", args{ + returnsOpenSCManagerError: nil, + returnsOpenServiceError: nil, + returnsServiceQueryError: nil, + returnsServiceCloseError: nil, + returnsSCMgrDisconnectError: errors.New("error while disconnecting from service manager"), + returnsServiceState: svc.StartPending, + }, api.HealthWarning}, + + // critical + {"should fail for paused service", args{ + returnsOpenSCManagerError: nil, + returnsOpenServiceError: nil, + returnsServiceQueryError: nil, + returnsServiceCloseError: nil, + returnsSCMgrDisconnectError: nil, + returnsServiceState: svc.Paused, + }, api.HealthCritical}, + {"should fail for stopped service", args{ + returnsOpenSCManagerError: nil, + returnsOpenServiceError: nil, + returnsServiceQueryError: nil, + returnsServiceCloseError: nil, + returnsSCMgrDisconnectError: nil, + returnsServiceState: svc.Stopped, + }, api.HealthCritical}, + {"should fail for stopped service even when there's an error closing the service handle", args{ + returnsOpenSCManagerError: nil, + returnsOpenServiceError: nil, + returnsServiceQueryError: nil, + returnsServiceCloseError: errors.New("error while closing the service handle"), + returnsSCMgrDisconnectError: nil, + returnsServiceState: svc.Stopped, + }, api.HealthCritical}, + {"should fail for stopped service even when there's an error disconnecting from SCManager", args{ + returnsOpenSCManagerError: nil, + returnsOpenServiceError: nil, + returnsServiceQueryError: nil, + returnsServiceCloseError: nil, + returnsSCMgrDisconnectError: errors.New("error while disconnecting from service manager"), + returnsServiceState: svc.Stopped, + }, api.HealthCritical}, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + old := win + defer func() { win = old }() + win = fakeWindowsOS{ + returnsOpenSCManagerError: tt.args.returnsOpenSCManagerError, + returnsOpenServiceError: tt.args.returnsOpenServiceError, + returnsServiceQueryError: tt.args.returnsServiceQueryError, + returnsServiceCloseError: tt.args.returnsServiceCloseError, + returnsSCMgrDisconnectError: tt.args.returnsSCMgrDisconnectError, + returnsServiceState: tt.args.returnsServiceState, + } + c, err := NewOSServiceClient() + if (tt.args.returnsOpenSCManagerError != nil && err == nil) || (tt.args.returnsOpenSCManagerError == nil && err != nil) { + t.Errorf("FAIL: %s. Expected error on OpenSCManager %v , but err == %v", tt.desc, tt.args.returnsOpenSCManagerError, err) + } + if err != nil { + return + } + + notif, upd := mock.NewNotifyChan() + logger := testutil.Logger(t) + statusHandler := NewStatusHandler(notif, logger, 0, 0, 0) + id := structs.NewCheckID("chk", nil) + + check := &CheckOSService{ + CheckID: id, + OSService: "testService", + Interval: 25 * time.Millisecond, + Client: c, + Logger: logger, + StatusHandler: statusHandler, + } + check.Start() + defer check.Stop() + + <-upd // wait for update + + if got, want := notif.State(id), tt.state; got != want { + t.Fatalf("got status %q want %q", got, want) + } + }) + } +} + +const ( + validSCManagerHandle = windows.Handle(1) + validOpenServiceHandle = windows.Handle(2) +) + +type fakeWindowsOS struct { + returnsOpenSCManagerError error + returnsOpenServiceError error + returnsServiceQueryError error + returnsServiceCloseError error + returnsSCMgrDisconnectError error + returnsServiceState svc.State +} + +func (f fakeWindowsOS) OpenSCManager(machineName *uint16, databaseName *uint16, access uint32) (handle windows.Handle, err error) { + if f.returnsOpenSCManagerError != nil { + return windows.InvalidHandle, f.returnsOpenSCManagerError + } + return validSCManagerHandle, nil +} +func (f fakeWindowsOS) OpenService(mgr windows.Handle, serviceName *uint16, access uint32) (handle windows.Handle, err error) { + if f.returnsOpenServiceError != nil { + return windows.InvalidHandle, f.returnsOpenServiceError + } + return validOpenServiceHandle, nil +} + +func (f fakeWindowsOS) getWindowsSvcMgr(h windows.Handle) windowsSvcMgr { + return &fakeWindowsSvcMgr{ + Handle: h, + returnsDisconnectError: f.returnsSCMgrDisconnectError, + } +} +func (fakeWindowsOS) getWindowsSvcMgrHandle(sm windowsSvcMgr) windows.Handle { + return sm.(*fakeWindowsSvcMgr).Handle +} + +func (f fakeWindowsOS) getWindowsSvc(name string, h windows.Handle) windowsSvc { + return &fakeWindowsSvc{ + Name: name, + Handle: h, + returnsCloseError: f.returnsServiceCloseError, + returnsServiceQueryError: f.returnsServiceQueryError, + returnsServiceState: f.returnsServiceState, + } +} + +type fakeWindowsSvcMgr struct { + Handle windows.Handle + + returnsDisconnectError error +} + +func (f fakeWindowsSvcMgr) Disconnect() error { return f.returnsDisconnectError } + +type fakeWindowsSvc struct { + Handle windows.Handle + Name string + + returnsServiceQueryError error + returnsCloseError error + returnsServiceState svc.State +} + +func (f fakeWindowsSvc) Close() error { return f.returnsCloseError } +func (f fakeWindowsSvc) Query() (svc.Status, error) { + if f.returnsServiceQueryError != nil { + return svc.Status{}, f.returnsServiceQueryError + } + return svc.Status{State: f.returnsServiceState}, nil +} + +func boolPointer(b bool) *bool { + return &b +} + +func boolVal(v *bool) bool { + if v == nil { + return false + } + return *v +} diff --git a/agent/checks/os_service_windows.go b/agent/checks/os_service_windows.go index 1aa3504486..631543b7b2 100644 --- a/agent/checks/os_service_windows.go +++ b/agent/checks/os_service_windows.go @@ -12,13 +12,17 @@ import ( "golang.org/x/sys/windows/svc/mgr" ) +var ( + win windowsSystem = windowsOS{} +) + type OSServiceClient struct { scHandle windows.Handle } func NewOSServiceClient() (*OSServiceClient, error) { var s *uint16 - scHandle, err := windows.OpenSCManager(s, nil, windows.SC_MANAGER_CONNECT) + scHandle, err := win.OpenSCManager(s, nil, windows.SC_MANAGER_CONNECT) if err != nil { return nil, fmt.Errorf("error connecting to service manager: %w", err) @@ -29,15 +33,32 @@ func NewOSServiceClient() (*OSServiceClient, error) { }, nil } -func (client *OSServiceClient) Check(serviceName string) error { - m := &mgr.Mgr{Handle: client.scHandle} - defer m.Disconnect() - svcHandle, err := windows.OpenService(m.Handle, syscall.StringToUTF16Ptr(serviceName), windows.SC_MANAGER_ENUMERATE_SERVICE) +func (client *OSServiceClient) Check(serviceName string) (err error) { + var isHealthy bool + + m := win.getWindowsSvcMgr(client.scHandle) + defer func() { + errDisconnect := m.Disconnect() + if isHealthy || errDisconnect == nil || err != nil { + return + } + //unreachable at the moment but we might want to log this error. leaving here for code-review + err = errDisconnect + }() + + svcHandle, err := win.OpenService(win.getWindowsSvcMgrHandle(m), syscall.StringToUTF16Ptr(serviceName), windows.SC_MANAGER_ENUMERATE_SERVICE) if err != nil { return fmt.Errorf("error accessing service: %w", err) } - service := &mgr.Service{Name: serviceName, Handle: svcHandle} - defer service.Close() + service := win.getWindowsSvc(serviceName, svcHandle) + defer func() { + errClose := service.Close() + if isHealthy || errClose == nil || err != nil { + return + } + //unreachable at the moment but we might want to log this error. leaving here for code-review + err = errClose + }() status, err := service.Query() if err != nil { return fmt.Errorf("error querying service status: %w", err) @@ -45,10 +66,49 @@ func (client *OSServiceClient) Check(serviceName string) error { switch status.State { case svc.Running: - return nil - case svc.Stopped: - return ErrOSServiceStatusCritical + err = nil + isHealthy = true + case svc.Paused, svc.Stopped: + err = ErrOSServiceStatusCritical default: - return fmt.Errorf("service status: %v", status.State) + err = fmt.Errorf("service status: %v", status.State) } + + return err +} + +type windowsOS struct{} + +func (windowsOS) OpenSCManager(machineName *uint16, databaseName *uint16, access uint32) (handle windows.Handle, err error) { + return windows.OpenSCManager(machineName, databaseName, access) +} +func (windowsOS) OpenService(mgr windows.Handle, serviceName *uint16, access uint32) (handle windows.Handle, err error) { + return windows.OpenService(mgr, serviceName, access) +} + +func (windowsOS) getWindowsSvcMgr(h windows.Handle) windowsSvcMgr { return &mgr.Mgr{Handle: h} } +func (windowsOS) getWindowsSvcMgrHandle(sm windowsSvcMgr) windows.Handle { + return sm.(*mgr.Mgr).Handle +} + +func (windowsOS) getWindowsSvc(name string, h windows.Handle) windowsSvc { + return &mgr.Service{Name: name, Handle: h} +} + +type windowsSystem interface { + OpenSCManager(machineName *uint16, databaseName *uint16, access uint32) (handle windows.Handle, err error) + OpenService(mgr windows.Handle, serviceName *uint16, access uint32) (handle windows.Handle, err error) + + getWindowsSvcMgr(h windows.Handle) windowsSvcMgr + getWindowsSvcMgrHandle(sm windowsSvcMgr) windows.Handle + getWindowsSvc(name string, h windows.Handle) windowsSvc +} + +type windowsSvcMgr interface { + Disconnect() error +} + +type windowsSvc interface { + Close() error + Query() (svc.Status, error) }