mirror of
https://github.com/logos-messaging/sds-go-bindings.git
synced 2026-01-03 22:53:07 +00:00
feat: updates for retrieval hint
This commit is contained in:
parent
05e797c76b
commit
3152eb6fdd
38
sds/sds.go
38
sds/sds.go
@ -10,6 +10,8 @@ package sds
|
|||||||
|
|
||||||
extern void sdsGlobalEventCallback(int ret, char* msg, size_t len, void* userData);
|
extern void sdsGlobalEventCallback(int ret, char* msg, size_t len, void* userData);
|
||||||
|
|
||||||
|
extern void sdsGlobalRetrievalHintProvider(char* messageId, char** hint, size_t* hintLen, void* userData);
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
int ret;
|
int ret;
|
||||||
char* msg;
|
char* msg;
|
||||||
@ -78,6 +80,10 @@ package sds
|
|||||||
SdsSetEventCallback(rmCtx, (SdsCallBack) sdsGlobalEventCallback, rmCtx);
|
SdsSetEventCallback(rmCtx, (SdsCallBack) sdsGlobalEventCallback, rmCtx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void cGoSdsSetRetrievalHintProvider(void* rmCtx) {
|
||||||
|
SdsSetRetrievalHintProvider(rmCtx, (SdsRetrievalHintProvider) sdsGlobalRetrievalHintProvider, rmCtx);
|
||||||
|
}
|
||||||
|
|
||||||
static void cGoSdsCleanupReliabilityManager(void* rmCtx, void* resp) {
|
static void cGoSdsCleanupReliabilityManager(void* rmCtx, void* resp) {
|
||||||
SdsCleanupReliabilityManager(rmCtx, (SdsCallBack) SdsGoCallback, resp);
|
SdsCleanupReliabilityManager(rmCtx, (SdsCallBack) SdsGoCallback, resp);
|
||||||
}
|
}
|
||||||
@ -158,8 +164,9 @@ func SdsGoCallback(ret C.int, msg *C.char, len C.size_t, resp unsafe.Pointer) {
|
|||||||
type EventCallbacks struct {
|
type EventCallbacks struct {
|
||||||
OnMessageReady func(messageId MessageID, channelId string)
|
OnMessageReady func(messageId MessageID, channelId string)
|
||||||
OnMessageSent func(messageId MessageID, channelId string)
|
OnMessageSent func(messageId MessageID, channelId string)
|
||||||
OnMissingDependencies func(messageId MessageID, missingDeps []MessageID, channelId string)
|
OnMissingDependencies func(messageId MessageID, missingDeps []HistoryEntry, channelId string)
|
||||||
OnPeriodicSync func()
|
OnPeriodicSync func()
|
||||||
|
RetrievalHintProvider func(messageId MessageID) []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReliabilityManager represents an instance of a nim-sds ReliabilityManager
|
// ReliabilityManager represents an instance of a nim-sds ReliabilityManager
|
||||||
@ -189,6 +196,7 @@ func NewReliabilityManager() (*ReliabilityManager, error) {
|
|||||||
|
|
||||||
C.cGoSdsSetEventCallback(rm.rmCtx)
|
C.cGoSdsSetEventCallback(rm.rmCtx)
|
||||||
registerReliabilityManager(rm)
|
registerReliabilityManager(rm)
|
||||||
|
C.cGoSdsSetRetrievalHintProvider(rm.rmCtx)
|
||||||
|
|
||||||
Debug("Successfully created Reliability Manager")
|
Debug("Successfully created Reliability Manager")
|
||||||
return rm, nil
|
return rm, nil
|
||||||
@ -246,14 +254,33 @@ type msgEvent struct {
|
|||||||
|
|
||||||
type missingDepsEvent struct {
|
type missingDepsEvent struct {
|
||||||
MessageId MessageID `json:"messageId"`
|
MessageId MessageID `json:"messageId"`
|
||||||
MissingDeps []MessageID `json:"missingDeps"`
|
MissingDeps []HistoryEntry `json:"missingDeps"`
|
||||||
ChannelId string `json:"channelId"`
|
ChannelId string `json:"channelId"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rm *ReliabilityManager) RegisterCallbacks(callbacks EventCallbacks) {
|
func (rm *ReliabilityManager) RegisterCallbacks(callbacks EventCallbacks) {
|
||||||
rm.callbacks = callbacks
|
rm.callbacks = callbacks
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//export sdsGlobalRetrievalHintProvider
|
||||||
|
func sdsGlobalRetrievalHintProvider(messageId *C.char, hint **C.char, hintLen *C.size_t, userData unsafe.Pointer) {
|
||||||
|
msgId := C.GoString(messageId)
|
||||||
|
Debug("sdsGlobalRetrievalHintProvider called for messageId: %s", msgId)
|
||||||
|
rm, ok := rmRegistry[userData]
|
||||||
|
if ok && rm.callbacks.RetrievalHintProvider != nil {
|
||||||
|
Debug("Found RM and callback, calling provider")
|
||||||
|
hintBytes := rm.callbacks.RetrievalHintProvider(MessageID(msgId))
|
||||||
|
Debug("Provider returned hint of length: %d", len(hintBytes))
|
||||||
|
if len(hintBytes) > 0 {
|
||||||
|
*hint = (*C.char)(C.CBytes(hintBytes))
|
||||||
|
*hintLen = C.size_t(len(hintBytes))
|
||||||
|
Debug("Set hint in C memory: %s", string(hintBytes))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Debug("No RM found or no callback registered")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (rm *ReliabilityManager) OnEvent(eventStr string) {
|
func (rm *ReliabilityManager) OnEvent(eventStr string) {
|
||||||
|
|
||||||
jsonEvent := jsonEvent{}
|
jsonEvent := jsonEvent{}
|
||||||
@ -467,10 +494,11 @@ func (rm *ReliabilityManager) UnwrapReceivedMessage(message []byte) (*UnwrappedM
|
|||||||
}
|
}
|
||||||
Debug("Successfully unwrapped message")
|
Debug("Successfully unwrapped message")
|
||||||
|
|
||||||
unwrappedMessage := UnwrappedMessage{}
|
Debug("Unwrapped message JSON: %s", resStr)
|
||||||
|
var unwrappedMessage UnwrappedMessage
|
||||||
err := json.Unmarshal([]byte(resStr), &unwrappedMessage)
|
err := json.Unmarshal([]byte(resStr), &unwrappedMessage)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
Error("Failed to unmarshal unwrapped message")
|
Error("Failed to unmarshal unwrapped message: %v", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -80,7 +80,7 @@ func TestDependencies(t *testing.T) {
|
|||||||
|
|
||||||
foundDep1 := false
|
foundDep1 := false
|
||||||
for _, dep := range *unwrappedMessage2.MissingDeps {
|
for _, dep := range *unwrappedMessage2.MissingDeps {
|
||||||
if dep == msgID1 {
|
if dep.MessageID == msgID1 {
|
||||||
foundDep1 = true
|
foundDep1 = true
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@ -239,12 +239,15 @@ func TestCallback_OnMissingDependencies(t *testing.T) {
|
|||||||
var cbMutex sync.Mutex
|
var cbMutex sync.Mutex
|
||||||
|
|
||||||
callbacks := EventCallbacks{
|
callbacks := EventCallbacks{
|
||||||
OnMissingDependencies: func(messageId MessageID, missingDeps []MessageID, chId string) {
|
OnMissingDependencies: func(messageId MessageID, missingDeps []HistoryEntry, chId string) {
|
||||||
require.Equal(t, channelID, chId)
|
require.Equal(t, channelID, chId)
|
||||||
cbMutex.Lock()
|
cbMutex.Lock()
|
||||||
missingCalled = true
|
missingCalled = true
|
||||||
missingMsgID = messageId
|
missingMsgID = messageId
|
||||||
missingDepsList = missingDeps // Copy slice
|
missingDepsList = make([]MessageID, len(missingDeps))
|
||||||
|
for i, dep := range missingDeps {
|
||||||
|
missingDepsList[i] = dep.MessageID
|
||||||
|
}
|
||||||
cbMutex.Unlock()
|
cbMutex.Unlock()
|
||||||
wg.Done()
|
wg.Done()
|
||||||
},
|
},
|
||||||
@ -383,7 +386,7 @@ func TestCallbacks_Combined(t *testing.T) {
|
|||||||
// are typically relevant to the Sender. We don't expect this.
|
// are typically relevant to the Sender. We don't expect this.
|
||||||
t.Errorf("Unexpected OnMessageSent call on Receiver for %s", messageId)
|
t.Errorf("Unexpected OnMessageSent call on Receiver for %s", messageId)
|
||||||
},
|
},
|
||||||
OnMissingDependencies: func(messageId MessageID, missingDeps []MessageID, chId string) {
|
OnMissingDependencies: func(messageId MessageID, missingDeps []HistoryEntry, chId string) {
|
||||||
// This callback is registered on Receiver, used for receiverRm2 below
|
// This callback is registered on Receiver, used for receiverRm2 below
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -404,7 +407,7 @@ func TestCallbacks_Combined(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
OnMissingDependencies: func(messageId MessageID, missingDeps []MessageID, chId string) {
|
OnMissingDependencies: func(messageId MessageID, missingDeps []HistoryEntry, chId string) {
|
||||||
// Not expected on sender
|
// Not expected on sender
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -447,11 +450,16 @@ func TestCallbacks_Combined(t *testing.T) {
|
|||||||
defer receiverRm2.Cleanup()
|
defer receiverRm2.Cleanup()
|
||||||
|
|
||||||
callbacksReceiver2 := EventCallbacks{
|
callbacksReceiver2 := EventCallbacks{
|
||||||
OnMissingDependencies: func(messageId MessageID, missingDeps []MessageID, chId string) {
|
OnMissingDependencies: func(messageId MessageID, missingDeps []HistoryEntry, chId string) {
|
||||||
require.Equal(t, channelID, chId)
|
require.Equal(t, channelID, chId)
|
||||||
if messageId == msgID3 {
|
if messageId == msgID3 {
|
||||||
|
// Convert []HistoryEntry to []MessageID for the channel
|
||||||
|
deps := make([]MessageID, len(missingDeps))
|
||||||
|
for i, d := range missingDeps {
|
||||||
|
deps[i] = d.MessageID
|
||||||
|
}
|
||||||
select {
|
select {
|
||||||
case missingChan <- missingDeps:
|
case missingChan <- deps:
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -677,3 +685,61 @@ func TestMultiChannelCallbacks(t *testing.T) {
|
|||||||
require.Equal(t, channel2, readyMessages[ackID2_ch2], "OnMessageReady for ack2 has incorrect channel")
|
require.Equal(t, channel2, readyMessages[ackID2_ch2], "OnMessageReady for ack2 has incorrect channel")
|
||||||
require.Len(t, readyMessages, 2, "Expected exactly 2 ready messages")
|
require.Len(t, readyMessages, 2, "Expected exactly 2 ready messages")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRetrievalHints(t *testing.T) {
|
||||||
|
rm, err := NewReliabilityManager()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer rm.Cleanup()
|
||||||
|
|
||||||
|
channelID := "test-retrieval-hints"
|
||||||
|
|
||||||
|
// Set a retrieval hint provider
|
||||||
|
rm.RegisterCallbacks(EventCallbacks{
|
||||||
|
RetrievalHintProvider: func(messageId MessageID) []byte {
|
||||||
|
return []byte("hint-for-" + messageId)
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
// 1. Send a message to populate the history
|
||||||
|
payload1 := []byte("message one")
|
||||||
|
msgID1 := MessageID("msg-hint-1")
|
||||||
|
wrappedMsg1, err := rm.WrapOutgoingMessage(payload1, msgID1, channelID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 2. Receive the message to add it to history
|
||||||
|
_, err = rm.UnwrapReceivedMessage(wrappedMsg1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 3. Send a second message, which will include the first in its causal history
|
||||||
|
payload2 := []byte("message two")
|
||||||
|
msgID2 := MessageID("msg-hint-2")
|
||||||
|
wrappedMsg2, err := rm.WrapOutgoingMessage(payload2, msgID2, channelID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 4. Unwrap the second message to inspect its causal history
|
||||||
|
// We need a new RM to avoid acknowledging the message
|
||||||
|
rm2, err := NewReliabilityManager()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer rm2.Cleanup()
|
||||||
|
|
||||||
|
rm2.RegisterCallbacks(EventCallbacks{
|
||||||
|
RetrievalHintProvider: func(messageId MessageID) []byte {
|
||||||
|
return []byte("hint-for-" + messageId)
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
unwrappedMsg2, err := rm2.UnwrapReceivedMessage(wrappedMsg2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 5. Check that the causal history contains the retrieval hint
|
||||||
|
require.Greater(t, len(*unwrappedMsg2.MissingDeps), 0, "Expected missing dependencies")
|
||||||
|
foundDep := false
|
||||||
|
for _, dep := range *unwrappedMsg2.MissingDeps {
|
||||||
|
if dep.MessageID == msgID1 {
|
||||||
|
foundDep = true
|
||||||
|
require.Equal(t, []byte("hint-for-"+msgID1), dep.RetrievalHint, "Retrieval hint does not match")
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
require.True(t, foundDep, "Expected to find dependency %s", msgID1)
|
||||||
|
}
|
||||||
|
|||||||
13
sds/types.go
13
sds/types.go
@ -2,8 +2,13 @@ package sds
|
|||||||
|
|
||||||
type MessageID string
|
type MessageID string
|
||||||
|
|
||||||
type UnwrappedMessage struct {
|
type HistoryEntry struct {
|
||||||
Message *[]byte `json:"message"`
|
MessageID MessageID `json:"messageId"`
|
||||||
MissingDeps *[]MessageID `json:"missingDeps"`
|
RetrievalHint []byte `json:"retrievalHint"`
|
||||||
ChannelId *string `json:"channelId"`
|
}
|
||||||
|
|
||||||
|
type UnwrappedMessage struct {
|
||||||
|
Message *[]byte `json:"message"`
|
||||||
|
MissingDeps *[]HistoryEntry `json:"missingDeps"`
|
||||||
|
ChannelId *string `json:"channelId"`
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user