From 271d6c2bf57b8330fe73797ee68613f0f22657db Mon Sep 17 00:00:00 2001 From: shash256 <111925100+shash256@users.noreply.github.com> Date: Sun, 20 Apr 2025 20:42:52 +0530 Subject: [PATCH] feat: add combined test --- sds_wrapper_test.go | 210 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 210 insertions(+) diff --git a/sds_wrapper_test.go b/sds_wrapper_test.go index fa75c40..7c418b8 100644 --- a/sds_wrapper_test.go +++ b/sds_wrapper_test.go @@ -435,6 +435,216 @@ func TestCallback_OnPeriodicSync(t *testing.T) { } } +// Combined Test for multiple callbacks +func TestCallbacks_Combined(t *testing.T) { + channelID := "test-cb-combined" + + // Create sender and receiver handles + handleSender, err := NewReliabilityManager(channelID) + if err != nil { + t.Fatalf("NewReliabilityManager (sender) failed: %v", err) + } + defer CleanupReliabilityManager(handleSender) + + handleReceiver, err := NewReliabilityManager(channelID) + if err != nil { + t.Fatalf("NewReliabilityManager (receiver) failed: %v", err) + } + defer CleanupReliabilityManager(handleReceiver) + + // Channels for synchronization + readyChan1 := make(chan bool, 1) + sentChan1 := make(chan bool, 1) + missingChan := make(chan []MessageID, 1) + + // Use maps for verification + receivedReady := make(map[MessageID]bool) + receivedSent := make(map[MessageID]bool) + var cbMutex sync.Mutex + + callbacksReceiver := Callbacks{ + OnMessageReady: func(messageId MessageID) { + cbMutex.Lock() + receivedReady[messageId] = true + cbMutex.Unlock() + if messageId == "cb-comb-1" { + // Use non-blocking send + select { + case readyChan1 <- true: + default: + } + } + }, + OnMessageSent: func(messageId MessageID) { + // This callback is registered on Receiver, but Sent events + // are typically relevant to the Sender. We don't expect this. + t.Errorf("Unexpected OnMessageSent call on Receiver for %s", messageId) + }, + OnMissingDependencies: func(messageId MessageID, missingDeps []MessageID) { + // This callback is registered on Receiver, used for handleReceiver2 below + }, + } + + callbacksSender := Callbacks{ + OnMessageReady: func(messageId MessageID) { + // Not expected on sender in this test flow + }, + OnMessageSent: func(messageId MessageID) { + cbMutex.Lock() + receivedSent[messageId] = true + cbMutex.Unlock() + if messageId == "cb-comb-1" { + select { + case sentChan1 <- true: + default: + } + } + }, + OnMissingDependencies: func(messageId MessageID, missingDeps []MessageID) { + // Not expected on sender + }, + } + + // Register callbacks + err = RegisterCallback(handleReceiver, callbacksReceiver) + if err != nil { + t.Fatalf("RegisterCallback (Receiver) failed: %v", err) + } + err = RegisterCallback(handleSender, callbacksSender) + if err != nil { + t.Fatalf("RegisterCallback (Sender) failed: %v", err) + } + + // --- Test Scenario --- + msgID1 := MessageID("cb-comb-1") + msgID2 := MessageID("cb-comb-2") + msgID3 := MessageID("cb-comb-3") + payload1 := []byte("combined test 1") + payload2 := []byte("combined test 2") + payload3 := []byte("combined test 3") + + // 1. Sender sends msg1 + wrappedMsg1, err := WrapOutgoingMessage(handleSender, payload1, msgID1) + if err != nil { + t.Fatalf("WrapOutgoingMessage (1) failed: %v", err) + } + + // 2. Receiver receives msg1 + _, _, err = UnwrapReceivedMessage(handleReceiver, wrappedMsg1) + if err != nil { + t.Fatalf("UnwrapReceivedMessage (1) failed: %v", err) + } + + // 3. Receiver sends msg2 (depends on msg1 implicitly via state) + wrappedMsg2, err := WrapOutgoingMessage(handleReceiver, payload2, msgID2) + if err != nil { + t.Fatalf("WrapOutgoingMessage (2) on Receiver failed: %v", err) + } + + // 4. Sender receives msg2 from Receiver (acks msg1 for sender) + _, _, err = UnwrapReceivedMessage(handleSender, wrappedMsg2) + if err != nil { + t.Fatalf("UnwrapReceivedMessage (2) on Sender failed: %v", err) + } + + // 5. Sender sends msg3 (depends on msg2) + wrappedMsg3, err := WrapOutgoingMessage(handleSender, payload3, msgID3) + if err != nil { + t.Fatalf("WrapOutgoingMessage (3) failed: %v", err) + } + + // 6. Create Receiver2, register missing deps callback + handleReceiver2, err := NewReliabilityManager(channelID) + if err != nil { + t.Fatalf("NewReliabilityManager (Receiver2) failed: %v", err) + } + defer CleanupReliabilityManager(handleReceiver2) + + callbacksReceiver2 := Callbacks{ + OnMissingDependencies: func(messageId MessageID, missingDeps []MessageID) { + if messageId == msgID3 { + select { + case missingChan <- missingDeps: + default: + } + } + }, + } + err = RegisterCallback(handleReceiver2, callbacksReceiver2) + if err != nil { + t.Fatalf("RegisterCallback (Receiver2) failed: %v", err) + } + + // 7. Receiver2 receives msg3 (should report missing msg1, msg2) + _, _, err = UnwrapReceivedMessage(handleReceiver2, wrappedMsg3) + if err != nil { + t.Fatalf("UnwrapReceivedMessage (3) on Receiver2 failed: %v", err) + } + + // --- Verification --- + timeout := 5 * time.Second + expectedReady1 := false + expectedSent1 := false + var reportedMissingDeps []MessageID + missingDepsReceived := false + + receivedCount := 0 + expectedCount := 3 // ready1, sent1, missingDeps + timer := time.NewTimer(timeout) + defer timer.Stop() + + for receivedCount < expectedCount { + select { + case <-readyChan1: + if !expectedReady1 { // Avoid double counting if signaled twice + expectedReady1 = true + receivedCount++ + } + case <-sentChan1: + if !expectedSent1 { + expectedSent1 = true + receivedCount++ + } + case deps := <-missingChan: + if !missingDepsReceived { + reportedMissingDeps = deps + missingDepsReceived = true + receivedCount++ + } + case <-timer.C: + t.Fatalf("Timed out waiting for combined callbacks (received %d out of %d)", receivedCount, expectedCount) + } + } + + // Check results + cbMutex.Lock() + defer cbMutex.Unlock() + + if !expectedReady1 || !receivedReady[msgID1] { + t.Errorf("OnMessageReady not called/verified for %s", msgID1) + } + if !expectedSent1 || !receivedSent[msgID1] { + t.Errorf("OnMessageSent not called/verified for %s", msgID1) + } + if !missingDepsReceived { + t.Errorf("OnMissingDependencies not called/verified for %s", msgID3) + } else { + foundDep1 := false + foundDep2 := false + for _, dep := range reportedMissingDeps { + if dep == msgID1 { + foundDep1 = true + } + if dep == msgID2 { + foundDep2 = true + } + } + if !foundDep1 || !foundDep2 { + t.Errorf("OnMissingDependencies for %s reported wrong deps: got %v, want %s and %s", msgID3, reportedMissingDeps, msgID1, msgID2) + } + } +} + // Helper function to wait for WaitGroup with a timeout func waitTimeout(wg *sync.WaitGroup, timeout time.Duration, t *testing.T) { c := make(chan struct{})