diff --git a/.circleci/config.yml b/.circleci/config.yml index 385a263d95..8b0481585d 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -108,7 +108,7 @@ jobs: rm -rf /tmp/vault* - run: | PACKAGE_NAMES=$(go list ./... | circleci tests split --split-by=timings --timings-type=classname) - gotestsum --format=short-verbose --junitfile $TEST_RESULTS_DIR/gotestsum-report.xml -- -tags=$GOTAGS -p 3 -cover -coverprofile=cov_$CIRCLE_NODE_INDEX.part $PACKAGE_NAMES + gotestsum --format=short-verbose --junitfile $TEST_RESULTS_DIR/gotestsum-report.xml -- -tags=$GOTAGS -p 2 -cover -coverprofile=cov_$CIRCLE_NODE_INDEX.part $PACKAGE_NAMES # save coverage report parts - persist_to_workspace: diff --git a/agent/consul/filter.go b/agent/consul/filter.go index 572d4ba1e7..10e584af5b 100644 --- a/agent/consul/filter.go +++ b/agent/consul/filter.go @@ -14,7 +14,10 @@ func (d *dirEntFilter) Len() int { return len(d.ent) } func (d *dirEntFilter) Filter(i int) bool { - return d.authorizer.KeyRead(d.ent[i].Key, nil) != acl.Allow + var entCtx acl.EnterpriseAuthorizerContext + d.ent[i].FillAuthzContext(&entCtx) + + return d.authorizer.KeyRead(d.ent[i].Key, &entCtx) != acl.Allow } func (d *dirEntFilter) Move(dst, src, span int) { copy(d.ent[dst:dst+span], d.ent[src:src+span]) @@ -27,30 +30,6 @@ func FilterDirEnt(authorizer acl.Authorizer, ent structs.DirEntries) structs.Dir return ent[:FilterEntries(&df)] } -type keyFilter struct { - authorizer acl.Authorizer - keys []string -} - -func (k *keyFilter) Len() int { - return len(k.keys) -} -func (k *keyFilter) Filter(i int) bool { - // TODO (namespaces) use a real ent authz context here - return k.authorizer.KeyRead(k.keys[i], nil) != acl.Allow -} - -func (k *keyFilter) Move(dst, src, span int) { - copy(k.keys[dst:dst+span], k.keys[src:src+span]) -} - -// FilterKeys is used to filter a list of keys by -// applying an ACL policy -func FilterKeys(authorizer acl.Authorizer, keys []string) []string { - kf := keyFilter{authorizer: authorizer, keys: keys} - return keys[:FilterEntries(&kf)] -} - type txnResultsFilter struct { authorizer acl.Authorizer results structs.TxnResults diff --git a/agent/consul/filter_test.go b/agent/consul/filter_test.go index a575494e4a..15369c1d8f 100644 --- a/agent/consul/filter_test.go +++ b/agent/consul/filter_test.go @@ -50,38 +50,6 @@ func TestFilter_DirEnt(t *testing.T) { } } -func TestFilter_Keys(t *testing.T) { - t.Parallel() - policy, _ := acl.NewPolicyFromSource("", 0, testFilterRules, acl.SyntaxLegacy, nil, nil) - aclR, _ := acl.NewPolicyAuthorizerWithDefaults(acl.DenyAll(), []*acl.Policy{policy}, nil) - - type tcase struct { - in []string - out []string - } - cases := []tcase{ - tcase{ - in: []string{"foo/test", "foo/priv/nope", "foo/other", "zoo"}, - out: []string{"foo/test", "foo/other"}, - }, - tcase{ - in: []string{"abe", "lincoln"}, - out: []string{}, - }, - tcase{ - in: []string{"abe", "foo/1", "foo/2", "foo/3", "nope"}, - out: []string{"foo/1", "foo/2", "foo/3"}, - }, - } - - for _, tc := range cases { - out := FilterKeys(aclR, tc.in) - if !reflect.DeepEqual(out, tc.out) { - t.Fatalf("bad: %#v %#v", out, tc.out) - } - } -} - func TestFilter_TxnResults(t *testing.T) { t.Parallel() policy, _ := acl.NewPolicyFromSource("", 0, testFilterRules, acl.SyntaxLegacy, nil, nil) diff --git a/agent/consul/fsm/commands_oss.go b/agent/consul/fsm/commands_oss.go index 5e48c1a6b3..064df0a3ac 100644 --- a/agent/consul/fsm/commands_oss.go +++ b/agent/consul/fsm/commands_oss.go @@ -93,15 +93,15 @@ func (c *FSM) applyKVSOperation(buf []byte, index uint64) interface{} { case api.KVSet: return c.state.KVSSet(index, &req.DirEnt) case api.KVDelete: - return c.state.KVSDelete(index, req.DirEnt.Key) + return c.state.KVSDelete(index, req.DirEnt.Key, &req.DirEnt.EnterpriseMeta) case api.KVDeleteCAS: - act, err := c.state.KVSDeleteCAS(index, req.DirEnt.ModifyIndex, req.DirEnt.Key) + act, err := c.state.KVSDeleteCAS(index, req.DirEnt.ModifyIndex, req.DirEnt.Key, &req.DirEnt.EnterpriseMeta) if err != nil { return err } return act case api.KVDeleteTree: - return c.state.KVSDeleteTree(index, req.DirEnt.Key) + return c.state.KVSDeleteTree(index, req.DirEnt.Key, &req.DirEnt.EnterpriseMeta) case api.KVCAS: act, err := c.state.KVSSetCAS(index, &req.DirEnt) if err != nil { @@ -141,7 +141,7 @@ func (c *FSM) applySessionOperation(buf []byte, index uint64) interface{} { } return req.Session.ID case structs.SessionDestroy: - return c.state.SessionDestroy(index, req.Session.ID) + return c.state.SessionDestroy(index, req.Session.ID, &req.Session.EnterpriseMeta) default: c.logger.Printf("[WARN] consul.fsm: Invalid Session operation '%s'", req.Op) return fmt.Errorf("Invalid Session operation '%s'", req.Op) diff --git a/agent/consul/fsm/commands_oss_test.go b/agent/consul/fsm/commands_oss_test.go index 60f5289a05..a162be43be 100644 --- a/agent/consul/fsm/commands_oss_test.go +++ b/agent/consul/fsm/commands_oss_test.go @@ -390,7 +390,7 @@ func TestFSM_KVSDelete(t *testing.T) { } // Verify key is not set - _, d, err := fsm.state.KVSGet(nil, "/test/path") + _, d, err := fsm.state.KVSGet(nil, "/test/path", nil) if err != nil { t.Fatalf("err: %v", err) } @@ -437,7 +437,7 @@ func TestFSM_KVSDeleteTree(t *testing.T) { } // Verify key is not set - _, d, err := fsm.state.KVSGet(nil, "/test/path") + _, d, err := fsm.state.KVSGet(nil, "/test/path", nil) if err != nil { t.Fatalf("err: %v", err) } @@ -472,7 +472,7 @@ func TestFSM_KVSDeleteCheckAndSet(t *testing.T) { } // Verify key is set - _, d, err := fsm.state.KVSGet(nil, "/test/path") + _, d, err := fsm.state.KVSGet(nil, "/test/path", nil) if err != nil { t.Fatalf("err: %v", err) } @@ -493,7 +493,7 @@ func TestFSM_KVSDeleteCheckAndSet(t *testing.T) { } // Verify key is gone - _, d, err = fsm.state.KVSGet(nil, "/test/path") + _, d, err = fsm.state.KVSGet(nil, "/test/path", nil) if err != nil { t.Fatalf("err: %v", err) } @@ -528,7 +528,7 @@ func TestFSM_KVSCheckAndSet(t *testing.T) { } // Verify key is set - _, d, err := fsm.state.KVSGet(nil, "/test/path") + _, d, err := fsm.state.KVSGet(nil, "/test/path", nil) if err != nil { t.Fatalf("err: %v", err) } @@ -550,7 +550,7 @@ func TestFSM_KVSCheckAndSet(t *testing.T) { } // Verify key is updated - _, d, err = fsm.state.KVSGet(nil, "/test/path") + _, d, err = fsm.state.KVSGet(nil, "/test/path", nil) if err != nil { t.Fatalf("err: %v", err) } @@ -566,9 +566,16 @@ func TestFSM_KVSLock(t *testing.T) { t.Fatalf("err: %v", err) } - fsm.state.EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}) + err = fsm.state.EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}) + if err != nil { + t.Fatalf("err: %v", err) + } + session := &structs.Session{ID: generateUUID(), Node: "foo"} - fsm.state.SessionCreate(2, session) + err = fsm.state.SessionCreate(2, session) + if err != nil { + t.Fatalf("err: %v", err) + } req := structs.KVSRequest{ Datacenter: "dc1", @@ -589,7 +596,7 @@ func TestFSM_KVSLock(t *testing.T) { } // Verify key is locked - _, d, err := fsm.state.KVSGet(nil, "/test/path") + _, d, err := fsm.state.KVSGet(nil, "/test/path", nil) if err != nil { t.Fatalf("err: %v", err) } @@ -611,9 +618,16 @@ func TestFSM_KVSUnlock(t *testing.T) { t.Fatalf("err: %v", err) } - fsm.state.EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}) + err = fsm.state.EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}) + if err != nil { + t.Fatalf("err: %v", err) + } + session := &structs.Session{ID: generateUUID(), Node: "foo"} - fsm.state.SessionCreate(2, session) + err = fsm.state.SessionCreate(2, session) + if err != nil { + t.Fatalf("err: %v", err) + } req := structs.KVSRequest{ Datacenter: "dc1", @@ -652,7 +666,7 @@ func TestFSM_KVSUnlock(t *testing.T) { } // Verify key is unlocked - _, d, err := fsm.state.KVSGet(nil, "/test/path") + _, d, err := fsm.state.KVSGet(nil, "/test/path", nil) if err != nil { t.Fatalf("err: %v", err) } @@ -675,8 +689,14 @@ func TestFSM_CoordinateUpdate(t *testing.T) { } // Register some nodes. - fsm.state.EnsureNode(1, &structs.Node{Node: "node1", Address: "127.0.0.1"}) - fsm.state.EnsureNode(2, &structs.Node{Node: "node2", Address: "127.0.0.1"}) + err = fsm.state.EnsureNode(1, &structs.Node{Node: "node1", Address: "127.0.0.1"}) + if err != nil { + t.Fatalf("err: %v", err) + } + err = fsm.state.EnsureNode(2, &structs.Node{Node: "node2", Address: "127.0.0.1"}) + if err != nil { + t.Fatalf("err: %v", err) + } // Write a batch of two coordinates. updates := structs.Coordinates{ @@ -715,12 +735,19 @@ func TestFSM_SessionCreate_Destroy(t *testing.T) { t.Fatalf("err: %v", err) } - fsm.state.EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}) - fsm.state.EnsureCheck(2, &structs.HealthCheck{ + err = fsm.state.EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}) + if err != nil { + t.Fatalf("err: %v", err) + } + + err = fsm.state.EnsureCheck(2, &structs.HealthCheck{ Node: "foo", CheckID: "web", Status: api.HealthPassing, }) + if err != nil { + t.Fatalf("err: %v", err) + } // Create a new session req := structs.SessionRequest{ @@ -743,7 +770,7 @@ func TestFSM_SessionCreate_Destroy(t *testing.T) { // Get the session id := resp.(string) - _, session, err := fsm.state.SessionGet(nil, id) + _, session, err := fsm.state.SessionGet(nil, id, nil) if err != nil { t.Fatalf("err: %v", err) } @@ -779,7 +806,7 @@ func TestFSM_SessionCreate_Destroy(t *testing.T) { t.Fatalf("resp: %v", resp) } - _, session, err = fsm.state.SessionGet(nil, id) + _, session, err = fsm.state.SessionGet(nil, id, nil) if err != nil { t.Fatalf("err: %v", err) } @@ -914,8 +941,15 @@ func TestFSM_PreparedQuery_CRUD(t *testing.T) { } // Register a service to query on. - fsm.state.EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}) - fsm.state.EnsureService(2, "foo", &structs.NodeService{ID: "web", Service: "web", Tags: nil, Address: "127.0.0.1", Port: 80}) + err = fsm.state.EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}) + if err != nil { + t.Fatalf("err: %v", err) + } + + err = fsm.state.EnsureService(2, "foo", &structs.NodeService{ID: "web", Service: "web", Tags: nil, Address: "127.0.0.1", Port: 80}) + if err != nil { + t.Fatalf("err: %v", err) + } // Create a new query. query := structs.PreparedQueryRequest{ @@ -1012,12 +1046,20 @@ func TestFSM_TombstoneReap(t *testing.T) { } // Create some tombstones - fsm.state.KVSSet(11, &structs.DirEntry{ + err = fsm.state.KVSSet(11, &structs.DirEntry{ Key: "/remove", Value: []byte("foo"), }) - fsm.state.KVSDelete(12, "/remove") - idx, _, err := fsm.state.KVSList(nil, "/remove") + if err != nil { + t.Fatalf("err: %v", err) + } + + err = fsm.state.KVSDelete(12, "/remove", nil) + if err != nil { + t.Fatalf("err: %v", err) + } + + idx, _, err := fsm.state.KVSList(nil, "/remove", nil) if err != nil { t.Fatalf("err: %s", err) } @@ -1085,7 +1127,7 @@ func TestFSM_Txn(t *testing.T) { } // Verify key is set directly in the state store. - _, d, err := fsm.state.KVSGet(nil, "/test/path") + _, d, err := fsm.state.KVSGet(nil, "/test/path", nil) if err != nil { t.Fatalf("err: %v", err) } @@ -1599,7 +1641,11 @@ func TestFSM_Chunking_TermChange(t *testing.T) { // Now verify the other baseline, that when the term doesn't change we see // non-nil. First make the chunker have a clean state, then set the terms // to be the same. - fsm.chunker.RestoreState(nil) + err = fsm.chunker.RestoreState(nil) + if err != nil { + t.Fatalf("err: %v", err) + } + logs[1].Term = uint64(0) // We should see nil only for the first one diff --git a/agent/consul/fsm/snapshot_oss_test.go b/agent/consul/fsm/snapshot_oss_test.go index 39e46a9a73..c15235adb8 100644 --- a/agent/consul/fsm/snapshot_oss_test.go +++ b/agent/consul/fsm/snapshot_oss_test.go @@ -79,6 +79,7 @@ func TestFSM_SnapshotRestore_OSS(t *testing.T) { }) session := &structs.Session{ID: generateUUID(), Node: "foo"} fsm.state.SessionCreate(9, session) + policy := &structs.ACLPolicy{ ID: structs.ACLPolicyGlobalManagementID, Name: "global-management", @@ -142,8 +143,8 @@ func TestFSM_SnapshotRestore_OSS(t *testing.T) { Key: "/remove", Value: []byte("foo"), }) - fsm.state.KVSDelete(12, "/remove") - idx, _, err := fsm.state.KVSList(nil, "/remove") + fsm.state.KVSDelete(12, "/remove", nil) + idx, _, err := fsm.state.KVSList(nil, "/remove", nil) if err != nil { t.Fatalf("err: %s", err) } @@ -350,7 +351,7 @@ func TestFSM_SnapshotRestore_OSS(t *testing.T) { } // Verify key is set - _, d, err := fsm2.state.KVSGet(nil, "/test") + _, d, err := fsm2.state.KVSGet(nil, "/test", nil) if err != nil { t.Fatalf("err: %v", err) } @@ -359,7 +360,7 @@ func TestFSM_SnapshotRestore_OSS(t *testing.T) { } // Verify session is restored - idx, s, err := fsm2.state.SessionGet(nil, session.ID) + idx, s, err := fsm2.state.SessionGet(nil, session.ID, nil) if err != nil { t.Fatalf("err: %v", err) } diff --git a/agent/consul/kvs_endpoint.go b/agent/consul/kvs_endpoint.go index 8c3722c496..b7900fec07 100644 --- a/agent/consul/kvs_endpoint.go +++ b/agent/consul/kvs_endpoint.go @@ -2,6 +2,7 @@ package consul import ( "fmt" + "strings" "time" "github.com/armon/go-metrics" @@ -22,7 +23,6 @@ type KVS struct { // must only be done on the leader. func kvsPreApply(srv *Server, rule acl.Authorizer, op api.KVOp, dirEnt *structs.DirEntry) (bool, error) { // Verify the entry. - if dirEnt.Key == "" && op != api.KVDeleteTree { return false, fmt.Errorf("Must provide key") } @@ -31,8 +31,10 @@ func kvsPreApply(srv *Server, rule acl.Authorizer, op api.KVOp, dirEnt *structs. if rule != nil { switch op { case api.KVDeleteTree: - // TODO (namespaces) use actual ent authz context - ensure we set the Sentinel Scope - if rule.KeyWritePrefix(dirEnt.Key, nil) != acl.Allow { + var authzContext acl.EnterpriseAuthorizerContext + dirEnt.FillAuthzContext(&authzContext) + + if rule.KeyWritePrefix(dirEnt.Key, &authzContext) != acl.Allow { return false, acl.ErrPermissionDenied } @@ -43,13 +45,17 @@ func kvsPreApply(srv *Server, rule acl.Authorizer, op api.KVOp, dirEnt *structs. // These could reveal information based on the outcome // of the transaction, and they operate on individual // keys so we check them here. - if rule.KeyRead(dirEnt.Key, nil) != acl.Allow { + var authzContext acl.EnterpriseAuthorizerContext + dirEnt.FillAuthzContext(&authzContext) + + if rule.KeyRead(dirEnt.Key, &authzContext) != acl.Allow { return false, acl.ErrPermissionDenied } default: var authzContext acl.EnterpriseAuthorizerContext dirEnt.FillAuthzContext(&authzContext) + if rule.KeyWrite(dirEnt.Key, &authzContext) != acl.Allow { return false, acl.ErrPermissionDenied } @@ -64,7 +70,7 @@ func kvsPreApply(srv *Server, rule acl.Authorizer, op api.KVOp, dirEnt *structs. // only the wall-time of the leader node is used, preventing any inconsistencies. if op == api.KVLock { state := srv.fsm.State() - expires := state.KVSLockDelay(dirEnt.Key) + expires := state.KVSLockDelay(dirEnt.Key, &dirEnt.EnterpriseMeta) if expires.After(time.Now()) { srv.logger.Printf("[WARN] consul.kvs: Rejecting lock of %s due to lock-delay until %v", dirEnt.Key, expires) @@ -82,12 +88,16 @@ func (k *KVS) Apply(args *structs.KVSRequest, reply *bool) error { } defer metrics.MeasureSince([]string{"kvs", "apply"}, time.Now()) + if err := k.srv.validateEnterpriseRequest(&args.DirEnt.EnterpriseMeta, true); err != nil { + return err + } + // Perform the pre-apply checks. - acl, err := k.srv.ResolveToken(args.Token) + rule, err := k.srv.ResolveToken(args.Token) if err != nil { return err } - ok, err := kvsPreApply(k.srv, acl, args.Op, &args.DirEnt) + ok, err := kvsPreApply(k.srv, rule, args.Op, &args.DirEnt) if err != nil { return err } @@ -118,20 +128,27 @@ func (k *KVS) Get(args *structs.KeyRequest, reply *structs.IndexedDirEntries) er if done, err := k.srv.forward("KVS.Get", args, args, reply); done { return err } + if err := k.srv.validateEnterpriseRequest(&args.EnterpriseMeta, false); err != nil { + return err + } - aclRule, err := k.srv.ResolveToken(args.Token) + var entCtx acl.EnterpriseAuthorizerContext + args.FillAuthzContext(&entCtx) + + rule, err := k.srv.ResolveToken(args.Token) if err != nil { return err } + return k.srv.blockingQuery( &args.QueryOptions, &reply.QueryMeta, func(ws memdb.WatchSet, state *state.Store) error { - index, ent, err := state.KVSGet(ws, args.Key) + index, ent, err := state.KVSGet(ws, args.Key, &args.EnterpriseMeta) if err != nil { return err } - if aclRule != nil && aclRule.KeyRead(args.Key, nil) != acl.Allow { + if rule != nil && rule.KeyRead(args.Key, &entCtx) != acl.Allow { return acl.ErrPermissionDenied } @@ -157,13 +174,18 @@ func (k *KVS) List(args *structs.KeyRequest, reply *structs.IndexedDirEntries) e if done, err := k.srv.forward("KVS.List", args, args, reply); done { return err } - - aclToken, err := k.srv.ResolveToken(args.Token) - if err != nil { + if err := k.srv.validateEnterpriseRequest(&args.EnterpriseMeta, false); err != nil { return err } - if aclToken != nil && k.srv.config.ACLEnableKeyListPolicy && aclToken.KeyList(args.Key, nil) != acl.Allow { + var entCtx acl.EnterpriseAuthorizerContext + args.FillAuthzContext(&entCtx) + + rule, err := k.srv.ResolveToken(args.Token) + if err != nil { + return err + } + if rule != nil && k.srv.config.ACLEnableKeyListPolicy && rule.KeyList(args.Key, &entCtx) != acl.Allow { return acl.ErrPermissionDenied } @@ -171,12 +193,12 @@ func (k *KVS) List(args *structs.KeyRequest, reply *structs.IndexedDirEntries) e &args.QueryOptions, &reply.QueryMeta, func(ws memdb.WatchSet, state *state.Store) error { - index, ent, err := state.KVSList(ws, args.Key) + index, ent, err := state.KVSList(ws, args.Key, &args.EnterpriseMeta) if err != nil { return err } - if aclToken != nil { - ent = FilterDirEnt(aclToken, ent) + if rule != nil { + ent = FilterDirEnt(rule, ent) } if len(ent) == 0 { @@ -197,17 +219,25 @@ func (k *KVS) List(args *structs.KeyRequest, reply *structs.IndexedDirEntries) e } // ListKeys is used to list all keys with a given prefix to a separator. +// An optional separator may be specified, which can be used to slice off a part +// of the response so that only a subset of the prefix is returned. In this +// mode, the keys which are omitted are still counted in the returned index. func (k *KVS) ListKeys(args *structs.KeyListRequest, reply *structs.IndexedKeyList) error { if done, err := k.srv.forward("KVS.ListKeys", args, args, reply); done { return err } - - aclToken, err := k.srv.ResolveToken(args.Token) - if err != nil { + if err := k.srv.validateEnterpriseRequest(&args.EnterpriseMeta, false); err != nil { return err } - if aclToken != nil && k.srv.config.ACLEnableKeyListPolicy && aclToken.KeyList(args.Prefix, nil) != acl.Allow { + var entCtx acl.EnterpriseAuthorizerContext + args.FillAuthzContext(&entCtx) + + rule, err := k.srv.ResolveToken(args.Token) + if err != nil { + return err + } + if rule != nil && k.srv.config.ACLEnableKeyListPolicy && rule.KeyList(args.Prefix, &entCtx) != acl.Allow { return acl.ErrPermissionDenied } @@ -215,7 +245,7 @@ func (k *KVS) ListKeys(args *structs.KeyListRequest, reply *structs.IndexedKeyLi &args.QueryOptions, &reply.QueryMeta, func(ws memdb.WatchSet, state *state.Store) error { - index, keys, err := state.KVSListKeys(ws, args.Prefix, args.Seperator) + index, entries, err := state.KVSList(ws, args.Prefix, &args.EnterpriseMeta) if err != nil { return err } @@ -228,8 +258,37 @@ func (k *KVS) ListKeys(args *structs.KeyListRequest, reply *structs.IndexedKeyLi reply.Index = index } - if aclToken != nil { - keys = FilterKeys(aclToken, keys) + if rule != nil { + entries = FilterDirEnt(rule, entries) + } + + // Collect the keys from the filtered entries + prefixLen := len(args.Prefix) + sepLen := len(args.Seperator) + + var keys []string + seen := make(map[string]bool) + + for _, e := range entries { + // Always accumulate if no separator provided + if sepLen == 0 { + keys = append(keys, e.Key) + continue + } + + // Parse and de-duplicate the returned keys based on the + // key separator, if provided. + after := e.Key[prefixLen:] + sepIdx := strings.Index(after, args.Seperator) + if sepIdx > -1 { + key := e.Key[:prefixLen+sepIdx+sepLen] + if ok := seen[key]; !ok { + keys = append(keys, key) + seen[key] = true + } + } else { + keys = append(keys, e.Key) + } } reply.Keys = keys return nil diff --git a/agent/consul/kvs_endpoint_test.go b/agent/consul/kvs_endpoint_test.go index e3ba380df5..85f1b9f6d3 100644 --- a/agent/consul/kvs_endpoint_test.go +++ b/agent/consul/kvs_endpoint_test.go @@ -21,7 +21,7 @@ func TestKVS_Apply(t *testing.T) { codec := rpcClient(t, s1) defer codec.Close() - testrpc.WaitForLeader(t, s1.RPC, "dc1") + testrpc.WaitForTestAgent(t, s1.RPC, "dc1") arg := structs.KVSRequest{ Datacenter: "dc1", @@ -39,7 +39,7 @@ func TestKVS_Apply(t *testing.T) { // Verify state := s1.fsm.State() - _, d, err := state.KVSGet(nil, "test") + _, d, err := state.KVSGet(nil, "test", &arg.DirEnt.EnterpriseMeta) if err != nil { t.Fatalf("err: %v", err) } @@ -61,7 +61,7 @@ func TestKVS_Apply(t *testing.T) { } // Verify - _, d, err = state.KVSGet(nil, "test") + _, d, err = state.KVSGet(nil, "test", &arg.DirEnt.EnterpriseMeta) if err != nil { t.Fatalf("err: %v", err) } @@ -83,7 +83,7 @@ func TestKVS_Apply_ACLDeny(t *testing.T) { codec := rpcClient(t, s1) defer codec.Close() - testrpc.WaitForLeader(t, s1.RPC, "dc1") + testrpc.WaitForTestAgent(t, s1.RPC, "dc1") // Create the ACL arg := structs.ACLRequest{ @@ -142,7 +142,7 @@ func TestKVS_Get(t *testing.T) { codec := rpcClient(t, s1) defer codec.Close() - testrpc.WaitForLeader(t, s1.RPC, "dc1") + testrpc.WaitForTestAgent(t, s1.RPC, "dc1") arg := structs.KVSRequest{ Datacenter: "dc1", @@ -195,7 +195,7 @@ func TestKVS_Get_ACLDeny(t *testing.T) { codec := rpcClient(t, s1) defer codec.Close() - testrpc.WaitForLeader(t, s1.RPC, "dc1") + testrpc.WaitForTestAgent(t, s1.RPC, "dc1") arg := structs.KVSRequest{ Datacenter: "dc1", @@ -231,7 +231,7 @@ func TestKVSEndpoint_List(t *testing.T) { codec := rpcClient(t, s1) defer codec.Close() - testrpc.WaitForLeader(t, s1.RPC, "dc1") + testrpc.WaitForTestAgent(t, s1.RPC, "dc1") keys := []string{ "/test/key1", @@ -303,7 +303,7 @@ func TestKVSEndpoint_List_Blocking(t *testing.T) { codec := rpcClient(t, s1) defer codec.Close() - testrpc.WaitForLeader(t, s1.RPC, "dc1") + testrpc.WaitForTestAgent(t, s1.RPC, "dc1") keys := []string{ "/test/key1", @@ -404,7 +404,7 @@ func TestKVSEndpoint_List_ACLDeny(t *testing.T) { codec := rpcClient(t, s1) defer codec.Close() - testrpc.WaitForLeader(t, s1.RPC, "dc1") + testrpc.WaitForTestAgent(t, s1.RPC, "dc1") keys := []string{ "abe", @@ -491,7 +491,7 @@ func TestKVSEndpoint_List_ACLEnableKeyListPolicy(t *testing.T) { codec := rpcClient(t, s1) defer codec.Close() - testrpc.WaitForLeader(t, s1.RPC, "dc1") + testrpc.WaitForTestAgent(t, s1.RPC, "dc1") keys := []string{ "abe", @@ -610,7 +610,7 @@ func TestKVSEndpoint_ListKeys(t *testing.T) { codec := rpcClient(t, s1) defer codec.Close() - testrpc.WaitForLeader(t, s1.RPC, "dc1") + testrpc.WaitForTestAgent(t, s1.RPC, "dc1") keys := []string{ "/test/key1", @@ -685,7 +685,7 @@ func TestKVSEndpoint_ListKeys_ACLDeny(t *testing.T) { codec := rpcClient(t, s1) defer codec.Close() - testrpc.WaitForLeader(t, s1.RPC, "dc1") + testrpc.WaitForTestAgent(t, s1.RPC, "dc1") keys := []string{ "abe", @@ -760,10 +760,11 @@ func TestKVS_Apply_LockDelay(t *testing.T) { codec := rpcClient(t, s1) defer codec.Close() - testrpc.WaitForLeader(t, s1.RPC, "dc1") + testrpc.WaitForTestAgent(t, s1.RPC, "dc1") // Create and invalidate a session with a lock. state := s1.fsm.State() + if err := state.EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil { t.Fatalf("err: %v", err) } @@ -783,7 +784,8 @@ func TestKVS_Apply_LockDelay(t *testing.T) { if ok, err := state.KVSLock(3, d); err != nil || !ok { t.Fatalf("err: %v", err) } - if err := state.SessionDestroy(4, id); err != nil { + + if err := state.SessionDestroy(4, id, nil); err != nil { t.Fatalf("err: %v", err) } @@ -830,7 +832,7 @@ func TestKVS_Issue_1626(t *testing.T) { codec := rpcClient(t, s1) defer codec.Close() - testrpc.WaitForLeader(t, s1.RPC, "dc1") + testrpc.WaitForTestAgent(t, s1.RPC, "dc1") // Set up the first key. { diff --git a/agent/consul/prepared_query_endpoint_test.go b/agent/consul/prepared_query_endpoint_test.go index c4ae24233d..1a0c051da0 100644 --- a/agent/consul/prepared_query_endpoint_test.go +++ b/agent/consul/prepared_query_endpoint_test.go @@ -14,6 +14,7 @@ import ( "github.com/hashicorp/consul/acl" "github.com/hashicorp/consul/agent/structs" + tokenStore "github.com/hashicorp/consul/agent/token" "github.com/hashicorp/consul/api" "github.com/hashicorp/consul/sdk/testutil/retry" "github.com/hashicorp/consul/testrpc" @@ -77,7 +78,7 @@ func TestPreparedQuery_Apply(t *testing.T) { query.Query.Service.Failover.NearestN = 0 query.Query.Session = "nope" err = msgpackrpc.CallWithCodec(codec, "PreparedQuery.Apply", &query, &reply) - if err == nil || !strings.Contains(err.Error(), "failed session lookup") { + if err == nil || !strings.Contains(err.Error(), "invalid session") { t.Fatalf("bad: %v", err) } @@ -852,7 +853,7 @@ func TestPreparedQuery_Get(t *testing.T) { codec := rpcClient(t, s1) defer codec.Close() - testrpc.WaitForLeader(t, s1.RPC, "dc1") + testrpc.WaitForTestAgent(t, s1.RPC, "dc1") // Create an ACL with write permissions for redis queries. var token string @@ -1105,7 +1106,7 @@ func TestPreparedQuery_List(t *testing.T) { codec := rpcClient(t, s1) defer codec.Close() - testrpc.WaitForLeader(t, s1.RPC, "dc1") + testrpc.WaitForTestAgent(t, s1.RPC, "dc1") // Create an ACL with write permissions for redis queries. var token string @@ -1461,16 +1462,16 @@ func TestPreparedQuery_Execute(t *testing.T) { codec2 := rpcClient(t, s2) defer codec2.Close() + s2.tokens.UpdateReplicationToken("root", tokenStore.TokenSourceConfig) testrpc.WaitForLeader(t, s1.RPC, "dc1") - testrpc.WaitForLeader(t, s2.RPC, "dc2") - - // Try to WAN join. joinWAN(t, s2, s1) + // Try to WAN join. retry.Run(t, func(r *retry.R) { if got, want := len(s1.WANMembers()), 2; got != want { r.Fatalf("got %d WAN members want %d", got, want) } }) + testrpc.WaitForLeader(t, s2.RPC, "dc2") // Create an ACL with read permission to the service. var execToken string @@ -2957,11 +2958,11 @@ func TestPreparedQuery_Wrapper(t *testing.T) { defer os.RemoveAll(dir2) defer s2.Shutdown() + s2.tokens.UpdateReplicationToken("root", tokenStore.TokenSourceConfig) testrpc.WaitForLeader(t, s1.RPC, "dc1") - testrpc.WaitForLeader(t, s2.RPC, "dc2") - // Try to WAN join. joinWAN(t, s2, s1) + testrpc.WaitForLeader(t, s2.RPC, "dc2") // Try all the operations on a real server via the wrapper. wrapper := &queryServerWrapper{s1} diff --git a/agent/consul/session_endpoint.go b/agent/consul/session_endpoint.go index 072cfce236..e3aac99e94 100644 --- a/agent/consul/session_endpoint.go +++ b/agent/consul/session_endpoint.go @@ -25,6 +25,10 @@ func (s *Session) Apply(args *structs.SessionRequest, reply *string) error { } defer metrics.MeasureSince([]string{"session", "apply"}, time.Now()) + if err := s.srv.validateEnterpriseRequest(&args.Session.EnterpriseMeta, true); err != nil { + return err + } + // Verify the args if args.Session.ID == "" && args.Op == structs.SessionDestroy { return fmt.Errorf("Must provide ID") @@ -33,30 +37,35 @@ func (s *Session) Apply(args *structs.SessionRequest, reply *string) error { return fmt.Errorf("Must provide Node") } + // TODO (namespaces) (acls) infer entmeta if not provided. + // The entMeta to populate will be the one in the Session struct, not SessionRequest + // This is because the Session is what is passed to downstream functions like raftApply + var entCtx acl.EnterpriseAuthorizerContext + args.Session.EnterpriseMeta.FillAuthzContext(&entCtx) + // Fetch the ACL token, if any, and apply the policy. rule, err := s.srv.ResolveToken(args.Token) if err != nil { return err } + if rule != nil && s.srv.config.ACLEnforceVersion8 { switch args.Op { case structs.SessionDestroy: state := s.srv.fsm.State() - _, existing, err := state.SessionGet(nil, args.Session.ID) + _, existing, err := state.SessionGet(nil, args.Session.ID, &args.Session.EnterpriseMeta) if err != nil { return fmt.Errorf("Session lookup failed: %v", err) } if existing == nil { return fmt.Errorf("Unknown session %q", args.Session.ID) } - // TODO (namespaces) - pass through a real ent authz ctx - if rule.SessionWrite(existing.Node, nil) != acl.Allow { + if rule.SessionWrite(existing.Node, &entCtx) != acl.Allow { return acl.ErrPermissionDenied } case structs.SessionCreate: - // TODO (namespaces) - pass through a real ent authz ctx - if rule.SessionWrite(args.Session.Node, nil) != acl.Allow { + if rule.SessionWrite(args.Session.Node, &entCtx) != acl.Allow { return acl.ErrPermissionDenied } @@ -102,7 +111,7 @@ func (s *Session) Apply(args *structs.SessionRequest, reply *string) error { s.srv.logger.Printf("[ERR] consul.session: UUID generation failed: %v", err) return err } - _, sess, err := state.SessionGet(nil, args.Session.ID) + _, sess, err := state.SessionGet(nil, args.Session.ID, &args.Session.EnterpriseMeta) if err != nil { s.srv.logger.Printf("[ERR] consul.session: Session lookup failed: %v", err) return err @@ -147,11 +156,24 @@ func (s *Session) Get(args *structs.SessionSpecificRequest, return err } + if err := s.srv.validateEnterpriseRequest(&args.EnterpriseMeta, false); err != nil { + return err + } + + // TODO (namespaces) TODO (acls) infer args.entmeta if not provided + var entCtx acl.EnterpriseAuthorizerContext + args.FillAuthzContext(&entCtx) + + rule, err := s.srv.ResolveToken(args.Token) + if err != nil { + return err + } + return s.srv.blockingQuery( &args.QueryOptions, &reply.QueryMeta, func(ws memdb.WatchSet, state *state.Store) error { - index, session, err := state.SessionGet(ws, args.Session) + index, session, err := state.SessionGet(ws, args.SessionID, &args.EnterpriseMeta) if err != nil { return err } @@ -162,7 +184,7 @@ func (s *Session) Get(args *structs.SessionSpecificRequest, } else { reply.Sessions = nil } - if err := s.srv.filterACL(args.Token, reply); err != nil { + if err := s.srv.filterACLWithAuthorizer(rule, reply); err != nil { return err } return nil @@ -170,23 +192,36 @@ func (s *Session) Get(args *structs.SessionSpecificRequest, } // List is used to list all the active sessions -func (s *Session) List(args *structs.DCSpecificRequest, +func (s *Session) List(args *structs.SessionSpecificRequest, reply *structs.IndexedSessions) error { if done, err := s.srv.forward("Session.List", args, args, reply); done { return err } + if err := s.srv.validateEnterpriseRequest(&args.EnterpriseMeta, false); err != nil { + return err + } + + // TODO (namespaces) TODO (acls) infer args.entmeta if not provided + var entCtx acl.EnterpriseAuthorizerContext + args.FillAuthzContext(&entCtx) + + rule, err := s.srv.ResolveToken(args.Token) + if err != nil { + return err + } + return s.srv.blockingQuery( &args.QueryOptions, &reply.QueryMeta, func(ws memdb.WatchSet, state *state.Store) error { - index, sessions, err := state.SessionList(ws) + index, sessions, err := state.SessionList(ws, &args.EnterpriseMeta) if err != nil { return err } reply.Index, reply.Sessions = index, sessions - if err := s.srv.filterACL(args.Token, reply); err != nil { + if err := s.srv.filterACLWithAuthorizer(rule, reply); err != nil { return err } return nil @@ -200,17 +235,30 @@ func (s *Session) NodeSessions(args *structs.NodeSpecificRequest, return err } + if err := s.srv.validateEnterpriseRequest(&args.EnterpriseMeta, false); err != nil { + return err + } + + // TODO (namespaces) TODO (acls) infer args.entmeta if not provided + var entCtx acl.EnterpriseAuthorizerContext + args.FillAuthzContext(&entCtx) + + rule, err := s.srv.ResolveToken(args.Token) + if err != nil { + return err + } + return s.srv.blockingQuery( &args.QueryOptions, &reply.QueryMeta, func(ws memdb.WatchSet, state *state.Store) error { - index, sessions, err := state.NodeSessions(ws, args.Node) + index, sessions, err := state.NodeSessions(ws, args.Node, &args.EnterpriseMeta) if err != nil { return err } reply.Index, reply.Sessions = index, sessions - if err := s.srv.filterACL(args.Token, reply); err != nil { + if err := s.srv.filterACLWithAuthorizer(rule, reply); err != nil { return err } return nil @@ -225,9 +273,13 @@ func (s *Session) Renew(args *structs.SessionSpecificRequest, } defer metrics.MeasureSince([]string{"session", "renew"}, time.Now()) + if err := s.srv.validateEnterpriseRequest(&args.EnterpriseMeta, true); err != nil { + return err + } + // Get the session, from local state. state := s.srv.fsm.State() - index, session, err := state.SessionGet(nil, args.Session) + index, session, err := state.SessionGet(nil, args.SessionID, &args.EnterpriseMeta) if err != nil { return err } @@ -237,21 +289,24 @@ func (s *Session) Renew(args *structs.SessionSpecificRequest, return nil } + // TODO (namespaces) (freddy):infer args.entmeta if not provided // Fetch the ACL token, if any, and apply the policy. + var entCtx acl.EnterpriseAuthorizerContext + args.FillAuthzContext(&entCtx) + rule, err := s.srv.ResolveToken(args.Token) if err != nil { return err } if rule != nil && s.srv.config.ACLEnforceVersion8 { - // TODO (namespaces) - pass through a real ent authz ctx - if rule.SessionWrite(session.Node, nil) != acl.Allow { + if rule.SessionWrite(session.Node, &entCtx) != acl.Allow { return acl.ErrPermissionDenied } } // Reset the session TTL timer. reply.Sessions = structs.Sessions{session} - if err := s.srv.resetSessionTimer(args.Session, session); err != nil { + if err := s.srv.resetSessionTimer(args.SessionID, session); err != nil { s.srv.logger.Printf("[ERR] consul.session: Session renew failed: %v", err) return err } diff --git a/agent/consul/session_endpoint_test.go b/agent/consul/session_endpoint_test.go index 3528284e27..bd42febc15 100644 --- a/agent/consul/session_endpoint_test.go +++ b/agent/consul/session_endpoint_test.go @@ -17,6 +17,7 @@ func TestSession_Apply(t *testing.T) { dir1, s1 := testServer(t) defer os.RemoveAll(dir1) defer s1.Shutdown() + codec := rpcClient(t, s1) defer codec.Close() @@ -41,7 +42,7 @@ func TestSession_Apply(t *testing.T) { // Verify state := s1.fsm.State() - _, s, err := state.SessionGet(nil, out) + _, s, err := state.SessionGet(nil, out, nil) if err != nil { t.Fatalf("err: %v", err) } @@ -63,7 +64,7 @@ func TestSession_Apply(t *testing.T) { } // Verify - _, s, err = state.SessionGet(nil, id) + _, s, err = state.SessionGet(nil, id, nil) if err != nil { t.Fatalf("err: %v", err) } @@ -77,6 +78,7 @@ func TestSession_DeleteApply(t *testing.T) { dir1, s1 := testServer(t) defer os.RemoveAll(dir1) defer s1.Shutdown() + codec := rpcClient(t, s1) defer codec.Close() @@ -102,7 +104,7 @@ func TestSession_DeleteApply(t *testing.T) { // Verify state := s1.fsm.State() - _, s, err := state.SessionGet(nil, out) + _, s, err := state.SessionGet(nil, out, nil) if err != nil { t.Fatalf("err: %v", err) } @@ -127,7 +129,7 @@ func TestSession_DeleteApply(t *testing.T) { } // Verify - _, s, err = state.SessionGet(nil, id) + _, s, err = state.SessionGet(nil, id, nil) if err != nil { t.Fatalf("err: %v", err) } @@ -147,6 +149,7 @@ func TestSession_Apply_ACLDeny(t *testing.T) { }) defer os.RemoveAll(dir1) defer s1.Shutdown() + codec := rpcClient(t, s1) defer codec.Close() @@ -237,6 +240,7 @@ func TestSession_Get(t *testing.T) { dir1, s1 := testServer(t) defer os.RemoveAll(dir1) defer s1.Shutdown() + codec := rpcClient(t, s1) defer codec.Close() @@ -257,7 +261,7 @@ func TestSession_Get(t *testing.T) { getR := structs.SessionSpecificRequest{ Datacenter: "dc1", - Session: out, + SessionID: out, } var sessions structs.IndexedSessions if err := msgpackrpc.CallWithCodec(codec, "Session.Get", &getR, &sessions); err != nil { @@ -281,6 +285,7 @@ func TestSession_List(t *testing.T) { dir1, s1 := testServer(t) defer os.RemoveAll(dir1) defer s1.Shutdown() + codec := rpcClient(t, s1) defer codec.Close() @@ -339,6 +344,7 @@ func TestSession_Get_List_NodeSessions_ACLFilter(t *testing.T) { }) defer os.RemoveAll(dir1) defer s1.Shutdown() + codec := rpcClient(t, s1) defer codec.Close() @@ -384,7 +390,7 @@ session "foo" { // 8 ACL enforcement isn't enabled. getR := structs.SessionSpecificRequest{ Datacenter: "dc1", - Session: out, + SessionID: out, } { var sessions structs.IndexedSessions @@ -486,7 +492,7 @@ session "foo" { // Try to get a session that doesn't exist to make sure that's handled // correctly by the filter (it will get passed a nil slice). - getR.Session = "adf4238a-882b-9ddc-4a9d-5b6758e4159e" + getR.SessionID = "adf4238a-882b-9ddc-4a9d-5b6758e4159e" { var sessions structs.IndexedSessions if err := msgpackrpc.CallWithCodec(codec, "Session.Get", &getR, &sessions); err != nil { @@ -503,10 +509,12 @@ func TestSession_ApplyTimers(t *testing.T) { dir1, s1 := testServer(t) defer os.RemoveAll(dir1) defer s1.Shutdown() - testrpc.WaitForTestAgent(t, s1.RPC, "dc1") + codec := rpcClient(t, s1) defer codec.Close() + testrpc.WaitForTestAgent(t, s1.RPC, "dc1") + s1.fsm.State().EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}) arg := structs.SessionRequest{ Datacenter: "dc1", @@ -551,6 +559,7 @@ func TestSession_Renew(t *testing.T) { defer os.RemoveAll(dir1) defer s1.Shutdown() testrpc.WaitForTestAgent(t, s1.RPC, "dc1") + codec := rpcClient(t, s1) defer codec.Close() @@ -613,7 +622,7 @@ func TestSession_Renew(t *testing.T) { for i := 0; i < 3; i++ { renewR := structs.SessionSpecificRequest{ Datacenter: "dc1", - Session: ids[i], + SessionID: ids[i], } var session structs.IndexedSessions if err := msgpackrpc.CallWithCodec(codec, "Session.Renew", &renewR, &session); err != nil { @@ -714,10 +723,12 @@ func TestSession_Renew_ACLDeny(t *testing.T) { }) defer os.RemoveAll(dir1) defer s1.Shutdown() - testrpc.WaitForTestAgent(t, s1.RPC, "dc1") + codec := rpcClient(t, s1) defer codec.Close() + testrpc.WaitForTestAgent(t, s1.RPC, "dc1") + // Create the ACL. req := structs.ACLRequest{ Datacenter: "dc1", @@ -761,7 +772,7 @@ session "foo" { // enforcement. renewR := structs.SessionSpecificRequest{ Datacenter: "dc1", - Session: id, + SessionID: id, } var session structs.IndexedSessions if err := msgpackrpc.CallWithCodec(codec, "Session.Renew", &renewR, &session); err != nil { @@ -787,6 +798,7 @@ func TestSession_NodeSessions(t *testing.T) { dir1, s1 := testServer(t) defer os.RemoveAll(dir1) defer s1.Shutdown() + codec := rpcClient(t, s1) defer codec.Close() @@ -846,6 +858,7 @@ func TestSession_Apply_BadTTL(t *testing.T) { dir1, s1 := testServer(t) defer os.RemoveAll(dir1) defer s1.Shutdown() + codec := rpcClient(t, s1) defer codec.Close() diff --git a/agent/consul/session_ttl.go b/agent/consul/session_ttl.go index fd12701fb5..e5ca429f7b 100644 --- a/agent/consul/session_ttl.go +++ b/agent/consul/session_ttl.go @@ -22,7 +22,8 @@ const ( func (s *Server) initializeSessionTimers() error { // Scan all sessions and reset their timer state := s.fsm.State() - _, sessions, err := state.SessionList(nil) + + _, sessions, err := state.SessionList(nil, structs.WildcardEnterpriseMeta()) if err != nil { return err } @@ -41,7 +42,7 @@ func (s *Server) resetSessionTimer(id string, session *structs.Session) error { // Fault the session in if not given if session == nil { state := s.fsm.State() - _, s, err := state.SessionGet(nil, id) + _, s, err := state.SessionGet(nil, id, nil) if err != nil { return err } @@ -66,11 +67,11 @@ func (s *Server) resetSessionTimer(id string, session *structs.Session) error { return nil } - s.createSessionTimer(session.ID, ttl) + s.createSessionTimer(session.ID, ttl, &session.EnterpriseMeta) return nil } -func (s *Server) createSessionTimer(id string, ttl time.Duration) { +func (s *Server) createSessionTimer(id string, ttl time.Duration, entMeta *structs.EnterpriseMeta) { // Reset the session timer // Adjust the given TTL by the TTL multiplier. This is done // to give a client a grace period and to compensate for network @@ -78,12 +79,12 @@ func (s *Server) createSessionTimer(id string, ttl time.Duration) { // before the TTL, but there is no explicit promise about the upper // bound so this is allowable. ttl = ttl * structs.SessionTTLMultiplier - s.sessionTimers.ResetOrCreate(id, ttl, func() { s.invalidateSession(id) }) + s.sessionTimers.ResetOrCreate(id, ttl, func() { s.invalidateSession(id, entMeta) }) } // invalidateSession is invoked when a session TTL is reached and we // need to invalidate the session. -func (s *Server) invalidateSession(id string) { +func (s *Server) invalidateSession(id string, entMeta *structs.EnterpriseMeta) { defer metrics.MeasureSince([]string{"session_ttl", "invalidate"}, time.Now()) // Clear the session timer @@ -97,6 +98,9 @@ func (s *Server) invalidateSession(id string) { ID: id, }, } + if entMeta != nil { + args.Session.EnterpriseMeta = *entMeta + } // Retry with exponential backoff to invalidate the session for attempt := uint(0); attempt < maxInvalidateAttempts; attempt++ { diff --git a/agent/consul/session_ttl_test.go b/agent/consul/session_ttl_test.go index dfa1b32e56..92e9ae2c55 100644 --- a/agent/consul/session_ttl_test.go +++ b/agent/consul/session_ttl_test.go @@ -157,7 +157,7 @@ func TestResetSessionTimerLocked(t *testing.T) { testrpc.WaitForLeader(t, s1.RPC, "dc1") - s1.createSessionTimer("foo", 5*time.Millisecond) + s1.createSessionTimer("foo", 5*time.Millisecond, nil) if s1.sessionTimers.Get("foo") == nil { t.Fatalf("missing timer") } @@ -178,7 +178,7 @@ func TestResetSessionTimerLocked_Renew(t *testing.T) { retry.Run(t, func(r *retry.R) { // create the timer and make verify it was created - s1.createSessionTimer("foo", ttl) + s1.createSessionTimer("foo", ttl, nil) if s1.sessionTimers.Get("foo") == nil { r.Fatalf("missing timer") } @@ -194,7 +194,7 @@ func TestResetSessionTimerLocked_Renew(t *testing.T) { retry.Run(t, func(r *retry.R) { // renew the session which will reset the TTL to 2*ttl // since that is the current SessionTTLMultiplier - s1.createSessionTimer("foo", ttl) + s1.createSessionTimer("foo", ttl, nil) if s1.sessionTimers.Get("foo") == nil { r.Fatal("missing timer") } @@ -231,6 +231,7 @@ func TestInvalidateSession(t *testing.T) { if err := state.EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil { t.Fatalf("err: %s", err) } + session := &structs.Session{ ID: generateUUID(), Node: "foo", @@ -241,10 +242,10 @@ func TestInvalidateSession(t *testing.T) { } // This should cause a destroy - s1.invalidateSession(session.ID) + s1.invalidateSession(session.ID, nil) // Check it is gone - _, sess, err := state.SessionGet(nil, session.ID) + _, sess, err := state.SessionGet(nil, session.ID, nil) if err != nil { t.Fatalf("err: %v", err) } @@ -259,7 +260,7 @@ func TestClearSessionTimer(t *testing.T) { defer os.RemoveAll(dir1) defer s1.Shutdown() - s1.createSessionTimer("foo", 5*time.Millisecond) + s1.createSessionTimer("foo", 5*time.Millisecond, nil) err := s1.clearSessionTimer("foo") if err != nil { @@ -277,9 +278,9 @@ func TestClearAllSessionTimers(t *testing.T) { defer os.RemoveAll(dir1) defer s1.Shutdown() - s1.createSessionTimer("foo", 10*time.Millisecond) - s1.createSessionTimer("bar", 10*time.Millisecond) - s1.createSessionTimer("baz", 10*time.Millisecond) + s1.createSessionTimer("foo", 10*time.Millisecond, nil) + s1.createSessionTimer("bar", 10*time.Millisecond, nil) + s1.createSessionTimer("baz", 10*time.Millisecond, nil) s1.clearAllSessionTimers() diff --git a/agent/consul/snapshot_endpoint_test.go b/agent/consul/snapshot_endpoint_test.go index f3717b2928..24af70180a 100644 --- a/agent/consul/snapshot_endpoint_test.go +++ b/agent/consul/snapshot_endpoint_test.go @@ -169,7 +169,7 @@ func TestSnapshot_LeaderState(t *testing.T) { defer os.RemoveAll(dir1) defer s1.Shutdown() - testrpc.WaitForLeader(t, s1.RPC, "dc1") + testrpc.WaitForTestAgent(t, s1.RPC, "dc1") codec := rpcClient(t, s1) defer codec.Close() diff --git a/agent/consul/state/catalog.go b/agent/consul/state/catalog.go index 7e7e89fb8f..ce140801f8 100644 --- a/agent/consul/state/catalog.go +++ b/agent/consul/state/catalog.go @@ -8,7 +8,7 @@ import ( "github.com/hashicorp/consul/api" "github.com/hashicorp/consul/types" "github.com/hashicorp/go-memdb" - uuid "github.com/hashicorp/go-uuid" + "github.com/hashicorp/go-uuid" ) const ( @@ -669,6 +669,7 @@ func (s *Store) deleteNodeCASTxn(tx *memdb.Txn, idx, cidx uint64, nodeName strin // deleteNodeTxn is the inner method used for removing a node from // the store within a given transaction. +// TODO (namespaces) (catalog) access to catalog tables needs to become namespace aware for services/checks func (s *Store) deleteNodeTxn(tx *memdb.Txn, idx uint64, nodeName string) error { // Look up the node. node, err := tx.First("nodes", "id", nodeName) @@ -744,19 +745,14 @@ func (s *Store) deleteNodeTxn(tx *memdb.Txn, idx uint64, nodeName string) error } // Invalidate any sessions for this node. - sessions, err := tx.Get("sessions", "node", nodeName) + toDelete, err := s.allNodeSessionsTxn(tx, nodeName) if err != nil { - return fmt.Errorf("failed session lookup: %s", err) - } - var ids []string - for sess := sessions.Next(); sess != nil; sess = sessions.Next() { - ids = append(ids, sess.(*structs.Session).ID) + return err } - // Do the delete in a separate loop so we don't trash the iterator. - for _, id := range ids { - if err := s.deleteSessionTxn(tx, idx, id); err != nil { - return fmt.Errorf("failed session delete: %s", err) + for _, session := range toDelete { + if err := s.deleteSessionTxn(tx, idx, session.ID, &session.EnterpriseMeta); err != nil { + return fmt.Errorf("failed to delete session '%s': %v", session.ID, err) } } @@ -1605,7 +1601,8 @@ func (s *Store) ensureCheckTxn(tx *memdb.Txn, idx uint64, hc *structs.HealthChec // Delete the session in a separate loop so we don't trash the // iterator. for _, id := range ids { - if err := s.deleteSessionTxn(tx, idx, id); err != nil { + // TODO (namespaces): Update when structs.HealthCheck supports Namespaces (&hc.EnterpriseMeta) + if err := s.deleteSessionTxn(tx, idx, id, nil); err != nil { return fmt.Errorf("failed deleting session: %s", err) } } @@ -1917,7 +1914,8 @@ func (s *Store) deleteCheckTxn(tx *memdb.Txn, idx uint64, node string, checkID t // Do the delete in a separate loop so we don't trash the iterator. for _, id := range ids { - if err := s.deleteSessionTxn(tx, idx, id); err != nil { + // TODO (namespaces): Update when structs.HealthCheck supports Namespaces (&hc.EnterpriseMeta) + if err := s.deleteSessionTxn(tx, idx, id, nil); err != nil { return fmt.Errorf("failed deleting session: %s", err) } } diff --git a/agent/consul/state/delay.go b/agent/consul/state/delay.go index 1a253b641e..96aa2d9c30 100644 --- a/agent/consul/state/delay.go +++ b/agent/consul/state/delay.go @@ -1,6 +1,9 @@ +// +build !consulent + package state import ( + "github.com/hashicorp/consul/agent/structs" "sync" "time" ) @@ -32,7 +35,7 @@ func NewDelay() *Delay { // GetExpiration returns the expiration time of a key lock delay. This must be // checked on the leader node, and not in KVSLock due to the variability of // clocks. -func (d *Delay) GetExpiration(key string) time.Time { +func (d *Delay) GetExpiration(key string, entMeta *structs.EnterpriseMeta) time.Time { d.lock.RLock() expires := d.delay[key] d.lock.RUnlock() @@ -41,7 +44,7 @@ func (d *Delay) GetExpiration(key string) time.Time { // SetExpiration sets the expiration time for the lock delay to the given // delay from the given now time. -func (d *Delay) SetExpiration(key string, now time.Time, delay time.Duration) { +func (d *Delay) SetExpiration(key string, now time.Time, delay time.Duration, entMeta *structs.EnterpriseMeta) { d.lock.Lock() defer d.lock.Unlock() diff --git a/agent/consul/state/delay_test.go b/agent/consul/state/delay_test.go index 68f67d3bef..507292f4aa 100644 --- a/agent/consul/state/delay_test.go +++ b/agent/consul/state/delay_test.go @@ -9,21 +9,21 @@ func TestDelay(t *testing.T) { d := NewDelay() // An unknown key should have a time in the past. - if exp := d.GetExpiration("nope"); !exp.Before(time.Now()) { + if exp := d.GetExpiration("nope", nil); !exp.Before(time.Now()) { t.Fatalf("bad: %v", exp) } // Add a key and set a short expiration. now := time.Now() delay := 250 * time.Millisecond - d.SetExpiration("bye", now, delay) - if exp := d.GetExpiration("bye"); !exp.After(now) { + d.SetExpiration("bye", now, delay, nil) + if exp := d.GetExpiration("bye", nil); !exp.After(now) { t.Fatalf("bad: %v", exp) } // Wait for the key to expire and check again. time.Sleep(2 * delay) - if exp := d.GetExpiration("bye"); !exp.Before(now) { + if exp := d.GetExpiration("bye", nil); !exp.Before(now) { t.Fatalf("bad: %v", exp) } } diff --git a/agent/consul/state/graveyard.go b/agent/consul/state/graveyard.go index 0ecd0974b1..96e0c9dbfa 100644 --- a/agent/consul/state/graveyard.go +++ b/agent/consul/state/graveyard.go @@ -2,7 +2,7 @@ package state import ( "fmt" - + "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/go-memdb" ) @@ -10,6 +10,8 @@ import ( type Tombstone struct { Key string Index uint64 + + structs.EnterpriseMeta } // Graveyard manages a set of tombstones. @@ -25,15 +27,18 @@ func NewGraveyard(gc *TombstoneGC) *Graveyard { } // InsertTxn adds a new tombstone. -func (g *Graveyard) InsertTxn(tx *memdb.Txn, key string, idx uint64) error { - // Insert the tombstone. - stone := &Tombstone{Key: key, Index: idx} - if err := tx.Insert("tombstones", stone); err != nil { - return fmt.Errorf("failed inserting tombstone: %s", err) +func (g *Graveyard) InsertTxn(tx *memdb.Txn, key string, idx uint64, entMeta *structs.EnterpriseMeta) error { + stone := &Tombstone{ + Key: key, + Index: idx, + } + if entMeta != nil { + stone.EnterpriseMeta = *entMeta } - if err := tx.Insert("index", &IndexEntry{"tombstones", idx}); err != nil { - return fmt.Errorf("failed updating index: %s", err) + // Insert the tombstone. + if err := g.insertTombstoneWithTxn(tx, "tombstones", stone, false); err != nil { + return fmt.Errorf("failed inserting tombstone: %s", err) } // If GC is configured, then we hint that this index requires reaping. @@ -45,8 +50,8 @@ func (g *Graveyard) InsertTxn(tx *memdb.Txn, key string, idx uint64) error { // GetMaxIndexTxn returns the highest index tombstone whose key matches the // given context, using a prefix match. -func (g *Graveyard) GetMaxIndexTxn(tx *memdb.Txn, prefix string) (uint64, error) { - stones, err := tx.Get("tombstones", "id_prefix", prefix) +func (g *Graveyard) GetMaxIndexTxn(tx *memdb.Txn, prefix string, entMeta *structs.EnterpriseMeta) (uint64, error) { + stones, err := getWithTxn(tx, "tombstones", "id_prefix", prefix, entMeta) if err != nil { return 0, fmt.Errorf("failed querying tombstones: %s", err) } @@ -74,13 +79,10 @@ func (g *Graveyard) DumpTxn(tx *memdb.Txn) (memdb.ResultIterator, error) { // RestoreTxn is used when restoring from a snapshot. For general inserts, use // InsertTxn. func (g *Graveyard) RestoreTxn(tx *memdb.Txn, stone *Tombstone) error { - if err := tx.Insert("tombstones", stone); err != nil { + if err := g.insertTombstoneWithTxn(tx, "tombstones", stone, true); err != nil { return fmt.Errorf("failed inserting tombstone: %s", err) } - if err := indexUpdateMaxTxn(tx, stone.Index, "tombstones"); err != nil { - return fmt.Errorf("failed updating index: %s", err) - } return nil } diff --git a/agent/consul/state/graveyard_oss.go b/agent/consul/state/graveyard_oss.go new file mode 100644 index 0000000000..5637706846 --- /dev/null +++ b/agent/consul/state/graveyard_oss.go @@ -0,0 +1,28 @@ +// +build !consulent + +package state + +import ( + "fmt" + + "github.com/hashicorp/go-memdb" +) + +func (g *Graveyard) insertTombstoneWithTxn(tx *memdb.Txn, + table string, stone *Tombstone, updateMax bool) error { + + if err := tx.Insert("tombstones", stone); err != nil { + return err + } + + if updateMax { + if err := indexUpdateMaxTxn(tx, stone.Index, "tombstones"); err != nil { + return fmt.Errorf("failed updating tombstone index: %v", err) + } + } else { + if err := tx.Insert("index", &IndexEntry{"tombstones", stone.Index}); err != nil { + return fmt.Errorf("failed updating tombstone index: %s", err) + } + } + return nil +} diff --git a/agent/consul/state/graveyard_test.go b/agent/consul/state/graveyard_test.go index 4b7f46e27f..332beaa35d 100644 --- a/agent/consul/state/graveyard_test.go +++ b/agent/consul/state/graveyard_test.go @@ -1,7 +1,6 @@ package state import ( - "reflect" "testing" "time" ) @@ -18,16 +17,16 @@ func TestGraveyard_Lifecycle(t *testing.T) { tx := s.db.Txn(true) defer tx.Abort() - if err := g.InsertTxn(tx, "foo/in/the/house", 2); err != nil { + if err := g.InsertTxn(tx, "foo/in/the/house", 2, nil); err != nil { t.Fatalf("err: %s", err) } - if err := g.InsertTxn(tx, "foo/bar/baz", 5); err != nil { + if err := g.InsertTxn(tx, "foo/bar/baz", 5, nil); err != nil { t.Fatalf("err: %s", err) } - if err := g.InsertTxn(tx, "foo/bar/zoo", 8); err != nil { + if err := g.InsertTxn(tx, "foo/bar/zoo", 8, nil); err != nil { t.Fatalf("err: %s", err) } - if err := g.InsertTxn(tx, "some/other/path", 9); err != nil { + if err := g.InsertTxn(tx, "some/other/path", 9, nil); err != nil { t.Fatalf("err: %s", err) } tx.Commit() @@ -38,25 +37,25 @@ func TestGraveyard_Lifecycle(t *testing.T) { tx := s.db.Txn(false) defer tx.Abort() - if idx, err := g.GetMaxIndexTxn(tx, "foo"); idx != 8 || err != nil { + if idx, err := g.GetMaxIndexTxn(tx, "foo", nil); idx != 8 || err != nil { t.Fatalf("bad: %d (%s)", idx, err) } - if idx, err := g.GetMaxIndexTxn(tx, "foo/in/the/house"); idx != 2 || err != nil { + if idx, err := g.GetMaxIndexTxn(tx, "foo/in/the/house", nil); idx != 2 || err != nil { t.Fatalf("bad: %d (%s)", idx, err) } - if idx, err := g.GetMaxIndexTxn(tx, "foo/bar/baz"); idx != 5 || err != nil { + if idx, err := g.GetMaxIndexTxn(tx, "foo/bar/baz", nil); idx != 5 || err != nil { t.Fatalf("bad: %d (%s)", idx, err) } - if idx, err := g.GetMaxIndexTxn(tx, "foo/bar/zoo"); idx != 8 || err != nil { + if idx, err := g.GetMaxIndexTxn(tx, "foo/bar/zoo", nil); idx != 8 || err != nil { t.Fatalf("bad: %d (%s)", idx, err) } - if idx, err := g.GetMaxIndexTxn(tx, "some/other/path"); idx != 9 || err != nil { + if idx, err := g.GetMaxIndexTxn(tx, "some/other/path", nil); idx != 9 || err != nil { t.Fatalf("bad: %d (%s)", idx, err) } - if idx, err := g.GetMaxIndexTxn(tx, ""); idx != 9 || err != nil { + if idx, err := g.GetMaxIndexTxn(tx, "", nil); idx != 9 || err != nil { t.Fatalf("bad: %d (%s)", idx, err) } - if idx, err := g.GetMaxIndexTxn(tx, "nope"); idx != 0 || err != nil { + if idx, err := g.GetMaxIndexTxn(tx, "nope", nil); idx != 0 || err != nil { t.Fatalf("bad: %d (%s)", idx, err) } }() @@ -77,25 +76,25 @@ func TestGraveyard_Lifecycle(t *testing.T) { tx := s.db.Txn(false) defer tx.Abort() - if idx, err := g.GetMaxIndexTxn(tx, "foo"); idx != 8 || err != nil { + if idx, err := g.GetMaxIndexTxn(tx, "foo", nil); idx != 8 || err != nil { t.Fatalf("bad: %d (%s)", idx, err) } - if idx, err := g.GetMaxIndexTxn(tx, "foo/in/the/house"); idx != 0 || err != nil { + if idx, err := g.GetMaxIndexTxn(tx, "foo/in/the/house", nil); idx != 0 || err != nil { t.Fatalf("bad: %d (%s)", idx, err) } - if idx, err := g.GetMaxIndexTxn(tx, "foo/bar/baz"); idx != 0 || err != nil { + if idx, err := g.GetMaxIndexTxn(tx, "foo/bar/baz", nil); idx != 0 || err != nil { t.Fatalf("bad: %d (%s)", idx, err) } - if idx, err := g.GetMaxIndexTxn(tx, "foo/bar/zoo"); idx != 8 || err != nil { + if idx, err := g.GetMaxIndexTxn(tx, "foo/bar/zoo", nil); idx != 8 || err != nil { t.Fatalf("bad: %d (%s)", idx, err) } - if idx, err := g.GetMaxIndexTxn(tx, "some/other/path"); idx != 9 || err != nil { + if idx, err := g.GetMaxIndexTxn(tx, "some/other/path", nil); idx != 9 || err != nil { t.Fatalf("bad: %d (%s)", idx, err) } - if idx, err := g.GetMaxIndexTxn(tx, ""); idx != 9 || err != nil { + if idx, err := g.GetMaxIndexTxn(tx, "", nil); idx != 9 || err != nil { t.Fatalf("bad: %d (%s)", idx, err) } - if idx, err := g.GetMaxIndexTxn(tx, "nope"); idx != 0 || err != nil { + if idx, err := g.GetMaxIndexTxn(tx, "nope", nil); idx != 0 || err != nil { t.Fatalf("bad: %d (%s)", idx, err) } }() @@ -125,7 +124,7 @@ func TestGraveyard_GC_Trigger(t *testing.T) { tx := s.db.Txn(true) defer tx.Abort() - if err := g.InsertTxn(tx, "foo/in/the/house", 2); err != nil { + if err := g.InsertTxn(tx, "foo/in/the/house", 2, nil); err != nil { t.Fatalf("err: %s", err) } }() @@ -140,7 +139,7 @@ func TestGraveyard_GC_Trigger(t *testing.T) { tx := s.db.Txn(true) defer tx.Abort() - if err := g.InsertTxn(tx, "foo/in/the/house", 2); err != nil { + if err := g.InsertTxn(tx, "foo/in/the/house", 2, nil); err != nil { t.Fatalf("err: %s", err) } tx.Commit() @@ -174,16 +173,16 @@ func TestGraveyard_Snapshot_Restore(t *testing.T) { tx := s.db.Txn(true) defer tx.Abort() - if err := g.InsertTxn(tx, "foo/in/the/house", 2); err != nil { + if err := g.InsertTxn(tx, "foo/in/the/house", 2, nil); err != nil { t.Fatalf("err: %s", err) } - if err := g.InsertTxn(tx, "foo/bar/baz", 5); err != nil { + if err := g.InsertTxn(tx, "foo/bar/baz", 5, nil); err != nil { t.Fatalf("err: %s", err) } - if err := g.InsertTxn(tx, "foo/bar/zoo", 8); err != nil { + if err := g.InsertTxn(tx, "foo/bar/zoo", 8, nil); err != nil { t.Fatalf("err: %s", err) } - if err := g.InsertTxn(tx, "some/other/path", 9); err != nil { + if err := g.InsertTxn(tx, "some/other/path", 9, nil); err != nil { t.Fatalf("err: %s", err) } tx.Commit() @@ -217,8 +216,16 @@ func TestGraveyard_Snapshot_Restore(t *testing.T) { &Tombstone{Key: "foo/in/the/house", Index: 2}, &Tombstone{Key: "some/other/path", Index: 9}, } - if !reflect.DeepEqual(dump, expected) { - t.Fatalf("bad: %v", dump) + if len(expected) != len(dump) { + t.Fatalf("expected %d, got %d tombstones", len(expected), len(dump)) + } + for i, e := range expected { + if e.Key != dump[i].Key { + t.Fatalf("expected key %s, got %s", e.Key, dump[i].Key) + } + if e.Index != dump[i].Index { + t.Fatalf("expected key %s, got %s", e.Key, dump[i].Key) + } } // Make another state store and restore from the dump. @@ -255,8 +262,16 @@ func TestGraveyard_Snapshot_Restore(t *testing.T) { } return dump }() - if !reflect.DeepEqual(dump, expected) { - t.Fatalf("bad: %v", dump) + if len(expected) != len(dump) { + t.Fatalf("expected %d, got %d tombstones", len(expected), len(dump)) + } + for i, e := range expected { + if e.Key != dump[i].Key { + t.Fatalf("expected key %s, got %s", e.Key, dump[i].Key) + } + if e.Index != dump[i].Index { + t.Fatalf("expected idx %d, got %d", e.Index, dump[i].Index) + } } }() } diff --git a/agent/consul/state/kvs.go b/agent/consul/state/kvs.go index eea081c91d..61f40897a9 100644 --- a/agent/consul/state/kvs.go +++ b/agent/consul/state/kvs.go @@ -2,7 +2,6 @@ package state import ( "fmt" - "strings" "time" "github.com/hashicorp/consul/agent/structs" @@ -19,10 +18,7 @@ func kvsTableSchema() *memdb.TableSchema { Name: "id", AllowMissing: false, Unique: true, - Indexer: &memdb.StringFieldIndex{ - Field: "Key", - Lowercase: false, - }, + Indexer: kvsIndexer(), }, "session": &memdb.IndexSchema{ Name: "session", @@ -46,10 +42,7 @@ func tombstonesTableSchema() *memdb.TableSchema { Name: "id", AllowMissing: false, Unique: true, - Indexer: &memdb.StringFieldIndex{ - Field: "Key", - Lowercase: false, - }, + Indexer: kvsIndexer(), }, }, } @@ -76,13 +69,10 @@ func (s *Snapshot) Tombstones() (memdb.ResultIterator, error) { // KVS is used when restoring from a snapshot. Use KVSSet for general inserts. func (s *Restore) KVS(entry *structs.DirEntry) error { - if err := s.tx.Insert("kvs", entry); err != nil { + if err := s.store.insertKVTxn(s.tx, entry, true); err != nil { return fmt.Errorf("failed inserting kvs entry: %s", err) } - if err := indexUpdateMaxTxn(s.tx, entry.ModifyIndex, "kvs"); err != nil { - return fmt.Errorf("failed updating index: %s", err) - } return nil } @@ -131,7 +121,7 @@ func (s *Store) KVSSet(idx uint64, entry *structs.DirEntry) error { // whatever the existing session is. func (s *Store) kvsSetTxn(tx *memdb.Txn, idx uint64, entry *structs.DirEntry, updateSession bool) error { // Retrieve an existing KV pair - existingNode, err := tx.First("kvs", "id", entry.Key) + existingNode, err := firstWithTxn(tx, "kvs", "id", entry.Key, &entry.EnterpriseMeta) if err != nil { return fmt.Errorf("failed kvs lookup: %s", err) } @@ -161,32 +151,31 @@ func (s *Store) kvsSetTxn(tx *memdb.Txn, idx uint64, entry *structs.DirEntry, up } // Store the kv pair in the state store and update the index. - if err := tx.Insert("kvs", entry); err != nil { + if err := s.insertKVTxn(tx, entry, false); err != nil { return fmt.Errorf("failed inserting kvs entry: %s", err) } - if err := tx.Insert("index", &IndexEntry{"kvs", idx}); err != nil { - return fmt.Errorf("failed updating index: %s", err) - } return nil } // KVSGet is used to retrieve a key/value pair from the state store. -func (s *Store) KVSGet(ws memdb.WatchSet, key string) (uint64, *structs.DirEntry, error) { +func (s *Store) KVSGet(ws memdb.WatchSet, key string, entMeta *structs.EnterpriseMeta) (uint64, *structs.DirEntry, error) { tx := s.db.Txn(false) defer tx.Abort() - return s.kvsGetTxn(tx, ws, key) + return s.kvsGetTxn(tx, ws, key, entMeta) } // kvsGetTxn is the inner method that gets a KVS entry inside an existing // transaction. -func (s *Store) kvsGetTxn(tx *memdb.Txn, ws memdb.WatchSet, key string) (uint64, *structs.DirEntry, error) { +func (s *Store) kvsGetTxn(tx *memdb.Txn, + ws memdb.WatchSet, key string, entMeta *structs.EnterpriseMeta) (uint64, *structs.DirEntry, error) { + // Get the table index. - idx := maxIndexTxn(tx, "kvs", "tombstones") + idx := kvsMaxIndex(tx, entMeta) // Retrieve the key. - watchCh, entry, err := tx.FirstWatch("kvs", "id", key) + watchCh, entry, err := firstWatchWithTxn(tx, "kvs", "id", key, entMeta) if err != nil { return 0, nil, fmt.Errorf("failed kvs lookup: %s", err) } @@ -201,41 +190,32 @@ func (s *Store) kvsGetTxn(tx *memdb.Txn, ws memdb.WatchSet, key string) (uint64, // prefix is left empty, all keys in the KVS will be returned. The returned // is the max index of the returned kvs entries or applicable tombstones, or // else it's the full table indexes for kvs and tombstones. -func (s *Store) KVSList(ws memdb.WatchSet, prefix string) (uint64, structs.DirEntries, error) { +func (s *Store) KVSList(ws memdb.WatchSet, + prefix string, entMeta *structs.EnterpriseMeta) (uint64, structs.DirEntries, error) { + tx := s.db.Txn(false) defer tx.Abort() - return s.kvsListTxn(tx, ws, prefix) + return s.kvsListTxn(tx, ws, prefix, entMeta) } // kvsListTxn is the inner method that gets a list of KVS entries matching a // prefix. -func (s *Store) kvsListTxn(tx *memdb.Txn, ws memdb.WatchSet, prefix string) (uint64, structs.DirEntries, error) { - // Get the table indexes. - idx := maxIndexTxn(tx, "kvs", "tombstones") +func (s *Store) kvsListTxn(tx *memdb.Txn, + ws memdb.WatchSet, prefix string, entMeta *structs.EnterpriseMeta) (uint64, structs.DirEntries, error) { - // Query the prefix and list the available keys - entries, err := tx.Get("kvs", "id_prefix", prefix) + // Get the table indexes. + idx := kvsMaxIndex(tx, entMeta) + + lindex, entries, err := s.kvsListEntriesTxn(tx, ws, prefix, entMeta) if err != nil { return 0, nil, fmt.Errorf("failed kvs lookup: %s", err) } - ws.Add(entries.WatchCh()) - - // Gather all of the keys found in the store - var ents structs.DirEntries - var lindex uint64 - for entry := entries.Next(); entry != nil; entry = entries.Next() { - e := entry.(*structs.DirEntry) - ents = append(ents, e) - if e.ModifyIndex > lindex { - lindex = e.ModifyIndex - } - } // Check for the highest index in the graveyard. If the prefix is empty // then just use the full table indexes since we are listing everything. if prefix != "" { - gindex, err := s.kvsGraveyard.GetMaxIndexTxn(tx, prefix) + gindex, err := s.kvsGraveyard.GetMaxIndexTxn(tx, prefix, entMeta) if err != nil { return 0, nil, fmt.Errorf("failed graveyard lookup: %s", err) } @@ -251,92 +231,17 @@ func (s *Store) kvsListTxn(tx *memdb.Txn, ws memdb.WatchSet, prefix string) (uin if lindex != 0 { idx = lindex } - return idx, ents, nil -} - -// KVSListKeys is used to query the KV store for keys matching the given prefix. -// An optional separator may be specified, which can be used to slice off a part -// of the response so that only a subset of the prefix is returned. In this -// mode, the keys which are omitted are still counted in the returned index. -func (s *Store) KVSListKeys(ws memdb.WatchSet, prefix, sep string) (uint64, []string, error) { - tx := s.db.Txn(false) - defer tx.Abort() - - // Get the table indexes. - idx := maxIndexTxn(tx, "kvs", "tombstones") - - // Fetch keys using the specified prefix - entries, err := tx.Get("kvs", "id_prefix", prefix) - if err != nil { - return 0, nil, fmt.Errorf("failed kvs lookup: %s", err) - } - ws.Add(entries.WatchCh()) - - prefixLen := len(prefix) - sepLen := len(sep) - - var keys []string - var lindex uint64 - var last string - for entry := entries.Next(); entry != nil; entry = entries.Next() { - e := entry.(*structs.DirEntry) - - // Accumulate the high index - if e.ModifyIndex > lindex { - lindex = e.ModifyIndex - } - - // Always accumulate if no separator provided - if sepLen == 0 { - keys = append(keys, e.Key) - continue - } - - // Parse and de-duplicate the returned keys based on the - // key separator, if provided. - after := e.Key[prefixLen:] - sepIdx := strings.Index(after, sep) - if sepIdx > -1 { - key := e.Key[:prefixLen+sepIdx+sepLen] - if key != last { - keys = append(keys, key) - last = key - } - } else { - keys = append(keys, e.Key) - } - } - - // Check for the highest index in the graveyard. If the prefix is empty - // then just use the full table indexes since we are listing everything. - if prefix != "" { - gindex, err := s.kvsGraveyard.GetMaxIndexTxn(tx, prefix) - if err != nil { - return 0, nil, fmt.Errorf("failed graveyard lookup: %s", err) - } - if gindex > lindex { - lindex = gindex - } - } else { - lindex = idx - } - - // Use the sub index if it was set and there are entries, otherwise use - // the full table index from above. - if lindex != 0 { - idx = lindex - } - return idx, keys, nil + return idx, entries, nil } // KVSDelete is used to perform a shallow delete on a single key in the // the state store. -func (s *Store) KVSDelete(idx uint64, key string) error { +func (s *Store) KVSDelete(idx uint64, key string, entMeta *structs.EnterpriseMeta) error { tx := s.db.Txn(true) defer tx.Abort() // Perform the actual delete - if err := s.kvsDeleteTxn(tx, idx, key); err != nil { + if err := s.kvsDeleteTxn(tx, idx, key, entMeta); err != nil { return err } @@ -346,9 +251,9 @@ func (s *Store) KVSDelete(idx uint64, key string) error { // kvsDeleteTxn is the inner method used to perform the actual deletion // of a key/value pair within an existing transaction. -func (s *Store) kvsDeleteTxn(tx *memdb.Txn, idx uint64, key string) error { +func (s *Store) kvsDeleteTxn(tx *memdb.Txn, idx uint64, key string, entMeta *structs.EnterpriseMeta) error { // Look up the entry in the state store. - entry, err := tx.First("kvs", "id", key) + entry, err := firstWithTxn(tx, "kvs", "id", key, entMeta) if err != nil { return fmt.Errorf("failed kvs lookup: %s", err) } @@ -357,30 +262,22 @@ func (s *Store) kvsDeleteTxn(tx *memdb.Txn, idx uint64, key string) error { } // Create a tombstone. - if err := s.kvsGraveyard.InsertTxn(tx, key, idx); err != nil { + if err := s.kvsGraveyard.InsertTxn(tx, key, idx, entMeta); err != nil { return fmt.Errorf("failed adding to graveyard: %s", err) } - // Delete the entry and update the index. - if err := tx.Delete("kvs", entry); err != nil { - return fmt.Errorf("failed deleting kvs entry: %s", err) - } - if err := tx.Insert("index", &IndexEntry{"kvs", idx}); err != nil { - return fmt.Errorf("failed updating index: %s", err) - } - - return nil + return s.kvsDeleteWithEntry(tx, entry.(*structs.DirEntry), idx) } // KVSDeleteCAS is used to try doing a KV delete operation with a given // raft index. If the CAS index specified is not equal to the last // observed index for the given key, then the call is a noop, otherwise // a normal KV delete is invoked. -func (s *Store) KVSDeleteCAS(idx, cidx uint64, key string) (bool, error) { +func (s *Store) KVSDeleteCAS(idx, cidx uint64, key string, entMeta *structs.EnterpriseMeta) (bool, error) { tx := s.db.Txn(true) defer tx.Abort() - set, err := s.kvsDeleteCASTxn(tx, idx, cidx, key) + set, err := s.kvsDeleteCASTxn(tx, idx, cidx, key, entMeta) if !set || err != nil { return false, err } @@ -391,9 +288,9 @@ func (s *Store) KVSDeleteCAS(idx, cidx uint64, key string) (bool, error) { // kvsDeleteCASTxn is the inner method that does a CAS delete within an existing // transaction. -func (s *Store) kvsDeleteCASTxn(tx *memdb.Txn, idx, cidx uint64, key string) (bool, error) { +func (s *Store) kvsDeleteCASTxn(tx *memdb.Txn, idx, cidx uint64, key string, entMeta *structs.EnterpriseMeta) (bool, error) { // Retrieve the existing kvs entry, if any exists. - entry, err := tx.First("kvs", "id", key) + entry, err := firstWithTxn(tx, "kvs", "id", key, entMeta) if err != nil { return false, fmt.Errorf("failed kvs lookup: %s", err) } @@ -407,7 +304,7 @@ func (s *Store) kvsDeleteCASTxn(tx *memdb.Txn, idx, cidx uint64, key string) (bo } // Call the actual deletion if the above passed. - if err := s.kvsDeleteTxn(tx, idx, key); err != nil { + if err := s.kvsDeleteTxn(tx, idx, key, entMeta); err != nil { return false, err } return true, nil @@ -434,7 +331,7 @@ func (s *Store) KVSSetCAS(idx uint64, entry *structs.DirEntry) (bool, error) { // transaction. func (s *Store) kvsSetCASTxn(tx *memdb.Txn, idx uint64, entry *structs.DirEntry) (bool, error) { // Retrieve the existing entry. - existing, err := tx.First("kvs", "id", entry.Key) + existing, err := firstWithTxn(tx, "kvs", "id", entry.Key, &entry.EnterpriseMeta) if err != nil { return false, fmt.Errorf("failed kvs lookup: %s", err) } @@ -462,11 +359,11 @@ func (s *Store) kvsSetCASTxn(tx *memdb.Txn, idx uint64, entry *structs.DirEntry) // KVSDeleteTree is used to do a recursive delete on a key prefix // in the state store. If any keys are modified, the last index is // set, otherwise this is a no-op. -func (s *Store) KVSDeleteTree(idx uint64, prefix string) error { +func (s *Store) KVSDeleteTree(idx uint64, prefix string, entMeta *structs.EnterpriseMeta) error { tx := s.db.Txn(true) defer tx.Abort() - if err := s.kvsDeleteTreeTxn(tx, idx, prefix); err != nil { + if err := s.kvsDeleteTreeTxn(tx, idx, prefix, entMeta); err != nil { return err } @@ -474,35 +371,10 @@ func (s *Store) KVSDeleteTree(idx uint64, prefix string) error { return nil } -// kvsDeleteTreeTxn is the inner method that does a recursive delete inside an -// existing transaction. -func (s *Store) kvsDeleteTreeTxn(tx *memdb.Txn, idx uint64, prefix string) error { - - // For prefix deletes, only insert one tombstone and delete the entire subtree - - deleted, err := tx.DeletePrefix("kvs", "id_prefix", prefix) - - if err != nil { - return fmt.Errorf("failed recursive deleting kvs entry: %s", err) - } - - if deleted { - if prefix != "" { // don't insert a tombstone if the entire tree is deleted, all watchers on keys will see the max_index of the tree - if err := s.kvsGraveyard.InsertTxn(tx, prefix, idx); err != nil { - return fmt.Errorf("failed adding to graveyard: %s", err) - } - } - if err := tx.Insert("index", &IndexEntry{"kvs", idx}); err != nil { - return fmt.Errorf("failed updating index: %s", err) - } - } - return nil -} - // KVSLockDelay returns the expiration time for any lock delay associated with // the given key. -func (s *Store) KVSLockDelay(key string) time.Time { - return s.lockDelay.GetExpiration(key) +func (s *Store) KVSLockDelay(key string, entMeta *structs.EnterpriseMeta) time.Time { + return s.lockDelay.GetExpiration(key, entMeta) } // KVSLock is similar to KVSSet but only performs the set if the lock can be @@ -529,7 +401,7 @@ func (s *Store) kvsLockTxn(tx *memdb.Txn, idx uint64, entry *structs.DirEntry) ( } // Verify that the session exists. - sess, err := tx.First("sessions", "id", entry.Session) + sess, err := firstWithTxn(tx, "sessions", "id", entry.Session, &entry.EnterpriseMeta) if err != nil { return false, fmt.Errorf("failed session lookup: %s", err) } @@ -538,7 +410,7 @@ func (s *Store) kvsLockTxn(tx *memdb.Txn, idx uint64, entry *structs.DirEntry) ( } // Retrieve the existing entry. - existing, err := tx.First("kvs", "id", entry.Key) + existing, err := firstWithTxn(tx, "kvs", "id", entry.Key, &entry.EnterpriseMeta) if err != nil { return false, fmt.Errorf("failed kvs lookup: %s", err) } @@ -595,7 +467,7 @@ func (s *Store) kvsUnlockTxn(tx *memdb.Txn, idx uint64, entry *structs.DirEntry) } // Retrieve the existing entry. - existing, err := tx.First("kvs", "id", entry.Key) + existing, err := firstWithTxn(tx, "kvs", "id", entry.Key, &entry.EnterpriseMeta) if err != nil { return false, fmt.Errorf("failed kvs lookup: %s", err) } @@ -626,8 +498,10 @@ func (s *Store) kvsUnlockTxn(tx *memdb.Txn, idx uint64, entry *structs.DirEntry) // kvsCheckSessionTxn checks to see if the given session matches the current // entry for a key. -func (s *Store) kvsCheckSessionTxn(tx *memdb.Txn, key string, session string) (*structs.DirEntry, error) { - entry, err := tx.First("kvs", "id", key) +func (s *Store) kvsCheckSessionTxn(tx *memdb.Txn, + key string, session string, entMeta *structs.EnterpriseMeta) (*structs.DirEntry, error) { + + entry, err := firstWithTxn(tx, "kvs", "id", key, entMeta) if err != nil { return nil, fmt.Errorf("failed kvs lookup: %s", err) } @@ -645,8 +519,10 @@ func (s *Store) kvsCheckSessionTxn(tx *memdb.Txn, key string, session string) (* // kvsCheckIndexTxn checks to see if the given modify index matches the current // entry for a key. -func (s *Store) kvsCheckIndexTxn(tx *memdb.Txn, key string, cidx uint64) (*structs.DirEntry, error) { - entry, err := tx.First("kvs", "id", key) +func (s *Store) kvsCheckIndexTxn(tx *memdb.Txn, + key string, cidx uint64, entMeta *structs.EnterpriseMeta) (*structs.DirEntry, error) { + + entry, err := firstWithTxn(tx, "kvs", "id", key, entMeta) if err != nil { return nil, fmt.Errorf("failed kvs lookup: %s", err) } diff --git a/agent/consul/state/kvs_oss.go b/agent/consul/state/kvs_oss.go new file mode 100644 index 0000000000..76dcb2ab56 --- /dev/null +++ b/agent/consul/state/kvs_oss.go @@ -0,0 +1,95 @@ +// +build !consulent + +package state + +import ( + "fmt" + + "github.com/hashicorp/consul/agent/structs" + "github.com/hashicorp/go-memdb" +) + +func kvsIndexer() *memdb.StringFieldIndex { + return &memdb.StringFieldIndex{ + Field: "Key", + Lowercase: false, + } +} + +func (s *Store) insertKVTxn(tx *memdb.Txn, entry *structs.DirEntry, updateMax bool) error { + if err := tx.Insert("kvs", entry); err != nil { + return err + } + + if updateMax { + if err := indexUpdateMaxTxn(tx, entry.ModifyIndex, "kvs"); err != nil { + return fmt.Errorf("failed updating kvs index: %v", err) + } + } else { + if err := tx.Insert("index", &IndexEntry{"kvs", entry.ModifyIndex}); err != nil { + return fmt.Errorf("failed updating kvs index: %s", err) + } + } + return nil +} + +func (s *Store) kvsListEntriesTxn(tx *memdb.Txn, ws memdb.WatchSet, prefix string, entMeta *structs.EnterpriseMeta) (uint64, structs.DirEntries, error) { + var ents structs.DirEntries + var lindex uint64 + + entries, err := tx.Get("kvs", "id_prefix", prefix) + if err != nil { + return 0, nil, fmt.Errorf("failed kvs lookup: %s", err) + } + ws.Add(entries.WatchCh()) + + // Gather all of the keys found + for entry := entries.Next(); entry != nil; entry = entries.Next() { + e := entry.(*structs.DirEntry) + ents = append(ents, e) + if e.ModifyIndex > lindex { + lindex = e.ModifyIndex + } + } + return lindex, ents, nil +} + +// kvsDeleteTreeTxn is the inner method that does a recursive delete inside an +// existing transaction. +func (s *Store) kvsDeleteTreeTxn(tx *memdb.Txn, idx uint64, prefix string, entMeta *structs.EnterpriseMeta) error { + // For prefix deletes, only insert one tombstone and delete the entire subtree + deleted, err := tx.DeletePrefix("kvs", "id_prefix", prefix) + if err != nil { + return fmt.Errorf("failed recursive deleting kvs entry: %s", err) + } + + if deleted { + if prefix != "" { // don't insert a tombstone if the entire tree is deleted, all watchers on keys will see the max_index of the tree + if err := s.kvsGraveyard.InsertTxn(tx, prefix, idx, entMeta); err != nil { + return fmt.Errorf("failed adding to graveyard: %s", err) + } + } + + if err := tx.Insert("index", &IndexEntry{"kvs", idx}); err != nil { + return fmt.Errorf("failed updating index: %s", err) + } + } + return nil +} + +func kvsMaxIndex(tx *memdb.Txn, entMeta *structs.EnterpriseMeta) uint64 { + return maxIndexTxn(tx, "kvs", "tombstones") +} + +func (s *Store) kvsDeleteWithEntry(tx *memdb.Txn, entry *structs.DirEntry, idx uint64) error { + // Delete the entry and update the index. + if err := tx.Delete("kvs", entry); err != nil { + return fmt.Errorf("failed deleting kvs entry: %s", err) + } + + if err := tx.Insert("index", &IndexEntry{"kvs", idx}); err != nil { + return fmt.Errorf("failed updating kvs index: %s", err) + } + + return nil +} diff --git a/agent/consul/state/kvs_test.go b/agent/consul/state/kvs_test.go index e9f088c6df..5de0f72576 100644 --- a/agent/consul/state/kvs_test.go +++ b/agent/consul/state/kvs_test.go @@ -12,6 +12,74 @@ import ( "github.com/hashicorp/go-memdb" ) +func TestStateStore_ReapTombstones(t *testing.T) { + s := testStateStore(t) + + // Create some KV pairs. + testSetKey(t, s, 1, "foo", "foo", nil) + testSetKey(t, s, 2, "foo/bar", "bar", nil) + testSetKey(t, s, 3, "foo/baz", "bar", nil) + testSetKey(t, s, 4, "foo/moo", "bar", nil) + testSetKey(t, s, 5, "foo/zoo", "bar", nil) + + // Call a delete on some specific keys. + if err := s.KVSDelete(6, "foo/baz", nil); err != nil { + t.Fatalf("err: %s", err) + } + if err := s.KVSDelete(7, "foo/moo", nil); err != nil { + t.Fatalf("err: %s", err) + } + + // Pull out the list and check the index, which should come from the + // tombstones. + idx, _, err := s.KVSList(nil, "foo/", nil) + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 7 { + t.Fatalf("bad index: %d", idx) + } + + // Reap the tombstones <= 6. + if err := s.ReapTombstones(6); err != nil { + t.Fatalf("err: %s", err) + } + + // Should still be good because 7 is in there. + idx, _, err = s.KVSList(nil, "foo/", nil) + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 7 { + t.Fatalf("bad index: %d", idx) + } + + // Now reap them all. + if err := s.ReapTombstones(7); err != nil { + t.Fatalf("err: %s", err) + } + + // At this point the sub index will slide backwards. + idx, _, err = s.KVSList(nil, "foo/", nil) + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 5 { + t.Fatalf("bad index: %d", idx) + } + + // Make sure the tombstones are actually gone. + snap := s.Snapshot() + defer snap.Close() + stones, err := snap.Tombstones() + if err != nil { + t.Fatalf("err: %s", err) + } + if stones.Next() != nil { + t.Fatalf("unexpected extra tombstones") + } +} + func TestStateStore_GC(t *testing.T) { // Build up a fast GC. ttl := 10 * time.Millisecond @@ -29,14 +97,14 @@ func TestStateStore_GC(t *testing.T) { } // Create some KV pairs. - testSetKey(t, s, 1, "foo", "foo") - testSetKey(t, s, 2, "foo/bar", "bar") - testSetKey(t, s, 3, "foo/baz", "bar") - testSetKey(t, s, 4, "foo/moo", "bar") - testSetKey(t, s, 5, "foo/zoo", "bar") + testSetKey(t, s, 1, "foo", "foo", nil) + testSetKey(t, s, 2, "foo/bar", "bar", nil) + testSetKey(t, s, 3, "foo/baz", "bar", nil) + testSetKey(t, s, 4, "foo/moo", "bar", nil) + testSetKey(t, s, 5, "foo/zoo", "bar", nil) // Delete a key and make sure the GC sees it. - if err := s.KVSDelete(6, "foo/zoo"); err != nil { + if err := s.KVSDelete(6, "foo/zoo", nil); err != nil { t.Fatalf("err: %s", err) } select { @@ -49,7 +117,7 @@ func TestStateStore_GC(t *testing.T) { } // Check for the same behavior with a tree delete. - if err := s.KVSDeleteTree(7, "foo/moo"); err != nil { + if err := s.KVSDeleteTree(7, "foo/moo", nil); err != nil { t.Fatalf("err: %s", err) } select { @@ -62,7 +130,7 @@ func TestStateStore_GC(t *testing.T) { } // Check for the same behavior with a CAS delete. - if ok, err := s.KVSDeleteCAS(8, 3, "foo/baz"); !ok || err != nil { + if ok, err := s.KVSDeleteCAS(8, 3, "foo/baz", nil); !ok || err != nil { t.Fatalf("err: %s", err) } select { @@ -91,7 +159,7 @@ func TestStateStore_GC(t *testing.T) { if ok, err := s.KVSLock(11, d); !ok || err != nil { t.Fatalf("err: %v", err) } - if err := s.SessionDestroy(12, session.ID); err != nil { + if err := s.SessionDestroy(12, session.ID, nil); err != nil { t.Fatalf("err: %s", err) } select { @@ -104,80 +172,12 @@ func TestStateStore_GC(t *testing.T) { } } -func TestStateStore_ReapTombstones(t *testing.T) { - s := testStateStore(t) - - // Create some KV pairs. - testSetKey(t, s, 1, "foo", "foo") - testSetKey(t, s, 2, "foo/bar", "bar") - testSetKey(t, s, 3, "foo/baz", "bar") - testSetKey(t, s, 4, "foo/moo", "bar") - testSetKey(t, s, 5, "foo/zoo", "bar") - - // Call a delete on some specific keys. - if err := s.KVSDelete(6, "foo/baz"); err != nil { - t.Fatalf("err: %s", err) - } - if err := s.KVSDelete(7, "foo/moo"); err != nil { - t.Fatalf("err: %s", err) - } - - // Pull out the list and check the index, which should come from the - // tombstones. - idx, _, err := s.KVSList(nil, "foo/") - if err != nil { - t.Fatalf("err: %s", err) - } - if idx != 7 { - t.Fatalf("bad index: %d", idx) - } - - // Reap the tombstones <= 6. - if err := s.ReapTombstones(6); err != nil { - t.Fatalf("err: %s", err) - } - - // Should still be good because 7 is in there. - idx, _, err = s.KVSList(nil, "foo/") - if err != nil { - t.Fatalf("err: %s", err) - } - if idx != 7 { - t.Fatalf("bad index: %d", idx) - } - - // Now reap them all. - if err := s.ReapTombstones(7); err != nil { - t.Fatalf("err: %s", err) - } - - // At this point the sub index will slide backwards. - idx, _, err = s.KVSList(nil, "foo/") - if err != nil { - t.Fatalf("err: %s", err) - } - if idx != 5 { - t.Fatalf("bad index: %d", idx) - } - - // Make sure the tombstones are actually gone. - snap := s.Snapshot() - defer snap.Close() - stones, err := snap.Tombstones() - if err != nil { - t.Fatalf("err: %s", err) - } - if stones.Next() != nil { - t.Fatalf("unexpected extra tombstones") - } -} - func TestStateStore_KVSSet_KVSGet(t *testing.T) { s := testStateStore(t) // Get on an nonexistent key returns nil. ws := memdb.NewWatchSet() - idx, result, err := s.KVSGet(ws, "foo") + idx, result, err := s.KVSGet(ws, "foo", nil) if result != nil || err != nil || idx != 0 { t.Fatalf("expected (0, nil, nil), got : (%#v, %#v, %#v)", idx, result, err) } @@ -196,7 +196,7 @@ func TestStateStore_KVSSet_KVSGet(t *testing.T) { // Retrieve the K/V entry again. ws = memdb.NewWatchSet() - idx, result, err = s.KVSGet(ws, "foo") + idx, result, err = s.KVSGet(ws, "foo", nil) if err != nil { t.Fatalf("err: %s", err) } @@ -231,7 +231,7 @@ func TestStateStore_KVSSet_KVSGet(t *testing.T) { // Fetch the kv pair and check. ws = memdb.NewWatchSet() - idx, result, err = s.KVSGet(ws, "foo") + idx, result, err = s.KVSGet(ws, "foo", nil) if err != nil { t.Fatalf("err: %s", err) } @@ -260,7 +260,7 @@ func TestStateStore_KVSSet_KVSGet(t *testing.T) { // Fetch the kv pair and check. ws = memdb.NewWatchSet() - idx, result, err = s.KVSGet(ws, "foo") + idx, result, err = s.KVSGet(ws, "foo", nil) if err != nil { t.Fatalf("err: %s", err) } @@ -298,7 +298,7 @@ func TestStateStore_KVSSet_KVSGet(t *testing.T) { // Fetch the kv pair and check. ws = memdb.NewWatchSet() - idx, result, err = s.KVSGet(ws, "foo") + idx, result, err = s.KVSGet(ws, "foo", nil) if err != nil { t.Fatalf("err: %s", err) } @@ -330,7 +330,7 @@ func TestStateStore_KVSSet_KVSGet(t *testing.T) { // Fetch the kv pair and check. ws = memdb.NewWatchSet() - idx, result, err = s.KVSGet(ws, "foo") + idx, result, err = s.KVSGet(ws, "foo", nil) if err != nil { t.Fatalf("err: %s", err) } @@ -348,14 +348,14 @@ func TestStateStore_KVSSet_KVSGet(t *testing.T) { } // Setting some unrelated key should not fire the watch. - testSetKey(t, s, 8, "other", "yup") + testSetKey(t, s, 8, "other", "yup", nil) if watchFired(ws) { t.Fatalf("bad") } // Fetch a key that doesn't exist and make sure we get the right // response. - idx, result, err = s.KVSGet(nil, "nope") + idx, result, err = s.KVSGet(nil, "nope", nil) if result != nil || err != nil || idx != 8 { t.Fatalf("expected (8, nil, nil), got : (%#v, %#v, %#v)", idx, result, err) } @@ -370,7 +370,7 @@ func TestStateStore_KVSSet_KVSGet(t *testing.T) { require.Nil(t, s.KVSSet(1, entry)) require.Nil(t, s.KVSSet(2, entry)) - idx, _, err = s.KVSGet(ws, entry.Key) + idx, _, err = s.KVSGet(ws, entry.Key, nil) require.Nil(t, err) require.Equal(t, uint64(1), idx) @@ -381,23 +381,23 @@ func TestStateStore_KVSList(t *testing.T) { // Listing an empty KVS returns nothing ws := memdb.NewWatchSet() - idx, entries, err := s.KVSList(ws, "") + idx, entries, err := s.KVSList(ws, "", nil) if idx != 0 || entries != nil || err != nil { t.Fatalf("expected (0, nil, nil), got: (%d, %#v, %#v)", idx, entries, err) } // Create some KVS entries - testSetKey(t, s, 1, "foo", "foo") - testSetKey(t, s, 2, "foo/bar", "bar") - testSetKey(t, s, 3, "foo/bar/zip", "zip") - testSetKey(t, s, 4, "foo/bar/zip/zorp", "zorp") - testSetKey(t, s, 5, "foo/bar/baz", "baz") + testSetKey(t, s, 1, "foo", "foo", nil) + testSetKey(t, s, 2, "foo/bar", "bar", nil) + testSetKey(t, s, 3, "foo/bar/zip", "zip", nil) + testSetKey(t, s, 4, "foo/bar/zip/zorp", "zorp", nil) + testSetKey(t, s, 5, "foo/bar/baz", "baz", nil) if !watchFired(ws) { t.Fatalf("bad") } // List out all of the keys - idx, entries, err = s.KVSList(nil, "") + idx, entries, err = s.KVSList(nil, "", nil) if err != nil { t.Fatalf("err: %s", err) } @@ -411,7 +411,7 @@ func TestStateStore_KVSList(t *testing.T) { } // Try listing with a provided prefix - idx, entries, err = s.KVSList(nil, "foo/bar/zip") + idx, entries, err = s.KVSList(nil, "foo/bar/zip", nil) if err != nil { t.Fatalf("err: %s", err) } @@ -429,18 +429,18 @@ func TestStateStore_KVSList(t *testing.T) { // Delete a key and make sure the index comes from the tombstone. ws = memdb.NewWatchSet() - idx, _, err = s.KVSList(ws, "foo/bar/baz") + idx, _, err = s.KVSList(ws, "foo/bar/baz", nil) if err != nil { t.Fatalf("err: %s", err) } - if err := s.KVSDelete(6, "foo/bar/baz"); err != nil { + if err := s.KVSDelete(6, "foo/bar/baz", nil); err != nil { t.Fatalf("err: %s", err) } if !watchFired(ws) { t.Fatalf("bad") } ws = memdb.NewWatchSet() - idx, _, err = s.KVSList(ws, "foo/bar/baz") + idx, _, err = s.KVSList(ws, "foo/bar/baz", nil) if err != nil { t.Fatalf("err: %s", err) } @@ -450,13 +450,13 @@ func TestStateStore_KVSList(t *testing.T) { // Set a different key to bump the index. This shouldn't fire the // watch since there's a different prefix. - testSetKey(t, s, 7, "some/other/key", "") + testSetKey(t, s, 7, "some/other/key", "", nil) if watchFired(ws) { t.Fatalf("bad") } // Make sure we get the right index from the tombstone. - idx, _, err = s.KVSList(nil, "foo/bar/baz") + idx, _, err = s.KVSList(nil, "foo/bar/baz", nil) if err != nil { t.Fatalf("err: %s", err) } @@ -469,7 +469,7 @@ func TestStateStore_KVSList(t *testing.T) { if err := s.ReapTombstones(6); err != nil { t.Fatalf("err: %s", err) } - idx, _, err = s.KVSList(nil, "foo/bar/baz") + idx, _, err = s.KVSList(nil, "foo/bar/baz", nil) if err != nil { t.Fatalf("err: %s", err) } @@ -478,7 +478,7 @@ func TestStateStore_KVSList(t *testing.T) { } // List all the keys to make sure the index is also correct. - idx, _, err = s.KVSList(nil, "") + idx, _, err = s.KVSList(nil, "", nil) if err != nil { t.Fatalf("err: %s", err) } @@ -487,148 +487,22 @@ func TestStateStore_KVSList(t *testing.T) { } } -func TestStateStore_KVSListKeys(t *testing.T) { - s := testStateStore(t) - - // Listing keys with no results returns nil. - ws := memdb.NewWatchSet() - idx, keys, err := s.KVSListKeys(ws, "", "") - if idx != 0 || keys != nil || err != nil { - t.Fatalf("expected (0, nil, nil), got: (%d, %#v, %#v)", idx, keys, err) - } - - // Create some keys. - testSetKey(t, s, 1, "foo", "foo") - testSetKey(t, s, 2, "foo/bar", "bar") - testSetKey(t, s, 3, "foo/bar/baz", "baz") - testSetKey(t, s, 4, "foo/bar/zip", "zip") - testSetKey(t, s, 5, "foo/bar/zip/zam", "zam") - testSetKey(t, s, 6, "foo/bar/zip/zorp", "zorp") - testSetKey(t, s, 7, "some/other/prefix", "nack") - if !watchFired(ws) { - t.Fatalf("bad") - } - - // List all the keys. - idx, keys, err = s.KVSListKeys(nil, "", "") - if err != nil { - t.Fatalf("err: %s", err) - } - if len(keys) != 7 { - t.Fatalf("bad keys: %#v", keys) - } - if idx != 7 { - t.Fatalf("bad index: %d", idx) - } - - // Query using a prefix and pass a separator. - idx, keys, err = s.KVSListKeys(nil, "foo/bar/", "/") - if err != nil { - t.Fatalf("err: %s", err) - } - if len(keys) != 3 { - t.Fatalf("bad keys: %#v", keys) - } - if idx != 6 { - t.Fatalf("bad index: %d", idx) - } - - // Subset of the keys was returned. - expect := []string{"foo/bar/baz", "foo/bar/zip", "foo/bar/zip/"} - if !reflect.DeepEqual(keys, expect) { - t.Fatalf("bad keys: %#v", keys) - } - - // Listing keys with no separator returns everything. - idx, keys, err = s.KVSListKeys(nil, "foo", "") - if err != nil { - t.Fatalf("err: %s", err) - } - if idx != 6 { - t.Fatalf("bad index: %d", idx) - } - expect = []string{"foo", "foo/bar", "foo/bar/baz", "foo/bar/zip", - "foo/bar/zip/zam", "foo/bar/zip/zorp"} - if !reflect.DeepEqual(keys, expect) { - t.Fatalf("bad keys: %#v", keys) - } - - // Delete a key and make sure the index comes from the tombstone. - ws = memdb.NewWatchSet() - idx, _, err = s.KVSListKeys(ws, "foo/bar/baz", "") - if err != nil { - t.Fatalf("err: %s", err) - } - if err := s.KVSDelete(8, "foo/bar/baz"); err != nil { - t.Fatalf("err: %s", err) - } - if !watchFired(ws) { - t.Fatalf("bad") - } - ws = memdb.NewWatchSet() - idx, _, err = s.KVSListKeys(ws, "foo/bar/baz", "") - if err != nil { - t.Fatalf("err: %s", err) - } - if idx != 8 { - t.Fatalf("bad index: %d", idx) - } - - // Set a different key to bump the index. This shouldn't fire the watch - // since there's a different prefix. - testSetKey(t, s, 9, "some/other/key", "") - if watchFired(ws) { - t.Fatalf("bad") - } - - // Make sure the index still comes from the tombstone. - idx, _, err = s.KVSListKeys(nil, "foo/bar/baz", "") - if err != nil { - t.Fatalf("err: %s", err) - } - if idx != 8 { - t.Fatalf("bad index: %d", idx) - } - - // Now reap the tombstones and make sure we get the latest index - // since there are no matching keys. - if err := s.ReapTombstones(8); err != nil { - t.Fatalf("err: %s", err) - } - idx, _, err = s.KVSListKeys(nil, "foo/bar/baz", "") - if err != nil { - t.Fatalf("err: %s", err) - } - if idx != 9 { - t.Fatalf("bad index: %d", idx) - } - - // List all the keys to make sure the index is also correct. - idx, _, err = s.KVSListKeys(nil, "", "") - if err != nil { - t.Fatalf("err: %s", err) - } - if idx != 9 { - t.Fatalf("bad index: %d", idx) - } -} - func TestStateStore_KVSDelete(t *testing.T) { s := testStateStore(t) // Create some KV pairs - testSetKey(t, s, 1, "foo", "foo") - testSetKey(t, s, 2, "foo/bar", "bar") + testSetKey(t, s, 1, "foo", "foo", nil) + testSetKey(t, s, 2, "foo/bar", "bar", nil) // Call a delete on a specific key - if err := s.KVSDelete(3, "foo"); err != nil { + if err := s.KVSDelete(3, "foo", nil); err != nil { t.Fatalf("err: %s", err) } // The entry was removed from the state store tx := s.db.Txn(false) defer tx.Abort() - e, err := tx.First("kvs", "id", "foo") + e, err := firstWithTxn(tx, "kvs", "id", "foo", nil) if err != nil { t.Fatalf("err: %s", err) } @@ -637,7 +511,7 @@ func TestStateStore_KVSDelete(t *testing.T) { } // Try fetching the other keys to ensure they still exist - e, err = tx.First("kvs", "id", "foo/bar") + e, err = firstWithTxn(tx, "kvs", "id", "foo/bar", nil) if err != nil { t.Fatalf("err: %s", err) } @@ -652,7 +526,7 @@ func TestStateStore_KVSDelete(t *testing.T) { // Check that the tombstone was created and that prevents the index // from sliding backwards. - idx, _, err := s.KVSList(nil, "foo") + idx, _, err := s.KVSList(nil, "foo", nil) if err != nil { t.Fatalf("err: %s", err) } @@ -665,7 +539,7 @@ func TestStateStore_KVSDelete(t *testing.T) { if err := s.ReapTombstones(3); err != nil { t.Fatalf("err: %s", err) } - idx, _, err = s.KVSList(nil, "foo") + idx, _, err = s.KVSList(nil, "foo", nil) if err != nil { t.Fatalf("err: %s", err) } @@ -675,7 +549,7 @@ func TestStateStore_KVSDelete(t *testing.T) { // Deleting a nonexistent key should be idempotent and not return an // error - if err := s.KVSDelete(4, "foo"); err != nil { + if err := s.KVSDelete(4, "foo", nil); err != nil { t.Fatalf("err: %s", err) } if idx := s.maxIndex("kvs"); idx != 3 { @@ -687,19 +561,19 @@ func TestStateStore_KVSDeleteCAS(t *testing.T) { s := testStateStore(t) // Create some KV entries - testSetKey(t, s, 1, "foo", "foo") - testSetKey(t, s, 2, "bar", "bar") - testSetKey(t, s, 3, "baz", "baz") + testSetKey(t, s, 1, "foo", "foo", nil) + testSetKey(t, s, 2, "bar", "bar", nil) + testSetKey(t, s, 3, "baz", "baz", nil) // Do a CAS delete with an index lower than the entry - ok, err := s.KVSDeleteCAS(4, 1, "bar") + ok, err := s.KVSDeleteCAS(4, 1, "bar", nil) if ok || err != nil { t.Fatalf("expected (false, nil), got: (%v, %#v)", ok, err) } // Check that the index is untouched and the entry // has not been deleted. - idx, e, err := s.KVSGet(nil, "foo") + idx, e, err := s.KVSGet(nil, "foo", nil) if err != nil { t.Fatalf("err: %s", err) } @@ -712,13 +586,13 @@ func TestStateStore_KVSDeleteCAS(t *testing.T) { // Do another CAS delete, this time with the correct index // which should cause the delete to take place. - ok, err = s.KVSDeleteCAS(4, 2, "bar") + ok, err = s.KVSDeleteCAS(4, 2, "bar", nil) if !ok || err != nil { t.Fatalf("expected (true, nil), got: (%v, %#v)", ok, err) } // Entry was deleted and index was updated - idx, e, err = s.KVSGet(nil, "bar") + idx, e, err = s.KVSGet(nil, "bar", nil) if err != nil { t.Fatalf("err: %s", err) } @@ -730,11 +604,11 @@ func TestStateStore_KVSDeleteCAS(t *testing.T) { } // Add another key to bump the index. - testSetKey(t, s, 5, "some/other/key", "baz") + testSetKey(t, s, 5, "some/other/key", "baz", nil) // Check that the tombstone was created and that prevents the index // from sliding backwards. - idx, _, err = s.KVSList(nil, "bar") + idx, _, err = s.KVSList(nil, "bar", nil) if err != nil { t.Fatalf("err: %s", err) } @@ -747,7 +621,7 @@ func TestStateStore_KVSDeleteCAS(t *testing.T) { if err := s.ReapTombstones(4); err != nil { t.Fatalf("err: %s", err) } - idx, _, err = s.KVSList(nil, "bar") + idx, _, err = s.KVSList(nil, "bar", nil) if err != nil { t.Fatalf("err: %s", err) } @@ -757,7 +631,7 @@ func TestStateStore_KVSDeleteCAS(t *testing.T) { // A delete on a nonexistent key should be idempotent and not return an // error - ok, err = s.KVSDeleteCAS(6, 2, "bar") + ok, err = s.KVSDeleteCAS(6, 2, "bar", nil) if !ok || err != nil { t.Fatalf("expected (true, nil), got: (%v, %#v)", ok, err) } @@ -786,7 +660,7 @@ func TestStateStore_KVSSetCAS(t *testing.T) { // Check that nothing was actually stored tx := s.db.Txn(false) - if e, err := tx.First("kvs", "id", "foo"); e != nil || err != nil { + if e, err := firstWithTxn(tx, "kvs", "id", "foo", nil); e != nil || err != nil { t.Fatalf("expected (nil, nil), got: (%#v, %#v)", e, err) } tx.Abort() @@ -812,7 +686,7 @@ func TestStateStore_KVSSetCAS(t *testing.T) { } // Entry was inserted - idx, entry, err := s.KVSGet(nil, "foo") + idx, entry, err := s.KVSGet(nil, "foo", nil) if err != nil { t.Fatalf("err: %s", err) } @@ -854,7 +728,7 @@ func TestStateStore_KVSSetCAS(t *testing.T) { } // Entry was not updated in the store - idx, entry, err = s.KVSGet(nil, "foo") + idx, entry, err = s.KVSGet(nil, "foo", nil) if err != nil { t.Fatalf("err: %s", err) } @@ -881,7 +755,7 @@ func TestStateStore_KVSSetCAS(t *testing.T) { } // Entry was updated - idx, entry, err = s.KVSGet(nil, "foo") + idx, entry, err = s.KVSGet(nil, "foo", nil) if err != nil { t.Fatalf("err: %s", err) } @@ -908,7 +782,7 @@ func TestStateStore_KVSSetCAS(t *testing.T) { } // Entry was updated, but the session should have been ignored. - idx, entry, err = s.KVSGet(nil, "foo") + idx, entry, err = s.KVSGet(nil, "foo", nil) if err != nil { t.Fatalf("err: %s", err) } @@ -953,7 +827,7 @@ func TestStateStore_KVSSetCAS(t *testing.T) { } // Entry was updated, and the lock status should have stayed the same. - idx, entry, err = s.KVSGet(nil, "foo") + idx, entry, err = s.KVSGet(nil, "foo", nil) if err != nil { t.Fatalf("err: %s", err) } @@ -970,14 +844,14 @@ func TestStateStore_KVSDeleteTree(t *testing.T) { s := testStateStore(t) // Create kvs entries in the state store. - testSetKey(t, s, 1, "foo/bar", "bar") - testSetKey(t, s, 2, "foo/bar/baz", "baz") - testSetKey(t, s, 3, "foo/bar/zip", "zip") - testSetKey(t, s, 4, "foo/zorp", "zorp") + testSetKey(t, s, 1, "foo/bar", "bar", nil) + testSetKey(t, s, 2, "foo/bar/baz", "baz", nil) + testSetKey(t, s, 3, "foo/bar/zip", "zip", nil) + testSetKey(t, s, 4, "foo/zorp", "zorp", nil) // Calling tree deletion which affects nothing does not // modify the table index. - if err := s.KVSDeleteTree(9, "bar"); err != nil { + if err := s.KVSDeleteTree(9, "bar", nil); err != nil { t.Fatalf("err: %s", err) } if idx := s.maxIndex("kvs"); idx != 4 { @@ -985,7 +859,7 @@ func TestStateStore_KVSDeleteTree(t *testing.T) { } // Call tree deletion with a nested prefix. - if err := s.KVSDeleteTree(5, "foo/bar"); err != nil { + if err := s.KVSDeleteTree(5, "foo/bar", nil); err != nil { t.Fatalf("err: %s", err) } @@ -1017,7 +891,7 @@ func TestStateStore_KVSDeleteTree(t *testing.T) { // Check that the tombstones ware created and that prevents the index // from sliding backwards. - idx, _, err := s.KVSList(nil, "foo") + idx, _, err := s.KVSList(nil, "foo", nil) if err != nil { t.Fatalf("err: %s", err) } @@ -1030,7 +904,7 @@ func TestStateStore_KVSDeleteTree(t *testing.T) { if err := s.ReapTombstones(5); err != nil { t.Fatalf("err: %s", err) } - idx, _, err = s.KVSList(nil, "foo") + idx, _, err = s.KVSList(nil, "foo", nil) if err != nil { t.Fatalf("err: %s", err) } @@ -1043,15 +917,15 @@ func TestStateStore_Watches_PrefixDelete(t *testing.T) { s := testStateStore(t) // Create some KVS entries - testSetKey(t, s, 1, "foo", "foo") - testSetKey(t, s, 2, "foo/bar", "bar") - testSetKey(t, s, 3, "foo/bar/zip", "zip") - testSetKey(t, s, 4, "foo/bar/zip/zorp", "zorp") - testSetKey(t, s, 5, "foo/bar/zip/zap", "zap") - testSetKey(t, s, 6, "foo/nope", "nope") + testSetKey(t, s, 1, "foo", "foo", nil) + testSetKey(t, s, 2, "foo/bar", "bar", nil) + testSetKey(t, s, 3, "foo/bar/zip", "zip", nil) + testSetKey(t, s, 4, "foo/bar/zip/zorp", "zorp", nil) + testSetKey(t, s, 5, "foo/bar/zip/zap", "zap", nil) + testSetKey(t, s, 6, "foo/nope", "nope", nil) ws := memdb.NewWatchSet() - got, _, err := s.KVSList(ws, "foo/bar") + got, _, err := s.KVSList(ws, "foo/bar", nil) if err != nil { t.Fatalf("unexpected err: %s", err) } @@ -1061,7 +935,7 @@ func TestStateStore_Watches_PrefixDelete(t *testing.T) { } // Delete a key and make sure the index comes from the tombstone. - if err := s.KVSDeleteTree(7, "foo/bar/zip"); err != nil { + if err := s.KVSDeleteTree(7, "foo/bar/zip", nil); err != nil { t.Fatalf("unexpected err: %s", err) } // Make sure watch fires @@ -1070,7 +944,7 @@ func TestStateStore_Watches_PrefixDelete(t *testing.T) { } //Verify index matches tombstone - got, _, err = s.KVSList(ws, "foo/bar") + got, _, err = s.KVSList(ws, "foo/bar", nil) if err != nil { t.Fatalf("unexpected err: %s", err) } @@ -1088,7 +962,7 @@ func TestStateStore_Watches_PrefixDelete(t *testing.T) { t.Fatalf("err: %s", err) } - got, _, err = s.KVSList(nil, "foo/bar") + got, _, err = s.KVSList(nil, "foo/bar", nil) wantIndex = 2 if err != nil { t.Fatalf("err: %s", err) @@ -1099,13 +973,13 @@ func TestStateStore_Watches_PrefixDelete(t *testing.T) { // Set a different key to bump the index. This shouldn't fire the // watch since there's a different prefix. - testSetKey(t, s, 8, "some/other/key", "") + testSetKey(t, s, 8, "some/other/key", "", nil) // Now ask for the index for a node within the prefix that was deleted // We expect to get the max index in the tree wantIndex = 8 ws = memdb.NewWatchSet() - got, _, err = s.KVSList(ws, "foo/bar/baz") + got, _, err = s.KVSList(ws, "foo/bar/baz", nil) if err != nil { t.Fatalf("err: %s", err) } @@ -1117,7 +991,7 @@ func TestStateStore_Watches_PrefixDelete(t *testing.T) { } // List all the keys to make sure the index returned is the max index - got, _, err = s.KVSList(nil, "") + got, _, err = s.KVSList(nil, "", nil) if err != nil { t.Fatalf("err: %s", err) } @@ -1126,11 +1000,11 @@ func TestStateStore_Watches_PrefixDelete(t *testing.T) { } // Delete all the keys, special case where tombstones are not inserted - if err := s.KVSDeleteTree(9, ""); err != nil { + if err := s.KVSDeleteTree(9, "", nil); err != nil { t.Fatalf("unexpected err: %s", err) } wantIndex = 9 - got, _, err = s.KVSList(nil, "/foo/bar") + got, _, err = s.KVSList(nil, "/foo/bar", nil) if err != nil { t.Fatalf("err: %s", err) } @@ -1145,7 +1019,7 @@ func TestStateStore_KVSLockDelay(t *testing.T) { // KVSLockDelay is exercised in the lock/unlock and session invalidation // cases below, so we just do a basic check on a nonexistent key here. - expires := s.KVSLockDelay("/not/there") + expires := s.KVSLockDelay("/not/there", nil) if expires.After(time.Now()) { t.Fatalf("bad: %v", expires) } @@ -1180,7 +1054,7 @@ func TestStateStore_KVSLock(t *testing.T) { } // Make sure the indexes got set properly. - idx, result, err := s.KVSGet(nil, "foo") + idx, result, err := s.KVSGet(nil, "foo", nil) if err != nil { t.Fatalf("err: %s", err) } @@ -1201,7 +1075,7 @@ func TestStateStore_KVSLock(t *testing.T) { // Make sure the indexes got set properly, note that the lock index // won't go up since we didn't lock it again. - idx, result, err = s.KVSGet(nil, "foo") + idx, result, err = s.KVSGet(nil, "foo", nil) if err != nil { t.Fatalf("err: %s", err) } @@ -1224,7 +1098,7 @@ func TestStateStore_KVSLock(t *testing.T) { } // Make sure the indexes got set properly. - idx, result, err = s.KVSGet(nil, "foo") + idx, result, err = s.KVSGet(nil, "foo", nil) if err != nil { t.Fatalf("err: %s", err) } @@ -1237,14 +1111,14 @@ func TestStateStore_KVSLock(t *testing.T) { } // Lock an existing key. - testSetKey(t, s, 8, "bar", "bar") + testSetKey(t, s, 8, "bar", "bar", nil) ok, err = s.KVSLock(9, &structs.DirEntry{Key: "bar", Value: []byte("xxx"), Session: session1}) if !ok || err != nil { t.Fatalf("didn't get the lock: %v %s", ok, err) } // Make sure the indexes got set properly. - idx, result, err = s.KVSGet(nil, "bar") + idx, result, err = s.KVSGet(nil, "bar", nil) if err != nil { t.Fatalf("err: %s", err) } @@ -1270,7 +1144,7 @@ func TestStateStore_KVSLock(t *testing.T) { } // Make sure the indexes didn't update. - idx, result, err = s.KVSGet(nil, "bar") + idx, result, err = s.KVSGet(nil, "bar", nil) if err != nil { t.Fatalf("err: %s", err) } @@ -1307,14 +1181,14 @@ func TestStateStore_KVSUnlock(t *testing.T) { } // Make a key and unlock it, without it being locked. - testSetKey(t, s, 4, "foo", "bar") + testSetKey(t, s, 4, "foo", "bar", nil) ok, err = s.KVSUnlock(5, &structs.DirEntry{Key: "foo", Value: []byte("baz"), Session: session1}) if ok || err != nil { t.Fatalf("didn't handle unlocking a non-locked key: %v %s", ok, err) } // Make sure the indexes didn't update. - idx, result, err := s.KVSGet(nil, "foo") + idx, result, err := s.KVSGet(nil, "foo", nil) if err != nil { t.Fatalf("err: %s", err) } @@ -1343,7 +1217,7 @@ func TestStateStore_KVSUnlock(t *testing.T) { } // Make sure the indexes didn't update. - idx, result, err = s.KVSGet(nil, "foo") + idx, result, err = s.KVSGet(nil, "foo", nil) if err != nil { t.Fatalf("err: %s", err) } @@ -1362,7 +1236,7 @@ func TestStateStore_KVSUnlock(t *testing.T) { } // Make sure the indexes got set properly. - idx, result, err = s.KVSGet(nil, "foo") + idx, result, err = s.KVSGet(nil, "foo", nil) if err != nil { t.Fatalf("err: %s", err) } @@ -1381,7 +1255,7 @@ func TestStateStore_KVSUnlock(t *testing.T) { } // Make sure the indexes didn't update. - idx, result, err = s.KVSGet(nil, "foo") + idx, result, err = s.KVSGet(nil, "foo", nil) if err != nil { t.Fatalf("err: %s", err) } @@ -1474,7 +1348,7 @@ func TestStateStore_KVS_Snapshot_Restore(t *testing.T) { restore.Commit() // Read the restored keys back out and verify they match. - idx, res, err := s.KVSList(nil, "") + idx, res, err := s.KVSList(nil, "", nil) if err != nil { t.Fatalf("err: %s", err) } @@ -1496,10 +1370,10 @@ func TestStateStore_Tombstone_Snapshot_Restore(t *testing.T) { s := testStateStore(t) // Insert a key and then delete it to create a tombstone. - testSetKey(t, s, 1, "foo/bar", "bar") - testSetKey(t, s, 2, "foo/bar/baz", "bar") - testSetKey(t, s, 3, "foo/bar/zoo", "bar") - if err := s.KVSDelete(4, "foo/bar"); err != nil { + testSetKey(t, s, 1, "foo/bar", "bar", nil) + testSetKey(t, s, 2, "foo/bar/baz", "bar", nil) + testSetKey(t, s, 3, "foo/bar/zoo", "bar", nil) + if err := s.KVSDelete(4, "foo/bar", nil); err != nil { t.Fatalf("err: %s", err) } @@ -1511,7 +1385,7 @@ func TestStateStore_Tombstone_Snapshot_Restore(t *testing.T) { if err := s.ReapTombstones(4); err != nil { t.Fatalf("err: %s", err) } - idx, _, err := s.KVSList(nil, "foo/bar") + idx, _, err := s.KVSList(nil, "foo/bar", nil) if err != nil { t.Fatalf("err: %s", err) } @@ -1548,7 +1422,7 @@ func TestStateStore_Tombstone_Snapshot_Restore(t *testing.T) { restore.Commit() // See if the stone works properly in a list query. - idx, _, err := s.KVSList(nil, "foo/bar") + idx, _, err := s.KVSList(nil, "foo/bar", nil) if err != nil { t.Fatalf("err: %s", err) } @@ -1562,7 +1436,7 @@ func TestStateStore_Tombstone_Snapshot_Restore(t *testing.T) { if err := s.ReapTombstones(4); err != nil { t.Fatalf("err: %s", err) } - idx, _, err = s.KVSList(nil, "foo/bar") + idx, _, err = s.KVSList(nil, "foo/bar", nil) if err != nil { t.Fatalf("err: %s", err) } diff --git a/agent/consul/state/operations_oss.go b/agent/consul/state/operations_oss.go new file mode 100644 index 0000000000..48deec7863 --- /dev/null +++ b/agent/consul/state/operations_oss.go @@ -0,0 +1,26 @@ +// +build !consulent + +package state + +import ( + "github.com/hashicorp/consul/agent/structs" + "github.com/hashicorp/go-memdb" +) + +func firstWithTxn(tx *memdb.Txn, + table, index, idxVal string, entMeta *structs.EnterpriseMeta) (interface{}, error) { + + return tx.First(table, index, idxVal) +} + +func firstWatchWithTxn(tx *memdb.Txn, + table, index, idxVal string, entMeta *structs.EnterpriseMeta) (<-chan struct{}, interface{}, error) { + + return tx.FirstWatch(table, index, idxVal) +} + +func getWithTxn(tx *memdb.Txn, + table, index, idxVal string, entMeta *structs.EnterpriseMeta) (memdb.ResultIterator, error) { + + return tx.Get(table, index, idxVal) +} diff --git a/agent/consul/state/prepared_query.go b/agent/consul/state/prepared_query.go index 285355785e..89a8f83493 100644 --- a/agent/consul/state/prepared_query.go +++ b/agent/consul/state/prepared_query.go @@ -210,9 +210,9 @@ func (s *Store) preparedQuerySetTxn(tx *memdb.Txn, idx uint64, query *structs.Pr // Verify that the session exists. if query.Session != "" { - sess, err := tx.First("sessions", "id", query.Session) + sess, err := firstWithTxn(tx, "sessions", "id", query.Session, nil) if err != nil { - return fmt.Errorf("failed session lookup: %s", err) + return fmt.Errorf("invalid session: %v", err) } if sess == nil { return fmt.Errorf("invalid session %#v", query.Session) diff --git a/agent/consul/state/prepared_query_test.go b/agent/consul/state/prepared_query_test.go index 44495819e3..55eb774596 100644 --- a/agent/consul/state/prepared_query_test.go +++ b/agent/consul/state/prepared_query_test.go @@ -68,7 +68,7 @@ func TestStateStore_PreparedQuerySet_PreparedQueryGet(t *testing.T) { // The set will still fail because the session is bogus. err = s.PreparedQuerySet(1, query) - if err == nil || !strings.Contains(err.Error(), "failed session lookup") { + if err == nil || !strings.Contains(err.Error(), "invalid session") { t.Fatalf("bad: %v", err) } diff --git a/agent/consul/state/session.go b/agent/consul/state/session.go index 9775ff639a..5bca46c3c9 100644 --- a/agent/consul/state/session.go +++ b/agent/consul/state/session.go @@ -19,18 +19,13 @@ func sessionsTableSchema() *memdb.TableSchema { Name: "id", AllowMissing: false, Unique: true, - Indexer: &memdb.UUIDFieldIndex{ - Field: "ID", - }, + Indexer: sessionIndexer(), }, "node": &memdb.IndexSchema{ Name: "node", AllowMissing: false, Unique: false, - Indexer: &memdb.StringFieldIndex{ - Field: "Node", - Lowercase: true, - }, + Indexer: nodeSessionsIndexer(), }, }, } @@ -108,28 +103,10 @@ func (s *Snapshot) Sessions() (memdb.ResultIterator, error) { // Session is used when restoring from a snapshot. For general inserts, use // SessionCreate. func (s *Restore) Session(sess *structs.Session) error { - // Insert the session. - if err := s.tx.Insert("sessions", sess); err != nil { + if err := s.store.insertSessionTxn(s.tx, sess, sess.ModifyIndex, true); err != nil { return fmt.Errorf("failed inserting session: %s", err) } - // Insert the check mappings. - for _, checkID := range sess.Checks { - mapping := &sessionCheck{ - Node: sess.Node, - CheckID: checkID, - Session: sess.ID, - } - if err := s.tx.Insert("session_checks", mapping); err != nil { - return fmt.Errorf("failed inserting session check mapping: %s", err) - } - } - - // Update the index. - if err := indexUpdateMaxTxn(s.tx, sess.ModifyIndex, "sessions"); err != nil { - return fmt.Errorf("failed updating index: %s", err) - } - return nil } @@ -206,44 +183,30 @@ func (s *Store) sessionCreateTxn(tx *memdb.Txn, idx uint64, sess *structs.Sessio } // Insert the session - if err := tx.Insert("sessions", sess); err != nil { + if err := s.insertSessionTxn(tx, sess, idx, false); err != nil { return fmt.Errorf("failed inserting session: %s", err) } - // Insert the check mappings - for _, checkID := range sess.Checks { - mapping := &sessionCheck{ - Node: sess.Node, - CheckID: checkID, - Session: sess.ID, - } - if err := tx.Insert("session_checks", mapping); err != nil { - return fmt.Errorf("failed inserting session check mapping: %s", err) - } - } - - // Update the index - if err := tx.Insert("index", &IndexEntry{"sessions", idx}); err != nil { - return fmt.Errorf("failed updating index: %s", err) - } - return nil } // SessionGet is used to retrieve an active session from the state store. -func (s *Store) SessionGet(ws memdb.WatchSet, sessionID string) (uint64, *structs.Session, error) { +func (s *Store) SessionGet(ws memdb.WatchSet, + sessionID string, entMeta *structs.EnterpriseMeta) (uint64, *structs.Session, error) { + tx := s.db.Txn(false) defer tx.Abort() // Get the table index. - idx := maxIndexTxn(tx, "sessions") + idx := s.sessionMaxIndex(tx, entMeta) // Look up the session by its ID - watchCh, session, err := tx.FirstWatch("sessions", "id", sessionID) + watchCh, session, err := firstWatchWithTxn(tx, "sessions", "id", sessionID, entMeta) if err != nil { return 0, nil, fmt.Errorf("failed session lookup: %s", err) } ws.Add(watchCh) + if session != nil { return idx, session.(*structs.Session), nil } @@ -251,15 +214,15 @@ func (s *Store) SessionGet(ws memdb.WatchSet, sessionID string) (uint64, *struct } // SessionList returns a slice containing all of the active sessions. -func (s *Store) SessionList(ws memdb.WatchSet) (uint64, structs.Sessions, error) { +func (s *Store) SessionList(ws memdb.WatchSet, entMeta *structs.EnterpriseMeta) (uint64, structs.Sessions, error) { tx := s.db.Txn(false) defer tx.Abort() // Get the table index. - idx := maxIndexTxn(tx, "sessions") + idx := s.sessionMaxIndex(tx, entMeta) // Query all of the active sessions. - sessions, err := tx.Get("sessions", "id") + sessions, err := getWithTxn(tx, "sessions", "id_prefix", "", entMeta) if err != nil { return 0, nil, fmt.Errorf("failed session lookup: %s", err) } @@ -276,24 +239,17 @@ func (s *Store) SessionList(ws memdb.WatchSet) (uint64, structs.Sessions, error) // NodeSessions returns a set of active sessions associated // with the given node ID. The returned index is the highest // index seen from the result set. -func (s *Store) NodeSessions(ws memdb.WatchSet, nodeID string) (uint64, structs.Sessions, error) { +func (s *Store) NodeSessions(ws memdb.WatchSet, nodeID string, entMeta *structs.EnterpriseMeta) (uint64, structs.Sessions, error) { tx := s.db.Txn(false) defer tx.Abort() // Get the table index. - idx := maxIndexTxn(tx, "sessions") + idx := s.sessionMaxIndex(tx, entMeta) // Get all of the sessions which belong to the node - sessions, err := tx.Get("sessions", "node", nodeID) + result, err := s.nodeSessionsTxn(tx, ws, nodeID, entMeta) if err != nil { - return 0, nil, fmt.Errorf("failed session lookup: %s", err) - } - ws.Add(sessions.WatchCh()) - - // Go over all of the sessions and return them as a slice - var result structs.Sessions - for session := sessions.Next(); session != nil; session = sessions.Next() { - result = append(result, session.(*structs.Session)) + return 0, nil, err } return idx, result, nil } @@ -301,12 +257,12 @@ func (s *Store) NodeSessions(ws memdb.WatchSet, nodeID string) (uint64, structs. // SessionDestroy is used to remove an active session. This will // implicitly invalidate the session and invoke the specified // session destroy behavior. -func (s *Store) SessionDestroy(idx uint64, sessionID string) error { +func (s *Store) SessionDestroy(idx uint64, sessionID string, entMeta *structs.EnterpriseMeta) error { tx := s.db.Txn(true) defer tx.Abort() // Call the session deletion. - if err := s.deleteSessionTxn(tx, idx, sessionID); err != nil { + if err := s.deleteSessionTxn(tx, idx, sessionID, entMeta); err != nil { return err } @@ -316,9 +272,9 @@ func (s *Store) SessionDestroy(idx uint64, sessionID string) error { // deleteSessionTxn is the inner method, which is used to do the actual // session deletion and handle session invalidation, etc. -func (s *Store) deleteSessionTxn(tx *memdb.Txn, idx uint64, sessionID string) error { +func (s *Store) deleteSessionTxn(tx *memdb.Txn, idx uint64, sessionID string, entMeta *structs.EnterpriseMeta) error { // Look up the session. - sess, err := tx.First("sessions", "id", sessionID) + sess, err := firstWithTxn(tx, "sessions", "id", sessionID, entMeta) if err != nil { return fmt.Errorf("failed session lookup: %s", err) } @@ -327,15 +283,12 @@ func (s *Store) deleteSessionTxn(tx *memdb.Txn, idx uint64, sessionID string) er } // Delete the session and write the new index. - if err := tx.Delete("sessions", sess); err != nil { - return fmt.Errorf("failed deleting session: %s", err) - } - if err := tx.Insert("index", &IndexEntry{"sessions", idx}); err != nil { - return fmt.Errorf("failed updating index: %s", err) + session := sess.(*structs.Session) + if err := s.sessionDeleteWithSession(tx, session, idx); err != nil { + return fmt.Errorf("failed deleting session: %v", err) } // Enforce the max lock delay. - session := sess.(*structs.Session) delay := session.LockDelay if delay > structs.MaxLockDelay { delay = structs.MaxLockDelay @@ -370,19 +323,19 @@ func (s *Store) deleteSessionTxn(tx *memdb.Txn, idx uint64, sessionID string) er // Apply the lock delay if present. if delay > 0 { - s.lockDelay.SetExpiration(e.Key, now, delay) + s.lockDelay.SetExpiration(e.Key, now, delay, entMeta) } } case structs.SessionKeysDelete: for _, obj := range kvs { e := obj.(*structs.DirEntry) - if err := s.kvsDeleteTxn(tx, idx, e.Key); err != nil { + if err := s.kvsDeleteTxn(tx, idx, e.Key, entMeta); err != nil { return fmt.Errorf("failed kvs delete: %s", err) } // Apply the lock delay if present. if delay > 0 { - s.lockDelay.SetExpiration(e.Key, now, delay) + s.lockDelay.SetExpiration(e.Key, now, delay, entMeta) } } default: diff --git a/agent/consul/state/session_oss.go b/agent/consul/state/session_oss.go new file mode 100644 index 0000000000..7edf6d02c6 --- /dev/null +++ b/agent/consul/state/session_oss.go @@ -0,0 +1,92 @@ +// +build !consulent + +package state + +import ( + "fmt" + + "github.com/hashicorp/consul/agent/structs" + "github.com/hashicorp/go-memdb" +) + +func sessionIndexer() *memdb.UUIDFieldIndex { + return &memdb.UUIDFieldIndex{ + Field: "ID", + } +} + +func nodeSessionsIndexer() *memdb.StringFieldIndex { + return &memdb.StringFieldIndex{ + Field: "Node", + Lowercase: true, + } +} + +func (s *Store) sessionDeleteWithSession(tx *memdb.Txn, session *structs.Session, idx uint64) error { + if err := tx.Delete("sessions", session); err != nil { + return fmt.Errorf("failed deleting session: %s", err) + } + + // Update the indexes + err := tx.Insert("index", &IndexEntry{"sessions", idx}) + if err != nil { + return fmt.Errorf("failed updating sessions index: %v", err) + } + return nil +} + +func (s *Store) insertSessionTxn(tx *memdb.Txn, session *structs.Session, idx uint64, updateMax bool) error { + if err := tx.Insert("sessions", session); err != nil { + return err + } + + // Insert the check mappings + for _, checkID := range session.Checks { + mapping := &sessionCheck{ + Node: session.Node, + CheckID: checkID, + Session: session.ID, + } + if err := tx.Insert("session_checks", mapping); err != nil { + return fmt.Errorf("failed inserting session check mapping: %s", err) + } + } + + // Update the index + if updateMax { + if err := indexUpdateMaxTxn(tx, idx, "sessions"); err != nil { + return fmt.Errorf("failed updating sessions index: %v", err) + } + } else { + err := tx.Insert("index", &IndexEntry{"sessions", idx}) + if err != nil { + return fmt.Errorf("failed updating sessions index: %v", err) + } + } + + return nil +} + +func (s *Store) allNodeSessionsTxn(tx *memdb.Txn, node string) (structs.Sessions, error) { + return s.nodeSessionsTxn(tx, nil, node, nil) +} + +func (s *Store) nodeSessionsTxn(tx *memdb.Txn, + ws memdb.WatchSet, node string, entMeta *structs.EnterpriseMeta) (structs.Sessions, error) { + + sessions, err := tx.Get("sessions", "node", node) + if err != nil { + return nil, fmt.Errorf("failed session lookup: %s", err) + } + ws.Add(sessions.WatchCh()) + + var result structs.Sessions + for session := sessions.Next(); session != nil; session = sessions.Next() { + result = append(result, session.(*structs.Session)) + } + return result, nil +} + +func (s *Store) sessionMaxIndex(tx *memdb.Txn, entMeta *structs.EnterpriseMeta) uint64 { + return maxIndexTxn(tx, "sessions") +} diff --git a/agent/consul/state/session_test.go b/agent/consul/state/session_test.go index 1e638f6ece..b5939367f4 100644 --- a/agent/consul/state/session_test.go +++ b/agent/consul/state/session_test.go @@ -2,6 +2,7 @@ package state import ( "fmt" + "github.com/stretchr/testify/assert" "reflect" "strings" "testing" @@ -18,7 +19,7 @@ func TestStateStore_SessionCreate_SessionGet(t *testing.T) { // SessionGet returns nil if the session doesn't exist ws := memdb.NewWatchSet() - idx, session, err := s.SessionGet(ws, testUUID()) + idx, session, err := s.SessionGet(ws, testUUID(), nil) if session != nil || err != nil { t.Fatalf("expected (nil, nil), got: (%#v, %#v)", session, err) } @@ -74,7 +75,7 @@ func TestStateStore_SessionCreate_SessionGet(t *testing.T) { // Retrieve the session again ws = memdb.NewWatchSet() - idx, session, err = s.SessionGet(ws, sess.ID) + idx, session, err = s.SessionGet(ws, sess.ID, nil) if err != nil { t.Fatalf("err: %s", err) } @@ -88,13 +89,15 @@ func TestStateStore_SessionCreate_SessionGet(t *testing.T) { ID: sess.ID, Behavior: structs.SessionKeysRelease, Node: "node1", - RaftIndex: structs.RaftIndex{ - CreateIndex: 2, - ModifyIndex: 2, - }, } - if !reflect.DeepEqual(expect, session) { - t.Fatalf("bad session: %#v", session) + if session.ID != expect.ID { + t.Fatalf("bad session ID: expected %s, got %s", expect.ID, session.ID) + } + if session.Node != expect.Node { + t.Fatalf("bad session Node: expected %s, got %s", expect.Node, session.Node) + } + if session.Behavior != expect.Behavior { + t.Fatalf("bad session Behavior: expected %s, got %s", expect.Behavior, session.Behavior) } // Registering with a non-existent check is disallowed @@ -176,7 +179,7 @@ func TestStateStore_SessionCreate_SessionGet(t *testing.T) { } // Pulling a nonexistent session gives the table index. - idx, session, err = s.SessionGet(nil, testUUID()) + idx, session, err = s.SessionGet(nil, testUUID(), nil) if err != nil { t.Fatalf("err: %s", err) } @@ -188,12 +191,12 @@ func TestStateStore_SessionCreate_SessionGet(t *testing.T) { } } -func TegstStateStore_SessionList(t *testing.T) { +func TestStateStore_SessionList(t *testing.T) { s := testStateStore(t) // Listing when no sessions exist returns nil ws := memdb.NewWatchSet() - idx, res, err := s.SessionList(ws) + idx, res, err := s.SessionList(ws, nil) if idx != 0 || res != nil || err != nil { t.Fatalf("expected (0, nil, nil), got: (%d, %#v, %#v)", idx, res, err) } @@ -231,15 +234,20 @@ func TegstStateStore_SessionList(t *testing.T) { } // List out all of the sessions - idx, sessionList, err := s.SessionList(nil) + idx, sessionList, err := s.SessionList(nil, nil) if err != nil { t.Fatalf("err: %s", err) } if idx != 6 { t.Fatalf("bad index: %d", idx) } - if !reflect.DeepEqual(sessionList, sessions) { - t.Fatalf("bad: %#v", sessions) + sessionMap := make(map[string]*structs.Session) + for _, session := range sessionList { + sessionMap[session.ID] = session + } + + for _, expect := range sessions { + assert.Equal(t, expect, sessionMap[expect.ID]) } } @@ -248,7 +256,7 @@ func TestStateStore_NodeSessions(t *testing.T) { // Listing sessions with no results returns nil ws := memdb.NewWatchSet() - idx, res, err := s.NodeSessions(ws, "node1") + idx, res, err := s.NodeSessions(ws, "node1", nil) if idx != 0 || res != nil || err != nil { t.Fatalf("expected (0, nil, nil), got: (%d, %#v, %#v)", idx, res, err) } @@ -290,7 +298,7 @@ func TestStateStore_NodeSessions(t *testing.T) { // Query all of the sessions associated with a specific // node in the state store. ws1 := memdb.NewWatchSet() - idx, res, err = s.NodeSessions(ws1, "node1") + idx, res, err = s.NodeSessions(ws1, "node1", nil) if err != nil { t.Fatalf("err: %s", err) } @@ -302,7 +310,7 @@ func TestStateStore_NodeSessions(t *testing.T) { } ws2 := memdb.NewWatchSet() - idx, res, err = s.NodeSessions(ws2, "node2") + idx, res, err = s.NodeSessions(ws2, "node2", nil) if err != nil { t.Fatalf("err: %s", err) } @@ -314,7 +322,7 @@ func TestStateStore_NodeSessions(t *testing.T) { } // Destroying a session on node1 should not affect node2's watch. - if err := s.SessionDestroy(100, sessions1[0].ID); err != nil { + if err := s.SessionDestroy(100, sessions1[0].ID, nil); err != nil { t.Fatalf("err: %s", err) } if !watchFired(ws1) { @@ -330,7 +338,7 @@ func TestStateStore_SessionDestroy(t *testing.T) { // Session destroy is idempotent and returns no error // if the session doesn't exist. - if err := s.SessionDestroy(1, testUUID()); err != nil { + if err := s.SessionDestroy(1, testUUID(), nil); err != nil { t.Fatalf("err: %s", err) } @@ -352,7 +360,7 @@ func TestStateStore_SessionDestroy(t *testing.T) { } // Destroy the session. - if err := s.SessionDestroy(3, sess.ID); err != nil { + if err := s.SessionDestroy(3, sess.ID, nil); err != nil { t.Fatalf("err: %s", err) } @@ -412,7 +420,7 @@ func TestStateStore_Session_Snapshot_Restore(t *testing.T) { defer snap.Close() // Alter the real state store. - if err := s.SessionDestroy(8, session1); err != nil { + if err := s.SessionDestroy(8, session1, nil); err != nil { t.Fatalf("err: %s", err) } @@ -456,7 +464,7 @@ func TestStateStore_Session_Snapshot_Restore(t *testing.T) { // Read the restored sessions back out and verify that they // match. - idx, res, err := s.SessionList(nil) + idx, res, err := s.SessionList(nil, nil) if err != nil { t.Fatalf("err: %s", err) } @@ -522,7 +530,7 @@ func TestStateStore_Session_Invalidate_DeleteNode(t *testing.T) { // Delete the node and make sure the watch fires. ws := memdb.NewWatchSet() - idx, s2, err := s.SessionGet(ws, session.ID) + idx, s2, err := s.SessionGet(ws, session.ID, nil) if err != nil { t.Fatalf("err: %v", err) } @@ -534,7 +542,7 @@ func TestStateStore_Session_Invalidate_DeleteNode(t *testing.T) { } // Lookup by ID, should be nil. - idx, s2, err = s.SessionGet(nil, session.ID) + idx, s2, err = s.SessionGet(nil, session.ID, nil) if err != nil { t.Fatalf("err: %v", err) } @@ -577,7 +585,7 @@ func TestStateStore_Session_Invalidate_DeleteService(t *testing.T) { // Delete the service and make sure the watch fires. ws := memdb.NewWatchSet() - idx, s2, err := s.SessionGet(ws, session.ID) + idx, s2, err := s.SessionGet(ws, session.ID, nil) if err != nil { t.Fatalf("err: %v", err) } @@ -589,7 +597,7 @@ func TestStateStore_Session_Invalidate_DeleteService(t *testing.T) { } // Lookup by ID, should be nil. - idx, s2, err = s.SessionGet(nil, session.ID) + idx, s2, err = s.SessionGet(nil, session.ID, nil) if err != nil { t.Fatalf("err: %v", err) } @@ -627,7 +635,7 @@ func TestStateStore_Session_Invalidate_Critical_Check(t *testing.T) { // Invalidate the check and make sure the watches fire. ws := memdb.NewWatchSet() - idx, s2, err := s.SessionGet(ws, session.ID) + idx, s2, err := s.SessionGet(ws, session.ID, nil) if err != nil { t.Fatalf("err: %v", err) } @@ -640,7 +648,7 @@ func TestStateStore_Session_Invalidate_Critical_Check(t *testing.T) { } // Lookup by ID, should be nil. - idx, s2, err = s.SessionGet(nil, session.ID) + idx, s2, err = s.SessionGet(nil, session.ID, nil) if err != nil { t.Fatalf("err: %v", err) } @@ -678,7 +686,7 @@ func TestStateStore_Session_Invalidate_DeleteCheck(t *testing.T) { // Delete the check and make sure the watches fire. ws := memdb.NewWatchSet() - idx, s2, err := s.SessionGet(ws, session.ID) + idx, s2, err := s.SessionGet(ws, session.ID, nil) if err != nil { t.Fatalf("err: %v", err) } @@ -690,7 +698,7 @@ func TestStateStore_Session_Invalidate_DeleteCheck(t *testing.T) { } // Lookup by ID, should be nil. - idx, s2, err = s.SessionGet(nil, session.ID) + idx, s2, err = s.SessionGet(nil, session.ID, nil) if err != nil { t.Fatalf("err: %v", err) } @@ -746,7 +754,7 @@ func TestStateStore_Session_Invalidate_Key_Unlock_Behavior(t *testing.T) { // Delete the node and make sure the watches fire. ws := memdb.NewWatchSet() - idx, s2, err := s.SessionGet(ws, session.ID) + idx, s2, err := s.SessionGet(ws, session.ID, nil) if err != nil { t.Fatalf("err: %v", err) } @@ -758,7 +766,7 @@ func TestStateStore_Session_Invalidate_Key_Unlock_Behavior(t *testing.T) { } // Lookup by ID, should be nil. - idx, s2, err = s.SessionGet(nil, session.ID) + idx, s2, err = s.SessionGet(nil, session.ID, nil) if err != nil { t.Fatalf("err: %v", err) } @@ -770,7 +778,7 @@ func TestStateStore_Session_Invalidate_Key_Unlock_Behavior(t *testing.T) { } // Key should be unlocked. - idx, d2, err := s.KVSGet(nil, "/foo") + idx, d2, err := s.KVSGet(nil, "/foo", nil) if err != nil { t.Fatalf("err: %s", err) } @@ -788,7 +796,7 @@ func TestStateStore_Session_Invalidate_Key_Unlock_Behavior(t *testing.T) { } // Key should have a lock delay. - expires := s.KVSLockDelay("/foo") + expires := s.KVSLockDelay("/foo", nil) if expires.Before(time.Now().Add(30 * time.Millisecond)) { t.Fatalf("Bad: %v", expires) } @@ -828,7 +836,7 @@ func TestStateStore_Session_Invalidate_Key_Delete_Behavior(t *testing.T) { // Delete the node and make sure the watches fire. ws := memdb.NewWatchSet() - idx, s2, err := s.SessionGet(ws, session.ID) + idx, s2, err := s.SessionGet(ws, session.ID, nil) if err != nil { t.Fatalf("err: %v", err) } @@ -840,7 +848,7 @@ func TestStateStore_Session_Invalidate_Key_Delete_Behavior(t *testing.T) { } // Lookup by ID, should be nil. - idx, s2, err = s.SessionGet(nil, session.ID) + idx, s2, err = s.SessionGet(nil, session.ID, nil) if err != nil { t.Fatalf("err: %v", err) } @@ -852,7 +860,7 @@ func TestStateStore_Session_Invalidate_Key_Delete_Behavior(t *testing.T) { } // Key should be deleted. - idx, d2, err := s.KVSGet(nil, "/bar") + idx, d2, err := s.KVSGet(nil, "/bar", nil) if err != nil { t.Fatalf("err: %s", err) } @@ -864,7 +872,7 @@ func TestStateStore_Session_Invalidate_Key_Delete_Behavior(t *testing.T) { } // Key should have a lock delay. - expires := s.KVSLockDelay("/bar") + expires := s.KVSLockDelay("/bar", nil) if expires.Before(time.Now().Add(30 * time.Millisecond)) { t.Fatalf("Bad: %v", expires) } @@ -896,11 +904,11 @@ func TestStateStore_Session_Invalidate_PreparedQuery_Delete(t *testing.T) { // Invalidate the session and make sure the watches fire. ws := memdb.NewWatchSet() - idx, s2, err := s.SessionGet(ws, session.ID) + idx, s2, err := s.SessionGet(ws, session.ID, nil) if err != nil { t.Fatalf("err: %v", err) } - if err := s.SessionDestroy(5, session.ID); err != nil { + if err := s.SessionDestroy(5, session.ID, nil); err != nil { t.Fatalf("err: %v", err) } if !watchFired(ws) { @@ -908,7 +916,7 @@ func TestStateStore_Session_Invalidate_PreparedQuery_Delete(t *testing.T) { } // Make sure the session is gone. - idx, s2, err = s.SessionGet(nil, session.ID) + idx, s2, err = s.SessionGet(nil, session.ID, nil) if err != nil { t.Fatalf("err: %v", err) } diff --git a/agent/consul/state/state_store.go b/agent/consul/state/state_store.go index b907dc3882..f7680b4919 100644 --- a/agent/consul/state/state_store.go +++ b/agent/consul/state/state_store.go @@ -3,9 +3,8 @@ package state import ( "errors" "fmt" - "github.com/hashicorp/consul/types" - memdb "github.com/hashicorp/go-memdb" + "github.com/hashicorp/go-memdb" ) var ( diff --git a/agent/consul/state/state_store_test.go b/agent/consul/state/state_store_test.go index b35baee38e..4c9805d847 100644 --- a/agent/consul/state/state_store_test.go +++ b/agent/consul/state/state_store_test.go @@ -178,15 +178,22 @@ func testRegisterConnectNativeService(t *testing.T, s *Store, idx uint64, nodeID require.NoError(t, s.EnsureService(idx, nodeID, svc)) } -func testSetKey(t *testing.T, s *Store, idx uint64, key, value string) { - entry := &structs.DirEntry{Key: key, Value: []byte(value)} +func testSetKey(t *testing.T, s *Store, idx uint64, key, value string, entMeta *structs.EnterpriseMeta) { + entry := &structs.DirEntry{ + Key: key, + Value: []byte(value), + } + if entMeta != nil { + entry.EnterpriseMeta = *entMeta + } + if err := s.KVSSet(idx, entry); err != nil { t.Fatalf("err: %s", err) } tx := s.db.Txn(false) defer tx.Abort() - e, err := tx.First("kvs", "id", key) + e, err := firstWithTxn(tx, "kvs", "id", key, entMeta) if err != nil { t.Fatalf("err: %s", err) } @@ -223,7 +230,7 @@ func TestStateStore_Restore_Abort(t *testing.T) { } restore.Abort() - idx, entries, err := s.KVSList(nil, "") + idx, entries, err := s.KVSList(nil, "", nil) if err != nil { t.Fatalf("err: %s", err) } diff --git a/agent/consul/state/txn.go b/agent/consul/state/txn.go index 941cae4118..76fb2d8604 100644 --- a/agent/consul/state/txn.go +++ b/agent/consul/state/txn.go @@ -19,17 +19,17 @@ func (s *Store) txnKVS(tx *memdb.Txn, idx uint64, op *structs.TxnKVOp) (structs. err = s.kvsSetTxn(tx, idx, entry, false) case api.KVDelete: - err = s.kvsDeleteTxn(tx, idx, op.DirEnt.Key) + err = s.kvsDeleteTxn(tx, idx, op.DirEnt.Key, &op.DirEnt.EnterpriseMeta) case api.KVDeleteCAS: var ok bool - ok, err = s.kvsDeleteCASTxn(tx, idx, op.DirEnt.ModifyIndex, op.DirEnt.Key) + ok, err = s.kvsDeleteCASTxn(tx, idx, op.DirEnt.ModifyIndex, op.DirEnt.Key, &op.DirEnt.EnterpriseMeta) if !ok && err == nil { err = fmt.Errorf("failed to delete key %q, index is stale", op.DirEnt.Key) } case api.KVDeleteTree: - err = s.kvsDeleteTreeTxn(tx, idx, op.DirEnt.Key) + err = s.kvsDeleteTreeTxn(tx, idx, op.DirEnt.Key, &op.DirEnt.EnterpriseMeta) case api.KVCAS: var ok bool @@ -56,14 +56,14 @@ func (s *Store) txnKVS(tx *memdb.Txn, idx uint64, op *structs.TxnKVOp) (structs. } case api.KVGet: - _, entry, err = s.kvsGetTxn(tx, nil, op.DirEnt.Key) + _, entry, err = s.kvsGetTxn(tx, nil, op.DirEnt.Key, &op.DirEnt.EnterpriseMeta) if entry == nil && err == nil { err = fmt.Errorf("key %q doesn't exist", op.DirEnt.Key) } case api.KVGetTree: var entries structs.DirEntries - _, entries, err = s.kvsListTxn(tx, nil, op.DirEnt.Key) + _, entries, err = s.kvsListTxn(tx, nil, op.DirEnt.Key, &op.DirEnt.EnterpriseMeta) if err == nil { results := make(structs.TxnResults, 0, len(entries)) for _, e := range entries { @@ -74,13 +74,13 @@ func (s *Store) txnKVS(tx *memdb.Txn, idx uint64, op *structs.TxnKVOp) (structs. } case api.KVCheckSession: - entry, err = s.kvsCheckSessionTxn(tx, op.DirEnt.Key, op.DirEnt.Session) + entry, err = s.kvsCheckSessionTxn(tx, op.DirEnt.Key, op.DirEnt.Session, &op.DirEnt.EnterpriseMeta) case api.KVCheckIndex: - entry, err = s.kvsCheckIndexTxn(tx, op.DirEnt.Key, op.DirEnt.ModifyIndex) + entry, err = s.kvsCheckIndexTxn(tx, op.DirEnt.Key, op.DirEnt.ModifyIndex, &op.DirEnt.EnterpriseMeta) case api.KVCheckNotExists: - _, entry, err = s.kvsGetTxn(tx, nil, op.DirEnt.Key) + _, entry, err = s.kvsGetTxn(tx, nil, op.DirEnt.Key, &op.DirEnt.EnterpriseMeta) if entry != nil && err == nil { err = fmt.Errorf("key %q exists", op.DirEnt.Key) } @@ -110,6 +110,23 @@ func (s *Store) txnKVS(tx *memdb.Txn, idx uint64, op *structs.TxnKVOp) (structs. return nil, nil } +// txnSession handles all Session-related operations. +func (s *Store) txnSession(tx *memdb.Txn, idx uint64, op *structs.TxnSessionOp) error { + var err error + + switch op.Verb { + case api.SessionDelete: + err = s.sessionDeleteWithSession(tx, &op.Session, idx) + default: + err = fmt.Errorf("unknown Session verb %q", op.Verb) + } + if err != nil { + return fmt.Errorf("failed to delete session: %v", err) + } + + return nil +} + // txnIntention handles all Intention-related operations. func (s *Store) txnIntention(tx *memdb.Txn, idx uint64, op *structs.TxnIntentionOp) error { switch op.Op { @@ -332,6 +349,8 @@ func (s *Store) txnDispatch(tx *memdb.Txn, idx uint64, ops structs.TxnOps) (stru ret, err = s.txnService(tx, idx, op.Service) case op.Check != nil: ret, err = s.txnCheck(tx, idx, op.Check) + case op.Session != nil: + err = s.txnSession(tx, idx, op.Session) default: err = fmt.Errorf("no operation specified") } diff --git a/agent/consul/state/txn_test.go b/agent/consul/state/txn_test.go index c21e3a0ea9..2f908aa8a8 100644 --- a/agent/consul/state/txn_test.go +++ b/agent/consul/state/txn_test.go @@ -2,7 +2,6 @@ package state import ( "fmt" - "reflect" "strings" "testing" @@ -500,11 +499,11 @@ func TestStateStore_Txn_KVS(t *testing.T) { s := testStateStore(t) // Create KV entries in the state store. - testSetKey(t, s, 1, "foo/delete", "bar") - testSetKey(t, s, 2, "foo/bar/baz", "baz") - testSetKey(t, s, 3, "foo/bar/zip", "zip") - testSetKey(t, s, 4, "foo/zorp", "zorp") - testSetKey(t, s, 5, "foo/update", "stale") + testSetKey(t, s, 1, "foo/delete", "bar", nil) + testSetKey(t, s, 2, "foo/bar/baz", "baz", nil) + testSetKey(t, s, 3, "foo/bar/zip", "zip", nil) + testSetKey(t, s, 4, "foo/zorp", "zorp", nil) + testSetKey(t, s, 5, "foo/update", "stale", nil) // Make a real session. testRegisterNode(t, s, 6, "node1") @@ -776,14 +775,23 @@ func TestStateStore_Txn_KVS(t *testing.T) { if len(results) != len(expected) { t.Fatalf("bad: %v", results) } - for i := range results { - if !reflect.DeepEqual(results[i], expected[i]) { - t.Fatalf("bad %d", i) + for i, e := range expected { + if e.KV.Key != results[i].KV.Key { + t.Fatalf("expected key %s, got %s", e.KV.Key, results[i].KV.Key) + } + if e.KV.LockIndex != results[i].KV.LockIndex { + t.Fatalf("expected lock index %d, got %d", e.KV.LockIndex, results[i].KV.LockIndex) + } + if e.KV.CreateIndex != results[i].KV.CreateIndex { + t.Fatalf("expected create index %d, got %d", e.KV.CreateIndex, results[i].KV.CreateIndex) + } + if e.KV.ModifyIndex != results[i].KV.ModifyIndex { + t.Fatalf("expected modify index %d, got %d", e.KV.ModifyIndex, results[i].KV.ModifyIndex) } } // Pull the resulting state store contents. - idx, actual, err := s.KVSList(nil, "") + idx, actual, err := s.KVSList(nil, "", nil) if err != nil { t.Fatalf("err: %s", err) } @@ -821,9 +829,21 @@ func TestStateStore_Txn_KVS(t *testing.T) { if len(actual) != len(entries) { t.Fatalf("bad len: %d != %d", len(actual), len(entries)) } - for i := range actual { - if !reflect.DeepEqual(actual[i], entries[i]) { - t.Fatalf("bad %d", i) + for i, e := range entries { + if e.Key != actual[i].Key { + t.Fatalf("expected key %s, got %s", e.Key, actual[i].Key) + } + if string(e.Value) != string(actual[i].Value) { + t.Fatalf("expected value %s, got %s", e.Value, actual[i].Value) + } + if e.LockIndex != actual[i].LockIndex { + t.Fatalf("expected lock index %d, got %d", e.LockIndex, actual[i].LockIndex) + } + if e.CreateIndex != actual[i].CreateIndex { + t.Fatalf("expected create index %d, got %d", e.CreateIndex, actual[i].CreateIndex) + } + if e.ModifyIndex != actual[i].ModifyIndex { + t.Fatalf("expected modify index %d, got %d", e.ModifyIndex, actual[i].ModifyIndex) } } } @@ -832,8 +852,8 @@ func TestStateStore_Txn_KVS_Rollback(t *testing.T) { s := testStateStore(t) // Create KV entries in the state store. - testSetKey(t, s, 1, "foo/delete", "bar") - testSetKey(t, s, 2, "foo/update", "stale") + testSetKey(t, s, 1, "foo/delete", "bar", nil) + testSetKey(t, s, 2, "foo/update", "stale", nil) testRegisterNode(t, s, 3, "node1") session := testUUID() @@ -852,7 +872,7 @@ func TestStateStore_Txn_KVS_Rollback(t *testing.T) { // This function verifies that the state store wasn't changed. verifyStateStore := func(desc string) { - idx, actual, err := s.KVSList(nil, "") + idx, actual, err := s.KVSList(nil, "", nil) if err != nil { t.Fatalf("err (%s): %s", desc, err) } @@ -892,9 +912,21 @@ func TestStateStore_Txn_KVS_Rollback(t *testing.T) { if len(actual) != len(entries) { t.Fatalf("bad len (%s): %d != %d", desc, len(actual), len(entries)) } - for i := range actual { - if !reflect.DeepEqual(actual[i], entries[i]) { - t.Fatalf("bad (%s): op %d: %v != %v", desc, i, *(actual[i]), *(entries[i])) + for i, e := range entries { + if e.Key != actual[i].Key { + t.Fatalf("expected key %s, got %s", e.Key, actual[i].Key) + } + if string(e.Value) != string(actual[i].Value) { + t.Fatalf("expected value %s, got %s", e.Value, actual[i].Value) + } + if e.LockIndex != actual[i].LockIndex { + t.Fatalf("expected lock index %d, got %d", e.LockIndex, actual[i].LockIndex) + } + if e.CreateIndex != actual[i].CreateIndex { + t.Fatalf("expected create index %d, got %d", e.CreateIndex, actual[i].CreateIndex) + } + if e.ModifyIndex != actual[i].ModifyIndex { + t.Fatalf("expected modify index %d, got %d", e.ModifyIndex, actual[i].ModifyIndex) } } } @@ -1027,9 +1059,9 @@ func TestStateStore_Txn_KVS_RO(t *testing.T) { s := testStateStore(t) // Create KV entries in the state store. - testSetKey(t, s, 1, "foo", "bar") - testSetKey(t, s, 2, "foo/bar/baz", "baz") - testSetKey(t, s, 3, "foo/bar/zip", "zip") + testSetKey(t, s, 1, "foo", "bar", nil) + testSetKey(t, s, 2, "foo/bar/baz", "baz", nil) + testSetKey(t, s, 3, "foo/bar/zip", "zip", nil) // Set up a transaction that hits all the read-only operations. ops := structs.TxnOps{ @@ -1129,9 +1161,18 @@ func TestStateStore_Txn_KVS_RO(t *testing.T) { if len(results) != len(expected) { t.Fatalf("bad: %v", results) } - for i := range results { - if !reflect.DeepEqual(results[i], expected[i]) { - t.Fatalf("bad %d", i) + for i, e := range expected { + if e.KV.Key != results[i].KV.Key { + t.Fatalf("expected key %s, got %s", e.KV.Key, results[i].KV.Key) + } + if e.KV.LockIndex != results[i].KV.LockIndex { + t.Fatalf("expected lock index %d, got %d", e.KV.LockIndex, results[i].KV.LockIndex) + } + if e.KV.CreateIndex != results[i].KV.CreateIndex { + t.Fatalf("expected create index %d, got %d", e.KV.CreateIndex, results[i].KV.CreateIndex) + } + if e.KV.ModifyIndex != results[i].KV.ModifyIndex { + t.Fatalf("expected modify index %d, got %d", e.KV.ModifyIndex, results[i].KV.ModifyIndex) } } } @@ -1140,9 +1181,9 @@ func TestStateStore_Txn_KVS_RO_Safety(t *testing.T) { s := testStateStore(t) // Create KV entries in the state store. - testSetKey(t, s, 1, "foo", "bar") - testSetKey(t, s, 2, "foo/bar/baz", "baz") - testSetKey(t, s, 3, "foo/bar/zip", "zip") + testSetKey(t, s, 1, "foo", "bar", nil) + testSetKey(t, s, 2, "foo/bar/baz", "baz", nil) + testSetKey(t, s, 3, "foo/bar/zip", "zip", nil) // Set up a transaction that hits all the read-only operations. ops := structs.TxnOps{ diff --git a/agent/consul/txn_endpoint_test.go b/agent/consul/txn_endpoint_test.go index 5cbfb56f2f..a929970e83 100644 --- a/agent/consul/txn_endpoint_test.go +++ b/agent/consul/txn_endpoint_test.go @@ -14,7 +14,6 @@ import ( "github.com/hashicorp/consul/testrpc" "github.com/hashicorp/consul/types" "github.com/hashicorp/net-rpc-msgpackrpc" - "github.com/pascaldekloe/goe/verify" "github.com/stretchr/testify/require" ) @@ -214,7 +213,7 @@ func TestTxn_Apply(t *testing.T) { // Verify the state store directly. state := s1.fsm.State() - _, d, err := state.KVSGet(nil, "test") + _, d, err := state.KVSGet(nil, "test", nil) if err != nil { t.Fatalf("err: %v", err) } @@ -262,6 +261,7 @@ func TestTxn_Apply(t *testing.T) { CreateIndex: d.CreateIndex, ModifyIndex: d.ModifyIndex, }, + EnterpriseMeta: d.EnterpriseMeta, }, }, &structs.TxnResult{ @@ -273,6 +273,7 @@ func TestTxn_Apply(t *testing.T) { CreateIndex: d.CreateIndex, ModifyIndex: d.ModifyIndex, }, + EnterpriseMeta: d.EnterpriseMeta, }, }, &structs.TxnResult{ @@ -295,7 +296,7 @@ func TestTxn_Apply(t *testing.T) { }, }, } - verify.Values(t, "", out, expected) + require.Equal(t, expected, out) } func TestTxn_Apply_ACLDeny(t *testing.T) { @@ -609,7 +610,7 @@ func TestTxn_Apply_ACLDeny(t *testing.T) { } } - verify.Values(t, "", out, expected) + require.Equal(expected, out) } func TestTxn_Apply_LockDelay(t *testing.T) { @@ -620,7 +621,7 @@ func TestTxn_Apply_LockDelay(t *testing.T) { codec := rpcClient(t, s1) defer codec.Close() - testrpc.WaitForLeader(t, s1.RPC, "dc1") + testrpc.WaitForTestAgent(t, s1.RPC, "dc1") // Create and invalidate a session with a lock. state := s1.fsm.State() @@ -643,7 +644,8 @@ func TestTxn_Apply_LockDelay(t *testing.T) { if ok, err := state.KVSLock(3, d); err != nil || !ok { t.Fatalf("err: %v", err) } - if err := state.SessionDestroy(4, id); err != nil { + + if err := state.SessionDestroy(4, id, nil); err != nil { t.Fatalf("err: %v", err) } @@ -777,6 +779,8 @@ func TestTxn_Read(t *testing.T) { // Verify the transaction's return value. svc.Weights = &structs.Weights{Passing: 1, Warning: 1} svc.RaftIndex = structs.RaftIndex{CreateIndex: 3, ModifyIndex: 3} + + entMeta := out.Results[0].KV.EnterpriseMeta expected := structs.TxnReadResponse{ TxnResponse: structs.TxnResponse{ Results: structs.TxnResults{ @@ -788,6 +792,7 @@ func TestTxn_Read(t *testing.T) { CreateIndex: 1, ModifyIndex: 1, }, + EnterpriseMeta: entMeta, }, }, &structs.TxnResult{ @@ -805,7 +810,7 @@ func TestTxn_Read(t *testing.T) { KnownLeader: true, }, } - verify.Values(t, "", out, expected) + require.Equal(expected, out) } func TestTxn_Read_ACLDeny(t *testing.T) { diff --git a/agent/http.go b/agent/http.go index a875e76aac..9e49414577 100644 --- a/agent/http.go +++ b/agent/http.go @@ -19,13 +19,13 @@ import ( "time" "github.com/NYTimes/gziphandler" - metrics "github.com/armon/go-metrics" + "github.com/armon/go-metrics" "github.com/hashicorp/consul/acl" "github.com/hashicorp/consul/agent/cache" "github.com/hashicorp/consul/agent/consul" "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/api" - cleanhttp "github.com/hashicorp/go-cleanhttp" + "github.com/hashicorp/go-cleanhttp" "github.com/mitchellh/mapstructure" "github.com/pkg/errors" ) diff --git a/agent/kvs_endpoint.go b/agent/kvs_endpoint.go index 4e1d13e787..49cde3b626 100644 --- a/agent/kvs_endpoint.go +++ b/agent/kvs_endpoint.go @@ -18,6 +18,7 @@ func (s *HTTPServer) KVSEndpoint(resp http.ResponseWriter, req *http.Request) (i if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done { return nil, nil } + s.parseEntMeta(req, &args.EnterpriseMeta) // Pull out the key name, validation left to each sub-handler args.Key = strings.TrimPrefix(req.URL.Path, "/v1/kv/") @@ -96,10 +97,11 @@ func (s *HTTPServer) KVSGetKeys(resp http.ResponseWriter, req *http.Request, arg // Construct the args listArgs := structs.KeyListRequest{ - Datacenter: args.Datacenter, - Prefix: args.Key, - Seperator: sep, - QueryOptions: args.QueryOptions, + Datacenter: args.Datacenter, + Prefix: args.Key, + Seperator: sep, + EnterpriseMeta: args.EnterpriseMeta, + QueryOptions: args.QueryOptions, } // Make the RPC @@ -135,9 +137,10 @@ func (s *HTTPServer) KVSPut(resp http.ResponseWriter, req *http.Request, args *s Datacenter: args.Datacenter, Op: api.KVSet, DirEnt: structs.DirEntry{ - Key: args.Key, - Flags: 0, - Value: nil, + Key: args.Key, + Flags: 0, + Value: nil, + EnterpriseMeta: args.EnterpriseMeta, }, } applyReq.Token = args.Token @@ -210,7 +213,8 @@ func (s *HTTPServer) KVSDelete(resp http.ResponseWriter, req *http.Request, args Datacenter: args.Datacenter, Op: api.KVDelete, DirEnt: structs.DirEntry{ - Key: args.Key, + Key: args.Key, + EnterpriseMeta: args.EnterpriseMeta, }, } applyReq.Token = args.Token diff --git a/agent/session_endpoint.go b/agent/session_endpoint.go index aba7dc1b29..7961aa1de6 100644 --- a/agent/session_endpoint.go +++ b/agent/session_endpoint.go @@ -31,6 +31,7 @@ func (s *HTTPServer) SessionCreate(resp http.ResponseWriter, req *http.Request) } s.parseDC(req, &args.Datacenter) s.parseToken(req, &args.Token) + s.parseEntMeta(req, &args.Session.EnterpriseMeta) // Handle optional request body if req.ContentLength > 0 { @@ -79,6 +80,7 @@ func (s *HTTPServer) SessionDestroy(resp http.ResponseWriter, req *http.Request) } s.parseDC(req, &args.Datacenter) s.parseToken(req, &args.Token) + s.parseEntMeta(req, &args.Session.EnterpriseMeta) // Pull out the session id args.Session.ID = strings.TrimPrefix(req.URL.Path, "/v1/session/destroy/") @@ -101,10 +103,11 @@ func (s *HTTPServer) SessionRenew(resp http.ResponseWriter, req *http.Request) ( if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done { return nil, nil } + s.parseEntMeta(req, &args.EnterpriseMeta) // Pull out the session id - args.Session = strings.TrimPrefix(req.URL.Path, "/v1/session/renew/") - if args.Session == "" { + args.SessionID = strings.TrimPrefix(req.URL.Path, "/v1/session/renew/") + if args.SessionID == "" { resp.WriteHeader(http.StatusBadRequest) fmt.Fprint(resp, "Missing session") return nil, nil @@ -115,7 +118,7 @@ func (s *HTTPServer) SessionRenew(resp http.ResponseWriter, req *http.Request) ( return nil, err } else if out.Sessions == nil { resp.WriteHeader(http.StatusNotFound) - fmt.Fprintf(resp, "Session id '%s' not found", args.Session) + fmt.Fprintf(resp, "Session id '%s' not found", args.SessionID) return nil, nil } @@ -128,10 +131,11 @@ func (s *HTTPServer) SessionGet(resp http.ResponseWriter, req *http.Request) (in if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done { return nil, nil } + s.parseEntMeta(req, &args.EnterpriseMeta) // Pull out the session id - args.Session = strings.TrimPrefix(req.URL.Path, "/v1/session/info/") - if args.Session == "" { + args.SessionID = strings.TrimPrefix(req.URL.Path, "/v1/session/info/") + if args.SessionID == "" { resp.WriteHeader(http.StatusBadRequest) fmt.Fprint(resp, "Missing session") return nil, nil @@ -152,10 +156,11 @@ func (s *HTTPServer) SessionGet(resp http.ResponseWriter, req *http.Request) (in // SessionList is used to list all the sessions func (s *HTTPServer) SessionList(resp http.ResponseWriter, req *http.Request) (interface{}, error) { - args := structs.DCSpecificRequest{} + args := structs.SessionSpecificRequest{} if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done { return nil, nil } + s.parseEntMeta(req, &args.EnterpriseMeta) var out structs.IndexedSessions defer setMeta(resp, &out.QueryMeta) @@ -176,6 +181,7 @@ func (s *HTTPServer) SessionsForNode(resp http.ResponseWriter, req *http.Request if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done { return nil, nil } + s.parseEntMeta(req, &args.EnterpriseMeta) // Pull out the node name args.Node = strings.TrimPrefix(req.URL.Path, "/v1/session/node/") diff --git a/agent/session_endpoint_test.go b/agent/session_endpoint_test.go index 6485e0254e..2e4c7e9991 100644 --- a/agent/session_endpoint_test.go +++ b/agent/session_endpoint_test.go @@ -5,6 +5,7 @@ import ( "encoding/json" "net/http" "net/http/httptest" + "reflect" "testing" "time" @@ -13,13 +14,14 @@ import ( "github.com/hashicorp/consul/sdk/testutil/retry" "github.com/hashicorp/consul/testrpc" "github.com/hashicorp/consul/types" - "github.com/pascaldekloe/goe/verify" ) func verifySession(t *testing.T, r *retry.R, a *TestAgent, want structs.Session) { + t.Helper() + args := &structs.SessionSpecificRequest{ Datacenter: "dc1", - Session: want.ID, + SessionID: want.ID, } var out structs.IndexedSessions if err := a.RPC("Session.Get", args, &out); err != nil { @@ -34,7 +36,22 @@ func verifySession(t *testing.T, r *retry.R, a *TestAgent, want structs.Session) got := *(out.Sessions[0]) got.CreateIndex = 0 got.ModifyIndex = 0 - verify.Values(t, "", got, want) + + if got.ID != want.ID { + t.Fatalf("bad session ID: expected %s, got %s", want.ID, got.ID) + } + if got.Node != want.Node { + t.Fatalf("bad session Node: expected %s, got %s", want.Node, got.Node) + } + if got.Behavior != want.Behavior { + t.Fatalf("bad session Behavior: expected %s, got %s", want.Behavior, got.Behavior) + } + if got.LockDelay != want.LockDelay { + t.Fatalf("bad session LockDelay: expected %s, got %s", want.LockDelay, got.LockDelay) + } + if !reflect.DeepEqual(got.Checks, want.Checks) { + t.Fatalf("bad session Checks: expected %+v, got %+v", want.Checks, got.Checks) + } } func TestSessionCreate(t *testing.T) { @@ -224,7 +241,8 @@ func TestSessionCreate_NoCheck(t *testing.T) { } func makeTestSession(t *testing.T, srv *HTTPServer) string { - req, _ := http.NewRequest("PUT", "/v1/session/create", nil) + url := "/v1/session/create" + req, _ := http.NewRequest("PUT", url, nil) resp := httptest.NewRecorder() obj, err := srv.SessionCreate(resp, req) if err != nil { @@ -243,7 +261,8 @@ func makeTestSessionDelete(t *testing.T, srv *HTTPServer) string { } enc.Encode(raw) - req, _ := http.NewRequest("PUT", "/v1/session/create", body) + url := "/v1/session/create" + req, _ := http.NewRequest("PUT", url, body) resp := httptest.NewRecorder() obj, err := srv.SessionCreate(resp, req) if err != nil { @@ -262,7 +281,8 @@ func makeTestSessionTTL(t *testing.T, srv *HTTPServer, ttl string) string { } enc.Encode(raw) - req, _ := http.NewRequest("PUT", "/v1/session/create", body) + url := "/v1/session/create" + req, _ := http.NewRequest("PUT", url, body) resp := httptest.NewRecorder() obj, err := srv.SessionCreate(resp, req) if err != nil { diff --git a/agent/structs/structs.go b/agent/structs/structs.go index 3551cecace..94e916172c 100644 --- a/agent/structs/structs.go +++ b/agent/structs/structs.go @@ -515,6 +515,7 @@ func (r *ServiceSpecificRequest) CacheMinIndex() uint64 { type NodeSpecificRequest struct { Datacenter string Node string + EnterpriseMeta QueryOptions } @@ -1620,6 +1621,7 @@ type DirEntry struct { Value []byte Session string `json:",omitempty"` + EnterpriseMeta RaftIndex } @@ -1635,6 +1637,7 @@ func (d *DirEntry) Clone() *DirEntry { CreateIndex: d.CreateIndex, ModifyIndex: d.ModifyIndex, }, + EnterpriseMeta: d.EnterpriseMeta, } } @@ -1664,6 +1667,7 @@ func (r *KVSRequest) RequestDatacenter() string { type KeyRequest struct { Datacenter string Key string + EnterpriseMeta QueryOptions } @@ -1677,6 +1681,7 @@ type KeyListRequest struct { Prefix string Seperator string QueryOptions + EnterpriseMeta } func (r *KeyListRequest) RequestDatacenter() string { @@ -1718,6 +1723,7 @@ type Session struct { Behavior SessionBehavior // What to do when session is invalidated TTL string + EnterpriseMeta RaftIndex } @@ -1773,7 +1779,8 @@ func (r *SessionRequest) RequestDatacenter() string { // SessionSpecificRequest is used to request a session by ID type SessionSpecificRequest struct { Datacenter string - Session string + SessionID string + EnterpriseMeta QueryOptions } diff --git a/agent/structs/structs_oss.go b/agent/structs/structs_oss.go index e5220ef9b4..1d146be24a 100644 --- a/agent/structs/structs_oss.go +++ b/agent/structs/structs_oss.go @@ -19,6 +19,11 @@ func (m *EnterpriseMeta) addToHash(hasher hash.Hash) { // do nothing } +// WildcardEnterpriseMeta stub +func WildcardEnterpriseMeta() *EnterpriseMeta { + return nil +} + // ReplicationEnterpriseMeta stub func ReplicationEnterpriseMeta() *EnterpriseMeta { return nil diff --git a/agent/structs/txn.go b/agent/structs/txn.go index 87d247bd02..07aaa83bee 100644 --- a/agent/structs/txn.go +++ b/agent/structs/txn.go @@ -49,10 +49,17 @@ type TxnCheckOp struct { Check HealthCheck } -// TxnCheckResult is used to define the result of a single operation on a health -// check inside a transaction. +// TxnCheckResult is used to define the result of a single operation on a +// session inside a transaction. type TxnCheckResult *HealthCheck +// TxnSessionOp is used to define a single operation on a session inside a +// transaction. +type TxnSessionOp struct { + Verb api.SessionOp + Session Session +} + // TxnKVOp is used to define a single operation on an Intention inside a // transaction. type TxnIntentionOp IntentionRequest @@ -65,6 +72,7 @@ type TxnOp struct { Node *TxnNodeOp Service *TxnServiceOp Check *TxnCheckOp + Session *TxnSessionOp } // TxnOps is a list of operations within a transaction. diff --git a/agent/txn_endpoint_test.go b/agent/txn_endpoint_test.go index ebfe83613d..d0048b8873 100644 --- a/agent/txn_endpoint_test.go +++ b/agent/txn_endpoint_test.go @@ -6,18 +6,16 @@ import ( "fmt" "net/http" "net/http/httptest" - "reflect" "strconv" "strings" "testing" "time" + "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/api" "github.com/hashicorp/consul/testrpc" "github.com/hashicorp/raft" - "github.com/pascaldekloe/goe/verify" - - "github.com/hashicorp/consul/agent/structs" + "github.com/stretchr/testify/assert" ) func TestTxnEndpoint_Bad_JSON(t *testing.T) { @@ -213,7 +211,10 @@ func TestTxnEndpoint_KV_Actions(t *testing.T) { if len(txnResp.Results) != 2 { t.Fatalf("bad: %v", txnResp) } + index = txnResp.Results[0].KV.ModifyIndex + entMeta := txnResp.Results[0].KV.EnterpriseMeta + expected := structs.TxnResponse{ Results: structs.TxnResults{ &structs.TxnResult{ @@ -227,6 +228,7 @@ func TestTxnEndpoint_KV_Actions(t *testing.T) { CreateIndex: index, ModifyIndex: index, }, + EnterpriseMeta: entMeta, }, }, &structs.TxnResult{ @@ -240,13 +242,12 @@ func TestTxnEndpoint_KV_Actions(t *testing.T) { CreateIndex: index, ModifyIndex: index, }, + EnterpriseMeta: entMeta, }, }, }, } - if !reflect.DeepEqual(txnResp, expected) { - t.Fatalf("bad: %v", txnResp) - } + assert.Equal(t, expected, txnResp) } // Do a read-only transaction that should get routed to the @@ -291,6 +292,7 @@ func TestTxnEndpoint_KV_Actions(t *testing.T) { if !ok { t.Fatalf("bad type: %T", obj) } + entMeta := txnResp.Results[0].KV.EnterpriseMeta expected := structs.TxnReadResponse{ TxnResponse: structs.TxnResponse{ Results: structs.TxnResults{ @@ -305,6 +307,7 @@ func TestTxnEndpoint_KV_Actions(t *testing.T) { CreateIndex: index, ModifyIndex: index, }, + EnterpriseMeta: entMeta, }, }, &structs.TxnResult{ @@ -318,6 +321,7 @@ func TestTxnEndpoint_KV_Actions(t *testing.T) { CreateIndex: index, ModifyIndex: index, }, + EnterpriseMeta: entMeta, }, }, }, @@ -326,9 +330,7 @@ func TestTxnEndpoint_KV_Actions(t *testing.T) { KnownLeader: true, }, } - if !reflect.DeepEqual(txnResp, expected) { - t.Fatalf("bad: %v", txnResp) - } + assert.Equal(t, expected, txnResp) } // Now that we have an index we can do a CAS to make sure the @@ -369,7 +371,10 @@ func TestTxnEndpoint_KV_Actions(t *testing.T) { if len(txnResp.Results) != 2 { t.Fatalf("bad: %v", txnResp) } + modIndex := txnResp.Results[0].KV.ModifyIndex + entMeta := txnResp.Results[0].KV.EnterpriseMeta + expected := structs.TxnResponse{ Results: structs.TxnResults{ &structs.TxnResult{ @@ -381,6 +386,7 @@ func TestTxnEndpoint_KV_Actions(t *testing.T) { CreateIndex: index, ModifyIndex: modIndex, }, + EnterpriseMeta: entMeta, }, }, &structs.TxnResult{ @@ -392,13 +398,12 @@ func TestTxnEndpoint_KV_Actions(t *testing.T) { CreateIndex: index, ModifyIndex: modIndex, }, + EnterpriseMeta: entMeta, }, }, }, } - if !reflect.DeepEqual(txnResp, expected) { - t.Fatalf("bad: %v", txnResp) - } + assert.Equal(t, expected, txnResp) } }) @@ -601,7 +606,7 @@ func TestTxnEndpoint_UpdateCheck(t *testing.T) { }, }, } - verify.Values(t, "", txnResp, expected) + assert.Equal(t, expected, txnResp) } func TestConvertOps_ContentLength(t *testing.T) { diff --git a/api/api.go b/api/api.go index 81072ebb75..c418cd2293 100644 --- a/api/api.go +++ b/api/api.go @@ -300,6 +300,10 @@ type Config struct { // If provided it is read once at startup and never again. TokenFile string + // Namespace is the name of the namespace to send along for the request + // when no other Namespace ispresent in the QueryOptions + Namespace string + TLSConfig TLSConfig } @@ -801,6 +805,9 @@ func (c *Client) newRequest(method, path string) *request { if c.config.Datacenter != "" { r.params.Set("dc", c.config.Datacenter) } + if c.config.Namespace != "" { + r.params.Set("ns", c.config.Namespace) + } if c.config.WaitTime != 0 { r.params.Set("wait", durToMsec(r.config.WaitTime)) } diff --git a/api/kv.go b/api/kv.go index bd45a067c9..5f9ebdde4b 100644 --- a/api/kv.go +++ b/api/kv.go @@ -40,6 +40,10 @@ type KVPair struct { // interactions with this key over the same session must specify the same // session ID. Session string + + // Namespace is the namespace the KVPair is associated with + // Namespacing is a Consul Enterprise feature. + Namespace string } // KVPairs is a list of KVPair objects diff --git a/api/lock.go b/api/lock.go index 82339cb744..e7d76c5169 100644 --- a/api/lock.go +++ b/api/lock.go @@ -79,6 +79,7 @@ type LockOptions struct { MonitorRetryTime time.Duration // Optional, defaults to DefaultMonitorRetryTime LockWaitTime time.Duration // Optional, defaults to DefaultLockWaitTime LockTryOnce bool // Optional, defaults to false which means try forever + Namespace string // Optional, defaults to API client config, namespace of ACL token, or "default" namespace } // LockKey returns a handle to a lock struct which can be used @@ -140,6 +141,10 @@ func (l *Lock) Lock(stopCh <-chan struct{}) (<-chan struct{}, error) { return nil, ErrLockHeld } + wOpts := WriteOptions{ + Namespace: l.opts.Namespace, + } + // Check if we need to create a session first l.lockSession = l.opts.Session if l.lockSession == "" { @@ -150,8 +155,9 @@ func (l *Lock) Lock(stopCh <-chan struct{}) (<-chan struct{}, error) { l.sessionRenew = make(chan struct{}) l.lockSession = s + session := l.c.Session() - go session.RenewPeriodic(l.opts.SessionTTL, s, nil, l.sessionRenew) + go session.RenewPeriodic(l.opts.SessionTTL, s, &wOpts, l.sessionRenew) // If we fail to acquire the lock, cleanup the session defer func() { @@ -164,8 +170,9 @@ func (l *Lock) Lock(stopCh <-chan struct{}) (<-chan struct{}, error) { // Setup the query options kv := l.c.KV() - qOpts := &QueryOptions{ + qOpts := QueryOptions{ WaitTime: l.opts.LockWaitTime, + Namespace: l.opts.Namespace, } start := time.Now() @@ -191,7 +198,7 @@ WAIT: attempts++ // Look for an existing lock, blocking until not taken - pair, meta, err := kv.Get(l.opts.Key, qOpts) + pair, meta, err := kv.Get(l.opts.Key, &qOpts) if err != nil { return nil, fmt.Errorf("failed to read lock: %v", err) } @@ -209,7 +216,8 @@ WAIT: // Try to acquire the lock pair = l.lockEntry(l.lockSession) - locked, _, err = kv.Acquire(pair, nil) + + locked, _, err = kv.Acquire(pair, &wOpts) if err != nil { return nil, fmt.Errorf("failed to acquire lock: %v", err) } @@ -218,7 +226,7 @@ WAIT: if !locked { // Determine why the lock failed qOpts.WaitIndex = 0 - pair, meta, err = kv.Get(l.opts.Key, qOpts) + pair, meta, err = kv.Get(l.opts.Key, &qOpts) if pair != nil && pair.Session != "" { //If the session is not null, this means that a wait can safely happen //using a long poll @@ -277,7 +285,9 @@ func (l *Lock) Unlock() error { // Release the lock explicitly kv := l.c.KV() - _, _, err := kv.Release(lockEnt, nil) + w := WriteOptions{Namespace: l.opts.Namespace} + + _, _, err := kv.Release(lockEnt, &w) if err != nil { return fmt.Errorf("failed to release lock: %v", err) } @@ -298,7 +308,9 @@ func (l *Lock) Destroy() error { // Look for an existing lock kv := l.c.KV() - pair, _, err := kv.Get(l.opts.Key, nil) + q := QueryOptions{Namespace: l.opts.Namespace} + + pair, _, err := kv.Get(l.opts.Key, &q) if err != nil { return fmt.Errorf("failed to read lock: %v", err) } @@ -319,7 +331,8 @@ func (l *Lock) Destroy() error { } // Attempt the delete - didRemove, _, err := kv.DeleteCAS(pair, nil) + w := WriteOptions{Namespace: l.opts.Namespace} + didRemove, _, err := kv.DeleteCAS(pair, &w) if err != nil { return fmt.Errorf("failed to remove lock: %v", err) } @@ -339,7 +352,8 @@ func (l *Lock) createSession() (string, error) { TTL: l.opts.SessionTTL, } } - id, _, err := session.Create(se, nil) + w := WriteOptions{Namespace: l.opts.Namespace} + id, _, err := session.Create(se, &w) if err != nil { return "", err } @@ -361,11 +375,14 @@ func (l *Lock) lockEntry(session string) *KVPair { func (l *Lock) monitorLock(session string, stopCh chan struct{}) { defer close(stopCh) kv := l.c.KV() - opts := &QueryOptions{RequireConsistent: true} + opts := QueryOptions{ + RequireConsistent: true, + Namespace: l.opts.Namespace, + } WAIT: retries := l.opts.MonitorRetries RETRY: - pair, meta, err := kv.Get(l.opts.Key, opts) + pair, meta, err := kv.Get(l.opts.Key, &opts) if err != nil { // If configured we can try to ride out a brief Consul unavailability // by doing retries. Note that we have to attempt the retry in a non- diff --git a/api/semaphore.go b/api/semaphore.go index bc4f885fec..d8c2ad2f59 100644 --- a/api/semaphore.go +++ b/api/semaphore.go @@ -73,6 +73,7 @@ type SemaphoreOptions struct { MonitorRetryTime time.Duration // Optional, defaults to DefaultMonitorRetryTime SemaphoreWaitTime time.Duration // Optional, defaults to DefaultSemaphoreWaitTime SemaphoreTryOnce bool // Optional, defaults to false which means try forever + Namespace string // Optional, defaults to API client config, namespace of ACL token, or "default" namespace } // semaphoreLock is written under the DefaultSemaphoreKey and @@ -176,14 +177,17 @@ func (s *Semaphore) Acquire(stopCh <-chan struct{}) (<-chan struct{}, error) { // Create the contender entry kv := s.c.KV() - made, _, err := kv.Acquire(s.contenderEntry(s.lockSession), nil) + wOpts := WriteOptions{Namespace: s.opts.Namespace} + + made, _, err := kv.Acquire(s.contenderEntry(s.lockSession), &wOpts) if err != nil || !made { return nil, fmt.Errorf("failed to make contender entry: %v", err) } // Setup the query options - qOpts := &QueryOptions{ + qOpts := QueryOptions{ WaitTime: s.opts.SemaphoreWaitTime, + Namespace: s.opts.Namespace, } start := time.Now() @@ -209,7 +213,7 @@ WAIT: attempts++ // Read the prefix - pairs, meta, err := kv.List(s.opts.Prefix, qOpts) + pairs, meta, err := kv.List(s.opts.Prefix, &qOpts) if err != nil { return nil, fmt.Errorf("failed to read prefix: %v", err) } @@ -247,7 +251,7 @@ WAIT: } // Attempt the acquisition - didSet, _, err := kv.CAS(newLock, nil) + didSet, _, err := kv.CAS(newLock, &wOpts) if err != nil { return nil, fmt.Errorf("failed to update lock: %v", err) } @@ -298,8 +302,12 @@ func (s *Semaphore) Release() error { // Remove ourselves as a lock holder kv := s.c.KV() key := path.Join(s.opts.Prefix, DefaultSemaphoreKey) + + wOpts := WriteOptions{Namespace: s.opts.Namespace} + qOpts := QueryOptions{Namespace: s.opts.Namespace} + READ: - pair, _, err := kv.Get(key, nil) + pair, _, err := kv.Get(key, &qOpts) if err != nil { return err } @@ -320,7 +328,7 @@ READ: } // Swap the locks - didSet, _, err := kv.CAS(newLock, nil) + didSet, _, err := kv.CAS(newLock, &wOpts) if err != nil { return fmt.Errorf("failed to update lock: %v", err) } @@ -331,7 +339,7 @@ READ: // Destroy the contender entry contenderKey := path.Join(s.opts.Prefix, lockSession) - if _, err := kv.Delete(contenderKey, nil); err != nil { + if _, err := kv.Delete(contenderKey, &wOpts); err != nil { return err } return nil @@ -351,7 +359,9 @@ func (s *Semaphore) Destroy() error { // List for the semaphore kv := s.c.KV() - pairs, _, err := kv.List(s.opts.Prefix, nil) + + q := QueryOptions{Namespace: s.opts.Namespace} + pairs, _, err := kv.List(s.opts.Prefix, &q) if err != nil { return fmt.Errorf("failed to read prefix: %v", err) } @@ -380,7 +390,8 @@ func (s *Semaphore) Destroy() error { } // Attempt the delete - didRemove, _, err := kv.DeleteCAS(lockPair, nil) + w := WriteOptions{Namespace: s.opts.Namespace} + didRemove, _, err := kv.DeleteCAS(lockPair, &w) if err != nil { return fmt.Errorf("failed to remove semaphore: %v", err) } @@ -398,7 +409,9 @@ func (s *Semaphore) createSession() (string, error) { TTL: s.opts.SessionTTL, Behavior: SessionBehaviorDelete, } - id, _, err := session.Create(se, nil) + + w := WriteOptions{Namespace: s.opts.Namespace} + id, _, err := session.Create(se, &w) if err != nil { return "", err } @@ -483,11 +496,14 @@ func (s *Semaphore) pruneDeadHolders(lock *semaphoreLock, pairs KVPairs) { func (s *Semaphore) monitorLock(session string, stopCh chan struct{}) { defer close(stopCh) kv := s.c.KV() - opts := &QueryOptions{RequireConsistent: true} + opts := QueryOptions{ + RequireConsistent: true, + Namespace: s.opts.Namespace, + } WAIT: retries := s.opts.MonitorRetries RETRY: - pairs, meta, err := kv.List(s.opts.Prefix, opts) + pairs, meta, err := kv.List(s.opts.Prefix, &opts) if err != nil { // If configured we can try to ride out a brief Consul unavailability // by doing retries. Note that we have to attempt the retry in a non- diff --git a/api/txn.go b/api/txn.go index 65d7a16ea0..eb50d55e6e 100644 --- a/api/txn.go +++ b/api/txn.go @@ -93,6 +93,19 @@ type KVTxnResponse struct { Errors TxnErrors } +// SessionOp constants give possible operations available in a transaction. +type SessionOp string + +const ( + SessionDelete SessionOp = "delete" +) + +// SessionTxnOp defines a single operation inside a transaction. +type SessionTxnOp struct { + Verb SessionOp + Session Session +} + // NodeOp constants give possible operations available in a transaction. type NodeOp string diff --git a/api/txn_test.go b/api/txn_test.go index 0d62b5460f..f454368a73 100644 --- a/api/txn_test.go +++ b/api/txn_test.go @@ -151,6 +151,7 @@ func TestAPI_ClientTxn(t *testing.T) { LockIndex: 1, CreateIndex: ret.Results[0].KV.CreateIndex, ModifyIndex: ret.Results[0].KV.ModifyIndex, + Namespace: ret.Results[0].KV.Namespace, }, }, &TxnResult{ @@ -161,6 +162,7 @@ func TestAPI_ClientTxn(t *testing.T) { LockIndex: 1, CreateIndex: ret.Results[1].KV.CreateIndex, ModifyIndex: ret.Results[1].KV.ModifyIndex, + Namespace: ret.Results[0].KV.Namespace, }, }, &TxnResult{ @@ -253,6 +255,7 @@ func TestAPI_ClientTxn(t *testing.T) { LockIndex: 1, CreateIndex: ret.Results[0].KV.CreateIndex, ModifyIndex: ret.Results[0].KV.ModifyIndex, + Namespace: ret.Results[0].KV.Namespace, }, }, &TxnResult{ diff --git a/command/flags/http.go b/command/flags/http.go index e2688fab8c..c32b700842 100644 --- a/command/flags/http.go +++ b/command/flags/http.go @@ -18,6 +18,7 @@ type HTTPFlags struct { certFile StringValue keyFile StringValue tlsServerName StringValue + namespace StringValue // server flags datacenter StringValue @@ -55,6 +56,10 @@ func (f *HTTPFlags) ClientFlags() *flag.FlagSet { fs.Var(&f.tlsServerName, "tls-server-name", "The server name to use as the SNI host when connecting via TLS. This "+ "can also be specified via the CONSUL_TLS_SERVER_NAME environment variable.") + // TODO (namespaces) Do we want to allow setting via an env var? CONSUL_NAMESPACE + fs.Var(&f.namespace, "ns", + "Specifies the namespace to query. If not provided, the namespace will be inferred +"+ + "from the request's ACL token, or will default to the `default` namespace.") return fs } @@ -135,4 +140,5 @@ func (f *HTTPFlags) MergeOntoConfig(c *api.Config) { f.keyFile.Merge(&c.TLSConfig.KeyFile) f.tlsServerName.Merge(&c.TLSConfig.Address) f.datacenter.Merge(&c.Datacenter) + f.namespace.Merge(&c.Namespace) } diff --git a/command/kv/get/kv_get.go b/command/kv/get/kv_get.go index 720f84f39f..f541d84d62 100644 --- a/command/kv/get/kv_get.go +++ b/command/kv/get/kv_get.go @@ -198,6 +198,9 @@ func prettyKVPair(w io.Writer, pair *api.KVPair, base64EncodeValue bool) error { } else { fmt.Fprintf(tw, "Session\t%s\n", pair.Session) } + if pair.Namespace != "" { + fmt.Fprintf(tw, "Namespace\t%s\n", pair.Namespace) + } if base64EncodeValue { fmt.Fprintf(tw, "Value\t%s", base64.StdEncoding.EncodeToString(pair.Value)) } else { diff --git a/command/kv/imp/kv_import.go b/command/kv/imp/kv_import.go index ad0d8d4f3a..030c7110c0 100644 --- a/command/kv/imp/kv_import.go +++ b/command/kv/imp/kv_import.go @@ -80,7 +80,8 @@ func (c *cmd) Run(args []string) int { Value: value, } - if _, err := client.KV().Put(pair, nil); err != nil { + w := api.WriteOptions{Namespace: entry.Namespace} + if _, err := client.KV().Put(pair, &w); err != nil { c.UI.Error(fmt.Sprintf("Error! Failed writing data for key %s: %s", pair.Key, err)) return 1 } diff --git a/command/kv/impexp/kvimpexp.go b/command/kv/impexp/kvimpexp.go index 221d37b2b1..ed1472785d 100644 --- a/command/kv/impexp/kvimpexp.go +++ b/command/kv/impexp/kvimpexp.go @@ -7,15 +7,17 @@ import ( ) type Entry struct { - Key string `json:"key"` - Flags uint64 `json:"flags"` - Value string `json:"value"` + Key string `json:"key"` + Flags uint64 `json:"flags"` + Value string `json:"value"` + Namespace string `json:"namespace,omitempty"` } func ToEntry(pair *api.KVPair) *Entry { return &Entry{ - Key: pair.Key, - Flags: pair.Flags, - Value: base64.StdEncoding.EncodeToString(pair.Value), + Key: pair.Key, + Flags: pair.Flags, + Value: base64.StdEncoding.EncodeToString(pair.Value), + Namespace: pair.Namespace, } } diff --git a/website/source/api/kv.html.md b/website/source/api/kv.html.md index fa635b61bb..4de89ad34e 100644 --- a/website/source/api/kv.html.md +++ b/website/source/api/kv.html.md @@ -62,11 +62,16 @@ The table below shows this endpoint's support for metadata). Specifying this implies `recurse`. This is specified as part of the URL as a query parameter. -- `separator` `(string: '')` - Specifies the string to use as a separator +- `separator` `(string: "")` - Specifies the string to use as a separator for recursive key lookups. This option is only used when paired with the `keys` parameter to limit the prefix of keys returned, only up to the given separator. This is specified as part of the URL as a query parameter. +- `ns` `(string: "")` - **Enterprise Only** Specifies the namespace to query. + If not provided, the namespace will be inferred from the request's ACL token, + or will default to the `default` namespace. This is specified as part of the + URL as a query parameter. + ### Sample Request ```text @@ -201,6 +206,11 @@ The table below shows this endpoint's support for will leave the `LockIndex` unmodified but will clear the associated `Session` of the key. The key must be held by this session to be unlocked. +- `ns` `(string: "")` - **Enterprise Only** Specifies the namespace to query. + If not provided, the namespace will be inferred from the request's ACL token, + or will default to the `default` namespace. This is specified as part of the + URL as a query parameter. + ### Sample Payload The payload is arbitrary, and is loaded directly into Consul as supplied. @@ -257,6 +267,11 @@ The table below shows this endpoint's support for index will not delete the key. If the index is non-zero, the key is only deleted if the index matches the `ModifyIndex` of that key. +- `ns` `(string: "")` - **Enterprise Only** Specifies the namespace to query. + If not provided, the namespace will be inferred from the request's ACL token, + or will default to the `default` namespace. This is specified as part of the + URL as a query parameter. + ### Sample Request ```text diff --git a/website/source/api/session.html.md b/website/source/api/session.html.md index 8c00636569..6329d20862 100644 --- a/website/source/api/session.html.md +++ b/website/source/api/session.html.md @@ -8,9 +8,7 @@ description: |- # Session HTTP Endpoint -The `/session` endpoints create, destroy, and query sessions in Consul. A -conceptual overview of sessions is found at the -[Session Internals](/docs/internals/sessions.html) page. +The `/session` endpoints create, destroy, and query sessions in Consul. ## Create Session @@ -33,11 +31,17 @@ The table below shows this endpoint's support for ### Parameters +- `ns` `(string: "")` - **Enterprise Only** Specifies the namespace to query. + If not provided, the namespace will be inferred from the request's ACL token, + or will default to the `default` namespace. This is specified as part of the + URL as a query parameter. + - `dc` `(string: "")` - Specifies the datacenter to query. This will default to the datacenter of the agent being queried. This is specified as part of the URL as a query parameter. Using this across datacenters is not recommended. -- `LockDelay` `(string: "15s")` - Specifies the duration for the lock delay. +- `LockDelay` `(string: "15s")` - Specifies the duration for the lock delay. This + must be greater than `0`. - `Node` `(string: "")` - Specifies the name of the node. This must refer to a node that is already registered. @@ -126,6 +130,11 @@ The table below shows this endpoint's support for - `dc` `(string: "")` - Specifies the datacenter to query. This will default to the datacenter of the agent being queried. This is specified as part of the URL as a query parameter. Using this across datacenters is not recommended. + +- `ns` `(string: "")` - **Enterprise Only** Specifies the namespace to query. + If not provided, the namespace will be inferred from the request's ACL token, + or will default to the `default` namespace. This is specified as part of the + URL as a query parameter. ### Sample Request @@ -167,6 +176,11 @@ The table below shows this endpoint's support for - `dc` `(string: "")` - Specifies the datacenter to query. This will default to the datacenter of the agent being queried. This is specified as part of the URL as a query parameter. Using this across datacenters is not recommended. + +- `ns` `(string: "")` - **Enterprise Only** Specifies the namespace to query. + If not provided, the namespace will be inferred from the request's ACL token, + or will default to the `default` namespace. This is specified as part of the + URL as a query parameter. ### Sample Request @@ -223,6 +237,11 @@ The table below shows this endpoint's support for - `dc` `(string: "")` - Specifies the datacenter to query. This will default to the datacenter of the agent being queried. This is specified as part of the URL as a query parameter. Using this across datacenters is not recommended. + +- `ns` `(string: "")` - **Enterprise Only** Specifies the namespace to query. + If not provided, the namespace will be inferred from the request's ACL token, + or will default to the `default` namespace. This is specified as part of the + URL as a query parameter. ### Sample Request @@ -274,6 +293,11 @@ The table below shows this endpoint's support for - `dc` `(string: "")` - Specifies the datacenter to query. This will default to the datacenter of the agent being queried. This is specified as part of the URL as a query parameter. Using this across datacenters is not recommended. + +- `ns` `(string: "")` - **Enterprise Only** Specifies the namespace to query. + If not provided, the namespace will be inferred from the request's ACL token, + or will default to the `default` namespace. This is specified as part of the + URL as a query parameter. ### Sample Request @@ -329,6 +353,11 @@ The table below shows this endpoint's support for - `dc` `(string: "")` - Specifies the datacenter to query. This will default to the datacenter of the agent being queried. This is specified as part of the URL as a query parameter. Using this across datacenters is not recommended. + +- `ns` `(string: "")` - **Enterprise Only** Specifies the namespace to query. + If not provided, the namespace will be inferred from the request's ACL token, + or will default to the `default` namespace. This is specified as part of the + URL as a query parameter. ### Sample Request diff --git a/website/source/docs/commands/_http_api_options_client.html.md b/website/source/docs/commands/_http_api_options_client.html.md index 7f67fd264e..3b33e90f66 100644 --- a/website/source/docs/commands/_http_api_options_client.html.md +++ b/website/source/docs/commands/_http_api_options_client.html.md @@ -34,3 +34,7 @@ instead of one specified via the `-token` argument or `CONSUL_HTTP_TOKEN` environment variable. This can also be specified via the `CONSUL_HTTP_TOKEN_FILE` environment variable. + +* `-ns=` - Specifies the namespace to query. + If not provided, the namespace will be inferred from the request's ACL token, + or will default to the `default` namespace.