From 6be89f02f01f690dc8cafbb44e949381e513c7e3 Mon Sep 17 00:00:00 2001 From: gmega Date: Wed, 22 Jan 2025 20:03:10 -0300 Subject: [PATCH] feat: implement sliced scrolling for logstash source --- benchmarks/core/concurrency.py | 56 +++++++++++++++++++ benchmarks/core/tests/test_concurrency.py | 50 +++++++++++++++++ benchmarks/logging/sources/logstash.py | 33 ++++++++++- .../sources/tests/test_logstash_source.py | 11 ++++ 4 files changed, 149 insertions(+), 1 deletion(-) create mode 100644 benchmarks/core/concurrency.py create mode 100644 benchmarks/core/tests/test_concurrency.py diff --git a/benchmarks/core/concurrency.py b/benchmarks/core/concurrency.py new file mode 100644 index 0000000..a2183b7 --- /dev/null +++ b/benchmarks/core/concurrency.py @@ -0,0 +1,56 @@ +from concurrent.futures.thread import ThreadPoolExecutor +from queue import Queue +from typing import Iterable, Iterator, List + +from typing_extensions import TypeVar + + +class _End: + pass + + +T = TypeVar("T") + + +def pflatmap( + tasks: List[Iterable[T]], workers: int, max_queue_size: int = 0 +) -> Iterator[T]: + """ + Parallel flatmap. + + :param tasks: Iterables to be run in separate threads. Typically generators. + :param workers: Number of workers to use. + :param max_queue_size: Maximum size of backlogged items. + + :return: An iterator over the items produced by the tasks. + """ + + q = Queue[T | _End](max_queue_size) + + def _consume(task: Iterable[T]) -> None: + try: + for item in task: + q.put(item) + finally: + q.put(_End()) + + executor = ThreadPoolExecutor(max_workers=workers) + try: + task_futures = [executor.submit(_consume, task) for task in tasks] + active_tasks = len(task_futures) + + while True: + item = q.get() + if isinstance(item, _End): + active_tasks -= 1 + if active_tasks == 0: + break + else: + yield item + + # This will cause any exceptions thrown in tasks to be re-raised. + for future in task_futures: + future.result() + + finally: + executor.shutdown(wait=True) diff --git a/benchmarks/core/tests/test_concurrency.py b/benchmarks/core/tests/test_concurrency.py new file mode 100644 index 0000000..e330133 --- /dev/null +++ b/benchmarks/core/tests/test_concurrency.py @@ -0,0 +1,50 @@ +from threading import Semaphore +from typing import Iterable + +import pytest + +from benchmarks.core.concurrency import pflatmap + + +def test_should_run_iterators_in_separate_threads(): + sema = Semaphore(0) + + def task() -> Iterable[int]: + assert sema.acquire(timeout=10) + yield from range(10) + + it = pflatmap([task(), task()], workers=2) + + sema.release() + for i in range(10): + assert next(it) == i + + sema.release() + for i in range(10): + assert next(it) == i + + with pytest.raises(StopIteration): + next(it) + + +def test_should_raise_exceptions_raised_by_tasks_at_the_end(): + def task() -> Iterable[int]: + yield from range(10) + + def faulty_task(): + yield "yay" + raise ValueError("I'm very faulty") + + reference_vals = set(list(range(10)) + ["yay"]) + actual_vals = set() + + it = pflatmap([task(), faulty_task()], workers=2) + + try: + for val in it: + actual_vals.add(val) + assert False, "ValueError was not raised" + except ValueError: + pass + + assert actual_vals == reference_vals diff --git a/benchmarks/logging/sources/logstash.py b/benchmarks/logging/sources/logstash.py index e6f05e4..384d055 100644 --- a/benchmarks/logging/sources/logstash.py +++ b/benchmarks/logging/sources/logstash.py @@ -5,6 +5,7 @@ from typing import Optional, Tuple, Any, Dict, List from elasticsearch import Elasticsearch +from benchmarks.core.concurrency import pflatmap from benchmarks.logging.sources.sources import LogSource, ExperimentId, NodeId, RawLine GROUP_LABEL = "app.kubernetes.io/part-of" @@ -24,6 +25,7 @@ class LogstashSource(LogSource): client: Elasticsearch, structured_only: bool = False, chronological: bool = False, + slices: int = 1, horizon: int = DEFAULT_HORIZON, today: Optional[datetime.date] = None, ): @@ -36,6 +38,7 @@ class LogstashSource(LogSource): self.client = client self.structured_only = structured_only self.chronological = chronological + self.slices = slices self._indexes = self._generate_indexes(today, horizon) def __enter__(self): @@ -89,13 +92,34 @@ class LogstashSource(LogSource): if self.chronological: query["sort"] = [{"@timestamp": {"order": "asc"}}] + else: + # More efficient, as per https://www.elastic.co/guide/en/elasticsearch/reference/current/paginate-search-results.html#scroll-search-results + query["sort"] = ["_doc"] # We can probably cache this, but for now OK. actual_indexes = [ index for index in self.indexes if self.client.indices.exists(index=index) ] - # Scrolls are much cheaper than queries. + if self.slices > 1: + yield from pflatmap( + [ + self._run_scroll(sliced_query, actual_indexes) + for sliced_query in self._sliced_queries(query) + ], + workers=self.slices, + max_queue_size=100_000, + ) + else: + yield from self._run_scroll(query, actual_indexes) + + def _sliced_queries(self, query: Dict[str, Any]) -> Iterator[Dict[str, Any]]: + for i in range(self.slices): + query_slice = query.copy() + query_slice["slice"] = {"id": i, "max": self.slices} + yield query_slice + + def _run_scroll(self, query: Dict[str, Any], actual_indexes: List[str]): scroll_response = self.client.search( index=actual_indexes, body=query, scroll="2m", size=ES_MAX_BATCH_SIZE ) @@ -104,6 +128,7 @@ class LogstashSource(LogSource): try: while True: hits = scroll_response["hits"]["hits"] + logger.info(f"Retrieved {len(hits)} log entries.") if not hits: break @@ -132,6 +157,12 @@ class LogstashSource(LogSource): # Clean up scroll context self.client.clear_scroll(scroll_id=scroll_id) + def __str__(self): + return ( + f"LogstashSource(client={self.client}, structured_only={self.structured_only}, " + f"chronological={self.chronological}, indexes={self.indexes})" + ) + def _generate_indexes(self, today: Optional[datetime.date], horizon: int): if today is None: today = datetime.date.today() diff --git a/benchmarks/logging/sources/tests/test_logstash_source.py b/benchmarks/logging/sources/tests/test_logstash_source.py index 2d90466..5a01f71 100644 --- a/benchmarks/logging/sources/tests/test_logstash_source.py +++ b/benchmarks/logging/sources/tests/test_logstash_source.py @@ -38,6 +38,17 @@ def test_should_retrieve_unstructured_log_messages(benchmark_logs_client): assert not all(">>" in line for line in lines) +@pytest.mark.integration +def test_should_retrieve_the_same_results_when_slicing(benchmark_logs_client): + source = LogstashSource(benchmark_logs_client, chronological=True) + unsliced = set(source.logs(group_id="g3")) + + source = LogstashSource(benchmark_logs_client, chronological=True, slices=2) + sliced = set(source.logs(group_id="g3")) + + assert unsliced == sliced + + @pytest.mark.integration def test_filter_out_unstructured_log_messages(benchmark_logs_client): source = LogstashSource(