diff --git a/src/api/v1/blocks.py b/src/api/v1/blocks.py index 3576fc3..9138443 100644 --- a/src/api/v1/blocks.py +++ b/src/api/v1/blocks.py @@ -19,8 +19,9 @@ async def list_blocks( request: NBERequest, page: int = Query(0, ge=0), page_size: int = Query(10, ge=1, le=100, alias="page-size"), + fork: int = Query(...), ) -> Response: - blocks, total_count = await request.app.state.block_repository.get_paginated(page, page_size) + blocks, total_count = await request.app.state.block_repository.get_paginated(page, page_size, fork=fork) total_pages = (total_count + page_size - 1) // page_size # ceiling division return JSONResponse({ @@ -32,18 +33,24 @@ async def list_blocks( }) -async def _get_blocks_stream_serialized(app: "NBE", block_from: Option[Block]) -> AsyncIterator[List[BlockRead]]: - _stream = app.state.block_repository.updates_stream(block_from) +async def _get_blocks_stream_serialized( + app: "NBE", block_from: Option[Block], *, fork: int +) -> AsyncIterator[List[BlockRead]]: + _stream = app.state.block_repository.updates_stream(block_from, fork=fork) async for blocks in _stream: yield [BlockRead.from_block(block) for block in blocks] -async def stream(request: NBERequest, prefetch_limit: int = Query(0, alias="prefetch-limit", ge=0)) -> Response: - latest_blocks = await request.app.state.block_repository.get_latest(prefetch_limit) +async def stream( + request: NBERequest, + prefetch_limit: int = Query(0, alias="prefetch-limit", ge=0), + fork: int = Query(...), +) -> Response: + latest_blocks = await request.app.state.block_repository.get_latest(prefetch_limit, fork=fork) latest_block = Some(latest_blocks[-1]) if latest_blocks else Empty() bootstrap_blocks: List[BlockRead] = [BlockRead.from_block(block) for block in latest_blocks] - blocks_stream: AsyncIterator[List[BlockRead]] = _get_blocks_stream_serialized(request.app, latest_block) + blocks_stream: AsyncIterator[List[BlockRead]] = _get_blocks_stream_serialized(request.app, latest_block, fork=fork) ndjson_blocks_stream = into_ndjson_stream(blocks_stream, bootstrap_data=bootstrap_blocks) return NDJsonStreamingResponse(ndjson_blocks_stream) diff --git a/src/api/v1/fork_choice.py b/src/api/v1/fork_choice.py new file mode 100644 index 0000000..1bc770e --- /dev/null +++ b/src/api/v1/fork_choice.py @@ -0,0 +1,10 @@ +from http.client import NOT_FOUND + +from starlette.responses import JSONResponse, Response + +from core.api import NBERequest + + +async def get(request: NBERequest) -> Response: + fork = await request.app.state.block_repository.get_fork_choice() + return fork.map(lambda f: JSONResponse({"fork": f})).unwrap_or_else(lambda: Response(status_code=NOT_FOUND)) diff --git a/src/api/v1/router.py b/src/api/v1/router.py index 5e6e779..6e3d36f 100644 --- a/src/api/v1/router.py +++ b/src/api/v1/router.py @@ -1,6 +1,6 @@ from fastapi import APIRouter -from . import blocks, health, index, transactions +from . import blocks, fork_choice, health, index, transactions def create_v1_router() -> APIRouter: @@ -17,6 +17,8 @@ def create_v1_router() -> APIRouter: router.add_api_route("/transactions/stream", transactions.stream, methods=["GET"]) router.add_api_route("/transactions/{transaction_hash:str}", transactions.get, methods=["GET"]) + router.add_api_route("/fork-choice", fork_choice.get, methods=["GET"]) + router.add_api_route("/blocks/stream", blocks.stream, methods=["GET"]) router.add_api_route("/blocks/list", blocks.list_blocks, methods=["GET"]) router.add_api_route("/blocks/{block_hash:str}", blocks.get, methods=["GET"]) diff --git a/src/api/v1/serializers/blocks.py b/src/api/v1/serializers/blocks.py index 3c9cb48..3a01fa5 100644 --- a/src/api/v1/serializers/blocks.py +++ b/src/api/v1/serializers/blocks.py @@ -13,6 +13,7 @@ class BlockRead(NbeSchema): parent_block_hash: HexBytes slot: int height: int + fork: int block_root: HexBytes proof_of_leadership: ProofOfLeadership transactions: List[Transaction] @@ -25,6 +26,7 @@ class BlockRead(NbeSchema): parent_block_hash=block.parent_block, slot=block.slot, height=block.height, + fork=block.fork, block_root=block.block_root, proof_of_leadership=block.proof_of_leadership, transactions=block.transactions, diff --git a/src/api/v1/transactions.py b/src/api/v1/transactions.py index 1a00578..4cd4e67 100644 --- a/src/api/v1/transactions.py +++ b/src/api/v1/transactions.py @@ -16,22 +16,26 @@ if TYPE_CHECKING: async def _get_transactions_stream_serialized( - app: "NBE", transaction_from: Option[Transaction] + app: "NBE", transaction_from: Option[Transaction], *, fork: int ) -> AsyncIterator[List[TransactionRead]]: - _stream = app.state.transaction_repository.updates_stream(transaction_from) + _stream = app.state.transaction_repository.updates_stream(transaction_from, fork=fork) async for transactions in _stream: yield [TransactionRead.from_transaction(transaction) for transaction in transactions] -async def stream(request: NBERequest, prefetch_limit: int = Query(0, alias="prefetch-limit", ge=0)) -> Response: +async def stream( + request: NBERequest, + prefetch_limit: int = Query(0, alias="prefetch-limit", ge=0), + fork: int = Query(...), +) -> Response: latest_transactions: List[Transaction] = await request.app.state.transaction_repository.get_latest( - prefetch_limit, ascending=True, preload_relationships=True + prefetch_limit, fork=fork, ascending=True, preload_relationships=True ) latest_transaction = Some(latest_transactions[-1]) if latest_transactions else Empty() bootstrap_transactions = [TransactionRead.from_transaction(transaction) for transaction in latest_transactions] transactions_stream: AsyncIterator[List[TransactionRead]] = _get_transactions_stream_serialized( - request.app, latest_transaction + request.app, latest_transaction, fork=fork ) ndjson_transactions_stream = into_ndjson_stream(transactions_stream, bootstrap_data=bootstrap_transactions) return NDJsonStreamingResponse(ndjson_transactions_stream) diff --git a/src/db/blocks.py b/src/db/blocks.py index 9374d65..adadf8d 100644 --- a/src/db/blocks.py +++ b/src/db/blocks.py @@ -16,9 +16,9 @@ from models.block import Block logger = logging.getLogger(__name__) -def get_latest_statement(limit: int, *, output_ascending: bool = True) -> Select: +def get_latest_statement(limit: int, *, fork: int, output_ascending: bool = True) -> Select: # Fetch the latest N blocks in descending height order - base = select(Block).order_by(Block.height.desc()).limit(limit) + base = select(Block).where(Block.fork == fork).order_by(Block.height.desc()).limit(limit) if not output_ascending: return base @@ -215,17 +215,27 @@ class BlockRepository: else: return Empty() - async def get_latest(self, limit: int, *, ascending: bool = True) -> List[Block]: + async def get_latest(self, limit: int, *, fork: int, ascending: bool = True) -> List[Block]: if limit == 0: return [] - statement = get_latest_statement(limit, output_ascending=ascending) + statement = get_latest_statement(limit, fork=fork, output_ascending=ascending) with self.client.session() as session: results: Result[Block] = session.exec(statement) b = results.all() return b + async def get_fork_choice(self) -> Option[int]: + """Return the fork number of the longest chain (block with max height).""" + statement = select(Block.fork).order_by(Block.height.desc()).limit(1) + with self.client.session() as session: + result = session.exec(statement).one_or_none() + if result is not None: + return Some(result) + else: + return Empty() + async def get_earliest(self) -> Option[Block]: statement = select(Block).order_by(Block.height.asc()).limit(1) @@ -236,7 +246,7 @@ class BlockRepository: else: return Empty() - async def get_paginated(self, page: int, page_size: int) -> tuple[List[Block], int]: + async def get_paginated(self, page: int, page_size: int, *, fork: int) -> tuple[List[Block], int]: """ Get blocks with pagination, ordered by height descending (newest first). Returns a tuple of (blocks, total_count). @@ -244,14 +254,14 @@ class BlockRepository: offset = page * page_size with self.client.session() as session: - # Get total count - from sqlalchemy import func - count_statement = select(func.count()).select_from(Block) + # Get total count for this fork + count_statement = select(sa_func.count()).select_from(Block).where(Block.fork == fork) total_count = session.exec(count_statement).one() # Get paginated blocks statement = ( select(Block) + .where(Block.fork == fork) .order_by(Block.height.desc()) .offset(offset) .limit(page_size) @@ -261,14 +271,14 @@ class BlockRepository: return blocks, total_count async def updates_stream( - self, block_from: Option[Block], *, timeout_seconds: int = 1 + self, block_from: Option[Block], *, fork: int, timeout_seconds: int = 1 ) -> AsyncIterator[List[Block]]: height_cursor: int = block_from.map(lambda block: block.height + 1).unwrap_or(0) while True: statement = ( select(Block) - .where(Block.height >= height_cursor) + .where(Block.fork == fork, Block.height >= height_cursor) .order_by(Block.height.asc()) ) diff --git a/src/db/transaction.py b/src/db/transaction.py index 698bbe4..ca318c8 100644 --- a/src/db/transaction.py +++ b/src/db/transaction.py @@ -11,11 +11,12 @@ from models.block import Block from models.transactions.transaction import Transaction -def get_latest_statement(limit: int, *, output_ascending: bool, preload_relationships: bool) -> Select: +def get_latest_statement(limit: int, *, fork: int, output_ascending: bool, preload_relationships: bool) -> Select: # Join with Block to order by Block's height and fetch the latest N transactions in descending order base = ( select(Transaction, Block.height.label("block__height")) .join(Block, Transaction.block_id == Block.id) + .where(Block.fork == fork) .order_by(Block.height.desc(), Transaction.id.desc()) .limit(limit) ) @@ -61,19 +62,19 @@ class TransactionRepository: return Empty() async def get_latest( - self, limit: int, *, ascending: bool = False, preload_relationships: bool = False + self, limit: int, *, fork: int, ascending: bool = False, preload_relationships: bool = False ) -> List[Transaction]: if limit == 0: return [] - statement = get_latest_statement(limit, output_ascending=ascending, preload_relationships=preload_relationships) + statement = get_latest_statement(limit, fork=fork, output_ascending=ascending, preload_relationships=preload_relationships) with self.client.session() as session: results: Result[Transaction] = session.exec(statement) return results.all() async def updates_stream( - self, transaction_from: Option[Transaction], *, timeout_seconds: int = 1 + self, transaction_from: Option[Transaction], *, fork: int, timeout_seconds: int = 1 ) -> AsyncIterator[List[Transaction]]: height_cursor = transaction_from.map(lambda transaction: transaction.block.height).unwrap_or(0) transaction_id_cursor = transaction_from.map(lambda transaction: transaction.id + 1).unwrap_or(0) @@ -84,6 +85,7 @@ class TransactionRepository: .options(selectinload(Transaction.block)) .join(Block, Transaction.block_id == Block.id) .where( + Block.fork == fork, Block.height >= height_cursor, Transaction.id >= transaction_id_cursor, ) diff --git a/static/components/BlocksTable.js b/static/components/BlocksTable.js index 9765b43..72524f1 100644 --- a/static/components/BlocksTable.js +++ b/static/components/BlocksTable.js @@ -4,6 +4,7 @@ import { useEffect, useState, useCallback, useRef } from 'preact/hooks'; import { PAGE, API } from '../lib/api.js'; import { TABLE_SIZE } from '../lib/constants.js'; import { shortenHex, streamNdjson } from '../lib/utils.js'; +import { subscribeFork } from '../lib/fork.js'; const normalize = (raw) => { const header = raw.header ?? null; @@ -32,12 +33,18 @@ export default function BlocksTable() { const [loading, setLoading] = useState(true); const [error, setError] = useState(null); const [live, setLive] = useState(true); // Start in live mode + const [fork, setFork] = useState(null); const abortRef = useRef(null); const seenKeysRef = useRef(new Set()); + // Subscribe to fork-choice changes + useEffect(() => { + return subscribeFork((newFork) => setFork(newFork)); + }, []); + // Fetch paginated blocks - const fetchBlocks = useCallback(async (pageNum) => { + const fetchBlocks = useCallback(async (pageNum, currentFork) => { // Stop any live stream abortRef.current?.abort(); seenKeysRef.current.clear(); @@ -45,7 +52,7 @@ export default function BlocksTable() { setLoading(true); setError(null); try { - const res = await fetch(API.BLOCKS_LIST(pageNum, TABLE_SIZE)); + const res = await fetch(API.BLOCKS_LIST(pageNum, TABLE_SIZE, currentFork)); if (!res.ok) throw new Error(`HTTP ${res.status}`); const data = await res.json(); setBlocks(data.blocks.map(normalize)); @@ -60,7 +67,7 @@ export default function BlocksTable() { }, []); // Start live streaming - const startLiveStream = useCallback(() => { + const startLiveStream = useCallback((currentFork) => { abortRef.current?.abort(); abortRef.current = new AbortController(); seenKeysRef.current.clear(); @@ -70,8 +77,9 @@ export default function BlocksTable() { let liveBlocks = []; + const url = `${API.BLOCKS_STREAM(currentFork)}&prefetch-limit=${encodeURIComponent(TABLE_SIZE)}`; streamNdjson( - `${API.BLOCKS_STREAM}?prefetch-limit=${encodeURIComponent(TABLE_SIZE)}`, + url, (raw) => { const b = normalize(raw); const key = `${b.id}:${b.slot}`; @@ -96,19 +104,22 @@ export default function BlocksTable() { ); }, []); - // Handle live mode changes + // Handle live mode and fork changes useEffect(() => { + if (fork == null) return; if (live) { - startLiveStream(); + startLiveStream(fork); + } else { + fetchBlocks(page, fork); } return () => abortRef.current?.abort(); - }, [live, startLiveStream]); + }, [live, fork, startLiveStream]); // Go to a page (turns off live mode) const goToPage = (newPage) => { - if (newPage >= 0) { + if (newPage >= 0 && fork != null) { setLive(false); - fetchBlocks(newPage); + fetchBlocks(newPage, fork); } }; diff --git a/static/components/TransactionsTable.js b/static/components/TransactionsTable.js index 4de57e1..ee257cb 100644 --- a/static/components/TransactionsTable.js +++ b/static/components/TransactionsTable.js @@ -1,6 +1,6 @@ -// static/pages/TransactionsTable.js +// static/components/TransactionsTable.js import { h } from 'preact'; -import { useEffect, useRef } from 'preact/hooks'; +import { useEffect, useRef, useState } from 'preact/hooks'; import { API, PAGE } from '../lib/api.js'; import { TABLE_SIZE } from '../lib/constants.js'; import { @@ -9,6 +9,7 @@ import { shortenHex, // (kept in case you want to use later) withBenignFilter, } from '../lib/utils.js'; +import { subscribeFork } from '../lib/fork.js'; const OPERATIONS_PREVIEW_LIMIT = 2; @@ -155,18 +156,31 @@ export default function TransactionsTable() { const countRef = useRef(null); const abortRef = useRef(null); const totalCountRef = useRef(0); + const [fork, setFork] = useState(null); + + // Subscribe to fork-choice changes + useEffect(() => { + return subscribeFork((newFork) => setFork(newFork)); + }, []); useEffect(() => { + if (fork == null) return; + const body = bodyRef.current; const counter = countRef.current; + // Clear existing rows on fork change + while (body.rows.length > 0) body.deleteRow(0); + totalCountRef.current = 0; + counter.textContent = '0'; + // 3 columns: Hash | Operations | Outputs ensureFixedRowCount(body, 3, TABLE_SIZE); abortRef.current?.abort(); abortRef.current = new AbortController(); - const url = `${API.TRANSACTIONS_STREAM}?prefetch-limit=${encodeURIComponent(TABLE_SIZE)}`; + const url = `${API.TRANSACTIONS_STREAM_WITH_FORK(fork)}&prefetch-limit=${encodeURIComponent(TABLE_SIZE)}`; streamNdjson( url, @@ -196,7 +210,7 @@ export default function TransactionsTable() { }); return () => abortRef.current?.abort(); - }, []); + }, [fork]); return h( 'div', diff --git a/static/lib/api.js b/static/lib/api.js index 927b859..24a9e63 100644 --- a/static/lib/api.js +++ b/static/lib/api.js @@ -11,15 +11,23 @@ const HEALTH_ENDPOINT = joinUrl(API_PREFIX, 'health/stream'); const TRANSACTION_DETAIL_BY_HASH = (hash) => joinUrl(API_PREFIX, 'transactions', encodeHash(hash)); const TRANSACTIONS_STREAM = joinUrl(API_PREFIX, 'transactions/stream'); +const FORK_CHOICE = joinUrl(API_PREFIX, 'fork-choice'); + const BLOCK_DETAIL_BY_HASH = (hash) => joinUrl(API_PREFIX, 'blocks', encodeHash(hash)); -const BLOCKS_STREAM = joinUrl(API_PREFIX, 'blocks/stream'); -const BLOCKS_LIST = (page, pageSize) => - `${joinUrl(API_PREFIX, 'blocks/list')}?page=${encodeURIComponent(page)}&page-size=${encodeURIComponent(pageSize)}`; +const BLOCKS_STREAM = (fork) => + `${joinUrl(API_PREFIX, 'blocks/stream')}?fork=${encodeURIComponent(fork)}`; +const BLOCKS_LIST = (page, pageSize, fork) => + `${joinUrl(API_PREFIX, 'blocks/list')}?page=${encodeURIComponent(page)}&page-size=${encodeURIComponent(pageSize)}&fork=${encodeURIComponent(fork)}`; + +const TRANSACTIONS_STREAM_WITH_FORK = (fork) => + `${joinUrl(API_PREFIX, 'transactions/stream')}?fork=${encodeURIComponent(fork)}`; export const API = { HEALTH_ENDPOINT, + FORK_CHOICE, TRANSACTION_DETAIL_BY_HASH, TRANSACTIONS_STREAM, + TRANSACTIONS_STREAM_WITH_FORK, BLOCK_DETAIL_BY_HASH, BLOCKS_STREAM, BLOCKS_LIST, diff --git a/static/lib/fork.js b/static/lib/fork.js new file mode 100644 index 0000000..422d572 --- /dev/null +++ b/static/lib/fork.js @@ -0,0 +1,50 @@ +import { API } from './api.js'; + +const POLL_INTERVAL_MS = 3000; + +let subscribers = new Set(); +let currentFork = null; +let pollTimer = null; + +async function poll() { + try { + const res = await fetch(API.FORK_CHOICE, { cache: 'no-cache' }); + if (!res.ok) return; + const data = await res.json(); + const newFork = data.fork; + if (newFork !== currentFork) { + currentFork = newFork; + for (const cb of subscribers) cb(currentFork); + } + } catch { + // ignore transient errors + } +} + +function startPolling() { + if (pollTimer != null) return; + poll(); // immediate first poll + pollTimer = setInterval(poll, POLL_INTERVAL_MS); +} + +function stopPolling() { + if (pollTimer == null) return; + clearInterval(pollTimer); + pollTimer = null; +} + +/** + * Subscribe to fork-choice changes. + * The callback is invoked immediately if a fork is already known, + * and again whenever the fork changes. + * Returns an unsubscribe function. + */ +export function subscribeFork(callback) { + subscribers.add(callback); + if (subscribers.size === 1) startPolling(); + if (currentFork != null) callback(currentFork); + return () => { + subscribers.delete(callback); + if (subscribers.size === 0) stopPolling(); + }; +} diff --git a/tests/test_block_forks.py b/tests/test_block_forks.py index 97c53a8..2da05bd 100644 --- a/tests/test_block_forks.py +++ b/tests/test_block_forks.py @@ -251,3 +251,124 @@ def test_batch_with_fork_and_chain(client, repo): assert forks[b"\x02"] == 0 # A inherits from genesis assert forks[b"\x03"] == 1 # B forks assert forks[b"\x04"] == 0 # C inherits from A + + +# --- Fork choice tests --- + + +def test_fork_choice_empty_db(client, repo): + """Fork choice returns Empty when no blocks exist.""" + from rusty_results import Empty + result = asyncio.run(repo.get_fork_choice()) + assert isinstance(result, Empty) + + +def test_fork_choice_single_chain(client, repo): + """Fork choice returns fork 0 for a single linear chain.""" + genesis = make_block(b"\x01", parent=b"\x00", slot=0) + a = make_block(b"\x02", parent=b"\x01", slot=1) + asyncio.run(repo.create(genesis, a)) + + result = asyncio.run(repo.get_fork_choice()) + assert result.unwrap() == 0 + + +def test_fork_choice_returns_longest_fork(client, repo): + """ + Fork choice returns the fork with the highest block. + + genesis -> A -> C (fork 0, height 2) + \\-> B (fork 1, height 1) + + Fork 0 is longer, so fork choice should return 0. + """ + genesis = make_block(b"\x01", parent=b"\x00", slot=0) + asyncio.run(repo.create(genesis)) + + a = make_block(b"\x02", parent=b"\x01", slot=1) + asyncio.run(repo.create(a)) + + b = make_block(b"\x03", parent=b"\x01", slot=1) + asyncio.run(repo.create(b)) + + c = make_block(b"\x04", parent=b"\x02", slot=2) + asyncio.run(repo.create(c)) + + result = asyncio.run(repo.get_fork_choice()) + assert result.unwrap() == 0 + + +def test_fork_choice_switches_on_overtake(client, repo): + """ + Fork choice switches when the alternative fork grows longer. + + genesis -> A (fork 0, height 1) + \\-> B -> C (fork 1, height 2) + + Fork 1 is longer, so fork choice should return 1. + """ + genesis = make_block(b"\x01", parent=b"\x00", slot=0) + asyncio.run(repo.create(genesis)) + + a = make_block(b"\x02", parent=b"\x01", slot=1) + asyncio.run(repo.create(a)) + + b = make_block(b"\x03", parent=b"\x01", slot=1) + asyncio.run(repo.create(b)) + + # Fork 0 has height 1 (block A). Now extend fork 1 past it. + c = make_block(b"\x04", parent=b"\x03", slot=2) + asyncio.run(repo.create(c)) + + result = asyncio.run(repo.get_fork_choice()) + assert result.unwrap() == 1 + + +# --- Fork-filtered query tests --- + + +def test_get_latest_filters_by_fork(client, repo): + """get_latest with fork filter only returns blocks from that fork.""" + genesis = make_block(b"\x01", parent=b"\x00", slot=0) + asyncio.run(repo.create(genesis)) + + a = make_block(b"\x02", parent=b"\x01", slot=1) + asyncio.run(repo.create(a)) + + b = make_block(b"\x03", parent=b"\x01", slot=1) + asyncio.run(repo.create(b)) + + # Fork 0: genesis, A. Fork 1: B (but B also shares genesis at fork 0... no, genesis is fork 0) + # Actually: genesis=fork0, A=fork0, B=fork1 + fork0_blocks = asyncio.run(repo.get_latest(10, fork=0)) + fork1_blocks = asyncio.run(repo.get_latest(10, fork=1)) + + fork0_hashes = {b.hash for b in fork0_blocks} + fork1_hashes = {b.hash for b in fork1_blocks} + + assert b"\x01" in fork0_hashes # genesis + assert b"\x02" in fork0_hashes # A + assert b"\x03" not in fork0_hashes + + assert b"\x03" in fork1_hashes # B + assert b"\x02" not in fork1_hashes + + +def test_get_paginated_filters_by_fork(client, repo): + """get_paginated with fork filter only returns blocks from that fork.""" + genesis = make_block(b"\x01", parent=b"\x00", slot=0) + asyncio.run(repo.create(genesis)) + + a = make_block(b"\x02", parent=b"\x01", slot=1) + asyncio.run(repo.create(a)) + + b = make_block(b"\x03", parent=b"\x01", slot=1) + asyncio.run(repo.create(b)) + + blocks_f0, count_f0 = asyncio.run(repo.get_paginated(0, 10, fork=0)) + blocks_f1, count_f1 = asyncio.run(repo.get_paginated(0, 10, fork=1)) + + assert count_f0 == 2 # genesis + A + assert count_f1 == 1 # B only + assert {b.hash for b in blocks_f0} == {b"\x01", b"\x02"} + assert {b.hash for b in blocks_f1} == {b"\x03"}