From 7bc7d4cd4ff5edbded3af3880512d0fa4fec2ef1 Mon Sep 17 00:00:00 2001 From: Armon Dadgar Date: Wed, 11 Dec 2013 15:34:10 -0800 Subject: [PATCH] Adding support for deregistration --- consul/catalog_endpoint.go | 10 ++- consul/catalog_endpoint_test.go | 26 +++++++ consul/fsm.go | 17 ++++ consul/fsm_test.go | 134 ++++++++++++++++++++++++++++++++ consul/state_store.go | 42 ++++++++-- consul/state_store_test.go | 57 ++++++++++++++ 6 files changed, 278 insertions(+), 8 deletions(-) diff --git a/consul/catalog_endpoint.go b/consul/catalog_endpoint.go index 7d4639ac1e..285015483d 100644 --- a/consul/catalog_endpoint.go +++ b/consul/catalog_endpoint.go @@ -27,7 +27,6 @@ func (c *Catalog) Register(args *rpc.RegisterRequest, reply *struct{}) error { return err } - // Run it through raft _, err := c.srv.raftApply(rpc.RegisterRequestType, args) if err != nil { c.srv.logger.Printf("[ERR] Register failed: %v", err) @@ -38,5 +37,14 @@ func (c *Catalog) Register(args *rpc.RegisterRequest, reply *struct{}) error { // Deregister is used to remove a service registration for a given node. func (c *Catalog) Deregister(args *rpc.DeregisterRequest, reply *struct{}) error { + if done, err := c.srv.forward("Catalog.Deregister", args.Datacenter, args, reply); done { + return err + } + + _, err := c.srv.raftApply(rpc.DeregisterRequestType, args) + if err != nil { + c.srv.logger.Printf("[ERR] Deregister failed: %v", err) + return err + } return nil } diff --git a/consul/catalog_endpoint_test.go b/consul/catalog_endpoint_test.go index 5383377dae..14920a17e7 100644 --- a/consul/catalog_endpoint_test.go +++ b/consul/catalog_endpoint_test.go @@ -36,3 +36,29 @@ func TestCatalogRegister(t *testing.T) { t.Fatalf("err: %v", err) } } + +func TestCatalogDeregister(t *testing.T) { + dir1, s1 := testServer(t) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + client := rpcClient(t, s1) + defer client.Close() + + arg := rpc.DeregisterRequest{ + Datacenter: "dc1", + Node: "foo", + } + var out struct{} + + err := client.Call("Catalog.Deregister", &arg, &out) + if err == nil || err.Error() != "No cluster leader" { + t.Fatalf("err: %v", err) + } + + // Wait for leader + time.Sleep(100 * time.Millisecond) + + if err := client.Call("Catalog.Deregister", &arg, &out); err != nil { + t.Fatalf("err: %v", err) + } +} diff --git a/consul/fsm.go b/consul/fsm.go index db83a37f7b..eb3ff1b62b 100644 --- a/consul/fsm.go +++ b/consul/fsm.go @@ -38,6 +38,8 @@ func (c *consulFSM) Apply(buf []byte) interface{} { switch rpc.MessageType(buf[0]) { case rpc.RegisterRequestType: return c.applyRegister(buf[1:]) + case rpc.DeregisterRequestType: + return c.applyDeregister(buf[1:]) default: panic(fmt.Errorf("failed to apply request: %#v", buf)) } @@ -59,6 +61,21 @@ func (c *consulFSM) applyRegister(buf []byte) interface{} { return nil } +func (c *consulFSM) applyDeregister(buf []byte) interface{} { + var req rpc.DeregisterRequest + if err := rpc.Decode(buf, &req); err != nil { + panic(fmt.Errorf("failed to decode request: %v", err)) + } + + // Either remove the service entry or the whole node + if req.ServiceName != "" { + c.state.DeleteNodeService(req.Node, req.ServiceName) + } else { + c.state.DeleteNode(req.Node) + } + return nil +} + func (c *consulFSM) Snapshot() (raft.FSMSnapshot, error) { snap := &consulSnapshot{fsm: c} return snap, nil diff --git a/consul/fsm_test.go b/consul/fsm_test.go index 40df4856c9..080479a0fb 100644 --- a/consul/fsm_test.go +++ b/consul/fsm_test.go @@ -11,6 +11,39 @@ func TestFSM_RegisterNode(t *testing.T) { t.Fatalf("err: %v", err) } + req := rpc.RegisterRequest{ + Datacenter: "dc1", + Node: "foo", + Address: "127.0.0.1", + } + buf, err := rpc.Encode(rpc.RegisterRequestType, req) + if err != nil { + t.Fatalf("err: %v", err) + } + + resp := fsm.Apply(buf) + if resp != nil { + t.Fatalf("resp: %v", resp) + } + + // Verify we are registered + if found, _ := fsm.state.GetNode("foo"); !found { + t.Fatalf("not found!") + } + + // Verify service registered + services := fsm.state.NodeServices("foo") + if len(services) != 0 { + t.Fatalf("Services: %v", services) + } +} + +func TestFSM_RegisterNode_Service(t *testing.T) { + fsm, err := NewFSM() + if err != nil { + t.Fatalf("err: %v", err) + } + req := rpc.RegisterRequest{ Datacenter: "dc1", Node: "foo", @@ -40,3 +73,104 @@ func TestFSM_RegisterNode(t *testing.T) { t.Fatalf("not registered!") } } + +func TestFSM_DeregisterService(t *testing.T) { + fsm, err := NewFSM() + if err != nil { + t.Fatalf("err: %v", err) + } + + req := rpc.RegisterRequest{ + Datacenter: "dc1", + Node: "foo", + Address: "127.0.0.1", + ServiceName: "db", + ServiceTag: "master", + ServicePort: 8000, + } + buf, err := rpc.Encode(rpc.RegisterRequestType, req) + if err != nil { + t.Fatalf("err: %v", err) + } + + resp := fsm.Apply(buf) + if resp != nil { + t.Fatalf("resp: %v", resp) + } + + dereg := rpc.DeregisterRequest{ + Datacenter: "dc1", + Node: "foo", + ServiceName: "db", + } + buf, err = rpc.Encode(rpc.DeregisterRequestType, dereg) + if err != nil { + t.Fatalf("err: %v", err) + } + + resp = fsm.Apply(buf) + if resp != nil { + t.Fatalf("resp: %v", resp) + } + + // Verify we are registered + if found, _ := fsm.state.GetNode("foo"); !found { + t.Fatalf("not found!") + } + + // Verify service not registered + services := fsm.state.NodeServices("foo") + if _, ok := services["db"]; ok { + t.Fatalf("db registered!") + } +} + +func TestFSM_DeregisterNode(t *testing.T) { + fsm, err := NewFSM() + if err != nil { + t.Fatalf("err: %v", err) + } + + req := rpc.RegisterRequest{ + Datacenter: "dc1", + Node: "foo", + Address: "127.0.0.1", + ServiceName: "db", + ServiceTag: "master", + ServicePort: 8000, + } + buf, err := rpc.Encode(rpc.RegisterRequestType, req) + if err != nil { + t.Fatalf("err: %v", err) + } + + resp := fsm.Apply(buf) + if resp != nil { + t.Fatalf("resp: %v", resp) + } + + dereg := rpc.DeregisterRequest{ + Datacenter: "dc1", + Node: "foo", + } + buf, err = rpc.Encode(rpc.DeregisterRequestType, dereg) + if err != nil { + t.Fatalf("err: %v", err) + } + + resp = fsm.Apply(buf) + if resp != nil { + t.Fatalf("resp: %v", resp) + } + + // Verify we are registered + if found, _ := fsm.state.GetNode("foo"); found { + t.Fatalf("found!") + } + + // Verify service not registered + services := fsm.state.NodeServices("foo") + if len(services) != 0 { + t.Fatalf("Services: %v", services) + } +} diff --git a/consul/state_store.go b/consul/state_store.go index f7bbe02733..cd18c82f5b 100644 --- a/consul/state_store.go +++ b/consul/state_store.go @@ -14,6 +14,8 @@ const ( queryNodes queryEnsureService queryNodeServices + queryDeleteNodeService + queryDeleteNode ) // NoodeServices maps the Service name to a tag and port @@ -67,6 +69,7 @@ func (s *StateStore) initialize() error { // Set the pragma first pragmas := []string{ "pragma journal_mode=memory;", + "pragma foreign_keys=ON;", } for _, p := range pragmas { if _, err := s.db.Exec(p); err != nil { @@ -77,8 +80,8 @@ func (s *StateStore) initialize() error { // Create the tables tables := []string{ `CREATE TABLE nodes (name text unique, address text);`, - `CREATE TABLE services (node text references nodes, service text, tag text, port integer);`, - `CREATE INDEX servName on services(service);`, + `CREATE TABLE services (node text REFERENCES nodes(name) ON DELETE CASCADE, service text, tag text, port integer);`, + `CREATE INDEX servName ON services(service);`, } for _, t := range tables { if _, err := s.db.Exec(t); err != nil { @@ -88,11 +91,13 @@ func (s *StateStore) initialize() error { // Prepare the queries queries := map[namedQuery]string{ - queryEnsureNode: "INSERT OR REPLACE INTO nodes (name, address) VALUES (?, ?)", - queryNode: "SELECT address FROM nodes where name=?", - queryNodes: "SELECT * FROM nodes", - queryEnsureService: "INSERT OR REPLACE INTO services (node, service, tag, port) VALUES (?, ?, ?, ?)", - queryNodeServices: "SELECT service, tag, port from services where node=?", + queryEnsureNode: "INSERT OR REPLACE INTO nodes (name, address) VALUES (?, ?)", + queryNode: "SELECT address FROM nodes where name=?", + queryNodes: "SELECT * FROM nodes", + queryEnsureService: "INSERT OR REPLACE INTO services (node, service, tag, port) VALUES (?, ?, ?, ?)", + queryNodeServices: "SELECT service, tag, port from services where node=?", + queryDeleteNodeService: "DELETE FROM services WHERE node=? AND service=?", + queryDeleteNode: "DELETE FROM nodes WHERE name=?", } for name, query := range queries { stmt, err := s.db.Prepare(query) @@ -118,6 +123,17 @@ func (s *StateStore) checkSet(res sql.Result, err error) error { return nil } +func (s *StateStore) checkDelete(res sql.Result, err error) error { + if err != nil { + return err + } + _, err = res.RowsAffected() + if err != nil { + return err + } + return nil +} + // EnsureNode is used to ensure a given node exists, with the provided address func (s *StateStore) EnsureNode(name string, address string) error { stmt := s.prepared[queryEnsureNode] @@ -186,3 +202,15 @@ func (s *StateStore) NodeServices(name string) NodeServices { return services } + +// DeleteNodeService is used to delete a node service +func (s *StateStore) DeleteNodeService(node, service string) error { + stmt := s.prepared[queryDeleteNodeService] + return s.checkDelete(stmt.Exec(node, service)) +} + +// DeleteNode is used to delete a node and all it's services +func (s *StateStore) DeleteNode(node string) error { + stmt := s.prepared[queryDeleteNode] + return s.checkDelete(stmt.Exec(node)) +} diff --git a/consul/state_store_test.go b/consul/state_store_test.go index bbae9936ec..cc3a8c915f 100644 --- a/consul/state_store_test.go +++ b/consul/state_store_test.go @@ -95,3 +95,60 @@ func TestEnsureService(t *testing.T) { t.Fatalf("Bad entry: %#v", entry) } } + +func TestDeleteNodeService(t *testing.T) { + store, err := NewStateStore() + if err != nil { + t.Fatalf("err: %v", err) + } + defer store.Close() + + if err := store.EnsureNode("foo", "127.0.0.1"); err != nil { + t.Fatalf("err: %v") + } + + if err := store.EnsureService("foo", "api", "", 5000); err != nil { + t.Fatalf("err: %v") + } + + if err := store.DeleteNodeService("foo", "api"); err != nil { + t.Fatalf("err: %v") + } + + services := store.NodeServices("foo") + _, ok := services["api"] + if ok { + t.Fatalf("has api: %#v", services) + } +} + +func TestDeleteNode(t *testing.T) { + store, err := NewStateStore() + if err != nil { + t.Fatalf("err: %v", err) + } + defer store.Close() + + if err := store.EnsureNode("foo", "127.0.0.1"); err != nil { + t.Fatalf("err: %v") + } + + if err := store.EnsureService("foo", "api", "", 5000); err != nil { + t.Fatalf("err: %v") + } + + if err := store.DeleteNode("foo"); err != nil { + t.Fatalf("err: %v") + } + + services := store.NodeServices("foo") + _, ok := services["api"] + if ok { + t.Fatalf("has api: %#v", services) + } + + found, _ := store.GetNode("foo") + if found { + t.Fatalf("found node") + } +}