feat: updates for retrieval hint

This commit is contained in:
shash256 2025-08-18 19:44:00 +05:30
parent 05e797c76b
commit 3152eb6fdd
4 changed files with 115 additions and 16 deletions

BIN
.DS_Store vendored Normal file

Binary file not shown.

View File

@ -10,6 +10,8 @@ package sds
extern void sdsGlobalEventCallback(int ret, char* msg, size_t len, void* userData);
extern void sdsGlobalRetrievalHintProvider(char* messageId, char** hint, size_t* hintLen, void* userData);
typedef struct {
int ret;
char* msg;
@ -78,6 +80,10 @@ package sds
SdsSetEventCallback(rmCtx, (SdsCallBack) sdsGlobalEventCallback, rmCtx);
}
static void cGoSdsSetRetrievalHintProvider(void* rmCtx) {
SdsSetRetrievalHintProvider(rmCtx, (SdsRetrievalHintProvider) sdsGlobalRetrievalHintProvider, rmCtx);
}
static void cGoSdsCleanupReliabilityManager(void* rmCtx, void* resp) {
SdsCleanupReliabilityManager(rmCtx, (SdsCallBack) SdsGoCallback, resp);
}
@ -158,8 +164,9 @@ func SdsGoCallback(ret C.int, msg *C.char, len C.size_t, resp unsafe.Pointer) {
type EventCallbacks struct {
OnMessageReady func(messageId MessageID, channelId string)
OnMessageSent func(messageId MessageID, channelId string)
OnMissingDependencies func(messageId MessageID, missingDeps []MessageID, channelId string)
OnMissingDependencies func(messageId MessageID, missingDeps []HistoryEntry, channelId string)
OnPeriodicSync func()
RetrievalHintProvider func(messageId MessageID) []byte
}
// ReliabilityManager represents an instance of a nim-sds ReliabilityManager
@ -189,6 +196,7 @@ func NewReliabilityManager() (*ReliabilityManager, error) {
C.cGoSdsSetEventCallback(rm.rmCtx)
registerReliabilityManager(rm)
C.cGoSdsSetRetrievalHintProvider(rm.rmCtx)
Debug("Successfully created Reliability Manager")
return rm, nil
@ -246,14 +254,33 @@ type msgEvent struct {
type missingDepsEvent struct {
MessageId MessageID `json:"messageId"`
MissingDeps []MessageID `json:"missingDeps"`
ChannelId string `json:"channelId"`
MissingDeps []HistoryEntry `json:"missingDeps"`
ChannelId string `json:"channelId"`
}
func (rm *ReliabilityManager) RegisterCallbacks(callbacks EventCallbacks) {
rm.callbacks = callbacks
}
//export sdsGlobalRetrievalHintProvider
func sdsGlobalRetrievalHintProvider(messageId *C.char, hint **C.char, hintLen *C.size_t, userData unsafe.Pointer) {
msgId := C.GoString(messageId)
Debug("sdsGlobalRetrievalHintProvider called for messageId: %s", msgId)
rm, ok := rmRegistry[userData]
if ok && rm.callbacks.RetrievalHintProvider != nil {
Debug("Found RM and callback, calling provider")
hintBytes := rm.callbacks.RetrievalHintProvider(MessageID(msgId))
Debug("Provider returned hint of length: %d", len(hintBytes))
if len(hintBytes) > 0 {
*hint = (*C.char)(C.CBytes(hintBytes))
*hintLen = C.size_t(len(hintBytes))
Debug("Set hint in C memory: %s", string(hintBytes))
}
} else {
Debug("No RM found or no callback registered")
}
}
func (rm *ReliabilityManager) OnEvent(eventStr string) {
jsonEvent := jsonEvent{}
@ -467,10 +494,11 @@ func (rm *ReliabilityManager) UnwrapReceivedMessage(message []byte) (*UnwrappedM
}
Debug("Successfully unwrapped message")
unwrappedMessage := UnwrappedMessage{}
Debug("Unwrapped message JSON: %s", resStr)
var unwrappedMessage UnwrappedMessage
err := json.Unmarshal([]byte(resStr), &unwrappedMessage)
if err != nil {
Error("Failed to unmarshal unwrapped message")
Error("Failed to unmarshal unwrapped message: %v", err)
return nil, err
}

View File

@ -80,7 +80,7 @@ func TestDependencies(t *testing.T) {
foundDep1 := false
for _, dep := range *unwrappedMessage2.MissingDeps {
if dep == msgID1 {
if dep.MessageID == msgID1 {
foundDep1 = true
break
}
@ -239,12 +239,15 @@ func TestCallback_OnMissingDependencies(t *testing.T) {
var cbMutex sync.Mutex
callbacks := EventCallbacks{
OnMissingDependencies: func(messageId MessageID, missingDeps []MessageID, chId string) {
OnMissingDependencies: func(messageId MessageID, missingDeps []HistoryEntry, chId string) {
require.Equal(t, channelID, chId)
cbMutex.Lock()
missingCalled = true
missingMsgID = messageId
missingDepsList = missingDeps // Copy slice
missingDepsList = make([]MessageID, len(missingDeps))
for i, dep := range missingDeps {
missingDepsList[i] = dep.MessageID
}
cbMutex.Unlock()
wg.Done()
},
@ -383,7 +386,7 @@ func TestCallbacks_Combined(t *testing.T) {
// are typically relevant to the Sender. We don't expect this.
t.Errorf("Unexpected OnMessageSent call on Receiver for %s", messageId)
},
OnMissingDependencies: func(messageId MessageID, missingDeps []MessageID, chId string) {
OnMissingDependencies: func(messageId MessageID, missingDeps []HistoryEntry, chId string) {
// This callback is registered on Receiver, used for receiverRm2 below
},
}
@ -404,7 +407,7 @@ func TestCallbacks_Combined(t *testing.T) {
}
}
},
OnMissingDependencies: func(messageId MessageID, missingDeps []MessageID, chId string) {
OnMissingDependencies: func(messageId MessageID, missingDeps []HistoryEntry, chId string) {
// Not expected on sender
},
}
@ -447,11 +450,16 @@ func TestCallbacks_Combined(t *testing.T) {
defer receiverRm2.Cleanup()
callbacksReceiver2 := EventCallbacks{
OnMissingDependencies: func(messageId MessageID, missingDeps []MessageID, chId string) {
OnMissingDependencies: func(messageId MessageID, missingDeps []HistoryEntry, chId string) {
require.Equal(t, channelID, chId)
if messageId == msgID3 {
// Convert []HistoryEntry to []MessageID for the channel
deps := make([]MessageID, len(missingDeps))
for i, d := range missingDeps {
deps[i] = d.MessageID
}
select {
case missingChan <- missingDeps:
case missingChan <- deps:
default:
}
}
@ -677,3 +685,61 @@ func TestMultiChannelCallbacks(t *testing.T) {
require.Equal(t, channel2, readyMessages[ackID2_ch2], "OnMessageReady for ack2 has incorrect channel")
require.Len(t, readyMessages, 2, "Expected exactly 2 ready messages")
}
func TestRetrievalHints(t *testing.T) {
rm, err := NewReliabilityManager()
require.NoError(t, err)
defer rm.Cleanup()
channelID := "test-retrieval-hints"
// Set a retrieval hint provider
rm.RegisterCallbacks(EventCallbacks{
RetrievalHintProvider: func(messageId MessageID) []byte {
return []byte("hint-for-" + messageId)
},
})
// 1. Send a message to populate the history
payload1 := []byte("message one")
msgID1 := MessageID("msg-hint-1")
wrappedMsg1, err := rm.WrapOutgoingMessage(payload1, msgID1, channelID)
require.NoError(t, err)
// 2. Receive the message to add it to history
_, err = rm.UnwrapReceivedMessage(wrappedMsg1)
require.NoError(t, err)
// 3. Send a second message, which will include the first in its causal history
payload2 := []byte("message two")
msgID2 := MessageID("msg-hint-2")
wrappedMsg2, err := rm.WrapOutgoingMessage(payload2, msgID2, channelID)
require.NoError(t, err)
// 4. Unwrap the second message to inspect its causal history
// We need a new RM to avoid acknowledging the message
rm2, err := NewReliabilityManager()
require.NoError(t, err)
defer rm2.Cleanup()
rm2.RegisterCallbacks(EventCallbacks{
RetrievalHintProvider: func(messageId MessageID) []byte {
return []byte("hint-for-" + messageId)
},
})
unwrappedMsg2, err := rm2.UnwrapReceivedMessage(wrappedMsg2)
require.NoError(t, err)
// 5. Check that the causal history contains the retrieval hint
require.Greater(t, len(*unwrappedMsg2.MissingDeps), 0, "Expected missing dependencies")
foundDep := false
for _, dep := range *unwrappedMsg2.MissingDeps {
if dep.MessageID == msgID1 {
foundDep = true
require.Equal(t, []byte("hint-for-"+msgID1), dep.RetrievalHint, "Retrieval hint does not match")
break
}
}
require.True(t, foundDep, "Expected to find dependency %s", msgID1)
}

View File

@ -2,8 +2,13 @@ package sds
type MessageID string
type UnwrappedMessage struct {
Message *[]byte `json:"message"`
MissingDeps *[]MessageID `json:"missingDeps"`
ChannelId *string `json:"channelId"`
type HistoryEntry struct {
MessageID MessageID `json:"messageId"`
RetrievalHint []byte `json:"retrievalHint"`
}
type UnwrappedMessage struct {
Message *[]byte `json:"message"`
MissingDeps *[]HistoryEntry `json:"missingDeps"`
ChannelId *string `json:"channelId"`
}