Fluffy state portal rpc validation (#2719)

* Validate content key in portal_stateLocalContent.

* Add additional validation to stateStore rpc method.
This commit is contained in:
bhartnett 2024-10-09 20:23:46 +08:00 committed by GitHub
parent 5edb0b320f
commit b13f06fcfb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 36 additions and 34 deletions

View File

@ -45,7 +45,6 @@ proc validateTrieProof*(
if proof.len() == 0: if proof.len() == 0:
return err("proof is empty") return err("proof is empty")
# TODO: Remove this once the hive tests support passing in state roots from the history network
if expectedRootHash.isSome(): if expectedRootHash.isSome():
if not proof[0].hashEquals(expectedRootHash.get()): if not proof[0].hashEquals(expectedRootHash.get()):
return err("hash of proof root node doesn't match the expected root hash") return err("hash of proof root node doesn't match the expected root hash")

View File

@ -1590,12 +1590,15 @@ proc storeContent*(
contentKey: ContentKeyByteList, contentKey: ContentKeyByteList,
contentId: ContentId, contentId: ContentId,
content: seq[byte], content: seq[byte],
) = ): bool {.discardable.} =
# Always re-check that the key is still in the node range to make sure only # Always re-check that the key is still in the node range to make sure only
# content in range is stored. # content in range is stored.
if p.inRange(contentId): if p.inRange(contentId):
doAssert(p.dbPut != nil) doAssert(p.dbPut != nil)
p.dbPut(contentKey, contentId, content) p.dbPut(contentKey, contentId, content)
true
else:
false
proc seedTable*(p: PortalProtocol) = proc seedTable*(p: PortalProtocol) =
## Seed the table with specifically provided Portal bootstrap nodes. These are ## Seed the table with specifically provided Portal bootstrap nodes. These are

View File

@ -115,13 +115,10 @@ proc installPortalBeaconApiHandlers*(rpcServer: RpcServer, p: PortalProtocol) =
let let
key = ContentKeyByteList.init(hexToSeqByte(contentKey)) key = ContentKeyByteList.init(hexToSeqByte(contentKey))
contentValueBytes = hexToSeqByte(contentValue) contentValueBytes = hexToSeqByte(contentValue)
contentId = p.toContentId(key) contentId = p.toContentId(key).valueOr:
raise invalidKeyErr()
if contentId.isSome(): p.storeContent(key, contentId, contentValueBytes)
p.storeContent(key, contentId.get(), contentValueBytes)
return true
else:
raise invalidKeyErr()
rpcServer.rpc("portal_beaconLocalContent") do(contentKey: string) -> string: rpcServer.rpc("portal_beaconLocalContent") do(contentKey: string) -> string:
let let

View File

@ -115,13 +115,10 @@ proc installPortalHistoryApiHandlers*(rpcServer: RpcServer, p: PortalProtocol) =
let let
key = ContentKeyByteList.init(hexToSeqByte(contentKey)) key = ContentKeyByteList.init(hexToSeqByte(contentKey))
contentValueBytes = hexToSeqByte(contentValue) contentValueBytes = hexToSeqByte(contentValue)
contentId = p.toContentId(key) contentId = p.toContentId(key).valueOr:
raise invalidKeyErr()
if contentId.isSome(): p.storeContent(key, contentId, contentValueBytes)
p.storeContent(key, contentId.get(), contentValueBytes)
return true
else:
raise invalidKeyErr()
rpcServer.rpc("portal_historyLocalContent") do(contentKey: string) -> string: rpcServer.rpc("portal_historyLocalContent") do(contentKey: string) -> string:
let let

View File

@ -13,7 +13,7 @@ import
json_serialization/std/tables, json_serialization/std/tables,
stew/byteutils, stew/byteutils,
../network/wire/portal_protocol, ../network/wire/portal_protocol,
../network/state/state_content, ../network/state/[state_content, state_validation],
./rpc_types ./rpc_types
{.warning[UnusedImport]: off.} {.warning[UnusedImport]: off.}
@ -114,41 +114,47 @@ proc installPortalStateApiHandlers*(rpcServer: RpcServer, p: PortalProtocol) =
contentKey: string, contentValue: string contentKey: string, contentValue: string
) -> bool: ) -> bool:
let let
key = ContentKeyByteList.init(hexToSeqByte(contentKey)) keyBytes = ContentKeyByteList.init(hexToSeqByte(contentKey))
contentValueBytes = hexToSeqByte(contentValue) key = ContentKey.decode(keyBytes).valueOr:
decodedKey = ContentKey.decode(key).valueOr:
raise invalidKeyErr() raise invalidKeyErr()
contentId = p.toContentId(keyBytes).valueOr:
raise invalidKeyErr()
contentBytes = hexToSeqByte(contentValue)
valueToStore = valueToStore =
case decodedKey.contentType case key.contentType
of unused: of unused:
raise invalidKeyErr() raise invalidKeyErr()
of accountTrieNode: of accountTrieNode:
let offerValue = AccountTrieNodeOffer.decode(contentValueBytes).valueOr: let offer = AccountTrieNodeOffer.decode(contentBytes).valueOr:
raise invalidValueErr raise invalidValueErr
offerValue.toRetrievalValue.encode() validateOffer(Opt.none(Hash32), key.accountTrieNodeKey, offer).isOkOr:
raise invalidValueErr
offer.toRetrievalValue.encode()
of contractTrieNode: of contractTrieNode:
let offerValue = ContractTrieNodeOffer.decode(contentValueBytes).valueOr: let offer = ContractTrieNodeOffer.decode(contentBytes).valueOr:
raise invalidValueErr raise invalidValueErr
offerValue.toRetrievalValue.encode() validateOffer(Opt.none(Hash32), key.contractTrieNodeKey, offer).isOkOr:
raise invalidValueErr
offer.toRetrievalValue.encode()
of contractCode: of contractCode:
let offerValue = ContractCodeOffer.decode(contentValueBytes).valueOr: let offer = ContractCodeOffer.decode(contentBytes).valueOr:
raise invalidValueErr raise invalidValueErr
offerValue.toRetrievalValue.encode() validateOffer(Opt.none(Hash32), key.contractCodeKey, offer).isOkOr:
raise invalidValueErr
offer.toRetrievalValue.encode()
let contentId = p.toContentId(key) p.storeContent(keyBytes, contentId, valueToStore)
if contentId.isSome():
p.storeContent(key, contentId.get(), valueToStore)
return true
else:
raise invalidKeyErr()
rpcServer.rpc("portal_stateLocalContent") do(contentKey: string) -> string: rpcServer.rpc("portal_stateLocalContent") do(contentKey: string) -> string:
let let
key = ContentKeyByteList.init(hexToSeqByte(contentKey)) keyBytes = ContentKeyByteList.init(hexToSeqByte(contentKey))
contentId = p.toContentId(key).valueOr: key = ContentKey.decode(keyBytes).valueOr:
raise invalidKeyErr()
contentId = p.toContentId(keyBytes).valueOr:
raise invalidKeyErr() raise invalidKeyErr()
contentResult = p.dbGet(key, contentId).valueOr: contentResult = p.dbGet(keyBytes, contentId).valueOr:
raise contentNotFoundErr() raise contentNotFoundErr()
return contentResult.to0xHex() return contentResult.to0xHex()