diff --git a/src/db/blocks.py b/src/db/blocks.py index adadf8d..7ab98e9 100644 --- a/src/db/blocks.py +++ b/src/db/blocks.py @@ -16,9 +16,42 @@ from models.block import Block logger = logging.getLogger(__name__) +def chain_block_ids_cte(*, fork: int): + """ + Recursive CTE that collects all block IDs on the chain from the tip + of the given fork back to genesis, following parent_block links. + + This correctly traverses across fork boundaries — e.g. if fork 1 diverged + from fork 0 at height 50, the CTE returns fork 1 blocks (50+) AND the + ancestor fork 0 blocks (0–49). + """ + tip_hash = ( + select(Block.hash) + .where(Block.fork == fork) + .order_by(Block.height.desc()) + .limit(1) + .scalar_subquery() + ) + + base = select(Block.id, Block.hash, Block.parent_block).where( + Block.hash == tip_hash + ) + cte = base.cte(name="chain", recursive=True) + + recursive = select(Block.id, Block.hash, Block.parent_block).where( + Block.hash == cte.c.parent_block + ) + return cte.union_all(recursive) + + def get_latest_statement(limit: int, *, fork: int, output_ascending: bool = True) -> Select: - # Fetch the latest N blocks in descending height order - base = select(Block).where(Block.fork == fork).order_by(Block.height.desc()).limit(limit) + chain = chain_block_ids_cte(fork=fork) + base = ( + select(Block) + .join(chain, Block.id == chain.c.id) + .order_by(Block.height.desc()) + .limit(limit) + ) if not output_ascending: return base @@ -249,19 +282,19 @@ class BlockRepository: async def get_paginated(self, page: int, page_size: int, *, fork: int) -> tuple[List[Block], int]: """ Get blocks with pagination, ordered by height descending (newest first). + Follows the chain from the fork's tip back to genesis across fork boundaries. Returns a tuple of (blocks, total_count). """ offset = page * page_size + chain = chain_block_ids_cte(fork=fork) with self.client.session() as session: - # Get total count for this fork - count_statement = select(sa_func.count()).select_from(Block).where(Block.fork == fork) + count_statement = select(sa_func.count()).select_from(chain) total_count = session.exec(count_statement).one() - # Get paginated blocks statement = ( select(Block) - .where(Block.fork == fork) + .join(chain, Block.id == chain.c.id) .order_by(Block.height.desc()) .offset(offset) .limit(page_size) diff --git a/src/db/transaction.py b/src/db/transaction.py index b01b001..9aee8f7 100644 --- a/src/db/transaction.py +++ b/src/db/transaction.py @@ -6,17 +6,18 @@ from sqlalchemy import Result, Select, func as sa_func from sqlalchemy.orm import aliased, selectinload from sqlmodel import select +from db.blocks import chain_block_ids_cte from db.clients import DbClient from models.block import Block from models.transactions.transaction import Transaction def get_latest_statement(limit: int, *, fork: int, output_ascending: bool, preload_relationships: bool) -> Select: - # Join with Block to order by Block's height and fetch the latest N transactions in descending order + chain = chain_block_ids_cte(fork=fork) base = ( select(Transaction, Block.height.label("block__height")) .join(Block, Transaction.block_id == Block.id) - .where(Block.fork == fork) + .join(chain, Block.id == chain.c.id) .order_by(Block.height.desc(), Transaction.id.desc()) .limit(limit) ) @@ -52,10 +53,12 @@ class TransactionRepository: return Empty() async def get_by_hash(self, transaction_hash: bytes, *, fork: int) -> Option[Transaction]: + chain = chain_block_ids_cte(fork=fork) statement = ( select(Transaction) .join(Block, Transaction.block_id == Block.id) - .where(Transaction.hash == transaction_hash, Block.fork == fork) + .join(chain, Block.id == chain.c.id) + .where(Transaction.hash == transaction_hash) ) with self.client.session() as session: @@ -80,26 +83,26 @@ class TransactionRepository: async def get_paginated(self, page: int, page_size: int, *, fork: int) -> tuple[List[Transaction], int]: """ Get transactions with pagination, ordered by block height descending (newest first). + Follows the chain from the fork's tip back to genesis across fork boundaries. Returns a tuple of (transactions, total_count). """ offset = page * page_size + chain = chain_block_ids_cte(fork=fork) with self.client.session() as session: - # Get total count for this fork count_statement = ( select(sa_func.count()) .select_from(Transaction) .join(Block, Transaction.block_id == Block.id) - .where(Block.fork == fork) + .join(chain, Block.id == chain.c.id) ) total_count = session.exec(count_statement).one() - # Get paginated transactions statement = ( select(Transaction) .options(selectinload(Transaction.block)) .join(Block, Transaction.block_id == Block.id) - .where(Block.fork == fork) + .join(chain, Block.id == chain.c.id) .order_by(Block.height.desc(), Transaction.id.desc()) .offset(offset) .limit(page_size) diff --git a/tests/test_block_forks.py b/tests/test_block_forks.py index 2da05bd..65bdd10 100644 --- a/tests/test_block_forks.py +++ b/tests/test_block_forks.py @@ -327,8 +327,17 @@ def test_fork_choice_switches_on_overtake(client, repo): # --- Fork-filtered query tests --- -def test_get_latest_filters_by_fork(client, repo): - """get_latest with fork filter only returns blocks from that fork.""" +def test_get_latest_follows_chain(client, repo): + """ + get_latest follows the chain from the fork's tip back to genesis, + crossing fork boundaries. + + genesis (fork 0) -> A (fork 0) + \\-> B (fork 1) + + Fork 0 chain: genesis, A + Fork 1 chain: genesis, B (genesis is a shared ancestor) + """ genesis = make_block(b"\x01", parent=b"\x00", slot=0) asyncio.run(repo.create(genesis)) @@ -338,24 +347,29 @@ def test_get_latest_filters_by_fork(client, repo): b = make_block(b"\x03", parent=b"\x01", slot=1) asyncio.run(repo.create(b)) - # Fork 0: genesis, A. Fork 1: B (but B also shares genesis at fork 0... no, genesis is fork 0) - # Actually: genesis=fork0, A=fork0, B=fork1 fork0_blocks = asyncio.run(repo.get_latest(10, fork=0)) fork1_blocks = asyncio.run(repo.get_latest(10, fork=1)) fork0_hashes = {b.hash for b in fork0_blocks} fork1_hashes = {b.hash for b in fork1_blocks} - assert b"\x01" in fork0_hashes # genesis - assert b"\x02" in fork0_hashes # A - assert b"\x03" not in fork0_hashes + # Fork 0 chain: genesis + A + assert fork0_hashes == {b"\x01", b"\x02"} - assert b"\x03" in fork1_hashes # B - assert b"\x02" not in fork1_hashes + # Fork 1 chain: genesis + B (crosses fork boundary to include genesis) + assert fork1_hashes == {b"\x01", b"\x03"} -def test_get_paginated_filters_by_fork(client, repo): - """get_paginated with fork filter only returns blocks from that fork.""" +def test_get_paginated_follows_chain(client, repo): + """ + get_paginated follows the chain from the fork's tip back to genesis. + + genesis (fork 0) -> A (fork 0) -> C (fork 0) + \\-> B (fork 1) + + Fork 0 chain: genesis, A, C (count=3) + Fork 1 chain: genesis, B (count=2, crosses fork boundary) + """ genesis = make_block(b"\x01", parent=b"\x00", slot=0) asyncio.run(repo.create(genesis)) @@ -365,10 +379,13 @@ def test_get_paginated_filters_by_fork(client, repo): b = make_block(b"\x03", parent=b"\x01", slot=1) asyncio.run(repo.create(b)) + c = make_block(b"\x04", parent=b"\x02", slot=2) + asyncio.run(repo.create(c)) + blocks_f0, count_f0 = asyncio.run(repo.get_paginated(0, 10, fork=0)) blocks_f1, count_f1 = asyncio.run(repo.get_paginated(0, 10, fork=1)) - assert count_f0 == 2 # genesis + A - assert count_f1 == 1 # B only - assert {b.hash for b in blocks_f0} == {b"\x01", b"\x02"} - assert {b.hash for b in blocks_f1} == {b"\x03"} + assert count_f0 == 3 # genesis + A + C + assert count_f1 == 2 # genesis + B (crosses fork boundary) + assert {b.hash for b in blocks_f0} == {b"\x01", b"\x02", b"\x04"} + assert {b.hash for b in blocks_f1} == {b"\x01", b"\x03"}