diff --git a/starky/src/fibonacci_stark.rs b/starky/src/fibonacci_stark.rs index 5e41f18e..ad2c8e8e 100644 --- a/starky/src/fibonacci_stark.rs +++ b/starky/src/fibonacci_stark.rs @@ -12,26 +12,33 @@ use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars}; /// Toy STARK system used for testing. /// Computes a Fibonacci sequence with inital values `x0, x1` using the transition /// `x0 <- x1, x1 <- x0 + x1`. -pub struct FibonacciStark, const D: usize> { +struct FibonacciStark, const D: usize> { x0: F, x1: F, + num_rows: usize, _phantom: PhantomData, } impl, const D: usize> FibonacciStark { - const NUM_COLUMNS: usize = 2; - const NUM_ROWS: usize = 1 << 5; + // The first public input is `x0`. + const PI_INDEX_X0: usize = 0; + // The second public input is `x1`. + const PI_INDEX_X1: usize = 1; + // The third public input is the second element of the last row, which should be equal to the + // `(num_rows + 1)`-th Fibonacci number. + const PI_INDEX_RES: usize = 2; - fn new(x0: F, x1: F) -> Self { + fn new(num_rows: usize, x0: F, x1: F) -> Self { Self { x0, x1, + num_rows, _phantom: PhantomData, } } - fn generate_trace(&self) -> Vec<[F; Self::NUM_COLUMNS]> { - (0..Self::NUM_ROWS) + fn generate_trace(&self) -> Vec<[F; Self::COLUMNS]> { + (0..self.num_rows) .scan([self.x0, self.x1], |acc, _| { let tmp = *acc; acc[0] = tmp[1]; @@ -43,8 +50,8 @@ impl, const D: usize> FibonacciStark { } impl, const D: usize> Stark for FibonacciStark { - const COLUMNS: usize = Self::NUM_COLUMNS; - const PUBLIC_INPUTS: usize = 0; + const COLUMNS: usize = 2; + const PUBLIC_INPUTS: usize = 3; fn eval_packed_generic( &self, @@ -54,6 +61,9 @@ impl, const D: usize> Stark for FibonacciStar FE: FieldExtension, P: PackedField, { + yield_constr.one_first_row(vars.local_values[0] - vars.public_inputs[Self::PI_INDEX_X0]); + yield_constr.one_first_row(vars.local_values[1] - vars.public_inputs[Self::PI_INDEX_X1]); + yield_constr.one_last_row(vars.local_values[1] - vars.public_inputs[Self::PI_INDEX_RES]); // x0 <- x1 yield_constr.one(vars.next_values[0] - vars.local_values[1]); // x1 <- x0 + x1 @@ -81,6 +91,10 @@ mod tests { use crate::fibonacci_stark::FibonacciStark; use crate::prover::prove; + fn fibonacci(n: usize, x0: usize, x1: usize) -> usize { + (0..n).fold((0, 1), |x, _| (x.1, x.0 + x.1)).1 + } + #[test] fn test_fibonacci_stark() -> Result<()> { const D: usize = 2; @@ -89,9 +103,21 @@ mod tests { type S = FibonacciStark; let config = StarkConfig::standard_fast_config(); - let stark = S::new(F::ZERO, F::ONE); + let num_rows = 1 << 5; + let public_inputs = [ + F::ZERO, + F::ONE, + F::from_canonical_usize(fibonacci(num_rows - 1, 0, 1)), + ]; + let stark = S::new(num_rows, public_inputs[0], public_inputs[1]); let trace = stark.generate_trace(); - prove::(stark, config, trace, &mut TimingTree::default())?; + prove::( + stark, + config, + trace, + public_inputs, + &mut TimingTree::default(), + )?; Ok(()) } diff --git a/starky/src/prover.rs b/starky/src/prover.rs index 5af22871..5473db68 100644 --- a/starky/src/prover.rs +++ b/starky/src/prover.rs @@ -25,6 +25,7 @@ pub fn prove( stark: S, config: StarkConfig, trace: Vec<[F; S::COLUMNS]>, + public_inputs: [F; S::PUBLIC_INPUTS], timing: &mut TimingTree, ) -> Result> where @@ -72,6 +73,7 @@ where let quotient_polys = compute_quotient_polys::( &stark, &trace_commitment, + public_inputs, &alphas, degree_bits, rate_bits, @@ -142,6 +144,7 @@ where fn compute_quotient_polys( stark: &S, trace_commitment: &PolynomialBatch, + public_inputs: [F; S::PUBLIC_INPUTS], alphas: &[F], degree_bits: usize, rate_bits: usize, @@ -196,7 +199,7 @@ where trace_commitment, (i + 1) % (degree << rate_bits), ), - public_inputs: &[F::ZERO; S::PUBLIC_INPUTS], + public_inputs: &public_inputs, }; stark.eval_packed_base(vars, &mut consumer); // TODO: Fix this once we a genuine `PackedField`. diff --git a/system_zero/src/system_zero.rs b/system_zero/src/system_zero.rs index 47950eb2..49e25e6c 100644 --- a/system_zero/src/system_zero.rs +++ b/system_zero/src/system_zero.rs @@ -83,27 +83,33 @@ impl, const D: usize> Stark for SystemZero Result<()> { type F = GoldilocksField; type C = PoseidonGoldilocksConfig; const D: usize = 2; type S = SystemZero; let system = S::default(); + let public_inputs = [F::ZERO; S::PUBLIC_INPUTS]; let config = StarkConfig::standard_fast_config(); let mut timing = TimingTree::new("prove", Level::Debug); let trace = system.generate_trace(); - prove::(system, config, trace, &mut timing).unwrap(); + prove::(system, config, trace, public_inputs, &mut timing)?; + + Ok(()) } }