mirror of
https://github.com/logos-messaging/sds-go-bindings.git
synced 2026-01-02 14:13:08 +00:00
feat: updates for retrieval hint
This commit is contained in:
parent
05e797c76b
commit
3152eb6fdd
38
sds/sds.go
38
sds/sds.go
@ -10,6 +10,8 @@ package sds
|
||||
|
||||
extern void sdsGlobalEventCallback(int ret, char* msg, size_t len, void* userData);
|
||||
|
||||
extern void sdsGlobalRetrievalHintProvider(char* messageId, char** hint, size_t* hintLen, void* userData);
|
||||
|
||||
typedef struct {
|
||||
int ret;
|
||||
char* msg;
|
||||
@ -78,6 +80,10 @@ package sds
|
||||
SdsSetEventCallback(rmCtx, (SdsCallBack) sdsGlobalEventCallback, rmCtx);
|
||||
}
|
||||
|
||||
static void cGoSdsSetRetrievalHintProvider(void* rmCtx) {
|
||||
SdsSetRetrievalHintProvider(rmCtx, (SdsRetrievalHintProvider) sdsGlobalRetrievalHintProvider, rmCtx);
|
||||
}
|
||||
|
||||
static void cGoSdsCleanupReliabilityManager(void* rmCtx, void* resp) {
|
||||
SdsCleanupReliabilityManager(rmCtx, (SdsCallBack) SdsGoCallback, resp);
|
||||
}
|
||||
@ -158,8 +164,9 @@ func SdsGoCallback(ret C.int, msg *C.char, len C.size_t, resp unsafe.Pointer) {
|
||||
type EventCallbacks struct {
|
||||
OnMessageReady func(messageId MessageID, channelId string)
|
||||
OnMessageSent func(messageId MessageID, channelId string)
|
||||
OnMissingDependencies func(messageId MessageID, missingDeps []MessageID, channelId string)
|
||||
OnMissingDependencies func(messageId MessageID, missingDeps []HistoryEntry, channelId string)
|
||||
OnPeriodicSync func()
|
||||
RetrievalHintProvider func(messageId MessageID) []byte
|
||||
}
|
||||
|
||||
// ReliabilityManager represents an instance of a nim-sds ReliabilityManager
|
||||
@ -189,6 +196,7 @@ func NewReliabilityManager() (*ReliabilityManager, error) {
|
||||
|
||||
C.cGoSdsSetEventCallback(rm.rmCtx)
|
||||
registerReliabilityManager(rm)
|
||||
C.cGoSdsSetRetrievalHintProvider(rm.rmCtx)
|
||||
|
||||
Debug("Successfully created Reliability Manager")
|
||||
return rm, nil
|
||||
@ -246,14 +254,33 @@ type msgEvent struct {
|
||||
|
||||
type missingDepsEvent struct {
|
||||
MessageId MessageID `json:"messageId"`
|
||||
MissingDeps []MessageID `json:"missingDeps"`
|
||||
ChannelId string `json:"channelId"`
|
||||
MissingDeps []HistoryEntry `json:"missingDeps"`
|
||||
ChannelId string `json:"channelId"`
|
||||
}
|
||||
|
||||
func (rm *ReliabilityManager) RegisterCallbacks(callbacks EventCallbacks) {
|
||||
rm.callbacks = callbacks
|
||||
}
|
||||
|
||||
//export sdsGlobalRetrievalHintProvider
|
||||
func sdsGlobalRetrievalHintProvider(messageId *C.char, hint **C.char, hintLen *C.size_t, userData unsafe.Pointer) {
|
||||
msgId := C.GoString(messageId)
|
||||
Debug("sdsGlobalRetrievalHintProvider called for messageId: %s", msgId)
|
||||
rm, ok := rmRegistry[userData]
|
||||
if ok && rm.callbacks.RetrievalHintProvider != nil {
|
||||
Debug("Found RM and callback, calling provider")
|
||||
hintBytes := rm.callbacks.RetrievalHintProvider(MessageID(msgId))
|
||||
Debug("Provider returned hint of length: %d", len(hintBytes))
|
||||
if len(hintBytes) > 0 {
|
||||
*hint = (*C.char)(C.CBytes(hintBytes))
|
||||
*hintLen = C.size_t(len(hintBytes))
|
||||
Debug("Set hint in C memory: %s", string(hintBytes))
|
||||
}
|
||||
} else {
|
||||
Debug("No RM found or no callback registered")
|
||||
}
|
||||
}
|
||||
|
||||
func (rm *ReliabilityManager) OnEvent(eventStr string) {
|
||||
|
||||
jsonEvent := jsonEvent{}
|
||||
@ -467,10 +494,11 @@ func (rm *ReliabilityManager) UnwrapReceivedMessage(message []byte) (*UnwrappedM
|
||||
}
|
||||
Debug("Successfully unwrapped message")
|
||||
|
||||
unwrappedMessage := UnwrappedMessage{}
|
||||
Debug("Unwrapped message JSON: %s", resStr)
|
||||
var unwrappedMessage UnwrappedMessage
|
||||
err := json.Unmarshal([]byte(resStr), &unwrappedMessage)
|
||||
if err != nil {
|
||||
Error("Failed to unmarshal unwrapped message")
|
||||
Error("Failed to unmarshal unwrapped message: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
@ -80,7 +80,7 @@ func TestDependencies(t *testing.T) {
|
||||
|
||||
foundDep1 := false
|
||||
for _, dep := range *unwrappedMessage2.MissingDeps {
|
||||
if dep == msgID1 {
|
||||
if dep.MessageID == msgID1 {
|
||||
foundDep1 = true
|
||||
break
|
||||
}
|
||||
@ -239,12 +239,15 @@ func TestCallback_OnMissingDependencies(t *testing.T) {
|
||||
var cbMutex sync.Mutex
|
||||
|
||||
callbacks := EventCallbacks{
|
||||
OnMissingDependencies: func(messageId MessageID, missingDeps []MessageID, chId string) {
|
||||
OnMissingDependencies: func(messageId MessageID, missingDeps []HistoryEntry, chId string) {
|
||||
require.Equal(t, channelID, chId)
|
||||
cbMutex.Lock()
|
||||
missingCalled = true
|
||||
missingMsgID = messageId
|
||||
missingDepsList = missingDeps // Copy slice
|
||||
missingDepsList = make([]MessageID, len(missingDeps))
|
||||
for i, dep := range missingDeps {
|
||||
missingDepsList[i] = dep.MessageID
|
||||
}
|
||||
cbMutex.Unlock()
|
||||
wg.Done()
|
||||
},
|
||||
@ -383,7 +386,7 @@ func TestCallbacks_Combined(t *testing.T) {
|
||||
// are typically relevant to the Sender. We don't expect this.
|
||||
t.Errorf("Unexpected OnMessageSent call on Receiver for %s", messageId)
|
||||
},
|
||||
OnMissingDependencies: func(messageId MessageID, missingDeps []MessageID, chId string) {
|
||||
OnMissingDependencies: func(messageId MessageID, missingDeps []HistoryEntry, chId string) {
|
||||
// This callback is registered on Receiver, used for receiverRm2 below
|
||||
},
|
||||
}
|
||||
@ -404,7 +407,7 @@ func TestCallbacks_Combined(t *testing.T) {
|
||||
}
|
||||
}
|
||||
},
|
||||
OnMissingDependencies: func(messageId MessageID, missingDeps []MessageID, chId string) {
|
||||
OnMissingDependencies: func(messageId MessageID, missingDeps []HistoryEntry, chId string) {
|
||||
// Not expected on sender
|
||||
},
|
||||
}
|
||||
@ -447,11 +450,16 @@ func TestCallbacks_Combined(t *testing.T) {
|
||||
defer receiverRm2.Cleanup()
|
||||
|
||||
callbacksReceiver2 := EventCallbacks{
|
||||
OnMissingDependencies: func(messageId MessageID, missingDeps []MessageID, chId string) {
|
||||
OnMissingDependencies: func(messageId MessageID, missingDeps []HistoryEntry, chId string) {
|
||||
require.Equal(t, channelID, chId)
|
||||
if messageId == msgID3 {
|
||||
// Convert []HistoryEntry to []MessageID for the channel
|
||||
deps := make([]MessageID, len(missingDeps))
|
||||
for i, d := range missingDeps {
|
||||
deps[i] = d.MessageID
|
||||
}
|
||||
select {
|
||||
case missingChan <- missingDeps:
|
||||
case missingChan <- deps:
|
||||
default:
|
||||
}
|
||||
}
|
||||
@ -677,3 +685,61 @@ func TestMultiChannelCallbacks(t *testing.T) {
|
||||
require.Equal(t, channel2, readyMessages[ackID2_ch2], "OnMessageReady for ack2 has incorrect channel")
|
||||
require.Len(t, readyMessages, 2, "Expected exactly 2 ready messages")
|
||||
}
|
||||
|
||||
func TestRetrievalHints(t *testing.T) {
|
||||
rm, err := NewReliabilityManager()
|
||||
require.NoError(t, err)
|
||||
defer rm.Cleanup()
|
||||
|
||||
channelID := "test-retrieval-hints"
|
||||
|
||||
// Set a retrieval hint provider
|
||||
rm.RegisterCallbacks(EventCallbacks{
|
||||
RetrievalHintProvider: func(messageId MessageID) []byte {
|
||||
return []byte("hint-for-" + messageId)
|
||||
},
|
||||
})
|
||||
|
||||
// 1. Send a message to populate the history
|
||||
payload1 := []byte("message one")
|
||||
msgID1 := MessageID("msg-hint-1")
|
||||
wrappedMsg1, err := rm.WrapOutgoingMessage(payload1, msgID1, channelID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 2. Receive the message to add it to history
|
||||
_, err = rm.UnwrapReceivedMessage(wrappedMsg1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 3. Send a second message, which will include the first in its causal history
|
||||
payload2 := []byte("message two")
|
||||
msgID2 := MessageID("msg-hint-2")
|
||||
wrappedMsg2, err := rm.WrapOutgoingMessage(payload2, msgID2, channelID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 4. Unwrap the second message to inspect its causal history
|
||||
// We need a new RM to avoid acknowledging the message
|
||||
rm2, err := NewReliabilityManager()
|
||||
require.NoError(t, err)
|
||||
defer rm2.Cleanup()
|
||||
|
||||
rm2.RegisterCallbacks(EventCallbacks{
|
||||
RetrievalHintProvider: func(messageId MessageID) []byte {
|
||||
return []byte("hint-for-" + messageId)
|
||||
},
|
||||
})
|
||||
|
||||
unwrappedMsg2, err := rm2.UnwrapReceivedMessage(wrappedMsg2)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 5. Check that the causal history contains the retrieval hint
|
||||
require.Greater(t, len(*unwrappedMsg2.MissingDeps), 0, "Expected missing dependencies")
|
||||
foundDep := false
|
||||
for _, dep := range *unwrappedMsg2.MissingDeps {
|
||||
if dep.MessageID == msgID1 {
|
||||
foundDep = true
|
||||
require.Equal(t, []byte("hint-for-"+msgID1), dep.RetrievalHint, "Retrieval hint does not match")
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, foundDep, "Expected to find dependency %s", msgID1)
|
||||
}
|
||||
|
||||
13
sds/types.go
13
sds/types.go
@ -2,8 +2,13 @@ package sds
|
||||
|
||||
type MessageID string
|
||||
|
||||
type UnwrappedMessage struct {
|
||||
Message *[]byte `json:"message"`
|
||||
MissingDeps *[]MessageID `json:"missingDeps"`
|
||||
ChannelId *string `json:"channelId"`
|
||||
type HistoryEntry struct {
|
||||
MessageID MessageID `json:"messageId"`
|
||||
RetrievalHint []byte `json:"retrievalHint"`
|
||||
}
|
||||
|
||||
type UnwrappedMessage struct {
|
||||
Message *[]byte `json:"message"`
|
||||
MissingDeps *[]HistoryEntry `json:"missingDeps"`
|
||||
ChannelId *string `json:"channelId"`
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user