diff --git a/leopard.nim b/leopard.nim index 77b78bc..84dad5c 100644 --- a/leopard.nim +++ b/leopard.nim @@ -16,6 +16,8 @@ const LeopardNotEnoughDataMsg = "Buffer counts are too low" MinBufferSize* = 64.uint + MinSymbols* = 1.uint + MaxTotalSymbols* = 65536.uint type Data* = seq[seq[byte]] @@ -61,11 +63,17 @@ type # data symbols = 239 # parity symbols = 255 - 239 = 16 +func isValid*(code: ReedSolomonCode): bool = + not ((code.codeword - code.data != code.parity) or + (code.parity > code.data) or (code.codeword < MinSymbols + 1) or + (code.data < MinSymbols) or (code.parity < MinSymbols) or + (code.codeword > MaxTotalSymbols)) + proc RS*(codeword, data: Positive): ReedSolomonCode = var parity = codeword - data - if parity <= 0: parity = 0 + if parity < 0: parity = 0 (codeword: codeword.uint, data: data.uint, parity: parity.uint) when (NimMajor, NimMinor, NimPatch) < (1, 4, 0): @@ -116,7 +124,7 @@ proc leoInit*() = proc encode*(code: ReedSolomonCode, data: Data): Result[ParityData, LeopardError] = - if code.parity < 1 or code.parity > code.data: + if not code.isValid: return err LeopardError(code: LeopardBadCode, msg: LeopardBadCodeMsg) var @@ -194,7 +202,7 @@ proc encode*(code: ReedSolomonCode, data: Data): proc decode*(code: ReedSolomonCode, data: Data, parityData: ParityData, symbolBytes: uint): Result[Data, LeopardError] = - if code.parity < 1 or code.parity > code.data: + if not code.isValid: return err LeopardError(code: LeopardBadCode, msg: LeopardBadCodeMsg) var diff --git a/tests/test_leopard.nim b/tests/test_leopard.nim index 1e4d496..c9c7d81 100644 --- a/tests/test_leopard.nim +++ b/tests/test_leopard.nim @@ -19,6 +19,33 @@ proc genData(outerLen, innerLen: uint): Data = var initialized = false +suite "Helpers": + test "isValid should return false if RS code is nonsensical or is invalid per Leopard-RS": + var + rsCode = (codeword: 8.uint, data: 5.uint, parity: 1.uint) + + check: not rsCode.isValid + + rsCode = RS(110,10) + + check: not rsCode.isValid + + rsCode = RS(1,1) + + check: not rsCode.isValid + + rsCode = (codeword: 2.uint, data: 0.uint, parity: 2.uint) + + check: not rsCode.isValid + + rsCode = RS(2,2) + + check: not rsCode.isValid + + rsCode = RS(65537,65409) + + check: not rsCode.isValid + suite "Initialization": test "encode and decode should fail if Leopard-RS is not initialized": let @@ -53,7 +80,7 @@ suite "Initialization": check: initialized suite "Encoder": - test "should fail if RS code is nonsensical or is so per Leopard-RS": + test "should fail if RS code is nonsensical or is invalid per Leopard-RS": check: initialized if not initialized: return @@ -61,7 +88,7 @@ suite "Encoder": symbolBytes = MinBufferSize var - rsCode = RS(5,5) + rsCode = RS(110,10) data = genData(rsCode.data, symbolBytes) encodeRes = rsCode.encode data @@ -69,22 +96,6 @@ suite "Encoder": if encodeRes.isErr: check: encodeRes.error.code == LeopardBadCode - rsCode = RS(5,10) - data = genData(rsCode.data, symbolBytes) - encodeRes = rsCode.encode data - - check: encodeRes.isErr - if encodeRes.isErr: - check: encodeRes.error.code == LeopardBadCode - - rsCode = RS(110,10) - data = genData(rsCode.data, symbolBytes) - encodeRes = rsCode.encode data - - check: encodeRes.isErr - if encodeRes.isErr: - check: encodeRes.error.code == LeopardBadCode - test "should fail if outer length of data does not match the RS code": check: initialized if not initialized: return @@ -185,7 +196,7 @@ suite "Encoder": check: encodeRes.isOk suite "Decoder": - test "should fail if RS code is nonsensical or is so per Leopard-RS": + test "should fail if RS code is nonsensical or is invalid per Leopard-RS": check: initialized if not initialized: return @@ -193,7 +204,7 @@ suite "Decoder": symbolBytes = MinBufferSize var - rsCode = RS(5,5) + rsCode = RS(110,10) data = genData(rsCode.data, symbolBytes) parityData: ParityData decodeRes = rsCode.decode(data, parityData, symbolBytes) @@ -202,22 +213,6 @@ suite "Decoder": if decodeRes.isErr: check: decodeRes.error.code == LeopardBadCode - rsCode = RS(5,10) - data = genData(rsCode.data, symbolBytes) - decodeRes = rsCode.decode(data, parityData, symbolBytes) - - check: decodeRes.isErr - if decodeRes.isErr: - check: decodeRes.error.code == LeopardBadCode - - rsCode = RS(110,10) - data = genData(rsCode.data, symbolBytes) - decodeRes = rsCode.decode(data, parityData, symbolBytes) - - check: decodeRes.isErr - if decodeRes.isErr: - check: decodeRes.error.code == LeopardBadCode - test "should fail if outer length of data does not match the RS code": check: initialized if not initialized: return