From 05e797c76b8c02f805d596c2f9e146ecd10a7609 Mon Sep 17 00:00:00 2001 From: Akhil <111925100+shash256@users.noreply.github.com> Date: Thu, 24 Jul 2025 15:22:49 +0530 Subject: [PATCH] feat: Add Support for Multiple Channels in Single Reliability Manager (#5) --- sds/sds.go | 80 ++++++++-------- sds/sds_test.go | 236 ++++++++++++++++++++++++++++++++++++++---------- sds/types.go | 1 + 3 files changed, 232 insertions(+), 85 deletions(-) diff --git a/sds/sds.go b/sds/sds.go index e111ef0..cd8168e 100644 --- a/sds/sds.go +++ b/sds/sds.go @@ -56,9 +56,9 @@ package sds // resp must be set != NULL in case interest on retrieving data from the callback void SdsGoCallback(int ret, char* msg, size_t len, void* resp); - static void* cGoSdsNewReliabilityManager(const char* channelId, void* resp) { + static void* cGoSdsNewReliabilityManager(void* resp) { // We pass NULL because we are not interested in retrieving data from this callback - void* ret = SdsNewReliabilityManager(channelId, (SdsCallBack) SdsGoCallback, resp); + void* ret = SdsNewReliabilityManager((SdsCallBack) SdsGoCallback, resp); return ret; } @@ -87,16 +87,18 @@ package sds } static void cGoSdsWrapOutgoingMessage(void* rmCtx, - void* message, - size_t messageLen, - const char* messageId, - void* resp) { + void* message, + size_t messageLen, + const char* messageId, + const char* channelId, + void* resp) { SdsWrapOutgoingMessage(rmCtx, - message, - messageLen, - messageId, - (SdsCallBack) SdsGoCallback, - resp); + message, + messageLen, + messageId, + channelId, + (SdsCallBack) SdsGoCallback, + resp); } static void cGoSdsUnwrapReceivedMessage(void* rmCtx, void* message, @@ -110,14 +112,16 @@ package sds } static void cGoSdsMarkDependenciesMet(void* rmCtx, - char** messageIDs, - size_t count, - void* resp) { + char** messageIDs, + size_t count, + const char* channelId, + void* resp) { SdsMarkDependenciesMet(rmCtx, - messageIDs, - count, - (SdsCallBack) SdsGoCallback, - resp); + messageIDs, + count, + channelId, + (SdsCallBack) SdsGoCallback, + resp); } static void cGoSdsStartPeriodicTasks(void* rmCtx, void* resp) { @@ -152,31 +156,25 @@ func SdsGoCallback(ret C.int, msg *C.char, len C.size_t, resp unsafe.Pointer) { } type EventCallbacks struct { - OnMessageReady func(messageId MessageID) - OnMessageSent func(messageId MessageID) - OnMissingDependencies func(messageId MessageID, missingDeps []MessageID) + OnMessageReady func(messageId MessageID, channelId string) + OnMessageSent func(messageId MessageID, channelId string) + OnMissingDependencies func(messageId MessageID, missingDeps []MessageID, channelId string) OnPeriodicSync func() } // ReliabilityManager represents an instance of a nim-sds ReliabilityManager type ReliabilityManager struct { rmCtx unsafe.Pointer - channelId string callbacks EventCallbacks } -func NewReliabilityManager(channelId string) (*ReliabilityManager, error) { +func NewReliabilityManager() (*ReliabilityManager, error) { Debug("Creating new Reliability Manager") - rm := &ReliabilityManager{ - channelId: channelId, - } + rm := &ReliabilityManager{} wg := sync.WaitGroup{} - var cChannelId = C.CString(string(channelId)) var resp = C.allocResp(unsafe.Pointer(&wg)) - - defer C.free(unsafe.Pointer(cChannelId)) defer C.freeResp(resp) if C.getRet(resp) != C.RET_OK { @@ -186,7 +184,7 @@ func NewReliabilityManager(channelId string) (*ReliabilityManager, error) { } wg.Add(1) - rm.rmCtx = C.cGoSdsNewReliabilityManager(cChannelId, resp) + rm.rmCtx = C.cGoSdsNewReliabilityManager(resp) wg.Wait() C.cGoSdsSetEventCallback(rm.rmCtx) @@ -243,11 +241,13 @@ type jsonEvent struct { type msgEvent struct { MessageId MessageID `json:"messageId"` + ChannelId string `json:"channelId"` } type missingDepsEvent struct { MessageId MessageID `json:"messageId"` MissingDeps []MessageID `json:"missingDeps"` + ChannelId string `json:"channelId"` } func (rm *ReliabilityManager) RegisterCallbacks(callbacks EventCallbacks) { @@ -288,7 +288,7 @@ func (rm *ReliabilityManager) parseMessageReadyEvent(eventStr string) { } if rm.callbacks.OnMessageReady != nil { - rm.callbacks.OnMessageReady(msgEvent.MessageId) + rm.callbacks.OnMessageReady(msgEvent.MessageId, msgEvent.ChannelId) } } @@ -301,7 +301,7 @@ func (rm *ReliabilityManager) parseMessageSentEvent(eventStr string) { } if rm.callbacks.OnMessageSent != nil { - rm.callbacks.OnMessageSent(msgEvent.MessageId) + rm.callbacks.OnMessageSent(msgEvent.MessageId, msgEvent.ChannelId) } } @@ -314,7 +314,7 @@ func (rm *ReliabilityManager) parseMissingDepsEvent(eventStr string) { } if rm.callbacks.OnMissingDependencies != nil { - rm.callbacks.OnMissingDependencies(missingDepsEvent.MessageId, missingDepsEvent.MissingDeps) + rm.callbacks.OnMissingDependencies(missingDepsEvent.MessageId, missingDepsEvent.MissingDeps, missingDepsEvent.ChannelId) } } @@ -375,7 +375,7 @@ func (rm *ReliabilityManager) Reset() error { return errors.New(errMsg) } -func (rm *ReliabilityManager) WrapOutgoingMessage(message []byte, messageId MessageID) ([]byte, error) { +func (rm *ReliabilityManager) WrapOutgoingMessage(message []byte, messageId MessageID, channelId string) ([]byte, error) { if rm == nil { err := errors.New("reliability manager is nil in WrapOutgoingMessage") Error("Failed to wrap outgoing message %v", err) @@ -400,8 +400,11 @@ func (rm *ReliabilityManager) WrapOutgoingMessage(message []byte, messageId Mess } cMessageLen := C.size_t(len(message)) + cChannelId := C.CString(channelId) + defer C.free(unsafe.Pointer(cChannelId)) + wg.Add(1) - C.cGoSdsWrapOutgoingMessage(rm.rmCtx, cMessagePtr, cMessageLen, cMessageId, resp) + C.cGoSdsWrapOutgoingMessage(rm.rmCtx, cMessagePtr, cMessageLen, cMessageId, cChannelId, resp) wg.Wait() if C.getRet(resp) == C.RET_OK { @@ -481,7 +484,7 @@ func (rm *ReliabilityManager) UnwrapReceivedMessage(message []byte) (*UnwrappedM } // MarkDependenciesMet informs the library that dependencies are met -func (rm *ReliabilityManager) MarkDependenciesMet(messageIDs []MessageID) error { +func (rm *ReliabilityManager) MarkDependenciesMet(messageIDs []MessageID, channelId string) error { if rm == nil { err := errors.New("reliability manager is nil in MarkDependenciesMet") Error("Failed to mark dependencies met %v", err) @@ -512,8 +515,11 @@ func (rm *ReliabilityManager) MarkDependenciesMet(messageIDs []MessageID) error } wg.Add(1) + cChannelId := C.CString(channelId) + defer C.free(unsafe.Pointer(cChannelId)) + // Pass the pointer variable (cMessageIDsPtr) directly, which is of type **C.char - C.cGoSdsMarkDependenciesMet(rm.rmCtx, cMessageIDsPtr, C.size_t(len(messageIDs)), resp) + C.cGoSdsMarkDependenciesMet(rm.rmCtx, cMessageIDsPtr, C.size_t(len(messageIDs)), cChannelId, resp) wg.Wait() if C.getRet(resp) == C.RET_OK { diff --git a/sds/sds_test.go b/sds/sds_test.go index 36e37f2..5e80cf7 100644 --- a/sds/sds_test.go +++ b/sds/sds_test.go @@ -10,8 +10,7 @@ import ( // Test basic creation, cleanup, and reset func TestLifecycle(t *testing.T) { - channelID := "test-lifecycle" - rm, err := NewReliabilityManager(channelID) + rm, err := NewReliabilityManager() require.NoError(t, err) require.NotNil(t, rm, "Expected ReliabilityManager to be not nil") @@ -23,15 +22,15 @@ func TestLifecycle(t *testing.T) { // Test wrapping and unwrapping a simple message func TestWrapUnwrap(t *testing.T) { - channelID := "test-wrap-unwrap" - rm, err := NewReliabilityManager(channelID) + rm, err := NewReliabilityManager() require.NoError(t, err) defer rm.Cleanup() + channelID := "test-wrap-unwrap" originalPayload := []byte("hello reliability") messageID := MessageID("msg-wrap-1") - wrappedMsg, err := rm.WrapOutgoingMessage(originalPayload, messageID) + wrappedMsg, err := rm.WrapOutgoingMessage(originalPayload, messageID, channelID) require.NoError(t, err) require.Greater(t, len(wrappedMsg), 0, "Expected non-empty wrapped message") @@ -46,15 +45,16 @@ func TestWrapUnwrap(t *testing.T) { // Test dependency handling func TestDependencies(t *testing.T) { - channelID := "test-deps" - rm, err := NewReliabilityManager(channelID) + rm, err := NewReliabilityManager() require.NoError(t, err) defer rm.Cleanup() + channelID := "test-deps" + // 1. Send message 1 (will become a dependency) payload1 := []byte("message one") msgID1 := MessageID("msg-dep-1") - wrappedMsg1, err := rm.WrapOutgoingMessage(payload1, msgID1) + wrappedMsg1, err := rm.WrapOutgoingMessage(payload1, msgID1, channelID) require.NoError(t, err) // Simulate receiving msg1 to add it to history (implicitly acknowledges it) @@ -64,11 +64,11 @@ func TestDependencies(t *testing.T) { // 2. Send message 2 (depends on message 1 implicitly via causal history) payload2 := []byte("message two") msgID2 := MessageID("msg-dep-2") - wrappedMsg2, err := rm.WrapOutgoingMessage(payload2, msgID2) + wrappedMsg2, err := rm.WrapOutgoingMessage(payload2, msgID2, channelID) require.NoError(t, err) // 3. Create a new manager to simulate a different peer receiving msg2 without msg1 - rm2, err := NewReliabilityManager(channelID) // Same channel ID + rm2, err := NewReliabilityManager() require.NoError(t, err) defer rm2.Cleanup() @@ -88,28 +88,29 @@ func TestDependencies(t *testing.T) { require.True(t, foundDep1, "Expected missing dependency %q, got %v", msgID1, *unwrappedMessage2.MissingDeps) // 5. Mark the dependency as met - err = rm2.MarkDependenciesMet([]MessageID{msgID1}) + err = rm2.MarkDependenciesMet([]MessageID{msgID1}, channelID) require.NoError(t, err) } // Test OnMessageReady callback func TestCallback_OnMessageReady(t *testing.T) { - channelID := "test-cb-ready" - // Create sender and receiver RMs - senderRm, err := NewReliabilityManager(channelID) + senderRm, err := NewReliabilityManager() require.NoError(t, err) defer senderRm.Cleanup() - receiverRm, err := NewReliabilityManager(channelID) + receiverRm, err := NewReliabilityManager() require.NoError(t, err) defer receiverRm.Cleanup() + channelID := "test-cb-ready" + // Use a channel for signaling readyChan := make(chan MessageID, 1) callbacks := EventCallbacks{ - OnMessageReady: func(messageId MessageID) { + OnMessageReady: func(messageId MessageID, chId string) { + require.Equal(t, channelID, chId) // Non-blocking send to channel select { case readyChan <- messageId: @@ -127,7 +128,7 @@ func TestCallback_OnMessageReady(t *testing.T) { msgID := MessageID("cb-ready-1") // Wrap on sender - wrappedMsg, err := senderRm.WrapOutgoingMessage(payload, msgID) + wrappedMsg, err := senderRm.WrapOutgoingMessage(payload, msgID, channelID) require.NoError(t, err) // Unwrap on receiver @@ -149,24 +150,25 @@ func TestCallback_OnMessageReady(t *testing.T) { // Test OnMessageSent callback (via causal history ACK) func TestCallback_OnMessageSent(t *testing.T) { - channelID := "test-cb-sent" - // Create two RMs - rm1, err := NewReliabilityManager(channelID) + rm1, err := NewReliabilityManager() require.NoError(t, err) defer rm1.Cleanup() - rm2, err := NewReliabilityManager(channelID) + rm2, err := NewReliabilityManager() require.NoError(t, err) defer rm2.Cleanup() + channelID := "test-cb-sent" + var wg sync.WaitGroup sentCalled := false var sentMsgID MessageID var cbMutex sync.Mutex callbacks := EventCallbacks{ - OnMessageSent: func(messageId MessageID) { + OnMessageSent: func(messageId MessageID, chId string) { + require.Equal(t, channelID, chId) cbMutex.Lock() sentCalled = true sentMsgID = messageId @@ -184,7 +186,7 @@ func TestCallback_OnMessageSent(t *testing.T) { // 1. rm1 sends msg1 payload1 := []byte("sent test 1") msgID1 := MessageID("cb-sent-1") - wrappedMsg1, err := rm1.WrapOutgoingMessage(payload1, msgID1) + wrappedMsg1, err := rm1.WrapOutgoingMessage(payload1, msgID1, channelID) require.NoError(t, err) // Note: msg1 is now in rm1's outgoing buffer @@ -195,7 +197,7 @@ func TestCallback_OnMessageSent(t *testing.T) { // 3. rm2 sends msg2 (will include msg1 in causal history) payload2 := []byte("sent test 2") msgID2 := MessageID("cb-sent-2") - wrappedMsg2, err := rm2.WrapOutgoingMessage(payload2, msgID2) + wrappedMsg2, err := rm2.WrapOutgoingMessage(payload2, msgID2, channelID) require.NoError(t, err) // 4. rm1 receives msg2 (should trigger ACK for msg1) @@ -219,17 +221,17 @@ func TestCallback_OnMessageSent(t *testing.T) { // Test OnMissingDependencies callback func TestCallback_OnMissingDependencies(t *testing.T) { - channelID := "test-cb-missing" - // Use separate sender/receiver RMs explicitly - senderRm, err := NewReliabilityManager(channelID) + senderRm, err := NewReliabilityManager() require.NoError(t, err) defer senderRm.Cleanup() - receiverRm, err := NewReliabilityManager(channelID) + receiverRm, err := NewReliabilityManager() require.NoError(t, err) defer receiverRm.Cleanup() + channelID := "test-cb-missing" + var wg sync.WaitGroup missingCalled := false var missingMsgID MessageID @@ -237,7 +239,8 @@ func TestCallback_OnMissingDependencies(t *testing.T) { var cbMutex sync.Mutex callbacks := EventCallbacks{ - OnMissingDependencies: func(messageId MessageID, missingDeps []MessageID) { + OnMissingDependencies: func(messageId MessageID, missingDeps []MessageID, chId string) { + require.Equal(t, channelID, chId) cbMutex.Lock() missingCalled = true missingMsgID = messageId @@ -256,13 +259,13 @@ func TestCallback_OnMissingDependencies(t *testing.T) { // 1. Sender sends msg1 payload1 := []byte("missing test 1") msgID1 := MessageID("cb-miss-1") - _, err = senderRm.WrapOutgoingMessage(payload1, msgID1) + _, err = senderRm.WrapOutgoingMessage(payload1, msgID1, channelID) require.NoError(t, err) // 2. Sender sends msg2 (depends on msg1) payload2 := []byte("missing test 2") msgID2 := MessageID("cb-miss-2") - wrappedMsg2, err := senderRm.WrapOutgoingMessage(payload2, msgID2) + wrappedMsg2, err := senderRm.WrapOutgoingMessage(payload2, msgID2, channelID) require.NoError(t, err) // 3. Receiver receives msg2 (haven't seen msg1) @@ -295,8 +298,7 @@ func TestCallback_OnMissingDependencies(t *testing.T) { // Test OnPeriodicSync callback func TestCallback_OnPeriodicSync(t *testing.T) { - channelID := "test-cb-sync" - rm, err := NewReliabilityManager(channelID) + rm, err := NewReliabilityManager() require.NoError(t, err) defer rm.Cleanup() @@ -341,17 +343,17 @@ func TestCallback_OnPeriodicSync(t *testing.T) { // Combined Test for multiple callbacks func TestCallbacks_Combined(t *testing.T) { - channelID := "test-cb-combined" - // Create sender and receiver RMs - senderRm, err := NewReliabilityManager(channelID) + senderRm, err := NewReliabilityManager() require.NoError(t, err) defer senderRm.Cleanup() - receiverRm, err := NewReliabilityManager(channelID) + receiverRm, err := NewReliabilityManager() require.NoError(t, err) defer receiverRm.Cleanup() + channelID := "test-cb-combined" + // Channels for synchronization readyChan1 := make(chan bool, 1) sentChan1 := make(chan bool, 1) @@ -363,7 +365,8 @@ func TestCallbacks_Combined(t *testing.T) { var cbMutex sync.Mutex callbacksReceiver := EventCallbacks{ - OnMessageReady: func(messageId MessageID) { + OnMessageReady: func(messageId MessageID, chId string) { + require.Equal(t, channelID, chId) cbMutex.Lock() receivedReady[messageId] = true cbMutex.Unlock() @@ -375,21 +378,22 @@ func TestCallbacks_Combined(t *testing.T) { } } }, - OnMessageSent: func(messageId MessageID) { + OnMessageSent: func(messageId MessageID, chId string) { // This callback is registered on Receiver, but Sent events // are typically relevant to the Sender. We don't expect this. t.Errorf("Unexpected OnMessageSent call on Receiver for %s", messageId) }, - OnMissingDependencies: func(messageId MessageID, missingDeps []MessageID) { + OnMissingDependencies: func(messageId MessageID, missingDeps []MessageID, chId string) { // This callback is registered on Receiver, used for receiverRm2 below }, } callbacksSender := EventCallbacks{ - OnMessageReady: func(messageId MessageID) { + OnMessageReady: func(messageId MessageID, chId string) { // Not expected on sender in this test flow }, - OnMessageSent: func(messageId MessageID) { + OnMessageSent: func(messageId MessageID, chId string) { + require.Equal(t, channelID, chId) cbMutex.Lock() receivedSent[messageId] = true cbMutex.Unlock() @@ -400,7 +404,7 @@ func TestCallbacks_Combined(t *testing.T) { } } }, - OnMissingDependencies: func(messageId MessageID, missingDeps []MessageID) { + OnMissingDependencies: func(messageId MessageID, missingDeps []MessageID, chId string) { // Not expected on sender }, } @@ -418,7 +422,7 @@ func TestCallbacks_Combined(t *testing.T) { payload3 := []byte("combined test 3") // 1. Sender sends msg1 - wrappedMsg1, err := senderRm.WrapOutgoingMessage(payload1, msgID1) + wrappedMsg1, err := senderRm.WrapOutgoingMessage(payload1, msgID1, channelID) require.NoError(t, err) // 2. Receiver receives msg1 @@ -426,7 +430,7 @@ func TestCallbacks_Combined(t *testing.T) { require.NoError(t, err) // 3. Receiver sends msg2 (depends on msg1 implicitly via state) - wrappedMsg2, err := receiverRm.WrapOutgoingMessage(payload2, msgID2) + wrappedMsg2, err := receiverRm.WrapOutgoingMessage(payload2, msgID2, channelID) require.NoError(t, err) // 4. Sender receives msg2 from Receiver (acks msg1 for sender) @@ -434,16 +438,17 @@ func TestCallbacks_Combined(t *testing.T) { require.NoError(t, err) // 5. Sender sends msg3 (depends on msg2) - wrappedMsg3, err := senderRm.WrapOutgoingMessage(payload3, msgID3) + wrappedMsg3, err := senderRm.WrapOutgoingMessage(payload3, msgID3, channelID) require.NoError(t, err) // 6. Create Receiver2, register missing deps callback - receiverRm2, err := NewReliabilityManager(channelID) + receiverRm2, err := NewReliabilityManager() require.NoError(t, err) defer receiverRm2.Cleanup() callbacksReceiver2 := EventCallbacks{ - OnMissingDependencies: func(messageId MessageID, missingDeps []MessageID) { + OnMissingDependencies: func(messageId MessageID, missingDeps []MessageID, chId string) { + require.Equal(t, channelID, chId) if messageId == msgID3 { select { case missingChan <- missingDeps: @@ -537,3 +542,138 @@ func waitTimeout(wg *sync.WaitGroup, timeout time.Duration, t *testing.T) { t.Fatalf("Timed out waiting for callbacks") } } + +// Test multi-channel functionality - one RM can handle messages from different channels +func TestMultiChannel_SingleRM(t *testing.T) { + rm, err := NewReliabilityManager() + require.NoError(t, err) + defer rm.Cleanup() + + // Test with two different channels + channel1 := "test-channel-1" + channel2 := "test-channel-2" + + // Create and wrap messages for different channels + msg1 := []byte("message for channel 1") + msgID1 := MessageID("msg1") + wrappedMsg1, err := rm.WrapOutgoingMessage(msg1, msgID1, channel1) + require.NoError(t, err) + + msg2 := []byte("message for channel 2") + msgID2 := MessageID("msg2") + wrappedMsg2, err := rm.WrapOutgoingMessage(msg2, msgID2, channel2) + require.NoError(t, err) + + // Unwrap messages - should extract channel ID and route correctly + unwrapped1, err := rm.UnwrapReceivedMessage(wrappedMsg1) + require.NoError(t, err) + require.Equal(t, msg1, *unwrapped1.Message) + + // Verify channel ID is extracted correctly + require.NotNil(t, unwrapped1.ChannelId, "Expected ChannelId to be not nil") + require.Equal(t, channel1, *unwrapped1.ChannelId) + + unwrapped2, err := rm.UnwrapReceivedMessage(wrappedMsg2) + require.NoError(t, err) + require.Equal(t, msg2, *unwrapped2.Message) + + // Verify channel ID is extracted correctly + require.NotNil(t, unwrapped2.ChannelId, "Expected ChannelId to be not nil") + require.Equal(t, channel2, *unwrapped2.ChannelId) + + // Test that the same RM can handle dependencies for different channels + err = rm.MarkDependenciesMet([]MessageID{msgID1}, channel1) + require.NoError(t, err) + + err = rm.MarkDependenciesMet([]MessageID{msgID2}, channel2) + require.NoError(t, err) +} + +// Test that callbacks are correctly triggered for multiple channels +func TestMultiChannelCallbacks(t *testing.T) { + // rm1 is the manager we are testing callbacks on + rm1, err := NewReliabilityManager() + require.NoError(t, err) + defer rm1.Cleanup() + + // rm2 simulates another peer + rm2, err := NewReliabilityManager() + require.NoError(t, err) + defer rm2.Cleanup() + + channel1 := "callback-channel-1" + channel2 := "callback-channel-2" + + var wg sync.WaitGroup + var cbMutex sync.Mutex + readyMessages := make(map[MessageID]string) + sentMessages := make(map[MessageID]string) + + callbacks := EventCallbacks{ + OnMessageReady: func(messageId MessageID, channelId string) { + // This will be triggered when rm1 receives messages from rm2 + cbMutex.Lock() + readyMessages[messageId] = channelId + cbMutex.Unlock() + wg.Done() + }, + OnMessageSent: func(messageId MessageID, channelId string) { + // This will be triggered when rm1's messages are acked by rm2 + cbMutex.Lock() + sentMessages[messageId] = channelId + cbMutex.Unlock() + wg.Done() + }, + } + rm1.RegisterCallbacks(callbacks) + + // --- Test Scenario --- + // 1. rm1 sends one message on each channel. + msgID1_ch1 := MessageID("msg-on-ch1") + wrappedMsg1_ch1, err := rm1.WrapOutgoingMessage([]byte("payload1"), msgID1_ch1, channel1) + require.NoError(t, err) + + msgID2_ch2 := MessageID("msg-on-ch2") + wrappedMsg2_ch2, err := rm1.WrapOutgoingMessage([]byte("payload2"), msgID2_ch2, channel2) + require.NoError(t, err) + + // 2. rm2 receives these messages to update its history. + _, err = rm2.UnwrapReceivedMessage(wrappedMsg1_ch1) + require.NoError(t, err) + _, err = rm2.UnwrapReceivedMessage(wrappedMsg2_ch2) + require.NoError(t, err) + + // 3. rm2 sends messages back on each channel, which will act as ACKs for rm1's messages. + ackID1_ch1 := MessageID("ack-on-ch1") + wrappedAck1_ch1, err := rm2.WrapOutgoingMessage([]byte("ack_payload1"), ackID1_ch1, channel1) + require.NoError(t, err) + + ackID2_ch2 := MessageID("ack-on-ch2") + wrappedAck2_ch2, err := rm2.WrapOutgoingMessage([]byte("ack_payload2"), ackID2_ch2, channel2) + require.NoError(t, err) + + // 4. rm1 receives the "ack" messages. This should trigger: + // - OnMessageSent for msgID1_ch1 and msgID2_ch2 + // - OnMessageReady for ackID1_ch1 and ackID2_ch2 + wg.Add(4) // Expect 2 sent and 2 ready callbacks + _, err = rm1.UnwrapReceivedMessage(wrappedAck1_ch1) + require.NoError(t, err) + _, err = rm1.UnwrapReceivedMessage(wrappedAck2_ch2) + require.NoError(t, err) + + // --- Verification --- + waitTimeout(&wg, 5*time.Second, t) + + cbMutex.Lock() + defer cbMutex.Unlock() + + // Check that both original messages were marked as sent and on the correct channel + require.Equal(t, channel1, sentMessages[msgID1_ch1], "OnMessageSent for msg1 has incorrect channel") + require.Equal(t, channel2, sentMessages[msgID2_ch2], "OnMessageSent for msg2 has incorrect channel") + require.Len(t, sentMessages, 2, "Expected exactly 2 sent messages") + + // Check that both ack messages were marked as ready and on the correct channel + require.Equal(t, channel1, readyMessages[ackID1_ch1], "OnMessageReady for ack1 has incorrect channel") + require.Equal(t, channel2, readyMessages[ackID2_ch2], "OnMessageReady for ack2 has incorrect channel") + require.Len(t, readyMessages, 2, "Expected exactly 2 ready messages") +} diff --git a/sds/types.go b/sds/types.go index cbe3b9b..8d788ca 100644 --- a/sds/types.go +++ b/sds/types.go @@ -5,4 +5,5 @@ type MessageID string type UnwrappedMessage struct { Message *[]byte `json:"message"` MissingDeps *[]MessageID `json:"missingDeps"` + ChannelId *string `json:"channelId"` }