mirror of
https://github.com/logos-messaging/sds-go-bindings.git
synced 2026-01-02 06:03:12 +00:00
425 lines
10 KiB
Go
425 lines
10 KiB
Go
//go:build !lint
|
|
|
|
package sds
|
|
|
|
/*
|
|
#include <libsds.h>
|
|
#include <stdio.h>
|
|
#include <stdlib.h>
|
|
|
|
extern void sdsGlobalEventCallback(int ret, char* msg, size_t len, void* userData);
|
|
|
|
typedef struct {
|
|
int ret;
|
|
char* msg;
|
|
size_t len;
|
|
void* ffiWg;
|
|
} SdsResp;
|
|
|
|
static void* allocResp(void* wg) {
|
|
SdsResp* r = calloc(1, sizeof(SdsResp));
|
|
r->ffiWg = wg;
|
|
return r;
|
|
}
|
|
|
|
static void freeResp(void* resp) {
|
|
if (resp != NULL) {
|
|
free(resp);
|
|
}
|
|
}
|
|
|
|
static char* getMyCharPtr(void* resp) {
|
|
if (resp == NULL) {
|
|
return NULL;
|
|
}
|
|
SdsResp* m = (SdsResp*) resp;
|
|
return m->msg;
|
|
}
|
|
|
|
static size_t getMyCharLen(void* resp) {
|
|
if (resp == NULL) {
|
|
return 0;
|
|
}
|
|
SdsResp* m = (SdsResp*) resp;
|
|
return m->len;
|
|
}
|
|
|
|
static int getRet(void* resp) {
|
|
if (resp == NULL) {
|
|
return 0;
|
|
}
|
|
SdsResp* m = (SdsResp*) resp;
|
|
return m->ret;
|
|
}
|
|
|
|
// resp must be set != NULL in case interest on retrieving data from the callback
|
|
void SdsGoCallback(int ret, char* msg, size_t len, void* resp);
|
|
|
|
static void* cGoSdsNewReliabilityManager(void* resp) {
|
|
// We pass NULL because we are not interested in retrieving data from this callback
|
|
void* ret = SdsNewReliabilityManager((SdsCallBack) SdsGoCallback, resp);
|
|
return ret;
|
|
}
|
|
|
|
static void cGoSdsSetEventCallback(void* rmCtx) {
|
|
// The 'sdsGlobalEventCallback' Go function is shared amongst all possible Reliability Manager instances.
|
|
|
|
// Given that the 'sdsGlobalEventCallback' is shared, we pass again the
|
|
// rmCtx instance but in this case is needed to pick up the correct method
|
|
// that will handle the event.
|
|
|
|
// In other words, for every call libsds makes to sdsGlobalEventCallback,
|
|
// the 'userData' parameter will bring the context of the rm that registered
|
|
// that sdsGlobalEventCallback.
|
|
|
|
// This technique is needed because cgo only allows to export Go functions and not methods.
|
|
|
|
SdsSetEventCallback(rmCtx, (SdsCallBack) sdsGlobalEventCallback, rmCtx);
|
|
}
|
|
|
|
static void cGoSdsCleanupReliabilityManager(void* rmCtx, void* resp) {
|
|
SdsCleanupReliabilityManager(rmCtx, (SdsCallBack) SdsGoCallback, resp);
|
|
}
|
|
|
|
static void cGoSdsResetReliabilityManager(void* rmCtx, void* resp) {
|
|
SdsResetReliabilityManager(rmCtx, (SdsCallBack) SdsGoCallback, resp);
|
|
}
|
|
|
|
static void cGoSdsWrapOutgoingMessage(void* rmCtx,
|
|
void* message,
|
|
size_t messageLen,
|
|
const char* messageId,
|
|
const char* channelId,
|
|
void* resp) {
|
|
SdsWrapOutgoingMessage(rmCtx,
|
|
message,
|
|
messageLen,
|
|
messageId,
|
|
channelId,
|
|
(SdsCallBack) SdsGoCallback,
|
|
resp);
|
|
}
|
|
static void cGoSdsUnwrapReceivedMessage(void* rmCtx,
|
|
void* message,
|
|
size_t messageLen,
|
|
void* resp) {
|
|
SdsUnwrapReceivedMessage(rmCtx,
|
|
message,
|
|
messageLen,
|
|
(SdsCallBack) SdsGoCallback,
|
|
resp);
|
|
}
|
|
|
|
static void cGoSdsMarkDependenciesMet(void* rmCtx,
|
|
char** messageIDs,
|
|
size_t count,
|
|
const char* channelId,
|
|
void* resp) {
|
|
SdsMarkDependenciesMet(rmCtx,
|
|
messageIDs,
|
|
count,
|
|
channelId,
|
|
(SdsCallBack) SdsGoCallback,
|
|
resp);
|
|
}
|
|
|
|
static void cGoSdsStartPeriodicTasks(void* rmCtx, void* resp) {
|
|
SdsStartPeriodicTasks(rmCtx, (SdsCallBack) SdsGoCallback, resp);
|
|
}
|
|
|
|
*/
|
|
import "C"
|
|
import (
|
|
"encoding/json"
|
|
"errors"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"unsafe"
|
|
|
|
errorspkg "github.com/pkg/errors"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
var (
|
|
errEmptyReliabilityManager = errors.New("empty reliability manager")
|
|
)
|
|
|
|
//export SdsGoCallback
|
|
func SdsGoCallback(ret C.int, msg *C.char, len C.size_t, resp unsafe.Pointer) {
|
|
if resp != nil {
|
|
m := (*C.SdsResp)(resp)
|
|
m.ret = ret
|
|
m.msg = msg
|
|
m.len = len
|
|
wg := (*sync.WaitGroup)(m.ffiWg)
|
|
wg.Done()
|
|
}
|
|
}
|
|
|
|
func NewReliabilityManager(logger *zap.Logger) (*ReliabilityManager, error) {
|
|
if logger == nil {
|
|
logger = zap.NewNop()
|
|
}
|
|
|
|
rm := &ReliabilityManager{
|
|
logger: logger,
|
|
}
|
|
|
|
rm.logger.Info("creating new reliability manager")
|
|
|
|
wg := sync.WaitGroup{}
|
|
|
|
var resp = C.allocResp(unsafe.Pointer(&wg))
|
|
defer C.freeResp(resp)
|
|
|
|
if C.getRet(resp) != C.RET_OK {
|
|
errMsg := C.GoStringN(C.getMyCharPtr(resp), C.int(C.getMyCharLen(resp)))
|
|
return nil, errors.New(errMsg)
|
|
}
|
|
|
|
wg.Add(1)
|
|
rm.rmCtx = C.cGoSdsNewReliabilityManager(resp)
|
|
wg.Wait()
|
|
|
|
C.cGoSdsSetEventCallback(rm.rmCtx)
|
|
registerReliabilityManager(rm)
|
|
|
|
rm.logger.Debug("successfully created reliability manager")
|
|
return rm, nil
|
|
}
|
|
|
|
//export sdsGlobalEventCallback
|
|
func sdsGlobalEventCallback(callerRet C.int, msg *C.char, len C.size_t, userData unsafe.Pointer) {
|
|
msgStr := C.GoStringN(msg, C.int(len))
|
|
rm, ok := rmRegistry[userData] // userData contains rm's ctx
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
if callerRet == C.RET_OK {
|
|
rm.OnEvent(msgStr)
|
|
} else {
|
|
rm.OnCallbackError(int(callerRet), msgStr)
|
|
}
|
|
}
|
|
|
|
func (rm *ReliabilityManager) Cleanup() error {
|
|
if rm == nil {
|
|
return errEmptyReliabilityManager
|
|
}
|
|
|
|
rm.logger.Debug("cleaning up reliability manager")
|
|
|
|
wg := sync.WaitGroup{}
|
|
var resp = C.allocResp(unsafe.Pointer(&wg))
|
|
defer C.freeResp(resp)
|
|
|
|
wg.Add(1)
|
|
C.cGoSdsCleanupReliabilityManager(rm.rmCtx, resp)
|
|
wg.Wait()
|
|
|
|
if C.getRet(resp) == C.RET_OK {
|
|
unregisterReliabilityManager(rm)
|
|
rm.logger.Debug("cleaned up reliability manager")
|
|
return nil
|
|
}
|
|
|
|
errMsg := "error CleanupReliabilityManager: " + C.GoStringN(C.getMyCharPtr(resp), C.int(C.getMyCharLen(resp)))
|
|
return errors.New(errMsg)
|
|
}
|
|
|
|
func (rm *ReliabilityManager) Reset() error {
|
|
if rm == nil {
|
|
return errEmptyReliabilityManager
|
|
}
|
|
|
|
rm.logger.Debug("resetting reliability manager")
|
|
|
|
wg := sync.WaitGroup{}
|
|
var resp = C.allocResp(unsafe.Pointer(&wg))
|
|
defer C.freeResp(resp)
|
|
|
|
wg.Add(1)
|
|
C.cGoSdsResetReliabilityManager(rm.rmCtx, resp)
|
|
wg.Wait()
|
|
|
|
if C.getRet(resp) == C.RET_OK {
|
|
rm.logger.Debug("successfully resetted reliability manager")
|
|
return nil
|
|
}
|
|
|
|
errMsg := "error ResetReliabilityManager: " + C.GoStringN(C.getMyCharPtr(resp), C.int(C.getMyCharLen(resp)))
|
|
return errors.New(errMsg)
|
|
}
|
|
|
|
func (rm *ReliabilityManager) WrapOutgoingMessage(message []byte, messageId MessageID, channelId string) ([]byte, error) {
|
|
if rm == nil {
|
|
return nil, errEmptyReliabilityManager
|
|
}
|
|
|
|
logger := rm.logger.With(zap.String("messageId", string(messageId)))
|
|
logger.Debug("wrapping outgoing message", zap.String("messageId", string(messageId)))
|
|
|
|
wg := sync.WaitGroup{}
|
|
var resp = C.allocResp(unsafe.Pointer(&wg))
|
|
defer C.freeResp(resp)
|
|
|
|
cMessageId := C.CString(string(messageId))
|
|
defer C.free(unsafe.Pointer(cMessageId))
|
|
|
|
var cMessagePtr unsafe.Pointer
|
|
if len(message) > 0 {
|
|
cMessagePtr = C.CBytes(message) // C.CBytes allocates memory that needs to be freed
|
|
defer C.free(cMessagePtr)
|
|
} else {
|
|
cMessagePtr = nil
|
|
}
|
|
cMessageLen := C.size_t(len(message))
|
|
|
|
cChannelId := C.CString(channelId)
|
|
defer C.free(unsafe.Pointer(cChannelId))
|
|
|
|
wg.Add(1)
|
|
C.cGoSdsWrapOutgoingMessage(rm.rmCtx, cMessagePtr, cMessageLen, cMessageId, cChannelId, resp)
|
|
wg.Wait()
|
|
|
|
if C.getRet(resp) == C.RET_OK {
|
|
resStr := C.GoStringN(C.getMyCharPtr(resp), C.int(C.getMyCharLen(resp)))
|
|
if resStr == "" {
|
|
logger.Debug("received empty res string for messageId")
|
|
return nil, nil
|
|
}
|
|
logger.Debug("successfully wrapped message")
|
|
|
|
parts := strings.Split(resStr, ",")
|
|
bytes := make([]byte, len(parts))
|
|
|
|
for i, part := range parts {
|
|
n, err := strconv.Atoi(strings.TrimSpace(part))
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
bytes[i] = byte(n)
|
|
}
|
|
|
|
return bytes, nil
|
|
}
|
|
|
|
errMsg := "error WrapOutgoingMessage: " + C.GoStringN(C.getMyCharPtr(resp), C.int(C.getMyCharLen(resp)))
|
|
return nil, errors.New(errMsg)
|
|
}
|
|
|
|
func (rm *ReliabilityManager) UnwrapReceivedMessage(message []byte) (*UnwrappedMessage, error) {
|
|
if rm == nil {
|
|
return nil, errEmptyReliabilityManager
|
|
}
|
|
|
|
wg := sync.WaitGroup{}
|
|
var resp = C.allocResp(unsafe.Pointer(&wg))
|
|
defer C.freeResp(resp)
|
|
|
|
var cMessagePtr unsafe.Pointer
|
|
if len(message) > 0 {
|
|
cMessagePtr = C.CBytes(message) // C.CBytes allocates memory that needs to be freed
|
|
defer C.free(cMessagePtr)
|
|
} else {
|
|
cMessagePtr = nil
|
|
}
|
|
cMessageLen := C.size_t(len(message))
|
|
|
|
wg.Add(1)
|
|
C.cGoSdsUnwrapReceivedMessage(rm.rmCtx, cMessagePtr, cMessageLen, resp)
|
|
wg.Wait()
|
|
|
|
if C.getRet(resp) == C.RET_OK {
|
|
resStr := C.GoStringN(C.getMyCharPtr(resp), C.int(C.getMyCharLen(resp)))
|
|
if resStr == "" {
|
|
rm.logger.Debug("received empty res string")
|
|
return nil, nil
|
|
}
|
|
rm.logger.Debug("successfully unwrapped message")
|
|
|
|
unwrappedMessage := UnwrappedMessage{}
|
|
err := json.Unmarshal([]byte(resStr), &unwrappedMessage)
|
|
if err != nil {
|
|
return nil, errorspkg.Wrap(err, "failed to unmarshal unwrapped message")
|
|
}
|
|
|
|
return &unwrappedMessage, nil
|
|
}
|
|
|
|
errMsg := "error UnwrapReceivedMessage: " + C.GoStringN(C.getMyCharPtr(resp), C.int(C.getMyCharLen(resp)))
|
|
return nil, errors.New(errMsg)
|
|
}
|
|
|
|
// MarkDependenciesMet informs the library that dependencies are met
|
|
func (rm *ReliabilityManager) MarkDependenciesMet(messageIDs []MessageID, channelId string) error {
|
|
if rm == nil {
|
|
return errEmptyReliabilityManager
|
|
}
|
|
|
|
if len(messageIDs) == 0 {
|
|
return nil // Nothing to do
|
|
}
|
|
|
|
wg := sync.WaitGroup{}
|
|
var resp = C.allocResp(unsafe.Pointer(&wg))
|
|
defer C.freeResp(resp)
|
|
|
|
// Convert Go string slice to C array of C strings (char**)
|
|
cMessageIDs := make([]*C.char, len(messageIDs))
|
|
for i, id := range messageIDs {
|
|
cMessageIDs[i] = C.CString(string(id))
|
|
defer C.free(unsafe.Pointer(cMessageIDs[i])) // Ensure each CString is freed
|
|
}
|
|
|
|
// Create a pointer (**C.char) to the first element of the slice
|
|
var cMessageIDsPtr **C.char
|
|
if len(cMessageIDs) > 0 {
|
|
cMessageIDsPtr = &cMessageIDs[0]
|
|
} else {
|
|
cMessageIDsPtr = nil // Handle empty slice case
|
|
}
|
|
|
|
wg.Add(1)
|
|
cChannelId := C.CString(channelId)
|
|
defer C.free(unsafe.Pointer(cChannelId))
|
|
|
|
// Pass the pointer variable (cMessageIDsPtr) directly, which is of type **C.char
|
|
C.cGoSdsMarkDependenciesMet(rm.rmCtx, cMessageIDsPtr, C.size_t(len(messageIDs)), cChannelId, resp)
|
|
wg.Wait()
|
|
|
|
if C.getRet(resp) == C.RET_OK {
|
|
rm.logger.Debug("successfully marked dependencies as met")
|
|
return nil
|
|
}
|
|
|
|
errMsg := "error MarkDependenciesMet: " + C.GoStringN(C.getMyCharPtr(resp), C.int(C.getMyCharLen(resp)))
|
|
return errors.New(errMsg)
|
|
}
|
|
|
|
func (rm *ReliabilityManager) StartPeriodicTasks() error {
|
|
if rm == nil {
|
|
return errEmptyReliabilityManager
|
|
}
|
|
|
|
rm.logger.Debug("starting periodic tasks")
|
|
|
|
wg := sync.WaitGroup{}
|
|
var resp = C.allocResp(unsafe.Pointer(&wg))
|
|
defer C.freeResp(resp)
|
|
|
|
wg.Add(1)
|
|
C.cGoSdsStartPeriodicTasks(rm.rmCtx, resp)
|
|
wg.Wait()
|
|
|
|
if C.getRet(resp) == C.RET_OK {
|
|
rm.logger.Debug("successfully started periodic tasks")
|
|
return nil
|
|
}
|
|
|
|
errMsg := "error StartPeriodicTasks: " + C.GoStringN(C.getMyCharPtr(resp), C.int(C.getMyCharLen(resp)))
|
|
return errors.New(errMsg)
|
|
}
|