diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..eea2e3e --- /dev/null +++ b/.gitignore @@ -0,0 +1,25 @@ +# If you prefer the allow list template instead of the deny list, see community template: +# https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore +# +# Binaries for programs and plugins +*.exe +*.exe~ +*.dll +*.so +*.dylib + +# Test binary, built with `go test -c` +*.test + +# Output of the go coverage tool, specifically when used with LiteIDE +*.out + +# Dependency directories (remove the comment below to include it) +# vendor/ + +# Go workspace file +go.work + +# Generated dependencies and cache +third_party +nimcache \ No newline at end of file diff --git a/README.md b/README.md index c3d6c7a..4178753 100644 --- a/README.md +++ b/README.md @@ -1 +1,45 @@ -# sds-go-bindings \ No newline at end of file +# SDS Go Bindings + +This repository provides Go bindings for the SDS library, enabling seamless integration with Go projects. + +## Installation + +To build the required dependencies for this module, the `make` command needs to be executed. If you are integrating this module into another project via `go get`, ensure that you navigate to the `sds-go-bindings/sds` directory and run `make`. + +### Steps to Install + +Follow these steps to install and set up the module: + +1. Retrieve the module using `go get`: + ``` + go get -u github.com/waku-org/sds-go-bindings + ``` +2. Navigate to the module's directory: + ``` + cd $(go list -m -f '{{.Dir}}' github.com/waku-org/sds-go-bindings) + ``` +3. Prepare third_party directory which will clone `nim-sds` + ``` + sudo mkdir third_party + sudo chown $USER third_party + ``` +4. Build the dependencies: + ``` + make -C sds + ``` + +Now the module is ready for use in your project. + +### Note + +In order to easily build the libsds library on demand, it is recommended to add the following target in your project's Makefile: + +``` +LIBSDS_DEP_PATH=$(shell go list -m -f '{{.Dir}}' github.com/waku-org/sds-go-bindings) + +buildlib: + cd $(LIBSDS_DEP_PATH) &&\ + sudo mkdir -p third_party &&\ + sudo chown $(USER) third_party &&\ + make -C sds +``` diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..9c1a6e2 --- /dev/null +++ b/go.mod @@ -0,0 +1,15 @@ +module github.com/waku-org/sds-go-bindings + +go 1.22.10 + +require ( + github.com/stretchr/testify v1.8.1 + go.uber.org/zap v1.27.0 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + go.uber.org/multierr v1.10.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..ea719a5 --- /dev/null +++ b/go.sum @@ -0,0 +1,20 @@ +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ= +go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= +go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= +go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/sds/Makefile b/sds/Makefile new file mode 100644 index 0000000..edd1bf8 --- /dev/null +++ b/sds/Makefile @@ -0,0 +1,44 @@ +# Makefile for SDS Go Bindings + +# Directories +THIRD_PARTY_DIR := ../third_party +NIM_SDS_REPO := https://github.com/waku-org/nim-sds +NIM_SDS_DIR := $(THIRD_PARTY_DIR)/nim-sds + +.PHONY: all clean prepare build-libsds build + +# Default target +all: build + +# Prepare third_party directory and clone nim-sds +# TODO: remove the "git checkout gabrielmer-feat-init-implementation" part +prepare: + @echo "Creating third_party directory..." + @mkdir -p $(THIRD_PARTY_DIR) + + @echo "Cloning nim-sds repository..." + @if [ ! -d "$(NIM_SDS_DIR)" ]; then \ + cd $(THIRD_PARTY_DIR) && \ + git clone $(NIM_SDS_REPO) && \ + cd $(NIM_SDS_DIR) && \ + git checkout gabrielmer-feat-init-implementation; \ + make update; \ + else \ + echo "nim-sds repository already exists."; \ + fi + +# Build libsds +build-libsds: prepare + @echo "Building libsds..." + @cd $(NIM_SDS_DIR) && make libsds + +# Build SDS Go Bindings +build: build-libsds + @echo "Building SDS Go Bindings..." + go build ./... + +# Clean up generated files +clean: + @echo "Cleaning up..." + @rm -rf $(THIRD_PARTY_DIR) + @rm -f sds-go-bindings \ No newline at end of file diff --git a/sds/logging.go b/sds/logging.go new file mode 100644 index 0000000..71e7295 --- /dev/null +++ b/sds/logging.go @@ -0,0 +1,47 @@ +package sds + +import ( + "sync" + + "go.uber.org/zap" +) + +var ( + once sync.Once + sugar *zap.SugaredLogger +) + +func _getLogger() *zap.SugaredLogger { + once.Do(func() { + + config := zap.NewDevelopmentConfig() + l, err := config.Build() + if err != nil { + panic(err) + } + sugar = l.Sugar() + }) + return sugar +} + +func SetLogger(newLogger *zap.Logger) { + once.Do(func() {}) + + sugar = newLogger.Sugar() +} + +func Debug(msg string, args ...interface{}) { + _getLogger().Debugf(msg, args...) +} + +func Info(msg string, args ...interface{}) { + _getLogger().Infof(msg, args...) +} + +func Warn(msg string, args ...interface{}) { + _getLogger().Warnf(msg, args...) +} + +func Error(msg string, args ...interface{}) { + _getLogger().Errorf(msg, args...) +} diff --git a/sds/sds.go b/sds/sds.go new file mode 100644 index 0000000..fec7caf --- /dev/null +++ b/sds/sds.go @@ -0,0 +1,556 @@ +package sds + +/* + #cgo LDFLAGS: -L../third_party/nim-sds/build/ -lsds + #cgo LDFLAGS: -L../third_party/nim-sds -Wl,-rpath,../third_party/nim-sds/build/ + + #include "../third_party/nim-sds/library/libsds.h" + #include + #include + + extern void globalEventCallback(int ret, char* msg, size_t len, void* userData); + + typedef struct { + int ret; + char* msg; + size_t len; + void* ffiWg; + } Resp; + + static void* allocResp(void* wg) { + Resp* r = calloc(1, sizeof(Resp)); + r->ffiWg = wg; + return r; + } + + static void freeResp(void* resp) { + if (resp != NULL) { + free(resp); + } + } + + static char* getMyCharPtr(void* resp) { + if (resp == NULL) { + return NULL; + } + Resp* m = (Resp*) resp; + return m->msg; + } + + static size_t getMyCharLen(void* resp) { + if (resp == NULL) { + return 0; + } + Resp* m = (Resp*) resp; + return m->len; + } + + static int getRet(void* resp) { + if (resp == NULL) { + return 0; + } + Resp* m = (Resp*) resp; + return m->ret; + } + + // resp must be set != NULL in case interest on retrieving data from the callback + void GoCallback(int ret, char* msg, size_t len, void* resp); + + static void* cGoNewReliabilityManager(const char* channelId, void* resp) { + // We pass NULL because we are not interested in retrieving data from this callback + void* ret = NewReliabilityManager(channelId, (SdsCallBack) GoCallback, resp); + return ret; + } + + static void cGoSetEventCallback(void* rmCtx) { + // The 'globalEventCallback' Go function is shared amongst all possible Reliability Manager instances. + + // Given that the 'globalEventCallback' is shared, we pass again the + // rmCtx instance but in this case is needed to pick up the correct method + // that will handle the event. + + // In other words, for every call libsds makes to globalEventCallback, + // the 'userData' parameter will bring the context of the rm that registered + // that globalEventCallback. + + // This technique is needed because cgo only allows to export Go functions and not methods. + + SetEventCallback(rmCtx, (SdsCallBack) globalEventCallback, rmCtx); + } + + static void cGoCleanupReliabilityManager(void* rmCtx, void* resp) { + CleanupReliabilityManager(rmCtx, (SdsCallBack) GoCallback, resp); + } + + static void cGoResetReliabilityManager(void* rmCtx, void* resp) { + ResetReliabilityManager(rmCtx, (SdsCallBack) GoCallback, resp); + } + + static void cGoWrapOutgoingMessage(void* rmCtx, + void* message, + size_t messageLen, + const char* messageId, + void* resp) { + WrapOutgoingMessage(rmCtx, + message, + messageLen, + messageId, + (SdsCallBack) GoCallback, + resp); + } + static void cGoUnwrapReceivedMessage(void* rmCtx, + void* message, + size_t messageLen, + void* resp) { + UnwrapReceivedMessage(rmCtx, + message, + messageLen, + (SdsCallBack) GoCallback, + resp); + } + + static void cGoMarkDependenciesMet(void* rmCtx, + char** messageIDs, + size_t count, + void* resp) { + MarkDependenciesMet(rmCtx, + messageIDs, + count, + (SdsCallBack) GoCallback, + resp); + } + + static void cGoStartPeriodicTasks(void* rmCtx, void* resp) { + StartPeriodicTasks(rmCtx, (SdsCallBack) GoCallback, resp); + } + +*/ +import "C" +import ( + "encoding/json" + "errors" + "strconv" + "strings" + "sync" + "time" + "unsafe" +) + +const requestTimeout = 30 * time.Second +const EventChanBufferSize = 1024 + +//export GoCallback +func GoCallback(ret C.int, msg *C.char, len C.size_t, resp unsafe.Pointer) { + if resp != nil { + m := (*C.Resp)(resp) + m.ret = ret + m.msg = msg + m.len = len + wg := (*sync.WaitGroup)(m.ffiWg) + wg.Done() + } +} + +type EventCallbacks struct { + OnMessageReady func(messageId MessageID) + OnMessageSent func(messageId MessageID) + OnMissingDependencies func(messageId MessageID, missingDeps []MessageID) + OnPeriodicSync func() +} + +// ReliabilityManager represents an instance of a nim-sds ReliabilityManager +type ReliabilityManager struct { + rmCtx unsafe.Pointer + channelId string + callbacks EventCallbacks +} + +func NewReliabilityManager(channelId string) (*ReliabilityManager, error) { + Debug("Creating new Reliability Manager") + rm := &ReliabilityManager{ + channelId: channelId, + } + + wg := sync.WaitGroup{} + + var cChannelId = C.CString(string(channelId)) + var resp = C.allocResp(unsafe.Pointer(&wg)) + + defer C.free(unsafe.Pointer(cChannelId)) + defer C.freeResp(resp) + + if C.getRet(resp) != C.RET_OK { + errMsg := C.GoStringN(C.getMyCharPtr(resp), C.int(C.getMyCharLen(resp))) + Error("error NewReliabilityManager: %v", errMsg) + return nil, errors.New(errMsg) + } + + wg.Add(1) + rm.rmCtx = C.cGoNewReliabilityManager(cChannelId, resp) + wg.Wait() + + C.cGoSetEventCallback(rm.rmCtx) + registerReliabilityManager(rm) + + Debug("Successfully created Reliability Manager") + return rm, nil +} + +// The event callback sends back the rm ctx to know to which +// rm is the event being emited for. Since we only have a global +// callback in the go side, We register all the rm's that we create +// so we can later obtain which instance of `ReliabilityManager` it should +// be invoked depending on the ctx received + +var rmRegistry map[unsafe.Pointer]*ReliabilityManager + +func init() { + rmRegistry = make(map[unsafe.Pointer]*ReliabilityManager) +} + +func registerReliabilityManager(rm *ReliabilityManager) { + _, ok := rmRegistry[rm.rmCtx] + if !ok { + rmRegistry[rm.rmCtx] = rm + } +} + +func unregisterReliabilityManager(rm *ReliabilityManager) { + delete(rmRegistry, rm.rmCtx) +} + +//export globalEventCallback +func globalEventCallback(callerRet C.int, msg *C.char, len C.size_t, userData unsafe.Pointer) { + if callerRet == C.RET_OK { + eventStr := C.GoStringN(msg, C.int(len)) + rm, ok := rmRegistry[userData] // userData contains rm's ctx + if ok { + rm.OnEvent(eventStr) + } + } else { + if len != 0 { + errMsg := C.GoStringN(msg, C.int(len)) + Error("globalEventCallback retCode not ok, retCode: %v: %v", callerRet, errMsg) + } else { + Error("globalEventCallback retCode not ok, retCode: %v", callerRet) + } + } +} + +type jsonEvent struct { + EventType string `json:"eventType"` +} + +type msgEvent struct { + MessageId MessageID `json:"messageId"` +} + +type missingDepsEvent struct { + MessageId MessageID `json:"messageId"` + MissingDeps []MessageID `json:"missingDeps"` +} + +func (rm *ReliabilityManager) RegisterCallbacks(callbacks EventCallbacks) { + rm.callbacks = callbacks +} + +func (rm *ReliabilityManager) OnEvent(eventStr string) { + + jsonEvent := jsonEvent{} + err := json.Unmarshal([]byte(eventStr), &jsonEvent) + if err != nil { + Error("could not unmarshal sds event string: %v", err) + + return + } + + switch jsonEvent.EventType { + case "message_ready": + rm.parseMessageReadyEvent(eventStr) + case "message_sent": + rm.parseMessageSentEvent(eventStr) + case "missing_dependencies": + rm.parseMissingDepsEvent(eventStr) + case "periodic_sync": + if rm.callbacks.OnPeriodicSync != nil { + rm.callbacks.OnPeriodicSync() + } + } + +} + +func (rm *ReliabilityManager) parseMessageReadyEvent(eventStr string) { + + msgEvent := msgEvent{} + err := json.Unmarshal([]byte(eventStr), &msgEvent) + if err != nil { + Error("could not parse message ready event %v", err) + } + + if rm.callbacks.OnMessageReady != nil { + rm.callbacks.OnMessageReady(msgEvent.MessageId) + } +} + +func (rm *ReliabilityManager) parseMessageSentEvent(eventStr string) { + + msgEvent := msgEvent{} + err := json.Unmarshal([]byte(eventStr), &msgEvent) + if err != nil { + Error("could not parse message sent event %v", err) + } + + if rm.callbacks.OnMessageSent != nil { + rm.callbacks.OnMessageSent(msgEvent.MessageId) + } +} + +func (rm *ReliabilityManager) parseMissingDepsEvent(eventStr string) { + + missingDepsEvent := missingDepsEvent{} + err := json.Unmarshal([]byte(eventStr), &missingDepsEvent) + if err != nil { + Error("could not parse missing dependencies event %v", err) + } + + if rm.callbacks.OnMissingDependencies != nil { + rm.callbacks.OnMissingDependencies(missingDepsEvent.MessageId, missingDepsEvent.MissingDeps) + } +} + +func (rm *ReliabilityManager) Cleanup() error { + if rm == nil { + err := errors.New("reliability manager is nil in Cleanup") + Error("Failed to cleanup %v", err) + return err + } + + Debug("Cleaning up reliability manager") + + wg := sync.WaitGroup{} + var resp = C.allocResp(unsafe.Pointer(&wg)) + defer C.freeResp(resp) + + wg.Add(1) + C.cGoCleanupReliabilityManager(rm.rmCtx, resp) + wg.Wait() + + if C.getRet(resp) == C.RET_OK { + unregisterReliabilityManager(rm) + Debug("Successfully cleaned up reliability manager") + return nil + } + + errMsg := "error CleanupReliabilityManager: " + C.GoStringN(C.getMyCharPtr(resp), C.int(C.getMyCharLen(resp))) + Error("Failed to cleanup reliability manager: %v", errMsg) + + return errors.New(errMsg) +} + +func (rm *ReliabilityManager) Reset() error { + if rm == nil { + err := errors.New("reliability manager is nil in Reset") + Error("Failed to reset %v", err) + return err + } + + Debug("Resetting reliability manager") + + wg := sync.WaitGroup{} + var resp = C.allocResp(unsafe.Pointer(&wg)) + defer C.freeResp(resp) + + wg.Add(1) + C.cGoResetReliabilityManager(rm.rmCtx, resp) + wg.Wait() + + if C.getRet(resp) == C.RET_OK { + Debug("Successfully resetted reliability manager") + return nil + } + + errMsg := "error ResetReliabilityManager: " + C.GoStringN(C.getMyCharPtr(resp), C.int(C.getMyCharLen(resp))) + Error("Failed to reset reliability manager: %v", errMsg) + + return errors.New(errMsg) +} + +func (rm *ReliabilityManager) WrapOutgoingMessage(message []byte, messageId MessageID) ([]byte, error) { + if rm == nil { + err := errors.New("reliability manager is nil in WrapOutgoingMessage") + Error("Failed to wrap outgoing message %v", err) + return nil, err + } + + Debug("Wrapping outgoing message %v", messageId) + + wg := sync.WaitGroup{} + var resp = C.allocResp(unsafe.Pointer(&wg)) + defer C.freeResp(resp) + + cMessageId := C.CString(string(messageId)) + defer C.free(unsafe.Pointer(cMessageId)) + + var cMessagePtr unsafe.Pointer + if len(message) > 0 { + cMessagePtr = C.CBytes(message) // C.CBytes allocates memory that needs to be freed + defer C.free(cMessagePtr) + } else { + cMessagePtr = nil + } + cMessageLen := C.size_t(len(message)) + + wg.Add(1) + C.cGoWrapOutgoingMessage(rm.rmCtx, cMessagePtr, cMessageLen, cMessageId, resp) + wg.Wait() + + if C.getRet(resp) == C.RET_OK { + resStr := C.GoStringN(C.getMyCharPtr(resp), C.int(C.getMyCharLen(resp))) + if resStr == "" { + Debug("Received empty res string for messageId: %v", messageId) + return nil, nil + } + Debug("Successfully wrapped message %s", messageId) + + parts := strings.Split(resStr, ",") + bytes := make([]byte, len(parts)) + + for i, part := range parts { + n, err := strconv.Atoi(strings.TrimSpace(part)) + if err != nil { + panic(err) + } + bytes[i] = byte(n) + } + + return bytes, nil + } + + errMsg := "error WrapOutgoingMessage: " + C.GoStringN(C.getMyCharPtr(resp), C.int(C.getMyCharLen(resp))) + Error("Failed to wrap message %v: %v", messageId, errMsg) + + return nil, errors.New(errMsg) +} + +func (rm *ReliabilityManager) UnwrapReceivedMessage(message []byte) (*UnwrappedMessage, error) { + if rm == nil { + err := errors.New("reliability manager is nil in UnwrapReceivedMessage") + Error("Failed to unwrap received message %v", err) + return nil, err + } + + wg := sync.WaitGroup{} + var resp = C.allocResp(unsafe.Pointer(&wg)) + defer C.freeResp(resp) + + var cMessagePtr unsafe.Pointer + if len(message) > 0 { + cMessagePtr = C.CBytes(message) // C.CBytes allocates memory that needs to be freed + defer C.free(cMessagePtr) + } else { + cMessagePtr = nil + } + cMessageLen := C.size_t(len(message)) + + wg.Add(1) + C.cGoUnwrapReceivedMessage(rm.rmCtx, cMessagePtr, cMessageLen, resp) + wg.Wait() + + if C.getRet(resp) == C.RET_OK { + resStr := C.GoStringN(C.getMyCharPtr(resp), C.int(C.getMyCharLen(resp))) + if resStr == "" { + Debug("Received empty res string") + return nil, nil + } + Debug("Successfully unwrapped message") + + unwrappedMessage := UnwrappedMessage{} + err := json.Unmarshal([]byte(resStr), &unwrappedMessage) + if err != nil { + Error("Failed to unmarshal unwrapped message") + return nil, err + } + + return &unwrappedMessage, nil + } + + errMsg := "error UnwrapReceivedMessage: " + C.GoStringN(C.getMyCharPtr(resp), C.int(C.getMyCharLen(resp))) + Error("Failed to unwrap message: %v", errMsg) + + return nil, errors.New(errMsg) +} + +// MarkDependenciesMet informs the library that dependencies are met +func (rm *ReliabilityManager) MarkDependenciesMet(messageIDs []MessageID) error { + if rm == nil { + err := errors.New("reliability manager is nil in MarkDependenciesMet") + Error("Failed to mark dependencies met %v", err) + return err + } + + if len(messageIDs) == 0 { + return nil // Nothing to do + } + + wg := sync.WaitGroup{} + var resp = C.allocResp(unsafe.Pointer(&wg)) + defer C.freeResp(resp) + + // Convert Go string slice to C array of C strings (char**) + cMessageIDs := make([]*C.char, len(messageIDs)) + for i, id := range messageIDs { + cMessageIDs[i] = C.CString(string(id)) + defer C.free(unsafe.Pointer(cMessageIDs[i])) // Ensure each CString is freed + } + + // Create a pointer (**C.char) to the first element of the slice + var cMessageIDsPtr **C.char + if len(cMessageIDs) > 0 { + cMessageIDsPtr = &cMessageIDs[0] + } else { + cMessageIDsPtr = nil // Handle empty slice case + } + + wg.Add(1) + // Pass the pointer variable (cMessageIDsPtr) directly, which is of type **C.char + C.cGoMarkDependenciesMet(rm.rmCtx, cMessageIDsPtr, C.size_t(len(messageIDs)), resp) + wg.Wait() + + if C.getRet(resp) == C.RET_OK { + Debug("Successfully marked dependencies as met") + return nil + } + + errMsg := "error MarkDependenciesMet: " + C.GoStringN(C.getMyCharPtr(resp), C.int(C.getMyCharLen(resp))) + Error("Failed to mark dependencies as met: %v", errMsg) + + return errors.New(errMsg) +} + +func (rm *ReliabilityManager) StartPeriodicTasks() error { + if rm == nil { + err := errors.New("reliability manager is nil in StartPeriodicTasks") + Error("Failed to start periodic tasks %v", err) + return err + } + + Debug("Starting periodic tasks") + + wg := sync.WaitGroup{} + var resp = C.allocResp(unsafe.Pointer(&wg)) + defer C.freeResp(resp) + + wg.Add(1) + C.cGoStartPeriodicTasks(rm.rmCtx, resp) + wg.Wait() + + if C.getRet(resp) == C.RET_OK { + Debug("Successfully started periodic tasks") + return nil + } + + errMsg := "error StartPeriodicTasks: " + C.GoStringN(C.getMyCharPtr(resp), C.int(C.getMyCharLen(resp))) + Error("Failed to start periodic tasks: %v", errMsg) + + return errors.New(errMsg) +} diff --git a/sds/sds_test.go b/sds/sds_test.go new file mode 100644 index 0000000..36e37f2 --- /dev/null +++ b/sds/sds_test.go @@ -0,0 +1,539 @@ +package sds + +import ( + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +// Test basic creation, cleanup, and reset +func TestLifecycle(t *testing.T) { + channelID := "test-lifecycle" + rm, err := NewReliabilityManager(channelID) + require.NoError(t, err) + require.NotNil(t, rm, "Expected ReliabilityManager to be not nil") + + defer rm.Cleanup() // Ensure cleanup even on test failure + + err = rm.Reset() + require.NoError(t, err) +} + +// Test wrapping and unwrapping a simple message +func TestWrapUnwrap(t *testing.T) { + channelID := "test-wrap-unwrap" + rm, err := NewReliabilityManager(channelID) + require.NoError(t, err) + defer rm.Cleanup() + + originalPayload := []byte("hello reliability") + messageID := MessageID("msg-wrap-1") + + wrappedMsg, err := rm.WrapOutgoingMessage(originalPayload, messageID) + require.NoError(t, err) + + require.Greater(t, len(wrappedMsg), 0, "Expected non-empty wrapped message") + + // Simulate receiving the wrapped message + unwrappedMessage, err := rm.UnwrapReceivedMessage(wrappedMsg) + require.NoError(t, err) + + require.Equal(t, string(*unwrappedMessage.Message), string(originalPayload), "Expected unwrapped and original payloads to be equal") + require.Equal(t, len(*unwrappedMessage.MissingDeps), 0, "Expexted to be no missing dependencies") +} + +// Test dependency handling +func TestDependencies(t *testing.T) { + channelID := "test-deps" + rm, err := NewReliabilityManager(channelID) + require.NoError(t, err) + defer rm.Cleanup() + + // 1. Send message 1 (will become a dependency) + payload1 := []byte("message one") + msgID1 := MessageID("msg-dep-1") + wrappedMsg1, err := rm.WrapOutgoingMessage(payload1, msgID1) + require.NoError(t, err) + + // Simulate receiving msg1 to add it to history (implicitly acknowledges it) + _, err = rm.UnwrapReceivedMessage(wrappedMsg1) + require.NoError(t, err) + + // 2. Send message 2 (depends on message 1 implicitly via causal history) + payload2 := []byte("message two") + msgID2 := MessageID("msg-dep-2") + wrappedMsg2, err := rm.WrapOutgoingMessage(payload2, msgID2) + require.NoError(t, err) + + // 3. Create a new manager to simulate a different peer receiving msg2 without msg1 + rm2, err := NewReliabilityManager(channelID) // Same channel ID + require.NoError(t, err) + defer rm2.Cleanup() + + // 4. Unwrap message 2 on the second manager - should report msg1 as missing + unwrappedMessage2, err := rm2.UnwrapReceivedMessage(wrappedMsg2) + require.NoError(t, err) + + require.Greater(t, len(*unwrappedMessage2.MissingDeps), 0, "Expected missing dependencies, got none") + + foundDep1 := false + for _, dep := range *unwrappedMessage2.MissingDeps { + if dep == msgID1 { + foundDep1 = true + break + } + } + require.True(t, foundDep1, "Expected missing dependency %q, got %v", msgID1, *unwrappedMessage2.MissingDeps) + + // 5. Mark the dependency as met + err = rm2.MarkDependenciesMet([]MessageID{msgID1}) + require.NoError(t, err) +} + +// Test OnMessageReady callback +func TestCallback_OnMessageReady(t *testing.T) { + channelID := "test-cb-ready" + + // Create sender and receiver RMs + senderRm, err := NewReliabilityManager(channelID) + require.NoError(t, err) + defer senderRm.Cleanup() + + receiverRm, err := NewReliabilityManager(channelID) + require.NoError(t, err) + defer receiverRm.Cleanup() + + // Use a channel for signaling + readyChan := make(chan MessageID, 1) + + callbacks := EventCallbacks{ + OnMessageReady: func(messageId MessageID) { + // Non-blocking send to channel + select { + case readyChan <- messageId: + default: + // Avoid blocking if channel is full or test already timed out + } + }, + } + + // Register callback only on the receiver + receiverRm.RegisterCallbacks(callbacks) + + // Scenario: Wrap message on sender, unwrap on receiver + payload := []byte("ready test") + msgID := MessageID("cb-ready-1") + + // Wrap on sender + wrappedMsg, err := senderRm.WrapOutgoingMessage(payload, msgID) + require.NoError(t, err) + + // Unwrap on receiver + _, err = receiverRm.UnwrapReceivedMessage(wrappedMsg) + require.NoError(t, err) + + // Verification - Wait on channel with timeout + select { + case receivedMsgID := <-readyChan: + // Mark as called implicitly since we received on channel + if receivedMsgID != msgID { + t.Errorf("OnMessageReady called with wrong ID: got %q, want %q", receivedMsgID, msgID) + } + case <-time.After(2 * time.Second): + // If timeout occurs, the channel receive failed. + t.Errorf("Timed out waiting for OnMessageReady callback on readyChan") + } +} + +// Test OnMessageSent callback (via causal history ACK) +func TestCallback_OnMessageSent(t *testing.T) { + channelID := "test-cb-sent" + + // Create two RMs + rm1, err := NewReliabilityManager(channelID) + require.NoError(t, err) + defer rm1.Cleanup() + + rm2, err := NewReliabilityManager(channelID) + require.NoError(t, err) + defer rm2.Cleanup() + + var wg sync.WaitGroup + sentCalled := false + var sentMsgID MessageID + var cbMutex sync.Mutex + + callbacks := EventCallbacks{ + OnMessageSent: func(messageId MessageID) { + cbMutex.Lock() + sentCalled = true + sentMsgID = messageId + cbMutex.Unlock() + wg.Done() + }, + } + + // Register callback on rm1 (the original sender) + rm1.RegisterCallbacks(callbacks) + + // Scenario: rm1 sends msg1, rm2 receives msg1, + // rm2 sends msg2 (acking msg1), rm1 receives msg2. + + // 1. rm1 sends msg1 + payload1 := []byte("sent test 1") + msgID1 := MessageID("cb-sent-1") + wrappedMsg1, err := rm1.WrapOutgoingMessage(payload1, msgID1) + require.NoError(t, err) + // Note: msg1 is now in rm1's outgoing buffer + + // 2. rm2 receives msg1 (to update its state) + _, err = rm2.UnwrapReceivedMessage(wrappedMsg1) + require.NoError(t, err) + + // 3. rm2 sends msg2 (will include msg1 in causal history) + payload2 := []byte("sent test 2") + msgID2 := MessageID("cb-sent-2") + wrappedMsg2, err := rm2.WrapOutgoingMessage(payload2, msgID2) + require.NoError(t, err) + + // 4. rm1 receives msg2 (should trigger ACK for msg1) + wg.Add(1) // Expect OnMessageSent for msg1 on rm1 + _, err = rm1.UnwrapReceivedMessage(wrappedMsg2) + require.NoError(t, err) + + // Verification + waitTimeout(&wg, 2*time.Second, t) + + cbMutex.Lock() + defer cbMutex.Unlock() + if !sentCalled { + t.Errorf("OnMessageSent was not called") + } + // We primarily care that msg1 was ACKed. + if sentMsgID != msgID1 { + t.Errorf("OnMessageSent called with wrong ID: got %q, want %q", sentMsgID, msgID1) + } +} + +// Test OnMissingDependencies callback +func TestCallback_OnMissingDependencies(t *testing.T) { + channelID := "test-cb-missing" + + // Use separate sender/receiver RMs explicitly + senderRm, err := NewReliabilityManager(channelID) + require.NoError(t, err) + defer senderRm.Cleanup() + + receiverRm, err := NewReliabilityManager(channelID) + require.NoError(t, err) + defer receiverRm.Cleanup() + + var wg sync.WaitGroup + missingCalled := false + var missingMsgID MessageID + var missingDepsList []MessageID + var cbMutex sync.Mutex + + callbacks := EventCallbacks{ + OnMissingDependencies: func(messageId MessageID, missingDeps []MessageID) { + cbMutex.Lock() + missingCalled = true + missingMsgID = messageId + missingDepsList = missingDeps // Copy slice + cbMutex.Unlock() + wg.Done() + }, + } + + // Register callback only on the receiver rm + receiverRm.RegisterCallbacks(callbacks) + + // Scenario: Sender sends msg1, then sender sends msg2 (depends on msg1), + // then receiver receives msg2 (which hasn't seen msg1). + + // 1. Sender sends msg1 + payload1 := []byte("missing test 1") + msgID1 := MessageID("cb-miss-1") + _, err = senderRm.WrapOutgoingMessage(payload1, msgID1) + require.NoError(t, err) + + // 2. Sender sends msg2 (depends on msg1) + payload2 := []byte("missing test 2") + msgID2 := MessageID("cb-miss-2") + wrappedMsg2, err := senderRm.WrapOutgoingMessage(payload2, msgID2) + require.NoError(t, err) + + // 3. Receiver receives msg2 (haven't seen msg1) + wg.Add(1) // Expect OnMissingDependencies + _, err = receiverRm.UnwrapReceivedMessage(wrappedMsg2) + require.NoError(t, err) + + // Verification + waitTimeout(&wg, 2*time.Second, t) + + cbMutex.Lock() + defer cbMutex.Unlock() + if !missingCalled { + t.Errorf("OnMissingDependencies was not called") + } + if missingMsgID != msgID2 { + t.Errorf("OnMissingDependencies called for wrong ID: got %q, want %q", missingMsgID, msgID2) + } + foundDep := false + for _, dep := range missingDepsList { + if dep == msgID1 { + foundDep = true + break + } + } + if !foundDep { + t.Errorf("OnMissingDependencies did not report %q as missing, got: %v", msgID1, missingDepsList) + } +} + +// Test OnPeriodicSync callback +func TestCallback_OnPeriodicSync(t *testing.T) { + channelID := "test-cb-sync" + rm, err := NewReliabilityManager(channelID) + require.NoError(t, err) + defer rm.Cleanup() + + syncCalled := false + var cbMutex sync.Mutex + // Use a channel to signal when the callback is hit + syncChan := make(chan bool, 1) + + callbacks := EventCallbacks{ + OnPeriodicSync: func() { + cbMutex.Lock() + if !syncCalled { // Only signal the first time + syncCalled = true + syncChan <- true + } + cbMutex.Unlock() + }, + } + + rm.RegisterCallbacks(callbacks) + + // Start periodic tasks + err = rm.StartPeriodicTasks() + require.NoError(t, err) + + // --- Verification --- + // Wait for the periodic sync callback with a timeout (needs to be longer than sync interval) + select { + case <-syncChan: + // Success + case <-time.After(10 * time.Second): + t.Errorf("Timed out waiting for OnPeriodicSync callback") + } + + cbMutex.Lock() + defer cbMutex.Unlock() + if !syncCalled { + // This might happen if the timeout was too short + t.Logf("Warning: OnPeriodicSync might not have been called within the test timeout") + } +} + +// Combined Test for multiple callbacks +func TestCallbacks_Combined(t *testing.T) { + channelID := "test-cb-combined" + + // Create sender and receiver RMs + senderRm, err := NewReliabilityManager(channelID) + require.NoError(t, err) + defer senderRm.Cleanup() + + receiverRm, err := NewReliabilityManager(channelID) + require.NoError(t, err) + defer receiverRm.Cleanup() + + // Channels for synchronization + readyChan1 := make(chan bool, 1) + sentChan1 := make(chan bool, 1) + missingChan := make(chan []MessageID, 1) + + // Use maps for verification + receivedReady := make(map[MessageID]bool) + receivedSent := make(map[MessageID]bool) + var cbMutex sync.Mutex + + callbacksReceiver := EventCallbacks{ + OnMessageReady: func(messageId MessageID) { + cbMutex.Lock() + receivedReady[messageId] = true + cbMutex.Unlock() + if messageId == "cb-comb-1" { + // Use non-blocking send + select { + case readyChan1 <- true: + default: + } + } + }, + OnMessageSent: func(messageId MessageID) { + // This callback is registered on Receiver, but Sent events + // are typically relevant to the Sender. We don't expect this. + t.Errorf("Unexpected OnMessageSent call on Receiver for %s", messageId) + }, + OnMissingDependencies: func(messageId MessageID, missingDeps []MessageID) { + // This callback is registered on Receiver, used for receiverRm2 below + }, + } + + callbacksSender := EventCallbacks{ + OnMessageReady: func(messageId MessageID) { + // Not expected on sender in this test flow + }, + OnMessageSent: func(messageId MessageID) { + cbMutex.Lock() + receivedSent[messageId] = true + cbMutex.Unlock() + if messageId == "cb-comb-1" { + select { + case sentChan1 <- true: + default: + } + } + }, + OnMissingDependencies: func(messageId MessageID, missingDeps []MessageID) { + // Not expected on sender + }, + } + + // Register callbacks + receiverRm.RegisterCallbacks(callbacksReceiver) + senderRm.RegisterCallbacks(callbacksSender) + + // --- Test Scenario --- + msgID1 := MessageID("cb-comb-1") + msgID2 := MessageID("cb-comb-2") + msgID3 := MessageID("cb-comb-3") + payload1 := []byte("combined test 1") + payload2 := []byte("combined test 2") + payload3 := []byte("combined test 3") + + // 1. Sender sends msg1 + wrappedMsg1, err := senderRm.WrapOutgoingMessage(payload1, msgID1) + require.NoError(t, err) + + // 2. Receiver receives msg1 + _, err = receiverRm.UnwrapReceivedMessage(wrappedMsg1) + require.NoError(t, err) + + // 3. Receiver sends msg2 (depends on msg1 implicitly via state) + wrappedMsg2, err := receiverRm.WrapOutgoingMessage(payload2, msgID2) + require.NoError(t, err) + + // 4. Sender receives msg2 from Receiver (acks msg1 for sender) + _, err = senderRm.UnwrapReceivedMessage(wrappedMsg2) + require.NoError(t, err) + + // 5. Sender sends msg3 (depends on msg2) + wrappedMsg3, err := senderRm.WrapOutgoingMessage(payload3, msgID3) + require.NoError(t, err) + + // 6. Create Receiver2, register missing deps callback + receiverRm2, err := NewReliabilityManager(channelID) + require.NoError(t, err) + defer receiverRm2.Cleanup() + + callbacksReceiver2 := EventCallbacks{ + OnMissingDependencies: func(messageId MessageID, missingDeps []MessageID) { + if messageId == msgID3 { + select { + case missingChan <- missingDeps: + default: + } + } + }, + } + + receiverRm2.RegisterCallbacks(callbacksReceiver2) + + // 7. Receiver2 receives msg3 (should report missing msg1, msg2) + _, err = receiverRm2.UnwrapReceivedMessage(wrappedMsg3) + require.NoError(t, err) + + // --- Verification --- + timeout := 5 * time.Second + expectedReady1 := false + expectedSent1 := false + var reportedMissingDeps []MessageID + missingDepsReceived := false + + receivedCount := 0 + expectedCount := 3 // ready1, sent1, missingDeps + timer := time.NewTimer(timeout) + defer timer.Stop() + + for receivedCount < expectedCount { + select { + case <-readyChan1: + if !expectedReady1 { // Avoid double counting if signaled twice + expectedReady1 = true + receivedCount++ + } + case <-sentChan1: + if !expectedSent1 { + expectedSent1 = true + receivedCount++ + } + case deps := <-missingChan: + if !missingDepsReceived { + reportedMissingDeps = deps + missingDepsReceived = true + receivedCount++ + } + case <-timer.C: + t.Fatalf("Timed out waiting for combined callbacks (received %d out of %d)", receivedCount, expectedCount) + } + } + + // Check results + cbMutex.Lock() + defer cbMutex.Unlock() + + if !expectedReady1 || !receivedReady[msgID1] { + t.Errorf("OnMessageReady not called/verified for %s", msgID1) + } + if !expectedSent1 || !receivedSent[msgID1] { + t.Errorf("OnMessageSent not called/verified for %s", msgID1) + } + if !missingDepsReceived { + t.Errorf("OnMissingDependencies not called/verified for %s", msgID3) + } else { + foundDep1 := false + foundDep2 := false + for _, dep := range reportedMissingDeps { + if dep == msgID1 { + foundDep1 = true + } + if dep == msgID2 { + foundDep2 = true + } + } + if !foundDep1 || !foundDep2 { + t.Errorf("OnMissingDependencies for %s reported wrong deps: got %v, want %s and %s", msgID3, reportedMissingDeps, msgID1, msgID2) + } + } +} + +// Helper function to wait for WaitGroup with a timeout +func waitTimeout(wg *sync.WaitGroup, timeout time.Duration, t *testing.T) { + c := make(chan struct{}) + go func() { + defer close(c) + wg.Wait() + }() + select { + case <-c: + // Completed normally + case <-time.After(timeout): + t.Fatalf("Timed out waiting for callbacks") + } +} diff --git a/sds/types.go b/sds/types.go new file mode 100644 index 0000000..cbe3b9b --- /dev/null +++ b/sds/types.go @@ -0,0 +1,8 @@ +package sds + +type MessageID string + +type UnwrappedMessage struct { + Message *[]byte `json:"message"` + MissingDeps *[]MessageID `json:"missingDeps"` +}