2026-02-16 23:58:06 +04:00

139 lines
5.3 KiB
Python

from asyncio import sleep
from typing import AsyncIterator, List
from rusty_results import Empty, Option, Some
from sqlalchemy import Result, Select, func as sa_func
from sqlalchemy.orm import aliased, selectinload
from sqlmodel import select
from db.clients import DbClient
from models.block import Block
from models.transactions.transaction import Transaction
def get_latest_statement(limit: int, *, fork: int, output_ascending: bool, preload_relationships: bool) -> Select:
# Join with Block to order by Block's height and fetch the latest N transactions in descending order
base = (
select(Transaction, Block.height.label("block__height"))
.join(Block, Transaction.block_id == Block.id)
.where(Block.fork == fork)
.order_by(Block.height.desc(), Transaction.id.desc())
.limit(limit)
)
if not output_ascending:
return base
# Reorder for output
inner = base.subquery()
latest = aliased(Transaction, inner)
statement = select(latest).order_by(inner.c.block__height.asc(), latest.id.asc())
if preload_relationships:
statement = statement.options(selectinload(latest.block))
return statement
class TransactionRepository:
def __init__(self, client: DbClient):
self.client = client
async def create(self, *transaction: Transaction) -> None:
with self.client.session() as session:
session.add_all(list(transaction))
session.commit()
async def get_by_id(self, transaction_id: int) -> Option[Transaction]:
statement = select(Transaction).where(Transaction.id == transaction_id)
with self.client.session() as session:
result: Result[Transaction] = session.exec(statement)
if (transaction := result.one_or_none()) is not None:
return Some(transaction)
else:
return Empty()
async def get_by_hash(self, transaction_hash: bytes, *, fork: int) -> Option[Transaction]:
statement = (
select(Transaction)
.join(Block, Transaction.block_id == Block.id)
.where(Transaction.hash == transaction_hash, Block.fork == fork)
)
with self.client.session() as session:
result: Result[Transaction] = session.exec(statement)
if (transaction := result.first()) is not None:
return Some(transaction)
else:
return Empty()
async def get_latest(
self, limit: int, *, fork: int, ascending: bool = False, preload_relationships: bool = False
) -> List[Transaction]:
if limit == 0:
return []
statement = get_latest_statement(limit, fork=fork, output_ascending=ascending, preload_relationships=preload_relationships)
with self.client.session() as session:
results: Result[Transaction] = session.exec(statement)
return results.all()
async def get_paginated(self, page: int, page_size: int, *, fork: int) -> tuple[List[Transaction], int]:
"""
Get transactions with pagination, ordered by block height descending (newest first).
Returns a tuple of (transactions, total_count).
"""
offset = page * page_size
with self.client.session() as session:
# Get total count for this fork
count_statement = (
select(sa_func.count())
.select_from(Transaction)
.join(Block, Transaction.block_id == Block.id)
.where(Block.fork == fork)
)
total_count = session.exec(count_statement).one()
# Get paginated transactions
statement = (
select(Transaction)
.options(selectinload(Transaction.block))
.join(Block, Transaction.block_id == Block.id)
.where(Block.fork == fork)
.order_by(Block.height.desc(), Transaction.id.desc())
.offset(offset)
.limit(page_size)
)
transactions = session.exec(statement).all()
return transactions, total_count
async def updates_stream(
self, transaction_from: Option[Transaction], *, fork: int, timeout_seconds: int = 1
) -> AsyncIterator[List[Transaction]]:
height_cursor = transaction_from.map(lambda transaction: transaction.block.height).unwrap_or(0)
transaction_id_cursor = transaction_from.map(lambda transaction: transaction.id + 1).unwrap_or(0)
while True:
statement = (
select(Transaction)
.options(selectinload(Transaction.block))
.join(Block, Transaction.block_id == Block.id)
.where(
Block.fork == fork,
Block.height >= height_cursor,
Transaction.id >= transaction_id_cursor,
)
.order_by(Block.height.asc(), Transaction.id.asc())
)
with self.client.session() as session:
transactions: List[Transaction] = session.exec(statement).all()
if len(transactions) > 0:
height_cursor = transactions[-1].block.height
transaction_id_cursor = transactions[-1].id + 1
yield transactions
else:
await sleep(timeout_seconds)