mirror of https://github.com/status-im/consul.git
Factor out duplicate functions into a lib package
Consolidate code duplication and tests into a single lib package. Most of these functions were from various **/util.go functions that couldn't be imported due to cyclic imports. The consul/lib package is intended to be a terminal node in an import DAG and a place to stash various consul-only helper functions. Pulled in hashicorp/go-uuid instead of consolidating UUID access.
This commit is contained in:
parent
40707934d2
commit
7af6a94edb
|
@ -18,6 +18,7 @@ import (
|
|||
"github.com/hashicorp/consul/consul"
|
||||
"github.com/hashicorp/consul/consul/state"
|
||||
"github.com/hashicorp/consul/consul/structs"
|
||||
"github.com/hashicorp/consul/lib"
|
||||
"github.com/hashicorp/serf/coordinate"
|
||||
"github.com/hashicorp/serf/serf"
|
||||
)
|
||||
|
@ -600,8 +601,8 @@ func (a *Agent) sendCoordinate() {
|
|||
for {
|
||||
rate := a.config.SyncCoordinateRateTarget
|
||||
min := a.config.SyncCoordinateIntervalMin
|
||||
intv := rateScaledInterval(rate, min, len(a.LANMembers()))
|
||||
intv = intv + randomStagger(intv)
|
||||
intv := lib.RateScaledInterval(rate, min, len(a.LANMembers()))
|
||||
intv = intv + lib.RandomStagger(intv)
|
||||
|
||||
select {
|
||||
case <-time.After(intv):
|
||||
|
|
|
@ -15,6 +15,7 @@ import (
|
|||
"github.com/armon/circbuf"
|
||||
docker "github.com/fsouza/go-dockerclient"
|
||||
"github.com/hashicorp/consul/consul/structs"
|
||||
"github.com/hashicorp/consul/lib"
|
||||
"github.com/hashicorp/go-cleanhttp"
|
||||
)
|
||||
|
||||
|
@ -131,7 +132,7 @@ func (c *CheckMonitor) Stop() {
|
|||
// run is invoked by a goroutine to run until Stop() is called
|
||||
func (c *CheckMonitor) run() {
|
||||
// Get the randomized initial pause time
|
||||
initialPauseTime := randomStagger(c.Interval)
|
||||
initialPauseTime := lib.RandomStagger(c.Interval)
|
||||
c.Logger.Printf("[DEBUG] agent: pausing %v before first invocation of %s", initialPauseTime, c.Script)
|
||||
next := time.After(initialPauseTime)
|
||||
for {
|
||||
|
@ -366,7 +367,7 @@ func (c *CheckHTTP) Stop() {
|
|||
// run is invoked by a goroutine to run until Stop() is called
|
||||
func (c *CheckHTTP) run() {
|
||||
// Get the randomized initial pause time
|
||||
initialPauseTime := randomStagger(c.Interval)
|
||||
initialPauseTime := lib.RandomStagger(c.Interval)
|
||||
c.Logger.Printf("[DEBUG] agent: pausing %v before first HTTP request of %s", initialPauseTime, c.HTTP)
|
||||
next := time.After(initialPauseTime)
|
||||
for {
|
||||
|
@ -482,7 +483,7 @@ func (c *CheckTCP) Stop() {
|
|||
// run is invoked by a goroutine to run until Stop() is called
|
||||
func (c *CheckTCP) run() {
|
||||
// Get the randomized initial pause time
|
||||
initialPauseTime := randomStagger(c.Interval)
|
||||
initialPauseTime := lib.RandomStagger(c.Interval)
|
||||
c.Logger.Printf("[DEBUG] agent: pausing %v before first socket connection of %s", initialPauseTime, c.TCP)
|
||||
next := time.After(initialPauseTime)
|
||||
for {
|
||||
|
@ -580,7 +581,7 @@ func (c *CheckDocker) Stop() {
|
|||
// run is invoked by a goroutine to run until Stop() is called
|
||||
func (c *CheckDocker) run() {
|
||||
// Get the randomized initial pause time
|
||||
initialPauseTime := randomStagger(c.Interval)
|
||||
initialPauseTime := lib.RandomStagger(c.Interval)
|
||||
c.Logger.Printf("[DEBUG] agent: pausing %v before first invocation of %s -c %s in container %s", initialPauseTime, c.Shell, c.Script, c.DockerContainerID)
|
||||
next := time.After(initialPauseTime)
|
||||
for {
|
||||
|
|
|
@ -16,6 +16,7 @@ import (
|
|||
|
||||
"github.com/armon/go-metrics"
|
||||
"github.com/armon/go-metrics/datadog"
|
||||
"github.com/hashicorp/consul/lib"
|
||||
"github.com/hashicorp/consul/watch"
|
||||
"github.com/hashicorp/go-checkpoint"
|
||||
"github.com/hashicorp/go-reap"
|
||||
|
@ -424,7 +425,7 @@ func (c *Command) setupAgent(config *Config, logOutput io.Writer, logWriter *log
|
|||
|
||||
// Do an immediate check within the next 30 seconds
|
||||
go func() {
|
||||
time.Sleep(randomStagger(30 * time.Second))
|
||||
time.Sleep(lib.RandomStagger(30 * time.Second))
|
||||
c.checkpointResults(checkpoint.Check(updateParams))
|
||||
}()
|
||||
}
|
||||
|
|
|
@ -13,6 +13,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/hashicorp/consul/consul"
|
||||
"github.com/hashicorp/consul/lib"
|
||||
"github.com/hashicorp/consul/watch"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
)
|
||||
|
@ -634,7 +635,7 @@ func DecodeConfig(r io.Reader) (*Config, error) {
|
|||
allowedKeys := []string{"service", "services", "check", "checks"}
|
||||
var unused []string
|
||||
for _, field := range md.Unused {
|
||||
if !strContains(allowedKeys, field) {
|
||||
if !lib.StrContains(allowedKeys, field) {
|
||||
unused = append(unused, field)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -11,6 +11,8 @@ import (
|
|||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/consul/lib"
|
||||
)
|
||||
|
||||
func TestConfigEncryptBytes(t *testing.T) {
|
||||
|
@ -1103,7 +1105,7 @@ func TestDecodeConfig_Service(t *testing.T) {
|
|||
t.Fatalf("bad: %v", serv)
|
||||
}
|
||||
|
||||
if !strContains(serv.Tags, "master") {
|
||||
if !lib.StrContains(serv.Tags, "master") {
|
||||
t.Fatalf("bad: %v", serv)
|
||||
}
|
||||
|
||||
|
|
|
@ -10,6 +10,7 @@ import (
|
|||
|
||||
"github.com/hashicorp/consul/consul"
|
||||
"github.com/hashicorp/consul/consul/structs"
|
||||
"github.com/hashicorp/consul/lib"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -252,7 +253,7 @@ func (l *localState) UpdateCheck(checkID, status, output string) {
|
|||
if l.config.CheckUpdateInterval > 0 && check.Status == status {
|
||||
check.Output = output
|
||||
if _, ok := l.deferCheck[checkID]; !ok {
|
||||
intv := time.Duration(uint64(l.config.CheckUpdateInterval)/2) + randomStagger(l.config.CheckUpdateInterval)
|
||||
intv := time.Duration(uint64(l.config.CheckUpdateInterval)/2) + lib.RandomStagger(l.config.CheckUpdateInterval)
|
||||
deferSync := time.AfterFunc(intv, func() {
|
||||
l.Lock()
|
||||
if _, ok := l.checkStatus[checkID]; ok {
|
||||
|
@ -302,11 +303,11 @@ SYNC:
|
|||
case <-l.consulCh:
|
||||
// Stagger the retry on leader election, avoid a thundering heard
|
||||
select {
|
||||
case <-time.After(randomStagger(aeScale(syncStaggerIntv, len(l.iface.LANMembers())))):
|
||||
case <-time.After(lib.RandomStagger(aeScale(syncStaggerIntv, len(l.iface.LANMembers())))):
|
||||
case <-shutdownCh:
|
||||
return
|
||||
}
|
||||
case <-time.After(syncRetryIntv + randomStagger(aeScale(syncRetryIntv, len(l.iface.LANMembers())))):
|
||||
case <-time.After(syncRetryIntv + lib.RandomStagger(aeScale(syncRetryIntv, len(l.iface.LANMembers())))):
|
||||
case <-shutdownCh:
|
||||
return
|
||||
}
|
||||
|
@ -317,7 +318,7 @@ SYNC:
|
|||
|
||||
// Schedule the next full sync, with a random stagger
|
||||
aeIntv := aeScale(l.config.AEInterval, len(l.iface.LANMembers()))
|
||||
aeIntv = aeIntv + randomStagger(aeIntv)
|
||||
aeIntv = aeIntv + lib.RandomStagger(aeIntv)
|
||||
aeTimer := time.After(aeIntv)
|
||||
|
||||
// Wait for sync events
|
||||
|
|
|
@ -10,8 +10,17 @@ import (
|
|||
|
||||
"github.com/hashicorp/consul/consul/structs"
|
||||
"github.com/hashicorp/consul/testutil"
|
||||
"github.com/hashicorp/go-uuid"
|
||||
)
|
||||
|
||||
func generateUUID() (ret string) {
|
||||
var err error
|
||||
if ret, err = uuid.GenerateUUID(); err != nil {
|
||||
return "DEADC0DE-BADD-CAFE-D00D-FEEDFACECAFE"
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func TestRexecWriter(t *testing.T) {
|
||||
writer := &rexecWriter{
|
||||
BufCh: make(chan []byte, 16),
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"regexp"
|
||||
|
||||
"github.com/hashicorp/consul/consul/structs"
|
||||
"github.com/hashicorp/go-uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -78,7 +79,10 @@ func (a *Agent) UserEvent(dc, token string, params *UserEvent) error {
|
|||
}
|
||||
|
||||
// Format message
|
||||
params.ID = generateUUID()
|
||||
var err error
|
||||
if params.ID, err = uuid.GenerateUUID(); err != nil {
|
||||
return fmt.Errorf("UUID generation failed: %v", err)
|
||||
}
|
||||
params.Version = userEventMaxVersion
|
||||
payload, err := encodeMsgPack(¶ms)
|
||||
if err != nil {
|
||||
|
|
|
@ -3,10 +3,8 @@ package agent
|
|||
import (
|
||||
"bytes"
|
||||
"crypto/md5"
|
||||
crand "crypto/rand"
|
||||
"fmt"
|
||||
"math"
|
||||
"math/rand"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
|
@ -39,32 +37,6 @@ func aeScale(interval time.Duration, n int) time.Duration {
|
|||
return time.Duration(multiplier) * interval
|
||||
}
|
||||
|
||||
// rateScaledInterval is used to choose an interval to perform an action in order
|
||||
// to target an aggregate number of actions per second across the whole cluster.
|
||||
func rateScaledInterval(rate float64, min time.Duration, n int) time.Duration {
|
||||
interval := time.Duration(float64(time.Second) * float64(n) / rate)
|
||||
if interval < min {
|
||||
return min
|
||||
}
|
||||
|
||||
return interval
|
||||
}
|
||||
|
||||
// Returns a random stagger interval between 0 and the duration
|
||||
func randomStagger(intv time.Duration) time.Duration {
|
||||
return time.Duration(uint64(rand.Int63()) % uint64(intv))
|
||||
}
|
||||
|
||||
// strContains checks if a list contains a string
|
||||
func strContains(l []string, s string) bool {
|
||||
for _, v := range l {
|
||||
if v == s {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ExecScript returns a command to execute a script
|
||||
func ExecScript(script string) (*exec.Cmd, error) {
|
||||
var shell, flag string
|
||||
|
@ -82,21 +54,6 @@ func ExecScript(script string) (*exec.Cmd, error) {
|
|||
return cmd, nil
|
||||
}
|
||||
|
||||
// generateUUID is used to generate a random UUID
|
||||
func generateUUID() string {
|
||||
buf := make([]byte, 16)
|
||||
if _, err := crand.Read(buf); err != nil {
|
||||
panic(fmt.Errorf("failed to read random bytes: %v", err))
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%08x-%04x-%04x-%04x-%12x",
|
||||
buf[0:4],
|
||||
buf[4:6],
|
||||
buf[6:8],
|
||||
buf[8:10],
|
||||
buf[10:16])
|
||||
}
|
||||
|
||||
// decodeMsgPack is used to decode a MsgPack encoded object
|
||||
func decodeMsgPack(buf []byte, out interface{}) error {
|
||||
return codec.NewDecoder(bytes.NewReader(buf), msgpackHandle).Decode(out)
|
||||
|
|
|
@ -24,39 +24,6 @@ func TestAEScale(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestRateScaledInterval(t *testing.T) {
|
||||
min := 1 * time.Second
|
||||
rate := 200.0
|
||||
if v := rateScaledInterval(rate, min, 0); v != min {
|
||||
t.Fatalf("Bad: %v", v)
|
||||
}
|
||||
if v := rateScaledInterval(rate, min, 100); v != min {
|
||||
t.Fatalf("Bad: %v", v)
|
||||
}
|
||||
if v := rateScaledInterval(rate, min, 200); v != 1*time.Second {
|
||||
t.Fatalf("Bad: %v", v)
|
||||
}
|
||||
if v := rateScaledInterval(rate, min, 1000); v != 5*time.Second {
|
||||
t.Fatalf("Bad: %v", v)
|
||||
}
|
||||
if v := rateScaledInterval(rate, min, 5000); v != 25*time.Second {
|
||||
t.Fatalf("Bad: %v", v)
|
||||
}
|
||||
if v := rateScaledInterval(rate, min, 10000); v != 50*time.Second {
|
||||
t.Fatalf("Bad: %v", v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRandomStagger(t *testing.T) {
|
||||
intv := time.Minute
|
||||
for i := 0; i < 10; i++ {
|
||||
stagger := randomStagger(intv)
|
||||
if stagger < 0 || stagger >= intv {
|
||||
t.Fatalf("Bad: %v", stagger)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestStringHash(t *testing.T) {
|
||||
in := "hello world"
|
||||
expected := "5eb63bbbe01eeed093cb22bb8f5acdc3"
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
"github.com/armon/go-metrics"
|
||||
"github.com/hashicorp/consul/acl"
|
||||
"github.com/hashicorp/consul/consul/structs"
|
||||
"github.com/hashicorp/go-uuid"
|
||||
)
|
||||
|
||||
// ACL endpoint is used to manipulate ACLs
|
||||
|
@ -62,7 +63,11 @@ func (a *ACL) Apply(args *structs.ACLRequest, reply *string) error {
|
|||
if args.ACL.ID == "" {
|
||||
state := a.srv.fsm.State()
|
||||
for {
|
||||
args.ACL.ID = generateUUID()
|
||||
if args.ACL.ID, err = uuid.GenerateUUID(); err != nil {
|
||||
a.srv.logger.Printf("[ERR] consul.acl: UUID generation failed: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
_, acl, err := state.ACLGet(args.ACL.ID)
|
||||
if err != nil {
|
||||
a.srv.logger.Printf("[ERR] consul.acl: ACL lookup failed: %v", err)
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/hashicorp/consul/consul/structs"
|
||||
"github.com/hashicorp/consul/lib"
|
||||
"github.com/hashicorp/consul/testutil"
|
||||
"github.com/hashicorp/net-rpc-msgpackrpc"
|
||||
)
|
||||
|
@ -436,7 +437,7 @@ func TestACLEndpoint_List(t *testing.T) {
|
|||
if s.ID == anonymousToken || s.ID == "root" {
|
||||
continue
|
||||
}
|
||||
if !strContains(ids, s.ID) {
|
||||
if !lib.StrContains(ids, s.ID) {
|
||||
t.Fatalf("bad: %v", s)
|
||||
}
|
||||
if s.Name != "User token" {
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/hashicorp/consul/consul/structs"
|
||||
"github.com/hashicorp/consul/lib"
|
||||
"github.com/hashicorp/consul/testutil"
|
||||
"github.com/hashicorp/net-rpc-msgpackrpc"
|
||||
)
|
||||
|
@ -978,7 +979,7 @@ func TestCatalogNodeServices(t *testing.T) {
|
|||
t.Fatalf("bad: %v", out)
|
||||
}
|
||||
services := out.NodeServices.Services
|
||||
if !strContains(services["db"].Tags, "primary") || services["db"].Port != 5000 {
|
||||
if !lib.StrContains(services["db"].Tags, "primary") || services["db"].Port != 5000 {
|
||||
t.Fatalf("bad: %v", out)
|
||||
}
|
||||
if len(services["web"].Tags) != 0 || services["web"].Port != 80 {
|
||||
|
|
|
@ -8,6 +8,8 @@ import (
|
|||
|
||||
"github.com/hashicorp/consul/consul/state"
|
||||
"github.com/hashicorp/consul/consul/structs"
|
||||
"github.com/hashicorp/consul/lib"
|
||||
"github.com/hashicorp/go-uuid"
|
||||
"github.com/hashicorp/raft"
|
||||
)
|
||||
|
||||
|
@ -38,6 +40,14 @@ func makeLog(buf []byte) *raft.Log {
|
|||
}
|
||||
}
|
||||
|
||||
func generateUUID() (ret string) {
|
||||
var err error
|
||||
if ret, err = uuid.GenerateUUID(); err != nil {
|
||||
return "DEADC0DE-BADD-CAFE-D00D-FEEDFACECAFE"
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func TestFSM_RegisterNode(t *testing.T) {
|
||||
fsm, err := NewFSM(nil, os.Stderr)
|
||||
if err != nil {
|
||||
|
@ -452,7 +462,7 @@ func TestFSM_SnapshotRestore(t *testing.T) {
|
|||
if len(fooSrv.Services) != 2 {
|
||||
t.Fatalf("Bad: %v", fooSrv)
|
||||
}
|
||||
if !strContains(fooSrv.Services["db"].Tags, "primary") {
|
||||
if !lib.StrContains(fooSrv.Services["db"].Tags, "primary") {
|
||||
t.Fatalf("Bad: %v", fooSrv)
|
||||
}
|
||||
if fooSrv.Services["db"].Port != 5000 {
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/hashicorp/consul/consul/structs"
|
||||
"github.com/hashicorp/consul/lib"
|
||||
"github.com/hashicorp/consul/testutil"
|
||||
"github.com/hashicorp/net-rpc-msgpackrpc"
|
||||
)
|
||||
|
@ -377,10 +378,10 @@ func TestHealth_ServiceNodes(t *testing.T) {
|
|||
if nodes[1].Node.Node != "foo" {
|
||||
t.Fatalf("Bad: %v", nodes[1])
|
||||
}
|
||||
if !strContains(nodes[0].Service.Tags, "slave") {
|
||||
if !lib.StrContains(nodes[0].Service.Tags, "slave") {
|
||||
t.Fatalf("Bad: %v", nodes[0])
|
||||
}
|
||||
if !strContains(nodes[1].Service.Tags, "master") {
|
||||
if !lib.StrContains(nodes[1].Service.Tags, "master") {
|
||||
t.Fatalf("Bad: %v", nodes[1])
|
||||
}
|
||||
if nodes[0].Checks[0].Status != structs.HealthWarning {
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/hashicorp/consul/consul/structs"
|
||||
"github.com/hashicorp/consul/lib"
|
||||
"github.com/hashicorp/consul/testutil"
|
||||
"github.com/hashicorp/net-rpc-msgpackrpc"
|
||||
)
|
||||
|
@ -56,7 +57,7 @@ func TestInternal_NodeInfo(t *testing.T) {
|
|||
if nodes[0].Node != "foo" {
|
||||
t.Fatalf("Bad: %v", nodes[0])
|
||||
}
|
||||
if !strContains(nodes[0].Services[0].Tags, "master") {
|
||||
if !lib.StrContains(nodes[0].Services[0].Tags, "master") {
|
||||
t.Fatalf("Bad: %v", nodes[0])
|
||||
}
|
||||
if nodes[0].Checks[0].Status != structs.HealthPassing {
|
||||
|
@ -130,7 +131,7 @@ func TestInternal_NodeDump(t *testing.T) {
|
|||
switch node.Node {
|
||||
case "foo":
|
||||
foundFoo = true
|
||||
if !strContains(node.Services[0].Tags, "master") {
|
||||
if !lib.StrContains(node.Services[0].Tags, "master") {
|
||||
t.Fatalf("Bad: %v", nodes[0])
|
||||
}
|
||||
if node.Checks[0].Status != structs.HealthPassing {
|
||||
|
@ -139,7 +140,7 @@ func TestInternal_NodeDump(t *testing.T) {
|
|||
|
||||
case "bar":
|
||||
foundBar = true
|
||||
if !strContains(node.Services[0].Tags, "slave") {
|
||||
if !lib.StrContains(node.Services[0].Tags, "slave") {
|
||||
t.Fatalf("Bad: %v", nodes[1])
|
||||
}
|
||||
if node.Checks[0].Status != structs.HealthWarning {
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
|
||||
"github.com/armon/go-metrics"
|
||||
"github.com/hashicorp/consul/consul/structs"
|
||||
"github.com/hashicorp/go-uuid"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -41,7 +42,9 @@ func (p *PreparedQuery) Apply(args *structs.PreparedQueryRequest, reply *string)
|
|||
// to collide since this isn't inside a write transaction.
|
||||
state := p.srv.fsm.State()
|
||||
for {
|
||||
args.Query.ID = generateUUID()
|
||||
if args.Query.ID, err = uuid.GenerateUUID(); err != nil {
|
||||
return fmt.Errorf("UUID generation for prepared query failed: %v", err)
|
||||
}
|
||||
_, query, err := state.PreparedQueryGet(args.Query.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Prepared query lookup failed: %v", err)
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
"github.com/armon/go-metrics"
|
||||
"github.com/hashicorp/consul/consul/state"
|
||||
"github.com/hashicorp/consul/consul/structs"
|
||||
"github.com/hashicorp/consul/lib"
|
||||
"github.com/hashicorp/memberlist"
|
||||
"github.com/hashicorp/net-rpc-msgpackrpc"
|
||||
"github.com/hashicorp/yamux"
|
||||
|
@ -329,7 +330,7 @@ func (s *Server) blockingRPC(queryOpts *structs.QueryOptions, queryMeta *structs
|
|||
}
|
||||
|
||||
// Apply a small amount of jitter to the request.
|
||||
queryOpts.MaxQueryTime += randomStagger(queryOpts.MaxQueryTime / jitterFraction)
|
||||
queryOpts.MaxQueryTime += lib.RandomStagger(queryOpts.MaxQueryTime / jitterFraction)
|
||||
|
||||
// Setup a query timeout.
|
||||
timeout = time.NewTimer(queryOpts.MaxQueryTime)
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
|
||||
"github.com/armon/go-metrics"
|
||||
"github.com/hashicorp/consul/consul/structs"
|
||||
"github.com/hashicorp/go-uuid"
|
||||
)
|
||||
|
||||
// Session endpoint is used to manipulate sessions for KV
|
||||
|
@ -61,7 +62,11 @@ func (s *Session) Apply(args *structs.SessionRequest, reply *string) error {
|
|||
// Generate a new session ID, verify uniqueness
|
||||
state := s.srv.fsm.State()
|
||||
for {
|
||||
args.Session.ID = generateUUID()
|
||||
var err error
|
||||
if args.Session.ID, err = uuid.GenerateUUID(); err != nil {
|
||||
s.srv.logger.Printf("[ERR] consul.session: UUID generation failed: %v", err)
|
||||
return err
|
||||
}
|
||||
_, sess, err := state.SessionGet(args.Session.ID)
|
||||
if err != nil {
|
||||
s.srv.logger.Printf("[ERR] consul.session: Session lookup failed: %v", err)
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/hashicorp/consul/consul/structs"
|
||||
"github.com/hashicorp/consul/lib"
|
||||
"github.com/hashicorp/consul/testutil"
|
||||
"github.com/hashicorp/net-rpc-msgpackrpc"
|
||||
)
|
||||
|
@ -217,7 +218,7 @@ func TestSessionEndpoint_List(t *testing.T) {
|
|||
}
|
||||
for i := 0; i < len(sessions.Sessions); i++ {
|
||||
s := sessions.Sessions[i]
|
||||
if !strContains(ids, s.ID) {
|
||||
if !lib.StrContains(ids, s.ID) {
|
||||
t.Fatalf("bad: %v", s)
|
||||
}
|
||||
if s.Node != "foo" {
|
||||
|
@ -318,7 +319,7 @@ func TestSessionEndpoint_Renew(t *testing.T) {
|
|||
}
|
||||
for i := 0; i < len(sessions.Sessions); i++ {
|
||||
s := sessions.Sessions[i]
|
||||
if !strContains(ids, s.ID) {
|
||||
if !lib.StrContains(ids, s.ID) {
|
||||
t.Fatalf("bad: %v", s)
|
||||
}
|
||||
if s.Node != "foo" {
|
||||
|
@ -352,7 +353,7 @@ func TestSessionEndpoint_Renew(t *testing.T) {
|
|||
}
|
||||
|
||||
s := session.Sessions[0]
|
||||
if !strContains(ids, s.ID) {
|
||||
if !lib.StrContains(ids, s.ID) {
|
||||
t.Fatalf("bad: %v", s)
|
||||
}
|
||||
if s.Node != "foo" {
|
||||
|
@ -379,7 +380,7 @@ func TestSessionEndpoint_Renew(t *testing.T) {
|
|||
|
||||
for i := 0; i < len(sessionsL1.Sessions); i++ {
|
||||
s := sessionsL1.Sessions[i]
|
||||
if !strContains(ids, s.ID) {
|
||||
if !lib.StrContains(ids, s.ID) {
|
||||
t.Fatalf("bad: %v", s)
|
||||
}
|
||||
if s.Node != "foo" {
|
||||
|
@ -411,7 +412,7 @@ func TestSessionEndpoint_Renew(t *testing.T) {
|
|||
if len(sessionsL2.Sessions) != 0 {
|
||||
for i := 0; i < len(sessionsL2.Sessions); i++ {
|
||||
s := sessionsL2.Sessions[i]
|
||||
if !strContains(ids, s.ID) {
|
||||
if !lib.StrContains(ids, s.ID) {
|
||||
t.Fatalf("bad: %v", s)
|
||||
}
|
||||
if s.Node != "foo" {
|
||||
|
@ -476,7 +477,7 @@ func TestSessionEndpoint_NodeSessions(t *testing.T) {
|
|||
}
|
||||
for i := 0; i < len(sessions.Sessions); i++ {
|
||||
s := sessions.Sessions[i]
|
||||
if !strContains(ids, s.ID) {
|
||||
if !lib.StrContains(ids, s.ID) {
|
||||
t.Fatalf("bad: %v", s)
|
||||
}
|
||||
if s.Node != "foo" {
|
||||
|
|
|
@ -11,6 +11,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/hashicorp/consul/consul/structs"
|
||||
"github.com/hashicorp/consul/lib"
|
||||
"github.com/hashicorp/serf/coordinate"
|
||||
)
|
||||
|
||||
|
@ -1189,16 +1190,6 @@ func TestStateStore_Services(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// strContains checks if a list contains a string
|
||||
func strContains(l []string, s string) bool {
|
||||
for _, v := range l {
|
||||
if v == s {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func TestStateStore_ServiceNodes(t *testing.T) {
|
||||
s := testStateStore(t)
|
||||
|
||||
|
@ -1249,7 +1240,7 @@ func TestStateStore_ServiceNodes(t *testing.T) {
|
|||
if nodes[0].ServiceID != "db" {
|
||||
t.Fatalf("bad: %v", nodes)
|
||||
}
|
||||
if !strContains(nodes[0].ServiceTags, "slave") {
|
||||
if !lib.StrContains(nodes[0].ServiceTags, "slave") {
|
||||
t.Fatalf("bad: %v", nodes)
|
||||
}
|
||||
if nodes[0].ServicePort != 8000 {
|
||||
|
@ -1265,7 +1256,7 @@ func TestStateStore_ServiceNodes(t *testing.T) {
|
|||
if nodes[1].ServiceID != "db2" {
|
||||
t.Fatalf("bad: %v", nodes)
|
||||
}
|
||||
if !strContains(nodes[1].ServiceTags, "slave") {
|
||||
if !lib.StrContains(nodes[1].ServiceTags, "slave") {
|
||||
t.Fatalf("bad: %v", nodes)
|
||||
}
|
||||
if nodes[1].ServicePort != 8001 {
|
||||
|
@ -1281,7 +1272,7 @@ func TestStateStore_ServiceNodes(t *testing.T) {
|
|||
if nodes[2].ServiceID != "db" {
|
||||
t.Fatalf("bad: %v", nodes)
|
||||
}
|
||||
if !strContains(nodes[2].ServiceTags, "master") {
|
||||
if !lib.StrContains(nodes[2].ServiceTags, "master") {
|
||||
t.Fatalf("bad: %v", nodes)
|
||||
}
|
||||
if nodes[2].ServicePort != 8000 {
|
||||
|
@ -1328,7 +1319,7 @@ func TestStateStore_ServiceTagNodes(t *testing.T) {
|
|||
if nodes[0].Address != "127.0.0.1" {
|
||||
t.Fatalf("bad: %v", nodes)
|
||||
}
|
||||
if !strContains(nodes[0].ServiceTags, "master") {
|
||||
if !lib.StrContains(nodes[0].ServiceTags, "master") {
|
||||
t.Fatalf("bad: %v", nodes)
|
||||
}
|
||||
if nodes[0].ServicePort != 8000 {
|
||||
|
@ -1375,7 +1366,7 @@ func TestStateStore_ServiceTagNodes_MultipleTags(t *testing.T) {
|
|||
if nodes[0].Address != "127.0.0.1" {
|
||||
t.Fatalf("bad: %v", nodes)
|
||||
}
|
||||
if !strContains(nodes[0].ServiceTags, "master") {
|
||||
if !lib.StrContains(nodes[0].ServiceTags, "master") {
|
||||
t.Fatalf("bad: %v", nodes)
|
||||
}
|
||||
if nodes[0].ServicePort != 8000 {
|
||||
|
@ -1409,7 +1400,7 @@ func TestStateStore_ServiceTagNodes_MultipleTags(t *testing.T) {
|
|||
if nodes[0].Address != "127.0.0.1" {
|
||||
t.Fatalf("bad: %v", nodes)
|
||||
}
|
||||
if !strContains(nodes[0].ServiceTags, "dev") {
|
||||
if !lib.StrContains(nodes[0].ServiceTags, "dev") {
|
||||
t.Fatalf("bad: %v", nodes)
|
||||
}
|
||||
if nodes[0].ServicePort != 8001 {
|
||||
|
|
|
@ -1,17 +1,13 @@
|
|||
package consul
|
||||
|
||||
import (
|
||||
crand "crypto/rand"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/serf/serf"
|
||||
)
|
||||
|
@ -83,24 +79,6 @@ func init() {
|
|||
privateBlocks[5] = block
|
||||
}
|
||||
|
||||
// strContains checks if a list contains a string
|
||||
func strContains(l []string, s string) bool {
|
||||
for _, v := range l {
|
||||
if v == s {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func ToLowerList(l []string) []string {
|
||||
var out []string
|
||||
for _, value := range l {
|
||||
out = append(out, strings.ToLower(value))
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// ensurePath is used to make sure a path exists
|
||||
func ensurePath(path string, dir bool) error {
|
||||
if !dir {
|
||||
|
@ -309,23 +287,3 @@ func runtimeStats() map[string]string {
|
|||
"cpu_count": strconv.FormatInt(int64(runtime.NumCPU()), 10),
|
||||
}
|
||||
}
|
||||
|
||||
// generateUUID is used to generate a random UUID
|
||||
func generateUUID() string {
|
||||
buf := make([]byte, 16)
|
||||
if _, err := crand.Read(buf); err != nil {
|
||||
panic(fmt.Errorf("failed to read random bytes: %v", err))
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%08x-%04x-%04x-%04x-%12x",
|
||||
buf[0:4],
|
||||
buf[4:6],
|
||||
buf[6:8],
|
||||
buf[8:10],
|
||||
buf[10:16])
|
||||
}
|
||||
|
||||
// Returns a random stagger interval between 0 and the duration
|
||||
func randomStagger(intv time.Duration) time.Duration {
|
||||
return time.Duration(uint64(rand.Int63()) % uint64(intv))
|
||||
}
|
||||
|
|
|
@ -6,30 +6,10 @@ import (
|
|||
"net"
|
||||
"regexp"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/serf/serf"
|
||||
)
|
||||
|
||||
func TestStrContains(t *testing.T) {
|
||||
l := []string{"a", "b", "c"}
|
||||
if !strContains(l, "b") {
|
||||
t.Fatalf("should contain")
|
||||
}
|
||||
if strContains(l, "d") {
|
||||
t.Fatalf("should not contain")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToLowerList(t *testing.T) {
|
||||
l := []string{"ABC", "Abc", "abc"}
|
||||
for _, value := range ToLowerList(l) {
|
||||
if value != "abc" {
|
||||
t.Fatalf("failed lowercasing")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetPrivateIP(t *testing.T) {
|
||||
ip, _, err := net.ParseCIDR("10.1.2.3/32")
|
||||
if err != nil {
|
||||
|
@ -295,13 +275,3 @@ func TestGenerateUUID(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRandomStagger(t *testing.T) {
|
||||
intv := time.Minute
|
||||
for i := 0; i < 10; i++ {
|
||||
stagger := randomStagger(intv)
|
||||
if stagger < 0 || stagger >= intv {
|
||||
t.Fatalf("Bad: %v", stagger)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
package lib
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Returns a random stagger interval between 0 and the duration
|
||||
func RandomStagger(intv time.Duration) time.Duration {
|
||||
return time.Duration(uint64(rand.Int63()) % uint64(intv))
|
||||
}
|
||||
|
||||
// RateScaledInterval is used to choose an interval to perform an action in
|
||||
// order to target an aggregate number of actions per second across the whole
|
||||
// cluster.
|
||||
func RateScaledInterval(rate float64, min time.Duration, n int) time.Duration {
|
||||
interval := time.Duration(float64(time.Second) * float64(n) / rate)
|
||||
if interval < min {
|
||||
return min
|
||||
}
|
||||
|
||||
return interval
|
||||
}
|
|
@ -0,0 +1,39 @@
|
|||
package lib
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestRandomStagger(t *testing.T) {
|
||||
intv := time.Minute
|
||||
for i := 0; i < 10; i++ {
|
||||
stagger := RandomStagger(intv)
|
||||
if stagger < 0 || stagger >= intv {
|
||||
t.Fatalf("Bad: %v", stagger)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateScaledInterval(t *testing.T) {
|
||||
min := 1 * time.Second
|
||||
rate := 200.0
|
||||
if v := RateScaledInterval(rate, min, 0); v != min {
|
||||
t.Fatalf("Bad: %v", v)
|
||||
}
|
||||
if v := RateScaledInterval(rate, min, 100); v != min {
|
||||
t.Fatalf("Bad: %v", v)
|
||||
}
|
||||
if v := RateScaledInterval(rate, min, 200); v != 1*time.Second {
|
||||
t.Fatalf("Bad: %v", v)
|
||||
}
|
||||
if v := RateScaledInterval(rate, min, 1000); v != 5*time.Second {
|
||||
t.Fatalf("Bad: %v", v)
|
||||
}
|
||||
if v := RateScaledInterval(rate, min, 5000); v != 25*time.Second {
|
||||
t.Fatalf("Bad: %v", v)
|
||||
}
|
||||
if v := RateScaledInterval(rate, min, 10000); v != 50*time.Second {
|
||||
t.Fatalf("Bad: %v", v)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,23 @@
|
|||
package lib
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// StrContains checks if a list contains a string
|
||||
func StrContains(l []string, s string) bool {
|
||||
for _, v := range l {
|
||||
if v == s {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func ToLowerList(l []string) []string {
|
||||
var out []string
|
||||
for _, value := range l {
|
||||
out = append(out, strings.ToLower(value))
|
||||
}
|
||||
return out
|
||||
}
|
|
@ -0,0 +1,24 @@
|
|||
package lib
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestStrContains(t *testing.T) {
|
||||
l := []string{"a", "b", "c"}
|
||||
if !StrContains(l, "b") {
|
||||
t.Fatalf("should contain")
|
||||
}
|
||||
if StrContains(l, "d") {
|
||||
t.Fatalf("should not contain")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToLowerList(t *testing.T) {
|
||||
l := []string{"ABC", "Abc", "abc"}
|
||||
for _, value := range ToLowerList(l) {
|
||||
if value != "abc" {
|
||||
t.Fatalf("failed lowercasing")
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue