diff --git a/sds/sds.go b/sds/sds.go index b1ed01c..8d77e82 100644 --- a/sds/sds.go +++ b/sds/sds.go @@ -257,7 +257,7 @@ func (rm *ReliabilityManager) RegisterCallbacks(callbacks EventCallbacks) { func (rm *ReliabilityManager) OnEvent(eventStr string) { - fmt.Println("------------------- received event: ", eventStr) + fmt.Println("------------------- received event: ", eventStr) // TODO: remove after debugging jsonEvent := jsonEvent{} err := json.Unmarshal([]byte(eventStr), &jsonEvent) @@ -385,7 +385,7 @@ func (rm *ReliabilityManager) WrapOutgoingMessage(message []byte, messageId Mess return nil, err } - Debug("Wraping outgoing message %v", messageId) + Debug("Wrapping outgoing message %v", messageId) wg := sync.WaitGroup{} var resp = C.allocResp(unsafe.Pointer(&wg)) diff --git a/sds/sds_test.go b/sds/sds_test.go index cab0155..7407fd3 100644 --- a/sds/sds_test.go +++ b/sds/sds_test.go @@ -1,7 +1,10 @@ package sds import ( + "fmt" + "sync" "testing" + "time" "github.com/stretchr/testify/require" ) @@ -89,3 +92,124 @@ func TestDependencies(t *testing.T) { err = rm2.MarkDependenciesMet([]MessageID{msgID1}) require.NoError(t, err) } + +// Test callbacks +func TestCallbacks(t *testing.T) { + channelID := "test-callbacks" + rm, err := NewReliabilityManager(channelID) + require.NoError(t, err) + defer rm.Cleanup() + + var wg sync.WaitGroup + receivedReady := make(map[MessageID]bool) + receivedSent := make(map[MessageID]bool) + receivedMissing := make(map[MessageID][]MessageID) + syncRequested := false + var cbMutex sync.Mutex // Protect access to callback tracking maps/vars + + callbacks := EventCallbacks{ + OnMessageReady: func(messageId MessageID) { + fmt.Printf("Test: OnMessageReady received: %s\n", messageId) + cbMutex.Lock() + receivedReady[messageId] = true + cbMutex.Unlock() + wg.Done() + }, + OnMessageSent: func(messageId MessageID) { + fmt.Printf("Test: OnMessageSent received: %s\n", messageId) + cbMutex.Lock() + receivedSent[messageId] = true + cbMutex.Unlock() + wg.Done() + }, + OnMissingDependencies: func(messageId MessageID, missingDeps []MessageID) { + fmt.Printf("Test: OnMissingDependencies received for %s: %v\n", messageId, missingDeps) + cbMutex.Lock() + receivedMissing[messageId] = missingDeps + cbMutex.Unlock() + wg.Done() + }, + OnPeriodicSync: func() { + fmt.Println("Test: OnPeriodicSync received") + cbMutex.Lock() + syncRequested = true + cbMutex.Unlock() + // Don't wg.Done() here, it might be called multiple times + }, + } + + rm.RegisterCallbacks(callbacks) + + // Start tasks AFTER registering callbacks + err = rm.StartPeriodicTasks() + require.NoError(t, err) + + // --- Test Scenario --- + + // 1. Send msg1 + wg.Add(1) // Expect OnMessageSent for msg1 eventually + payload1 := []byte("callback test 1") + msgID1 := MessageID("cb-msg-1") + wrappedMsg1, err := rm.WrapOutgoingMessage(payload1, msgID1) + require.NoError(t, err) + + // 2. Receive msg1 (triggers OnMessageReady for msg1, OnMessageSent for msg1 via causal history) + wg.Add(1) // Expect OnMessageReady for msg1 + _, err = rm.UnwrapReceivedMessage(wrappedMsg1) + require.NoError(t, err) + + // 3. Send msg2 (depends on msg1) + wg.Add(1) // Expect OnMessageSent for msg2 eventually + payload2 := []byte("callback test 2") + msgID2 := MessageID("cb-msg-2") + wrappedMsg2, err := rm.WrapOutgoingMessage(payload2, msgID2) + require.NoError(t, err) + + // 4. Receive msg2 (triggers OnMessageReady for msg2, OnMessageSent for msg2) + wg.Add(1) // Expect OnMessageReady for msg2 + _, err = rm.UnwrapReceivedMessage(wrappedMsg2) + require.NoError(t, err) + + // --- Verification --- + // Wait for expected callbacks with a timeout + waitTimeout(&wg, 5*time.Second, t) + + cbMutex.Lock() + defer cbMutex.Unlock() + + if !receivedReady[msgID1] { + t.Errorf("OnMessageReady not called for %s", msgID1) + } + if !receivedReady[msgID2] { + t.Errorf("OnMessageReady not called for %s", msgID2) + } + if !receivedSent[msgID1] { + t.Errorf("OnMessageSent not called for %s", msgID1) + } + if !receivedSent[msgID2] { + t.Errorf("OnMessageSent not called for %s", msgID2) + } + // We didn't explicitly test missing deps in this path + if len(receivedMissing) > 0 { + t.Errorf("Unexpected OnMissingDependencies calls: %v", receivedMissing) + } + // Periodic sync is harder to guarantee in a short test, just check if it was ever true + if !syncRequested { + t.Logf("Warning: OnPeriodicSync might not have been called within the test timeout") + } +} + +// Helper function to wait for WaitGroup with a timeout +func waitTimeout(wg *sync.WaitGroup, timeout time.Duration, t *testing.T) { + c := make(chan struct{}) + go func() { + defer close(c) + wg.Wait() + }() + select { + case <-c: + // Completed normally + case <-time.After(timeout): + t.Fatalf("Timed out waiting for callbacks") + } +}