diff --git a/src/node/wrapper_helpers.py b/src/node/wrapper_helpers.py index 8710ed556..7c9dd1a2d 100644 --- a/src/node/wrapper_helpers.py +++ b/src/node/wrapper_helpers.py @@ -45,6 +45,17 @@ class EventCollector: with self._lock: return [e for e in self.events if e.get("requestId") == request_id] + def snapshot(self) -> list[dict]: + """Return a thread-safe copy of all collected events. + + Use this whenever you need to iterate over every event (rather than + events for a single request_id). Iterating `self.events` directly is + unsafe because `event_callback` appends from the wrapper's event + thread. + """ + with self._lock: + return list(self.events) + def is_propagated_event(event: dict) -> bool: return event.get("eventType") == EVENT_PROPAGATED @@ -120,10 +131,9 @@ def wait_for_connected( """Wait until a connection_status_change event with PartiallyConnected or Connected arrives.""" deadline = time.monotonic() + timeout_s while time.monotonic() < deadline: - with collector._lock: - for event in collector.events: - if event.get("eventType") == "connection_status_change" and event.get("connectionStatus") in ("PartiallyConnected", "Connected"): - return event + for event in collector.snapshot(): + if event.get("eventType") == "connection_status_change" and event.get("connectionStatus") in ("PartiallyConnected", "Connected"): + return event time.sleep(poll_interval_s) return None @@ -201,7 +211,7 @@ def assert_no_unknown_request_ids(collector: EventCollector, issued_request_ids) the wrong request id under concurrency. """ issued = set(issued_request_ids) - for event in collector.events: + for event in collector.snapshot(): event_request_id = event.get("requestId") if event_request_id is None: continue diff --git a/tests/wrappers_tests/test_send_e2e_part1.py b/tests/wrappers_tests/test_send_e2e_part1.py index 7ef272f0d..a4bd86c6a 100644 --- a/tests/wrappers_tests/test_send_e2e_part1.py +++ b/tests/wrappers_tests/test_send_e2e_part1.py @@ -1,4 +1,5 @@ from concurrent.futures import ThreadPoolExecutor + import pytest from src.env_vars import NODE_2 from src.steps.common import StepsCommon @@ -805,7 +806,7 @@ class TestSendBeforeRelay(StepsStore): # Cross-association guard: every event with a requestId must # belong to exactly one of the request ids we issued. issued = set(request_ids) - for event in sender_collector.events: + for event in sender_collector.snapshot(): event_request_id = event.get("requestId") if event_request_id is None: continue @@ -917,7 +918,7 @@ class TestSendBeforeRelay(StepsStore): assert error_event is None, f"Unexpected terminal message_error for phase-2 " f"request_id={request_id} after recovery: {error_event}" issued = set(all_request_ids) - for event in sender_collector.events: + for event in sender_collector.snapshot(): event_request_id = event.get("requestId") if event_request_id is None: continue