consul/agent/checks/os_service_windows.go

125 lines
3.5 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
//go:build windows
// +build windows
package checks
import (
"fmt"
"syscall"
"golang.org/x/sys/windows"
"golang.org/x/sys/windows/svc"
"golang.org/x/sys/windows/svc/mgr"
)
var (
win windowsSystem = windowsOS{}
)
type OSServiceClient struct{}
func NewOSServiceClient() (*OSServiceClient, error) {
return &OSServiceClient{}, nil
}
func (client *OSServiceClient) Check(serviceName string) (err error) {
h, err := windows.OpenSCManager(nil, nil, windows.SC_MANAGER_CONNECT)
if err != nil {
return fmt.Errorf("failed to connect to Windows service manager: %w", err)
}
m := win.getWindowsSvcMgr(h)
defer m.Disconnect()
svcNamePtr, err := syscall.UTF16PtrFromString(serviceName)
if err != nil {
return fmt.Errorf("service name must not contain NUL bytes: %w", err)
}
svcHandle, err := win.OpenService(win.getWindowsSvcMgrHandle(m), svcNamePtr, windows.SC_MANAGER_ENUMERATE_SERVICE)
if err != nil {
return fmt.Errorf("error accessing service: %w", err)
}
service := win.getWindowsSvc(serviceName, svcHandle)
defer service.Close()
status, err := service.Query()
if err != nil {
return fmt.Errorf("error querying service status: %w", err)
}
switch status.State {
case svc.Running:
return nil
case svc.Paused, svc.Stopped:
err = fmt.Errorf("service status: %v - %w", svcStateString(status.State), ErrOSServiceStatusCritical)
default:
err = fmt.Errorf("service status: %v", svcStateString(status.State))
}
return err
}
type windowsOS struct{}
func (windowsOS) OpenSCManager(machineName *uint16, databaseName *uint16, access uint32) (handle windows.Handle, err error) {
return windows.OpenSCManager(machineName, databaseName, access)
}
func (windowsOS) OpenService(mgr windows.Handle, serviceName *uint16, access uint32) (handle windows.Handle, err error) {
return windows.OpenService(mgr, serviceName, access)
}
func (windowsOS) getWindowsSvcMgr(h windows.Handle) windowsSvcMgr { return &mgr.Mgr{Handle: h} }
func (windowsOS) getWindowsSvcMgrHandle(sm windowsSvcMgr) windows.Handle {
return sm.(*mgr.Mgr).Handle
}
func (windowsOS) getWindowsSvc(name string, h windows.Handle) windowsSvc {
return &mgr.Service{Name: name, Handle: h}
}
type windowsSystem interface {
OpenSCManager(machineName *uint16, databaseName *uint16, access uint32) (handle windows.Handle, err error)
OpenService(mgr windows.Handle, serviceName *uint16, access uint32) (handle windows.Handle, err error)
getWindowsSvcMgr(h windows.Handle) windowsSvcMgr
getWindowsSvcMgrHandle(sm windowsSvcMgr) windows.Handle
getWindowsSvc(name string, h windows.Handle) windowsSvc
}
type windowsSvcMgr interface {
Disconnect() error
}
type windowsSvc interface {
Close() error
Query() (svc.Status, error)
}
// svcStateString converts svc.State (uint32) to human readable string
//
// source: https://pkg.go.dev/golang.org/x/sys/windows/svc#pkg-constants
func svcStateString(state svc.State) string {
switch state {
case svc.State(windows.SERVICE_STOPPED):
return "Stopped"
case svc.State(windows.SERVICE_START_PENDING):
return "StartPending"
case svc.State(windows.SERVICE_STOP_PENDING):
return "StopPending"
case svc.State(windows.SERVICE_RUNNING):
return "Running"
case svc.State(windows.SERVICE_CONTINUE_PENDING):
return "ContinuePending"
case svc.State(windows.SERVICE_PAUSE_PENDING):
return "PausePending"
case svc.State(windows.SERVICE_PAUSED):
return "Paused"
default:
//if not handled we return the underlying uint32
return fmt.Sprintf("%d", state)
}
}