Merge pull request #1 from waku-org/feat-initial-setup

feat: initial implementation
This commit is contained in:
gabrielmer 2025-05-28 12:26:54 +02:00 committed by GitHub
commit 78ea8e7f8f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 1299 additions and 1 deletions

25
.gitignore vendored Normal file
View File

@ -0,0 +1,25 @@
# If you prefer the allow list template instead of the deny list, see community template:
# https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore
#
# Binaries for programs and plugins
*.exe
*.exe~
*.dll
*.so
*.dylib
# Test binary, built with `go test -c`
*.test
# Output of the go coverage tool, specifically when used with LiteIDE
*.out
# Dependency directories (remove the comment below to include it)
# vendor/
# Go workspace file
go.work
# Generated dependencies and cache
third_party
nimcache

View File

@ -1 +1,45 @@
# sds-go-bindings
# SDS Go Bindings
This repository provides Go bindings for the SDS library, enabling seamless integration with Go projects.
## Installation
To build the required dependencies for this module, the `make` command needs to be executed. If you are integrating this module into another project via `go get`, ensure that you navigate to the `sds-go-bindings/sds` directory and run `make`.
### Steps to Install
Follow these steps to install and set up the module:
1. Retrieve the module using `go get`:
```
go get -u github.com/waku-org/sds-go-bindings
```
2. Navigate to the module's directory:
```
cd $(go list -m -f '{{.Dir}}' github.com/waku-org/sds-go-bindings)
```
3. Prepare third_party directory which will clone `nim-sds`
```
sudo mkdir third_party
sudo chown $USER third_party
```
4. Build the dependencies:
```
make -C sds
```
Now the module is ready for use in your project.
### Note
In order to easily build the libsds library on demand, it is recommended to add the following target in your project's Makefile:
```
LIBSDS_DEP_PATH=$(shell go list -m -f '{{.Dir}}' github.com/waku-org/sds-go-bindings)
buildlib:
cd $(LIBSDS_DEP_PATH) &&\
sudo mkdir -p third_party &&\
sudo chown $(USER) third_party &&\
make -C sds
```

15
go.mod Normal file
View File

@ -0,0 +1,15 @@
module github.com/waku-org/sds-go-bindings
go 1.22.10
require (
github.com/stretchr/testify v1.8.1
go.uber.org/zap v1.27.0
)
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
go.uber.org/multierr v1.10.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

20
go.sum Normal file
View File

@ -0,0 +1,20 @@
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ=
go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8=
go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

44
sds/Makefile Normal file
View File

@ -0,0 +1,44 @@
# Makefile for SDS Go Bindings
# Directories
THIRD_PARTY_DIR := ../third_party
NIM_SDS_REPO := https://github.com/waku-org/nim-sds
NIM_SDS_DIR := $(THIRD_PARTY_DIR)/nim-sds
.PHONY: all clean prepare build-libsds build
# Default target
all: build
# Prepare third_party directory and clone nim-sds
# TODO: remove the "git checkout gabrielmer-feat-init-implementation" part
prepare:
@echo "Creating third_party directory..."
@mkdir -p $(THIRD_PARTY_DIR)
@echo "Cloning nim-sds repository..."
@if [ ! -d "$(NIM_SDS_DIR)" ]; then \
cd $(THIRD_PARTY_DIR) && \
git clone $(NIM_SDS_REPO) && \
cd $(NIM_SDS_DIR) && \
git checkout gabrielmer-feat-init-implementation; \
make update; \
else \
echo "nim-sds repository already exists."; \
fi
# Build libsds
build-libsds: prepare
@echo "Building libsds..."
@cd $(NIM_SDS_DIR) && make libsds
# Build SDS Go Bindings
build: build-libsds
@echo "Building SDS Go Bindings..."
go build ./...
# Clean up generated files
clean:
@echo "Cleaning up..."
@rm -rf $(THIRD_PARTY_DIR)
@rm -f sds-go-bindings

47
sds/logging.go Normal file
View File

@ -0,0 +1,47 @@
package sds
import (
"sync"
"go.uber.org/zap"
)
var (
once sync.Once
sugar *zap.SugaredLogger
)
func _getLogger() *zap.SugaredLogger {
once.Do(func() {
config := zap.NewDevelopmentConfig()
l, err := config.Build()
if err != nil {
panic(err)
}
sugar = l.Sugar()
})
return sugar
}
func SetLogger(newLogger *zap.Logger) {
once.Do(func() {})
sugar = newLogger.Sugar()
}
func Debug(msg string, args ...interface{}) {
_getLogger().Debugf(msg, args...)
}
func Info(msg string, args ...interface{}) {
_getLogger().Infof(msg, args...)
}
func Warn(msg string, args ...interface{}) {
_getLogger().Warnf(msg, args...)
}
func Error(msg string, args ...interface{}) {
_getLogger().Errorf(msg, args...)
}

556
sds/sds.go Normal file
View File

@ -0,0 +1,556 @@
package sds
/*
#cgo LDFLAGS: -L../third_party/nim-sds/build/ -lsds
#cgo LDFLAGS: -L../third_party/nim-sds -Wl,-rpath,../third_party/nim-sds/build/
#include "../third_party/nim-sds/library/libsds.h"
#include <stdio.h>
#include <stdlib.h>
extern void globalEventCallback(int ret, char* msg, size_t len, void* userData);
typedef struct {
int ret;
char* msg;
size_t len;
void* ffiWg;
} Resp;
static void* allocResp(void* wg) {
Resp* r = calloc(1, sizeof(Resp));
r->ffiWg = wg;
return r;
}
static void freeResp(void* resp) {
if (resp != NULL) {
free(resp);
}
}
static char* getMyCharPtr(void* resp) {
if (resp == NULL) {
return NULL;
}
Resp* m = (Resp*) resp;
return m->msg;
}
static size_t getMyCharLen(void* resp) {
if (resp == NULL) {
return 0;
}
Resp* m = (Resp*) resp;
return m->len;
}
static int getRet(void* resp) {
if (resp == NULL) {
return 0;
}
Resp* m = (Resp*) resp;
return m->ret;
}
// resp must be set != NULL in case interest on retrieving data from the callback
void GoCallback(int ret, char* msg, size_t len, void* resp);
static void* cGoNewReliabilityManager(const char* channelId, void* resp) {
// We pass NULL because we are not interested in retrieving data from this callback
void* ret = NewReliabilityManager(channelId, (SdsCallBack) GoCallback, resp);
return ret;
}
static void cGoSetEventCallback(void* rmCtx) {
// The 'globalEventCallback' Go function is shared amongst all possible Reliability Manager instances.
// Given that the 'globalEventCallback' is shared, we pass again the
// rmCtx instance but in this case is needed to pick up the correct method
// that will handle the event.
// In other words, for every call libsds makes to globalEventCallback,
// the 'userData' parameter will bring the context of the rm that registered
// that globalEventCallback.
// This technique is needed because cgo only allows to export Go functions and not methods.
SetEventCallback(rmCtx, (SdsCallBack) globalEventCallback, rmCtx);
}
static void cGoCleanupReliabilityManager(void* rmCtx, void* resp) {
CleanupReliabilityManager(rmCtx, (SdsCallBack) GoCallback, resp);
}
static void cGoResetReliabilityManager(void* rmCtx, void* resp) {
ResetReliabilityManager(rmCtx, (SdsCallBack) GoCallback, resp);
}
static void cGoWrapOutgoingMessage(void* rmCtx,
void* message,
size_t messageLen,
const char* messageId,
void* resp) {
WrapOutgoingMessage(rmCtx,
message,
messageLen,
messageId,
(SdsCallBack) GoCallback,
resp);
}
static void cGoUnwrapReceivedMessage(void* rmCtx,
void* message,
size_t messageLen,
void* resp) {
UnwrapReceivedMessage(rmCtx,
message,
messageLen,
(SdsCallBack) GoCallback,
resp);
}
static void cGoMarkDependenciesMet(void* rmCtx,
char** messageIDs,
size_t count,
void* resp) {
MarkDependenciesMet(rmCtx,
messageIDs,
count,
(SdsCallBack) GoCallback,
resp);
}
static void cGoStartPeriodicTasks(void* rmCtx, void* resp) {
StartPeriodicTasks(rmCtx, (SdsCallBack) GoCallback, resp);
}
*/
import "C"
import (
"encoding/json"
"errors"
"strconv"
"strings"
"sync"
"time"
"unsafe"
)
const requestTimeout = 30 * time.Second
const EventChanBufferSize = 1024
//export GoCallback
func GoCallback(ret C.int, msg *C.char, len C.size_t, resp unsafe.Pointer) {
if resp != nil {
m := (*C.Resp)(resp)
m.ret = ret
m.msg = msg
m.len = len
wg := (*sync.WaitGroup)(m.ffiWg)
wg.Done()
}
}
type EventCallbacks struct {
OnMessageReady func(messageId MessageID)
OnMessageSent func(messageId MessageID)
OnMissingDependencies func(messageId MessageID, missingDeps []MessageID)
OnPeriodicSync func()
}
// ReliabilityManager represents an instance of a nim-sds ReliabilityManager
type ReliabilityManager struct {
rmCtx unsafe.Pointer
channelId string
callbacks EventCallbacks
}
func NewReliabilityManager(channelId string) (*ReliabilityManager, error) {
Debug("Creating new Reliability Manager")
rm := &ReliabilityManager{
channelId: channelId,
}
wg := sync.WaitGroup{}
var cChannelId = C.CString(string(channelId))
var resp = C.allocResp(unsafe.Pointer(&wg))
defer C.free(unsafe.Pointer(cChannelId))
defer C.freeResp(resp)
if C.getRet(resp) != C.RET_OK {
errMsg := C.GoStringN(C.getMyCharPtr(resp), C.int(C.getMyCharLen(resp)))
Error("error NewReliabilityManager: %v", errMsg)
return nil, errors.New(errMsg)
}
wg.Add(1)
rm.rmCtx = C.cGoNewReliabilityManager(cChannelId, resp)
wg.Wait()
C.cGoSetEventCallback(rm.rmCtx)
registerReliabilityManager(rm)
Debug("Successfully created Reliability Manager")
return rm, nil
}
// The event callback sends back the rm ctx to know to which
// rm is the event being emited for. Since we only have a global
// callback in the go side, We register all the rm's that we create
// so we can later obtain which instance of `ReliabilityManager` it should
// be invoked depending on the ctx received
var rmRegistry map[unsafe.Pointer]*ReliabilityManager
func init() {
rmRegistry = make(map[unsafe.Pointer]*ReliabilityManager)
}
func registerReliabilityManager(rm *ReliabilityManager) {
_, ok := rmRegistry[rm.rmCtx]
if !ok {
rmRegistry[rm.rmCtx] = rm
}
}
func unregisterReliabilityManager(rm *ReliabilityManager) {
delete(rmRegistry, rm.rmCtx)
}
//export globalEventCallback
func globalEventCallback(callerRet C.int, msg *C.char, len C.size_t, userData unsafe.Pointer) {
if callerRet == C.RET_OK {
eventStr := C.GoStringN(msg, C.int(len))
rm, ok := rmRegistry[userData] // userData contains rm's ctx
if ok {
rm.OnEvent(eventStr)
}
} else {
if len != 0 {
errMsg := C.GoStringN(msg, C.int(len))
Error("globalEventCallback retCode not ok, retCode: %v: %v", callerRet, errMsg)
} else {
Error("globalEventCallback retCode not ok, retCode: %v", callerRet)
}
}
}
type jsonEvent struct {
EventType string `json:"eventType"`
}
type msgEvent struct {
MessageId MessageID `json:"messageId"`
}
type missingDepsEvent struct {
MessageId MessageID `json:"messageId"`
MissingDeps []MessageID `json:"missingDeps"`
}
func (rm *ReliabilityManager) RegisterCallbacks(callbacks EventCallbacks) {
rm.callbacks = callbacks
}
func (rm *ReliabilityManager) OnEvent(eventStr string) {
jsonEvent := jsonEvent{}
err := json.Unmarshal([]byte(eventStr), &jsonEvent)
if err != nil {
Error("could not unmarshal sds event string: %v", err)
return
}
switch jsonEvent.EventType {
case "message_ready":
rm.parseMessageReadyEvent(eventStr)
case "message_sent":
rm.parseMessageSentEvent(eventStr)
case "missing_dependencies":
rm.parseMissingDepsEvent(eventStr)
case "periodic_sync":
if rm.callbacks.OnPeriodicSync != nil {
rm.callbacks.OnPeriodicSync()
}
}
}
func (rm *ReliabilityManager) parseMessageReadyEvent(eventStr string) {
msgEvent := msgEvent{}
err := json.Unmarshal([]byte(eventStr), &msgEvent)
if err != nil {
Error("could not parse message ready event %v", err)
}
if rm.callbacks.OnMessageReady != nil {
rm.callbacks.OnMessageReady(msgEvent.MessageId)
}
}
func (rm *ReliabilityManager) parseMessageSentEvent(eventStr string) {
msgEvent := msgEvent{}
err := json.Unmarshal([]byte(eventStr), &msgEvent)
if err != nil {
Error("could not parse message sent event %v", err)
}
if rm.callbacks.OnMessageSent != nil {
rm.callbacks.OnMessageSent(msgEvent.MessageId)
}
}
func (rm *ReliabilityManager) parseMissingDepsEvent(eventStr string) {
missingDepsEvent := missingDepsEvent{}
err := json.Unmarshal([]byte(eventStr), &missingDepsEvent)
if err != nil {
Error("could not parse missing dependencies event %v", err)
}
if rm.callbacks.OnMissingDependencies != nil {
rm.callbacks.OnMissingDependencies(missingDepsEvent.MessageId, missingDepsEvent.MissingDeps)
}
}
func (rm *ReliabilityManager) Cleanup() error {
if rm == nil {
err := errors.New("reliability manager is nil in Cleanup")
Error("Failed to cleanup %v", err)
return err
}
Debug("Cleaning up reliability manager")
wg := sync.WaitGroup{}
var resp = C.allocResp(unsafe.Pointer(&wg))
defer C.freeResp(resp)
wg.Add(1)
C.cGoCleanupReliabilityManager(rm.rmCtx, resp)
wg.Wait()
if C.getRet(resp) == C.RET_OK {
unregisterReliabilityManager(rm)
Debug("Successfully cleaned up reliability manager")
return nil
}
errMsg := "error CleanupReliabilityManager: " + C.GoStringN(C.getMyCharPtr(resp), C.int(C.getMyCharLen(resp)))
Error("Failed to cleanup reliability manager: %v", errMsg)
return errors.New(errMsg)
}
func (rm *ReliabilityManager) Reset() error {
if rm == nil {
err := errors.New("reliability manager is nil in Reset")
Error("Failed to reset %v", err)
return err
}
Debug("Resetting reliability manager")
wg := sync.WaitGroup{}
var resp = C.allocResp(unsafe.Pointer(&wg))
defer C.freeResp(resp)
wg.Add(1)
C.cGoResetReliabilityManager(rm.rmCtx, resp)
wg.Wait()
if C.getRet(resp) == C.RET_OK {
Debug("Successfully resetted reliability manager")
return nil
}
errMsg := "error ResetReliabilityManager: " + C.GoStringN(C.getMyCharPtr(resp), C.int(C.getMyCharLen(resp)))
Error("Failed to reset reliability manager: %v", errMsg)
return errors.New(errMsg)
}
func (rm *ReliabilityManager) WrapOutgoingMessage(message []byte, messageId MessageID) ([]byte, error) {
if rm == nil {
err := errors.New("reliability manager is nil in WrapOutgoingMessage")
Error("Failed to wrap outgoing message %v", err)
return nil, err
}
Debug("Wrapping outgoing message %v", messageId)
wg := sync.WaitGroup{}
var resp = C.allocResp(unsafe.Pointer(&wg))
defer C.freeResp(resp)
cMessageId := C.CString(string(messageId))
defer C.free(unsafe.Pointer(cMessageId))
var cMessagePtr unsafe.Pointer
if len(message) > 0 {
cMessagePtr = C.CBytes(message) // C.CBytes allocates memory that needs to be freed
defer C.free(cMessagePtr)
} else {
cMessagePtr = nil
}
cMessageLen := C.size_t(len(message))
wg.Add(1)
C.cGoWrapOutgoingMessage(rm.rmCtx, cMessagePtr, cMessageLen, cMessageId, resp)
wg.Wait()
if C.getRet(resp) == C.RET_OK {
resStr := C.GoStringN(C.getMyCharPtr(resp), C.int(C.getMyCharLen(resp)))
if resStr == "" {
Debug("Received empty res string for messageId: %v", messageId)
return nil, nil
}
Debug("Successfully wrapped message %s", messageId)
parts := strings.Split(resStr, ",")
bytes := make([]byte, len(parts))
for i, part := range parts {
n, err := strconv.Atoi(strings.TrimSpace(part))
if err != nil {
panic(err)
}
bytes[i] = byte(n)
}
return bytes, nil
}
errMsg := "error WrapOutgoingMessage: " + C.GoStringN(C.getMyCharPtr(resp), C.int(C.getMyCharLen(resp)))
Error("Failed to wrap message %v: %v", messageId, errMsg)
return nil, errors.New(errMsg)
}
func (rm *ReliabilityManager) UnwrapReceivedMessage(message []byte) (*UnwrappedMessage, error) {
if rm == nil {
err := errors.New("reliability manager is nil in UnwrapReceivedMessage")
Error("Failed to unwrap received message %v", err)
return nil, err
}
wg := sync.WaitGroup{}
var resp = C.allocResp(unsafe.Pointer(&wg))
defer C.freeResp(resp)
var cMessagePtr unsafe.Pointer
if len(message) > 0 {
cMessagePtr = C.CBytes(message) // C.CBytes allocates memory that needs to be freed
defer C.free(cMessagePtr)
} else {
cMessagePtr = nil
}
cMessageLen := C.size_t(len(message))
wg.Add(1)
C.cGoUnwrapReceivedMessage(rm.rmCtx, cMessagePtr, cMessageLen, resp)
wg.Wait()
if C.getRet(resp) == C.RET_OK {
resStr := C.GoStringN(C.getMyCharPtr(resp), C.int(C.getMyCharLen(resp)))
if resStr == "" {
Debug("Received empty res string")
return nil, nil
}
Debug("Successfully unwrapped message")
unwrappedMessage := UnwrappedMessage{}
err := json.Unmarshal([]byte(resStr), &unwrappedMessage)
if err != nil {
Error("Failed to unmarshal unwrapped message")
return nil, err
}
return &unwrappedMessage, nil
}
errMsg := "error UnwrapReceivedMessage: " + C.GoStringN(C.getMyCharPtr(resp), C.int(C.getMyCharLen(resp)))
Error("Failed to unwrap message: %v", errMsg)
return nil, errors.New(errMsg)
}
// MarkDependenciesMet informs the library that dependencies are met
func (rm *ReliabilityManager) MarkDependenciesMet(messageIDs []MessageID) error {
if rm == nil {
err := errors.New("reliability manager is nil in MarkDependenciesMet")
Error("Failed to mark dependencies met %v", err)
return err
}
if len(messageIDs) == 0 {
return nil // Nothing to do
}
wg := sync.WaitGroup{}
var resp = C.allocResp(unsafe.Pointer(&wg))
defer C.freeResp(resp)
// Convert Go string slice to C array of C strings (char**)
cMessageIDs := make([]*C.char, len(messageIDs))
for i, id := range messageIDs {
cMessageIDs[i] = C.CString(string(id))
defer C.free(unsafe.Pointer(cMessageIDs[i])) // Ensure each CString is freed
}
// Create a pointer (**C.char) to the first element of the slice
var cMessageIDsPtr **C.char
if len(cMessageIDs) > 0 {
cMessageIDsPtr = &cMessageIDs[0]
} else {
cMessageIDsPtr = nil // Handle empty slice case
}
wg.Add(1)
// Pass the pointer variable (cMessageIDsPtr) directly, which is of type **C.char
C.cGoMarkDependenciesMet(rm.rmCtx, cMessageIDsPtr, C.size_t(len(messageIDs)), resp)
wg.Wait()
if C.getRet(resp) == C.RET_OK {
Debug("Successfully marked dependencies as met")
return nil
}
errMsg := "error MarkDependenciesMet: " + C.GoStringN(C.getMyCharPtr(resp), C.int(C.getMyCharLen(resp)))
Error("Failed to mark dependencies as met: %v", errMsg)
return errors.New(errMsg)
}
func (rm *ReliabilityManager) StartPeriodicTasks() error {
if rm == nil {
err := errors.New("reliability manager is nil in StartPeriodicTasks")
Error("Failed to start periodic tasks %v", err)
return err
}
Debug("Starting periodic tasks")
wg := sync.WaitGroup{}
var resp = C.allocResp(unsafe.Pointer(&wg))
defer C.freeResp(resp)
wg.Add(1)
C.cGoStartPeriodicTasks(rm.rmCtx, resp)
wg.Wait()
if C.getRet(resp) == C.RET_OK {
Debug("Successfully started periodic tasks")
return nil
}
errMsg := "error StartPeriodicTasks: " + C.GoStringN(C.getMyCharPtr(resp), C.int(C.getMyCharLen(resp)))
Error("Failed to start periodic tasks: %v", errMsg)
return errors.New(errMsg)
}

539
sds/sds_test.go Normal file
View File

@ -0,0 +1,539 @@
package sds
import (
"sync"
"testing"
"time"
"github.com/stretchr/testify/require"
)
// Test basic creation, cleanup, and reset
func TestLifecycle(t *testing.T) {
channelID := "test-lifecycle"
rm, err := NewReliabilityManager(channelID)
require.NoError(t, err)
require.NotNil(t, rm, "Expected ReliabilityManager to be not nil")
defer rm.Cleanup() // Ensure cleanup even on test failure
err = rm.Reset()
require.NoError(t, err)
}
// Test wrapping and unwrapping a simple message
func TestWrapUnwrap(t *testing.T) {
channelID := "test-wrap-unwrap"
rm, err := NewReliabilityManager(channelID)
require.NoError(t, err)
defer rm.Cleanup()
originalPayload := []byte("hello reliability")
messageID := MessageID("msg-wrap-1")
wrappedMsg, err := rm.WrapOutgoingMessage(originalPayload, messageID)
require.NoError(t, err)
require.Greater(t, len(wrappedMsg), 0, "Expected non-empty wrapped message")
// Simulate receiving the wrapped message
unwrappedMessage, err := rm.UnwrapReceivedMessage(wrappedMsg)
require.NoError(t, err)
require.Equal(t, string(*unwrappedMessage.Message), string(originalPayload), "Expected unwrapped and original payloads to be equal")
require.Equal(t, len(*unwrappedMessage.MissingDeps), 0, "Expexted to be no missing dependencies")
}
// Test dependency handling
func TestDependencies(t *testing.T) {
channelID := "test-deps"
rm, err := NewReliabilityManager(channelID)
require.NoError(t, err)
defer rm.Cleanup()
// 1. Send message 1 (will become a dependency)
payload1 := []byte("message one")
msgID1 := MessageID("msg-dep-1")
wrappedMsg1, err := rm.WrapOutgoingMessage(payload1, msgID1)
require.NoError(t, err)
// Simulate receiving msg1 to add it to history (implicitly acknowledges it)
_, err = rm.UnwrapReceivedMessage(wrappedMsg1)
require.NoError(t, err)
// 2. Send message 2 (depends on message 1 implicitly via causal history)
payload2 := []byte("message two")
msgID2 := MessageID("msg-dep-2")
wrappedMsg2, err := rm.WrapOutgoingMessage(payload2, msgID2)
require.NoError(t, err)
// 3. Create a new manager to simulate a different peer receiving msg2 without msg1
rm2, err := NewReliabilityManager(channelID) // Same channel ID
require.NoError(t, err)
defer rm2.Cleanup()
// 4. Unwrap message 2 on the second manager - should report msg1 as missing
unwrappedMessage2, err := rm2.UnwrapReceivedMessage(wrappedMsg2)
require.NoError(t, err)
require.Greater(t, len(*unwrappedMessage2.MissingDeps), 0, "Expected missing dependencies, got none")
foundDep1 := false
for _, dep := range *unwrappedMessage2.MissingDeps {
if dep == msgID1 {
foundDep1 = true
break
}
}
require.True(t, foundDep1, "Expected missing dependency %q, got %v", msgID1, *unwrappedMessage2.MissingDeps)
// 5. Mark the dependency as met
err = rm2.MarkDependenciesMet([]MessageID{msgID1})
require.NoError(t, err)
}
// Test OnMessageReady callback
func TestCallback_OnMessageReady(t *testing.T) {
channelID := "test-cb-ready"
// Create sender and receiver RMs
senderRm, err := NewReliabilityManager(channelID)
require.NoError(t, err)
defer senderRm.Cleanup()
receiverRm, err := NewReliabilityManager(channelID)
require.NoError(t, err)
defer receiverRm.Cleanup()
// Use a channel for signaling
readyChan := make(chan MessageID, 1)
callbacks := EventCallbacks{
OnMessageReady: func(messageId MessageID) {
// Non-blocking send to channel
select {
case readyChan <- messageId:
default:
// Avoid blocking if channel is full or test already timed out
}
},
}
// Register callback only on the receiver
receiverRm.RegisterCallbacks(callbacks)
// Scenario: Wrap message on sender, unwrap on receiver
payload := []byte("ready test")
msgID := MessageID("cb-ready-1")
// Wrap on sender
wrappedMsg, err := senderRm.WrapOutgoingMessage(payload, msgID)
require.NoError(t, err)
// Unwrap on receiver
_, err = receiverRm.UnwrapReceivedMessage(wrappedMsg)
require.NoError(t, err)
// Verification - Wait on channel with timeout
select {
case receivedMsgID := <-readyChan:
// Mark as called implicitly since we received on channel
if receivedMsgID != msgID {
t.Errorf("OnMessageReady called with wrong ID: got %q, want %q", receivedMsgID, msgID)
}
case <-time.After(2 * time.Second):
// If timeout occurs, the channel receive failed.
t.Errorf("Timed out waiting for OnMessageReady callback on readyChan")
}
}
// Test OnMessageSent callback (via causal history ACK)
func TestCallback_OnMessageSent(t *testing.T) {
channelID := "test-cb-sent"
// Create two RMs
rm1, err := NewReliabilityManager(channelID)
require.NoError(t, err)
defer rm1.Cleanup()
rm2, err := NewReliabilityManager(channelID)
require.NoError(t, err)
defer rm2.Cleanup()
var wg sync.WaitGroup
sentCalled := false
var sentMsgID MessageID
var cbMutex sync.Mutex
callbacks := EventCallbacks{
OnMessageSent: func(messageId MessageID) {
cbMutex.Lock()
sentCalled = true
sentMsgID = messageId
cbMutex.Unlock()
wg.Done()
},
}
// Register callback on rm1 (the original sender)
rm1.RegisterCallbacks(callbacks)
// Scenario: rm1 sends msg1, rm2 receives msg1,
// rm2 sends msg2 (acking msg1), rm1 receives msg2.
// 1. rm1 sends msg1
payload1 := []byte("sent test 1")
msgID1 := MessageID("cb-sent-1")
wrappedMsg1, err := rm1.WrapOutgoingMessage(payload1, msgID1)
require.NoError(t, err)
// Note: msg1 is now in rm1's outgoing buffer
// 2. rm2 receives msg1 (to update its state)
_, err = rm2.UnwrapReceivedMessage(wrappedMsg1)
require.NoError(t, err)
// 3. rm2 sends msg2 (will include msg1 in causal history)
payload2 := []byte("sent test 2")
msgID2 := MessageID("cb-sent-2")
wrappedMsg2, err := rm2.WrapOutgoingMessage(payload2, msgID2)
require.NoError(t, err)
// 4. rm1 receives msg2 (should trigger ACK for msg1)
wg.Add(1) // Expect OnMessageSent for msg1 on rm1
_, err = rm1.UnwrapReceivedMessage(wrappedMsg2)
require.NoError(t, err)
// Verification
waitTimeout(&wg, 2*time.Second, t)
cbMutex.Lock()
defer cbMutex.Unlock()
if !sentCalled {
t.Errorf("OnMessageSent was not called")
}
// We primarily care that msg1 was ACKed.
if sentMsgID != msgID1 {
t.Errorf("OnMessageSent called with wrong ID: got %q, want %q", sentMsgID, msgID1)
}
}
// Test OnMissingDependencies callback
func TestCallback_OnMissingDependencies(t *testing.T) {
channelID := "test-cb-missing"
// Use separate sender/receiver RMs explicitly
senderRm, err := NewReliabilityManager(channelID)
require.NoError(t, err)
defer senderRm.Cleanup()
receiverRm, err := NewReliabilityManager(channelID)
require.NoError(t, err)
defer receiverRm.Cleanup()
var wg sync.WaitGroup
missingCalled := false
var missingMsgID MessageID
var missingDepsList []MessageID
var cbMutex sync.Mutex
callbacks := EventCallbacks{
OnMissingDependencies: func(messageId MessageID, missingDeps []MessageID) {
cbMutex.Lock()
missingCalled = true
missingMsgID = messageId
missingDepsList = missingDeps // Copy slice
cbMutex.Unlock()
wg.Done()
},
}
// Register callback only on the receiver rm
receiverRm.RegisterCallbacks(callbacks)
// Scenario: Sender sends msg1, then sender sends msg2 (depends on msg1),
// then receiver receives msg2 (which hasn't seen msg1).
// 1. Sender sends msg1
payload1 := []byte("missing test 1")
msgID1 := MessageID("cb-miss-1")
_, err = senderRm.WrapOutgoingMessage(payload1, msgID1)
require.NoError(t, err)
// 2. Sender sends msg2 (depends on msg1)
payload2 := []byte("missing test 2")
msgID2 := MessageID("cb-miss-2")
wrappedMsg2, err := senderRm.WrapOutgoingMessage(payload2, msgID2)
require.NoError(t, err)
// 3. Receiver receives msg2 (haven't seen msg1)
wg.Add(1) // Expect OnMissingDependencies
_, err = receiverRm.UnwrapReceivedMessage(wrappedMsg2)
require.NoError(t, err)
// Verification
waitTimeout(&wg, 2*time.Second, t)
cbMutex.Lock()
defer cbMutex.Unlock()
if !missingCalled {
t.Errorf("OnMissingDependencies was not called")
}
if missingMsgID != msgID2 {
t.Errorf("OnMissingDependencies called for wrong ID: got %q, want %q", missingMsgID, msgID2)
}
foundDep := false
for _, dep := range missingDepsList {
if dep == msgID1 {
foundDep = true
break
}
}
if !foundDep {
t.Errorf("OnMissingDependencies did not report %q as missing, got: %v", msgID1, missingDepsList)
}
}
// Test OnPeriodicSync callback
func TestCallback_OnPeriodicSync(t *testing.T) {
channelID := "test-cb-sync"
rm, err := NewReliabilityManager(channelID)
require.NoError(t, err)
defer rm.Cleanup()
syncCalled := false
var cbMutex sync.Mutex
// Use a channel to signal when the callback is hit
syncChan := make(chan bool, 1)
callbacks := EventCallbacks{
OnPeriodicSync: func() {
cbMutex.Lock()
if !syncCalled { // Only signal the first time
syncCalled = true
syncChan <- true
}
cbMutex.Unlock()
},
}
rm.RegisterCallbacks(callbacks)
// Start periodic tasks
err = rm.StartPeriodicTasks()
require.NoError(t, err)
// --- Verification ---
// Wait for the periodic sync callback with a timeout (needs to be longer than sync interval)
select {
case <-syncChan:
// Success
case <-time.After(10 * time.Second):
t.Errorf("Timed out waiting for OnPeriodicSync callback")
}
cbMutex.Lock()
defer cbMutex.Unlock()
if !syncCalled {
// This might happen if the timeout was too short
t.Logf("Warning: OnPeriodicSync might not have been called within the test timeout")
}
}
// Combined Test for multiple callbacks
func TestCallbacks_Combined(t *testing.T) {
channelID := "test-cb-combined"
// Create sender and receiver RMs
senderRm, err := NewReliabilityManager(channelID)
require.NoError(t, err)
defer senderRm.Cleanup()
receiverRm, err := NewReliabilityManager(channelID)
require.NoError(t, err)
defer receiverRm.Cleanup()
// Channels for synchronization
readyChan1 := make(chan bool, 1)
sentChan1 := make(chan bool, 1)
missingChan := make(chan []MessageID, 1)
// Use maps for verification
receivedReady := make(map[MessageID]bool)
receivedSent := make(map[MessageID]bool)
var cbMutex sync.Mutex
callbacksReceiver := EventCallbacks{
OnMessageReady: func(messageId MessageID) {
cbMutex.Lock()
receivedReady[messageId] = true
cbMutex.Unlock()
if messageId == "cb-comb-1" {
// Use non-blocking send
select {
case readyChan1 <- true:
default:
}
}
},
OnMessageSent: func(messageId MessageID) {
// This callback is registered on Receiver, but Sent events
// are typically relevant to the Sender. We don't expect this.
t.Errorf("Unexpected OnMessageSent call on Receiver for %s", messageId)
},
OnMissingDependencies: func(messageId MessageID, missingDeps []MessageID) {
// This callback is registered on Receiver, used for receiverRm2 below
},
}
callbacksSender := EventCallbacks{
OnMessageReady: func(messageId MessageID) {
// Not expected on sender in this test flow
},
OnMessageSent: func(messageId MessageID) {
cbMutex.Lock()
receivedSent[messageId] = true
cbMutex.Unlock()
if messageId == "cb-comb-1" {
select {
case sentChan1 <- true:
default:
}
}
},
OnMissingDependencies: func(messageId MessageID, missingDeps []MessageID) {
// Not expected on sender
},
}
// Register callbacks
receiverRm.RegisterCallbacks(callbacksReceiver)
senderRm.RegisterCallbacks(callbacksSender)
// --- Test Scenario ---
msgID1 := MessageID("cb-comb-1")
msgID2 := MessageID("cb-comb-2")
msgID3 := MessageID("cb-comb-3")
payload1 := []byte("combined test 1")
payload2 := []byte("combined test 2")
payload3 := []byte("combined test 3")
// 1. Sender sends msg1
wrappedMsg1, err := senderRm.WrapOutgoingMessage(payload1, msgID1)
require.NoError(t, err)
// 2. Receiver receives msg1
_, err = receiverRm.UnwrapReceivedMessage(wrappedMsg1)
require.NoError(t, err)
// 3. Receiver sends msg2 (depends on msg1 implicitly via state)
wrappedMsg2, err := receiverRm.WrapOutgoingMessage(payload2, msgID2)
require.NoError(t, err)
// 4. Sender receives msg2 from Receiver (acks msg1 for sender)
_, err = senderRm.UnwrapReceivedMessage(wrappedMsg2)
require.NoError(t, err)
// 5. Sender sends msg3 (depends on msg2)
wrappedMsg3, err := senderRm.WrapOutgoingMessage(payload3, msgID3)
require.NoError(t, err)
// 6. Create Receiver2, register missing deps callback
receiverRm2, err := NewReliabilityManager(channelID)
require.NoError(t, err)
defer receiverRm2.Cleanup()
callbacksReceiver2 := EventCallbacks{
OnMissingDependencies: func(messageId MessageID, missingDeps []MessageID) {
if messageId == msgID3 {
select {
case missingChan <- missingDeps:
default:
}
}
},
}
receiverRm2.RegisterCallbacks(callbacksReceiver2)
// 7. Receiver2 receives msg3 (should report missing msg1, msg2)
_, err = receiverRm2.UnwrapReceivedMessage(wrappedMsg3)
require.NoError(t, err)
// --- Verification ---
timeout := 5 * time.Second
expectedReady1 := false
expectedSent1 := false
var reportedMissingDeps []MessageID
missingDepsReceived := false
receivedCount := 0
expectedCount := 3 // ready1, sent1, missingDeps
timer := time.NewTimer(timeout)
defer timer.Stop()
for receivedCount < expectedCount {
select {
case <-readyChan1:
if !expectedReady1 { // Avoid double counting if signaled twice
expectedReady1 = true
receivedCount++
}
case <-sentChan1:
if !expectedSent1 {
expectedSent1 = true
receivedCount++
}
case deps := <-missingChan:
if !missingDepsReceived {
reportedMissingDeps = deps
missingDepsReceived = true
receivedCount++
}
case <-timer.C:
t.Fatalf("Timed out waiting for combined callbacks (received %d out of %d)", receivedCount, expectedCount)
}
}
// Check results
cbMutex.Lock()
defer cbMutex.Unlock()
if !expectedReady1 || !receivedReady[msgID1] {
t.Errorf("OnMessageReady not called/verified for %s", msgID1)
}
if !expectedSent1 || !receivedSent[msgID1] {
t.Errorf("OnMessageSent not called/verified for %s", msgID1)
}
if !missingDepsReceived {
t.Errorf("OnMissingDependencies not called/verified for %s", msgID3)
} else {
foundDep1 := false
foundDep2 := false
for _, dep := range reportedMissingDeps {
if dep == msgID1 {
foundDep1 = true
}
if dep == msgID2 {
foundDep2 = true
}
}
if !foundDep1 || !foundDep2 {
t.Errorf("OnMissingDependencies for %s reported wrong deps: got %v, want %s and %s", msgID3, reportedMissingDeps, msgID1, msgID2)
}
}
}
// Helper function to wait for WaitGroup with a timeout
func waitTimeout(wg *sync.WaitGroup, timeout time.Duration, t *testing.T) {
c := make(chan struct{})
go func() {
defer close(c)
wg.Wait()
}()
select {
case <-c:
// Completed normally
case <-time.After(timeout):
t.Fatalf("Timed out waiting for callbacks")
}
}

8
sds/types.go Normal file
View File

@ -0,0 +1,8 @@
package sds
type MessageID string
type UnwrappedMessage struct {
Message *[]byte `json:"message"`
MissingDeps *[]MessageID `json:"missingDeps"`
}