diff --git a/sds/sds.go b/sds/sds.go index 30d4e8f..66d1bb1 100644 --- a/sds/sds.go +++ b/sds/sds.go @@ -9,6 +9,8 @@ package sds extern void sdsGlobalEventCallback(int ret, char* msg, size_t len, void* userData); + extern void sdsGlobalRetrievalHintProvider(char* messageId, char** hint, size_t* hintLen, void* userData); + typedef struct { int ret; char* msg; @@ -77,6 +79,10 @@ package sds SdsSetEventCallback(rmCtx, (SdsCallBack) sdsGlobalEventCallback, rmCtx); } + static void cGoSdsSetRetrievalHintProvider(void* rmCtx) { + SdsSetRetrievalHintProvider(rmCtx, (SdsRetrievalHintProvider) sdsGlobalRetrievalHintProvider, rmCtx); + } + static void cGoSdsCleanupReliabilityManager(void* rmCtx, void* resp) { SdsCleanupReliabilityManager(rmCtx, (SdsCallBack) SdsGoCallback, resp); } @@ -184,6 +190,7 @@ func NewReliabilityManager(logger *zap.Logger) (*ReliabilityManager, error) { C.cGoSdsSetEventCallback(rm.rmCtx) registerReliabilityManager(rm) + C.cGoSdsSetRetrievalHintProvider(rm.rmCtx) rm.logger.Debug("successfully created reliability manager") return rm, nil @@ -204,6 +211,22 @@ func sdsGlobalEventCallback(callerRet C.int, msg *C.char, len C.size_t, userData } } +//export sdsGlobalRetrievalHintProvider +func sdsGlobalRetrievalHintProvider(messageId *C.char, hint **C.char, hintLen *C.size_t, userData unsafe.Pointer) { + msgId := C.GoString(messageId) + Debug("sdsGlobalRetrievalHintProvider called for messageId: %s", msgId) + rm, ok := rmRegistry[userData] + if ok { + if rm.callbacks.RetrievalHintProvider != nil { + hintBytes := rm.callbacks.RetrievalHintProvider(MessageID(msgId)) + if len(hintBytes) > 0 { + *hint = (*C.char)(C.CBytes(hintBytes)) + *hintLen = C.size_t(len(hintBytes)) + } + } + } +} + func (rm *ReliabilityManager) Cleanup() error { if rm == nil { return errEmptyReliabilityManager @@ -340,7 +363,8 @@ func (rm *ReliabilityManager) UnwrapReceivedMessage(message []byte) (*UnwrappedM } rm.logger.Debug("successfully unwrapped message") - unwrappedMessage := UnwrappedMessage{} + Debug("Unwrapped message JSON: %s", resStr) + var unwrappedMessage UnwrappedMessage err := json.Unmarshal([]byte(resStr), &unwrappedMessage) if err != nil { return nil, errorspkg.Wrap(err, "failed to unmarshal unwrapped message") diff --git a/sds/sds_common.go b/sds/sds_common.go index 16fe196..f6f6300 100644 --- a/sds/sds_common.go +++ b/sds/sds_common.go @@ -14,8 +14,9 @@ const EventChanBufferSize = 1024 type EventCallbacks struct { OnMessageReady func(messageId MessageID, channelId string) OnMessageSent func(messageId MessageID, channelId string) - OnMissingDependencies func(messageId MessageID, missingDeps []MessageID, channelId string) + OnMissingDependencies func(messageId MessageID, missingDeps []HistoryEntry, channelId string) OnPeriodicSync func() + RetrievalHintProvider func(messageId MessageID) []byte } // ReliabilityManager represents an instance of a nim-sds ReliabilityManager @@ -58,8 +59,8 @@ type msgEvent struct { type missingDepsEvent struct { MessageId MessageID `json:"messageId"` - MissingDeps []MessageID `json:"missingDeps"` - ChannelId string `json:"channelId"` + MissingDeps []HistoryEntry `json:"missingDeps"` + ChannelId string `json:"channelId"` } func (rm *ReliabilityManager) RegisterCallbacks(callbacks EventCallbacks) { diff --git a/sds/sds_test.go b/sds/sds_test.go index 5e80cf7..4c95010 100644 --- a/sds/sds_test.go +++ b/sds/sds_test.go @@ -10,7 +10,7 @@ import ( // Test basic creation, cleanup, and reset func TestLifecycle(t *testing.T) { - rm, err := NewReliabilityManager() + rm, err := NewReliabilityManager(nil) require.NoError(t, err) require.NotNil(t, rm, "Expected ReliabilityManager to be not nil") @@ -22,7 +22,7 @@ func TestLifecycle(t *testing.T) { // Test wrapping and unwrapping a simple message func TestWrapUnwrap(t *testing.T) { - rm, err := NewReliabilityManager() + rm, err := NewReliabilityManager(nil) require.NoError(t, err) defer rm.Cleanup() @@ -45,7 +45,7 @@ func TestWrapUnwrap(t *testing.T) { // Test dependency handling func TestDependencies(t *testing.T) { - rm, err := NewReliabilityManager() + rm, err := NewReliabilityManager(nil) require.NoError(t, err) defer rm.Cleanup() @@ -68,7 +68,7 @@ func TestDependencies(t *testing.T) { require.NoError(t, err) // 3. Create a new manager to simulate a different peer receiving msg2 without msg1 - rm2, err := NewReliabilityManager() + rm2, err := NewReliabilityManager(nil) require.NoError(t, err) defer rm2.Cleanup() @@ -80,7 +80,7 @@ func TestDependencies(t *testing.T) { foundDep1 := false for _, dep := range *unwrappedMessage2.MissingDeps { - if dep == msgID1 { + if dep.MessageID == msgID1 { foundDep1 = true break } @@ -95,11 +95,11 @@ func TestDependencies(t *testing.T) { // Test OnMessageReady callback func TestCallback_OnMessageReady(t *testing.T) { // Create sender and receiver RMs - senderRm, err := NewReliabilityManager() + senderRm, err := NewReliabilityManager(nil) require.NoError(t, err) defer senderRm.Cleanup() - receiverRm, err := NewReliabilityManager() + receiverRm, err := NewReliabilityManager(nil) require.NoError(t, err) defer receiverRm.Cleanup() @@ -151,11 +151,11 @@ func TestCallback_OnMessageReady(t *testing.T) { // Test OnMessageSent callback (via causal history ACK) func TestCallback_OnMessageSent(t *testing.T) { // Create two RMs - rm1, err := NewReliabilityManager() + rm1, err := NewReliabilityManager(nil) require.NoError(t, err) defer rm1.Cleanup() - rm2, err := NewReliabilityManager() + rm2, err := NewReliabilityManager(nil) require.NoError(t, err) defer rm2.Cleanup() @@ -222,11 +222,11 @@ func TestCallback_OnMessageSent(t *testing.T) { // Test OnMissingDependencies callback func TestCallback_OnMissingDependencies(t *testing.T) { // Use separate sender/receiver RMs explicitly - senderRm, err := NewReliabilityManager() + senderRm, err := NewReliabilityManager(nil) require.NoError(t, err) defer senderRm.Cleanup() - receiverRm, err := NewReliabilityManager() + receiverRm, err := NewReliabilityManager(nil) require.NoError(t, err) defer receiverRm.Cleanup() @@ -239,12 +239,15 @@ func TestCallback_OnMissingDependencies(t *testing.T) { var cbMutex sync.Mutex callbacks := EventCallbacks{ - OnMissingDependencies: func(messageId MessageID, missingDeps []MessageID, chId string) { + OnMissingDependencies: func(messageId MessageID, missingDeps []HistoryEntry, chId string) { require.Equal(t, channelID, chId) cbMutex.Lock() missingCalled = true missingMsgID = messageId - missingDepsList = missingDeps // Copy slice + missingDepsList = make([]MessageID, len(missingDeps)) + for i, dep := range missingDeps { + missingDepsList[i] = dep.MessageID + } cbMutex.Unlock() wg.Done() }, @@ -298,7 +301,7 @@ func TestCallback_OnMissingDependencies(t *testing.T) { // Test OnPeriodicSync callback func TestCallback_OnPeriodicSync(t *testing.T) { - rm, err := NewReliabilityManager() + rm, err := NewReliabilityManager(nil) require.NoError(t, err) defer rm.Cleanup() @@ -344,11 +347,11 @@ func TestCallback_OnPeriodicSync(t *testing.T) { // Combined Test for multiple callbacks func TestCallbacks_Combined(t *testing.T) { // Create sender and receiver RMs - senderRm, err := NewReliabilityManager() + senderRm, err := NewReliabilityManager(nil) require.NoError(t, err) defer senderRm.Cleanup() - receiverRm, err := NewReliabilityManager() + receiverRm, err := NewReliabilityManager(nil) require.NoError(t, err) defer receiverRm.Cleanup() @@ -383,7 +386,7 @@ func TestCallbacks_Combined(t *testing.T) { // are typically relevant to the Sender. We don't expect this. t.Errorf("Unexpected OnMessageSent call on Receiver for %s", messageId) }, - OnMissingDependencies: func(messageId MessageID, missingDeps []MessageID, chId string) { + OnMissingDependencies: func(messageId MessageID, missingDeps []HistoryEntry, chId string) { // This callback is registered on Receiver, used for receiverRm2 below }, } @@ -404,7 +407,7 @@ func TestCallbacks_Combined(t *testing.T) { } } }, - OnMissingDependencies: func(messageId MessageID, missingDeps []MessageID, chId string) { + OnMissingDependencies: func(messageId MessageID, missingDeps []HistoryEntry, chId string) { // Not expected on sender }, } @@ -442,16 +445,21 @@ func TestCallbacks_Combined(t *testing.T) { require.NoError(t, err) // 6. Create Receiver2, register missing deps callback - receiverRm2, err := NewReliabilityManager() + receiverRm2, err := NewReliabilityManager(nil) require.NoError(t, err) defer receiverRm2.Cleanup() callbacksReceiver2 := EventCallbacks{ - OnMissingDependencies: func(messageId MessageID, missingDeps []MessageID, chId string) { + OnMissingDependencies: func(messageId MessageID, missingDeps []HistoryEntry, chId string) { require.Equal(t, channelID, chId) if messageId == msgID3 { + // Convert []HistoryEntry to []MessageID for the channel + deps := make([]MessageID, len(missingDeps)) + for i, d := range missingDeps { + deps[i] = d.MessageID + } select { - case missingChan <- missingDeps: + case missingChan <- deps: default: } } @@ -545,7 +553,7 @@ func waitTimeout(wg *sync.WaitGroup, timeout time.Duration, t *testing.T) { // Test multi-channel functionality - one RM can handle messages from different channels func TestMultiChannel_SingleRM(t *testing.T) { - rm, err := NewReliabilityManager() + rm, err := NewReliabilityManager(nil) require.NoError(t, err) defer rm.Cleanup() @@ -592,12 +600,12 @@ func TestMultiChannel_SingleRM(t *testing.T) { // Test that callbacks are correctly triggered for multiple channels func TestMultiChannelCallbacks(t *testing.T) { // rm1 is the manager we are testing callbacks on - rm1, err := NewReliabilityManager() + rm1, err := NewReliabilityManager(nil) require.NoError(t, err) defer rm1.Cleanup() // rm2 simulates another peer - rm2, err := NewReliabilityManager() + rm2, err := NewReliabilityManager(nil) require.NoError(t, err) defer rm2.Cleanup() @@ -677,3 +685,61 @@ func TestMultiChannelCallbacks(t *testing.T) { require.Equal(t, channel2, readyMessages[ackID2_ch2], "OnMessageReady for ack2 has incorrect channel") require.Len(t, readyMessages, 2, "Expected exactly 2 ready messages") } + +func TestRetrievalHints(t *testing.T) { + rm, err := NewReliabilityManager(nil) + require.NoError(t, err) + defer rm.Cleanup() + + channelID := "test-retrieval-hints" + + // Set a retrieval hint provider + rm.RegisterCallbacks(EventCallbacks{ + RetrievalHintProvider: func(messageId MessageID) []byte { + return []byte("hint-for-" + messageId) + }, + }) + + // 1. Send a message to populate the history + payload1 := []byte("message one") + msgID1 := MessageID("msg-hint-1") + wrappedMsg1, err := rm.WrapOutgoingMessage(payload1, msgID1, channelID) + require.NoError(t, err) + + // 2. Receive the message to add it to history + _, err = rm.UnwrapReceivedMessage(wrappedMsg1) + require.NoError(t, err) + + // 3. Send a second message, which will include the first in its causal history + payload2 := []byte("message two") + msgID2 := MessageID("msg-hint-2") + wrappedMsg2, err := rm.WrapOutgoingMessage(payload2, msgID2, channelID) + require.NoError(t, err) + + // 4. Unwrap the second message to inspect its causal history + // We need a new RM to avoid acknowledging the message + rm2, err := NewReliabilityManager(nil) + require.NoError(t, err) + defer rm2.Cleanup() + + rm2.RegisterCallbacks(EventCallbacks{ + RetrievalHintProvider: func(messageId MessageID) []byte { + return []byte("hint-for-" + messageId) + }, + }) + + unwrappedMsg2, err := rm2.UnwrapReceivedMessage(wrappedMsg2) + require.NoError(t, err) + + // 5. Check that the causal history contains the retrieval hint + require.Greater(t, len(*unwrappedMsg2.MissingDeps), 0, "Expected missing dependencies") + foundDep := false + for _, dep := range *unwrappedMsg2.MissingDeps { + if dep.MessageID == msgID1 { + foundDep = true + require.Equal(t, []byte("hint-for-"+msgID1), dep.RetrievalHint, "Retrieval hint does not match") + break + } + } + require.True(t, foundDep, "Expected to find dependency %s", msgID1) +} diff --git a/sds/types.go b/sds/types.go index 8d788ca..6606fd4 100644 --- a/sds/types.go +++ b/sds/types.go @@ -2,8 +2,13 @@ package sds type MessageID string -type UnwrappedMessage struct { - Message *[]byte `json:"message"` - MissingDeps *[]MessageID `json:"missingDeps"` - ChannelId *string `json:"channelId"` +type HistoryEntry struct { + MessageID MessageID `json:"messageId"` + RetrievalHint []byte `json:"retrievalHint"` +} + +type UnwrappedMessage struct { + Message *[]byte `json:"message"` + MissingDeps *[]HistoryEntry `json:"missingDeps"` + ChannelId *string `json:"channelId"` }