mirror of https://github.com/status-im/consul.git
Pulling in the RPC framework from serf
This commit is contained in:
parent
53298520ad
commit
f4692b468f
|
@ -3,6 +3,7 @@ package agent
|
|||
import (
|
||||
"fmt"
|
||||
"github.com/hashicorp/consul/consul"
|
||||
"github.com/hashicorp/serf/serf"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
|
@ -187,3 +188,59 @@ func (a *Agent) Shutdown() error {
|
|||
func (a *Agent) ShutdownCh() <-chan struct{} {
|
||||
return a.shutdownCh
|
||||
}
|
||||
|
||||
// JoinLAN is used to have the agent join a LAN cluster
|
||||
func (a *Agent) JoinLAN(addrs []string) (n int, err error) {
|
||||
a.logger.Printf("[INFO] agent: (LAN) joining: %v", addrs)
|
||||
if a.server != nil {
|
||||
n, err = a.server.JoinLAN(addrs)
|
||||
} else {
|
||||
n, err = a.client.JoinLAN(addrs)
|
||||
}
|
||||
a.logger.Printf("[INFO] agent: (LAN) joined: %d Err: %v", n, err)
|
||||
return
|
||||
}
|
||||
|
||||
// JoinWAN is used to have the agent join a WAN cluster
|
||||
func (a *Agent) JoinWAN(addrs []string) (n int, err error) {
|
||||
a.logger.Printf("[INFO] agent: (WAN) joining: %v", addrs)
|
||||
if a.server != nil {
|
||||
n, err = a.server.JoinWAN(addrs)
|
||||
} else {
|
||||
err = fmt.Errorf("Must be a server to join WAN cluster")
|
||||
}
|
||||
a.logger.Printf("[INFO] agent: (WAN) joined: %d Err: %v", n, err)
|
||||
return
|
||||
}
|
||||
|
||||
// ForceLeave is used to remove a failed node from the cluster
|
||||
func (a *Agent) ForceLeave(node string) (err error) {
|
||||
a.logger.Printf("[INFO] Force leaving node: %v", node)
|
||||
if a.server != nil {
|
||||
err = a.server.RemoveFailedNode(node)
|
||||
} else {
|
||||
err = a.client.RemoveFailedNode(node)
|
||||
}
|
||||
if err != nil {
|
||||
a.logger.Printf("[WARN] Failed to remove node: %v", err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Used to retrieve the LAN members
|
||||
func (a *Agent) LANMembers() []serf.Member {
|
||||
if a.server != nil {
|
||||
return a.server.LANMembers()
|
||||
} else {
|
||||
return a.client.LANMembers()
|
||||
}
|
||||
}
|
||||
|
||||
// Used to retrieve the WAN members
|
||||
func (a *Agent) WANMembers() []serf.Member {
|
||||
if a.server != nil {
|
||||
return a.server.WANMembers()
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ package agent
|
|||
import (
|
||||
"fmt"
|
||||
"github.com/hashicorp/consul/consul"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"sync/atomic"
|
||||
|
@ -18,22 +19,24 @@ func nextConfig() *Config {
|
|||
|
||||
conf.Bootstrap = true
|
||||
conf.Datacenter = "dc1"
|
||||
conf.NodeName = fmt.Sprintf("Node %d", idx)
|
||||
conf.HTTPAddr = fmt.Sprintf("127.0.0.1:%d", 8500+10*idx)
|
||||
conf.RPCAddr = fmt.Sprintf("127.0.0.1:%d", 8400+10*idx)
|
||||
conf.SerfBindAddr = "127.0.0.1"
|
||||
conf.SerfLanPort = int(8301 + 10*idx)
|
||||
conf.SerfWanPort = int(8302 + 10*idx)
|
||||
conf.Server = true
|
||||
conf.ServerAddr = fmt.Sprintf("127.0.0.1:%d", 8100+10*idx)
|
||||
|
||||
cons := consul.DefaultConfig()
|
||||
conf.ConsulConfig = cons
|
||||
|
||||
cons.SerfLANConfig.MemberlistConfig.ProbeTimeout = 200 * time.Millisecond
|
||||
cons.SerfLANConfig.MemberlistConfig.ProbeInterval = time.Second
|
||||
cons.SerfLANConfig.MemberlistConfig.ProbeTimeout = 100 * time.Millisecond
|
||||
cons.SerfLANConfig.MemberlistConfig.ProbeInterval = 100 * time.Millisecond
|
||||
cons.SerfLANConfig.MemberlistConfig.GossipInterval = 100 * time.Millisecond
|
||||
|
||||
cons.SerfWANConfig.MemberlistConfig.ProbeTimeout = 200 * time.Millisecond
|
||||
cons.SerfWANConfig.MemberlistConfig.ProbeInterval = time.Second
|
||||
cons.SerfWANConfig.MemberlistConfig.ProbeTimeout = 100 * time.Millisecond
|
||||
cons.SerfWANConfig.MemberlistConfig.ProbeInterval = 100 * time.Millisecond
|
||||
cons.SerfWANConfig.MemberlistConfig.GossipInterval = 100 * time.Millisecond
|
||||
|
||||
cons.RaftConfig.HeartbeatTimeout = 40 * time.Millisecond
|
||||
|
@ -42,14 +45,14 @@ func nextConfig() *Config {
|
|||
return conf
|
||||
}
|
||||
|
||||
func makeAgent(t *testing.T, conf *Config) (string, *Agent) {
|
||||
func makeAgentLog(t *testing.T, conf *Config, l io.Writer) (string, *Agent) {
|
||||
dir, err := ioutil.TempDir("", "agent")
|
||||
if err != nil {
|
||||
t.Fatalf(fmt.Sprintf("err: %v", err))
|
||||
}
|
||||
|
||||
conf.DataDir = dir
|
||||
agent, err := Create(conf, nil)
|
||||
agent, err := Create(conf, l)
|
||||
if err != nil {
|
||||
os.RemoveAll(dir)
|
||||
t.Fatalf(fmt.Sprintf("err: %v", err))
|
||||
|
@ -58,6 +61,10 @@ func makeAgent(t *testing.T, conf *Config) (string, *Agent) {
|
|||
return dir, agent
|
||||
}
|
||||
|
||||
func makeAgent(t *testing.T, conf *Config) (string, *Agent) {
|
||||
return makeAgentLog(t, conf, nil)
|
||||
}
|
||||
|
||||
func TestAgentStartStop(t *testing.T) {
|
||||
dir, agent := makeAgent(t, nextConfig())
|
||||
defer os.RemoveAll(dir)
|
||||
|
|
|
@ -0,0 +1,544 @@
|
|||
package agent
|
||||
|
||||
/*
|
||||
The agent exposes an RPC mechanism that is used for both controlling
|
||||
Consul as well as providing a fast streaming mechanism for events. This
|
||||
allows other applications to easily leverage Consul without embedding.
|
||||
|
||||
We additionally make use of the RPC layer to also handle calls from
|
||||
the CLI to unify the code paths. This results in a split Request/Response
|
||||
as well as streaming mode of operation.
|
||||
|
||||
The system is fairly simple, each client opens a TCP connection to the
|
||||
agent. The connection is initialized with a handshake which establishes
|
||||
the protocol version being used. This is to allow for future changes to
|
||||
the protocol.
|
||||
|
||||
Once initialized, clients send commands and wait for responses. Certain
|
||||
commands will cause the client to subscribe to events, and those will be
|
||||
pushed down the socket as they are received. This provides a low-latency
|
||||
mechanism for applications to send and receive events, while also providing
|
||||
a flexible control mechanism for Consul.
|
||||
*/
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"github.com/hashicorp/logutils"
|
||||
"github.com/hashicorp/serf/serf"
|
||||
"github.com/ugorji/go/codec"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
const (
|
||||
MinRPCVersion = 1
|
||||
MaxRPCVersion = 1
|
||||
)
|
||||
|
||||
const (
|
||||
handshakeCommand = "handshake"
|
||||
forceLeaveCommand = "force-leave"
|
||||
joinCommand = "join"
|
||||
membersLANCommand = "members-lan"
|
||||
membersWANCommand = "members-wan"
|
||||
stopCommand = "stop"
|
||||
monitorCommand = "monitor"
|
||||
leaveCommand = "leave"
|
||||
)
|
||||
|
||||
const (
|
||||
unsupportedCommand = "Unsupported command"
|
||||
unsupportedRPCVersion = "Unsupported RPC version"
|
||||
duplicateHandshake = "Handshake already performed"
|
||||
handshakeRequired = "Handshake required"
|
||||
monitorExists = "Monitor already exists"
|
||||
)
|
||||
|
||||
// Request header is sent before each request
|
||||
type requestHeader struct {
|
||||
Command string
|
||||
Seq uint64
|
||||
}
|
||||
|
||||
// Response header is sent before each response
|
||||
type responseHeader struct {
|
||||
Seq uint64
|
||||
Error string
|
||||
}
|
||||
|
||||
type handshakeRequest struct {
|
||||
Version int32
|
||||
}
|
||||
|
||||
type eventRequest struct {
|
||||
Name string
|
||||
Payload []byte
|
||||
Coalesce bool
|
||||
}
|
||||
|
||||
type forceLeaveRequest struct {
|
||||
Node string
|
||||
}
|
||||
|
||||
type joinRequest struct {
|
||||
Existing []string
|
||||
WAN bool
|
||||
}
|
||||
|
||||
type joinResponse struct {
|
||||
Num int32
|
||||
}
|
||||
|
||||
type membersResponse struct {
|
||||
Members []Member
|
||||
}
|
||||
|
||||
type monitorRequest struct {
|
||||
LogLevel string
|
||||
}
|
||||
|
||||
type streamRequest struct {
|
||||
Type string
|
||||
}
|
||||
|
||||
type stopRequest struct {
|
||||
Stop uint64
|
||||
}
|
||||
|
||||
type logRecord struct {
|
||||
Log string
|
||||
}
|
||||
|
||||
type userEventRecord struct {
|
||||
Event string
|
||||
LTime serf.LamportTime
|
||||
Name string
|
||||
Payload []byte
|
||||
Coalesce bool
|
||||
}
|
||||
|
||||
type Member struct {
|
||||
Name string
|
||||
Addr net.IP
|
||||
Port uint16
|
||||
Role string
|
||||
Status string
|
||||
ProtocolMin uint8
|
||||
ProtocolMax uint8
|
||||
ProtocolCur uint8
|
||||
DelegateMin uint8
|
||||
DelegateMax uint8
|
||||
DelegateCur uint8
|
||||
}
|
||||
|
||||
type memberEventRecord struct {
|
||||
Event string
|
||||
Members []Member
|
||||
}
|
||||
|
||||
type AgentRPC struct {
|
||||
sync.Mutex
|
||||
agent *Agent
|
||||
clients map[string]*rpcClient
|
||||
listener net.Listener
|
||||
logger *log.Logger
|
||||
logWriter *logWriter
|
||||
stop bool
|
||||
stopCh chan struct{}
|
||||
}
|
||||
|
||||
type rpcClient struct {
|
||||
name string
|
||||
conn net.Conn
|
||||
reader *bufio.Reader
|
||||
writer *bufio.Writer
|
||||
dec *codec.Decoder
|
||||
enc *codec.Encoder
|
||||
writeLock sync.Mutex
|
||||
version int32 // From the handshake, 0 before
|
||||
logStreamer *logStream
|
||||
}
|
||||
|
||||
// send is used to send an object using the MsgPack encoding. send
|
||||
// is serialized to prevent write overlaps, while properly buffering.
|
||||
func (c *rpcClient) Send(header *responseHeader, obj interface{}) error {
|
||||
c.writeLock.Lock()
|
||||
defer c.writeLock.Unlock()
|
||||
|
||||
if err := c.enc.Encode(header); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if obj != nil {
|
||||
if err := c.enc.Encode(obj); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := c.writer.Flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *rpcClient) String() string {
|
||||
return fmt.Sprintf("rpc.client: %v", c.conn)
|
||||
}
|
||||
|
||||
// NewAgentRPC is used to create a new Agent RPC handler
|
||||
func NewAgentRPC(agent *Agent, listener net.Listener,
|
||||
logOutput io.Writer, logWriter *logWriter) *AgentRPC {
|
||||
if logOutput == nil {
|
||||
logOutput = os.Stderr
|
||||
}
|
||||
rpc := &AgentRPC{
|
||||
agent: agent,
|
||||
clients: make(map[string]*rpcClient),
|
||||
listener: listener,
|
||||
logger: log.New(logOutput, "", log.LstdFlags),
|
||||
logWriter: logWriter,
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
go rpc.listen()
|
||||
return rpc
|
||||
}
|
||||
|
||||
// Shutdown is used to shutdown the RPC layer
|
||||
func (i *AgentRPC) Shutdown() {
|
||||
i.Lock()
|
||||
defer i.Unlock()
|
||||
|
||||
if i.stop {
|
||||
return
|
||||
}
|
||||
|
||||
i.stop = true
|
||||
close(i.stopCh)
|
||||
i.listener.Close()
|
||||
|
||||
// Close the existing connections
|
||||
for _, client := range i.clients {
|
||||
client.conn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// listen is a long running routine that listens for new clients
|
||||
func (i *AgentRPC) listen() {
|
||||
for {
|
||||
conn, err := i.listener.Accept()
|
||||
if err != nil {
|
||||
if i.stop {
|
||||
return
|
||||
}
|
||||
i.logger.Printf("[ERR] agent.rpc: Failed to accept client: %v", err)
|
||||
continue
|
||||
}
|
||||
i.logger.Printf("[INFO] agent.rpc: Accepted client: %v", conn.RemoteAddr())
|
||||
|
||||
// Wrap the connection in a client
|
||||
client := &rpcClient{
|
||||
name: conn.RemoteAddr().String(),
|
||||
conn: conn,
|
||||
reader: bufio.NewReader(conn),
|
||||
writer: bufio.NewWriter(conn),
|
||||
}
|
||||
client.dec = codec.NewDecoder(client.reader,
|
||||
&codec.MsgpackHandle{RawToString: true, WriteExt: true})
|
||||
client.enc = codec.NewEncoder(client.writer,
|
||||
&codec.MsgpackHandle{RawToString: true, WriteExt: true})
|
||||
if err != nil {
|
||||
i.logger.Printf("[ERR] agent.rpc: Failed to create decoder: %v", err)
|
||||
conn.Close()
|
||||
continue
|
||||
}
|
||||
|
||||
// Register the client
|
||||
i.Lock()
|
||||
if !i.stop {
|
||||
i.clients[client.name] = client
|
||||
go i.handleClient(client)
|
||||
} else {
|
||||
conn.Close()
|
||||
}
|
||||
i.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// deregisterClient is called to cleanup after a client disconnects
|
||||
func (i *AgentRPC) deregisterClient(client *rpcClient) {
|
||||
// Close the socket
|
||||
client.conn.Close()
|
||||
|
||||
// Remove from the clients list
|
||||
i.Lock()
|
||||
delete(i.clients, client.name)
|
||||
i.Unlock()
|
||||
|
||||
// Remove from the log writer
|
||||
if client.logStreamer != nil {
|
||||
i.logWriter.DeregisterHandler(client.logStreamer)
|
||||
client.logStreamer.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
// handleClient is a long running routine that handles a single client
|
||||
func (i *AgentRPC) handleClient(client *rpcClient) {
|
||||
defer i.deregisterClient(client)
|
||||
var reqHeader requestHeader
|
||||
for {
|
||||
// Decode the header
|
||||
if err := client.dec.Decode(&reqHeader); err != nil {
|
||||
if err != io.EOF && !i.stop {
|
||||
i.logger.Printf("[ERR] agent.rpc: failed to decode request header: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Evaluate the command
|
||||
if err := i.handleRequest(client, &reqHeader); err != nil {
|
||||
i.logger.Printf("[ERR] agent.rpc: Failed to evaluate request: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleRequest is used to evaluate a single client command
|
||||
func (i *AgentRPC) handleRequest(client *rpcClient, reqHeader *requestHeader) error {
|
||||
// Look for a command field
|
||||
command := reqHeader.Command
|
||||
seq := reqHeader.Seq
|
||||
|
||||
// Ensure the handshake is performed before other commands
|
||||
if command != handshakeCommand && client.version == 0 {
|
||||
respHeader := responseHeader{Seq: seq, Error: handshakeRequired}
|
||||
client.Send(&respHeader, nil)
|
||||
return fmt.Errorf(handshakeRequired)
|
||||
}
|
||||
|
||||
// Dispatch command specific handlers
|
||||
switch command {
|
||||
case handshakeCommand:
|
||||
return i.handleHandshake(client, seq)
|
||||
|
||||
case membersLANCommand:
|
||||
return i.handleMembersLAN(client, seq)
|
||||
|
||||
case membersWANCommand:
|
||||
return i.handleMembersWAN(client, seq)
|
||||
|
||||
case monitorCommand:
|
||||
return i.handleMonitor(client, seq)
|
||||
|
||||
case stopCommand:
|
||||
return i.handleStop(client, seq)
|
||||
|
||||
case forceLeaveCommand:
|
||||
return i.handleForceLeave(client, seq)
|
||||
|
||||
case joinCommand:
|
||||
return i.handleJoin(client, seq)
|
||||
|
||||
case leaveCommand:
|
||||
return i.handleLeave(client, seq)
|
||||
|
||||
default:
|
||||
respHeader := responseHeader{Seq: seq, Error: unsupportedCommand}
|
||||
client.Send(&respHeader, nil)
|
||||
return fmt.Errorf("command '%s' not recognized", command)
|
||||
}
|
||||
}
|
||||
|
||||
func (i *AgentRPC) handleHandshake(client *rpcClient, seq uint64) error {
|
||||
var req handshakeRequest
|
||||
if err := client.dec.Decode(&req); err != nil {
|
||||
return fmt.Errorf("decode failed: %v", err)
|
||||
}
|
||||
|
||||
resp := responseHeader{
|
||||
Seq: seq,
|
||||
Error: "",
|
||||
}
|
||||
|
||||
// Check the version
|
||||
if req.Version < MinRPCVersion || req.Version > MaxRPCVersion {
|
||||
resp.Error = unsupportedRPCVersion
|
||||
} else if client.version != 0 {
|
||||
resp.Error = duplicateHandshake
|
||||
} else {
|
||||
client.version = req.Version
|
||||
}
|
||||
return client.Send(&resp, nil)
|
||||
}
|
||||
|
||||
func (i *AgentRPC) handleForceLeave(client *rpcClient, seq uint64) error {
|
||||
var req forceLeaveRequest
|
||||
if err := client.dec.Decode(&req); err != nil {
|
||||
return fmt.Errorf("decode failed: %v", err)
|
||||
}
|
||||
|
||||
// Attempt leave
|
||||
err := i.agent.ForceLeave(req.Node)
|
||||
|
||||
// Respond
|
||||
resp := responseHeader{
|
||||
Seq: seq,
|
||||
Error: errToString(err),
|
||||
}
|
||||
return client.Send(&resp, nil)
|
||||
}
|
||||
|
||||
func (i *AgentRPC) handleJoin(client *rpcClient, seq uint64) error {
|
||||
var req joinRequest
|
||||
if err := client.dec.Decode(&req); err != nil {
|
||||
return fmt.Errorf("decode failed: %v", err)
|
||||
}
|
||||
|
||||
// Attempt the join
|
||||
var num int
|
||||
var err error
|
||||
if req.WAN {
|
||||
num, err = i.agent.JoinWAN(req.Existing)
|
||||
} else {
|
||||
num, err = i.agent.JoinLAN(req.Existing)
|
||||
}
|
||||
|
||||
// Respond
|
||||
header := responseHeader{
|
||||
Seq: seq,
|
||||
Error: errToString(err),
|
||||
}
|
||||
resp := joinResponse{
|
||||
Num: int32(num),
|
||||
}
|
||||
return client.Send(&header, &resp)
|
||||
}
|
||||
|
||||
func (i *AgentRPC) handleMembersLAN(client *rpcClient, seq uint64) error {
|
||||
raw := i.agent.LANMembers()
|
||||
return formatMembers(raw, client, seq)
|
||||
}
|
||||
|
||||
func (i *AgentRPC) handleMembersWAN(client *rpcClient, seq uint64) error {
|
||||
raw := i.agent.WANMembers()
|
||||
return formatMembers(raw, client, seq)
|
||||
}
|
||||
|
||||
func formatMembers(raw []serf.Member, client *rpcClient, seq uint64) error {
|
||||
members := make([]Member, 0, len(raw))
|
||||
for _, m := range raw {
|
||||
sm := Member{
|
||||
Name: m.Name,
|
||||
Addr: m.Addr,
|
||||
Port: m.Port,
|
||||
Role: m.Role,
|
||||
Status: m.Status.String(),
|
||||
ProtocolMin: m.ProtocolMin,
|
||||
ProtocolMax: m.ProtocolMax,
|
||||
ProtocolCur: m.ProtocolCur,
|
||||
DelegateMin: m.DelegateMin,
|
||||
DelegateMax: m.DelegateMax,
|
||||
DelegateCur: m.DelegateCur,
|
||||
}
|
||||
members = append(members, sm)
|
||||
}
|
||||
|
||||
header := responseHeader{
|
||||
Seq: seq,
|
||||
Error: "",
|
||||
}
|
||||
resp := membersResponse{
|
||||
Members: members,
|
||||
}
|
||||
return client.Send(&header, &resp)
|
||||
}
|
||||
|
||||
func (i *AgentRPC) handleMonitor(client *rpcClient, seq uint64) error {
|
||||
var req monitorRequest
|
||||
if err := client.dec.Decode(&req); err != nil {
|
||||
return fmt.Errorf("decode failed: %v", err)
|
||||
}
|
||||
|
||||
resp := responseHeader{
|
||||
Seq: seq,
|
||||
Error: "",
|
||||
}
|
||||
|
||||
// Upper case the log level
|
||||
req.LogLevel = strings.ToUpper(req.LogLevel)
|
||||
|
||||
// Create a level filter
|
||||
filter := LevelFilter()
|
||||
filter.MinLevel = logutils.LogLevel(req.LogLevel)
|
||||
if !ValidateLevelFilter(filter.MinLevel, filter) {
|
||||
resp.Error = fmt.Sprintf("Unknown log level: %s", filter.MinLevel)
|
||||
goto SEND
|
||||
}
|
||||
|
||||
// Check if there is an existing monitor
|
||||
if client.logStreamer != nil {
|
||||
resp.Error = monitorExists
|
||||
goto SEND
|
||||
}
|
||||
|
||||
// Create a log streamer
|
||||
client.logStreamer = newLogStream(client, filter, seq, i.logger)
|
||||
|
||||
// Register with the log writer. Defer so that we can respond before
|
||||
// registration, avoids any possible race condition
|
||||
defer i.logWriter.RegisterHandler(client.logStreamer)
|
||||
|
||||
SEND:
|
||||
return client.Send(&resp, nil)
|
||||
}
|
||||
|
||||
func (i *AgentRPC) handleStop(client *rpcClient, seq uint64) error {
|
||||
var req stopRequest
|
||||
if err := client.dec.Decode(&req); err != nil {
|
||||
return fmt.Errorf("decode failed: %v", err)
|
||||
}
|
||||
|
||||
// Remove a log monitor if any
|
||||
if client.logStreamer != nil && client.logStreamer.seq == req.Stop {
|
||||
i.logWriter.DeregisterHandler(client.logStreamer)
|
||||
client.logStreamer.Stop()
|
||||
client.logStreamer = nil
|
||||
}
|
||||
|
||||
// Always succeed
|
||||
resp := responseHeader{Seq: seq, Error: ""}
|
||||
return client.Send(&resp, nil)
|
||||
}
|
||||
|
||||
func (i *AgentRPC) handleLeave(client *rpcClient, seq uint64) error {
|
||||
i.logger.Printf("[INFO] agent.rpc: Graceful leave triggered")
|
||||
|
||||
// Do the leave
|
||||
err := i.agent.Leave()
|
||||
if err != nil {
|
||||
i.logger.Printf("[ERR] agent.rpc: leave failed: %v", err)
|
||||
}
|
||||
resp := responseHeader{Seq: seq, Error: errToString(err)}
|
||||
|
||||
// Send and wait
|
||||
err = client.Send(&resp, nil)
|
||||
|
||||
// Trigger a shutdown!
|
||||
if err := i.agent.Shutdown(); err != nil {
|
||||
i.logger.Printf("[ERR] agent.rpc: shutdown failed: %v", err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Used to convert an error to a string representation
|
||||
func errToString(err error) string {
|
||||
if err == nil {
|
||||
return ""
|
||||
}
|
||||
return err.Error()
|
||||
}
|
|
@ -0,0 +1,399 @@
|
|||
package agent
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"github.com/hashicorp/logutils"
|
||||
"github.com/ugorji/go/codec"
|
||||
"log"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
var (
|
||||
clientClosed = fmt.Errorf("client closed")
|
||||
)
|
||||
|
||||
type seqCallback struct {
|
||||
handler func(*responseHeader)
|
||||
}
|
||||
|
||||
func (sc *seqCallback) Handle(resp *responseHeader) {
|
||||
sc.handler(resp)
|
||||
}
|
||||
func (sc *seqCallback) Cleanup() {}
|
||||
|
||||
// seqHandler interface is used to handle responses
|
||||
type seqHandler interface {
|
||||
Handle(*responseHeader)
|
||||
Cleanup()
|
||||
}
|
||||
|
||||
// RPCClient is the RPC client to make requests to the agent RPC.
|
||||
type RPCClient struct {
|
||||
seq uint64
|
||||
|
||||
conn *net.TCPConn
|
||||
reader *bufio.Reader
|
||||
writer *bufio.Writer
|
||||
dec *codec.Decoder
|
||||
enc *codec.Encoder
|
||||
writeLock sync.Mutex
|
||||
|
||||
dispatch map[uint64]seqHandler
|
||||
dispatchLock sync.Mutex
|
||||
|
||||
shutdown bool
|
||||
shutdownCh chan struct{}
|
||||
shutdownLock sync.Mutex
|
||||
}
|
||||
|
||||
// send is used to send an object using the MsgPack encoding. send
|
||||
// is serialized to prevent write overlaps, while properly buffering.
|
||||
func (c *RPCClient) send(header *requestHeader, obj interface{}) error {
|
||||
c.writeLock.Lock()
|
||||
defer c.writeLock.Unlock()
|
||||
|
||||
if c.shutdown {
|
||||
return clientClosed
|
||||
}
|
||||
|
||||
if err := c.enc.Encode(header); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if obj != nil {
|
||||
if err := c.enc.Encode(obj); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := c.writer.Flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewRPCClient is used to create a new RPC client given the address.
|
||||
// This will properly dial, handshake, and start listening
|
||||
func NewRPCClient(addr string) (*RPCClient, error) {
|
||||
// Try to dial to agent
|
||||
conn, err := net.Dial("tcp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create the client
|
||||
client := &RPCClient{
|
||||
seq: 0,
|
||||
conn: conn.(*net.TCPConn),
|
||||
reader: bufio.NewReader(conn),
|
||||
writer: bufio.NewWriter(conn),
|
||||
dispatch: make(map[uint64]seqHandler),
|
||||
shutdownCh: make(chan struct{}),
|
||||
}
|
||||
client.dec = codec.NewDecoder(client.reader,
|
||||
&codec.MsgpackHandle{RawToString: true, WriteExt: true})
|
||||
client.enc = codec.NewEncoder(client.writer,
|
||||
&codec.MsgpackHandle{RawToString: true, WriteExt: true})
|
||||
go client.listen()
|
||||
|
||||
// Do the initial handshake
|
||||
if err := client.handshake(); err != nil {
|
||||
client.Close()
|
||||
return nil, err
|
||||
}
|
||||
return client, err
|
||||
}
|
||||
|
||||
// StreamHandle is an opaque handle passed to stop to stop streaming
|
||||
type StreamHandle uint64
|
||||
|
||||
// Close is used to free any resources associated with the client
|
||||
func (c *RPCClient) Close() error {
|
||||
c.shutdownLock.Lock()
|
||||
defer c.shutdownLock.Unlock()
|
||||
|
||||
if !c.shutdown {
|
||||
c.shutdown = true
|
||||
close(c.shutdownCh)
|
||||
c.deregisterAll()
|
||||
return c.conn.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ForceLeave is used to ask the agent to issue a leave command for
|
||||
// a given node
|
||||
func (c *RPCClient) ForceLeave(node string) error {
|
||||
header := requestHeader{
|
||||
Command: forceLeaveCommand,
|
||||
Seq: c.getSeq(),
|
||||
}
|
||||
req := forceLeaveRequest{
|
||||
Node: node,
|
||||
}
|
||||
return c.genericRPC(&header, &req, nil)
|
||||
}
|
||||
|
||||
// Join is used to instruct the agent to attempt a join
|
||||
func (c *RPCClient) Join(addrs []string, wan bool) (int, error) {
|
||||
header := requestHeader{
|
||||
Command: joinCommand,
|
||||
Seq: c.getSeq(),
|
||||
}
|
||||
req := joinRequest{
|
||||
Existing: addrs,
|
||||
WAN: wan,
|
||||
}
|
||||
var resp joinResponse
|
||||
|
||||
err := c.genericRPC(&header, &req, &resp)
|
||||
return int(resp.Num), err
|
||||
}
|
||||
|
||||
// LANMembers is used to fetch a list of known members
|
||||
func (c *RPCClient) LANMembers() ([]Member, error) {
|
||||
header := requestHeader{
|
||||
Command: membersLANCommand,
|
||||
Seq: c.getSeq(),
|
||||
}
|
||||
var resp membersResponse
|
||||
|
||||
err := c.genericRPC(&header, nil, &resp)
|
||||
return resp.Members, err
|
||||
}
|
||||
|
||||
// WANMembers is used to fetch a list of known members
|
||||
func (c *RPCClient) WANMembers() ([]Member, error) {
|
||||
header := requestHeader{
|
||||
Command: membersWANCommand,
|
||||
Seq: c.getSeq(),
|
||||
}
|
||||
var resp membersResponse
|
||||
|
||||
err := c.genericRPC(&header, nil, &resp)
|
||||
return resp.Members, err
|
||||
}
|
||||
|
||||
// Leave is used to trigger a graceful leave and shutdown
|
||||
func (c *RPCClient) Leave() error {
|
||||
header := requestHeader{
|
||||
Command: leaveCommand,
|
||||
Seq: c.getSeq(),
|
||||
}
|
||||
return c.genericRPC(&header, nil, nil)
|
||||
}
|
||||
|
||||
type monitorHandler struct {
|
||||
client *RPCClient
|
||||
closed bool
|
||||
init bool
|
||||
initCh chan<- error
|
||||
logCh chan<- string
|
||||
seq uint64
|
||||
}
|
||||
|
||||
func (mh *monitorHandler) Handle(resp *responseHeader) {
|
||||
// Initialize on the first response
|
||||
if !mh.init {
|
||||
mh.init = true
|
||||
mh.initCh <- strToError(resp.Error)
|
||||
return
|
||||
}
|
||||
|
||||
// Decode logs for all other responses
|
||||
var rec logRecord
|
||||
if err := mh.client.dec.Decode(&rec); err != nil {
|
||||
log.Printf("[ERR] Failed to decode log: %v", err)
|
||||
mh.client.deregisterHandler(mh.seq)
|
||||
return
|
||||
}
|
||||
select {
|
||||
case mh.logCh <- rec.Log:
|
||||
default:
|
||||
log.Printf("[ERR] Dropping log! Monitor channel full")
|
||||
}
|
||||
}
|
||||
|
||||
func (mh *monitorHandler) Cleanup() {
|
||||
if !mh.closed {
|
||||
if !mh.init {
|
||||
mh.init = true
|
||||
mh.initCh <- fmt.Errorf("Stream closed")
|
||||
}
|
||||
close(mh.logCh)
|
||||
mh.closed = true
|
||||
}
|
||||
}
|
||||
|
||||
// Monitor is used to subscribe to the logs of the agent
|
||||
func (c *RPCClient) Monitor(level logutils.LogLevel, ch chan<- string) (StreamHandle, error) {
|
||||
// Setup the request
|
||||
seq := c.getSeq()
|
||||
header := requestHeader{
|
||||
Command: monitorCommand,
|
||||
Seq: seq,
|
||||
}
|
||||
req := monitorRequest{
|
||||
LogLevel: string(level),
|
||||
}
|
||||
|
||||
// Create a monitor handler
|
||||
initCh := make(chan error, 1)
|
||||
handler := &monitorHandler{
|
||||
client: c,
|
||||
initCh: initCh,
|
||||
logCh: ch,
|
||||
seq: seq,
|
||||
}
|
||||
c.handleSeq(seq, handler)
|
||||
|
||||
// Send the request
|
||||
if err := c.send(&header, &req); err != nil {
|
||||
c.deregisterHandler(seq)
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Wait for a response
|
||||
select {
|
||||
case err := <-initCh:
|
||||
return StreamHandle(seq), err
|
||||
case <-c.shutdownCh:
|
||||
c.deregisterHandler(seq)
|
||||
return 0, clientClosed
|
||||
}
|
||||
}
|
||||
|
||||
// Stop is used to unsubscribe from logs or event streams
|
||||
func (c *RPCClient) Stop(handle StreamHandle) error {
|
||||
// Deregister locally first to stop delivery
|
||||
c.deregisterHandler(uint64(handle))
|
||||
|
||||
header := requestHeader{
|
||||
Command: stopCommand,
|
||||
Seq: c.getSeq(),
|
||||
}
|
||||
req := stopRequest{
|
||||
Stop: uint64(handle),
|
||||
}
|
||||
return c.genericRPC(&header, &req, nil)
|
||||
}
|
||||
|
||||
// handshake is used to perform the initial handshake on connect
|
||||
func (c *RPCClient) handshake() error {
|
||||
header := requestHeader{
|
||||
Command: handshakeCommand,
|
||||
Seq: c.getSeq(),
|
||||
}
|
||||
req := handshakeRequest{
|
||||
Version: MaxRPCVersion,
|
||||
}
|
||||
return c.genericRPC(&header, &req, nil)
|
||||
}
|
||||
|
||||
// genericRPC is used to send a request and wait for an
|
||||
// errorSequenceResponse, potentially returning an error
|
||||
func (c *RPCClient) genericRPC(header *requestHeader, req interface{}, resp interface{}) error {
|
||||
// Setup a response handler
|
||||
errCh := make(chan error, 1)
|
||||
handler := func(respHeader *responseHeader) {
|
||||
if resp != nil {
|
||||
err := c.dec.Decode(resp)
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
}
|
||||
errCh <- strToError(respHeader.Error)
|
||||
}
|
||||
c.handleSeq(header.Seq, &seqCallback{handler: handler})
|
||||
defer c.deregisterHandler(header.Seq)
|
||||
|
||||
// Send the request
|
||||
if err := c.send(header, req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Wait for a response
|
||||
select {
|
||||
case err := <-errCh:
|
||||
return err
|
||||
case <-c.shutdownCh:
|
||||
return clientClosed
|
||||
}
|
||||
}
|
||||
|
||||
// strToError converts a string to an error if not blank
|
||||
func strToError(s string) error {
|
||||
if s != "" {
|
||||
return fmt.Errorf(s)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// getSeq returns the next sequence number in a safe manner
|
||||
func (c *RPCClient) getSeq() uint64 {
|
||||
return atomic.AddUint64(&c.seq, 1)
|
||||
}
|
||||
|
||||
// deregisterAll is used to deregister all handlers
|
||||
func (c *RPCClient) deregisterAll() {
|
||||
c.dispatchLock.Lock()
|
||||
defer c.dispatchLock.Unlock()
|
||||
|
||||
for _, seqH := range c.dispatch {
|
||||
seqH.Cleanup()
|
||||
}
|
||||
c.dispatch = make(map[uint64]seqHandler)
|
||||
}
|
||||
|
||||
// deregisterHandler is used to deregister a handler
|
||||
func (c *RPCClient) deregisterHandler(seq uint64) {
|
||||
c.dispatchLock.Lock()
|
||||
seqH, ok := c.dispatch[seq]
|
||||
delete(c.dispatch, seq)
|
||||
c.dispatchLock.Unlock()
|
||||
|
||||
if ok {
|
||||
seqH.Cleanup()
|
||||
}
|
||||
}
|
||||
|
||||
// handleSeq is used to setup a handlerto wait on a response for
|
||||
// a given sequence number.
|
||||
func (c *RPCClient) handleSeq(seq uint64, handler seqHandler) {
|
||||
c.dispatchLock.Lock()
|
||||
defer c.dispatchLock.Unlock()
|
||||
c.dispatch[seq] = handler
|
||||
}
|
||||
|
||||
// respondSeq is used to respond to a given sequence number
|
||||
func (c *RPCClient) respondSeq(seq uint64, respHeader *responseHeader) {
|
||||
c.dispatchLock.Lock()
|
||||
seqL, ok := c.dispatch[seq]
|
||||
c.dispatchLock.Unlock()
|
||||
|
||||
// Get a registered listener, ignore if none
|
||||
if ok {
|
||||
seqL.Handle(respHeader)
|
||||
}
|
||||
}
|
||||
|
||||
// listen is used to processes data coming over the RPC channel,
|
||||
// and wrote it to the correct destination based on seq no
|
||||
func (c *RPCClient) listen() {
|
||||
defer c.Close()
|
||||
var respHeader responseHeader
|
||||
for {
|
||||
if err := c.dec.Decode(&respHeader); err != nil {
|
||||
if !c.shutdown {
|
||||
log.Printf("[ERR] agent.client: Failed to decode response header: %v", err)
|
||||
}
|
||||
break
|
||||
}
|
||||
c.respondSeq(respHeader.Seq, &respHeader)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,264 @@
|
|||
package agent
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/hashicorp/serf/serf"
|
||||
"github.com/hashicorp/serf/testutil"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type rpcParts struct {
|
||||
dir string
|
||||
client *RPCClient
|
||||
agent *Agent
|
||||
rpc *AgentRPC
|
||||
}
|
||||
|
||||
func (r *rpcParts) Close() {
|
||||
r.client.Close()
|
||||
r.rpc.Shutdown()
|
||||
r.agent.Shutdown()
|
||||
os.RemoveAll(r.dir)
|
||||
}
|
||||
|
||||
// testRPCClient returns an RPCClient connected to an RPC server that
|
||||
// serves only this connection.
|
||||
func testRPCClient(t *testing.T) *rpcParts {
|
||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
lw := NewLogWriter(512)
|
||||
mult := io.MultiWriter(os.Stderr, lw)
|
||||
|
||||
conf := nextConfig()
|
||||
dir, agent := makeAgentLog(t, conf, mult)
|
||||
rpc := NewAgentRPC(agent, l, mult, lw)
|
||||
|
||||
rpcClient, err := NewRPCClient(l.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
return &rpcParts{
|
||||
dir: dir,
|
||||
client: rpcClient,
|
||||
agent: agent,
|
||||
rpc: rpc,
|
||||
}
|
||||
}
|
||||
|
||||
func TestRPCClientForceLeave(t *testing.T) {
|
||||
p1 := testRPCClient(t)
|
||||
p2 := testRPCClient(t)
|
||||
defer p1.Close()
|
||||
defer p2.Close()
|
||||
testutil.Yield()
|
||||
|
||||
s2Addr := fmt.Sprintf("127.0.0.1:%d", p2.agent.config.SerfLanPort)
|
||||
if _, err := p1.agent.JoinLAN([]string{s2Addr}); err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
testutil.Yield()
|
||||
|
||||
if err := p2.agent.Shutdown(); err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
time.Sleep(time.Second)
|
||||
|
||||
if err := p1.client.ForceLeave(p2.agent.config.NodeName); err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
testutil.Yield()
|
||||
|
||||
m := p1.agent.LANMembers()
|
||||
if len(m) != 2 {
|
||||
t.Fatalf("should have 2 members: %#v", m)
|
||||
}
|
||||
|
||||
if m[1].Status != serf.StatusLeft {
|
||||
t.Fatalf("should be left: %#v %v", m[1], m[1].Status == serf.StatusLeft)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRPCClientJoinLAN(t *testing.T) {
|
||||
p1 := testRPCClient(t)
|
||||
p2 := testRPCClient(t)
|
||||
defer p1.Close()
|
||||
defer p2.Close()
|
||||
testutil.Yield()
|
||||
|
||||
s2Addr := fmt.Sprintf("127.0.0.1:%d", p2.agent.config.SerfLanPort)
|
||||
n, err := p1.client.Join([]string{s2Addr}, false)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
if n != 1 {
|
||||
t.Fatalf("n != 1: %d", n)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRPCClientJoinWAN(t *testing.T) {
|
||||
p1 := testRPCClient(t)
|
||||
p2 := testRPCClient(t)
|
||||
defer p1.Close()
|
||||
defer p2.Close()
|
||||
testutil.Yield()
|
||||
|
||||
s2Addr := fmt.Sprintf("127.0.0.1:%d", p2.agent.config.SerfWanPort)
|
||||
n, err := p1.client.Join([]string{s2Addr}, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
if n != 1 {
|
||||
t.Fatalf("n != 1: %d", n)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRPCClientLANMembers(t *testing.T) {
|
||||
p1 := testRPCClient(t)
|
||||
p2 := testRPCClient(t)
|
||||
defer p1.Close()
|
||||
defer p2.Close()
|
||||
testutil.Yield()
|
||||
|
||||
mem, err := p1.client.LANMembers()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
if len(mem) != 1 {
|
||||
t.Fatalf("bad: %#v", mem)
|
||||
}
|
||||
|
||||
s2Addr := fmt.Sprintf("127.0.0.1:%d", p2.agent.config.SerfLanPort)
|
||||
_, err = p1.client.Join([]string{s2Addr}, false)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
testutil.Yield()
|
||||
|
||||
mem, err = p1.client.LANMembers()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
if len(mem) != 2 {
|
||||
t.Fatalf("bad: %#v", mem)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRPCClientWANMembers(t *testing.T) {
|
||||
p1 := testRPCClient(t)
|
||||
p2 := testRPCClient(t)
|
||||
defer p1.Close()
|
||||
defer p2.Close()
|
||||
testutil.Yield()
|
||||
|
||||
mem, err := p1.client.WANMembers()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
if len(mem) != 1 {
|
||||
t.Fatalf("bad: %#v", mem)
|
||||
}
|
||||
|
||||
s2Addr := fmt.Sprintf("127.0.0.1:%d", p2.agent.config.SerfWanPort)
|
||||
_, err = p1.client.Join([]string{s2Addr}, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
testutil.Yield()
|
||||
|
||||
mem, err = p1.client.WANMembers()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
if len(mem) != 2 {
|
||||
t.Fatalf("bad: %#v", mem)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRPCClientLeave(t *testing.T) {
|
||||
p1 := testRPCClient(t)
|
||||
defer p1.Close()
|
||||
testutil.Yield()
|
||||
|
||||
if err := p1.client.Leave(); err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
testutil.Yield()
|
||||
|
||||
select {
|
||||
case <-p1.agent.ShutdownCh():
|
||||
default:
|
||||
t.Fatalf("agent should be shutdown!")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRPCClientMonitor(t *testing.T) {
|
||||
p1 := testRPCClient(t)
|
||||
defer p1.Close()
|
||||
testutil.Yield()
|
||||
|
||||
eventCh := make(chan string, 64)
|
||||
if handle, err := p1.client.Monitor("debug", eventCh); err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
} else {
|
||||
defer p1.client.Stop(handle)
|
||||
}
|
||||
|
||||
testutil.Yield()
|
||||
|
||||
found := false
|
||||
OUTER1:
|
||||
for {
|
||||
select {
|
||||
case e := <-eventCh:
|
||||
if strings.Contains(e, "Accepted client") {
|
||||
found = true
|
||||
}
|
||||
default:
|
||||
break OUTER1
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Fatalf("should log client accept")
|
||||
}
|
||||
|
||||
// Join a bad thing to generate more events
|
||||
p1.agent.JoinLAN(nil)
|
||||
testutil.Yield()
|
||||
|
||||
found = false
|
||||
OUTER2:
|
||||
for {
|
||||
select {
|
||||
case e := <-eventCh:
|
||||
if strings.Contains(e, "joining") {
|
||||
found = true
|
||||
}
|
||||
default:
|
||||
break OUTER2
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Fatalf("should log joining")
|
||||
}
|
||||
}
|
|
@ -0,0 +1,68 @@
|
|||
package agent
|
||||
|
||||
import (
|
||||
"github.com/hashicorp/logutils"
|
||||
"log"
|
||||
)
|
||||
|
||||
type streamClient interface {
|
||||
Send(*responseHeader, interface{}) error
|
||||
}
|
||||
|
||||
// logStream is used to stream logs to a client over RPC
|
||||
type logStream struct {
|
||||
client streamClient
|
||||
filter *logutils.LevelFilter
|
||||
logCh chan string
|
||||
logger *log.Logger
|
||||
seq uint64
|
||||
}
|
||||
|
||||
func newLogStream(client streamClient, filter *logutils.LevelFilter,
|
||||
seq uint64, logger *log.Logger) *logStream {
|
||||
ls := &logStream{
|
||||
client: client,
|
||||
filter: filter,
|
||||
logCh: make(chan string, 512),
|
||||
logger: logger,
|
||||
seq: seq,
|
||||
}
|
||||
go ls.stream()
|
||||
return ls
|
||||
}
|
||||
|
||||
func (ls *logStream) HandleLog(l string) {
|
||||
// Check the log level
|
||||
if !ls.filter.Check([]byte(l)) {
|
||||
return
|
||||
}
|
||||
|
||||
// Do a non-blocking send
|
||||
select {
|
||||
case ls.logCh <- l:
|
||||
default:
|
||||
// We can't log syncronously, since we are already being invoked
|
||||
// from the logWriter, and a log will need to invoke Write() which
|
||||
// already holds the lock. We must therefor do the log async, so
|
||||
// as to not deadlock
|
||||
go ls.logger.Printf("[WARN] Dropping logs to %v", ls.client)
|
||||
}
|
||||
}
|
||||
|
||||
func (ls *logStream) Stop() {
|
||||
close(ls.logCh)
|
||||
}
|
||||
|
||||
func (ls *logStream) stream() {
|
||||
header := responseHeader{Seq: ls.seq, Error: ""}
|
||||
rec := logRecord{Log: ""}
|
||||
|
||||
for line := range ls.logCh {
|
||||
rec.Log = line
|
||||
if err := ls.client.Send(&header, &rec); err != nil {
|
||||
ls.logger.Printf("[ERR] Failed to stream log to %v: %v",
|
||||
ls.client, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,54 @@
|
|||
package agent
|
||||
|
||||
import (
|
||||
"github.com/hashicorp/logutils"
|
||||
"log"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type MockStreamClient struct {
|
||||
headers []*responseHeader
|
||||
objs []interface{}
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *MockStreamClient) Send(h *responseHeader, o interface{}) error {
|
||||
m.headers = append(m.headers, h)
|
||||
m.objs = append(m.objs, o)
|
||||
return m.err
|
||||
}
|
||||
|
||||
func TestRPCLogStream(t *testing.T) {
|
||||
sc := &MockStreamClient{}
|
||||
filter := LevelFilter()
|
||||
filter.MinLevel = logutils.LogLevel("INFO")
|
||||
|
||||
ls := newLogStream(sc, filter, 42, log.New(os.Stderr, "", log.LstdFlags))
|
||||
defer ls.Stop()
|
||||
|
||||
log := "[DEBUG] this is a test log"
|
||||
log2 := "[INFO] This should pass"
|
||||
ls.HandleLog(log)
|
||||
ls.HandleLog(log2)
|
||||
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
|
||||
if len(sc.headers) != 1 {
|
||||
t.Fatalf("expected 1 messages!")
|
||||
}
|
||||
for _, h := range sc.headers {
|
||||
if h.Seq != 42 {
|
||||
t.Fatalf("bad seq")
|
||||
}
|
||||
if h.Error != "" {
|
||||
t.Fatalf("bad err")
|
||||
}
|
||||
}
|
||||
|
||||
obj1 := sc.objs[0].(*logRecord)
|
||||
if obj1.Log != log2 {
|
||||
t.Fatalf("bad event %#v", obj1)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue