From 58d77153fc84d9396606c2d27a5e2b9496c4d0af Mon Sep 17 00:00:00 2001 From: Jacek Sieka Date: Thu, 13 Aug 2020 11:50:05 +0200 Subject: [PATCH] fix invalid state root being written to database (#1493) * fix invalid state root being written to database When rewinding state data, the wrong block reference would be used when saving the state root - this would cause state loading to fail by loading a different state than expected, preventing blocks to be applied. * refactor state loading and saving to consistently use and set StateData block * avoid rollback when state is missing from database (as opposed to being partially overwritten and therefore in need of rollback) * don't store state roots for empty slots - previously, these were used as a cache to avoid recalculating them in state transition, but this has been superceded by hash tree root caching * don't attempt loading states / state roots for non-epoch slots, these are not saved to the database * simplify rewinder and clean up funcitions after caches have been reworked * fix chaindag logscope * add database reload metric * re-enable clearance epoch tests * names --- beacon_chain/beacon_chain_db.nim | 36 ++-- beacon_chain/block_pools/chain_dag.nim | 254 +++++++++++-------------- beacon_chain/block_pools/clearance.nim | 2 +- tests/test_block_pool.nim | 109 +++++++---- 4 files changed, 203 insertions(+), 198 deletions(-) diff --git a/beacon_chain/beacon_chain_db.nim b/beacon_chain/beacon_chain_db.nim index 66d73915a..9e27e26d2 100644 --- a/beacon_chain/beacon_chain_db.nim +++ b/beacon_chain/beacon_chain_db.nim @@ -94,26 +94,31 @@ proc get(db: BeaconChainDB, key: openArray[byte], T: type Eth2Digest): Opt[T] = res -proc get(db: BeaconChainDB, key: openArray[byte], res: var auto): bool = - var found = false +type GetResult = enum + found + notFound + corrupted + +proc get(db: BeaconChainDB, key: openArray[byte], output: var auto): GetResult = + var status = GetResult.notFound # TODO address is needed because there's no way to express lifetimes in nim # we'll use unsafeAddr to find the code later - var resPtr = unsafeAddr res # callback is local, ptr wont escape + var outputPtr = unsafeAddr output # callback is local, ptr wont escape proc decode(data: openArray[byte]) = try: - resPtr[] = SSZ.decode(snappy.decode(data), type res) - found = true + outputPtr[] = SSZ.decode(snappy.decode(data), type output) + status = GetResult.found except SerializationError as e: # If the data can't be deserialized, it could be because it's from a # version of the software that uses a different SSZ encoding warn "Unable to deserialize data, old database?", - err = e.msg, typ = name(type res), dataLen = data.len - discard + err = e.msg, typ = name(type output), dataLen = data.len + status = GetResult.corrupted discard db.backend.get(key, decode).expect("working database") - found + status proc putBlock*(db: BeaconChainDB, value: SignedBeaconBlock) = db.put(subkey(type value, value.root), value) @@ -152,7 +157,7 @@ proc putTailBlock*(db: BeaconChainDB, key: Eth2Digest) = proc getBlock*(db: BeaconChainDB, key: Eth2Digest): Opt[TrustedSignedBeaconBlock] = # We only store blocks that we trust in the database result.ok(TrustedSignedBeaconBlock(root: key)) - if not db.get(subkey(SignedBeaconBlock, key), result.get): + if db.get(subkey(SignedBeaconBlock, key), result.get) != GetResult.found: result.err() proc getState*( @@ -162,15 +167,20 @@ proc getState*( ## re-allocating it if possible ## Return `true` iff the entry was found in the database and `output` was ## overwritten. + ## Rollback will be called only if output was partially written - if it was + ## not found at all, rollback will not be called # TODO rollback is needed to deal with bug - use `noRollback` to ignore: # https://github.com/nim-lang/Nim/issues/14126 # TODO RVO is inefficient for large objects: # https://github.com/nim-lang/Nim/issues/13879 - if not db.get(subkey(BeaconState, key), output): + case db.get(subkey(BeaconState, key), output) + of GetResult.found: + true + of GetResult.notFound: + false + of GetResult.corrupted: rollback(output) false - else: - true proc getStateRoot*(db: BeaconChainDB, root: Eth2Digest, @@ -198,6 +208,6 @@ iterator getAncestors*(db: BeaconChainDB, root: Eth2Digest): var res: TrustedSignedBeaconBlock res.root = root - while db.get(subkey(SignedBeaconBlock, res.root), res): + while db.get(subkey(SignedBeaconBlock, res.root), res) == GetResult.found: yield res res.root = res.message.parent_root diff --git a/beacon_chain/block_pools/chain_dag.nim b/beacon_chain/block_pools/chain_dag.nim index 3d0091c95..1d6a03917 100644 --- a/beacon_chain/block_pools/chain_dag.nim +++ b/beacon_chain/block_pools/chain_dag.nim @@ -24,8 +24,9 @@ export block_pools_types declareCounter beacon_reorgs_total, "Total occurrences of reorganizations of the chain" # On fork choice declareCounter beacon_state_data_cache_hits, "EpochRef hits" declareCounter beacon_state_data_cache_misses, "EpochRef misses" +declareCounter beacon_state_rewinds, "State database rewinds" -logScope: topics = "hotdb" +logScope: topics = "chaindag" proc putBlock*( dag: var ChainDAGRef, signedBlock: SignedBeaconBlock) = @@ -382,11 +383,11 @@ proc getEpochRef*(dag: ChainDAGRef, blck: BlockRef, epoch: Epoch): EpochRef = getEpochInfo(blck, state, cache) proc getState( - dag: ChainDAGRef, db: BeaconChainDB, stateRoot: Eth2Digest, blck: BlockRef, - output: var StateData): bool = - let outputAddr = unsafeAddr output # local scope + dag: ChainDAGRef, state: var StateData, stateRoot: Eth2Digest, + blck: BlockRef): bool = + let stateAddr = unsafeAddr state # local scope func restore(v: var BeaconState) = - if outputAddr == (unsafeAddr dag.headState): + if stateAddr == (unsafeAddr dag.headState): # TODO seeing the headState in the restore shouldn't happen - we load # head states only when updating the head position, and by that time # the database will have gone through enough sanity checks that @@ -394,40 +395,55 @@ proc getState( # Nonetheless, this is an ugly workaround that needs to go away doAssert false, "Cannot alias headState" - assign(outputAddr[], dag.headState) + assign(stateAddr[], dag.headState) - if not db.getState(stateRoot, output.data.data, restore): + if not dag.db.getState(stateRoot, state.data.data, restore): return false - output.blck = blck - output.data.root = stateRoot + state.blck = blck + state.data.root = stateRoot true -proc putState*(dag: ChainDAGRef, state: HashedBeaconState, blck: BlockRef) = +proc getState(dag: ChainDAGRef, state: var StateData, bs: BlockSlot): bool = + ## Load a state from the database given a block and a slot - this will first + ## lookup the state root in the state root table then load the corresponding + ## state, if it exists + if not bs.slot.isEpoch: + return false # We only ever save epoch states - no need to hit database + + if (let stateRoot = dag.db.getStateRoot(bs.blck.root, bs.slot); + stateRoot.isSome()): + return dag.getState(state, stateRoot.get(), bs.blck) + + false + +proc putState*(dag: ChainDAGRef, state: StateData) = + # Store a state and its root # TODO we save state at every epoch start but never remove them - we also # potentially save multiple states per slot if reorgs happen, meaning # we could easily see a state explosion logScope: pcs = "save_state_at_epoch_start" - var rootWritten = false - if state.data.slot != blck.slot: - # This is a state that was produced by a skip slot for which there is no - # block - we'll save the state root in the database in case we need to - # replay the skip - dag.db.putStateRoot(blck.root, state.data.slot, state.root) - rootWritten = true + if not state.data.data.slot.isEpoch: + # As a policy, we only store epoch boundary states - the rest can be + # reconstructed by loading an epoch boundary state and applying the + # missing blocks + return - if state.data.slot.isEpoch: - if not dag.db.containsState(state.root): - info "Storing state", - blck = shortLog(blck), - stateSlot = shortLog(state.data.slot), - stateRoot = shortLog(state.root) + if dag.db.containsState(state.data.root): + return - dag.db.putState(state.root, state.data) - if not rootWritten: - dag.db.putStateRoot(blck.root, state.data.slot, state.root) + info "Storing state", + blck = shortLog(state.blck), + stateSlot = shortLog(state.data.data.slot), + stateRoot = shortLog(state.data.root) + + # Ideally we would save the state and the root lookup cache in a single + # transaction to prevent database inconsistencies, but the state loading code + # is resilient against one or the other going missing + dag.db.putState(state.data.root, state.data.data) + dag.db.putStateRoot(state.blck.root, state.data.data.slot, state.data.root) func getRef*(dag: ChainDAGRef, root: Eth2Digest): BlockRef = ## Retrieve a resolved block reference, if available @@ -500,122 +516,48 @@ proc get*(dag: ChainDAGRef, root: Eth2Digest): Option[BlockData] = else: none(BlockData) -proc skipAndUpdateState( - dag: ChainDAGRef, - state: var HashedBeaconState, blck: BlockRef, slot: Slot, save: bool) = - while state.data.slot < slot: +proc advanceSlots( + dag: ChainDAGRef, state: var StateData, slot: Slot, save: bool) = + # Given a state, advance it zero or more slots by applying empty slot + # processing + doAssert state.data.data.slot <= slot + + while state.data.data.slot < slot: # Process slots one at a time in case afterUpdate needs to see empty states - var stateCache = getEpochCache(blck, state.data) - advance_slot(state, dag.updateFlags, stateCache) + var cache = getEpochCache(state.blck, state.data.data) + advance_slot(state.data, dag.updateFlags, cache) if save: - dag.putState(state, blck) + dag.putState(state) -proc skipAndUpdateState( +proc applyBlock( dag: ChainDAGRef, state: var StateData, blck: BlockData, flags: UpdateFlags, save: bool): bool = + # Apply a single block to the state - the state must be positioned at the + # parent of the block with a slot lower than the one of the block being + # applied + doAssert state.blck == blck.refs.parent - dag.skipAndUpdateState( - state.data, blck.refs, blck.data.message.slot - 1, save) + # `state_transition` can handle empty slots, but we want to potentially save + # some of the empty slot states + dag.advanceSlots(state, blck.data.message.slot - 1, save) var statePtr = unsafeAddr state # safe because `restore` is locally scoped func restore(v: var HashedBeaconState) = doAssert (addr(statePtr.data) == addr v) statePtr[] = dag.headState - var stateCache = getEpochCache(blck.refs, state.data.data) + var cache = getEpochCache(blck.refs, state.data.data) + let ok = state_transition( dag.runtimePreset, state.data, blck.data, - stateCache, flags + dag.updateFlags, restore) - - if ok and save: - dag.putState(state.data, blck.refs) + cache, flags + dag.updateFlags, restore) + if ok: + state.blck = blck.refs + dag.putState(state) ok -proc rewindState( - dag: ChainDAGRef, state: var StateData, bs: BlockSlot): seq[BlockRef] = - logScope: - blockSlot = shortLog(bs) - pcs = "replay_state" - - var ancestors = @[bs.blck] - # Common case: the last block applied is the parent of the block to apply: - if not bs.blck.parent.isNil and state.blck.root == bs.blck.parent.root and - state.data.data.slot < bs.blck.slot: - return ancestors - - # It appears that the parent root of the proposed new block is different from - # what we expected. We will have to rewind the state to a point along the - # chain of ancestors of the new block. We will do this by loading each - # successive parent block and checking if we can find the corresponding state - # in the database. - var - stateRoot = block: - let tmp = dag.db.getStateRoot(bs.blck.root, bs.slot) - if tmp.isSome() and dag.db.containsState(tmp.get()): - tmp - else: - # State roots are sometimes kept in database even though state is not - err(Opt[Eth2Digest]) - curBs = bs - - while stateRoot.isNone(): - let parBs = curBs.parent() - if parBs.blck.isNil: - break # Bug probably! - - if parBs.blck != curBs.blck: - ancestors.add(parBs.blck) - - if (let tmp = dag.db.getStateRoot(parBs.blck.root, parBs.slot); tmp.isSome()): - if dag.db.containsState(tmp.get): - stateRoot = tmp - break - - curBs = parBs - - if stateRoot.isNone(): - # TODO this should only happen if the database is corrupt - we walked the - # list of parent blocks and couldn't find a corresponding state in the - # database, which should never happen (at least we should have the - # tail state in there!) - fatal "Couldn't find ancestor state root!" - doAssert false, "Oh noes, we passed big bang!" - - let - ancestor = ancestors.pop() - root = stateRoot.get() - found = dag.getState(dag.db, root, ancestor, state) - - if not found: - # TODO this should only happen if the database is corrupt - we walked the - # list of parent blocks and couldn't find a corresponding state in the - # database, which should never happen (at least we should have the - # tail state in there!) - fatal "Couldn't find ancestor state or block parent missing!" - doAssert false, "Oh noes, we passed big bang!" - - trace "Replaying state transitions", - stateSlot = shortLog(state.data.data.slot), - ancestors = ancestors.len - - ancestors - -proc getStateDataCached( - dag: ChainDAGRef, state: var StateData, bs: BlockSlot): bool = - # This pointedly does not run rewindState or state_transition, but otherwise - # mostly matches updateStateData(...), because it's too expensive to run the - # rewindState(...)/skipAndUpdateState(...)/state_transition(...) procs, when - # each hash_tree_root(...) consumes a nontrivial fraction of a second. - - # In-memory caches didn't hit. Try main block pool database. This is slower - # than the caches due to SSZ (de)serializing and disk I/O, so prefer them. - if (let tmp = dag.db.getStateRoot(bs.blck.root, bs.slot); tmp.isSome()): - return dag.getState(dag.db, tmp.get(), bs.blck, state) - - false - proc updateStateData*( dag: ChainDAGRef, state: var StateData, bs: BlockSlot) = ## Rewind or advance state such that it matches the given block and slot - @@ -624,56 +566,72 @@ proc updateStateData*( ## If slot is higher than blck.slot, replay will fill in with empty/non-block ## slots, else it is ignored - # We need to check the slot because the state might have moved forwards - # without blocks - if state.blck.root == bs.blck.root and state.data.data.slot <= bs.slot: - if state.data.data.slot != bs.slot: - # Might be that we're moving to the same block but later slot - dag.skipAndUpdateState(state.data, bs.blck, bs.slot, true) + # First, see if we're already at the requested block. If we are, also check + # that the state has not been advanced past the desired block - if it has, + # an earlier state must be loaded since there's no way to undo the slot + # transitions + if state.blck == bs.blck and state.data.data.slot <= bs.slot: + # The block is the same and we're at an early enough slot - advance the + # state with empty slot processing until the slot is correct + dag.advanceSlots(state, bs.slot, true) - return # State already at the right spot - - if dag.getStateDataCached(state, bs): return - let ancestors = rewindState(dag, state, bs) + # Either the state is too new or was created by applying a different block. + # We'll now resort to loading the state from the database then reapplying + # blocks until we reach the desired point in time. - # If we come this far, we found the state root. The last block on the stack - # is the one that produced this particular state, so we can pop it - # TODO it might be possible to use the latest block hashes from the state to - # do this more efficiently.. whatever! + var + ancestors: seq[BlockRef] + cur = bs + # Look for a state in the database and load it - as long as it cannot be + # found, keep track of the blocks that are needed to reach it from the + # state that eventually will be found + while not dag.getState(state, cur): + # There's no state saved for this particular BlockSlot combination, keep + # looking... + if cur.slot == cur.blck.slot: + # This is not an empty slot, so the block will need to be applied to + # eventually reach bs + ancestors.add(cur.blck) - # Time to replay all the blocks between then and now. We skip one because - # it's the one that we found the state with, and it has already been - # applied. Pathologically quadratic in slot number, naïvely. + # Moves back slot by slot, in case a state for an empty slot was saved + cur = cur.parent + + # Time to replay all the blocks between then and now for i in countdown(ancestors.len - 1, 0): # Because the ancestors are in the database, there's no need to persist them # again. Also, because we're applying blocks that were loaded from the # database, we can skip certain checks that have already been performed - # before adding the block to the database. In particular, this means that - # no state root calculation will take place here, because we can load - # the final state root from the block itself. + # before adding the block to the database. let ok = - dag.skipAndUpdateState(state, dag.get(ancestors[i]), {}, false) + dag.applyBlock(state, dag.get(ancestors[i]), {}, false) doAssert ok, "Blocks in database should never fail to apply.." # We save states here - blocks were guaranteed to have passed through the save # function once at least, but not so for empty slots! - dag.skipAndUpdateState(state.data, bs.blck, bs.slot, true) + dag.advanceSlots(state, bs.slot, true) - state.blck = bs.blck + beacon_state_rewinds.inc() + + debug "State reloaded from database", + blocks = ancestors.len, stateRoot = shortLog(state.data.root), + blck = shortLog(bs) proc loadTailState*(dag: ChainDAGRef): StateData = ## Load the state associated with the current tail in the dag let stateRoot = dag.db.getBlock(dag.tail.root).get().message.state_root - let found = dag.getState(dag.db, stateRoot, dag.tail, result) + let found = dag.getState(result, stateRoot, dag.tail) # TODO turn into regular error, this can happen doAssert found, "Failed to load tail state, database corrupt?" proc delState(dag: ChainDAGRef, bs: BlockSlot) = # Delete state state and mapping for a particular block+slot + if not bs.slot.isEpoch: + return # We only ever save epoch states if (let root = dag.db.getStateRoot(bs.blck.root, bs.slot); root.isSome()): dag.db.delState(root.get()) + dag.db.delStateRoot(bs.blck.root, bs.slot) proc updateHead*(dag: ChainDAGRef, newHead: BlockRef) = ## Update what we consider to be the current head, as given by the fork diff --git a/beacon_chain/block_pools/clearance.nim b/beacon_chain/block_pools/clearance.nim index bdf89a1d7..4f62a1c31 100644 --- a/beacon_chain/block_pools/clearance.nim +++ b/beacon_chain/block_pools/clearance.nim @@ -216,7 +216,7 @@ proc addRawBlock*( onBlockAdded ) - dag.putState(dag.clearanceState.data, dag.clearanceState.blck) + dag.putState(dag.clearanceState) return ok dag.clearanceState.blck diff --git a/tests/test_block_pool.nim b/tests/test_block_pool.nim index 4bc106bb7..07e7f5f2e 100644 --- a/tests/test_block_pool.nim +++ b/tests/test_block_pool.nim @@ -351,45 +351,82 @@ suiteReport "chain DAG finalization tests" & preset(): hash_tree_root(dag2.headState.data.data) == hash_tree_root(dag.headState.data.data) - # timedTest "init with gaps" & preset(): - # var cache = StateCache() - # for i in 0 ..< (SLOTS_PER_EPOCH * 6 - 2): - # var - # blck = makeTestBlock( - # dag.headState.data, pool.head.blck.root, cache, - # attestations = makeFullAttestations( - # dag.headState.data.data, pool.head.blck.root, - # dag.headState.data.data.slot, cache, {})) + timedTest "orphaned epoch block" & preset(): + var prestate = (ref HashedBeaconState)() + for i in 0 ..< SLOTS_PER_EPOCH: + if i == SLOTS_PER_EPOCH - 1: + assign(prestate[], dag.headState.data) - # let added = dag.addRawBlock(quarantine, hash_tree_root(blck.message), blck) do (validBlock: BlockRef): - # discard - # check: added.isOk() - # dag.updateHead(added[]) + let blck = makeTestBlock( + dag.headState.data, dag.head.root, cache) + let added = dag.addRawBlock(quarantine, blck, nil) + check: added.isOk() + dag.updateHead(added[]) - # # Advance past epoch so that the epoch transition is gapped - # check: - # process_slots( - # dag.headState.data, Slot(SLOTS_PER_EPOCH * 6 + 2) ) + check: + dag.heads.len() == 1 - # var blck = makeTestBlock( - # dag.headState.data, pool.head.blck.root, cache, - # attestations = makeFullAttestations( - # dag.headState.data.data, pool.head.blck.root, - # dag.headState.data.data.slot, cache, {})) + advance_slot(prestate[], {}, cache) - # let added = dag.addRawBlock(quarantine, hash_tree_root(blck.message), blck) do (validBlock: BlockRef): - # discard - # check: added.isOk() - # dag.updateHead(added[]) + # create another block, orphaning the head + let blck = makeTestBlock( + prestate[], dag.head.parent.root, cache) - # let - # pool2 = BlockPool.init(db) + # Add block, but don't update head + let added = dag.addRawBlock(quarantine, blck, nil) + check: added.isOk() - # # check that the state reloaded from database resembles what we had before - # check: - # pool2.dag.tail.root == dag.tail.root - # pool2.dag.head.blck.root == dag.head.blck.root - # pool2.dag.finalizedHead.blck.root == dag.finalizedHead.blck.root - # pool2.dag.finalizedHead.slot == dag.finalizedHead.slot - # hash_tree_root(pool2.headState.data.data) == - # hash_tree_root(dag.headState.data.data) + var + dag2 = init(ChainDAGRef, defaultRuntimePreset, db) + + # check that we can apply the block after the orphaning + let added2 = dag2.addRawBlock(quarantine, blck, nil) + check: added2.isOk() + +suiteReport "chain DAG finalization tests" & preset(): + setup: + var + db = makeTestDB(SLOTS_PER_EPOCH) + dag = init(ChainDAGRef, defaultRuntimePreset, db) + quarantine = QuarantineRef() + cache = StateCache() + + timedTest "init with gaps" & preset(): + for i in 0 ..< (SLOTS_PER_EPOCH * 6 - 2): + var + blck = makeTestBlock( + dag.headState.data, dag.head.root, cache, + attestations = makeFullAttestations( + dag.headState.data.data, dag.head.root, + dag.headState.data.data.slot, cache, {})) + + let added = dag.addRawBlock(quarantine, blck, nil) + check: added.isOk() + dag.updateHead(added[]) + + # Advance past epoch so that the epoch transition is gapped + check: + process_slots( + dag.headState.data, Slot(SLOTS_PER_EPOCH * 6 + 2) ) + + var blck = makeTestBlock( + dag.headState.data, dag.head.root, cache, + attestations = makeFullAttestations( + dag.headState.data.data, dag.head.root, + dag.headState.data.data.slot, cache, {})) + + let added = dag.addRawBlock(quarantine, blck, nil) + check: added.isOk() + dag.updateHead(added[]) + + let + dag2 = init(ChainDAGRef, defaultRuntimePreset, db) + + # check that the state reloaded from database resembles what we had before + check: + dag2.tail.root == dag.tail.root + dag2.head.root == dag.head.root + dag2.finalizedHead.blck.root == dag.finalizedHead.blck.root + dag2.finalizedHead.slot == dag.finalizedHead.slot + hash_tree_root(dag2.headState.data.data) == + hash_tree_root(dag.headState.data.data)