474 lines
12 KiB
Go
Raw Normal View History

package client
import (
"crypto/tls"
"fmt"
"io"
"log"
"math/rand"
"net"
"net/rpc"
"os"
"strings"
"sync"
"time"
"github.com/armon/go-metrics"
"github.com/hashicorp/net-rpc-msgpackrpc"
)
const (
// DefaultEndpoint is the endpoint used if none is provided
DefaultEndpoint = "scada.hashicorp.com:7223"
// DefaultBackoff is the amount of time we back off if we encounter
// and error, and no specific backoff is available.
DefaultBackoff = 120 * time.Second
// DisconnectDelay is how long we delay the disconnect to allow
// the RPC to complete.
DisconnectDelay = time.Second
)
// CapabilityProvider is used to provide a given capability
// when requested remotely. They must return a connection
// that is bridged or an error.
type CapabilityProvider func(capability string, meta map[string]string, conn io.ReadWriteCloser) error
// ProviderService is the service being exposed
type ProviderService struct {
Service string
ServiceVersion string
Capabilities map[string]int
Meta map[string]string
ResourceType string
}
// ProviderConfig is used to parameterize a provider
type ProviderConfig struct {
// Endpoint is the SCADA endpoint, defaults to DefaultEndpoint
Endpoint string
// Service is the service to expose
Service *ProviderService
// Handlers are invoked to provide the named capability
Handlers map[string]CapabilityProvider
// ResourceGroup is the named group e.g. "hashicorp/prod"
ResourceGroup string
// Token is the Atlas authentication token
Token string
// Optional TLS configuration, defaults used otherwise
TLSConfig *tls.Config
// LogOutput is to control the log output
LogOutput io.Writer
}
// Provider is a high-level interface to SCADA by which
// clients declare themselves as a service providing capabilities.
// Provider manages the client/server interactions required,
// making it simpler to integrate.
type Provider struct {
config *ProviderConfig
logger *log.Logger
client *Client
clientLock sync.Mutex
noRetry bool // set when the server instructs us to not retry
backoff time.Duration // set when the server provides a longer backoff
backoffLock sync.Mutex
sessionID string
sessionAuth bool
sessionLock sync.RWMutex
shutdown bool
shutdownCh chan struct{}
shutdownLock sync.Mutex
}
// validateConfig is used to sanity check the configuration
func validateConfig(config *ProviderConfig) error {
// Validate the inputs
if config == nil {
return fmt.Errorf("missing config")
}
if config.Service == nil {
return fmt.Errorf("missing service")
}
if config.Service.Service == "" {
return fmt.Errorf("missing service name")
}
if config.Service.ServiceVersion == "" {
return fmt.Errorf("missing service version")
}
if config.Service.ResourceType == "" {
return fmt.Errorf("missing service resource type")
}
if config.Handlers == nil && len(config.Service.Capabilities) != 0 {
return fmt.Errorf("missing handlers")
}
for c := range config.Service.Capabilities {
if _, ok := config.Handlers[c]; !ok {
return fmt.Errorf("missing handler for '%s' capability", c)
}
}
if config.ResourceGroup == "" {
return fmt.Errorf("missing resource group")
}
if config.Token == "" {
config.Token = os.Getenv("ATLAS_TOKEN")
}
if config.Token == "" {
return fmt.Errorf("missing token")
}
// Default the endpoint
if config.Endpoint == "" {
config.Endpoint = DefaultEndpoint
if end := os.Getenv("SCADA_ENDPOINT"); end != "" {
config.Endpoint = end
}
}
return nil
}
// NewProvider is used to create a new provider
func NewProvider(config *ProviderConfig) (*Provider, error) {
if err := validateConfig(config); err != nil {
return nil, err
}
// Create logger
if config.LogOutput == nil {
config.LogOutput = os.Stderr
}
logger := log.New(config.LogOutput, "", log.LstdFlags)
p := &Provider{
config: config,
logger: logger,
shutdownCh: make(chan struct{}),
}
go p.run()
return p, nil
}
// Shutdown is used to close the provider
func (p *Provider) Shutdown() {
p.shutdownLock.Lock()
p.shutdownLock.Unlock()
if p.shutdown {
return
}
p.shutdown = true
close(p.shutdownCh)
}
// IsShutdown checks if we have been shutdown
func (p *Provider) IsShutdown() bool {
select {
case <-p.shutdownCh:
return true
default:
return false
}
}
// backoffDuration is used to compute the next backoff duration
func (p *Provider) backoffDuration() time.Duration {
// Use the default backoff
backoff := DefaultBackoff
// Check for a server specified backoff
p.backoffLock.Lock()
if p.backoff != 0 {
backoff = p.backoff
}
if p.noRetry {
backoff = 0
}
p.backoffLock.Unlock()
return backoff
}
// wait is used to delay dialing on an error
func (p *Provider) wait() {
// Compute the backoff time
backoff := p.backoffDuration()
// Setup a wait timer
var wait <-chan time.Time
if backoff > 0 {
jitter := time.Duration(rand.Uint32()) % backoff
wait = time.After(backoff + jitter)
}
// Wait until timer or shutdown
select {
case <-wait:
case <-p.shutdownCh:
}
}
// run is a long running routine to manage the provider
func (p *Provider) run() {
for !p.IsShutdown() {
// Setup a new connection
client, err := p.clientSetup()
if err != nil {
p.wait()
continue
}
// Handle the session
doneCh := make(chan struct{})
go p.handleSession(client, doneCh)
// Wait for session termination or shutdown
select {
case <-doneCh:
p.wait()
case <-p.shutdownCh:
p.clientLock.Lock()
client.Close()
p.clientLock.Unlock()
return
}
}
}
// handleSession is used to handle an established session
func (p *Provider) handleSession(list net.Listener, doneCh chan struct{}) {
defer close(doneCh)
defer list.Close()
// Accept new connections
for !p.IsShutdown() {
conn, err := list.Accept()
if err != nil {
p.logger.Printf("[ERR] scada-client: failed to accept connection: %v", err)
return
}
p.logger.Printf("[DEBUG] scada-client: accepted connection")
go p.handleConnection(conn)
}
}
// handleConnection handles an incoming connection
func (p *Provider) handleConnection(conn net.Conn) {
// Create an RPC server to handle inbound
pe := &providerEndpoint{p: p}
rpcServer := rpc.NewServer()
rpcServer.RegisterName("Client", pe)
rpcCodec := msgpackrpc.NewCodec(false, false, conn)
defer func() {
if !pe.hijacked() {
conn.Close()
}
}()
for !p.IsShutdown() {
if err := rpcServer.ServeRequest(rpcCodec); err != nil {
if err != io.EOF && !strings.Contains(err.Error(), "closed") {
p.logger.Printf("[ERR] scada-client: RPC error: %v", err)
}
return
}
// Handle potential hijack in Client.Connect
if pe.hijacked() {
cb := pe.getHijack()
cb(conn)
return
}
}
}
// clientSetup is used to setup a new connection
func (p *Provider) clientSetup() (*Client, error) {
defer metrics.MeasureSince([]string{"scada", "setup"}, time.Now())
// Reset the previous backoff
p.backoffLock.Lock()
p.noRetry = false
p.backoff = 0
p.backoffLock.Unlock()
// Dial a new connection
opts := Opts{
Addr: p.config.Endpoint,
TLS: true,
TLSConfig: p.config.TLSConfig,
LogOutput: p.config.LogOutput,
}
client, err := DialOpts(&opts)
if err != nil {
p.logger.Printf("[ERR] scada-client: failed to dial: %v", err)
return nil, err
}
// Perform a handshake
resp, err := p.handshake(client)
if err != nil {
p.logger.Printf("[ERR] scada-client: failed to handshake: %v", err)
client.Close()
return nil, err
}
if resp != nil && resp.SessionID != "" {
p.logger.Printf("[DEBUG] scada-client: assigned session '%s'", resp.SessionID)
}
if resp != nil && !resp.Authenticated {
p.logger.Printf("[WARN] scada-client: authentication failed: %v", resp.Reason)
}
// Set the new client
p.clientLock.Lock()
if p.client != nil {
p.client.Close()
}
p.client = client
p.clientLock.Unlock()
p.sessionLock.Lock()
p.sessionID = resp.SessionID
p.sessionAuth = resp.Authenticated
p.sessionLock.Unlock()
return client, nil
}
// SessionID provides the current session ID
func (p *Provider) SessionID() string {
p.sessionLock.RLock()
defer p.sessionLock.RUnlock()
return p.sessionID
}
// SessionAuth checks if the current session is authenticated
func (p *Provider) SessionAuthenticated() bool {
p.sessionLock.RLock()
defer p.sessionLock.RUnlock()
return p.sessionAuth
}
// handshake does the initial handshake
func (p *Provider) handshake(client *Client) (*HandshakeResponse, error) {
defer metrics.MeasureSince([]string{"scada", "handshake"}, time.Now())
req := HandshakeRequest{
Service: p.config.Service.Service,
ServiceVersion: p.config.Service.ServiceVersion,
Capabilities: p.config.Service.Capabilities,
Meta: p.config.Service.Meta,
ResourceType: p.config.Service.ResourceType,
ResourceGroup: p.config.ResourceGroup,
Token: p.config.Token,
}
resp := new(HandshakeResponse)
if err := client.RPC("Session.Handshake", &req, resp); err != nil {
return nil, err
}
return resp, nil
}
type HijackFunc func(io.ReadWriteCloser)
// providerEndpoint is used to implement the Client.* RPC endpoints
// as part of the provider.
type providerEndpoint struct {
p *Provider
hijack HijackFunc
}
// Hijacked is used to check if the connection has been hijacked
func (pe *providerEndpoint) hijacked() bool {
return pe.hijack != nil
}
// GetHijack returns the hijack function
func (pe *providerEndpoint) getHijack() HijackFunc {
return pe.hijack
}
// Hijack is used to take over the yamux stream for Client.Connect
func (pe *providerEndpoint) setHijack(cb HijackFunc) {
pe.hijack = cb
}
// Connect is invoked by the broker to connect to a capability
func (pe *providerEndpoint) Connect(args *ConnectRequest, resp *ConnectResponse) error {
defer metrics.IncrCounter([]string{"scada", "connect", args.Capability}, 1)
pe.p.logger.Printf("[INFO] scada-client: connect requested (capability: %s)",
args.Capability)
// Handle potential flash
if args.Severity != "" && args.Message != "" {
pe.p.logger.Printf("[%s] scada-client: %s", args.Severity, args.Message)
}
// Look for the handler
handler := pe.p.config.Handlers[args.Capability]
if handler == nil {
pe.p.logger.Printf("[WARN] scada-client: requested capability '%s' not available",
args.Capability)
return fmt.Errorf("invalid capability")
}
// Hijack the connection
pe.setHijack(func(a io.ReadWriteCloser) {
if err := handler(args.Capability, args.Meta, a); err != nil {
pe.p.logger.Printf("[ERR] scada-client: '%s' handler error: %v",
args.Capability, err)
}
})
resp.Success = true
return nil
}
// Disconnect is invoked by the broker to ask us to backoff
func (pe *providerEndpoint) Disconnect(args *DisconnectRequest, resp *DisconnectResponse) error {
defer metrics.IncrCounter([]string{"scada", "disconnect"}, 1)
if args.Reason == "" {
args.Reason = "<no reason provided>"
}
pe.p.logger.Printf("[INFO] scada-client: disconnect requested (retry: %v, backoff: %v): %v",
!args.NoRetry, args.Backoff, args.Reason)
// Use the backoff information
pe.p.backoffLock.Lock()
pe.p.noRetry = args.NoRetry
pe.p.backoff = args.Backoff
pe.p.backoffLock.Unlock()
// Clear the session information
pe.p.sessionLock.Lock()
pe.p.sessionID = ""
pe.p.sessionAuth = false
pe.p.sessionLock.Unlock()
// Force the disconnect
time.AfterFunc(DisconnectDelay, func() {
pe.p.clientLock.Lock()
if pe.p.client != nil {
pe.p.client.Close()
}
pe.p.clientLock.Unlock()
})
return nil
}
// Flash is invoked by the broker log a message
func (pe *providerEndpoint) Flash(args *FlashRequest, resp *FlashResponse) error {
defer metrics.IncrCounter([]string{"scada", "flash"}, 1)
if args.Severity != "" && args.Message != "" {
pe.p.logger.Printf("[%s] scada-client: %s", args.Severity, args.Message)
}
return nil
}