From fcdeb1c31e9514c9b617dfbfab7364ed0e2bb61b Mon Sep 17 00:00:00 2001 From: gmega Date: Fri, 21 Feb 2025 18:16:55 -0300 Subject: [PATCH] feat: add primitives for sampling with and without replacement from a set using rng --- codex/rng.nim | 38 +++++++++++++++++++++++++++++++++----- tests/codex/testrng.nim | 37 +++++++++++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+), 5 deletions(-) create mode 100644 tests/codex/testrng.nim diff --git a/codex/rng.nim b/codex/rng.nim index 9d82156e..e36507f7 100644 --- a/codex/rng.nim +++ b/codex/rng.nim @@ -7,11 +7,7 @@ ## This file may not be copied, modified, or distributed except according to ## those terms. -import pkg/upraises - -push: - {.upraises: [].} - +import std/sugar import pkg/libp2p/crypto/crypto import pkg/bearssl/rand @@ -39,9 +35,41 @@ proc rand*(rng: Rng, max: Natural): int = if x < randMax - (randMax mod (uint64(max) + 1'u64)): # against modulo bias return int(x mod (uint64(max) + 1'u64)) +proc sampleNoReplacement[T](a: seq[T], n: int): seq[T] = + if n > a.len: + raise newException( + RngSampleError, + "Cannot sample " & $n & " elements from a set of size " & $a.len & + " without replacement.", + ) + + if n == a.len: + return a + + var x = a + collect: + for i in 0 ..< n: + swap(x[i], x[i + rng.rand(x.len - i - 1)]) + x[i] + +proc sampleWithReplacement[T](a: seq[T], n: int): seq[T] = + collect: + for i in 0 ..< n: + a[rng.rand(a.high)] + proc sample*[T](rng: Rng, a: openArray[T]): T = result = a[rng.rand(a.high)] +proc sample*[T](rng: Rng, a: seq[T], n: int, replace: bool = false): seq[T] = + ## Sample `n` elements from a set `a` with or without replacement. In case of + ## sampling with replacement, `n` must not be greater than the size of `a`. + ## In case of sampling without replacement, `n` must not be greater than the + ## size of `a` minus 1. + if replace: + sampleWithReplacement(a, n) + else: + sampleNoReplacement(a, n) + proc sample*[T]( rng: Rng, sample, exclude: openArray[T] ): T {.raises: [Defect, RngSampleError].} = diff --git a/tests/codex/testrng.nim b/tests/codex/testrng.nim new file mode 100644 index 00000000..f97a253b --- /dev/null +++ b/tests/codex/testrng.nim @@ -0,0 +1,37 @@ +import std/unittest +import std/sequtils +import std/sets + +import ../../codex/rng + +suite "Random Number Generator (RNG)": + let rng = Rng.instance() + + test "should sample with replacement": + let elements = toSeq(1 .. 10) + + let sample = rng.sample(elements, n = 15, replace = true) + check sample.len == 15 + for element in sample: + check element in elements + + test "should sample without replacement": + let elements = toSeq(1 .. 10) + + # If we were not drawing without replacement, there'd be a 1/2 chance + # that we'd draw the same element twice in a sample of size 5. + # Running this 40 times gives enough assurance. + var seen: array[10, bool] + for i in 1 .. 40: + let sample = rng.sample(elements, n = 5, replace = false) + + check sample.len == 5 + check sample.toHashSet.len == 5 + + for element in sample: + seen[element - 1] = true + + # There's a 1/2 chance we'll see an element for each draw we do. + # After 40 draws, we are reasonably sure we've seen every element. + for seen in seen: + check seen