diff --git a/consul/mdb_table.go b/consul/mdb_table.go index 0968d1aa41..800e260a28 100644 --- a/consul/mdb_table.go +++ b/consul/mdb_table.go @@ -60,13 +60,13 @@ func (t *MDBTxn) Commit() error { } type RowID uint64 -type IndexFunc func(...string) string +type IndexFunc func([]string) string // DefaultIndexFunc is used if no IdxFunc is provided. It joins // the columns using '||' which is reasonably unlikely to occur. // We also prefix with a byte to ensure we never have a zero length // key -func DefaultIndexFunc(parts ...string) string { +func DefaultIndexFunc(parts []string) string { return "_" + strings.Join(parts, "||") } @@ -241,21 +241,12 @@ func (t *MDBTable) Insert(obj interface{}) error { // Get is used to lookup one or more rows. An index an appropriate // fields are specified. The fields can be a prefix of the index. func (t *MDBTable) Get(index string, parts ...string) ([]interface{}, error) { - // Get the index - idx, ok := t.Indexes[index] - if !ok { - return nil, noIndex + // Get the associated index + idx, key, err := t.getIndex(index, parts) + if err != nil { + return nil, err } - // Check the arity - arity := idx.arity() - if len(parts) > arity { - return nil, tooManyFields - } - - // Construct the key - key := []byte(idx.IdxFunc(parts...)) - // Start a readonly txn tx, err := t.StartTxn(true) if err != nil { @@ -273,11 +264,50 @@ func (t *MDBTable) Get(index string, parts ...string) ([]interface{}, error) { return results, err } +// getIndex is used to get the proper index, and also check the arity +func (t *MDBTable) getIndex(index string, parts []string) (*MDBIndex, []byte, error) { + // Get the index + idx, ok := t.Indexes[index] + if !ok { + return nil, nil, noIndex + } + + // Check the arity + arity := idx.arity() + if len(parts) > arity { + return nil, nil, tooManyFields + } + + // Construct the key + key := []byte(idx.IdxFunc(parts)) + return idx, key, nil +} + // Delete is used to delete one or more rows. An index an appropriate // fields are specified. The fields can be a prefix of the index. // Returns the rows deleted or an error. func (t *MDBTable) Delete(index string, parts ...string) (int, error) { - return 0, nil + // Get the associated index + idx, key, err := t.getIndex(index, parts) + if err != nil { + return 0, err + } + + // Start a write txn + tx, err := t.StartTxn(false) + if err != nil { + return 0, err + } + defer tx.Abort() + + // Accumulate the results + num := 0 + err = idx.iterate(tx, key, func(res []byte) { + num++ + }) + + // Return the deleted count + return num, tx.Commit() } // Initializes an index and returns a potential error @@ -344,7 +374,7 @@ func (i *MDBIndex) keyFromObject(obj interface{}) ([]byte, error) { } parts = append(parts, val) } - key := i.IdxFunc(parts...) + key := i.IdxFunc(parts) return []byte(key), nil } diff --git a/consul/mdb_table_test.go b/consul/mdb_table_test.go index 95b270d34d..54577aadb4 100644 --- a/consul/mdb_table_test.go +++ b/consul/mdb_table_test.go @@ -161,6 +161,23 @@ func TestMDBTableInsert(t *testing.T) { if !reflect.DeepEqual(res[0], objs[2]) { t.Fatalf("bad: %#v", res[2]) } + + res, err = table.Get("id") + if err != nil { + t.Fatalf("err: %v", err) + } + if len(res) != 3 { + t.Fatalf("expect 2 result: %#v", res) + } + if !reflect.DeepEqual(res[0], objs[0]) { + t.Fatalf("bad: %#v", res[0]) + } + if !reflect.DeepEqual(res[1], objs[1]) { + t.Fatalf("bad: %#v", res[1]) + } + if !reflect.DeepEqual(res[2], objs[2]) { + t.Fatalf("bad: %#v", res[2]) + } } func TestMDBTableInsert_MissingFields(t *testing.T) {