diff --git a/hardy/ct_primitives.nim b/hardy/ct_primitives.nim index 841563a..6764b54 100644 --- a/hardy/ct_primitives.nim +++ b/hardy/ct_primitives.nim @@ -51,6 +51,11 @@ func `*`*[T: HardBase](x, y: T): T {.magic: "MulU".} # We don't implement div/mod as we can't assume the hardware implementation # is constant-time +func `-`*(x: HardBase): HardBase {.inline.}= + ## Unary minus returns the two-complement representation + ## of an unsigned integer + {.emit:"`result` = -`x`;".} + # ############################################################ # # Hardened Boolean primitives @@ -61,11 +66,6 @@ func `not`*(ctl: HardBool): HardBool {.inline.}= ## Negate a constant-time boolean ctl xor 1 -func `-`*(x: HardBase): HardBase {.inline.}= - ## Unary minus returns the two-complement representation - ## of an unsigned integer - {.emit:"`result` = -`x`;".} - func select*[T: HardBase](ctl: HardBool[T], x, y: T): T {.inline.}= ## Multiplexer / selector ## Returns x if ctl == 1 diff --git a/hardy/datatypes.nim b/hardy/datatypes.nim index ad72c0b..6ddeedf 100644 --- a/hardy/datatypes.nim +++ b/hardy/datatypes.nim @@ -30,10 +30,6 @@ func htrue*(T: type(BaseUint)): auto {.compileTime.}= func hfalse*(T: type(BaseUint)): auto {.compileTime.}= (HardBool[HardBase[T]])(false) -template hard*(x: static int, T: type BaseUint): HardBase[T] = - ## For int literals - (HardBase[T])(x) - func hard*[T: BaseUint](x: T): HardBase[T] {.inline.}= (HardBase[T])(x) diff --git a/tests/all_tests.nim b/tests/all_tests.nim index 1ca2700..bff9359 100644 --- a/tests/all_tests.nim +++ b/tests/all_tests.nim @@ -113,3 +113,22 @@ suite "Hardened unsigned integers": operator_check(`+`) operator_check(`-`) operator_check(`*`) + + test "Unary `-`, returning the 2-complement of an unsigned integer": + let x1 = rand(high(int)).uint64 + let y1 = rand(high(int)).uint64 + let x2 = rand(high(int)).uint64 + let y2 = rand(high(int)).uint64 + let x3 = rand(high(int)).uint64 + let y3 = rand(high(int)).uint64 + check: + (-hard(0'u32)).undistinct == 0 + (-high(HardBase[uint32])).undistinct == 1'u32 + (-hard(0x80000000'u32)).undistinct == 0x80000000'u32 # This is low(int32) == 0b10000..0000 + + undistinct(-x1.hard) == undistinct(not(x1.hard) + hard(1'u64)) + undistinct(-x2.hard) == undistinct(not(x2.hard) + hard(1'u64)) + undistinct(-x3.hard) == undistinct(not(x3.hard) + hard(1'u64)) + undistinct(-y1.hard) == undistinct(not(y1.hard) + hard(1'u64)) + undistinct(-y2.hard) == undistinct(not(y2.hard) + hard(1'u64)) + undistinct(-y3.hard) == undistinct(not(y3.hard) + hard(1'u64))