mirror of
https://github.com/logos-messaging/sds-go-bindings.git
synced 2026-01-04 07:03:09 +00:00
feat: Add Support for Multiple Channels in Single Reliability Manager (#5)
This commit is contained in:
parent
be7f61d809
commit
05e797c76b
80
sds/sds.go
80
sds/sds.go
@ -56,9 +56,9 @@ package sds
|
|||||||
// resp must be set != NULL in case interest on retrieving data from the callback
|
// resp must be set != NULL in case interest on retrieving data from the callback
|
||||||
void SdsGoCallback(int ret, char* msg, size_t len, void* resp);
|
void SdsGoCallback(int ret, char* msg, size_t len, void* resp);
|
||||||
|
|
||||||
static void* cGoSdsNewReliabilityManager(const char* channelId, void* resp) {
|
static void* cGoSdsNewReliabilityManager(void* resp) {
|
||||||
// We pass NULL because we are not interested in retrieving data from this callback
|
// We pass NULL because we are not interested in retrieving data from this callback
|
||||||
void* ret = SdsNewReliabilityManager(channelId, (SdsCallBack) SdsGoCallback, resp);
|
void* ret = SdsNewReliabilityManager((SdsCallBack) SdsGoCallback, resp);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -87,16 +87,18 @@ package sds
|
|||||||
}
|
}
|
||||||
|
|
||||||
static void cGoSdsWrapOutgoingMessage(void* rmCtx,
|
static void cGoSdsWrapOutgoingMessage(void* rmCtx,
|
||||||
void* message,
|
void* message,
|
||||||
size_t messageLen,
|
size_t messageLen,
|
||||||
const char* messageId,
|
const char* messageId,
|
||||||
void* resp) {
|
const char* channelId,
|
||||||
|
void* resp) {
|
||||||
SdsWrapOutgoingMessage(rmCtx,
|
SdsWrapOutgoingMessage(rmCtx,
|
||||||
message,
|
message,
|
||||||
messageLen,
|
messageLen,
|
||||||
messageId,
|
messageId,
|
||||||
(SdsCallBack) SdsGoCallback,
|
channelId,
|
||||||
resp);
|
(SdsCallBack) SdsGoCallback,
|
||||||
|
resp);
|
||||||
}
|
}
|
||||||
static void cGoSdsUnwrapReceivedMessage(void* rmCtx,
|
static void cGoSdsUnwrapReceivedMessage(void* rmCtx,
|
||||||
void* message,
|
void* message,
|
||||||
@ -110,14 +112,16 @@ package sds
|
|||||||
}
|
}
|
||||||
|
|
||||||
static void cGoSdsMarkDependenciesMet(void* rmCtx,
|
static void cGoSdsMarkDependenciesMet(void* rmCtx,
|
||||||
char** messageIDs,
|
char** messageIDs,
|
||||||
size_t count,
|
size_t count,
|
||||||
void* resp) {
|
const char* channelId,
|
||||||
|
void* resp) {
|
||||||
SdsMarkDependenciesMet(rmCtx,
|
SdsMarkDependenciesMet(rmCtx,
|
||||||
messageIDs,
|
messageIDs,
|
||||||
count,
|
count,
|
||||||
(SdsCallBack) SdsGoCallback,
|
channelId,
|
||||||
resp);
|
(SdsCallBack) SdsGoCallback,
|
||||||
|
resp);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void cGoSdsStartPeriodicTasks(void* rmCtx, void* resp) {
|
static void cGoSdsStartPeriodicTasks(void* rmCtx, void* resp) {
|
||||||
@ -152,31 +156,25 @@ func SdsGoCallback(ret C.int, msg *C.char, len C.size_t, resp unsafe.Pointer) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type EventCallbacks struct {
|
type EventCallbacks struct {
|
||||||
OnMessageReady func(messageId MessageID)
|
OnMessageReady func(messageId MessageID, channelId string)
|
||||||
OnMessageSent func(messageId MessageID)
|
OnMessageSent func(messageId MessageID, channelId string)
|
||||||
OnMissingDependencies func(messageId MessageID, missingDeps []MessageID)
|
OnMissingDependencies func(messageId MessageID, missingDeps []MessageID, channelId string)
|
||||||
OnPeriodicSync func()
|
OnPeriodicSync func()
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReliabilityManager represents an instance of a nim-sds ReliabilityManager
|
// ReliabilityManager represents an instance of a nim-sds ReliabilityManager
|
||||||
type ReliabilityManager struct {
|
type ReliabilityManager struct {
|
||||||
rmCtx unsafe.Pointer
|
rmCtx unsafe.Pointer
|
||||||
channelId string
|
|
||||||
callbacks EventCallbacks
|
callbacks EventCallbacks
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewReliabilityManager(channelId string) (*ReliabilityManager, error) {
|
func NewReliabilityManager() (*ReliabilityManager, error) {
|
||||||
Debug("Creating new Reliability Manager")
|
Debug("Creating new Reliability Manager")
|
||||||
rm := &ReliabilityManager{
|
rm := &ReliabilityManager{}
|
||||||
channelId: channelId,
|
|
||||||
}
|
|
||||||
|
|
||||||
wg := sync.WaitGroup{}
|
wg := sync.WaitGroup{}
|
||||||
|
|
||||||
var cChannelId = C.CString(string(channelId))
|
|
||||||
var resp = C.allocResp(unsafe.Pointer(&wg))
|
var resp = C.allocResp(unsafe.Pointer(&wg))
|
||||||
|
|
||||||
defer C.free(unsafe.Pointer(cChannelId))
|
|
||||||
defer C.freeResp(resp)
|
defer C.freeResp(resp)
|
||||||
|
|
||||||
if C.getRet(resp) != C.RET_OK {
|
if C.getRet(resp) != C.RET_OK {
|
||||||
@ -186,7 +184,7 @@ func NewReliabilityManager(channelId string) (*ReliabilityManager, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
rm.rmCtx = C.cGoSdsNewReliabilityManager(cChannelId, resp)
|
rm.rmCtx = C.cGoSdsNewReliabilityManager(resp)
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
|
|
||||||
C.cGoSdsSetEventCallback(rm.rmCtx)
|
C.cGoSdsSetEventCallback(rm.rmCtx)
|
||||||
@ -243,11 +241,13 @@ type jsonEvent struct {
|
|||||||
|
|
||||||
type msgEvent struct {
|
type msgEvent struct {
|
||||||
MessageId MessageID `json:"messageId"`
|
MessageId MessageID `json:"messageId"`
|
||||||
|
ChannelId string `json:"channelId"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type missingDepsEvent struct {
|
type missingDepsEvent struct {
|
||||||
MessageId MessageID `json:"messageId"`
|
MessageId MessageID `json:"messageId"`
|
||||||
MissingDeps []MessageID `json:"missingDeps"`
|
MissingDeps []MessageID `json:"missingDeps"`
|
||||||
|
ChannelId string `json:"channelId"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rm *ReliabilityManager) RegisterCallbacks(callbacks EventCallbacks) {
|
func (rm *ReliabilityManager) RegisterCallbacks(callbacks EventCallbacks) {
|
||||||
@ -288,7 +288,7 @@ func (rm *ReliabilityManager) parseMessageReadyEvent(eventStr string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if rm.callbacks.OnMessageReady != nil {
|
if rm.callbacks.OnMessageReady != nil {
|
||||||
rm.callbacks.OnMessageReady(msgEvent.MessageId)
|
rm.callbacks.OnMessageReady(msgEvent.MessageId, msgEvent.ChannelId)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -301,7 +301,7 @@ func (rm *ReliabilityManager) parseMessageSentEvent(eventStr string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if rm.callbacks.OnMessageSent != nil {
|
if rm.callbacks.OnMessageSent != nil {
|
||||||
rm.callbacks.OnMessageSent(msgEvent.MessageId)
|
rm.callbacks.OnMessageSent(msgEvent.MessageId, msgEvent.ChannelId)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -314,7 +314,7 @@ func (rm *ReliabilityManager) parseMissingDepsEvent(eventStr string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if rm.callbacks.OnMissingDependencies != nil {
|
if rm.callbacks.OnMissingDependencies != nil {
|
||||||
rm.callbacks.OnMissingDependencies(missingDepsEvent.MessageId, missingDepsEvent.MissingDeps)
|
rm.callbacks.OnMissingDependencies(missingDepsEvent.MessageId, missingDepsEvent.MissingDeps, missingDepsEvent.ChannelId)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -375,7 +375,7 @@ func (rm *ReliabilityManager) Reset() error {
|
|||||||
return errors.New(errMsg)
|
return errors.New(errMsg)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rm *ReliabilityManager) WrapOutgoingMessage(message []byte, messageId MessageID) ([]byte, error) {
|
func (rm *ReliabilityManager) WrapOutgoingMessage(message []byte, messageId MessageID, channelId string) ([]byte, error) {
|
||||||
if rm == nil {
|
if rm == nil {
|
||||||
err := errors.New("reliability manager is nil in WrapOutgoingMessage")
|
err := errors.New("reliability manager is nil in WrapOutgoingMessage")
|
||||||
Error("Failed to wrap outgoing message %v", err)
|
Error("Failed to wrap outgoing message %v", err)
|
||||||
@ -400,8 +400,11 @@ func (rm *ReliabilityManager) WrapOutgoingMessage(message []byte, messageId Mess
|
|||||||
}
|
}
|
||||||
cMessageLen := C.size_t(len(message))
|
cMessageLen := C.size_t(len(message))
|
||||||
|
|
||||||
|
cChannelId := C.CString(channelId)
|
||||||
|
defer C.free(unsafe.Pointer(cChannelId))
|
||||||
|
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
C.cGoSdsWrapOutgoingMessage(rm.rmCtx, cMessagePtr, cMessageLen, cMessageId, resp)
|
C.cGoSdsWrapOutgoingMessage(rm.rmCtx, cMessagePtr, cMessageLen, cMessageId, cChannelId, resp)
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
|
|
||||||
if C.getRet(resp) == C.RET_OK {
|
if C.getRet(resp) == C.RET_OK {
|
||||||
@ -481,7 +484,7 @@ func (rm *ReliabilityManager) UnwrapReceivedMessage(message []byte) (*UnwrappedM
|
|||||||
}
|
}
|
||||||
|
|
||||||
// MarkDependenciesMet informs the library that dependencies are met
|
// MarkDependenciesMet informs the library that dependencies are met
|
||||||
func (rm *ReliabilityManager) MarkDependenciesMet(messageIDs []MessageID) error {
|
func (rm *ReliabilityManager) MarkDependenciesMet(messageIDs []MessageID, channelId string) error {
|
||||||
if rm == nil {
|
if rm == nil {
|
||||||
err := errors.New("reliability manager is nil in MarkDependenciesMet")
|
err := errors.New("reliability manager is nil in MarkDependenciesMet")
|
||||||
Error("Failed to mark dependencies met %v", err)
|
Error("Failed to mark dependencies met %v", err)
|
||||||
@ -512,8 +515,11 @@ func (rm *ReliabilityManager) MarkDependenciesMet(messageIDs []MessageID) error
|
|||||||
}
|
}
|
||||||
|
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
|
cChannelId := C.CString(channelId)
|
||||||
|
defer C.free(unsafe.Pointer(cChannelId))
|
||||||
|
|
||||||
// Pass the pointer variable (cMessageIDsPtr) directly, which is of type **C.char
|
// Pass the pointer variable (cMessageIDsPtr) directly, which is of type **C.char
|
||||||
C.cGoSdsMarkDependenciesMet(rm.rmCtx, cMessageIDsPtr, C.size_t(len(messageIDs)), resp)
|
C.cGoSdsMarkDependenciesMet(rm.rmCtx, cMessageIDsPtr, C.size_t(len(messageIDs)), cChannelId, resp)
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
|
|
||||||
if C.getRet(resp) == C.RET_OK {
|
if C.getRet(resp) == C.RET_OK {
|
||||||
|
|||||||
236
sds/sds_test.go
236
sds/sds_test.go
@ -10,8 +10,7 @@ import (
|
|||||||
|
|
||||||
// Test basic creation, cleanup, and reset
|
// Test basic creation, cleanup, and reset
|
||||||
func TestLifecycle(t *testing.T) {
|
func TestLifecycle(t *testing.T) {
|
||||||
channelID := "test-lifecycle"
|
rm, err := NewReliabilityManager()
|
||||||
rm, err := NewReliabilityManager(channelID)
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, rm, "Expected ReliabilityManager to be not nil")
|
require.NotNil(t, rm, "Expected ReliabilityManager to be not nil")
|
||||||
|
|
||||||
@ -23,15 +22,15 @@ func TestLifecycle(t *testing.T) {
|
|||||||
|
|
||||||
// Test wrapping and unwrapping a simple message
|
// Test wrapping and unwrapping a simple message
|
||||||
func TestWrapUnwrap(t *testing.T) {
|
func TestWrapUnwrap(t *testing.T) {
|
||||||
channelID := "test-wrap-unwrap"
|
rm, err := NewReliabilityManager()
|
||||||
rm, err := NewReliabilityManager(channelID)
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer rm.Cleanup()
|
defer rm.Cleanup()
|
||||||
|
|
||||||
|
channelID := "test-wrap-unwrap"
|
||||||
originalPayload := []byte("hello reliability")
|
originalPayload := []byte("hello reliability")
|
||||||
messageID := MessageID("msg-wrap-1")
|
messageID := MessageID("msg-wrap-1")
|
||||||
|
|
||||||
wrappedMsg, err := rm.WrapOutgoingMessage(originalPayload, messageID)
|
wrappedMsg, err := rm.WrapOutgoingMessage(originalPayload, messageID, channelID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
require.Greater(t, len(wrappedMsg), 0, "Expected non-empty wrapped message")
|
require.Greater(t, len(wrappedMsg), 0, "Expected non-empty wrapped message")
|
||||||
@ -46,15 +45,16 @@ func TestWrapUnwrap(t *testing.T) {
|
|||||||
|
|
||||||
// Test dependency handling
|
// Test dependency handling
|
||||||
func TestDependencies(t *testing.T) {
|
func TestDependencies(t *testing.T) {
|
||||||
channelID := "test-deps"
|
rm, err := NewReliabilityManager()
|
||||||
rm, err := NewReliabilityManager(channelID)
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer rm.Cleanup()
|
defer rm.Cleanup()
|
||||||
|
|
||||||
|
channelID := "test-deps"
|
||||||
|
|
||||||
// 1. Send message 1 (will become a dependency)
|
// 1. Send message 1 (will become a dependency)
|
||||||
payload1 := []byte("message one")
|
payload1 := []byte("message one")
|
||||||
msgID1 := MessageID("msg-dep-1")
|
msgID1 := MessageID("msg-dep-1")
|
||||||
wrappedMsg1, err := rm.WrapOutgoingMessage(payload1, msgID1)
|
wrappedMsg1, err := rm.WrapOutgoingMessage(payload1, msgID1, channelID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Simulate receiving msg1 to add it to history (implicitly acknowledges it)
|
// Simulate receiving msg1 to add it to history (implicitly acknowledges it)
|
||||||
@ -64,11 +64,11 @@ func TestDependencies(t *testing.T) {
|
|||||||
// 2. Send message 2 (depends on message 1 implicitly via causal history)
|
// 2. Send message 2 (depends on message 1 implicitly via causal history)
|
||||||
payload2 := []byte("message two")
|
payload2 := []byte("message two")
|
||||||
msgID2 := MessageID("msg-dep-2")
|
msgID2 := MessageID("msg-dep-2")
|
||||||
wrappedMsg2, err := rm.WrapOutgoingMessage(payload2, msgID2)
|
wrappedMsg2, err := rm.WrapOutgoingMessage(payload2, msgID2, channelID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// 3. Create a new manager to simulate a different peer receiving msg2 without msg1
|
// 3. Create a new manager to simulate a different peer receiving msg2 without msg1
|
||||||
rm2, err := NewReliabilityManager(channelID) // Same channel ID
|
rm2, err := NewReliabilityManager()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer rm2.Cleanup()
|
defer rm2.Cleanup()
|
||||||
|
|
||||||
@ -88,28 +88,29 @@ func TestDependencies(t *testing.T) {
|
|||||||
require.True(t, foundDep1, "Expected missing dependency %q, got %v", msgID1, *unwrappedMessage2.MissingDeps)
|
require.True(t, foundDep1, "Expected missing dependency %q, got %v", msgID1, *unwrappedMessage2.MissingDeps)
|
||||||
|
|
||||||
// 5. Mark the dependency as met
|
// 5. Mark the dependency as met
|
||||||
err = rm2.MarkDependenciesMet([]MessageID{msgID1})
|
err = rm2.MarkDependenciesMet([]MessageID{msgID1}, channelID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test OnMessageReady callback
|
// Test OnMessageReady callback
|
||||||
func TestCallback_OnMessageReady(t *testing.T) {
|
func TestCallback_OnMessageReady(t *testing.T) {
|
||||||
channelID := "test-cb-ready"
|
|
||||||
|
|
||||||
// Create sender and receiver RMs
|
// Create sender and receiver RMs
|
||||||
senderRm, err := NewReliabilityManager(channelID)
|
senderRm, err := NewReliabilityManager()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer senderRm.Cleanup()
|
defer senderRm.Cleanup()
|
||||||
|
|
||||||
receiverRm, err := NewReliabilityManager(channelID)
|
receiverRm, err := NewReliabilityManager()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer receiverRm.Cleanup()
|
defer receiverRm.Cleanup()
|
||||||
|
|
||||||
|
channelID := "test-cb-ready"
|
||||||
|
|
||||||
// Use a channel for signaling
|
// Use a channel for signaling
|
||||||
readyChan := make(chan MessageID, 1)
|
readyChan := make(chan MessageID, 1)
|
||||||
|
|
||||||
callbacks := EventCallbacks{
|
callbacks := EventCallbacks{
|
||||||
OnMessageReady: func(messageId MessageID) {
|
OnMessageReady: func(messageId MessageID, chId string) {
|
||||||
|
require.Equal(t, channelID, chId)
|
||||||
// Non-blocking send to channel
|
// Non-blocking send to channel
|
||||||
select {
|
select {
|
||||||
case readyChan <- messageId:
|
case readyChan <- messageId:
|
||||||
@ -127,7 +128,7 @@ func TestCallback_OnMessageReady(t *testing.T) {
|
|||||||
msgID := MessageID("cb-ready-1")
|
msgID := MessageID("cb-ready-1")
|
||||||
|
|
||||||
// Wrap on sender
|
// Wrap on sender
|
||||||
wrappedMsg, err := senderRm.WrapOutgoingMessage(payload, msgID)
|
wrappedMsg, err := senderRm.WrapOutgoingMessage(payload, msgID, channelID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Unwrap on receiver
|
// Unwrap on receiver
|
||||||
@ -149,24 +150,25 @@ func TestCallback_OnMessageReady(t *testing.T) {
|
|||||||
|
|
||||||
// Test OnMessageSent callback (via causal history ACK)
|
// Test OnMessageSent callback (via causal history ACK)
|
||||||
func TestCallback_OnMessageSent(t *testing.T) {
|
func TestCallback_OnMessageSent(t *testing.T) {
|
||||||
channelID := "test-cb-sent"
|
|
||||||
|
|
||||||
// Create two RMs
|
// Create two RMs
|
||||||
rm1, err := NewReliabilityManager(channelID)
|
rm1, err := NewReliabilityManager()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer rm1.Cleanup()
|
defer rm1.Cleanup()
|
||||||
|
|
||||||
rm2, err := NewReliabilityManager(channelID)
|
rm2, err := NewReliabilityManager()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer rm2.Cleanup()
|
defer rm2.Cleanup()
|
||||||
|
|
||||||
|
channelID := "test-cb-sent"
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
sentCalled := false
|
sentCalled := false
|
||||||
var sentMsgID MessageID
|
var sentMsgID MessageID
|
||||||
var cbMutex sync.Mutex
|
var cbMutex sync.Mutex
|
||||||
|
|
||||||
callbacks := EventCallbacks{
|
callbacks := EventCallbacks{
|
||||||
OnMessageSent: func(messageId MessageID) {
|
OnMessageSent: func(messageId MessageID, chId string) {
|
||||||
|
require.Equal(t, channelID, chId)
|
||||||
cbMutex.Lock()
|
cbMutex.Lock()
|
||||||
sentCalled = true
|
sentCalled = true
|
||||||
sentMsgID = messageId
|
sentMsgID = messageId
|
||||||
@ -184,7 +186,7 @@ func TestCallback_OnMessageSent(t *testing.T) {
|
|||||||
// 1. rm1 sends msg1
|
// 1. rm1 sends msg1
|
||||||
payload1 := []byte("sent test 1")
|
payload1 := []byte("sent test 1")
|
||||||
msgID1 := MessageID("cb-sent-1")
|
msgID1 := MessageID("cb-sent-1")
|
||||||
wrappedMsg1, err := rm1.WrapOutgoingMessage(payload1, msgID1)
|
wrappedMsg1, err := rm1.WrapOutgoingMessage(payload1, msgID1, channelID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
// Note: msg1 is now in rm1's outgoing buffer
|
// Note: msg1 is now in rm1's outgoing buffer
|
||||||
|
|
||||||
@ -195,7 +197,7 @@ func TestCallback_OnMessageSent(t *testing.T) {
|
|||||||
// 3. rm2 sends msg2 (will include msg1 in causal history)
|
// 3. rm2 sends msg2 (will include msg1 in causal history)
|
||||||
payload2 := []byte("sent test 2")
|
payload2 := []byte("sent test 2")
|
||||||
msgID2 := MessageID("cb-sent-2")
|
msgID2 := MessageID("cb-sent-2")
|
||||||
wrappedMsg2, err := rm2.WrapOutgoingMessage(payload2, msgID2)
|
wrappedMsg2, err := rm2.WrapOutgoingMessage(payload2, msgID2, channelID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// 4. rm1 receives msg2 (should trigger ACK for msg1)
|
// 4. rm1 receives msg2 (should trigger ACK for msg1)
|
||||||
@ -219,17 +221,17 @@ func TestCallback_OnMessageSent(t *testing.T) {
|
|||||||
|
|
||||||
// Test OnMissingDependencies callback
|
// Test OnMissingDependencies callback
|
||||||
func TestCallback_OnMissingDependencies(t *testing.T) {
|
func TestCallback_OnMissingDependencies(t *testing.T) {
|
||||||
channelID := "test-cb-missing"
|
|
||||||
|
|
||||||
// Use separate sender/receiver RMs explicitly
|
// Use separate sender/receiver RMs explicitly
|
||||||
senderRm, err := NewReliabilityManager(channelID)
|
senderRm, err := NewReliabilityManager()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer senderRm.Cleanup()
|
defer senderRm.Cleanup()
|
||||||
|
|
||||||
receiverRm, err := NewReliabilityManager(channelID)
|
receiverRm, err := NewReliabilityManager()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer receiverRm.Cleanup()
|
defer receiverRm.Cleanup()
|
||||||
|
|
||||||
|
channelID := "test-cb-missing"
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
missingCalled := false
|
missingCalled := false
|
||||||
var missingMsgID MessageID
|
var missingMsgID MessageID
|
||||||
@ -237,7 +239,8 @@ func TestCallback_OnMissingDependencies(t *testing.T) {
|
|||||||
var cbMutex sync.Mutex
|
var cbMutex sync.Mutex
|
||||||
|
|
||||||
callbacks := EventCallbacks{
|
callbacks := EventCallbacks{
|
||||||
OnMissingDependencies: func(messageId MessageID, missingDeps []MessageID) {
|
OnMissingDependencies: func(messageId MessageID, missingDeps []MessageID, chId string) {
|
||||||
|
require.Equal(t, channelID, chId)
|
||||||
cbMutex.Lock()
|
cbMutex.Lock()
|
||||||
missingCalled = true
|
missingCalled = true
|
||||||
missingMsgID = messageId
|
missingMsgID = messageId
|
||||||
@ -256,13 +259,13 @@ func TestCallback_OnMissingDependencies(t *testing.T) {
|
|||||||
// 1. Sender sends msg1
|
// 1. Sender sends msg1
|
||||||
payload1 := []byte("missing test 1")
|
payload1 := []byte("missing test 1")
|
||||||
msgID1 := MessageID("cb-miss-1")
|
msgID1 := MessageID("cb-miss-1")
|
||||||
_, err = senderRm.WrapOutgoingMessage(payload1, msgID1)
|
_, err = senderRm.WrapOutgoingMessage(payload1, msgID1, channelID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// 2. Sender sends msg2 (depends on msg1)
|
// 2. Sender sends msg2 (depends on msg1)
|
||||||
payload2 := []byte("missing test 2")
|
payload2 := []byte("missing test 2")
|
||||||
msgID2 := MessageID("cb-miss-2")
|
msgID2 := MessageID("cb-miss-2")
|
||||||
wrappedMsg2, err := senderRm.WrapOutgoingMessage(payload2, msgID2)
|
wrappedMsg2, err := senderRm.WrapOutgoingMessage(payload2, msgID2, channelID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// 3. Receiver receives msg2 (haven't seen msg1)
|
// 3. Receiver receives msg2 (haven't seen msg1)
|
||||||
@ -295,8 +298,7 @@ func TestCallback_OnMissingDependencies(t *testing.T) {
|
|||||||
|
|
||||||
// Test OnPeriodicSync callback
|
// Test OnPeriodicSync callback
|
||||||
func TestCallback_OnPeriodicSync(t *testing.T) {
|
func TestCallback_OnPeriodicSync(t *testing.T) {
|
||||||
channelID := "test-cb-sync"
|
rm, err := NewReliabilityManager()
|
||||||
rm, err := NewReliabilityManager(channelID)
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer rm.Cleanup()
|
defer rm.Cleanup()
|
||||||
|
|
||||||
@ -341,17 +343,17 @@ func TestCallback_OnPeriodicSync(t *testing.T) {
|
|||||||
|
|
||||||
// Combined Test for multiple callbacks
|
// Combined Test for multiple callbacks
|
||||||
func TestCallbacks_Combined(t *testing.T) {
|
func TestCallbacks_Combined(t *testing.T) {
|
||||||
channelID := "test-cb-combined"
|
|
||||||
|
|
||||||
// Create sender and receiver RMs
|
// Create sender and receiver RMs
|
||||||
senderRm, err := NewReliabilityManager(channelID)
|
senderRm, err := NewReliabilityManager()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer senderRm.Cleanup()
|
defer senderRm.Cleanup()
|
||||||
|
|
||||||
receiverRm, err := NewReliabilityManager(channelID)
|
receiverRm, err := NewReliabilityManager()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer receiverRm.Cleanup()
|
defer receiverRm.Cleanup()
|
||||||
|
|
||||||
|
channelID := "test-cb-combined"
|
||||||
|
|
||||||
// Channels for synchronization
|
// Channels for synchronization
|
||||||
readyChan1 := make(chan bool, 1)
|
readyChan1 := make(chan bool, 1)
|
||||||
sentChan1 := make(chan bool, 1)
|
sentChan1 := make(chan bool, 1)
|
||||||
@ -363,7 +365,8 @@ func TestCallbacks_Combined(t *testing.T) {
|
|||||||
var cbMutex sync.Mutex
|
var cbMutex sync.Mutex
|
||||||
|
|
||||||
callbacksReceiver := EventCallbacks{
|
callbacksReceiver := EventCallbacks{
|
||||||
OnMessageReady: func(messageId MessageID) {
|
OnMessageReady: func(messageId MessageID, chId string) {
|
||||||
|
require.Equal(t, channelID, chId)
|
||||||
cbMutex.Lock()
|
cbMutex.Lock()
|
||||||
receivedReady[messageId] = true
|
receivedReady[messageId] = true
|
||||||
cbMutex.Unlock()
|
cbMutex.Unlock()
|
||||||
@ -375,21 +378,22 @@ func TestCallbacks_Combined(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
OnMessageSent: func(messageId MessageID) {
|
OnMessageSent: func(messageId MessageID, chId string) {
|
||||||
// This callback is registered on Receiver, but Sent events
|
// This callback is registered on Receiver, but Sent events
|
||||||
// are typically relevant to the Sender. We don't expect this.
|
// are typically relevant to the Sender. We don't expect this.
|
||||||
t.Errorf("Unexpected OnMessageSent call on Receiver for %s", messageId)
|
t.Errorf("Unexpected OnMessageSent call on Receiver for %s", messageId)
|
||||||
},
|
},
|
||||||
OnMissingDependencies: func(messageId MessageID, missingDeps []MessageID) {
|
OnMissingDependencies: func(messageId MessageID, missingDeps []MessageID, chId string) {
|
||||||
// This callback is registered on Receiver, used for receiverRm2 below
|
// This callback is registered on Receiver, used for receiverRm2 below
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
callbacksSender := EventCallbacks{
|
callbacksSender := EventCallbacks{
|
||||||
OnMessageReady: func(messageId MessageID) {
|
OnMessageReady: func(messageId MessageID, chId string) {
|
||||||
// Not expected on sender in this test flow
|
// Not expected on sender in this test flow
|
||||||
},
|
},
|
||||||
OnMessageSent: func(messageId MessageID) {
|
OnMessageSent: func(messageId MessageID, chId string) {
|
||||||
|
require.Equal(t, channelID, chId)
|
||||||
cbMutex.Lock()
|
cbMutex.Lock()
|
||||||
receivedSent[messageId] = true
|
receivedSent[messageId] = true
|
||||||
cbMutex.Unlock()
|
cbMutex.Unlock()
|
||||||
@ -400,7 +404,7 @@ func TestCallbacks_Combined(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
OnMissingDependencies: func(messageId MessageID, missingDeps []MessageID) {
|
OnMissingDependencies: func(messageId MessageID, missingDeps []MessageID, chId string) {
|
||||||
// Not expected on sender
|
// Not expected on sender
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -418,7 +422,7 @@ func TestCallbacks_Combined(t *testing.T) {
|
|||||||
payload3 := []byte("combined test 3")
|
payload3 := []byte("combined test 3")
|
||||||
|
|
||||||
// 1. Sender sends msg1
|
// 1. Sender sends msg1
|
||||||
wrappedMsg1, err := senderRm.WrapOutgoingMessage(payload1, msgID1)
|
wrappedMsg1, err := senderRm.WrapOutgoingMessage(payload1, msgID1, channelID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// 2. Receiver receives msg1
|
// 2. Receiver receives msg1
|
||||||
@ -426,7 +430,7 @@ func TestCallbacks_Combined(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// 3. Receiver sends msg2 (depends on msg1 implicitly via state)
|
// 3. Receiver sends msg2 (depends on msg1 implicitly via state)
|
||||||
wrappedMsg2, err := receiverRm.WrapOutgoingMessage(payload2, msgID2)
|
wrappedMsg2, err := receiverRm.WrapOutgoingMessage(payload2, msgID2, channelID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// 4. Sender receives msg2 from Receiver (acks msg1 for sender)
|
// 4. Sender receives msg2 from Receiver (acks msg1 for sender)
|
||||||
@ -434,16 +438,17 @@ func TestCallbacks_Combined(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// 5. Sender sends msg3 (depends on msg2)
|
// 5. Sender sends msg3 (depends on msg2)
|
||||||
wrappedMsg3, err := senderRm.WrapOutgoingMessage(payload3, msgID3)
|
wrappedMsg3, err := senderRm.WrapOutgoingMessage(payload3, msgID3, channelID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// 6. Create Receiver2, register missing deps callback
|
// 6. Create Receiver2, register missing deps callback
|
||||||
receiverRm2, err := NewReliabilityManager(channelID)
|
receiverRm2, err := NewReliabilityManager()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer receiverRm2.Cleanup()
|
defer receiverRm2.Cleanup()
|
||||||
|
|
||||||
callbacksReceiver2 := EventCallbacks{
|
callbacksReceiver2 := EventCallbacks{
|
||||||
OnMissingDependencies: func(messageId MessageID, missingDeps []MessageID) {
|
OnMissingDependencies: func(messageId MessageID, missingDeps []MessageID, chId string) {
|
||||||
|
require.Equal(t, channelID, chId)
|
||||||
if messageId == msgID3 {
|
if messageId == msgID3 {
|
||||||
select {
|
select {
|
||||||
case missingChan <- missingDeps:
|
case missingChan <- missingDeps:
|
||||||
@ -537,3 +542,138 @@ func waitTimeout(wg *sync.WaitGroup, timeout time.Duration, t *testing.T) {
|
|||||||
t.Fatalf("Timed out waiting for callbacks")
|
t.Fatalf("Timed out waiting for callbacks")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test multi-channel functionality - one RM can handle messages from different channels
|
||||||
|
func TestMultiChannel_SingleRM(t *testing.T) {
|
||||||
|
rm, err := NewReliabilityManager()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer rm.Cleanup()
|
||||||
|
|
||||||
|
// Test with two different channels
|
||||||
|
channel1 := "test-channel-1"
|
||||||
|
channel2 := "test-channel-2"
|
||||||
|
|
||||||
|
// Create and wrap messages for different channels
|
||||||
|
msg1 := []byte("message for channel 1")
|
||||||
|
msgID1 := MessageID("msg1")
|
||||||
|
wrappedMsg1, err := rm.WrapOutgoingMessage(msg1, msgID1, channel1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
msg2 := []byte("message for channel 2")
|
||||||
|
msgID2 := MessageID("msg2")
|
||||||
|
wrappedMsg2, err := rm.WrapOutgoingMessage(msg2, msgID2, channel2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Unwrap messages - should extract channel ID and route correctly
|
||||||
|
unwrapped1, err := rm.UnwrapReceivedMessage(wrappedMsg1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, msg1, *unwrapped1.Message)
|
||||||
|
|
||||||
|
// Verify channel ID is extracted correctly
|
||||||
|
require.NotNil(t, unwrapped1.ChannelId, "Expected ChannelId to be not nil")
|
||||||
|
require.Equal(t, channel1, *unwrapped1.ChannelId)
|
||||||
|
|
||||||
|
unwrapped2, err := rm.UnwrapReceivedMessage(wrappedMsg2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, msg2, *unwrapped2.Message)
|
||||||
|
|
||||||
|
// Verify channel ID is extracted correctly
|
||||||
|
require.NotNil(t, unwrapped2.ChannelId, "Expected ChannelId to be not nil")
|
||||||
|
require.Equal(t, channel2, *unwrapped2.ChannelId)
|
||||||
|
|
||||||
|
// Test that the same RM can handle dependencies for different channels
|
||||||
|
err = rm.MarkDependenciesMet([]MessageID{msgID1}, channel1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = rm.MarkDependenciesMet([]MessageID{msgID2}, channel2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test that callbacks are correctly triggered for multiple channels
|
||||||
|
func TestMultiChannelCallbacks(t *testing.T) {
|
||||||
|
// rm1 is the manager we are testing callbacks on
|
||||||
|
rm1, err := NewReliabilityManager()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer rm1.Cleanup()
|
||||||
|
|
||||||
|
// rm2 simulates another peer
|
||||||
|
rm2, err := NewReliabilityManager()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer rm2.Cleanup()
|
||||||
|
|
||||||
|
channel1 := "callback-channel-1"
|
||||||
|
channel2 := "callback-channel-2"
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
var cbMutex sync.Mutex
|
||||||
|
readyMessages := make(map[MessageID]string)
|
||||||
|
sentMessages := make(map[MessageID]string)
|
||||||
|
|
||||||
|
callbacks := EventCallbacks{
|
||||||
|
OnMessageReady: func(messageId MessageID, channelId string) {
|
||||||
|
// This will be triggered when rm1 receives messages from rm2
|
||||||
|
cbMutex.Lock()
|
||||||
|
readyMessages[messageId] = channelId
|
||||||
|
cbMutex.Unlock()
|
||||||
|
wg.Done()
|
||||||
|
},
|
||||||
|
OnMessageSent: func(messageId MessageID, channelId string) {
|
||||||
|
// This will be triggered when rm1's messages are acked by rm2
|
||||||
|
cbMutex.Lock()
|
||||||
|
sentMessages[messageId] = channelId
|
||||||
|
cbMutex.Unlock()
|
||||||
|
wg.Done()
|
||||||
|
},
|
||||||
|
}
|
||||||
|
rm1.RegisterCallbacks(callbacks)
|
||||||
|
|
||||||
|
// --- Test Scenario ---
|
||||||
|
// 1. rm1 sends one message on each channel.
|
||||||
|
msgID1_ch1 := MessageID("msg-on-ch1")
|
||||||
|
wrappedMsg1_ch1, err := rm1.WrapOutgoingMessage([]byte("payload1"), msgID1_ch1, channel1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
msgID2_ch2 := MessageID("msg-on-ch2")
|
||||||
|
wrappedMsg2_ch2, err := rm1.WrapOutgoingMessage([]byte("payload2"), msgID2_ch2, channel2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 2. rm2 receives these messages to update its history.
|
||||||
|
_, err = rm2.UnwrapReceivedMessage(wrappedMsg1_ch1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, err = rm2.UnwrapReceivedMessage(wrappedMsg2_ch2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 3. rm2 sends messages back on each channel, which will act as ACKs for rm1's messages.
|
||||||
|
ackID1_ch1 := MessageID("ack-on-ch1")
|
||||||
|
wrappedAck1_ch1, err := rm2.WrapOutgoingMessage([]byte("ack_payload1"), ackID1_ch1, channel1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
ackID2_ch2 := MessageID("ack-on-ch2")
|
||||||
|
wrappedAck2_ch2, err := rm2.WrapOutgoingMessage([]byte("ack_payload2"), ackID2_ch2, channel2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 4. rm1 receives the "ack" messages. This should trigger:
|
||||||
|
// - OnMessageSent for msgID1_ch1 and msgID2_ch2
|
||||||
|
// - OnMessageReady for ackID1_ch1 and ackID2_ch2
|
||||||
|
wg.Add(4) // Expect 2 sent and 2 ready callbacks
|
||||||
|
_, err = rm1.UnwrapReceivedMessage(wrappedAck1_ch1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, err = rm1.UnwrapReceivedMessage(wrappedAck2_ch2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// --- Verification ---
|
||||||
|
waitTimeout(&wg, 5*time.Second, t)
|
||||||
|
|
||||||
|
cbMutex.Lock()
|
||||||
|
defer cbMutex.Unlock()
|
||||||
|
|
||||||
|
// Check that both original messages were marked as sent and on the correct channel
|
||||||
|
require.Equal(t, channel1, sentMessages[msgID1_ch1], "OnMessageSent for msg1 has incorrect channel")
|
||||||
|
require.Equal(t, channel2, sentMessages[msgID2_ch2], "OnMessageSent for msg2 has incorrect channel")
|
||||||
|
require.Len(t, sentMessages, 2, "Expected exactly 2 sent messages")
|
||||||
|
|
||||||
|
// Check that both ack messages were marked as ready and on the correct channel
|
||||||
|
require.Equal(t, channel1, readyMessages[ackID1_ch1], "OnMessageReady for ack1 has incorrect channel")
|
||||||
|
require.Equal(t, channel2, readyMessages[ackID2_ch2], "OnMessageReady for ack2 has incorrect channel")
|
||||||
|
require.Len(t, readyMessages, 2, "Expected exactly 2 ready messages")
|
||||||
|
}
|
||||||
|
|||||||
@ -5,4 +5,5 @@ type MessageID string
|
|||||||
type UnwrappedMessage struct {
|
type UnwrappedMessage struct {
|
||||||
Message *[]byte `json:"message"`
|
Message *[]byte `json:"message"`
|
||||||
MissingDeps *[]MessageID `json:"missingDeps"`
|
MissingDeps *[]MessageID `json:"missingDeps"`
|
||||||
|
ChannelId *string `json:"channelId"`
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user