diff --git a/consul/client_test.go b/consul/client_test.go index 59785b57b8..a783c3a52c 100644 --- a/consul/client_test.go +++ b/consul/client_test.go @@ -2,12 +2,14 @@ package consul import ( "fmt" - "github.com/hashicorp/consul/consul/structs" - "github.com/hashicorp/consul/testutil" "net" "os" "testing" "time" + + "github.com/hashicorp/consul/consul/structs" + "github.com/hashicorp/consul/testutil" + "github.com/hashicorp/serf/serf" ) func testClientConfig(t *testing.T, NodeName string) (string, *Config) { @@ -44,6 +46,17 @@ func testClientDC(t *testing.T, dc string) (string, *Client) { return dir, client } +func testClientWithConfig(t *testing.T, cb func(c *Config)) (string, *Client) { + name := fmt.Sprintf("Client %d", getPort()) + dir, config := testClientConfig(t, name) + cb(config) + client, err := NewClient(config) + if err != nil { + t.Fatalf("err: %v", err) + } + return dir, client +} + func TestClient_StartStop(t *testing.T) { dir, client := testClient(t) defer os.RemoveAll(dir) @@ -178,3 +191,81 @@ func TestClient_RPC_TLS(t *testing.T) { t.Fatalf("err: %v", err) }) } + +func TestClientServer_UserEvent(t *testing.T) { + clientOut := make(chan serf.UserEvent, 2) + dir1, c1 := testClientWithConfig(t, func(conf *Config) { + conf.UserEventHandler = func(e serf.UserEvent) { + clientOut <- e + } + }) + defer os.RemoveAll(dir1) + defer c1.Shutdown() + + serverOut := make(chan serf.UserEvent, 2) + dir2, s1 := testServerWithConfig(t, func(conf *Config) { + conf.UserEventHandler = func(e serf.UserEvent) { + serverOut <- e + } + }) + defer os.RemoveAll(dir2) + defer s1.Shutdown() + + // Try to join + addr := fmt.Sprintf("127.0.0.1:%d", + s1.config.SerfLANConfig.MemberlistConfig.BindPort) + if _, err := c1.JoinLAN([]string{addr}); err != nil { + t.Fatalf("err: %v", err) + } + + // Check the members + testutil.WaitForResult(func() (bool, error) { + return len(c1.LANMembers()) == 2 && len(s1.LANMembers()) == 2, nil + }, func(err error) { + t.Fatalf("bad len") + }) + + // Fire the user event + err := c1.UserEvent("foo", []byte("bar")) + if err != nil { + t.Fatalf("err: %v", err) + } + + err = s1.UserEvent("bar", []byte("baz")) + if err != nil { + t.Fatalf("err: %v", err) + } + + // Wait for all the events + var serverFoo, serverBar, clientFoo, clientBar bool + for i := 0; i < 4; i++ { + select { + case e := <-clientOut: + switch e.Name { + case "foo": + clientFoo = true + case "bar": + clientBar = true + default: + t.Fatalf("Bad: %#v", e) + } + + case e := <-serverOut: + switch e.Name { + case "foo": + serverFoo = true + case "bar": + serverBar = true + default: + t.Fatalf("Bad: %#v", e) + } + + case <-time.After(10 * time.Second): + t.Fatalf("timeout") + } + } + + if !(serverFoo && serverBar && clientFoo && clientBar) { + t.Fatalf("missing events") + } +}