follow chain through forks

This commit is contained in:
David Rusu 2026-02-17 10:10:47 +04:00
parent 15da0cc2aa
commit 8e5ec0647c
3 changed files with 81 additions and 28 deletions

View File

@ -16,9 +16,42 @@ from models.block import Block
logger = logging.getLogger(__name__)
def chain_block_ids_cte(*, fork: int):
"""
Recursive CTE that collects all block IDs on the chain from the tip
of the given fork back to genesis, following parent_block links.
This correctly traverses across fork boundaries e.g. if fork 1 diverged
from fork 0 at height 50, the CTE returns fork 1 blocks (50+) AND the
ancestor fork 0 blocks (049).
"""
tip_hash = (
select(Block.hash)
.where(Block.fork == fork)
.order_by(Block.height.desc())
.limit(1)
.scalar_subquery()
)
base = select(Block.id, Block.hash, Block.parent_block).where(
Block.hash == tip_hash
)
cte = base.cte(name="chain", recursive=True)
recursive = select(Block.id, Block.hash, Block.parent_block).where(
Block.hash == cte.c.parent_block
)
return cte.union_all(recursive)
def get_latest_statement(limit: int, *, fork: int, output_ascending: bool = True) -> Select:
# Fetch the latest N blocks in descending height order
base = select(Block).where(Block.fork == fork).order_by(Block.height.desc()).limit(limit)
chain = chain_block_ids_cte(fork=fork)
base = (
select(Block)
.join(chain, Block.id == chain.c.id)
.order_by(Block.height.desc())
.limit(limit)
)
if not output_ascending:
return base
@ -249,19 +282,19 @@ class BlockRepository:
async def get_paginated(self, page: int, page_size: int, *, fork: int) -> tuple[List[Block], int]:
"""
Get blocks with pagination, ordered by height descending (newest first).
Follows the chain from the fork's tip back to genesis across fork boundaries.
Returns a tuple of (blocks, total_count).
"""
offset = page * page_size
chain = chain_block_ids_cte(fork=fork)
with self.client.session() as session:
# Get total count for this fork
count_statement = select(sa_func.count()).select_from(Block).where(Block.fork == fork)
count_statement = select(sa_func.count()).select_from(chain)
total_count = session.exec(count_statement).one()
# Get paginated blocks
statement = (
select(Block)
.where(Block.fork == fork)
.join(chain, Block.id == chain.c.id)
.order_by(Block.height.desc())
.offset(offset)
.limit(page_size)

View File

@ -6,17 +6,18 @@ from sqlalchemy import Result, Select, func as sa_func
from sqlalchemy.orm import aliased, selectinload
from sqlmodel import select
from db.blocks import chain_block_ids_cte
from db.clients import DbClient
from models.block import Block
from models.transactions.transaction import Transaction
def get_latest_statement(limit: int, *, fork: int, output_ascending: bool, preload_relationships: bool) -> Select:
# Join with Block to order by Block's height and fetch the latest N transactions in descending order
chain = chain_block_ids_cte(fork=fork)
base = (
select(Transaction, Block.height.label("block__height"))
.join(Block, Transaction.block_id == Block.id)
.where(Block.fork == fork)
.join(chain, Block.id == chain.c.id)
.order_by(Block.height.desc(), Transaction.id.desc())
.limit(limit)
)
@ -52,10 +53,12 @@ class TransactionRepository:
return Empty()
async def get_by_hash(self, transaction_hash: bytes, *, fork: int) -> Option[Transaction]:
chain = chain_block_ids_cte(fork=fork)
statement = (
select(Transaction)
.join(Block, Transaction.block_id == Block.id)
.where(Transaction.hash == transaction_hash, Block.fork == fork)
.join(chain, Block.id == chain.c.id)
.where(Transaction.hash == transaction_hash)
)
with self.client.session() as session:
@ -80,26 +83,26 @@ class TransactionRepository:
async def get_paginated(self, page: int, page_size: int, *, fork: int) -> tuple[List[Transaction], int]:
"""
Get transactions with pagination, ordered by block height descending (newest first).
Follows the chain from the fork's tip back to genesis across fork boundaries.
Returns a tuple of (transactions, total_count).
"""
offset = page * page_size
chain = chain_block_ids_cte(fork=fork)
with self.client.session() as session:
# Get total count for this fork
count_statement = (
select(sa_func.count())
.select_from(Transaction)
.join(Block, Transaction.block_id == Block.id)
.where(Block.fork == fork)
.join(chain, Block.id == chain.c.id)
)
total_count = session.exec(count_statement).one()
# Get paginated transactions
statement = (
select(Transaction)
.options(selectinload(Transaction.block))
.join(Block, Transaction.block_id == Block.id)
.where(Block.fork == fork)
.join(chain, Block.id == chain.c.id)
.order_by(Block.height.desc(), Transaction.id.desc())
.offset(offset)
.limit(page_size)

View File

@ -327,8 +327,17 @@ def test_fork_choice_switches_on_overtake(client, repo):
# --- Fork-filtered query tests ---
def test_get_latest_filters_by_fork(client, repo):
"""get_latest with fork filter only returns blocks from that fork."""
def test_get_latest_follows_chain(client, repo):
"""
get_latest follows the chain from the fork's tip back to genesis,
crossing fork boundaries.
genesis (fork 0) -> A (fork 0)
\\-> B (fork 1)
Fork 0 chain: genesis, A
Fork 1 chain: genesis, B (genesis is a shared ancestor)
"""
genesis = make_block(b"\x01", parent=b"\x00", slot=0)
asyncio.run(repo.create(genesis))
@ -338,24 +347,29 @@ def test_get_latest_filters_by_fork(client, repo):
b = make_block(b"\x03", parent=b"\x01", slot=1)
asyncio.run(repo.create(b))
# Fork 0: genesis, A. Fork 1: B (but B also shares genesis at fork 0... no, genesis is fork 0)
# Actually: genesis=fork0, A=fork0, B=fork1
fork0_blocks = asyncio.run(repo.get_latest(10, fork=0))
fork1_blocks = asyncio.run(repo.get_latest(10, fork=1))
fork0_hashes = {b.hash for b in fork0_blocks}
fork1_hashes = {b.hash for b in fork1_blocks}
assert b"\x01" in fork0_hashes # genesis
assert b"\x02" in fork0_hashes # A
assert b"\x03" not in fork0_hashes
# Fork 0 chain: genesis + A
assert fork0_hashes == {b"\x01", b"\x02"}
assert b"\x03" in fork1_hashes # B
assert b"\x02" not in fork1_hashes
# Fork 1 chain: genesis + B (crosses fork boundary to include genesis)
assert fork1_hashes == {b"\x01", b"\x03"}
def test_get_paginated_filters_by_fork(client, repo):
"""get_paginated with fork filter only returns blocks from that fork."""
def test_get_paginated_follows_chain(client, repo):
"""
get_paginated follows the chain from the fork's tip back to genesis.
genesis (fork 0) -> A (fork 0) -> C (fork 0)
\\-> B (fork 1)
Fork 0 chain: genesis, A, C (count=3)
Fork 1 chain: genesis, B (count=2, crosses fork boundary)
"""
genesis = make_block(b"\x01", parent=b"\x00", slot=0)
asyncio.run(repo.create(genesis))
@ -365,10 +379,13 @@ def test_get_paginated_filters_by_fork(client, repo):
b = make_block(b"\x03", parent=b"\x01", slot=1)
asyncio.run(repo.create(b))
c = make_block(b"\x04", parent=b"\x02", slot=2)
asyncio.run(repo.create(c))
blocks_f0, count_f0 = asyncio.run(repo.get_paginated(0, 10, fork=0))
blocks_f1, count_f1 = asyncio.run(repo.get_paginated(0, 10, fork=1))
assert count_f0 == 2 # genesis + A
assert count_f1 == 1 # B only
assert {b.hash for b in blocks_f0} == {b"\x01", b"\x02"}
assert {b.hash for b in blocks_f1} == {b"\x03"}
assert count_f0 == 3 # genesis + A + C
assert count_f1 == 2 # genesis + B (crosses fork boundary)
assert {b.hash for b in blocks_f0} == {b"\x01", b"\x02", b"\x04"}
assert {b.hash for b in blocks_f1} == {b"\x01", b"\x03"}