diff --git a/hardy/ct_primitives.nim b/hardy/ct_primitives.nim index 6764b54..c27cac0 100644 --- a/hardy/ct_primitives.nim +++ b/hardy/ct_primitives.nim @@ -56,6 +56,17 @@ func `-`*(x: HardBase): HardBase {.inline.}= ## of an unsigned integer {.emit:"`result` = -`x`;".} +# ############################################################ +# +# Bit hacks +# +# ############################################################ + +func isMsbSet*[T: HardBase](x: T): HardBool[T] {.inline.} = + ## Returns the most significant bit of an integer + const msb_pos = T.sizeof * 8 - 1 + result = (HardBool[T])(x shr msb_pos) + # ############################################################ # # Hardened Boolean primitives @@ -77,37 +88,36 @@ func select*[T: HardBase](ctl: HardBool[T], x, y: T): T {.inline.}= # is optimized into a branch by Clang :/ y xor (-ctl.T and (x xor y)) -func `!=`*[T: HardBase](x, y: T): HardBool[T] {.inline.}= +func noteq[T: HardBase](x, y: T): HardBool[T] {.inline.}= const msb = T.sizeof * 8 - 1 let z = x xor y result = (type result)((z or -z) shr msb) func `==`*[T: HardBase](x, y: T): HardBool[T] {.inline.}= - not(x != y) + not(noteq(x, y)) func `<`*[T: HardBase](x, y: T): HardBool[T] {.inline.}= - const msb = T.sizeof * 8 - 1 - result = (type result)( - ( + result = isMsbSet( x xor ( (x xor y) or ((x - y) xor y) ) - ) shr msb - ) + ) func `<=`*[T: HardBase](x, y: T): HardBool[T] {.inline.}= (y < x) xor 1 # ############################################################ # -# Bit hacks +# Workaround system.nim `!=` template # # ############################################################ -func isMsbSet*[T: HardBase](x: T): HardBool[T] {.inline.} = - ## Returns the most significant bit of an integer - const msb_pos = T.sizeof * 8 - 1 - result = (HardBool[T])(x shr msb_pos) +# system.nim defines `!=` as a catchall template +# in terms of `==` while we define `==` in terms of `!=` +# So we would have not(not(noteq(x,y))) + +template trmFixSystemNotEq*{x != y}[T: HardBase](x, y: T): HardBool[T] = + noteq(x, y) # ############################################################ # diff --git a/hardy/datatypes.nim b/hardy/datatypes.nim index 6ddeedf..ee947bb 100644 --- a/hardy/datatypes.nim +++ b/hardy/datatypes.nim @@ -24,10 +24,10 @@ type ## Note that constant-time allocation is very involved for ## heap-allocated types (i.e. requires a memory pool) -func htrue*(T: type(BaseUint)): auto {.compileTime.}= +func htrue*(T: type(BaseUint)): auto {.inline.}= (HardBool[HardBase[T]])(true) -func hfalse*(T: type(BaseUint)): auto {.compileTime.}= +func hfalse*(T: type(BaseUint)): auto {.inline.}= (HardBool[HardBase[T]])(false) func hard*[T: BaseUint](x: T): HardBase[T] {.inline.}= diff --git a/tests/all_tests.nim b/tests/all_tests.nim index bff9359..d361491 100644 --- a/tests/all_tests.nim +++ b/tests/all_tests.nim @@ -132,3 +132,37 @@ suite "Hardened unsigned integers": undistinct(-y1.hard) == undistinct(not(y1.hard) + hard(1'u64)) undistinct(-y2.hard) == undistinct(not(y2.hard) + hard(1'u64)) undistinct(-y3.hard) == undistinct(not(y3.hard) + hard(1'u64)) + +suite "Hardened booleans": + test "Boolean not": + check: + not(htrue(uint32)).bool == false + not(hfalse(uint32)).bool == true + + test "Comparison": + check: + bool(hard(0'u32) != hard(0'u32)) == false + bool(hard(0'u32) != hard(1'u32)) == true + + bool(hard(10'u32) == hard(10'u32)) == true + bool(hard(10'u32) != hard(20'u32)) == true + + bool(hard(10'u32) <= hard(10'u32)) == true + bool(hard(10'u32) <= hard(20'u32)) == true + bool(hard(10'u32) <= hard(5'u32)) == false + bool(hard(10'u32) <= hard(0xFFFFFFFF'u32)) == true + + bool(hard(10'u32) < hard(10'u32)) == false + bool(hard(10'u32) < hard(20'u32)) == true + bool(hard(10'u32) < hard(5'u32)) == false + bool(hard(10'u32) < hard(0xFFFFFFFF'u32)) == true + + bool(hard(10'u32) > hard(10'u32)) == false + bool(hard(10'u32) > hard(20'u32)) == false + bool(hard(10'u32) > hard(5'u32)) == true + bool(hard(10'u32) > hard(0xFFFFFFFF'u32)) == false + + bool(hard(10'u32) >= hard(10'u32)) == true + bool(hard(10'u32) >= hard(20'u32)) == false + bool(hard(10'u32) >= hard(5'u32)) == true + bool(hard(10'u32) >= hard(0xFFFFFFFF'u32)) == false