mirror of
https://github.com/logos-messaging/sds-go-bindings.git
synced 2026-01-02 06:03:12 +00:00
Merge pull request #1 from waku-org/feat-initial-setup
feat: initial implementation
This commit is contained in:
commit
78ea8e7f8f
25
.gitignore
vendored
Normal file
25
.gitignore
vendored
Normal file
@ -0,0 +1,25 @@
|
||||
# If you prefer the allow list template instead of the deny list, see community template:
|
||||
# https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore
|
||||
#
|
||||
# Binaries for programs and plugins
|
||||
*.exe
|
||||
*.exe~
|
||||
*.dll
|
||||
*.so
|
||||
*.dylib
|
||||
|
||||
# Test binary, built with `go test -c`
|
||||
*.test
|
||||
|
||||
# Output of the go coverage tool, specifically when used with LiteIDE
|
||||
*.out
|
||||
|
||||
# Dependency directories (remove the comment below to include it)
|
||||
# vendor/
|
||||
|
||||
# Go workspace file
|
||||
go.work
|
||||
|
||||
# Generated dependencies and cache
|
||||
third_party
|
||||
nimcache
|
||||
46
README.md
46
README.md
@ -1 +1,45 @@
|
||||
# sds-go-bindings
|
||||
# SDS Go Bindings
|
||||
|
||||
This repository provides Go bindings for the SDS library, enabling seamless integration with Go projects.
|
||||
|
||||
## Installation
|
||||
|
||||
To build the required dependencies for this module, the `make` command needs to be executed. If you are integrating this module into another project via `go get`, ensure that you navigate to the `sds-go-bindings/sds` directory and run `make`.
|
||||
|
||||
### Steps to Install
|
||||
|
||||
Follow these steps to install and set up the module:
|
||||
|
||||
1. Retrieve the module using `go get`:
|
||||
```
|
||||
go get -u github.com/waku-org/sds-go-bindings
|
||||
```
|
||||
2. Navigate to the module's directory:
|
||||
```
|
||||
cd $(go list -m -f '{{.Dir}}' github.com/waku-org/sds-go-bindings)
|
||||
```
|
||||
3. Prepare third_party directory which will clone `nim-sds`
|
||||
```
|
||||
sudo mkdir third_party
|
||||
sudo chown $USER third_party
|
||||
```
|
||||
4. Build the dependencies:
|
||||
```
|
||||
make -C sds
|
||||
```
|
||||
|
||||
Now the module is ready for use in your project.
|
||||
|
||||
### Note
|
||||
|
||||
In order to easily build the libsds library on demand, it is recommended to add the following target in your project's Makefile:
|
||||
|
||||
```
|
||||
LIBSDS_DEP_PATH=$(shell go list -m -f '{{.Dir}}' github.com/waku-org/sds-go-bindings)
|
||||
|
||||
buildlib:
|
||||
cd $(LIBSDS_DEP_PATH) &&\
|
||||
sudo mkdir -p third_party &&\
|
||||
sudo chown $(USER) third_party &&\
|
||||
make -C sds
|
||||
```
|
||||
|
||||
15
go.mod
Normal file
15
go.mod
Normal file
@ -0,0 +1,15 @@
|
||||
module github.com/waku-org/sds-go-bindings
|
||||
|
||||
go 1.22.10
|
||||
|
||||
require (
|
||||
github.com/stretchr/testify v1.8.1
|
||||
go.uber.org/zap v1.27.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
go.uber.org/multierr v1.10.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
20
go.sum
Normal file
20
go.sum
Normal file
@ -0,0 +1,20 @@
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
|
||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ=
|
||||
go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
|
||||
go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8=
|
||||
go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
44
sds/Makefile
Normal file
44
sds/Makefile
Normal file
@ -0,0 +1,44 @@
|
||||
# Makefile for SDS Go Bindings
|
||||
|
||||
# Directories
|
||||
THIRD_PARTY_DIR := ../third_party
|
||||
NIM_SDS_REPO := https://github.com/waku-org/nim-sds
|
||||
NIM_SDS_DIR := $(THIRD_PARTY_DIR)/nim-sds
|
||||
|
||||
.PHONY: all clean prepare build-libsds build
|
||||
|
||||
# Default target
|
||||
all: build
|
||||
|
||||
# Prepare third_party directory and clone nim-sds
|
||||
# TODO: remove the "git checkout gabrielmer-feat-init-implementation" part
|
||||
prepare:
|
||||
@echo "Creating third_party directory..."
|
||||
@mkdir -p $(THIRD_PARTY_DIR)
|
||||
|
||||
@echo "Cloning nim-sds repository..."
|
||||
@if [ ! -d "$(NIM_SDS_DIR)" ]; then \
|
||||
cd $(THIRD_PARTY_DIR) && \
|
||||
git clone $(NIM_SDS_REPO) && \
|
||||
cd $(NIM_SDS_DIR) && \
|
||||
git checkout gabrielmer-feat-init-implementation; \
|
||||
make update; \
|
||||
else \
|
||||
echo "nim-sds repository already exists."; \
|
||||
fi
|
||||
|
||||
# Build libsds
|
||||
build-libsds: prepare
|
||||
@echo "Building libsds..."
|
||||
@cd $(NIM_SDS_DIR) && make libsds
|
||||
|
||||
# Build SDS Go Bindings
|
||||
build: build-libsds
|
||||
@echo "Building SDS Go Bindings..."
|
||||
go build ./...
|
||||
|
||||
# Clean up generated files
|
||||
clean:
|
||||
@echo "Cleaning up..."
|
||||
@rm -rf $(THIRD_PARTY_DIR)
|
||||
@rm -f sds-go-bindings
|
||||
47
sds/logging.go
Normal file
47
sds/logging.go
Normal file
@ -0,0 +1,47 @@
|
||||
package sds
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
var (
|
||||
once sync.Once
|
||||
sugar *zap.SugaredLogger
|
||||
)
|
||||
|
||||
func _getLogger() *zap.SugaredLogger {
|
||||
once.Do(func() {
|
||||
|
||||
config := zap.NewDevelopmentConfig()
|
||||
l, err := config.Build()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
sugar = l.Sugar()
|
||||
})
|
||||
return sugar
|
||||
}
|
||||
|
||||
func SetLogger(newLogger *zap.Logger) {
|
||||
once.Do(func() {})
|
||||
|
||||
sugar = newLogger.Sugar()
|
||||
}
|
||||
|
||||
func Debug(msg string, args ...interface{}) {
|
||||
_getLogger().Debugf(msg, args...)
|
||||
}
|
||||
|
||||
func Info(msg string, args ...interface{}) {
|
||||
_getLogger().Infof(msg, args...)
|
||||
}
|
||||
|
||||
func Warn(msg string, args ...interface{}) {
|
||||
_getLogger().Warnf(msg, args...)
|
||||
}
|
||||
|
||||
func Error(msg string, args ...interface{}) {
|
||||
_getLogger().Errorf(msg, args...)
|
||||
}
|
||||
556
sds/sds.go
Normal file
556
sds/sds.go
Normal file
@ -0,0 +1,556 @@
|
||||
package sds
|
||||
|
||||
/*
|
||||
#cgo LDFLAGS: -L../third_party/nim-sds/build/ -lsds
|
||||
#cgo LDFLAGS: -L../third_party/nim-sds -Wl,-rpath,../third_party/nim-sds/build/
|
||||
|
||||
#include "../third_party/nim-sds/library/libsds.h"
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
extern void globalEventCallback(int ret, char* msg, size_t len, void* userData);
|
||||
|
||||
typedef struct {
|
||||
int ret;
|
||||
char* msg;
|
||||
size_t len;
|
||||
void* ffiWg;
|
||||
} Resp;
|
||||
|
||||
static void* allocResp(void* wg) {
|
||||
Resp* r = calloc(1, sizeof(Resp));
|
||||
r->ffiWg = wg;
|
||||
return r;
|
||||
}
|
||||
|
||||
static void freeResp(void* resp) {
|
||||
if (resp != NULL) {
|
||||
free(resp);
|
||||
}
|
||||
}
|
||||
|
||||
static char* getMyCharPtr(void* resp) {
|
||||
if (resp == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
Resp* m = (Resp*) resp;
|
||||
return m->msg;
|
||||
}
|
||||
|
||||
static size_t getMyCharLen(void* resp) {
|
||||
if (resp == NULL) {
|
||||
return 0;
|
||||
}
|
||||
Resp* m = (Resp*) resp;
|
||||
return m->len;
|
||||
}
|
||||
|
||||
static int getRet(void* resp) {
|
||||
if (resp == NULL) {
|
||||
return 0;
|
||||
}
|
||||
Resp* m = (Resp*) resp;
|
||||
return m->ret;
|
||||
}
|
||||
|
||||
// resp must be set != NULL in case interest on retrieving data from the callback
|
||||
void GoCallback(int ret, char* msg, size_t len, void* resp);
|
||||
|
||||
static void* cGoNewReliabilityManager(const char* channelId, void* resp) {
|
||||
// We pass NULL because we are not interested in retrieving data from this callback
|
||||
void* ret = NewReliabilityManager(channelId, (SdsCallBack) GoCallback, resp);
|
||||
return ret;
|
||||
}
|
||||
|
||||
static void cGoSetEventCallback(void* rmCtx) {
|
||||
// The 'globalEventCallback' Go function is shared amongst all possible Reliability Manager instances.
|
||||
|
||||
// Given that the 'globalEventCallback' is shared, we pass again the
|
||||
// rmCtx instance but in this case is needed to pick up the correct method
|
||||
// that will handle the event.
|
||||
|
||||
// In other words, for every call libsds makes to globalEventCallback,
|
||||
// the 'userData' parameter will bring the context of the rm that registered
|
||||
// that globalEventCallback.
|
||||
|
||||
// This technique is needed because cgo only allows to export Go functions and not methods.
|
||||
|
||||
SetEventCallback(rmCtx, (SdsCallBack) globalEventCallback, rmCtx);
|
||||
}
|
||||
|
||||
static void cGoCleanupReliabilityManager(void* rmCtx, void* resp) {
|
||||
CleanupReliabilityManager(rmCtx, (SdsCallBack) GoCallback, resp);
|
||||
}
|
||||
|
||||
static void cGoResetReliabilityManager(void* rmCtx, void* resp) {
|
||||
ResetReliabilityManager(rmCtx, (SdsCallBack) GoCallback, resp);
|
||||
}
|
||||
|
||||
static void cGoWrapOutgoingMessage(void* rmCtx,
|
||||
void* message,
|
||||
size_t messageLen,
|
||||
const char* messageId,
|
||||
void* resp) {
|
||||
WrapOutgoingMessage(rmCtx,
|
||||
message,
|
||||
messageLen,
|
||||
messageId,
|
||||
(SdsCallBack) GoCallback,
|
||||
resp);
|
||||
}
|
||||
static void cGoUnwrapReceivedMessage(void* rmCtx,
|
||||
void* message,
|
||||
size_t messageLen,
|
||||
void* resp) {
|
||||
UnwrapReceivedMessage(rmCtx,
|
||||
message,
|
||||
messageLen,
|
||||
(SdsCallBack) GoCallback,
|
||||
resp);
|
||||
}
|
||||
|
||||
static void cGoMarkDependenciesMet(void* rmCtx,
|
||||
char** messageIDs,
|
||||
size_t count,
|
||||
void* resp) {
|
||||
MarkDependenciesMet(rmCtx,
|
||||
messageIDs,
|
||||
count,
|
||||
(SdsCallBack) GoCallback,
|
||||
resp);
|
||||
}
|
||||
|
||||
static void cGoStartPeriodicTasks(void* rmCtx, void* resp) {
|
||||
StartPeriodicTasks(rmCtx, (SdsCallBack) GoCallback, resp);
|
||||
}
|
||||
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
const requestTimeout = 30 * time.Second
|
||||
const EventChanBufferSize = 1024
|
||||
|
||||
//export GoCallback
|
||||
func GoCallback(ret C.int, msg *C.char, len C.size_t, resp unsafe.Pointer) {
|
||||
if resp != nil {
|
||||
m := (*C.Resp)(resp)
|
||||
m.ret = ret
|
||||
m.msg = msg
|
||||
m.len = len
|
||||
wg := (*sync.WaitGroup)(m.ffiWg)
|
||||
wg.Done()
|
||||
}
|
||||
}
|
||||
|
||||
type EventCallbacks struct {
|
||||
OnMessageReady func(messageId MessageID)
|
||||
OnMessageSent func(messageId MessageID)
|
||||
OnMissingDependencies func(messageId MessageID, missingDeps []MessageID)
|
||||
OnPeriodicSync func()
|
||||
}
|
||||
|
||||
// ReliabilityManager represents an instance of a nim-sds ReliabilityManager
|
||||
type ReliabilityManager struct {
|
||||
rmCtx unsafe.Pointer
|
||||
channelId string
|
||||
callbacks EventCallbacks
|
||||
}
|
||||
|
||||
func NewReliabilityManager(channelId string) (*ReliabilityManager, error) {
|
||||
Debug("Creating new Reliability Manager")
|
||||
rm := &ReliabilityManager{
|
||||
channelId: channelId,
|
||||
}
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
|
||||
var cChannelId = C.CString(string(channelId))
|
||||
var resp = C.allocResp(unsafe.Pointer(&wg))
|
||||
|
||||
defer C.free(unsafe.Pointer(cChannelId))
|
||||
defer C.freeResp(resp)
|
||||
|
||||
if C.getRet(resp) != C.RET_OK {
|
||||
errMsg := C.GoStringN(C.getMyCharPtr(resp), C.int(C.getMyCharLen(resp)))
|
||||
Error("error NewReliabilityManager: %v", errMsg)
|
||||
return nil, errors.New(errMsg)
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
rm.rmCtx = C.cGoNewReliabilityManager(cChannelId, resp)
|
||||
wg.Wait()
|
||||
|
||||
C.cGoSetEventCallback(rm.rmCtx)
|
||||
registerReliabilityManager(rm)
|
||||
|
||||
Debug("Successfully created Reliability Manager")
|
||||
return rm, nil
|
||||
}
|
||||
|
||||
// The event callback sends back the rm ctx to know to which
|
||||
// rm is the event being emited for. Since we only have a global
|
||||
// callback in the go side, We register all the rm's that we create
|
||||
// so we can later obtain which instance of `ReliabilityManager` it should
|
||||
// be invoked depending on the ctx received
|
||||
|
||||
var rmRegistry map[unsafe.Pointer]*ReliabilityManager
|
||||
|
||||
func init() {
|
||||
rmRegistry = make(map[unsafe.Pointer]*ReliabilityManager)
|
||||
}
|
||||
|
||||
func registerReliabilityManager(rm *ReliabilityManager) {
|
||||
_, ok := rmRegistry[rm.rmCtx]
|
||||
if !ok {
|
||||
rmRegistry[rm.rmCtx] = rm
|
||||
}
|
||||
}
|
||||
|
||||
func unregisterReliabilityManager(rm *ReliabilityManager) {
|
||||
delete(rmRegistry, rm.rmCtx)
|
||||
}
|
||||
|
||||
//export globalEventCallback
|
||||
func globalEventCallback(callerRet C.int, msg *C.char, len C.size_t, userData unsafe.Pointer) {
|
||||
if callerRet == C.RET_OK {
|
||||
eventStr := C.GoStringN(msg, C.int(len))
|
||||
rm, ok := rmRegistry[userData] // userData contains rm's ctx
|
||||
if ok {
|
||||
rm.OnEvent(eventStr)
|
||||
}
|
||||
} else {
|
||||
if len != 0 {
|
||||
errMsg := C.GoStringN(msg, C.int(len))
|
||||
Error("globalEventCallback retCode not ok, retCode: %v: %v", callerRet, errMsg)
|
||||
} else {
|
||||
Error("globalEventCallback retCode not ok, retCode: %v", callerRet)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type jsonEvent struct {
|
||||
EventType string `json:"eventType"`
|
||||
}
|
||||
|
||||
type msgEvent struct {
|
||||
MessageId MessageID `json:"messageId"`
|
||||
}
|
||||
|
||||
type missingDepsEvent struct {
|
||||
MessageId MessageID `json:"messageId"`
|
||||
MissingDeps []MessageID `json:"missingDeps"`
|
||||
}
|
||||
|
||||
func (rm *ReliabilityManager) RegisterCallbacks(callbacks EventCallbacks) {
|
||||
rm.callbacks = callbacks
|
||||
}
|
||||
|
||||
func (rm *ReliabilityManager) OnEvent(eventStr string) {
|
||||
|
||||
jsonEvent := jsonEvent{}
|
||||
err := json.Unmarshal([]byte(eventStr), &jsonEvent)
|
||||
if err != nil {
|
||||
Error("could not unmarshal sds event string: %v", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
switch jsonEvent.EventType {
|
||||
case "message_ready":
|
||||
rm.parseMessageReadyEvent(eventStr)
|
||||
case "message_sent":
|
||||
rm.parseMessageSentEvent(eventStr)
|
||||
case "missing_dependencies":
|
||||
rm.parseMissingDepsEvent(eventStr)
|
||||
case "periodic_sync":
|
||||
if rm.callbacks.OnPeriodicSync != nil {
|
||||
rm.callbacks.OnPeriodicSync()
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (rm *ReliabilityManager) parseMessageReadyEvent(eventStr string) {
|
||||
|
||||
msgEvent := msgEvent{}
|
||||
err := json.Unmarshal([]byte(eventStr), &msgEvent)
|
||||
if err != nil {
|
||||
Error("could not parse message ready event %v", err)
|
||||
}
|
||||
|
||||
if rm.callbacks.OnMessageReady != nil {
|
||||
rm.callbacks.OnMessageReady(msgEvent.MessageId)
|
||||
}
|
||||
}
|
||||
|
||||
func (rm *ReliabilityManager) parseMessageSentEvent(eventStr string) {
|
||||
|
||||
msgEvent := msgEvent{}
|
||||
err := json.Unmarshal([]byte(eventStr), &msgEvent)
|
||||
if err != nil {
|
||||
Error("could not parse message sent event %v", err)
|
||||
}
|
||||
|
||||
if rm.callbacks.OnMessageSent != nil {
|
||||
rm.callbacks.OnMessageSent(msgEvent.MessageId)
|
||||
}
|
||||
}
|
||||
|
||||
func (rm *ReliabilityManager) parseMissingDepsEvent(eventStr string) {
|
||||
|
||||
missingDepsEvent := missingDepsEvent{}
|
||||
err := json.Unmarshal([]byte(eventStr), &missingDepsEvent)
|
||||
if err != nil {
|
||||
Error("could not parse missing dependencies event %v", err)
|
||||
}
|
||||
|
||||
if rm.callbacks.OnMissingDependencies != nil {
|
||||
rm.callbacks.OnMissingDependencies(missingDepsEvent.MessageId, missingDepsEvent.MissingDeps)
|
||||
}
|
||||
}
|
||||
|
||||
func (rm *ReliabilityManager) Cleanup() error {
|
||||
if rm == nil {
|
||||
err := errors.New("reliability manager is nil in Cleanup")
|
||||
Error("Failed to cleanup %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
Debug("Cleaning up reliability manager")
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
var resp = C.allocResp(unsafe.Pointer(&wg))
|
||||
defer C.freeResp(resp)
|
||||
|
||||
wg.Add(1)
|
||||
C.cGoCleanupReliabilityManager(rm.rmCtx, resp)
|
||||
wg.Wait()
|
||||
|
||||
if C.getRet(resp) == C.RET_OK {
|
||||
unregisterReliabilityManager(rm)
|
||||
Debug("Successfully cleaned up reliability manager")
|
||||
return nil
|
||||
}
|
||||
|
||||
errMsg := "error CleanupReliabilityManager: " + C.GoStringN(C.getMyCharPtr(resp), C.int(C.getMyCharLen(resp)))
|
||||
Error("Failed to cleanup reliability manager: %v", errMsg)
|
||||
|
||||
return errors.New(errMsg)
|
||||
}
|
||||
|
||||
func (rm *ReliabilityManager) Reset() error {
|
||||
if rm == nil {
|
||||
err := errors.New("reliability manager is nil in Reset")
|
||||
Error("Failed to reset %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
Debug("Resetting reliability manager")
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
var resp = C.allocResp(unsafe.Pointer(&wg))
|
||||
defer C.freeResp(resp)
|
||||
|
||||
wg.Add(1)
|
||||
C.cGoResetReliabilityManager(rm.rmCtx, resp)
|
||||
wg.Wait()
|
||||
|
||||
if C.getRet(resp) == C.RET_OK {
|
||||
Debug("Successfully resetted reliability manager")
|
||||
return nil
|
||||
}
|
||||
|
||||
errMsg := "error ResetReliabilityManager: " + C.GoStringN(C.getMyCharPtr(resp), C.int(C.getMyCharLen(resp)))
|
||||
Error("Failed to reset reliability manager: %v", errMsg)
|
||||
|
||||
return errors.New(errMsg)
|
||||
}
|
||||
|
||||
func (rm *ReliabilityManager) WrapOutgoingMessage(message []byte, messageId MessageID) ([]byte, error) {
|
||||
if rm == nil {
|
||||
err := errors.New("reliability manager is nil in WrapOutgoingMessage")
|
||||
Error("Failed to wrap outgoing message %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
Debug("Wrapping outgoing message %v", messageId)
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
var resp = C.allocResp(unsafe.Pointer(&wg))
|
||||
defer C.freeResp(resp)
|
||||
|
||||
cMessageId := C.CString(string(messageId))
|
||||
defer C.free(unsafe.Pointer(cMessageId))
|
||||
|
||||
var cMessagePtr unsafe.Pointer
|
||||
if len(message) > 0 {
|
||||
cMessagePtr = C.CBytes(message) // C.CBytes allocates memory that needs to be freed
|
||||
defer C.free(cMessagePtr)
|
||||
} else {
|
||||
cMessagePtr = nil
|
||||
}
|
||||
cMessageLen := C.size_t(len(message))
|
||||
|
||||
wg.Add(1)
|
||||
C.cGoWrapOutgoingMessage(rm.rmCtx, cMessagePtr, cMessageLen, cMessageId, resp)
|
||||
wg.Wait()
|
||||
|
||||
if C.getRet(resp) == C.RET_OK {
|
||||
resStr := C.GoStringN(C.getMyCharPtr(resp), C.int(C.getMyCharLen(resp)))
|
||||
if resStr == "" {
|
||||
Debug("Received empty res string for messageId: %v", messageId)
|
||||
return nil, nil
|
||||
}
|
||||
Debug("Successfully wrapped message %s", messageId)
|
||||
|
||||
parts := strings.Split(resStr, ",")
|
||||
bytes := make([]byte, len(parts))
|
||||
|
||||
for i, part := range parts {
|
||||
n, err := strconv.Atoi(strings.TrimSpace(part))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
bytes[i] = byte(n)
|
||||
}
|
||||
|
||||
return bytes, nil
|
||||
}
|
||||
|
||||
errMsg := "error WrapOutgoingMessage: " + C.GoStringN(C.getMyCharPtr(resp), C.int(C.getMyCharLen(resp)))
|
||||
Error("Failed to wrap message %v: %v", messageId, errMsg)
|
||||
|
||||
return nil, errors.New(errMsg)
|
||||
}
|
||||
|
||||
func (rm *ReliabilityManager) UnwrapReceivedMessage(message []byte) (*UnwrappedMessage, error) {
|
||||
if rm == nil {
|
||||
err := errors.New("reliability manager is nil in UnwrapReceivedMessage")
|
||||
Error("Failed to unwrap received message %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
var resp = C.allocResp(unsafe.Pointer(&wg))
|
||||
defer C.freeResp(resp)
|
||||
|
||||
var cMessagePtr unsafe.Pointer
|
||||
if len(message) > 0 {
|
||||
cMessagePtr = C.CBytes(message) // C.CBytes allocates memory that needs to be freed
|
||||
defer C.free(cMessagePtr)
|
||||
} else {
|
||||
cMessagePtr = nil
|
||||
}
|
||||
cMessageLen := C.size_t(len(message))
|
||||
|
||||
wg.Add(1)
|
||||
C.cGoUnwrapReceivedMessage(rm.rmCtx, cMessagePtr, cMessageLen, resp)
|
||||
wg.Wait()
|
||||
|
||||
if C.getRet(resp) == C.RET_OK {
|
||||
resStr := C.GoStringN(C.getMyCharPtr(resp), C.int(C.getMyCharLen(resp)))
|
||||
if resStr == "" {
|
||||
Debug("Received empty res string")
|
||||
return nil, nil
|
||||
}
|
||||
Debug("Successfully unwrapped message")
|
||||
|
||||
unwrappedMessage := UnwrappedMessage{}
|
||||
err := json.Unmarshal([]byte(resStr), &unwrappedMessage)
|
||||
if err != nil {
|
||||
Error("Failed to unmarshal unwrapped message")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &unwrappedMessage, nil
|
||||
}
|
||||
|
||||
errMsg := "error UnwrapReceivedMessage: " + C.GoStringN(C.getMyCharPtr(resp), C.int(C.getMyCharLen(resp)))
|
||||
Error("Failed to unwrap message: %v", errMsg)
|
||||
|
||||
return nil, errors.New(errMsg)
|
||||
}
|
||||
|
||||
// MarkDependenciesMet informs the library that dependencies are met
|
||||
func (rm *ReliabilityManager) MarkDependenciesMet(messageIDs []MessageID) error {
|
||||
if rm == nil {
|
||||
err := errors.New("reliability manager is nil in MarkDependenciesMet")
|
||||
Error("Failed to mark dependencies met %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
if len(messageIDs) == 0 {
|
||||
return nil // Nothing to do
|
||||
}
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
var resp = C.allocResp(unsafe.Pointer(&wg))
|
||||
defer C.freeResp(resp)
|
||||
|
||||
// Convert Go string slice to C array of C strings (char**)
|
||||
cMessageIDs := make([]*C.char, len(messageIDs))
|
||||
for i, id := range messageIDs {
|
||||
cMessageIDs[i] = C.CString(string(id))
|
||||
defer C.free(unsafe.Pointer(cMessageIDs[i])) // Ensure each CString is freed
|
||||
}
|
||||
|
||||
// Create a pointer (**C.char) to the first element of the slice
|
||||
var cMessageIDsPtr **C.char
|
||||
if len(cMessageIDs) > 0 {
|
||||
cMessageIDsPtr = &cMessageIDs[0]
|
||||
} else {
|
||||
cMessageIDsPtr = nil // Handle empty slice case
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
// Pass the pointer variable (cMessageIDsPtr) directly, which is of type **C.char
|
||||
C.cGoMarkDependenciesMet(rm.rmCtx, cMessageIDsPtr, C.size_t(len(messageIDs)), resp)
|
||||
wg.Wait()
|
||||
|
||||
if C.getRet(resp) == C.RET_OK {
|
||||
Debug("Successfully marked dependencies as met")
|
||||
return nil
|
||||
}
|
||||
|
||||
errMsg := "error MarkDependenciesMet: " + C.GoStringN(C.getMyCharPtr(resp), C.int(C.getMyCharLen(resp)))
|
||||
Error("Failed to mark dependencies as met: %v", errMsg)
|
||||
|
||||
return errors.New(errMsg)
|
||||
}
|
||||
|
||||
func (rm *ReliabilityManager) StartPeriodicTasks() error {
|
||||
if rm == nil {
|
||||
err := errors.New("reliability manager is nil in StartPeriodicTasks")
|
||||
Error("Failed to start periodic tasks %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
Debug("Starting periodic tasks")
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
var resp = C.allocResp(unsafe.Pointer(&wg))
|
||||
defer C.freeResp(resp)
|
||||
|
||||
wg.Add(1)
|
||||
C.cGoStartPeriodicTasks(rm.rmCtx, resp)
|
||||
wg.Wait()
|
||||
|
||||
if C.getRet(resp) == C.RET_OK {
|
||||
Debug("Successfully started periodic tasks")
|
||||
return nil
|
||||
}
|
||||
|
||||
errMsg := "error StartPeriodicTasks: " + C.GoStringN(C.getMyCharPtr(resp), C.int(C.getMyCharLen(resp)))
|
||||
Error("Failed to start periodic tasks: %v", errMsg)
|
||||
|
||||
return errors.New(errMsg)
|
||||
}
|
||||
539
sds/sds_test.go
Normal file
539
sds/sds_test.go
Normal file
@ -0,0 +1,539 @@
|
||||
package sds
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// Test basic creation, cleanup, and reset
|
||||
func TestLifecycle(t *testing.T) {
|
||||
channelID := "test-lifecycle"
|
||||
rm, err := NewReliabilityManager(channelID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rm, "Expected ReliabilityManager to be not nil")
|
||||
|
||||
defer rm.Cleanup() // Ensure cleanup even on test failure
|
||||
|
||||
err = rm.Reset()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Test wrapping and unwrapping a simple message
|
||||
func TestWrapUnwrap(t *testing.T) {
|
||||
channelID := "test-wrap-unwrap"
|
||||
rm, err := NewReliabilityManager(channelID)
|
||||
require.NoError(t, err)
|
||||
defer rm.Cleanup()
|
||||
|
||||
originalPayload := []byte("hello reliability")
|
||||
messageID := MessageID("msg-wrap-1")
|
||||
|
||||
wrappedMsg, err := rm.WrapOutgoingMessage(originalPayload, messageID)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Greater(t, len(wrappedMsg), 0, "Expected non-empty wrapped message")
|
||||
|
||||
// Simulate receiving the wrapped message
|
||||
unwrappedMessage, err := rm.UnwrapReceivedMessage(wrappedMsg)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, string(*unwrappedMessage.Message), string(originalPayload), "Expected unwrapped and original payloads to be equal")
|
||||
require.Equal(t, len(*unwrappedMessage.MissingDeps), 0, "Expexted to be no missing dependencies")
|
||||
}
|
||||
|
||||
// Test dependency handling
|
||||
func TestDependencies(t *testing.T) {
|
||||
channelID := "test-deps"
|
||||
rm, err := NewReliabilityManager(channelID)
|
||||
require.NoError(t, err)
|
||||
defer rm.Cleanup()
|
||||
|
||||
// 1. Send message 1 (will become a dependency)
|
||||
payload1 := []byte("message one")
|
||||
msgID1 := MessageID("msg-dep-1")
|
||||
wrappedMsg1, err := rm.WrapOutgoingMessage(payload1, msgID1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Simulate receiving msg1 to add it to history (implicitly acknowledges it)
|
||||
_, err = rm.UnwrapReceivedMessage(wrappedMsg1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 2. Send message 2 (depends on message 1 implicitly via causal history)
|
||||
payload2 := []byte("message two")
|
||||
msgID2 := MessageID("msg-dep-2")
|
||||
wrappedMsg2, err := rm.WrapOutgoingMessage(payload2, msgID2)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 3. Create a new manager to simulate a different peer receiving msg2 without msg1
|
||||
rm2, err := NewReliabilityManager(channelID) // Same channel ID
|
||||
require.NoError(t, err)
|
||||
defer rm2.Cleanup()
|
||||
|
||||
// 4. Unwrap message 2 on the second manager - should report msg1 as missing
|
||||
unwrappedMessage2, err := rm2.UnwrapReceivedMessage(wrappedMsg2)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Greater(t, len(*unwrappedMessage2.MissingDeps), 0, "Expected missing dependencies, got none")
|
||||
|
||||
foundDep1 := false
|
||||
for _, dep := range *unwrappedMessage2.MissingDeps {
|
||||
if dep == msgID1 {
|
||||
foundDep1 = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, foundDep1, "Expected missing dependency %q, got %v", msgID1, *unwrappedMessage2.MissingDeps)
|
||||
|
||||
// 5. Mark the dependency as met
|
||||
err = rm2.MarkDependenciesMet([]MessageID{msgID1})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Test OnMessageReady callback
|
||||
func TestCallback_OnMessageReady(t *testing.T) {
|
||||
channelID := "test-cb-ready"
|
||||
|
||||
// Create sender and receiver RMs
|
||||
senderRm, err := NewReliabilityManager(channelID)
|
||||
require.NoError(t, err)
|
||||
defer senderRm.Cleanup()
|
||||
|
||||
receiverRm, err := NewReliabilityManager(channelID)
|
||||
require.NoError(t, err)
|
||||
defer receiverRm.Cleanup()
|
||||
|
||||
// Use a channel for signaling
|
||||
readyChan := make(chan MessageID, 1)
|
||||
|
||||
callbacks := EventCallbacks{
|
||||
OnMessageReady: func(messageId MessageID) {
|
||||
// Non-blocking send to channel
|
||||
select {
|
||||
case readyChan <- messageId:
|
||||
default:
|
||||
// Avoid blocking if channel is full or test already timed out
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
// Register callback only on the receiver
|
||||
receiverRm.RegisterCallbacks(callbacks)
|
||||
|
||||
// Scenario: Wrap message on sender, unwrap on receiver
|
||||
payload := []byte("ready test")
|
||||
msgID := MessageID("cb-ready-1")
|
||||
|
||||
// Wrap on sender
|
||||
wrappedMsg, err := senderRm.WrapOutgoingMessage(payload, msgID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Unwrap on receiver
|
||||
_, err = receiverRm.UnwrapReceivedMessage(wrappedMsg)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verification - Wait on channel with timeout
|
||||
select {
|
||||
case receivedMsgID := <-readyChan:
|
||||
// Mark as called implicitly since we received on channel
|
||||
if receivedMsgID != msgID {
|
||||
t.Errorf("OnMessageReady called with wrong ID: got %q, want %q", receivedMsgID, msgID)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
// If timeout occurs, the channel receive failed.
|
||||
t.Errorf("Timed out waiting for OnMessageReady callback on readyChan")
|
||||
}
|
||||
}
|
||||
|
||||
// Test OnMessageSent callback (via causal history ACK)
|
||||
func TestCallback_OnMessageSent(t *testing.T) {
|
||||
channelID := "test-cb-sent"
|
||||
|
||||
// Create two RMs
|
||||
rm1, err := NewReliabilityManager(channelID)
|
||||
require.NoError(t, err)
|
||||
defer rm1.Cleanup()
|
||||
|
||||
rm2, err := NewReliabilityManager(channelID)
|
||||
require.NoError(t, err)
|
||||
defer rm2.Cleanup()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
sentCalled := false
|
||||
var sentMsgID MessageID
|
||||
var cbMutex sync.Mutex
|
||||
|
||||
callbacks := EventCallbacks{
|
||||
OnMessageSent: func(messageId MessageID) {
|
||||
cbMutex.Lock()
|
||||
sentCalled = true
|
||||
sentMsgID = messageId
|
||||
cbMutex.Unlock()
|
||||
wg.Done()
|
||||
},
|
||||
}
|
||||
|
||||
// Register callback on rm1 (the original sender)
|
||||
rm1.RegisterCallbacks(callbacks)
|
||||
|
||||
// Scenario: rm1 sends msg1, rm2 receives msg1,
|
||||
// rm2 sends msg2 (acking msg1), rm1 receives msg2.
|
||||
|
||||
// 1. rm1 sends msg1
|
||||
payload1 := []byte("sent test 1")
|
||||
msgID1 := MessageID("cb-sent-1")
|
||||
wrappedMsg1, err := rm1.WrapOutgoingMessage(payload1, msgID1)
|
||||
require.NoError(t, err)
|
||||
// Note: msg1 is now in rm1's outgoing buffer
|
||||
|
||||
// 2. rm2 receives msg1 (to update its state)
|
||||
_, err = rm2.UnwrapReceivedMessage(wrappedMsg1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 3. rm2 sends msg2 (will include msg1 in causal history)
|
||||
payload2 := []byte("sent test 2")
|
||||
msgID2 := MessageID("cb-sent-2")
|
||||
wrappedMsg2, err := rm2.WrapOutgoingMessage(payload2, msgID2)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 4. rm1 receives msg2 (should trigger ACK for msg1)
|
||||
wg.Add(1) // Expect OnMessageSent for msg1 on rm1
|
||||
_, err = rm1.UnwrapReceivedMessage(wrappedMsg2)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verification
|
||||
waitTimeout(&wg, 2*time.Second, t)
|
||||
|
||||
cbMutex.Lock()
|
||||
defer cbMutex.Unlock()
|
||||
if !sentCalled {
|
||||
t.Errorf("OnMessageSent was not called")
|
||||
}
|
||||
// We primarily care that msg1 was ACKed.
|
||||
if sentMsgID != msgID1 {
|
||||
t.Errorf("OnMessageSent called with wrong ID: got %q, want %q", sentMsgID, msgID1)
|
||||
}
|
||||
}
|
||||
|
||||
// Test OnMissingDependencies callback
|
||||
func TestCallback_OnMissingDependencies(t *testing.T) {
|
||||
channelID := "test-cb-missing"
|
||||
|
||||
// Use separate sender/receiver RMs explicitly
|
||||
senderRm, err := NewReliabilityManager(channelID)
|
||||
require.NoError(t, err)
|
||||
defer senderRm.Cleanup()
|
||||
|
||||
receiverRm, err := NewReliabilityManager(channelID)
|
||||
require.NoError(t, err)
|
||||
defer receiverRm.Cleanup()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
missingCalled := false
|
||||
var missingMsgID MessageID
|
||||
var missingDepsList []MessageID
|
||||
var cbMutex sync.Mutex
|
||||
|
||||
callbacks := EventCallbacks{
|
||||
OnMissingDependencies: func(messageId MessageID, missingDeps []MessageID) {
|
||||
cbMutex.Lock()
|
||||
missingCalled = true
|
||||
missingMsgID = messageId
|
||||
missingDepsList = missingDeps // Copy slice
|
||||
cbMutex.Unlock()
|
||||
wg.Done()
|
||||
},
|
||||
}
|
||||
|
||||
// Register callback only on the receiver rm
|
||||
receiverRm.RegisterCallbacks(callbacks)
|
||||
|
||||
// Scenario: Sender sends msg1, then sender sends msg2 (depends on msg1),
|
||||
// then receiver receives msg2 (which hasn't seen msg1).
|
||||
|
||||
// 1. Sender sends msg1
|
||||
payload1 := []byte("missing test 1")
|
||||
msgID1 := MessageID("cb-miss-1")
|
||||
_, err = senderRm.WrapOutgoingMessage(payload1, msgID1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 2. Sender sends msg2 (depends on msg1)
|
||||
payload2 := []byte("missing test 2")
|
||||
msgID2 := MessageID("cb-miss-2")
|
||||
wrappedMsg2, err := senderRm.WrapOutgoingMessage(payload2, msgID2)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 3. Receiver receives msg2 (haven't seen msg1)
|
||||
wg.Add(1) // Expect OnMissingDependencies
|
||||
_, err = receiverRm.UnwrapReceivedMessage(wrappedMsg2)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verification
|
||||
waitTimeout(&wg, 2*time.Second, t)
|
||||
|
||||
cbMutex.Lock()
|
||||
defer cbMutex.Unlock()
|
||||
if !missingCalled {
|
||||
t.Errorf("OnMissingDependencies was not called")
|
||||
}
|
||||
if missingMsgID != msgID2 {
|
||||
t.Errorf("OnMissingDependencies called for wrong ID: got %q, want %q", missingMsgID, msgID2)
|
||||
}
|
||||
foundDep := false
|
||||
for _, dep := range missingDepsList {
|
||||
if dep == msgID1 {
|
||||
foundDep = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundDep {
|
||||
t.Errorf("OnMissingDependencies did not report %q as missing, got: %v", msgID1, missingDepsList)
|
||||
}
|
||||
}
|
||||
|
||||
// Test OnPeriodicSync callback
|
||||
func TestCallback_OnPeriodicSync(t *testing.T) {
|
||||
channelID := "test-cb-sync"
|
||||
rm, err := NewReliabilityManager(channelID)
|
||||
require.NoError(t, err)
|
||||
defer rm.Cleanup()
|
||||
|
||||
syncCalled := false
|
||||
var cbMutex sync.Mutex
|
||||
// Use a channel to signal when the callback is hit
|
||||
syncChan := make(chan bool, 1)
|
||||
|
||||
callbacks := EventCallbacks{
|
||||
OnPeriodicSync: func() {
|
||||
cbMutex.Lock()
|
||||
if !syncCalled { // Only signal the first time
|
||||
syncCalled = true
|
||||
syncChan <- true
|
||||
}
|
||||
cbMutex.Unlock()
|
||||
},
|
||||
}
|
||||
|
||||
rm.RegisterCallbacks(callbacks)
|
||||
|
||||
// Start periodic tasks
|
||||
err = rm.StartPeriodicTasks()
|
||||
require.NoError(t, err)
|
||||
|
||||
// --- Verification ---
|
||||
// Wait for the periodic sync callback with a timeout (needs to be longer than sync interval)
|
||||
select {
|
||||
case <-syncChan:
|
||||
// Success
|
||||
case <-time.After(10 * time.Second):
|
||||
t.Errorf("Timed out waiting for OnPeriodicSync callback")
|
||||
}
|
||||
|
||||
cbMutex.Lock()
|
||||
defer cbMutex.Unlock()
|
||||
if !syncCalled {
|
||||
// This might happen if the timeout was too short
|
||||
t.Logf("Warning: OnPeriodicSync might not have been called within the test timeout")
|
||||
}
|
||||
}
|
||||
|
||||
// Combined Test for multiple callbacks
|
||||
func TestCallbacks_Combined(t *testing.T) {
|
||||
channelID := "test-cb-combined"
|
||||
|
||||
// Create sender and receiver RMs
|
||||
senderRm, err := NewReliabilityManager(channelID)
|
||||
require.NoError(t, err)
|
||||
defer senderRm.Cleanup()
|
||||
|
||||
receiverRm, err := NewReliabilityManager(channelID)
|
||||
require.NoError(t, err)
|
||||
defer receiverRm.Cleanup()
|
||||
|
||||
// Channels for synchronization
|
||||
readyChan1 := make(chan bool, 1)
|
||||
sentChan1 := make(chan bool, 1)
|
||||
missingChan := make(chan []MessageID, 1)
|
||||
|
||||
// Use maps for verification
|
||||
receivedReady := make(map[MessageID]bool)
|
||||
receivedSent := make(map[MessageID]bool)
|
||||
var cbMutex sync.Mutex
|
||||
|
||||
callbacksReceiver := EventCallbacks{
|
||||
OnMessageReady: func(messageId MessageID) {
|
||||
cbMutex.Lock()
|
||||
receivedReady[messageId] = true
|
||||
cbMutex.Unlock()
|
||||
if messageId == "cb-comb-1" {
|
||||
// Use non-blocking send
|
||||
select {
|
||||
case readyChan1 <- true:
|
||||
default:
|
||||
}
|
||||
}
|
||||
},
|
||||
OnMessageSent: func(messageId MessageID) {
|
||||
// This callback is registered on Receiver, but Sent events
|
||||
// are typically relevant to the Sender. We don't expect this.
|
||||
t.Errorf("Unexpected OnMessageSent call on Receiver for %s", messageId)
|
||||
},
|
||||
OnMissingDependencies: func(messageId MessageID, missingDeps []MessageID) {
|
||||
// This callback is registered on Receiver, used for receiverRm2 below
|
||||
},
|
||||
}
|
||||
|
||||
callbacksSender := EventCallbacks{
|
||||
OnMessageReady: func(messageId MessageID) {
|
||||
// Not expected on sender in this test flow
|
||||
},
|
||||
OnMessageSent: func(messageId MessageID) {
|
||||
cbMutex.Lock()
|
||||
receivedSent[messageId] = true
|
||||
cbMutex.Unlock()
|
||||
if messageId == "cb-comb-1" {
|
||||
select {
|
||||
case sentChan1 <- true:
|
||||
default:
|
||||
}
|
||||
}
|
||||
},
|
||||
OnMissingDependencies: func(messageId MessageID, missingDeps []MessageID) {
|
||||
// Not expected on sender
|
||||
},
|
||||
}
|
||||
|
||||
// Register callbacks
|
||||
receiverRm.RegisterCallbacks(callbacksReceiver)
|
||||
senderRm.RegisterCallbacks(callbacksSender)
|
||||
|
||||
// --- Test Scenario ---
|
||||
msgID1 := MessageID("cb-comb-1")
|
||||
msgID2 := MessageID("cb-comb-2")
|
||||
msgID3 := MessageID("cb-comb-3")
|
||||
payload1 := []byte("combined test 1")
|
||||
payload2 := []byte("combined test 2")
|
||||
payload3 := []byte("combined test 3")
|
||||
|
||||
// 1. Sender sends msg1
|
||||
wrappedMsg1, err := senderRm.WrapOutgoingMessage(payload1, msgID1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 2. Receiver receives msg1
|
||||
_, err = receiverRm.UnwrapReceivedMessage(wrappedMsg1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 3. Receiver sends msg2 (depends on msg1 implicitly via state)
|
||||
wrappedMsg2, err := receiverRm.WrapOutgoingMessage(payload2, msgID2)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 4. Sender receives msg2 from Receiver (acks msg1 for sender)
|
||||
_, err = senderRm.UnwrapReceivedMessage(wrappedMsg2)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 5. Sender sends msg3 (depends on msg2)
|
||||
wrappedMsg3, err := senderRm.WrapOutgoingMessage(payload3, msgID3)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 6. Create Receiver2, register missing deps callback
|
||||
receiverRm2, err := NewReliabilityManager(channelID)
|
||||
require.NoError(t, err)
|
||||
defer receiverRm2.Cleanup()
|
||||
|
||||
callbacksReceiver2 := EventCallbacks{
|
||||
OnMissingDependencies: func(messageId MessageID, missingDeps []MessageID) {
|
||||
if messageId == msgID3 {
|
||||
select {
|
||||
case missingChan <- missingDeps:
|
||||
default:
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
receiverRm2.RegisterCallbacks(callbacksReceiver2)
|
||||
|
||||
// 7. Receiver2 receives msg3 (should report missing msg1, msg2)
|
||||
_, err = receiverRm2.UnwrapReceivedMessage(wrappedMsg3)
|
||||
require.NoError(t, err)
|
||||
|
||||
// --- Verification ---
|
||||
timeout := 5 * time.Second
|
||||
expectedReady1 := false
|
||||
expectedSent1 := false
|
||||
var reportedMissingDeps []MessageID
|
||||
missingDepsReceived := false
|
||||
|
||||
receivedCount := 0
|
||||
expectedCount := 3 // ready1, sent1, missingDeps
|
||||
timer := time.NewTimer(timeout)
|
||||
defer timer.Stop()
|
||||
|
||||
for receivedCount < expectedCount {
|
||||
select {
|
||||
case <-readyChan1:
|
||||
if !expectedReady1 { // Avoid double counting if signaled twice
|
||||
expectedReady1 = true
|
||||
receivedCount++
|
||||
}
|
||||
case <-sentChan1:
|
||||
if !expectedSent1 {
|
||||
expectedSent1 = true
|
||||
receivedCount++
|
||||
}
|
||||
case deps := <-missingChan:
|
||||
if !missingDepsReceived {
|
||||
reportedMissingDeps = deps
|
||||
missingDepsReceived = true
|
||||
receivedCount++
|
||||
}
|
||||
case <-timer.C:
|
||||
t.Fatalf("Timed out waiting for combined callbacks (received %d out of %d)", receivedCount, expectedCount)
|
||||
}
|
||||
}
|
||||
|
||||
// Check results
|
||||
cbMutex.Lock()
|
||||
defer cbMutex.Unlock()
|
||||
|
||||
if !expectedReady1 || !receivedReady[msgID1] {
|
||||
t.Errorf("OnMessageReady not called/verified for %s", msgID1)
|
||||
}
|
||||
if !expectedSent1 || !receivedSent[msgID1] {
|
||||
t.Errorf("OnMessageSent not called/verified for %s", msgID1)
|
||||
}
|
||||
if !missingDepsReceived {
|
||||
t.Errorf("OnMissingDependencies not called/verified for %s", msgID3)
|
||||
} else {
|
||||
foundDep1 := false
|
||||
foundDep2 := false
|
||||
for _, dep := range reportedMissingDeps {
|
||||
if dep == msgID1 {
|
||||
foundDep1 = true
|
||||
}
|
||||
if dep == msgID2 {
|
||||
foundDep2 = true
|
||||
}
|
||||
}
|
||||
if !foundDep1 || !foundDep2 {
|
||||
t.Errorf("OnMissingDependencies for %s reported wrong deps: got %v, want %s and %s", msgID3, reportedMissingDeps, msgID1, msgID2)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to wait for WaitGroup with a timeout
|
||||
func waitTimeout(wg *sync.WaitGroup, timeout time.Duration, t *testing.T) {
|
||||
c := make(chan struct{})
|
||||
go func() {
|
||||
defer close(c)
|
||||
wg.Wait()
|
||||
}()
|
||||
select {
|
||||
case <-c:
|
||||
// Completed normally
|
||||
case <-time.After(timeout):
|
||||
t.Fatalf("Timed out waiting for callbacks")
|
||||
}
|
||||
}
|
||||
8
sds/types.go
Normal file
8
sds/types.go
Normal file
@ -0,0 +1,8 @@
|
||||
package sds
|
||||
|
||||
type MessageID string
|
||||
|
||||
type UnwrappedMessage struct {
|
||||
Message *[]byte `json:"message"`
|
||||
MissingDeps *[]MessageID `json:"missingDeps"`
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user