From 17fa4997556d0a69814144744e389cfa97c8f5ee Mon Sep 17 00:00:00 2001 From: Jacek Sieka Date: Sat, 4 May 2019 08:10:45 -0600 Subject: [PATCH] cache state root a bit more aggressively (#260) needed to be able to run state sim, too slow otherwise --- beacon_chain/beacon_node.nim | 17 ++++--- beacon_chain/beacon_node_types.nim | 6 +-- beacon_chain/block_pool.nim | 74 ++++++++++++++---------------- beacon_chain/spec/datatypes.nim | 5 ++ beacon_chain/state_transition.nim | 58 +++++++++++++++++++++++ tests/test_attestation_pool.nim | 22 ++++----- tests/test_block_pool.nim | 14 +++--- 7 files changed, 125 insertions(+), 71 deletions(-) diff --git a/beacon_chain/beacon_node.nim b/beacon_chain/beacon_node.nim index f0174a2f5..cdbb254a4 100644 --- a/beacon_chain/beacon_node.nim +++ b/beacon_chain/beacon_node.nim @@ -186,14 +186,12 @@ proc init*(T: type BeaconNode, conf: BeaconNodeConf): Future[BeaconNode] {.async sync.node = result sync.db = result.db - let head = result.blockPool.get(result.db.getHeadBlock().get()) - result.stateCache = result.blockPool.loadTailState() result.justifiedStateCache = result.stateCache let addressFile = string(conf.dataDir) / "beacon_node.address" result.network.saveConnectionAddressFile(addressFile) - result.beaconClock = BeaconClock.init(result.stateCache.data) + result.beaconClock = BeaconClock.init(result.stateCache.data.data) template withState( pool: BlockPool, cache: var StateData, blockSlot: BlockSlot, body: untyped): untyped = @@ -204,9 +202,10 @@ template withState( updateStateData(pool, cache, blockSlot) - template state(): BeaconState {.inject.} = cache.data + template hashedState(): HashedBeaconState {.inject.} = cache.data + template state(): BeaconState {.inject.} = cache.data.data template blck(): BlockRef {.inject.} = cache.blck - template root(): Eth2Digest {.inject.} = cache.root + template root(): Eth2Digest {.inject.} = cache.data.root body @@ -369,12 +368,12 @@ proc proposeBlock(node: BeaconNode, signature: ValidatorSig(), # we need the rest of the block first! ) - var tmpState = state + var tmpState = hashedState let ok = updateState(tmpState, newBlock, {skipValidation}) doAssert ok # TODO: err, could this fail somehow? - newBlock.state_root = hash_tree_root(tmpState) + newBlock.state_root = tmpState.root let blockRoot = signed_root(newBlock) @@ -791,6 +790,6 @@ when isMainModule: # TODO slightly ugly to rely on node.stateCache state here.. if node.nickname != "": - dynamicLogScope(node = node.nickname): node.start(node.stateCache.data) + dynamicLogScope(node = node.nickname): node.start(node.stateCache.data.data) else: - node.start(node.stateCache.data) + node.start(node.stateCache.data.data) diff --git a/beacon_chain/beacon_node_types.nim b/beacon_chain/beacon_node_types.nim index 719dac5ff..b7da13a9f 100644 --- a/beacon_chain/beacon_node_types.nim +++ b/beacon_chain/beacon_node_types.nim @@ -210,13 +210,11 @@ type refs*: BlockRef StateData* = object - data*: BeaconState - root*: Eth2Digest ##\ - ## Root of above data (cache) + data*: HashedBeaconState blck*: BlockRef ##\ ## The block associated with the state found in data - in particular, - ## blck.state_root == root + ## blck.state_root == rdata.root StateCache* = object crosslink_committee_cache*: diff --git a/beacon_chain/block_pool.nim b/beacon_chain/block_pool.nim index b89f02686..81759fb7a 100644 --- a/beacon_chain/block_pool.nim +++ b/beacon_chain/block_pool.nim @@ -158,13 +158,12 @@ proc addResolvedBlock( # set the rest here - need a blockRef to update it. Clean this up - # hopefully it won't be necessary by the time hash caching and the rest # is done.. - doAssert state.data.slot == blockRef.slot - state.root = blck.state_root + doAssert state.data.data.slot == blockRef.slot state.blck = blockRef # This block *might* have caused a justification - make sure we stow away # that information: - let justifiedSlot = state.data.current_justified_epoch.get_epoch_start_slot() + let justifiedSlot = state.data.data.current_justified_epoch.get_epoch_start_slot() var foundHead: Option[Head] for head in pool.heads.mitems(): @@ -345,8 +344,8 @@ proc checkMissing*(pool: var BlockPool): seq[FetchRecord] = result.add(FetchRecord(root: k, historySlots: v.slots)) proc skipAndUpdateState( - state: var BeaconState, blck: BeaconBlock, flags: UpdateFlags, - afterUpdate: proc (state: BeaconState)): bool = + state: var HashedBeaconState, blck: BeaconBlock, flags: UpdateFlags, + afterUpdate: proc (state: HashedBeaconState)): bool = skipSlots(state, blck.slot - 1, afterUpdate) let ok = updateState(state, blck, flags) @@ -354,27 +353,25 @@ proc skipAndUpdateState( ok -proc maybePutState(pool: BlockPool, state: BeaconState, blck: BlockRef) = +proc maybePutState(pool: BlockPool, state: HashedBeaconState, blck: BlockRef) = # TODO we save state at every epoch start but never remove them - we also # potentially save multiple states per slot if reorgs happen, meaning # we could easily see a state explosion - if state.slot mod SLOTS_PER_EPOCH == 0: - let root = hash_tree_root(state) - - if not pool.db.containsState(root): + if state.data.slot mod SLOTS_PER_EPOCH == 0: + if not pool.db.containsState(state.root): info "Storing state", - stateSlot = humaneSlotNum(state.slot), - stateRoot = shortLog(root) - pool.db.putState(root, state) + stateSlot = humaneSlotNum(state.data.slot), + stateRoot = shortLog(state.root) + pool.db.putState(state.root, state.data) # TODO this should be atomic with the above write.. - pool.db.putStateRoot(blck.root, state.slot, root) + pool.db.putStateRoot(blck.root, state.data.slot, state.root) proc rewindState(pool: BlockPool, state: var StateData, bs: BlockSlot): seq[BlockData] = var ancestors = @[pool.get(bs.blck)] # Common case: the last block applied is the parent of the block to apply: if not bs.blck.parent.isNil and state.blck.root == bs.blck.parent.root and - state.data.slot < bs.slot: + state.data.data.slot < bs.slot: return ancestors # It appears that the parent root of the proposed new block is different from @@ -424,14 +421,16 @@ proc rewindState(pool: BlockPool, state: var StateData, bs: BlockSlot): doAssert false, "Oh noes, we passed big bang!" debug "Replaying state transitions", - stateSlot = humaneSlotNum(state.data.slot), + stateSlot = humaneSlotNum(state.data.data.slot), ancestorStateRoot = shortLog(ancestor.data.state_root), ancestorStateSlot = humaneSlotNum(ancestorState.get().slot), slot = humaneSlotNum(bs.slot), blockRoot = shortLog(bs.blck.root), ancestors = ancestors.len - state.data = ancestorState.get() + state.data.data = ancestorState.get() + state.data.root = stateRoot.get() + state.blck = ancestor.refs ancestors @@ -444,14 +443,12 @@ proc updateStateData*(pool: BlockPool, state: var StateData, bs: BlockSlot) = # We need to check the slot because the state might have moved forwards # without blocks - if state.blck.root == bs.blck.root and state.data.slot <= bs.slot: - if state.data.slot != bs.slot: + if state.blck.root == bs.blck.root and state.data.data.slot <= bs.slot: + if state.data.data.slot != bs.slot: # Might be that we're moving to the same block but later slot - skipSlots(state.data, bs.slot) do (state: BeaconState): + skipSlots(state.data, bs.slot) do (state: HashedBeaconState): pool.maybePutState(state, bs.blck) - state.root = hash_tree_root(state.data) - return # State already at the right spot let ancestors = rewindState(pool, state, bs) @@ -467,25 +464,22 @@ proc updateStateData*(pool: BlockPool, state: var StateData, bs: BlockSlot) = for i in countdown(ancestors.len - 2, 0): let ok = skipAndUpdateState(state.data, ancestors[i].data, {skipValidation}) do( - state: BeaconState): + state: HashedBeaconState): pool.maybePutState(state, ancestors[i].refs) doAssert ok, "Blocks in database should never fail to apply.." - skipSlots(state.data, bs.slot) do (state: BeaconState): + skipSlots(state.data, bs.slot) do (state: HashedBeaconState): pool.maybePutState(state, bs.blck) - # TODO could perhaps avoi a hash_tree_root if putState happens.. hmm.. state.blck = bs.blck - state.root = - if state.data.slot == ancestors[0].data.slot: ancestors[0].data.state_root - else: hash_tree_root(state.data) proc loadTailState*(pool: BlockPool): StateData = ## Load the state associated with the current tail in the pool let stateRoot = pool.db.getBlock(pool.tail.root).get().state_root StateData( - data: pool.db.getState(stateRoot).get(), - root: stateRoot, + data: HashedBeaconState( + data: pool.db.getState(stateRoot).get(), + root: stateRoot), blck: pool.tail ) @@ -517,30 +511,30 @@ proc updateHead*(pool: BlockPool, state: var StateData, blck: BlockRef) = # Start off by making sure we have the right state updateStateData(pool, state, BlockSlot(blck: blck, slot: blck.slot)) - let justifiedSlot = state.data.current_justified_epoch.get_epoch_start_slot() + let justifiedSlot = state.data.data.current_justified_epoch.get_epoch_start_slot() pool.head = Head(blck: blck, justified: blck.findAncestorBySlot(justifiedSlot)) if lastHead.blck != blck.parent: notice "Updated head with new parent", lastHeadRoot = shortLog(lastHead.blck.root), parentRoot = shortLog(blck.parent.root), - stateRoot = shortLog(state.root), + stateRoot = shortLog(state.data.root), headBlockRoot = shortLog(state.blck.root), - stateSlot = humaneSlotNum(state.data.slot), - justifiedEpoch = humaneEpochNum(state.data.current_justified_epoch), - finalizedEpoch = humaneEpochNum(state.data.finalized_epoch) + stateSlot = humaneSlotNum(state.data.data.slot), + justifiedEpoch = humaneEpochNum(state.data.data.current_justified_epoch), + finalizedEpoch = humaneEpochNum(state.data.data.finalized_epoch) else: info "Updated head", - stateRoot = shortLog(state.root), + stateRoot = shortLog(state.data.root), headBlockRoot = shortLog(state.blck.root), - stateSlot = humaneSlotNum(state.data.slot), - justifiedEpoch = humaneEpochNum(state.data.current_justified_epoch), - finalizedEpoch = humaneEpochNum(state.data.finalized_epoch) + stateSlot = humaneSlotNum(state.data.data.slot), + justifiedEpoch = humaneEpochNum(state.data.data.current_justified_epoch), + finalizedEpoch = humaneEpochNum(state.data.data.finalized_epoch) let # TODO there might not be a block at the epoch boundary - what then? finalizedHead = - blck.findAncestorBySlot(state.data.finalized_epoch.get_epoch_start_slot()) + blck.findAncestorBySlot(state.data.data.finalized_epoch.get_epoch_start_slot()) doAssert (not finalizedHead.blck.isNil), "Block graph should always lead to a finalized block" diff --git a/beacon_chain/spec/datatypes.nim b/beacon_chain/spec/datatypes.nim index 285d1fa99..ba0b49543 100644 --- a/beacon_chain/spec/datatypes.nim +++ b/beacon_chain/spec/datatypes.nim @@ -502,6 +502,11 @@ type # TODO: not in spec CrosslinkCommittee* = tuple[committee: seq[ValidatorIndex], shard: uint64] + # TODO to be replaced with some magic hash caching + HashedBeaconState* = object + data*: BeaconState + root*: Eth2Digest # hash_tree_root (not signed_root!) + func shortValidatorKey*(state: BeaconState, validatorIdx: int): string = ($state.validator_registry[validatorIdx].pubkey)[0..7] diff --git a/beacon_chain/state_transition.nim b/beacon_chain/state_transition.nim index 2cd5debf3..c01ea8eab 100644 --- a/beacon_chain/state_transition.nim +++ b/beacon_chain/state_transition.nim @@ -1212,6 +1212,64 @@ proc skipSlots*(state: var BeaconState, slot: Slot, if not afterSlot.isNil: afterSlot(state) +# TODO hashed versions of above - not in spec + +# https://github.com/ethereum/eth2.0-specs/blob/v0.5.0/specs/core/0_beacon-chain.md#state-caching +func cacheState(state: var HashedBeaconState) = + let previous_slot_state_root = state.root + + # store the previous slot's post state transition root + state.data.latest_state_roots[state.data.slot mod SLOTS_PER_HISTORICAL_ROOT] = + previous_slot_state_root + + # cache state root in stored latest_block_header if empty + if state.data.latest_block_header.state_root == ZERO_HASH: + state.data.latest_block_header.state_root = previous_slot_state_root + + # store latest known block for previous slot + state.data.latest_block_roots[state.data.slot mod SLOTS_PER_HISTORICAL_ROOT] = + signed_root(state.data.latest_block_header) + +proc advanceState*(state: var HashedBeaconState) = + cacheState(state) + processEpoch(state.data) + advance_slot(state.data) + +proc updateState*( + state: var HashedBeaconState, blck: BeaconBlock, flags: UpdateFlags): bool = + var old_state = state + advanceState(state) + + if processBlock(state.data, blck, flags): + if skipValidation in flags or verifyStateRoot(state.data, blck): + # State root is what it should be - we're done! + + # TODO when creating a new block, state_root is not yet set.. comparing + # with zero hash here is a bit fragile however, but this whole thing + # should go away with proper hash caching + state.root = + if blck.state_root == Eth2Digest(): hash_tree_root(state.data) + else: blck.state_root + + return true + + # Block processing failed, roll back changes + state = old_state + false + +proc skipSlots*(state: var HashedBeaconState, slot: Slot, + afterSlot: proc (state: HashedBeaconState) = nil) = + if state.data.slot < slot: + debug "Advancing state with empty slots", + targetSlot = humaneSlotNum(slot), + stateSlot = humaneSlotNum(state.data.slot) + + while state.data.slot < slot: + advanceState(state) + + if not afterSlot.isNil: + afterSlot(state) + # TODO document this: # Jacek Sieka diff --git a/tests/test_attestation_pool.nim b/tests/test_attestation_pool.nim index 7f4c5057c..d0208a48f 100644 --- a/tests/test_attestation_pool.nim +++ b/tests/test_attestation_pool.nim @@ -33,14 +33,14 @@ suite "Attestation pool processing": let # Create an attestation for slot 1 signed by the only attester we have! crosslink_committees = - get_crosslink_committees_at_slot(state.data, state.data.slot) + get_crosslink_committees_at_slot(state.data.data, state.data.data.slot) attestation = makeAttestation( - state.data, state.blck.root, crosslink_committees[0].committee[0]) + state.data.data, state.blck.root, crosslink_committees[0].committee[0]) - pool.add(state.data, attestation) + pool.add(state.data.data, attestation) let attestations = pool.getAttestationsForBlock( - state.data, state.data.slot + MIN_ATTESTATION_INCLUSION_DELAY) + state.data.data, state.data.data.slot + MIN_ATTESTATION_INCLUSION_DELAY) # TODO test needs fixing for new attestation validation # check: @@ -57,24 +57,24 @@ suite "Attestation pool processing": let # Create an attestation for slot 1 signed by the only attester we have! crosslink_committees1 = - get_crosslink_committees_at_slot(state.data, state.data.slot) + get_crosslink_committees_at_slot(state.data.data, state.data.data.slot) attestation1 = makeAttestation( - state.data, state.blck.root, crosslink_committees1[0].committee[0]) + state.data.data, state.blck.root, crosslink_committees1[0].committee[0]) advanceState(state.data) let crosslink_committees2 = - get_crosslink_committees_at_slot(state.data, state.data.slot) + get_crosslink_committees_at_slot(state.data.data, state.data.data.slot) attestation2 = makeAttestation( - state.data, state.blck.root, crosslink_committees2[0].committee[0]) + state.data.data, state.blck.root, crosslink_committees2[0].committee[0]) # test reverse order - pool.add(state.data, attestation2) - pool.add(state.data, attestation1) + pool.add(state.data.data, attestation2) + pool.add(state.data.data, attestation1) let attestations = pool.getAttestationsForBlock( - state.data, state.data.slot + MIN_ATTESTATION_INCLUSION_DELAY) + state.data.data, state.data.data.slot + MIN_ATTESTATION_INCLUSION_DELAY) # TODO test needs fixing for new attestation validation # check: diff --git a/tests/test_block_pool.nim b/tests/test_block_pool.nim index 6fcb473ce..d8391bbe3 100644 --- a/tests/test_block_pool.nim +++ b/tests/test_block_pool.nim @@ -25,7 +25,7 @@ suite "Block pool processing": b0 = pool.get(state.blck.root) check: - state.data.slot == GENESIS_SLOT + state.data.data.slot == GENESIS_SLOT b0.isSome() toSeq(pool.blockRootsForSlot(GENESIS_SLOT)) == @[state.blck.root] @@ -35,7 +35,7 @@ suite "Block pool processing": state = pool.loadTailState() let - b1 = makeBlock(state.data, state.blck.root, BeaconBlockBody()) + b1 = makeBlock(state.data.data, state.blck.root, BeaconBlockBody()) b1Root = signed_root(b1) # TODO the return value is ugly here, need to fix and test.. @@ -46,7 +46,7 @@ suite "Block pool processing": check: b1Ref.isSome() b1Ref.get().refs.root == b1Root - hash_tree_root(state.data) == state.root + hash_tree_root(state.data.data) == state.data.root test "Reverse order block add & get": var @@ -55,9 +55,9 @@ suite "Block pool processing": state = pool.loadTailState() let - b1 = addBlock(state.data, state.blck.root, BeaconBlockBody(), {}) + b1 = addBlock(state.data.data, state.blck.root, BeaconBlockBody(), {}) b1Root = signed_root(b1) - b2 = addBlock(state.data, b1Root, BeaconBlockBody(), {}) + b2 = addBlock(state.data.data, b1Root, BeaconBlockBody(), {}) b2Root = signed_root(b2) discard pool.add(state, b2Root, b2) @@ -68,7 +68,7 @@ suite "Block pool processing": discard pool.add(state, b1Root, b1) - check: hash_tree_root(state.data) == state.root + check: hash_tree_root(state.data.data) == state.data.root let b1r = pool.get(b1Root) @@ -90,6 +90,6 @@ suite "Block pool processing": pool2 = BlockPool.init(db) check: - hash_tree_root(state.data) == state.root + hash_tree_root(state.data.data) == state.data.root pool2.get(b1Root).isSome() pool2.get(b2Root).isSome()