feat: implement sliced scrolling for logstash source

This commit is contained in:
gmega 2025-01-22 20:03:10 -03:00
parent 8096c9f4e0
commit 6be89f02f0
No known key found for this signature in database
GPG Key ID: 6290D34EAD824B18
4 changed files with 149 additions and 1 deletions

View File

@ -0,0 +1,56 @@
from concurrent.futures.thread import ThreadPoolExecutor
from queue import Queue
from typing import Iterable, Iterator, List
from typing_extensions import TypeVar
class _End:
pass
T = TypeVar("T")
def pflatmap(
tasks: List[Iterable[T]], workers: int, max_queue_size: int = 0
) -> Iterator[T]:
"""
Parallel flatmap.
:param tasks: Iterables to be run in separate threads. Typically generators.
:param workers: Number of workers to use.
:param max_queue_size: Maximum size of backlogged items.
:return: An iterator over the items produced by the tasks.
"""
q = Queue[T | _End](max_queue_size)
def _consume(task: Iterable[T]) -> None:
try:
for item in task:
q.put(item)
finally:
q.put(_End())
executor = ThreadPoolExecutor(max_workers=workers)
try:
task_futures = [executor.submit(_consume, task) for task in tasks]
active_tasks = len(task_futures)
while True:
item = q.get()
if isinstance(item, _End):
active_tasks -= 1
if active_tasks == 0:
break
else:
yield item
# This will cause any exceptions thrown in tasks to be re-raised.
for future in task_futures:
future.result()
finally:
executor.shutdown(wait=True)

View File

@ -0,0 +1,50 @@
from threading import Semaphore
from typing import Iterable
import pytest
from benchmarks.core.concurrency import pflatmap
def test_should_run_iterators_in_separate_threads():
sema = Semaphore(0)
def task() -> Iterable[int]:
assert sema.acquire(timeout=10)
yield from range(10)
it = pflatmap([task(), task()], workers=2)
sema.release()
for i in range(10):
assert next(it) == i
sema.release()
for i in range(10):
assert next(it) == i
with pytest.raises(StopIteration):
next(it)
def test_should_raise_exceptions_raised_by_tasks_at_the_end():
def task() -> Iterable[int]:
yield from range(10)
def faulty_task():
yield "yay"
raise ValueError("I'm very faulty")
reference_vals = set(list(range(10)) + ["yay"])
actual_vals = set()
it = pflatmap([task(), faulty_task()], workers=2)
try:
for val in it:
actual_vals.add(val)
assert False, "ValueError was not raised"
except ValueError:
pass
assert actual_vals == reference_vals

View File

@ -5,6 +5,7 @@ from typing import Optional, Tuple, Any, Dict, List
from elasticsearch import Elasticsearch
from benchmarks.core.concurrency import pflatmap
from benchmarks.logging.sources.sources import LogSource, ExperimentId, NodeId, RawLine
GROUP_LABEL = "app.kubernetes.io/part-of"
@ -24,6 +25,7 @@ class LogstashSource(LogSource):
client: Elasticsearch,
structured_only: bool = False,
chronological: bool = False,
slices: int = 1,
horizon: int = DEFAULT_HORIZON,
today: Optional[datetime.date] = None,
):
@ -36,6 +38,7 @@ class LogstashSource(LogSource):
self.client = client
self.structured_only = structured_only
self.chronological = chronological
self.slices = slices
self._indexes = self._generate_indexes(today, horizon)
def __enter__(self):
@ -89,13 +92,34 @@ class LogstashSource(LogSource):
if self.chronological:
query["sort"] = [{"@timestamp": {"order": "asc"}}]
else:
# More efficient, as per https://www.elastic.co/guide/en/elasticsearch/reference/current/paginate-search-results.html#scroll-search-results
query["sort"] = ["_doc"]
# We can probably cache this, but for now OK.
actual_indexes = [
index for index in self.indexes if self.client.indices.exists(index=index)
]
# Scrolls are much cheaper than queries.
if self.slices > 1:
yield from pflatmap(
[
self._run_scroll(sliced_query, actual_indexes)
for sliced_query in self._sliced_queries(query)
],
workers=self.slices,
max_queue_size=100_000,
)
else:
yield from self._run_scroll(query, actual_indexes)
def _sliced_queries(self, query: Dict[str, Any]) -> Iterator[Dict[str, Any]]:
for i in range(self.slices):
query_slice = query.copy()
query_slice["slice"] = {"id": i, "max": self.slices}
yield query_slice
def _run_scroll(self, query: Dict[str, Any], actual_indexes: List[str]):
scroll_response = self.client.search(
index=actual_indexes, body=query, scroll="2m", size=ES_MAX_BATCH_SIZE
)
@ -104,6 +128,7 @@ class LogstashSource(LogSource):
try:
while True:
hits = scroll_response["hits"]["hits"]
logger.info(f"Retrieved {len(hits)} log entries.")
if not hits:
break
@ -132,6 +157,12 @@ class LogstashSource(LogSource):
# Clean up scroll context
self.client.clear_scroll(scroll_id=scroll_id)
def __str__(self):
return (
f"LogstashSource(client={self.client}, structured_only={self.structured_only}, "
f"chronological={self.chronological}, indexes={self.indexes})"
)
def _generate_indexes(self, today: Optional[datetime.date], horizon: int):
if today is None:
today = datetime.date.today()

View File

@ -38,6 +38,17 @@ def test_should_retrieve_unstructured_log_messages(benchmark_logs_client):
assert not all(">>" in line for line in lines)
@pytest.mark.integration
def test_should_retrieve_the_same_results_when_slicing(benchmark_logs_client):
source = LogstashSource(benchmark_logs_client, chronological=True)
unsliced = set(source.logs(group_id="g3"))
source = LogstashSource(benchmark_logs_client, chronological=True, slices=2)
sliced = set(source.logs(group_id="g3"))
assert unsliced == sliced
@pytest.mark.integration
def test_filter_out_unstructured_log_messages(benchmark_logs_client):
source = LogstashSource(