From 68e691583e83e98f0e23d6b5e4df3354966aa33c Mon Sep 17 00:00:00 2001 From: Tomasz Bekas Date: Wed, 28 Aug 2024 18:17:03 +0200 Subject: [PATCH] Add prepare phase to encode/decode --- leopard/leopard.nim | 237 +++++++++++++++++++++++++++++------------- tests/helpers.nim | 2 + tests/testleopard.nim | 27 ++--- 3 files changed, 182 insertions(+), 84 deletions(-) diff --git a/leopard/leopard.nim b/leopard/leopard.nim index 615f7ad..a542bad 100644 --- a/leopard/leopard.nim +++ b/leopard/leopard.nim @@ -36,6 +36,9 @@ type dataBufferPtr: seq[LeoBufferPtr] # buffer where data is copied before encoding workBufferCount: int # number of parity work buffers workBufferPtr: seq[LeoBufferPtr] # buffer where parity data is written during encoding or before decoding + + dataBufferNil: seq[bool] # true represents Nil in dataBufferPtr + workBufferNil: seq[bool] # true represents nil in workBufferPtr case kind: LeoCoderKind of LeoCoderKind.Decoder: decodeBufferCount: int # number of decoding work buffers @@ -46,31 +49,33 @@ type LeoEncoder* = object of Leo LeoDecoder* = object of Leo -func encode*( - self: var LeoEncoder, - data, - parity: var openArray[seq[byte]]): Result[void, cstring] = - ## Encode a list of buffers in `data` into a number of `bufSize` sized - ## `parity` buffers - ## - ## `data` - list of original data `buffers` of size `bufSize` - ## `parity` - list of parity `buffers` of size `bufSize` - ## +func prepareEncode*( + self: var LeoEncoder, + data: var openArray[seq[byte]] + ): Result[void, cstring] = + ## Copy `data` into internal encode buffer + ## + if data.len != self.buffers: return err("Number of data buffers should match!") - if parity.len != self.parity: - return err("Number of parity buffers should match!") + # copy data into aligned buffer + for i in 0.. 0: + copyMem(self.dataBufferPtr[i], addr data[i][0], self.bufSize) + self.dataBufferNil[i] = false + else: + self.dataBufferNil[i] = true + + # copy parity into aligned buffer + for i in 0.. 0: + copyMem(self.workBufferPtr[i], addr parity[i][0], self.bufSize) + self.workBufferNil[i] = false + else: + self.workBufferNil[i] = true + + ok() + +func decodePrepared*( + self: var LeoDecoder + ): Result[void, cstring] = + + for i in 0.. 0: - copyMem(self.dataBufferPtr[i], addr data[i][0], self.bufSize) - dataPtr[i] = self.dataBufferPtr[i] - else: - dataPtr[i] = nil - - # copy parity into aligned buffer - for i in 0.. 0: - copyMem(self.workBufferPtr[i], addr parity[i][0], self.bufSize) - parityPtr[i] = self.workBufferPtr[i] - else: - parityPtr[i] = nil - - let - res = leoDecode( - self.bufSize.culonglong, - self.buffers.cuint, - self.parity.cuint, - self.decodeBufferCount.cuint, - cast[LeoDataPtr](addr dataPtr[0]), - cast[LeoDataPtr](addr parityPtr[0]), - cast[ptr pointer](addr self.decodeBufferPtr[0])) - - if ord(res) != ord(LeopardSuccess): - return err(leoResultString(res.LeopardResult)) - - for i, p in dataPtr: - if p.isNil: - copyMem(addr recovered[i][0], self.decodeBufferPtr[i], self.bufSize) - - ok() + self.readDecoded(recovered) func free*(self: var Leo) = + if self.dataBufferNil.len > 0: + self.dataBufferNil.setLen(0) + + if self.workBufferNil.len > 0: + self.workBufferNil.setLen(0) + if self.workBufferPtr.len > 0: for i, p in self.workBufferPtr: if not isNil(p): @@ -232,6 +324,9 @@ proc init[TT: Leo]( buffers.cuint, parity.cuint).int + self.workBufferNil.setLen(self.workBufferCount) + self.dataBufferNil.setLen(self.buffers) + # initialize encode work buffers for _ in 0.. 0: dropRandomIdx(parityBuf, parityLosses) + GC_fullCollect() + decoder.decode(dataBuf, parityBuf, recoveredBuf).tryGet() for i, d in dataBuf: diff --git a/tests/testleopard.nim b/tests/testleopard.nim index 8c5cb00..b1a782a 100644 --- a/tests/testleopard.nim +++ b/tests/testleopard.nim @@ -1,5 +1,6 @@ import std/random import std/sets +import std/sequtils import pkg/unittest2 import pkg/stew/results @@ -31,8 +32,8 @@ suite "Leopard Parametrization": test "Should not allow encoding with invalid data buffer counts": var leo = LeoEncoder.init(64, 4, 2).tryGet() - data = newSeq[seq[byte]](3) - parity = newSeq[seq[byte]](2) + data = newSeqWith[seq[byte]](3, newSeq[byte](64)) + parity = newSeqWith[seq[byte]](2, newSeq[byte](64)) check: leo.encode(data, parity).error == "Number of data buffers should match!" @@ -40,8 +41,8 @@ suite "Leopard Parametrization": test "Should not allow encoding with invalid parity buffer counts": var leo = LeoEncoder.init(64, 4, 2).tryGet() - data = newSeq[seq[byte]](4) - parity = newSeq[seq[byte]](3) + data = newSeqWith[seq[byte]](4, newSeq[byte](64)) + parity = newSeqWith[seq[byte]](3, newSeq[byte](64)) check: leo.encode(data, parity).error == "Number of parity buffers should match!" @@ -49,9 +50,9 @@ suite "Leopard Parametrization": test "Should not allow decoding with invalid data buffer counts": var leo = LeoDecoder.init(64, 4, 2).tryGet() - data = newSeq[seq[byte]](3) - parity = newSeq[seq[byte]](2) - recovered = newSeq[seq[byte]](3) + data = newSeqWith[seq[byte]](3, newSeq[byte](64)) + parity = newSeqWith[seq[byte]](2, newSeq[byte](64)) + recovered = newSeqWith[seq[byte]](3, newSeq[byte](64)) check: leo.decode(data, parity, recovered).error == "Number of data buffers should match!" @@ -59,9 +60,9 @@ suite "Leopard Parametrization": test "Should not allow decoding with invalid data buffer counts": var leo = LeoDecoder.init(64, 4, 2).tryGet() - data = newSeq[seq[byte]](4) - parity = newSeq[seq[byte]](1) - recovered = newSeq[seq[byte]](3) + data = newSeqWith[seq[byte]](4, newSeq[byte](64)) + parity = newSeqWith[seq[byte]](1, newSeq[byte](64)) + recovered = newSeqWith[seq[byte]](3, newSeq[byte](64)) check: leo.decode(data, parity, recovered).error == "Number of parity buffers should match!" @@ -69,9 +70,9 @@ suite "Leopard Parametrization": test "Should not allow decoding with invalid data buffer counts": var leo = LeoDecoder.init(64, 4, 2).tryGet() - data = newSeq[seq[byte]](4) - parity = newSeq[seq[byte]](2) - recovered = newSeq[seq[byte]](3) + data = newSeqWith[seq[byte]](4, newSeq[byte](64)) + parity = newSeqWith[seq[byte]](2, newSeq[byte](64)) + recovered = newSeqWith[seq[byte]](3, newSeq[byte](64)) check: leo.decode(data, parity, recovered).error == "Number of recovered buffers should match buffers!"