diff --git a/eth/trie/db.nim b/eth/trie/db.nim index 050bd30..7e380b5 100644 --- a/eth/trie/db.nim +++ b/eth/trie/db.nim @@ -34,11 +34,15 @@ type containsProc: ContainsProc mostInnerTransaction: DbTransaction + TransactionFlags = enum + Committed + RolledBack + DbTransaction* = ref object db: TrieDatabaseRef parentTransaction: DbTransaction modifications: MemoryLayer - committed: bool + flags: set[TransactionFlags] proc put*(db: TrieDatabaseRef, key, val: openarray[byte]) {.gcsafe.} proc get*(db: TrieDatabaseRef, key: openarray[byte]): Bytes {.gcsafe.} @@ -137,24 +141,32 @@ proc rollback*(t: DbTransaction) = # Transactions should be handled in a strictly nested fashion. # Any child transaction must be committed or rolled-back before # its parent transactions: - doAssert t.db.mostInnerTransaction == t and not t.committed + doAssert t.db.mostInnerTransaction == t and + Committed notin t.flags and + RolledBack notin t.flags t.db.mostInnerTransaction = t.parentTransaction + t.flags.incl RolledBack proc commit*(t: DbTransaction) = # Transactions should be handled in a strictly nested fashion. # Any child transaction must be committed or rolled-back before # its parent transactions: - doAssert t.db.mostInnerTransaction == t and not t.committed + doAssert t.db.mostInnerTransaction == t and + Committed notin t.flags and + RolledBack notin t.flags t.db.mostInnerTransaction = t.parentTransaction t.modifications.commit(t.db) - t.committed = true + t.flags.incl Committed proc dispose*(t: DbTransaction) {.inline.} = - if not t.committed: + if Committed notin t.flags and + RolledBack notin t.flags: t.rollback() proc safeDispose*(t: DbTransaction) {.inline.} = - if t != nil and not t.committed: + if t != nil and + Committed notin t.flags and + RolledBack notin t.flags: t.rollback() proc putImpl[T](db: RootRef, key, val: openarray[byte]) = diff --git a/tests/trie/test_transaction_db.nim b/tests/trie/test_transaction_db.nim new file mode 100644 index 0000000..f55926f --- /dev/null +++ b/tests/trie/test_transaction_db.nim @@ -0,0 +1,179 @@ +import + unittest, strutils, sequtils, os, + eth/trie/[db, trie_defs], ./testutils, + eth/rlp/types as rlpTypes + +suite "transaction db": + setup: + const + listLength = 30 + + var + keysA = randList(Bytes, randGen(3, 33), randGen(listLength)) + valuesA = randList(Bytes, randGen(5, 77), randGen(listLength)) + keysB = randList(Bytes, randGen(3, 33), randGen(listLength)) + valuesB = randList(Bytes, randGen(5, 77), randGen(listLength)) + + proc populateA(db: TrieDatabaseRef) = + for i in 0 ..< listLength: + db.put(keysA[i], valuesA[i]) + + proc checkContentsA(db: TrieDatabaseRef): bool = + for i in 0 ..< listLength: + let v = db.get(keysA[i]) + if v != valuesA[i]: return false + result = true + + proc checkEmptyContentsA(db: TrieDatabaseRef): bool {.used.} = + for i in 0 ..< listLength: + let v = db.get(keysA[i]) + if v.len != 0: return false + result = true + + proc populateB(db: TrieDatabaseRef) {.used.} = + for i in 0 ..< listLength: + db.put(keysB[i], valuesB[i]) + + proc checkContentsB(db: TrieDatabaseRef): bool {.used.} = + for i in 0 ..< listLength: + let v = db.get(keysB[i]) + if v != valuesB[i]: return false + result = true + + proc checkEmptyContentsB(db: TrieDatabaseRef): bool {.used.} = + for i in 0 ..< listLength: + let v = db.get(keysB[i]) + if v.len != 0: return false + result = true + + test "commit": + var db = newMemoryDB() + var tx = db.beginTransaction() + db.populateA() + check checkContentsA(db) + tx.commit() + check checkContentsA(db) + + test "rollback": + var db = newMemoryDB() + var tx = db.beginTransaction() + db.populateA() + check checkContentsA(db) + tx.rollback() + check checkEmptyContentsA(db) + + test "dispose": + var db = newMemoryDB() + var tx = db.beginTransaction() + db.populateA() + check checkContentsA(db) + tx.dispose() + check checkEmptyContentsA(db) + + test "commit dispose": + var db = newMemoryDB() + var tx = db.beginTransaction() + db.populateA() + check checkContentsA(db) + tx.commit() + tx.dispose() + check checkContentsA(db) + + test "rollback dispose": + var db = newMemoryDB() + var tx = db.beginTransaction() + db.populateA() + check checkContentsA(db) + tx.rollback() + tx.dispose() + check checkEmptyContentsA(db) + + test "dispose dispose": + var db = newMemoryDB() + var tx = db.beginTransaction() + db.populateA() + check checkContentsA(db) + tx.dispose() + tx.dispose() + check checkEmptyContentsA(db) + + test "commit commit": + var db = newMemoryDB() + var txA = db.beginTransaction() + db.populateA() + var txB = db.beginTransaction() + db.populateB() + + check checkContentsA(db) + check checkContentsB(db) + + txB.commit() + txA.commit() + + check checkContentsA(db) + check checkContentsB(db) + + test "commit rollback": + var db = newMemoryDB() + var txA = db.beginTransaction() + db.populateA() + var txB = db.beginTransaction() + db.populateB() + + check checkContentsA(db) + check checkContentsB(db) + + txB.rollback() + txA.commit() + + check checkContentsA(db) + check checkEmptyContentsB(db) + + test "rollback commit": + var db = newMemoryDB() + var txA = db.beginTransaction() + db.populateA() + var txB = db.beginTransaction() + db.populateB() + + check checkContentsA(db) + check checkContentsB(db) + + txB.commit() + txA.rollback() + + check checkEmptyContentsB(db) + check checkEmptyContentsA(db) + + test "rollback rollback": + var db = newMemoryDB() + var txA = db.beginTransaction() + db.populateA() + var txB = db.beginTransaction() + db.populateB() + + check checkContentsA(db) + check checkContentsB(db) + + txB.rollback() + txA.rollback() + + check checkEmptyContentsB(db) + check checkEmptyContentsA(db) + + test "commit rollback dispose": + var db = newMemoryDB() + var txA = db.beginTransaction() + db.populateA() + var txB = db.beginTransaction() + db.populateB() + + check checkContentsA(db) + check checkContentsB(db) + + txB.rollback() + txA.commit() + txA.dispose() + + check checkContentsA(db) + check checkEmptyContentsB(db) diff --git a/tests/trie/testutils.nim b/tests/trie/testutils.nim index b6f7075..51c8ebe 100644 --- a/tests/trie/testutils.nim +++ b/tests/trie/testutils.nim @@ -27,6 +27,11 @@ proc randString*(len: int): string = for i in 0..