diff --git a/consul/mdb_table.go b/consul/mdb_table.go index 800e260a28..61c5f6f935 100644 --- a/consul/mdb_table.go +++ b/consul/mdb_table.go @@ -195,17 +195,27 @@ func (t *MDBTable) StartTxn(readonly bool) (*MDBTxn, error) { return mdbTxn, nil } -// Insert is used to insert or update an object -func (t *MDBTable) Insert(obj interface{}) error { +// objIndexKeys builds the indexes for a given object +func (t *MDBTable) objIndexKeys(obj interface{}) (map[string][]byte, error) { // Construct the indexes keys indexes := make(map[string][]byte) for name, index := range t.Indexes { key, err := index.keyFromObject(obj) if err != nil { - return err + return nil, err } indexes[name] = key } + return indexes, nil +} + +// Insert is used to insert or update an object +func (t *MDBTable) Insert(obj interface{}) error { + // Construct the indexes keys + indexes, err := t.objIndexKeys(obj) + if err != nil { + return err + } // Encode the obj raw := t.Encoder(obj) @@ -256,9 +266,10 @@ func (t *MDBTable) Get(index string, parts ...string) ([]interface{}, error) { // Accumulate the results var results []interface{} - err = idx.iterate(tx, key, func(res []byte) { + err = idx.iterate(tx, key, func(encRowId, res []byte) bool { obj := t.Decoder(res) results = append(results, obj) + return false }) return results, err @@ -286,7 +297,7 @@ func (t *MDBTable) getIndex(index string, parts []string) (*MDBIndex, []byte, er // Delete is used to delete one or more rows. An index an appropriate // fields are specified. The fields can be a prefix of the index. // Returns the rows deleted or an error. -func (t *MDBTable) Delete(index string, parts ...string) (int, error) { +func (t *MDBTable) Delete(index string, parts ...string) (num int, err error) { // Get the associated index idx, key, err := t.getIndex(index, parts) if err != nil { @@ -300,11 +311,48 @@ func (t *MDBTable) Delete(index string, parts ...string) (int, error) { } defer tx.Abort() - // Accumulate the results - num := 0 - err = idx.iterate(tx, key, func(res []byte) { + // Handle an error while deleting + defer func() { + if r := recover(); r != nil { + num = 0 + err = err + } + }() + + // Delete everything as we iterate + err = idx.iterate(tx, key, func(encRowId, res []byte) bool { + // Get the object + obj := t.Decoder(res) + + // Build index values + indexes, err := t.objIndexKeys(obj) + if err != nil { + panic(err) + } + + // Delete the indexes we are not iterating + for name, otherIdx := range t.Indexes { + if name == index { + continue + } + dbi := tx.dbis[otherIdx.dbiName] + if err := tx.tx.Del(dbi, indexes[name], encRowId); err != nil { + panic(err) + } + } + + // Delete the data row + if err := tx.tx.Del(tx.dbis[t.Name], encRowId, nil); err != nil { + panic(err) + } + + // Delete the object num++ + return true }) + if err != nil { + return 0, err + } // Return the deleted count return num, tx.Commit() @@ -381,7 +429,8 @@ func (i *MDBIndex) keyFromObject(obj interface{}) ([]byte, error) { // iterate is used to iterate over keys matching the prefix, // and invoking the cb with each row. We dereference the rowid, // and only return the object row -func (i *MDBIndex) iterate(tx *MDBTxn, prefix []byte, cb func(res []byte)) error { +func (i *MDBIndex) iterate(tx *MDBTxn, prefix []byte, + cb func(encRowId, res []byte) bool) error { table := tx.dbis[i.table.Name] dbi := tx.dbis[i.dbiName] @@ -390,24 +439,28 @@ func (i *MDBIndex) iterate(tx *MDBTxn, prefix []byte, cb func(res []byte)) error return err } - var key, val []byte + var key, encRowId, objBytes []byte first := true + shouldDelete := false for { if first && len(prefix) > 0 { first = false - key, val, err = cursor.Get(prefix, mdb.SET_RANGE) + key, encRowId, err = cursor.Get(prefix, mdb.SET_RANGE) + } else if shouldDelete { + key, encRowId, err = cursor.Get(nil, 0) + shouldDelete = false } else if i.Unique { - key, val, err = cursor.Get(nil, mdb.NEXT) + key, encRowId, err = cursor.Get(nil, mdb.NEXT) } else { - key, val, err = cursor.Get(nil, mdb.NEXT_DUP) + key, encRowId, err = cursor.Get(nil, mdb.NEXT_DUP) if err == mdb.NotFound { - key, val, err = cursor.Get(nil, mdb.NEXT) + key, encRowId, err = cursor.Get(nil, mdb.NEXT) } } if err == mdb.NotFound { break } else if err != nil { - return err + return fmt.Errorf("iterate failed: %v", err) } // Bail if this does not match our filter @@ -416,13 +469,17 @@ func (i *MDBIndex) iterate(tx *MDBTxn, prefix []byte, cb func(res []byte)) error } // Lookup the actual object - objBytes, err := tx.tx.Get(table, val) + objBytes, err = tx.tx.Get(table, encRowId) if err != nil { - return err + return fmt.Errorf("rowid lookup failed: %v (%v)", err, encRowId) } // Invoke the cb - cb(objBytes) + if shouldDelete = cb(encRowId, objBytes); shouldDelete { + if err := cursor.Del(0); err != nil { + return fmt.Errorf("delete failed: %v", err) + } + } } return nil } diff --git a/consul/mdb_table_test.go b/consul/mdb_table_test.go index 54577aadb4..15fd5b356a 100644 --- a/consul/mdb_table_test.go +++ b/consul/mdb_table_test.go @@ -282,3 +282,95 @@ func TestMDBTableInsert_AllowBlank(t *testing.T) { } } } + +func TestMDBTableDelete(t *testing.T) { + dir, env := testMDBEnv(t) + defer os.RemoveAll(dir) + defer env.Close() + + table := &MDBTable{ + Env: env, + Name: "test", + Indexes: map[string]*MDBIndex{ + "id": &MDBIndex{ + Unique: true, + Fields: []string{"Key"}, + }, + "name": &MDBIndex{ + Fields: []string{"First", "Last"}, + }, + "country": &MDBIndex{ + Fields: []string{"Country"}, + }, + }, + Encoder: MockEncoder, + Decoder: MockDecoder, + } + if err := table.Init(); err != nil { + t.Fatalf("err: %v", err) + } + + objs := []*MockData{ + &MockData{ + Key: "1", + First: "Kevin", + Last: "Smith", + Country: "USA", + }, + &MockData{ + Key: "2", + First: "Kevin", + Last: "Wang", + Country: "USA", + }, + &MockData{ + Key: "3", + First: "Bernardo", + Last: "Torres", + Country: "Mexico", + }, + } + + // Insert some mock objects + for _, obj := range objs { + if err := table.Insert(obj); err != nil { + t.Fatalf("err: %v", err) + } + } + + _, err := table.Get("id", "3") + if err != nil { + t.Fatalf("err: %v", err) + } + + // Verify with some gets + num, err := table.Delete("id", "3") + if err != nil { + t.Fatalf("err: %v", err) + } + if num != 1 { + t.Fatalf("expect 1 delete: %#v", num) + } + res, err := table.Get("id", "3") + if err != nil { + t.Fatalf("err: %v", err) + } + if len(res) != 0 { + t.Fatalf("expect 0 result: %#v", res) + } + + num, err = table.Delete("name", "Kevin") + if err != nil { + t.Fatalf("err: %v", err) + } + if num != 2 { + t.Fatalf("expect 2 deletes: %#v", num) + } + res, err = table.Get("name", "Kevin") + if err != nil { + t.Fatalf("err: %v", err) + } + if len(res) != 0 { + t.Fatalf("expect 0 results: %#v", res) + } +}