From b13f06fcfbf01de73632dd1ca2ad35f7e481b481 Mon Sep 17 00:00:00 2001 From: bhartnett <51288821+bhartnett@users.noreply.github.com> Date: Wed, 9 Oct 2024 20:23:46 +0800 Subject: [PATCH] Fluffy state portal rpc validation (#2719) * Validate content key in portal_stateLocalContent. * Add additional validation to stateStore rpc method. --- fluffy/network/state/state_validation.nim | 1 - fluffy/network/wire/portal_protocol.nim | 5 ++- fluffy/rpc/rpc_portal_beacon_api.nim | 9 ++--- fluffy/rpc/rpc_portal_history_api.nim | 9 ++--- fluffy/rpc/rpc_portal_state_api.nim | 46 +++++++++++++---------- 5 files changed, 36 insertions(+), 34 deletions(-) diff --git a/fluffy/network/state/state_validation.nim b/fluffy/network/state/state_validation.nim index ce61cc35d..8779077cc 100644 --- a/fluffy/network/state/state_validation.nim +++ b/fluffy/network/state/state_validation.nim @@ -45,7 +45,6 @@ proc validateTrieProof*( if proof.len() == 0: return err("proof is empty") - # TODO: Remove this once the hive tests support passing in state roots from the history network if expectedRootHash.isSome(): if not proof[0].hashEquals(expectedRootHash.get()): return err("hash of proof root node doesn't match the expected root hash") diff --git a/fluffy/network/wire/portal_protocol.nim b/fluffy/network/wire/portal_protocol.nim index be3e1b481..bf4af975e 100644 --- a/fluffy/network/wire/portal_protocol.nim +++ b/fluffy/network/wire/portal_protocol.nim @@ -1590,12 +1590,15 @@ proc storeContent*( contentKey: ContentKeyByteList, contentId: ContentId, content: seq[byte], -) = +): bool {.discardable.} = # Always re-check that the key is still in the node range to make sure only # content in range is stored. if p.inRange(contentId): doAssert(p.dbPut != nil) p.dbPut(contentKey, contentId, content) + true + else: + false proc seedTable*(p: PortalProtocol) = ## Seed the table with specifically provided Portal bootstrap nodes. These are diff --git a/fluffy/rpc/rpc_portal_beacon_api.nim b/fluffy/rpc/rpc_portal_beacon_api.nim index 32714f3e6..ddcf9c988 100644 --- a/fluffy/rpc/rpc_portal_beacon_api.nim +++ b/fluffy/rpc/rpc_portal_beacon_api.nim @@ -115,13 +115,10 @@ proc installPortalBeaconApiHandlers*(rpcServer: RpcServer, p: PortalProtocol) = let key = ContentKeyByteList.init(hexToSeqByte(contentKey)) contentValueBytes = hexToSeqByte(contentValue) - contentId = p.toContentId(key) + contentId = p.toContentId(key).valueOr: + raise invalidKeyErr() - if contentId.isSome(): - p.storeContent(key, contentId.get(), contentValueBytes) - return true - else: - raise invalidKeyErr() + p.storeContent(key, contentId, contentValueBytes) rpcServer.rpc("portal_beaconLocalContent") do(contentKey: string) -> string: let diff --git a/fluffy/rpc/rpc_portal_history_api.nim b/fluffy/rpc/rpc_portal_history_api.nim index bcfbbb2ca..6d057996a 100644 --- a/fluffy/rpc/rpc_portal_history_api.nim +++ b/fluffy/rpc/rpc_portal_history_api.nim @@ -115,13 +115,10 @@ proc installPortalHistoryApiHandlers*(rpcServer: RpcServer, p: PortalProtocol) = let key = ContentKeyByteList.init(hexToSeqByte(contentKey)) contentValueBytes = hexToSeqByte(contentValue) - contentId = p.toContentId(key) + contentId = p.toContentId(key).valueOr: + raise invalidKeyErr() - if contentId.isSome(): - p.storeContent(key, contentId.get(), contentValueBytes) - return true - else: - raise invalidKeyErr() + p.storeContent(key, contentId, contentValueBytes) rpcServer.rpc("portal_historyLocalContent") do(contentKey: string) -> string: let diff --git a/fluffy/rpc/rpc_portal_state_api.nim b/fluffy/rpc/rpc_portal_state_api.nim index 3a42ea376..15b351663 100644 --- a/fluffy/rpc/rpc_portal_state_api.nim +++ b/fluffy/rpc/rpc_portal_state_api.nim @@ -13,7 +13,7 @@ import json_serialization/std/tables, stew/byteutils, ../network/wire/portal_protocol, - ../network/state/state_content, + ../network/state/[state_content, state_validation], ./rpc_types {.warning[UnusedImport]: off.} @@ -114,41 +114,47 @@ proc installPortalStateApiHandlers*(rpcServer: RpcServer, p: PortalProtocol) = contentKey: string, contentValue: string ) -> bool: let - key = ContentKeyByteList.init(hexToSeqByte(contentKey)) - contentValueBytes = hexToSeqByte(contentValue) - decodedKey = ContentKey.decode(key).valueOr: + keyBytes = ContentKeyByteList.init(hexToSeqByte(contentKey)) + key = ContentKey.decode(keyBytes).valueOr: raise invalidKeyErr() + contentId = p.toContentId(keyBytes).valueOr: + raise invalidKeyErr() + + contentBytes = hexToSeqByte(contentValue) valueToStore = - case decodedKey.contentType + case key.contentType of unused: raise invalidKeyErr() of accountTrieNode: - let offerValue = AccountTrieNodeOffer.decode(contentValueBytes).valueOr: + let offer = AccountTrieNodeOffer.decode(contentBytes).valueOr: raise invalidValueErr - offerValue.toRetrievalValue.encode() + validateOffer(Opt.none(Hash32), key.accountTrieNodeKey, offer).isOkOr: + raise invalidValueErr + offer.toRetrievalValue.encode() of contractTrieNode: - let offerValue = ContractTrieNodeOffer.decode(contentValueBytes).valueOr: + let offer = ContractTrieNodeOffer.decode(contentBytes).valueOr: raise invalidValueErr - offerValue.toRetrievalValue.encode() + validateOffer(Opt.none(Hash32), key.contractTrieNodeKey, offer).isOkOr: + raise invalidValueErr + offer.toRetrievalValue.encode() of contractCode: - let offerValue = ContractCodeOffer.decode(contentValueBytes).valueOr: + let offer = ContractCodeOffer.decode(contentBytes).valueOr: raise invalidValueErr - offerValue.toRetrievalValue.encode() + validateOffer(Opt.none(Hash32), key.contractCodeKey, offer).isOkOr: + raise invalidValueErr + offer.toRetrievalValue.encode() - let contentId = p.toContentId(key) - if contentId.isSome(): - p.storeContent(key, contentId.get(), valueToStore) - return true - else: - raise invalidKeyErr() + p.storeContent(keyBytes, contentId, valueToStore) rpcServer.rpc("portal_stateLocalContent") do(contentKey: string) -> string: let - key = ContentKeyByteList.init(hexToSeqByte(contentKey)) - contentId = p.toContentId(key).valueOr: + keyBytes = ContentKeyByteList.init(hexToSeqByte(contentKey)) + key = ContentKey.decode(keyBytes).valueOr: + raise invalidKeyErr() + contentId = p.toContentId(keyBytes).valueOr: raise invalidKeyErr() - contentResult = p.dbGet(key, contentId).valueOr: + contentResult = p.dbGet(keyBytes, contentId).valueOr: raise contentNotFoundErr() return contentResult.to0xHex()