diff --git a/evm/Cargo.toml b/evm/Cargo.toml index db774345..6db81902 100644 --- a/evm/Cargo.toml +++ b/evm/Cargo.toml @@ -5,11 +5,13 @@ version = "0.1.0" edition = "2021" [dependencies] -plonky2 = { path = "../plonky2" } +plonky2 = { path = "../plonky2", default-features = false, features = ["rand", "timing"] } plonky2_util = { path = "../util" } +eth-trie-utils = { git = "https://github.com/mir-protocol/eth-trie-utils.git", rev = "dd3595b4ba7923f8d465450d210f17a2b4e20f96" } +maybe_rayon = { path = "../maybe_rayon" } anyhow = "1.0.40" env_logger = "0.9.0" -ethereum-types = "0.13.1" +ethereum-types = "0.14.0" hex = { version = "0.4.3", optional = true } hex-literal = "0.3.4" itertools = "0.10.3" @@ -17,10 +19,10 @@ log = "0.4.14" once_cell = "1.13.0" pest = "2.1.3" pest_derive = "2.1.0" -maybe_rayon = { path = "../maybe_rayon" } rand = "0.8.5" rand_chacha = "0.3.1" rlp = "0.5.1" +serde = { version = "1.0.144", features = ["derive"] } keccak-hash = "0.9.0" tiny-keccak = "2.0.2" @@ -31,7 +33,7 @@ hex = "0.4.3" [features] default = ["parallel"] asmtools = ["hex"] -parallel = ["maybe_rayon/parallel"] +parallel = ["plonky2/parallel", "maybe_rayon/parallel"] [[bin]] name = "assemble" diff --git a/evm/spec/.gitignore b/evm/spec/.gitignore new file mode 100644 index 00000000..ba6d4007 --- /dev/null +++ b/evm/spec/.gitignore @@ -0,0 +1,7 @@ +## Files generated by pdflatex, bibtex, etc. +*.aux +*.log +*.out +*.toc +*.bbl +*.blg diff --git a/evm/spec/Makefile b/evm/spec/Makefile new file mode 100644 index 00000000..97954528 --- /dev/null +++ b/evm/spec/Makefile @@ -0,0 +1,20 @@ +DOCNAME=zkevm + +all: pdf + +.PHONY: clean + +quick: + pdflatex $(DOCNAME).tex + +pdf: + pdflatex $(DOCNAME).tex + bibtex $(DOCNAME).aux + pdflatex $(DOCNAME).tex + pdflatex $(DOCNAME).tex + +view: pdf + open $(DOCNAME).pdf + +clean: + rm -f *.blg *.bbl *.aux *.log diff --git a/evm/spec/bibliography.bib b/evm/spec/bibliography.bib new file mode 100644 index 00000000..41fa56b8 --- /dev/null +++ b/evm/spec/bibliography.bib @@ -0,0 +1,20 @@ +@misc{stark, + author = {Eli Ben-Sasson and + Iddo Bentov and + Yinon Horesh and + Michael Riabzev}, + title = {Scalable, transparent, and post-quantum secure computational integrity}, + howpublished = {Cryptology ePrint Archive, Report 2018/046}, + year = {2018}, + note = {\url{https://ia.cr/2018/046}}, +} + +@misc{plonk, + author = {Ariel Gabizon and + Zachary J. Williamson and + Oana Ciobotaru}, + title = {PLONK: Permutations over Lagrange-bases for Oecumenical Noninteractive arguments of Knowledge}, + howpublished = {Cryptology ePrint Archive, Report 2019/953}, + year = {2019}, + note = {\url{https://ia.cr/2019/953}}, +} diff --git a/evm/spec/framework.tex b/evm/spec/framework.tex new file mode 100644 index 00000000..d99a31bb --- /dev/null +++ b/evm/spec/framework.tex @@ -0,0 +1,37 @@ +\section{STARK framework} +\label{framework} + + +\subsection{Cost model} + +Our zkEVM is designed for efficient verification by STARKs \cite{stark}, particularly by an AIR with degree 3 constraints. In this model, the prover bottleneck is typically constructing Merkle trees, particularly constructing the tree containing low-degree extensions of witness polynomials. + + +\subsection{Field selection} +\label{field} +Our zkEVM is designed to have its execution traces encoded in a particular prime field $\mathbb{F}_p$, with $p = 2^{64} - 2^{32} + 1$. A nice property of this field is that it can represent the results of many common \texttt{u32} operations. For example, (widening) \texttt{u32} multiplication has a maximum value of $(2^{32} - 1)^2$, which is less than $p$. In fact a \texttt{u32} multiply-add has a maximum value of $p - 1$, so the result can be represented with a single field element, although if we were to add a carry in bit, this no longer holds. + +This field also enables a very efficient reduction method. Observe that +$$ +2^{64} \equiv 2^{32} - 1 \pmod p +$$ +and consequently +\begin{align*} + 2^{96} &\equiv 2^{32} (2^{32} - 1) \pmod p \\ + &\equiv 2^{64} - 2^{32} \pmod p \\ + &\equiv -1 \pmod p. +\end{align*} +To reduce a 128-bit number $n$, we first rewrite $n$ as $n_0 + 2^{64} n_1 + 2^{96} n_2$, where $n_0$ is 64 bits and $n_1, n_2$ are 32 bits each. Then +\begin{align*} + n &\equiv n_0 + 2^{64} n_1 + 2^{96} n_2 \pmod p \\ + &\equiv n_0 + (2^{32} - 1) n_1 - n_2 \pmod p +\end{align*} +After computing $(2^{32} - 1) n_1$, which can be done with a shift and subtraction, we add the first two terms, subtracting $p$ if overflow occurs. We then subtract $n_2$, adding $p$ if underflow occurs. + +At this point we have reduced $n$ to a \texttt{u64}. This partial reduction is adequate for most purposes, but if we needed the result in canonical form, we would perform a final conditional subtraction. + + +\subsection{Cross-table lookups} +\label{ctl} + +TODO diff --git a/evm/spec/instructions.tex b/evm/spec/instructions.tex new file mode 100644 index 00000000..ea096982 --- /dev/null +++ b/evm/spec/instructions.tex @@ -0,0 +1,8 @@ +\section{Privileged instructions} +\label{privileged-instructions} + +\begin{enumerate} + \item[0xFB.] \texttt{MLOAD\_GENERAL}. Returns + \item[0xFC.] \texttt{MSTORE\_GENERAL}. Returns + \item[TODO.] \texttt{STACK\_SIZE}. Returns +\end{enumerate} diff --git a/evm/spec/introduction.tex b/evm/spec/introduction.tex new file mode 100644 index 00000000..cb969a16 --- /dev/null +++ b/evm/spec/introduction.tex @@ -0,0 +1,3 @@ +\section{Introduction} + +TODO diff --git a/evm/spec/tables.tex b/evm/spec/tables.tex new file mode 100644 index 00000000..92ee1d2a --- /dev/null +++ b/evm/spec/tables.tex @@ -0,0 +1,9 @@ +\section{Tables} +\label{tables} + +\input{tables/cpu} +\input{tables/arithmetic} +\input{tables/logic} +\input{tables/memory} +\input{tables/keccak-f} +\input{tables/keccak-sponge} diff --git a/evm/spec/tables/arithmetic.tex b/evm/spec/tables/arithmetic.tex new file mode 100644 index 00000000..eafed3ba --- /dev/null +++ b/evm/spec/tables/arithmetic.tex @@ -0,0 +1,4 @@ +\subsection{Arithmetic} +\label{arithmetic} + +TODO diff --git a/evm/spec/tables/cpu.tex b/evm/spec/tables/cpu.tex new file mode 100644 index 00000000..76c8be07 --- /dev/null +++ b/evm/spec/tables/cpu.tex @@ -0,0 +1,4 @@ +\subsection{CPU} +\label{cpu} + +TODO diff --git a/evm/spec/tables/keccak-f.tex b/evm/spec/tables/keccak-f.tex new file mode 100644 index 00000000..76e9e9f4 --- /dev/null +++ b/evm/spec/tables/keccak-f.tex @@ -0,0 +1,4 @@ +\subsection{Keccak-f} +\label{keccak-f} + +This table computes the Keccak-f[1600] permutation. diff --git a/evm/spec/tables/keccak-sponge.tex b/evm/spec/tables/keccak-sponge.tex new file mode 100644 index 00000000..29f71ba1 --- /dev/null +++ b/evm/spec/tables/keccak-sponge.tex @@ -0,0 +1,4 @@ +\subsection{Keccak sponge} +\label{keccak-sponge} + +This table computes the Keccak256 hash, a sponge-based hash built on top of the Keccak-f[1600] permutation. diff --git a/evm/spec/tables/logic.tex b/evm/spec/tables/logic.tex new file mode 100644 index 00000000..b430c95d --- /dev/null +++ b/evm/spec/tables/logic.tex @@ -0,0 +1,4 @@ +\subsection{Logic} +\label{logic} + +TODO diff --git a/evm/spec/tables/memory.tex b/evm/spec/tables/memory.tex new file mode 100644 index 00000000..9653f391 --- /dev/null +++ b/evm/spec/tables/memory.tex @@ -0,0 +1,61 @@ +\subsection{Memory} +\label{memory} + +For simplicity, let's treat addresses and values as individual field elements. The generalization to multi-element addresses and values is straightforward. + +Each row of the memory table corresponds to a single memory operation (a read or a write), and contains the following columns: + +\begin{enumerate} + \item $a$, the target address + \item $r$, an ``is read'' flag, which should be 1 for a read or 0 for a write + \item $v$, the value being read or written + \item $\tau$, the timestamp of the operation +\end{enumerate} +The memory table should be ordered by $(a, \tau)$. Note that the correctness memory could be checked as follows: +\begin{enumerate} + \item Verify the ordering by checking that $(a_i, \tau_i) < (a_{i+1}, \tau_{i+1})$ for each consecutive pair. + \item Enumerate the purportedly-ordered log while tracking a ``current'' value $c$, which is initially zero.\footnote{EVM memory is zero-initialized.} + \begin{enumerate} + \item Upon observing an address which doesn't match that of the previous row, set $c \leftarrow 0$. + \item Upon observing a write, set $c \leftarrow v$. + \item Upon observing a read, check that $v = c$. + \end{enumerate} +\end{enumerate} + +The ordering check is slightly involved since we are comparing multiple columns. To facilitate this, we add an additional column $e$, where the prover can indicate whether two consecutive addresses are equal. An honest prover will set +$$ +e_i \leftarrow \begin{cases} + 1 & \text{if } a_i = a_{i + 1}, \\ + 0 & \text{otherwise}. +\end{cases} +$$ +We then impose the following transition constraints: +\begin{enumerate} + \item $e_i (e_i - 1) = 0$, + \item $e_i (a_i - a_{i + 1}) = 0$, + \item $e_i (\tau_{i + 1} - \tau_i) + (1 - e_i) (a_{i + 1} - a_i - 1) < 2^{32}$. +\end{enumerate} +The last constraint emulates a comparison between two addresses or timestamps by bounding their difference; this assumes that all addresses and timestamps fit in 32 bits and that the field is larger than that. + +Finally, the iterative checks can be arithmetized by introducing a trace column for the current value $c$. We add a boundary constraint $c_0 = 0$, and the following transition constraints: +\todo{This is out of date, we don't actually need a $c$ column.} +\begin{enumerate} + \item $v_{\text{from},i} = c_i$, + \item $c_{i + 1} = e_i v_{\text{to},i}$. +\end{enumerate} + + +\subsubsection{Virtual memory} + +In the EVM, each contract call has its own address space. Within that address space, there are separate segments for code, main memory, stack memory, calldata, and returndata. Thus each address actually has three compoments: +\begin{enumerate} + \item an execution context, representing a contract call, + \item a segment ID, used to separate code, main memory, and so forth, and so on + \item a virtual address. +\end{enumerate} +The comparisons now involve several columns, which requires some minor adaptations to the technique described above; we will leave these as an exercise to the reader. + + +\subsubsection{Timestamps} + +TODO: Explain $\tau = \texttt{NUM\_CHANNELS} \times \texttt{cycle} + \texttt{channel}$. diff --git a/evm/spec/tries.tex b/evm/spec/tries.tex new file mode 100644 index 00000000..fed78f40 --- /dev/null +++ b/evm/spec/tries.tex @@ -0,0 +1,16 @@ +\section{Merkle Patricia tries} +\label{tries} + +\subsection{Internal memory format} + +Withour our zkEVM's kernel memory, +\begin{enumerate} + \item An empty node is encoded as $(\texttt{MPT\_NODE\_EMPTY})$. + \item A branch node is encoded as $(\texttt{MPT\_NODE\_BRANCH}, c_1, \dots, c_{16}, v)$, where each $c_i$ is a pointer to a child node, and $v$ is a leaf payload. + \item An extension node is encoded as $(\texttt{MPT\_NODE\_EXTENSION}, k, c)$, $k$ is a 2-tuple $(\texttt{packed\_nibbles}, \texttt{num\_nibbles})$, and $c$ is a pointer to a child node. + \item A leaf node is encoded as $(\texttt{MPT\_NODE\_LEAF}, k, v)$, where $k$ is a 2-tuple as above, and $v$ is a leaf payload. + \item A digest node is encoded as $(\texttt{MPT\_NODE\_DIGEST}, d)$, where $d$ is a Keccak256 digest. +\end{enumerate} + + +\subsection{Prover input format} diff --git a/evm/spec/zkevm.pdf b/evm/spec/zkevm.pdf new file mode 100644 index 00000000..aff46eda Binary files /dev/null and b/evm/spec/zkevm.pdf differ diff --git a/evm/spec/zkevm.tex b/evm/spec/zkevm.tex new file mode 100644 index 00000000..f87f02f3 --- /dev/null +++ b/evm/spec/zkevm.tex @@ -0,0 +1,59 @@ +\documentclass[12pt]{article} +\usepackage{amsmath} +\usepackage{amssymb} +\usepackage{cite} +\usepackage{draftwatermark} +\usepackage[margin=1.5in]{geometry} +\usepackage{hyperref} +\usepackage{makecell} +\usepackage{mathtools} +\usepackage{tabularx} +\usepackage[textwidth=1.25in]{todonotes} + +% Scale for DRAFT watermark. +\SetWatermarkFontSize{24cm} +\SetWatermarkScale{5} +\SetWatermarkLightness{0.92} + +% Hyperlink colors. +\hypersetup{ + colorlinks=true, + linkcolor=blue, + citecolor=blue, + urlcolor=blue, +} + +% We want all section autorefs to say "Section". +\def\sectionautorefname{Section} +\let\subsectionautorefname\sectionautorefname +\let\subsubsectionautorefname\sectionautorefname + +% \floor{...} and \ceil{...} +\DeclarePairedDelimiter\ceil{\lceil}{\rceil} +\DeclarePairedDelimiter\floor{\lfloor}{\rfloor} + +\title{The Polygon Zero zkEVM} +%\author{Polygon Zero Team} +\date{DRAFT\\\today} + +\begin{document} +\maketitle + +\begin{abstract} + We describe the design of Polygon Zero's zkEVM, ... +\end{abstract} + +\newpage +{\hypersetup{hidelinks} \tableofcontents} +\newpage + +\input{introduction} +\input{framework} +\input{tables} +\input{tries} +\input{instructions} + +\bibliography{bibliography}{} +\bibliographystyle{ieeetr} + +\end{document} diff --git a/evm/src/all_stark.rs b/evm/src/all_stark.rs index 164c60fc..07f38694 100644 --- a/evm/src/all_stark.rs +++ b/evm/src/all_stark.rs @@ -1,3 +1,5 @@ +use std::iter; + use plonky2::field::extension::Extendable; use plonky2::field::types::Field; use plonky2::hash::hash_types::RichField; @@ -5,6 +7,7 @@ use plonky2::hash::hash_types::RichField; use crate::config::StarkConfig; use crate::cpu::cpu_stark; use crate::cpu::cpu_stark::CpuStark; +use crate::cpu::membus::NUM_GP_CHANNELS; use crate::cross_table_lookup::{CrossTableLookup, TableWithColumns}; use crate::keccak::keccak_stark; use crate::keccak::keccak_stark::KeccakStark; @@ -13,8 +16,8 @@ use crate::keccak_memory::keccak_memory_stark; use crate::keccak_memory::keccak_memory_stark::KeccakMemoryStark; use crate::logic; use crate::logic::LogicStark; +use crate::memory::memory_stark; use crate::memory::memory_stark::MemoryStark; -use crate::memory::{memory_stark, NUM_CHANNELS}; use crate::stark::Stark; #[derive(Clone)] @@ -129,11 +132,16 @@ fn ctl_logic() -> CrossTableLookup { } fn ctl_memory() -> CrossTableLookup { - let cpu_memory_ops = (0..NUM_CHANNELS).map(|channel| { + let cpu_memory_code_read = TableWithColumns::new( + Table::Cpu, + cpu_stark::ctl_data_code_memory(), + Some(cpu_stark::ctl_filter_code_memory()), + ); + let cpu_memory_gp_ops = (0..NUM_GP_CHANNELS).map(|channel| { TableWithColumns::new( Table::Cpu, - cpu_stark::ctl_data_memory(channel), - Some(cpu_stark::ctl_filter_memory(channel)), + cpu_stark::ctl_data_gp_memory(channel), + Some(cpu_stark::ctl_filter_gp_memory(channel)), ) }); let keccak_memory_reads = (0..KECCAK_WIDTH_BYTES).map(|i| { @@ -150,7 +158,8 @@ fn ctl_memory() -> CrossTableLookup { Some(keccak_memory_stark::ctl_filter()), ) }); - let all_lookers = cpu_memory_ops + let all_lookers = iter::once(cpu_memory_code_read) + .chain(cpu_memory_gp_ops) .chain(keccak_memory_reads) .chain(keccak_memory_writes) .collect(); @@ -214,14 +223,18 @@ mod tests { let keccak_inputs = (0..num_keccak_perms) .map(|_| [0u64; NUM_INPUTS].map(|_| rng.gen())) .collect_vec(); - keccak_stark.generate_trace(keccak_inputs) + keccak_stark.generate_trace(keccak_inputs, &mut TimingTree::default()) } fn make_keccak_memory_trace( keccak_memory_stark: &KeccakMemoryStark, config: &StarkConfig, ) -> Vec> { - keccak_memory_stark.generate_trace(vec![], config.fri_config.num_cap_elements()) + keccak_memory_stark.generate_trace( + vec![], + config.fri_config.num_cap_elements(), + &mut TimingTree::default(), + ) } fn make_logic_trace( @@ -238,7 +251,7 @@ mod tests { Operation::new(op, input0, input1) }) .collect(); - logic_stark.generate_trace(ops) + logic_stark.generate_trace(ops, &mut TimingTree::default()) } fn make_memory_trace( @@ -247,7 +260,7 @@ mod tests { rng: &mut R, ) -> (Vec>, usize) { let memory_ops = generate_random_memory_ops(num_memory_ops, rng); - let trace = memory_stark.generate_trace(memory_ops); + let trace = memory_stark.generate_trace(memory_ops, &mut TimingTree::default()); let num_ops = trace[0].values.len(); (trace, num_ops) } @@ -325,7 +338,7 @@ mod tests { row.opcode_bits = bits_from_opcode(0x5b); row.is_cpu_cycle = F::ONE; row.is_kernel_mode = F::ONE; - row.program_counter = F::from_canonical_usize(KERNEL.global_labels["route_txn"]); + row.program_counter = F::from_canonical_usize(KERNEL.global_labels["main"]); cpu_stark.generate(row.borrow_mut()); cpu_trace_rows.push(row.into()); } @@ -364,8 +377,8 @@ mod tests { row.is_cpu_cycle = F::ONE; row.is_kernel_mode = F::ONE; - // Since these are the first cycle rows, we must start with PC=route_txn then increment. - row.program_counter = F::from_canonical_usize(KERNEL.global_labels["route_txn"] + i); + // Since these are the first cycle rows, we must start with PC=main then increment. + row.program_counter = F::from_canonical_usize(KERNEL.global_labels["main"] + i); row.opcode_bits = bits_from_opcode( if logic_trace[logic::columns::IS_AND].values[i] != F::ZERO { 0x16 @@ -725,6 +738,7 @@ mod tests { } #[test] + #[ignore] // Ignoring but not deleting so the test can serve as an API usage example fn test_all_stark() -> Result<()> { let config = StarkConfig::standard_fast_config(); let (all_stark, proof) = get_proof(&config)?; @@ -732,6 +746,7 @@ mod tests { } #[test] + #[ignore] // Ignoring but not deleting so the test can serve as an API usage example fn test_all_stark_recursive_verifier() -> Result<()> { init_logger(); diff --git a/evm/src/arithmetic/compare.rs b/evm/src/arithmetic/compare.rs index 8410cade..a6566db5 100644 --- a/evm/src/arithmetic/compare.rs +++ b/evm/src/arithmetic/compare.rs @@ -45,6 +45,14 @@ pub(crate) fn generate(lv: &mut [F; NUM_ARITH_COLUMNS], op: usize) lv[CMP_OUTPUT] = F::from_canonical_u64(br); } +fn eval_packed_generic_check_is_one_bit( + yield_constr: &mut ConstraintConsumer

, + filter: P, + x: P, +) { + yield_constr.constraint(filter * x * (x - P::ONES)); +} + pub(crate) fn eval_packed_generic_lt( yield_constr: &mut ConstraintConsumer

, is_op: P, @@ -69,15 +77,31 @@ pub fn eval_packed_generic( range_check_error!(CMP_INPUT_0, 16); range_check_error!(CMP_INPUT_1, 16); range_check_error!(CMP_AUX_INPUT, 16); - range_check_error!([CMP_OUTPUT], 1); + + let is_lt = lv[IS_LT]; + let is_gt = lv[IS_GT]; let input0 = CMP_INPUT_0.map(|c| lv[c]); let input1 = CMP_INPUT_1.map(|c| lv[c]); let aux = CMP_AUX_INPUT.map(|c| lv[c]); let output = lv[CMP_OUTPUT]; - eval_packed_generic_lt(yield_constr, lv[IS_LT], input0, input1, aux, output); - eval_packed_generic_lt(yield_constr, lv[IS_GT], input1, input0, aux, output); + let is_cmp = is_lt + is_gt; + eval_packed_generic_check_is_one_bit(yield_constr, is_cmp, output); + + eval_packed_generic_lt(yield_constr, is_lt, input0, input1, aux, output); + eval_packed_generic_lt(yield_constr, is_gt, input1, input0, aux, output); +} + +fn eval_ext_circuit_check_is_one_bit, const D: usize>( + builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, + yield_constr: &mut RecursiveConstraintConsumer, + filter: ExtensionTarget, + x: ExtensionTarget, +) { + let constr = builder.mul_sub_extension(x, x, x); + let filtered_constr = builder.mul_extension(filter, constr); + yield_constr.constraint(builder, filtered_constr); } #[allow(clippy::needless_collect)] @@ -117,29 +141,19 @@ pub fn eval_ext_circuit, const D: usize>( lv: &[ExtensionTarget; NUM_ARITH_COLUMNS], yield_constr: &mut RecursiveConstraintConsumer, ) { + let is_lt = lv[IS_LT]; + let is_gt = lv[IS_GT]; + let input0 = CMP_INPUT_0.map(|c| lv[c]); let input1 = CMP_INPUT_1.map(|c| lv[c]); let aux = CMP_AUX_INPUT.map(|c| lv[c]); let output = lv[CMP_OUTPUT]; - eval_ext_circuit_lt( - builder, - yield_constr, - lv[IS_LT], - input0, - input1, - aux, - output, - ); - eval_ext_circuit_lt( - builder, - yield_constr, - lv[IS_GT], - input1, - input0, - aux, - output, - ); + let is_cmp = builder.add_extension(is_lt, is_gt); + eval_ext_circuit_check_is_one_bit(builder, yield_constr, is_cmp, output); + + eval_ext_circuit_lt(builder, yield_constr, is_lt, input0, input1, aux, output); + eval_ext_circuit_lt(builder, yield_constr, is_gt, input1, input0, aux, output); } #[cfg(test)] diff --git a/evm/src/config.rs b/evm/src/config.rs index 500cd957..a593c827 100644 --- a/evm/src/config.rs +++ b/evm/src/config.rs @@ -21,9 +21,9 @@ impl StarkConfig { fri_config: FriConfig { rate_bits: 1, cap_height: 4, - proof_of_work_bits: 10, + proof_of_work_bits: 16, reduction_strategy: FriReductionStrategy::ConstantArityBits(4, 5), - num_query_rounds: 90, + num_query_rounds: 84, }, } } diff --git a/evm/src/cpu/bootstrap_kernel.rs b/evm/src/cpu/bootstrap_kernel.rs index 533589af..dd52f166 100644 --- a/evm/src/cpu/bootstrap_kernel.rs +++ b/evm/src/cpu/bootstrap_kernel.rs @@ -17,7 +17,6 @@ use crate::cpu::kernel::aggregator::KERNEL; use crate::cpu::kernel::keccak_util::keccakf_u32s; use crate::generation::state::GenerationState; use crate::memory::segments::Segment; -use crate::memory::NUM_CHANNELS; use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars}; /// The Keccak rate (1088 bits), measured in bytes. @@ -47,8 +46,7 @@ pub(crate) fn generate_bootstrap_kernel(state: &mut GenerationState // Write this chunk to memory, while simultaneously packing its bytes into a u32 word. let mut packed_bytes: u32 = 0; - for (addr, byte) in chunk { - let channel = addr % NUM_CHANNELS; + for (channel, (addr, byte)) in chunk.enumerate() { state.set_mem_cpu_current(channel, Segment::Code, addr, byte.into()); packed_bytes = (packed_bytes << 8) | byte as u32; diff --git a/evm/src/cpu/columns/mod.rs b/evm/src/cpu/columns/mod.rs index 567c5a97..d0ef3f28 100644 --- a/evm/src/cpu/columns/mod.rs +++ b/evm/src/cpu/columns/mod.rs @@ -7,10 +7,15 @@ use std::mem::{size_of, transmute}; use std::ops::{Index, IndexMut}; use crate::cpu::columns::general::CpuGeneralColumnsView; +use crate::cpu::columns::ops::OpsColumnsView; +use crate::cpu::membus::NUM_GP_CHANNELS; use crate::memory; use crate::util::{indices_arr, transmute_no_compile_time_size_checks}; mod general; +pub(crate) mod ops; + +pub type MemValue = [T; memory::VALUE_LIMBS]; #[repr(C)] #[derive(Clone, Copy, Debug, Eq, PartialEq)] @@ -22,7 +27,7 @@ pub struct MemoryChannelView { pub addr_context: T, pub addr_segment: T, pub addr_virtual: T, - pub value: [T; memory::VALUE_LIMBS], + pub value: MemValue, } #[repr(C)] @@ -35,110 +40,29 @@ pub struct CpuColumnsView { /// Lets us re-use columns in non-cycle rows. pub is_cpu_cycle: T, + /// If CPU cycle: Current context. + // TODO: this is currently unconstrained + pub context: T, + + /// If CPU cycle: Context for code memory channel. + pub code_context: T, + /// If CPU cycle: The program counter for the current instruction. pub program_counter: T, + /// If CPU cycle: The stack length. + pub stack_len: T, + + /// If CPU cycle: A prover-provided value needed to show that the instruction does not cause the + /// stack to underflow or overflow. + pub stack_len_bounds_aux: T, + /// If CPU cycle: We're in kernel (privileged) mode. pub is_kernel_mode: T, - // If CPU cycle: flags for EVM instructions. PUSHn, DUPn, and SWAPn only get one flag each. - // Invalid opcodes are split between a number of flags for practical reasons. Exactly one of - // these flags must be 1. - pub is_stop: T, - pub is_add: T, - pub is_mul: T, - pub is_sub: T, - pub is_div: T, - pub is_sdiv: T, - pub is_mod: T, - pub is_smod: T, - pub is_addmod: T, - pub is_mulmod: T, - pub is_exp: T, - pub is_signextend: T, - pub is_lt: T, - pub is_gt: T, - pub is_slt: T, - pub is_sgt: T, - pub is_eq: T, // Note: This column must be 0 when is_cpu_cycle = 0. - pub is_iszero: T, // Note: This column must be 0 when is_cpu_cycle = 0. - pub is_and: T, - pub is_or: T, - pub is_xor: T, - pub is_not: T, - pub is_byte: T, - pub is_shl: T, - pub is_shr: T, - pub is_sar: T, - pub is_keccak256: T, - pub is_address: T, - pub is_balance: T, - pub is_origin: T, - pub is_caller: T, - pub is_callvalue: T, - pub is_calldataload: T, - pub is_calldatasize: T, - pub is_calldatacopy: T, - pub is_codesize: T, - pub is_codecopy: T, - pub is_gasprice: T, - pub is_extcodesize: T, - pub is_extcodecopy: T, - pub is_returndatasize: T, - pub is_returndatacopy: T, - pub is_extcodehash: T, - pub is_blockhash: T, - pub is_coinbase: T, - pub is_timestamp: T, - pub is_number: T, - pub is_difficulty: T, - pub is_gaslimit: T, - pub is_chainid: T, - pub is_selfbalance: T, - pub is_basefee: T, - pub is_prover_input: T, - pub is_pop: T, - pub is_mload: T, - pub is_mstore: T, - pub is_mstore8: T, - pub is_sload: T, - pub is_sstore: T, - pub is_jump: T, // Note: This column must be 0 when is_cpu_cycle = 0. - pub is_jumpi: T, // Note: This column must be 0 when is_cpu_cycle = 0. - pub is_pc: T, - pub is_msize: T, - pub is_gas: T, - pub is_jumpdest: T, - pub is_get_state_root: T, - pub is_set_state_root: T, - pub is_get_receipt_root: T, - pub is_set_receipt_root: T, - pub is_push: T, - pub is_dup: T, - pub is_swap: T, - pub is_log0: T, - pub is_log1: T, - pub is_log2: T, - pub is_log3: T, - pub is_log4: T, - // PANIC does not get a flag; it fails at the decode stage. - pub is_create: T, - pub is_call: T, - pub is_callcode: T, - pub is_return: T, - pub is_delegatecall: T, - pub is_create2: T, - pub is_get_context: T, - pub is_set_context: T, - pub is_consume_gas: T, - pub is_exit_kernel: T, - pub is_staticcall: T, - pub is_mload_general: T, - pub is_mstore_general: T, - pub is_revert: T, - pub is_selfdestruct: T, - - pub is_invalid: T, + /// If CPU cycle: flags for EVM instructions (a few cannot be shared; see the comments in + /// `OpsColumnsView`). + pub op: OpsColumnsView, /// If CPU cycle: the opcode, broken up into bits in little-endian order. pub opcode_bits: [T; 8], @@ -152,7 +76,7 @@ pub struct CpuColumnsView { pub(crate) general: CpuGeneralColumnsView, pub(crate) clock: T, - pub mem_channels: [MemoryChannelView; memory::NUM_CHANNELS], + pub mem_channels: [MemoryChannelView; NUM_GP_CHANNELS], } // `u8` is guaranteed to have a `size_of` of 1. diff --git a/evm/src/cpu/columns/ops.rs b/evm/src/cpu/columns/ops.rs new file mode 100644 index 00000000..e0cb2952 --- /dev/null +++ b/evm/src/cpu/columns/ops.rs @@ -0,0 +1,154 @@ +use std::borrow::{Borrow, BorrowMut}; +use std::mem::{size_of, transmute}; +use std::ops::{Deref, DerefMut}; + +use crate::util::{indices_arr, transmute_no_compile_time_size_checks}; + +#[repr(C)] +#[derive(Eq, PartialEq, Debug)] +pub struct OpsColumnsView { + pub stop: T, + pub add: T, + pub mul: T, + pub sub: T, + pub div: T, + pub sdiv: T, + pub mod_: T, + pub smod: T, + pub addmod: T, + pub mulmod: T, + pub exp: T, + pub signextend: T, + pub lt: T, + pub gt: T, + pub slt: T, + pub sgt: T, + pub eq: T, // Note: This column must be 0 when is_cpu_cycle = 0. + pub iszero: T, // Note: This column must be 0 when is_cpu_cycle = 0. + pub and: T, + pub or: T, + pub xor: T, + pub not: T, + pub byte: T, + pub shl: T, + pub shr: T, + pub sar: T, + pub keccak256: T, + pub keccak_general: T, + pub address: T, + pub balance: T, + pub origin: T, + pub caller: T, + pub callvalue: T, + pub calldataload: T, + pub calldatasize: T, + pub calldatacopy: T, + pub codesize: T, + pub codecopy: T, + pub gasprice: T, + pub extcodesize: T, + pub extcodecopy: T, + pub returndatasize: T, + pub returndatacopy: T, + pub extcodehash: T, + pub blockhash: T, + pub coinbase: T, + pub timestamp: T, + pub number: T, + pub difficulty: T, + pub gaslimit: T, + pub chainid: T, + pub selfbalance: T, + pub basefee: T, + pub prover_input: T, + pub pop: T, + pub mload: T, + pub mstore: T, + pub mstore8: T, + pub sload: T, + pub sstore: T, + pub jump: T, // Note: This column must be 0 when is_cpu_cycle = 0. + pub jumpi: T, // Note: This column must be 0 when is_cpu_cycle = 0. + pub pc: T, + pub msize: T, + pub gas: T, + pub jumpdest: T, + pub get_state_root: T, + pub set_state_root: T, + pub get_receipt_root: T, + pub set_receipt_root: T, + pub push: T, + pub dup: T, + pub swap: T, + pub log0: T, + pub log1: T, + pub log2: T, + pub log3: T, + pub log4: T, + // PANIC does not get a flag; it fails at the decode stage. + pub create: T, + pub call: T, + pub callcode: T, + pub return_: T, + pub delegatecall: T, + pub create2: T, + pub get_context: T, + pub set_context: T, + pub consume_gas: T, + pub exit_kernel: T, + pub staticcall: T, + pub mload_general: T, + pub mstore_general: T, + pub revert: T, + pub selfdestruct: T, + + // TODO: this doesn't actually need its own flag. We can just do `1 - sum(all other flags)`. + pub invalid: T, +} + +// `u8` is guaranteed to have a `size_of` of 1. +pub const NUM_OPS_COLUMNS: usize = size_of::>(); + +impl From<[T; NUM_OPS_COLUMNS]> for OpsColumnsView { + fn from(value: [T; NUM_OPS_COLUMNS]) -> Self { + unsafe { transmute_no_compile_time_size_checks(value) } + } +} + +impl From> for [T; NUM_OPS_COLUMNS] { + fn from(value: OpsColumnsView) -> Self { + unsafe { transmute_no_compile_time_size_checks(value) } + } +} + +impl Borrow> for [T; NUM_OPS_COLUMNS] { + fn borrow(&self) -> &OpsColumnsView { + unsafe { transmute(self) } + } +} + +impl BorrowMut> for [T; NUM_OPS_COLUMNS] { + fn borrow_mut(&mut self) -> &mut OpsColumnsView { + unsafe { transmute(self) } + } +} + +impl Deref for OpsColumnsView { + type Target = [T; NUM_OPS_COLUMNS]; + fn deref(&self) -> &Self::Target { + unsafe { transmute(self) } + } +} + +impl DerefMut for OpsColumnsView { + fn deref_mut(&mut self) -> &mut Self::Target { + unsafe { transmute(self) } + } +} + +const fn make_col_map() -> OpsColumnsView { + let indices_arr = indices_arr::(); + unsafe { transmute::<[usize; NUM_OPS_COLUMNS], OpsColumnsView>(indices_arr) } +} + +pub const COL_MAP: OpsColumnsView = make_col_map(); diff --git a/evm/src/cpu/control_flow.rs b/evm/src/cpu/control_flow.rs index e6ded598..3856726c 100644 --- a/evm/src/cpu/control_flow.rs +++ b/evm/src/cpu/control_flow.rs @@ -10,31 +10,31 @@ use crate::cpu::kernel::aggregator::KERNEL; // TODO: This list is incomplete. const NATIVE_INSTRUCTIONS: [usize; 25] = [ - COL_MAP.is_add, - COL_MAP.is_mul, - COL_MAP.is_sub, - COL_MAP.is_div, - COL_MAP.is_sdiv, - COL_MAP.is_mod, - COL_MAP.is_smod, - COL_MAP.is_addmod, - COL_MAP.is_mulmod, - COL_MAP.is_signextend, - COL_MAP.is_lt, - COL_MAP.is_gt, - COL_MAP.is_slt, - COL_MAP.is_sgt, - COL_MAP.is_eq, - COL_MAP.is_iszero, - COL_MAP.is_and, - COL_MAP.is_or, - COL_MAP.is_xor, - COL_MAP.is_not, - COL_MAP.is_byte, - COL_MAP.is_shl, - COL_MAP.is_shr, - COL_MAP.is_sar, - COL_MAP.is_pop, + COL_MAP.op.add, + COL_MAP.op.mul, + COL_MAP.op.sub, + COL_MAP.op.div, + COL_MAP.op.sdiv, + COL_MAP.op.mod_, + COL_MAP.op.smod, + COL_MAP.op.addmod, + COL_MAP.op.mulmod, + COL_MAP.op.signextend, + COL_MAP.op.lt, + COL_MAP.op.gt, + COL_MAP.op.slt, + COL_MAP.op.sgt, + COL_MAP.op.eq, + COL_MAP.op.iszero, + COL_MAP.op.and, + COL_MAP.op.or, + COL_MAP.op.xor, + COL_MAP.op.not, + COL_MAP.op.byte, + COL_MAP.op.shl, + COL_MAP.op.shr, + COL_MAP.op.sar, + COL_MAP.op.pop, ]; fn get_halt_pcs() -> (F, F) { @@ -68,14 +68,16 @@ pub fn eval_packed_generic( lv.is_cpu_cycle * is_native_instruction * (lv.is_kernel_mode - nv.is_kernel_mode), ); - // If a non-CPU cycle row is followed by a CPU cycle row, then the `program_counter` of the CPU - // cycle row is route_txn (the entry point of our kernel) and it is in kernel mode. + // If a non-CPU cycle row is followed by a CPU cycle row, then: + // - the `program_counter` of the CPU cycle row is `main` (the entry point of our kernel), + // - execution is in kernel mode, and + // - the stack is empty. + let is_last_noncpu_cycle = (lv.is_cpu_cycle - P::ONES) * nv.is_cpu_cycle; let pc_diff = - nv.program_counter - P::Scalar::from_canonical_usize(KERNEL.global_labels["route_txn"]); - yield_constr.constraint_transition((lv.is_cpu_cycle - P::ONES) * nv.is_cpu_cycle * pc_diff); - yield_constr.constraint_transition( - (lv.is_cpu_cycle - P::ONES) * nv.is_cpu_cycle * (nv.is_kernel_mode - P::ONES), - ); + nv.program_counter - P::Scalar::from_canonical_usize(KERNEL.global_labels["main"]); + yield_constr.constraint_transition(is_last_noncpu_cycle * pc_diff); + yield_constr.constraint_transition(is_last_noncpu_cycle * (nv.is_kernel_mode - P::ONES)); + yield_constr.constraint_transition(is_last_noncpu_cycle * nv.stack_len); // The last row must be a CPU cycle row. yield_constr.constraint_last_row(lv.is_cpu_cycle - P::ONES); @@ -115,17 +117,32 @@ pub fn eval_ext_circuit, const D: usize>( yield_constr.constraint_transition(builder, kernel_constr); } - // If a non-CPU cycle row is followed by a CPU cycle row, then the `program_counter` of the CPU - // cycle row is route_txn (the entry point of our kernel) and it is in kernel mode. + // If a non-CPU cycle row is followed by a CPU cycle row, then: + // - the `program_counter` of the CPU cycle row is `main` (the entry point of our kernel), + // - execution is in kernel mode, and + // - the stack is empty. { - let filter = builder.mul_sub_extension(lv.is_cpu_cycle, nv.is_cpu_cycle, nv.is_cpu_cycle); - let route_txn = builder.constant_extension(F::Extension::from_canonical_usize( - KERNEL.global_labels["route_txn"], + let is_last_noncpu_cycle = + builder.mul_sub_extension(lv.is_cpu_cycle, nv.is_cpu_cycle, nv.is_cpu_cycle); + + // Start at `main`. + let main = builder.constant_extension(F::Extension::from_canonical_usize( + KERNEL.global_labels["main"], )); - let pc_diff = builder.sub_extension(nv.program_counter, route_txn); - let pc_constr = builder.mul_extension(filter, pc_diff); + let pc_diff = builder.sub_extension(nv.program_counter, main); + let pc_constr = builder.mul_extension(is_last_noncpu_cycle, pc_diff); yield_constr.constraint_transition(builder, pc_constr); - let kernel_constr = builder.mul_sub_extension(filter, nv.is_kernel_mode, filter); + + // Start in kernel mode + let kernel_constr = builder.mul_sub_extension( + is_last_noncpu_cycle, + nv.is_kernel_mode, + is_last_noncpu_cycle, + ); + yield_constr.constraint_transition(builder, kernel_constr); + + // Start with empty stack + let kernel_constr = builder.mul_extension(is_last_noncpu_cycle, nv.stack_len); yield_constr.constraint_transition(builder, kernel_constr); } diff --git a/evm/src/cpu/cpu_stark.rs b/evm/src/cpu/cpu_stark.rs index 9fd4792d..7ee204ca 100644 --- a/evm/src/cpu/cpu_stark.rs +++ b/evm/src/cpu/cpu_stark.rs @@ -1,4 +1,5 @@ use std::borrow::{Borrow, BorrowMut}; +use std::iter::repeat; use std::marker::PhantomData; use itertools::Itertools; @@ -9,9 +10,13 @@ use plonky2::hash::hash_types::RichField; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::cpu::columns::{CpuColumnsView, COL_MAP, NUM_CPU_COLUMNS}; -use crate::cpu::{bootstrap_kernel, control_flow, decode, jumps, simple_logic, syscalls}; +use crate::cpu::{ + bootstrap_kernel, control_flow, decode, jumps, membus, simple_logic, stack, stack_bounds, + syscalls, +}; use crate::cross_table_lookup::Column; -use crate::memory::NUM_CHANNELS; +use crate::memory::segments::Segment; +use crate::memory::{NUM_CHANNELS, VALUE_LIMBS}; use crate::stark::Stark; use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars}; @@ -23,14 +28,13 @@ pub fn ctl_data_keccak() -> Vec> { } pub fn ctl_data_keccak_memory() -> Vec> { - // When executing KECCAK_GENERAL, the memory channels are used as follows: - // channel 0: instruction - // channel 1: stack[-1] = context - // channel 2: stack[-2] = segment - // channel 3: stack[-3] = virtual - let context = Column::single(COL_MAP.mem_channels[1].value[0]); - let segment = Column::single(COL_MAP.mem_channels[2].value[0]); - let virt = Column::single(COL_MAP.mem_channels[3].value[0]); + // When executing KECCAK_GENERAL, the GP memory channels are used as follows: + // GP channel 0: stack[-1] = context + // GP channel 1: stack[-2] = segment + // GP channel 2: stack[-3] = virtual + let context = Column::single(COL_MAP.mem_channels[0].value[0]); + let segment = Column::single(COL_MAP.mem_channels[1].value[0]); + let virt = Column::single(COL_MAP.mem_channels[2].value[0]); let num_channels = F::from_canonical_usize(NUM_CHANNELS); let clock = Column::linear_combination([(COL_MAP.clock, num_channels)]); @@ -47,7 +51,7 @@ pub fn ctl_filter_keccak_memory() -> Column { } pub fn ctl_data_logic() -> Vec> { - let mut res = Column::singles([COL_MAP.is_and, COL_MAP.is_or, COL_MAP.is_xor]).collect_vec(); + let mut res = Column::singles([COL_MAP.op.and, COL_MAP.op.or, COL_MAP.op.xor]).collect_vec(); res.extend(Column::singles(COL_MAP.mem_channels[0].value)); res.extend(Column::singles(COL_MAP.mem_channels[1].value)); res.extend(Column::singles(COL_MAP.mem_channels[2].value)); @@ -55,32 +59,60 @@ pub fn ctl_data_logic() -> Vec> { } pub fn ctl_filter_logic() -> Column { - Column::sum([COL_MAP.is_and, COL_MAP.is_or, COL_MAP.is_xor]) + Column::sum([COL_MAP.op.and, COL_MAP.op.or, COL_MAP.op.xor]) } -pub fn ctl_data_memory(channel: usize) -> Vec> { - debug_assert!(channel < NUM_CHANNELS); +pub const MEM_CODE_CHANNEL_IDX: usize = 0; +pub const MEM_GP_CHANNELS_IDX_START: usize = MEM_CODE_CHANNEL_IDX + 1; + +/// Make the time/channel column for memory lookups. +fn mem_time_and_channel(channel: usize) -> Column { + let scalar = F::from_canonical_usize(NUM_CHANNELS); + let addend = F::from_canonical_usize(channel); + Column::linear_combination_with_constant([(COL_MAP.clock, scalar)], addend) +} + +pub fn ctl_data_code_memory() -> Vec> { + let mut cols = vec![ + Column::constant(F::ONE), // is_read + Column::single(COL_MAP.code_context), // addr_context + Column::constant(F::from_canonical_u64(Segment::Code as u64)), // addr_segment + Column::single(COL_MAP.program_counter), // addr_virtual + ]; + + // Low limb of the value matches the opcode bits + cols.push(Column::le_bits(COL_MAP.opcode_bits)); + + // High limbs of the value are all zero. + cols.extend(repeat(Column::constant(F::ZERO)).take(VALUE_LIMBS - 1)); + + cols.push(mem_time_and_channel(MEM_CODE_CHANNEL_IDX)); + + cols +} + +pub fn ctl_data_gp_memory(channel: usize) -> Vec> { let channel_map = COL_MAP.mem_channels[channel]; - let mut cols: Vec> = Column::singles([ + let mut cols: Vec<_> = Column::singles([ channel_map.is_read, channel_map.addr_context, channel_map.addr_segment, channel_map.addr_virtual, ]) - .collect_vec(); + .collect(); + cols.extend(Column::singles(channel_map.value)); - let scalar = F::from_canonical_usize(NUM_CHANNELS); - let addend = F::from_canonical_usize(channel); - cols.push(Column::linear_combination_with_constant( - [(COL_MAP.clock, scalar)], - addend, - )); + cols.push(mem_time_and_channel(MEM_GP_CHANNELS_IDX_START + channel)); cols } -pub fn ctl_filter_memory(channel: usize) -> Column { +pub fn ctl_filter_code_memory() -> Column { + Column::single(COL_MAP.is_cpu_cycle) +} + +pub fn ctl_filter_gp_memory(channel: usize) -> Column { Column::single(COL_MAP.mem_channels[channel].used) } @@ -93,7 +125,9 @@ impl CpuStark { pub fn generate(&self, local_values: &mut [F; NUM_CPU_COLUMNS]) { let local_values: &mut CpuColumnsView<_> = local_values.borrow_mut(); decode::generate(local_values); + membus::generate(local_values); simple_logic::generate(local_values); + stack_bounds::generate(local_values); // Must come after `decode`. } } @@ -114,7 +148,10 @@ impl, const D: usize> Stark for CpuStark, const D: usize> Stark for CpuStark(lv: &mut CpuColumnsView) { let cycle_filter = lv.is_cpu_cycle; if cycle_filter == F::ZERO { // These columns cannot be shared. - lv.is_eq = F::ZERO; - lv.is_iszero = F::ZERO; + lv.op.eq = F::ZERO; + lv.op.iszero = F::ZERO; return; } // This assert is not _strictly_ necessary, but I include it as a sanity check. @@ -196,7 +197,7 @@ pub fn generate(lv: &mut CpuColumnsView) { any_flag_set = any_flag_set || flag; } // is_invalid is a catch-all for opcodes we can't decode. - lv.is_invalid = F::from_bool(!any_flag_set); + lv.op.invalid = F::from_bool(!any_flag_set); } /// Break up an opcode (which is 8 bits long) into its eight bits. @@ -234,13 +235,13 @@ pub fn eval_packed_generic( let flag = lv[flag_col]; yield_constr.constraint(cycle_filter * flag * (flag - P::ONES)); } - yield_constr.constraint(cycle_filter * lv.is_invalid * (lv.is_invalid - P::ONES)); + yield_constr.constraint(cycle_filter * lv.op.invalid * (lv.op.invalid - P::ONES)); // Now check that exactly one is 1. let flag_sum: P = OPCODES .into_iter() .map(|(_, _, _, flag_col)| lv[flag_col]) .sum::

() - + lv.is_invalid; + + lv.op.invalid; yield_constr.constraint(cycle_filter * (P::ONES - flag_sum)); // Finally, classify all opcodes, together with the kernel flag, into blocks @@ -305,7 +306,7 @@ pub fn eval_ext_circuit, const D: usize>( yield_constr.constraint(builder, constr); } { - let constr = builder.mul_sub_extension(lv.is_invalid, lv.is_invalid, lv.is_invalid); + let constr = builder.mul_sub_extension(lv.op.invalid, lv.op.invalid, lv.op.invalid); let constr = builder.mul_extension(cycle_filter, constr); yield_constr.constraint(builder, constr); } @@ -316,7 +317,7 @@ pub fn eval_ext_circuit, const D: usize>( let flag = lv[flag_col]; constr = builder.sub_extension(constr, flag); } - constr = builder.sub_extension(constr, lv.is_invalid); + constr = builder.sub_extension(constr, lv.op.invalid); constr = builder.mul_extension(cycle_filter, constr); yield_constr.constraint(builder, constr); } diff --git a/evm/src/cpu/jumps.rs b/evm/src/cpu/jumps.rs index 219b39dd..fb13f83b 100644 --- a/evm/src/cpu/jumps.rs +++ b/evm/src/cpu/jumps.rs @@ -23,10 +23,10 @@ pub fn eval_packed_exit_kernel( // flag. The top 6 (32-bit) limbs are ignored (this is not part of the spec, but we trust the // kernel to set them to zero). yield_constr.constraint_transition( - lv.is_cpu_cycle * lv.is_exit_kernel * (input[0] - nv.program_counter), + lv.is_cpu_cycle * lv.op.exit_kernel * (input[0] - nv.program_counter), ); yield_constr.constraint_transition( - lv.is_cpu_cycle * lv.is_exit_kernel * (input[1] - nv.is_kernel_mode), + lv.is_cpu_cycle * lv.op.exit_kernel * (input[1] - nv.is_kernel_mode), ); } @@ -37,7 +37,7 @@ pub fn eval_ext_circuit_exit_kernel, const D: usize yield_constr: &mut RecursiveConstraintConsumer, ) { let input = lv.mem_channels[0].value; - let filter = builder.mul_extension(lv.is_cpu_cycle, lv.is_exit_kernel); + let filter = builder.mul_extension(lv.is_cpu_cycle, lv.op.exit_kernel); // If we are executing `EXIT_KERNEL` then we simply restore the program counter and kernel mode // flag. The top 6 (32-bit) limbs are ignored (this is not part of the spec, but we trust the @@ -60,16 +60,16 @@ pub fn eval_packed_jump_jumpi( let jumps_lv = lv.general.jumps(); let input0 = lv.mem_channels[0].value; let input1 = lv.mem_channels[1].value; - let filter = lv.is_jump + lv.is_jumpi; // `JUMP` or `JUMPI` + let filter = lv.op.jump + lv.op.jumpi; // `JUMP` or `JUMPI` // If `JUMP`, re-use the `JUMPI` logic, but setting the second input (the predicate) to be 1. // In other words, we implement `JUMP(addr)` as `JUMPI(addr, cond=1)`. - yield_constr.constraint(lv.is_jump * (input1[0] - P::ONES)); + yield_constr.constraint(lv.op.jump * (input1[0] - P::ONES)); for &limb in &input1[1..] { // Set all limbs (other than the least-significant limb) to 0. // NB: Technically, they don't have to be 0, as long as the sum // `input1[0] + ... + input1[7]` cannot overflow. - yield_constr.constraint(lv.is_jump * limb); + yield_constr.constraint(lv.op.jump * limb); } // Check `input0_upper_zero` @@ -162,19 +162,19 @@ pub fn eval_ext_circuit_jump_jumpi, const D: usize> let jumps_lv = lv.general.jumps(); let input0 = lv.mem_channels[0].value; let input1 = lv.mem_channels[1].value; - let filter = builder.add_extension(lv.is_jump, lv.is_jumpi); // `JUMP` or `JUMPI` + let filter = builder.add_extension(lv.op.jump, lv.op.jumpi); // `JUMP` or `JUMPI` // If `JUMP`, re-use the `JUMPI` logic, but setting the second input (the predicate) to be 1. // In other words, we implement `JUMP(addr)` as `JUMPI(addr, cond=1)`. { - let constr = builder.mul_sub_extension(lv.is_jump, input1[0], lv.is_jump); + let constr = builder.mul_sub_extension(lv.op.jump, input1[0], lv.op.jump); yield_constr.constraint(builder, constr); } for &limb in &input1[1..] { // Set all limbs (other than the least-significant limb) to 0. // NB: Technically, they don't have to be 0, as long as the sum // `input1[0] + ... + input1[7]` cannot overflow. - let constr = builder.mul_extension(lv.is_jump, limb); + let constr = builder.mul_extension(lv.op.jump, limb); yield_constr.constraint(builder, constr); } diff --git a/evm/src/cpu/kernel/aggregator.rs b/evm/src/cpu/kernel/aggregator.rs index dda006e6..002a84fb 100644 --- a/evm/src/cpu/kernel/aggregator.rs +++ b/evm/src/cpu/kernel/aggregator.rs @@ -33,6 +33,7 @@ pub(crate) fn combined_kernel() -> Kernel { include_str!("asm/curve/secp256k1/moddiv.asm"), include_str!("asm/exp.asm"), include_str!("asm/halt.asm"), + include_str!("asm/main.asm"), include_str!("asm/memory/core.asm"), include_str!("asm/memory/memcpy.asm"), include_str!("asm/memory/metadata.asm"), @@ -41,8 +42,15 @@ pub(crate) fn combined_kernel() -> Kernel { include_str!("asm/rlp/encode.asm"), include_str!("asm/rlp/decode.asm"), include_str!("asm/rlp/read_to_memory.asm"), - include_str!("asm/storage/read.asm"), - include_str!("asm/storage/write.asm"), + include_str!("asm/mpt/hash.asm"), + include_str!("asm/mpt/hash_trie_specific.asm"), + include_str!("asm/mpt/hex_prefix.asm"), + include_str!("asm/mpt/load.asm"), + include_str!("asm/mpt/read.asm"), + include_str!("asm/mpt/storage_read.asm"), + include_str!("asm/mpt/storage_write.asm"), + include_str!("asm/mpt/util.asm"), + include_str!("asm/mpt/write.asm"), include_str!("asm/transactions/router.asm"), include_str!("asm/transactions/type_0.asm"), include_str!("asm/transactions/type_1.asm"), diff --git a/evm/src/cpu/kernel/asm/core/call.asm b/evm/src/cpu/kernel/asm/core/call.asm index 3cbbb441..1b8a535f 100644 --- a/evm/src/cpu/kernel/asm/core/call.asm +++ b/evm/src/cpu/kernel/asm/core/call.asm @@ -2,21 +2,21 @@ // Creates a new sub context and executes the code of the given account. global call: - // stack: gas, address, value, args_offset, args_size, ret_offset, ret_size + // stack: gas, address, value, args_offset, args_size, ret_offset, ret_size, retdest %address %stack (self, gas, address, value) - // These are (should_transfer_value, value, static, gas, sender, storage, code_addr) - -> (1, value, 0, gas, self, address, address) + // These are (static, should_transfer_value, value, sender, address, code_addr, gas) + -> (0, 1, value, self, address, address, gas) %jump(call_common) // Creates a new sub context as if calling itself, but with the code of the // given account. In particular the storage remains the same. global call_code: - // stack: gas, address, value, args_offset, args_size, ret_offset, ret_size + // stack: gas, address, value, args_offset, args_size, ret_offset, ret_size, retdest %address %stack (self, gas, address, value) - // These are (should_transfer_value, value, static, gas, sender, storage, code_addr) - -> (1, value, 0, gas, self, self, address) + // These are (static, should_transfer_value, value, sender, address, code_addr, gas) + -> (0, 1, value, self, self, address, gas) %jump(call_common) // Creates a new sub context and executes the code of the given account. @@ -25,35 +25,86 @@ global call_code: // are CREATE, CREATE2, LOG0, LOG1, LOG2, LOG3, LOG4, SSTORE, SELFDESTRUCT and // CALL if the value sent is not 0. global static_all: - // stack: gas, address, args_offset, args_size, ret_offset, ret_size + // stack: gas, address, args_offset, args_size, ret_offset, ret_size, retdest %address %stack (self, gas, address) - // These are (should_transfer_value, value, static, gas, sender, storage, code_addr) - -> (0, 0, 1, gas, self, address, address) + // These are (static, should_transfer_value, value, sender, address, code_addr, gas) + -> (1, 0, 0, self, address, address, gas) %jump(call_common) // Creates a new sub context as if calling itself, but with the code of the // given account. In particular the storage, the current sender and the current // value remain the same. global delegate_call: - // stack: gas, address, args_offset, args_size, ret_offset, ret_size + // stack: gas, address, args_offset, args_size, ret_offset, ret_size, retdest %address %sender %callvalue %stack (self, sender, value, gas, address) - // These are (should_transfer_value, value, static, gas, sender, storage, code_addr) - -> (0, value, 0, gas, sender, self, address) + // These are (static, should_transfer_value, value, sender, address, code_addr, gas) + -> (0, 0, value, sender, self, address, gas) %jump(call_common) call_common: - // stack: should_transfer_value, value, static, gas, sender, storage, code_addr, args_offset, args_size, ret_offset, ret_size - // TODO: Set all the appropriate metadata fields... + // stack: static, should_transfer_value, value, sender, address, code_addr, gas, args_offset, args_size, ret_offset, ret_size, retdest %create_context - // stack: new_ctx, after_call + // Store the static flag in metadata. + %stack (new_ctx, static) -> (new_ctx, @SEGMENT_CONTEXT_METADATA, @CTX_METADATA_STATIC, static, new_ctx) + MSTORE_GENERAL + // stack: new_ctx, should_transfer_value, value, sender, address, code_addr, gas, args_offset, args_size, ret_offset, ret_size, retdest + + // Store the address in metadata. + %stack (new_ctx, should_transfer_value, value, sender, address) + -> (new_ctx, @SEGMENT_CONTEXT_METADATA, @CTX_METADATA_ADDRESS, address, + new_ctx, should_transfer_value, value, sender, address) + MSTORE_GENERAL + // stack: new_ctx, should_transfer_value, value, sender, address, code_addr, gas, args_offset, args_size, ret_offset, ret_size, retdest + + // Store the caller in metadata. + %stack (new_ctx, should_transfer_value, value, sender) + -> (new_ctx, @SEGMENT_CONTEXT_METADATA, @CTX_METADATA_CALLER, sender, + new_ctx, should_transfer_value, value, sender) + MSTORE_GENERAL + // stack: new_ctx, should_transfer_value, value, sender, address, code_addr, gas, args_offset, args_size, ret_offset, ret_size, retdest + + // Store the call value field in metadata. + %stack (new_ctx, should_transfer_value, value, sender, address) = + -> (new_ctx, @SEGMENT_CONTEXT_METADATA, @CTX_METADATA_CALL_VALUE, value, + should_transfer_value, sender, address, value, new_ctx) + MSTORE_GENERAL + // stack: should_transfer_value, sender, address, value, new_ctx, code_addr, gas, args_offset, args_size, ret_offset, ret_size, retdest + + %maybe_transfer_eth + // stack: new_ctx, code_addr, gas, args_offset, args_size, ret_offset, ret_size, retdest + + // Store parent context in metadata. + GET_CONTEXT + PUSH @CTX_METADATA_PARENT_CONTEXT + PUSH @SEGMENT_CONTEXT_METADATA + DUP4 // new_ctx + MSTORE_GENERAL + // stack: new_ctx, code_addr, gas, args_offset, args_size, ret_offset, ret_size, retdest + + // Store parent PC = after_call. + %stack (new_ctx) -> (new_ctx, @SEGMENT_CONTEXT_METADATA, @CTX_METADATA_PARENT_PC, after_call, new_ctx) + MSTORE_GENERAL + // stack: new_ctx, code_addr, gas, args_offset, args_size, ret_offset, ret_size, retdest + + // TODO: Populate CALLDATA + // TODO: Save parent gas and set child gas + // TODO: Populate code + + // TODO: Temporary, remove after above steps are done. + %stack (new_ctx, code_addr, gas, args_offset, args_size) -> (new_ctx) + // stack: new_ctx, ret_offset, ret_size, retdest + // Now, switch to the new context and go to usermode with PC=0. + DUP1 // new_ctx SET_CONTEXT - PUSH 0 + PUSH 0 // jump dest EXIT_KERNEL after_call: - // TODO: Set RETURNDATA etc. + // stack: new_ctx, ret_offset, ret_size, retdest + // TODO: Set RETURNDATA. + // TODO: Return to caller w/ EXIT_KERNEL. diff --git a/evm/src/cpu/kernel/asm/core/intrinsic_gas.asm b/evm/src/cpu/kernel/asm/core/intrinsic_gas.asm index 3c6f47b6..931a6a7b 100644 --- a/evm/src/cpu/kernel/asm/core/intrinsic_gas.asm +++ b/evm/src/cpu/kernel/asm/core/intrinsic_gas.asm @@ -1,6 +1,3 @@ -// After the transaction data has been parsed into a normalized set of fields -// (see NormalizedTxnField), this routine processes the transaction. - global intrinsic_gas: // stack: retdest // Calculate the number of zero and nonzero bytes in the txn data. diff --git a/evm/src/cpu/kernel/asm/core/process_txn.asm b/evm/src/cpu/kernel/asm/core/process_txn.asm index b3d92131..ac52c53d 100644 --- a/evm/src/cpu/kernel/asm/core/process_txn.asm +++ b/evm/src/cpu/kernel/asm/core/process_txn.asm @@ -1,6 +1,8 @@ // After the transaction data has been parsed into a normalized set of fields // (see NormalizedTxnField), this routine processes the transaction. +// TODO: Save checkpoints in @CTX_METADATA_STATE_TRIE_CHECKPOINT_PTR and @SEGMENT_STORAGE_TRIE_CHECKPOINT_PTRS. + global process_normalized_txn: // stack: (empty) PUSH validate diff --git a/evm/src/cpu/kernel/asm/core/transfer.asm b/evm/src/cpu/kernel/asm/core/transfer.asm index 41057aff..b12bc9de 100644 --- a/evm/src/cpu/kernel/asm/core/transfer.asm +++ b/evm/src/cpu/kernel/asm/core/transfer.asm @@ -1,11 +1,15 @@ // Transfers some ETH from one address to another. The amount is given in wei. // Pre stack: from, to, amount, retdest // Post stack: (empty) - global transfer_eth: // stack: from, to, amount, retdest - // TODO: Replace with actual implementation. - %pop3 + %stack (from, to, amount, retdest) + -> (from, amount, to, amount) + %deduct_eth + // TODO: Handle exception from %deduct_eth? + // stack: to, amount, retdest + %add_eth + // stack: retdest JUMP // Convenience macro to call transfer_eth and return where we left off. @@ -14,3 +18,38 @@ global transfer_eth: %jump(transfer_eth) %%after: %endmacro + +// Pre stack: should_transfer, from, to, amount +// Post stack: (empty) +%macro maybe_transfer_eth + %jumpi(%%transfer) + // We're skipping the transfer, so just pop the arguments and return. + %pop3 + %jump(%%after) +%%transfer: + %transfer_eth +%%after: +%endmacro + +global deduct_eth: + // stack: addr, amount, retdest + %jump(mpt_read_state_trie) +deduct_eth_after_read: + PANIC // TODO + +// Convenience macro to call deduct_eth and return where we left off. +%macro deduct_eth + %stack (addr, amount) -> (addr, amount, %%after) + %jump(deduct_eth) +%%after: +%endmacro + +global add_eth: + PANIC // TODO + +// Convenience macro to call add_eth and return where we left off. +%macro add_eth + %stack (addr, amount) -> (addr, amount, %%after) + %jump(add_eth) +%%after: +%endmacro diff --git a/evm/src/cpu/kernel/asm/curve/bn254/curve_add.asm b/evm/src/cpu/kernel/asm/curve/bn254/curve_add.asm index 15f9df05..dda82109 100644 --- a/evm/src/cpu/kernel/asm/curve/bn254/curve_add.asm +++ b/evm/src/cpu/kernel/asm/curve/bn254/curve_add.asm @@ -9,7 +9,6 @@ global ec_add: // PUSH 1 // PUSH 0x1bf9384aa3f0b3ad763aee81940cacdde1af71617c06f46e11510f14f3d5d121 // PUSH 0xe7313274bb29566ff0c8220eb9841de1d96c2923c6a4028f7dd3c6a14cee770 - JUMPDEST // stack: x0, y0, x1, y1, retdest // Check if points are valid BN254 points. @@ -38,7 +37,6 @@ global ec_add: // BN254 elliptic curve addition. // Assumption: (x0,y0) and (x1,y1) are valid points. global ec_add_valid_points: - JUMPDEST // stack: x0, y0, x1, y1, retdest // Check if the first point is the identity. @@ -92,7 +90,6 @@ global ec_add_valid_points: // BN254 elliptic curve addition. // Assumption: (x0,y0) == (0,0) ec_add_first_zero: - JUMPDEST // stack: x0, y0, x1, y1, retdest // Just return (x1,y1) %stack (x0, y0, x1, y1, retdest) -> (retdest, x1, y1) @@ -101,7 +98,6 @@ ec_add_first_zero: // BN254 elliptic curve addition. // Assumption: (x1,y1) == (0,0) ec_add_snd_zero: - JUMPDEST // stack: x0, y0, x1, y1, retdest // Just return (x0,y0) @@ -111,7 +107,6 @@ ec_add_snd_zero: // BN254 elliptic curve addition. // Assumption: lambda = (y0 - y1)/(x0 - x1) ec_add_valid_points_with_lambda: - JUMPDEST // stack: lambda, x0, y0, x1, y1, retdest // Compute x2 = lambda^2 - x1 - x0 @@ -159,7 +154,6 @@ ec_add_valid_points_with_lambda: // BN254 elliptic curve addition. // Assumption: (x0,y0) and (x1,y1) are valid points and x0 == x1 ec_add_equal_first_coord: - JUMPDEST // stack: x0, y0, x1, y1, retdest with x0 == x1 // Check if the points are equal @@ -188,7 +182,6 @@ ec_add_equal_first_coord: // Assumption: x0 == x1 and y0 == y1 // Standard doubling formula. ec_add_equal_points: - JUMPDEST // stack: x0, y0, x1, y1, retdest // Compute lambda = 3/2 * x0^2 / y0 @@ -216,7 +209,6 @@ ec_add_equal_points: // Assumption: (x0,y0) is a valid point. // Standard doubling formula. global ec_double: - JUMPDEST // stack: x0, y0, retdest DUP2 // stack: y0, x0, y0, retdest diff --git a/evm/src/cpu/kernel/asm/curve/bn254/curve_mul.asm b/evm/src/cpu/kernel/asm/curve/bn254/curve_mul.asm index 62cf2235..b1472812 100644 --- a/evm/src/cpu/kernel/asm/curve/bn254/curve_mul.asm +++ b/evm/src/cpu/kernel/asm/curve/bn254/curve_mul.asm @@ -6,7 +6,6 @@ global ec_mul: // PUSH 0xd // PUSH 2 // PUSH 1 - JUMPDEST // stack: x, y, s, retdest DUP2 // stack: y, x, y, s, retdest @@ -29,7 +28,6 @@ global ec_mul: // Same algorithm as in `exp.asm` ec_mul_valid_point: - JUMPDEST // stack: x, y, s, retdest DUP3 // stack: s, x, y, s, retdest @@ -38,7 +36,6 @@ ec_mul_valid_point: %jump(ret_zero_ec_mul) step_case: - JUMPDEST // stack: x, y, s, retdest PUSH recursion_return // stack: recursion_return, x, y, s, retdest @@ -58,12 +55,10 @@ step_case: // Assumption: 2(x,y) = (x',y') step_case_contd: - JUMPDEST // stack: x', y', s / 2, recursion_return, x, y, s, retdest %jump(ec_mul_valid_point) recursion_return: - JUMPDEST // stack: x', y', x, y, s, retdest SWAP4 // stack: s, y', x, y, x', retdest @@ -96,6 +91,5 @@ recursion_return: JUMP odd_scalar: - JUMPDEST // stack: x', y', x, y, retdest %jump(ec_add_valid_points) diff --git a/evm/src/cpu/kernel/asm/curve/common.asm b/evm/src/cpu/kernel/asm/curve/common.asm index 107dc63c..9e273c15 100644 --- a/evm/src/cpu/kernel/asm/curve/common.asm +++ b/evm/src/cpu/kernel/asm/curve/common.asm @@ -1,5 +1,4 @@ global ret_zero_ec_mul: - JUMPDEST // stack: x, y, s, retdest %pop3 // stack: retdest diff --git a/evm/src/cpu/kernel/asm/curve/secp256k1/curve_add.asm b/evm/src/cpu/kernel/asm/curve/secp256k1/curve_add.asm index 7f9c1fff..790fb116 100644 --- a/evm/src/cpu/kernel/asm/curve/secp256k1/curve_add.asm +++ b/evm/src/cpu/kernel/asm/curve/secp256k1/curve_add.asm @@ -3,7 +3,6 @@ // Secp256k1 elliptic curve addition. // Assumption: (x0,y0) and (x1,y1) are valid points. global ec_add_valid_points_secp: - JUMPDEST // stack: x0, y0, x1, y1, retdest // Check if the first point is the identity. @@ -57,7 +56,6 @@ global ec_add_valid_points_secp: // Secp256k1 elliptic curve addition. // Assumption: (x0,y0) == (0,0) ec_add_first_zero: - JUMPDEST // stack: x0, y0, x1, y1, retdest // Just return (x1,y1) @@ -72,7 +70,6 @@ ec_add_first_zero: // Secp256k1 elliptic curve addition. // Assumption: (x1,y1) == (0,0) ec_add_snd_zero: - JUMPDEST // stack: x0, y0, x1, y1, retdest // Just return (x1,y1) @@ -93,7 +90,6 @@ ec_add_snd_zero: // Secp256k1 elliptic curve addition. // Assumption: lambda = (y0 - y1)/(x0 - x1) ec_add_valid_points_with_lambda: - JUMPDEST // stack: lambda, x0, y0, x1, y1, retdest // Compute x2 = lambda^2 - x1 - x0 @@ -150,7 +146,6 @@ ec_add_valid_points_with_lambda: // Secp256k1 elliptic curve addition. // Assumption: (x0,y0) and (x1,y1) are valid points and x0 == x1 ec_add_equal_first_coord: - JUMPDEST // stack: x0, y0, x1, y1, retdest with x0 == x1 // Check if the points are equal @@ -179,7 +174,6 @@ ec_add_equal_first_coord: // Assumption: x0 == x1 and y0 == y1 // Standard doubling formula. ec_add_equal_points: - JUMPDEST // stack: x0, y0, x1, y1, retdest // Compute lambda = 3/2 * x0^2 / y0 @@ -207,7 +201,6 @@ ec_add_equal_points: // Assumption: (x0,y0) is a valid point. // Standard doubling formula. global ec_double_secp: - JUMPDEST // stack: x0, y0, retdest DUP2 // stack: y0, x0, y0, retdest diff --git a/evm/src/cpu/kernel/asm/curve/secp256k1/curve_mul.asm b/evm/src/cpu/kernel/asm/curve/secp256k1/curve_mul.asm index f0825e88..892d57c0 100644 --- a/evm/src/cpu/kernel/asm/curve/secp256k1/curve_mul.asm +++ b/evm/src/cpu/kernel/asm/curve/secp256k1/curve_mul.asm @@ -1,6 +1,5 @@ // Same algorithm as in `exp.asm` global ec_mul_valid_point_secp: - JUMPDEST // stack: x, y, s, retdest %stack (x,y) -> (x,y,x,y) %ec_isidentity @@ -13,7 +12,6 @@ global ec_mul_valid_point_secp: %jump(ret_zero_ec_mul) step_case: - JUMPDEST // stack: x, y, s, retdest PUSH recursion_return // stack: recursion_return, x, y, s, retdest @@ -33,12 +31,10 @@ step_case: // Assumption: 2(x,y) = (x',y') step_case_contd: - JUMPDEST // stack: x', y', s / 2, recursion_return, x, y, s, retdest %jump(ec_mul_valid_point_secp) recursion_return: - JUMPDEST // stack: x', y', x, y, s, retdest SWAP4 // stack: s, y', x, y, x', retdest @@ -71,6 +67,5 @@ recursion_return: JUMP odd_scalar: - JUMPDEST // stack: x', y', x, y, retdest %jump(ec_add_valid_points_secp) diff --git a/evm/src/cpu/kernel/asm/curve/secp256k1/ecrecover.asm b/evm/src/cpu/kernel/asm/curve/secp256k1/ecrecover.asm index 538a86dc..96e177ff 100644 --- a/evm/src/cpu/kernel/asm/curve/secp256k1/ecrecover.asm +++ b/evm/src/cpu/kernel/asm/curve/secp256k1/ecrecover.asm @@ -1,6 +1,5 @@ // ecrecover precompile. global ecrecover: - JUMPDEST // stack: hash, v, r, s, retdest // Check if inputs are valid. @@ -47,7 +46,6 @@ global ecrecover: // let u2 = -hash * r_inv; // return u1*P + u2*GENERATOR; ecrecover_valid_input: - JUMPDEST // stack: hash, y, r, s, retdest // Compute u1 = s * r^(-1) @@ -83,7 +81,6 @@ ecrecover_valid_input: // ecrecover precompile. // Assumption: (X,Y) = u1 * P. Result is (X,Y) + u2*GENERATOR ecrecover_with_first_point: - JUMPDEST // stack: X, Y, hash, r^(-1), retdest %secp_scalar // stack: p, X, Y, hash, r^(-1), retdest @@ -132,7 +129,6 @@ ecrecover_with_first_point: // Take a public key (PKx, PKy) and return the associated address KECCAK256(PKx || PKy)[-20:]. pubkey_to_addr: - JUMPDEST // stack: PKx, PKy, retdest PUSH 0 // stack: 0, PKx, PKy, retdest diff --git a/evm/src/cpu/kernel/asm/exp.asm b/evm/src/cpu/kernel/asm/exp.asm index 3640b2f6..f025e312 100644 --- a/evm/src/cpu/kernel/asm/exp.asm +++ b/evm/src/cpu/kernel/asm/exp.asm @@ -10,7 +10,6 @@ /// Note that this correctly handles exp(0, 0) == 1. global exp: - jumpdest // stack: x, e, retdest dup2 // stack: e, x, e, retdest @@ -27,7 +26,6 @@ global exp: jump step_case: - jumpdest // stack: x, e, retdest push recursion_return // stack: recursion_return, x, e, retdest @@ -43,7 +41,6 @@ step_case: // stack: x * x, e / 2, recursion_return, x, e, retdest %jump(exp) recursion_return: - jumpdest // stack: exp(x * x, e / 2), x, e, retdest push 2 // stack: 2, exp(x * x, e / 2), x, e, retdest diff --git a/evm/src/cpu/kernel/asm/keccak.asm b/evm/src/cpu/kernel/asm/keccak.asm deleted file mode 100644 index d464bb6a..00000000 --- a/evm/src/cpu/kernel/asm/keccak.asm +++ /dev/null @@ -1,8 +0,0 @@ -// Computes the Keccak256 hash of some arbitrary bytes in memory. -// The given memory values should be in the range of a byte. -// -// Pre stack: ADDR, len, retdest -// Post stack: hash -global keccak_general: - // stack: ADDR, len - // TODO diff --git a/evm/src/cpu/kernel/asm/main.asm b/evm/src/cpu/kernel/asm/main.asm new file mode 100644 index 00000000..e8c8e3e4 --- /dev/null +++ b/evm/src/cpu/kernel/asm/main.asm @@ -0,0 +1,24 @@ +global main: + // First, load all MPT data from the prover. + PUSH txn_loop + %jump(load_all_mpts) + +hash_initial_tries: + %mpt_hash_state_trie %mstore_global_metadata(@GLOBAL_METADATA_STATE_TRIE_DIGEST_BEFORE) + %mpt_hash_txn_trie %mstore_global_metadata(@GLOBAL_METADATA_TXN_TRIE_DIGEST_BEFORE) + %mpt_hash_receipt_trie %mstore_global_metadata(@GLOBAL_METADATA_RECEIPT_TRIE_DIGEST_BEFORE) + +txn_loop: + // If the prover has no more txns for us to process, halt. + PROVER_INPUT(end_of_txns) + %jumpi(hash_final_tries) + + // Call route_txn. When we return, continue the txn loop. + PUSH txn_loop + %jump(route_txn) + +hash_final_tries: + %mpt_hash_state_trie %mstore_global_metadata(@GLOBAL_METADATA_STATE_TRIE_DIGEST_AFTER) + %mpt_hash_txn_trie %mstore_global_metadata(@GLOBAL_METADATA_TXN_TRIE_DIGEST_AFTER) + %mpt_hash_receipt_trie %mstore_global_metadata(@GLOBAL_METADATA_RECEIPT_TRIE_DIGEST_AFTER) + %jump(halt) diff --git a/evm/src/cpu/kernel/asm/memory/core.asm b/evm/src/cpu/kernel/asm/memory/core.asm index 73bafbee..6722b0ca 100644 --- a/evm/src/cpu/kernel/asm/memory/core.asm +++ b/evm/src/cpu/kernel/asm/memory/core.asm @@ -98,3 +98,10 @@ %mstore_kernel(@SEGMENT_CODE) // stack: (empty) %endmacro + +// Store a single byte to @SEGMENT_RLP_RAW. +%macro mstore_rlp + // stack: offset, value + %mstore_kernel(@SEGMENT_RLP_RAW) + // stack: (empty) +%endmacro diff --git a/evm/src/cpu/kernel/asm/memory/memcpy.asm b/evm/src/cpu/kernel/asm/memory/memcpy.asm index 0a390736..3feca35d 100644 --- a/evm/src/cpu/kernel/asm/memory/memcpy.asm +++ b/evm/src/cpu/kernel/asm/memory/memcpy.asm @@ -4,7 +4,6 @@ // DST = (dst_ctx, dst_segment, dst_addr). // These tuple definitions are used for brevity in the stack comments below. global memcpy: - JUMPDEST // stack: DST, SRC, count, retdest DUP7 // stack: count, DST, SRC, count, retdest @@ -44,7 +43,6 @@ global memcpy: %jump(memcpy) memcpy_finish: - JUMPDEST // stack: DST, SRC, count, retdest %pop7 // stack: retdest diff --git a/evm/src/cpu/kernel/asm/memory/metadata.asm b/evm/src/cpu/kernel/asm/memory/metadata.asm index 23c45d13..644699e0 100644 --- a/evm/src/cpu/kernel/asm/memory/metadata.asm +++ b/evm/src/cpu/kernel/asm/memory/metadata.asm @@ -12,7 +12,7 @@ // stack: value PUSH $field // stack: offset, value - %mload_kernel(@SEGMENT_GLOBAL_METADATA) + %mstore_kernel(@SEGMENT_GLOBAL_METADATA) // stack: (empty) %endmacro @@ -30,18 +30,18 @@ // stack: value PUSH $field // stack: offset, value - %mload_current(@SEGMENT_CONTEXT_METADATA) + %mstore_current(@SEGMENT_CONTEXT_METADATA) // stack: (empty) %endmacro %macro address - %mload_context_metadata(0) // TODO: Read proper field. + %mload_context_metadata(@CTX_METADATA_ADDRESS) %endmacro %macro sender - %mload_context_metadata(0) // TODO: Read proper field. + %mload_context_metadata(@CTX_METADATA_CALLER) %endmacro %macro callvalue - %mload_context_metadata(0) // TODO: Read proper field. + %mload_context_metadata(@CTX_METADATA_CALL_VALUE) %endmacro diff --git a/evm/src/cpu/kernel/asm/mpt/hash.asm b/evm/src/cpu/kernel/asm/mpt/hash.asm new file mode 100644 index 00000000..053f357c --- /dev/null +++ b/evm/src/cpu/kernel/asm/mpt/hash.asm @@ -0,0 +1,82 @@ +// Computes the Merkle root of the given trie node. +// +// The encode_value function should take as input +// - the position withing @SEGMENT_RLP_RAW to write to, +// - the offset of a value within @SEGMENT_TRIE_DATA, and +// - a return address. +// It should serialize the value, write it to @SEGMENT_RLP_RAW starting at the +// given position, and return an updated position (the next unused offset). +%macro mpt_hash(encode_value) + // stack: node_ptr, retdest + DUP1 + %mload_trie_data + // stack: node_type, node_ptr, retdest + // Increment node_ptr, so it points to the node payload instead of its type. + SWAP1 %add_const(1) SWAP1 + // stack: node_type, node_payload_ptr, retdest + + DUP1 %eq_const(@MPT_NODE_EMPTY) %jumpi(mpt_hash_empty) + DUP1 %eq_const(@MPT_NODE_HASH) %jumpi(mpt_hash_hash) + DUP1 %eq_const(@MPT_NODE_BRANCH) %jumpi(%%mpt_hash_branch) + DUP1 %eq_const(@MPT_NODE_EXTENSION) %jumpi(%%mpt_hash_extension) + DUP1 %eq_const(@MPT_NODE_LEAF) %jumpi(%%mpt_hash_leaf) + PANIC // Invalid node type? Shouldn't get here. + +%%mpt_hash_branch: + // stack: node_type, node_payload_ptr, retdest + POP + // stack: node_payload_ptr, retdest + PANIC // TODO + +%%mpt_hash_extension: + // stack: node_type, node_payload_ptr, retdest + POP + // stack: node_payload_ptr, retdest + PANIC // TODO + +%%mpt_hash_leaf: + // stack: node_type, node_payload_ptr, retdest + POP + // stack: node_payload_ptr, retdest + PUSH %%mpt_hash_leaf_after_hex_prefix // retdest + PUSH 1 // terminated + // stack: terminated, %%mpt_hash_leaf_after_hex_prefix, node_payload_ptr, retdest + DUP3 %add_const(1) %mload_trie_data // Load the packed_nibbles field, which is at index 1. + // stack: packed_nibbles, terminated, %%mpt_hash_leaf_after_hex_prefix, node_payload_ptr, retdest + DUP4 %mload_trie_data // Load the num_nibbles field, which is at index 0. + // stack: num_nibbles, packed_nibbles, terminated, %%mpt_hash_leaf_after_hex_prefix, node_payload_ptr, retdest + PUSH 9 // We start at 9 to leave room to prepend the largest possible RLP list header. + // stack: rlp_start, num_nibbles, packed_nibbles, terminated, %%mpt_hash_leaf_after_hex_prefix, node_payload_ptr, retdest + %jump(hex_prefix_rlp) +%%mpt_hash_leaf_after_hex_prefix: + // stack: rlp_pos, node_payload_ptr, retdest + SWAP1 + %add_const(2) // The value starts at index 2. + %stack (value_ptr, rlp_pos, retdest) + -> (rlp_pos, value_ptr, %%mpt_hash_leaf_after_encode_value, retdest) + %jump($encode_value) +%%mpt_hash_leaf_after_encode_value: + // stack: rlp_end_pos, retdest + %prepend_rlp_list_prefix + // stack: rlp_start_pos, rlp_len, retdest + PUSH @SEGMENT_RLP_RAW + PUSH 0 // kernel context + // stack: rlp_start_addr: 3, rlp_len, retdest + KECCAK_GENERAL + // stack: hash, retdest + SWAP1 + JUMP +%endmacro + +global mpt_hash_empty: + %stack (node_type, node_payload_ptr, retdest) -> (retdest, @EMPTY_NODE_HASH) + JUMP + +global mpt_hash_hash: + // stack: node_type, node_payload_ptr, retdest + POP + // stack: node_payload_ptr, retdest + %mload_trie_data + // stack: hash, retdest + SWAP1 + JUMP diff --git a/evm/src/cpu/kernel/asm/mpt/hash_trie_specific.asm b/evm/src/cpu/kernel/asm/mpt/hash_trie_specific.asm new file mode 100644 index 00000000..30ea730f --- /dev/null +++ b/evm/src/cpu/kernel/asm/mpt/hash_trie_specific.asm @@ -0,0 +1,80 @@ +// Hashing logic specific to a particular trie. + +global mpt_hash_state_trie: + // stack: retdest + %mload_global_metadata(@GLOBAL_METADATA_STATE_TRIE_ROOT) + // stack: node_ptr, retdest + %mpt_hash(encode_account) + +%macro mpt_hash_state_trie + PUSH %%after + %jump(mpt_hash_state_trie) +%%after: +%endmacro + +global mpt_hash_txn_trie: + // stack: retdest + %mload_global_metadata(@GLOBAL_METADATA_TXN_TRIE_ROOT) + // stack: node_ptr, retdest + %mpt_hash(encode_txn) + +%macro mpt_hash_txn_trie + PUSH %%after + %jump(mpt_hash_txn_trie) +%%after: +%endmacro + +global mpt_hash_receipt_trie: + // stack: retdest + %mload_global_metadata(@GLOBAL_METADATA_RECEIPT_TRIE_ROOT) + // stack: node_ptr, retdest + %mpt_hash(encode_receipt) + +%macro mpt_hash_receipt_trie + PUSH %%after + %jump(mpt_hash_receipt_trie) +%%after: +%endmacro + +encode_account: + // stack: rlp_pos, value_ptr, retdest + // First, we compute the length of the RLP data we're about to write. + // The nonce and balance fields are variable-length, so we need to load them + // to determine their contribution, while the other two fields are fixed + // 32-bytes integers. + DUP2 %mload_trie_data // nonce = value[0] + %scalar_rlp_len + // stack: nonce_rlp_len, rlp_pos, value_ptr, retdest + DUP3 %add_const(1) %mload_trie_data // balance = value[1] + %scalar_rlp_len + // stack: balance_rlp_lenm, nonce_rlp_len, rlp_pos, value_ptr, retdest + PUSH 66 // storage_root and code_hash fields each take 1 + 32 bytes + ADD ADD + // stack: payload_len, rlp_pos, value_ptr, retdest + SWAP1 + %encode_rlp_list_prefix + // stack: rlp_pos', value_ptr, retdest + DUP2 %mload_trie_data // nonce = value[0] + // stack: nonce, rlp_pos', value_ptr, retdest + SWAP1 %encode_rlp_scalar + // stack: rlp_pos'', value_ptr, retdest + DUP2 %add_const(1) %mload_trie_data // balance = value[1] + // stack: balance, rlp_pos'', value_ptr, retdest + SWAP1 %encode_rlp_scalar + // stack: rlp_pos''', value_ptr, retdest + DUP2 %add_const(2) %mload_trie_data // storage_root = value[2] + // stack: storage_root, rlp_pos''', value_ptr, retdest + SWAP1 %encode_rlp_256 + // stack: rlp_pos'''', value_ptr, retdest + SWAP1 %add_const(3) %mload_trie_data // code_hash = value[3] + // stack: code_hash, rlp_pos'''', retdest + SWAP1 %encode_rlp_256 + // stack: rlp_pos''''', retdest + SWAP1 + JUMP + +encode_txn: + PANIC // TODO + +encode_receipt: + PANIC // TODO diff --git a/evm/src/cpu/kernel/asm/mpt/hex_prefix.asm b/evm/src/cpu/kernel/asm/mpt/hex_prefix.asm new file mode 100644 index 00000000..72ac18cc --- /dev/null +++ b/evm/src/cpu/kernel/asm/mpt/hex_prefix.asm @@ -0,0 +1,104 @@ +// Computes the RLP encoding of the hex-prefix encoding of the given nibble list +// and termination flag. Writes the result to @SEGMENT_RLP_RAW starting at the +// given position, and returns the updated position, i.e. a pointer to the next +// unused offset. +// +// Pre stack: rlp_start_pos, num_nibbles, packed_nibbles, terminated, retdest +// Post stack: rlp_end_pos + +global hex_prefix_rlp: + // stack: rlp_pos, num_nibbles, packed_nibbles, terminated, retdest + // We will iterate backwards, from i = num_nibbles / 2 to i = 0, so that we + // can take nibbles from the least-significant end of packed_nibbles. + PUSH 2 DUP3 DIV // i = num_nibbles / 2 + // stack: i, rlp_pos, num_nibbles, packed_nibbles, terminated, retdest + + // Compute the length of the hex-prefix string, in bytes: + // hp_len = num_nibbles / 2 + 1 = i + 1 + DUP1 %add_const(1) + // stack: hp_len, i, rlp_pos, num_nibbles, packed_nibbles, terminated, retdest + + // Write the RLP header. + DUP1 %gt_const(55) %jumpi(rlp_header_large) + DUP1 %gt_const(1) %jumpi(rlp_header_medium) + + // The hex-prefix is a single byte. It must be <= 127, since its first + // nibble only has two bits. So this is the "small" RLP string case, where + // the byte is its own RLP encoding. + // stack: hp_len, i, rlp_pos, num_nibbles, packed_nibbles, terminated, retdest + %jump(start_loop) + +rlp_header_medium: + // stack: hp_len, i, rlp_pos, num_nibbles, packed_nibbles, terminated, retdest + DUP1 %add_const(0x80) // value = 0x80 + hp_len + DUP4 // offset = rlp_pos + %mstore_rlp + + // rlp_pos += 1 + SWAP2 %add_const(1) SWAP2 + + %jump(start_loop) + +rlp_header_large: + // stack: hp_len, i, rlp_pos, num_nibbles, packed_nibbles, terminated, retdest + // In practice hex-prefix length will never exceed 256, so the length of the + // length will always be 1 byte in this case. + + PUSH 0xb8 // value = 0xb7 + len_of_len = 0xb8 + DUP4 // offset = rlp_pos + %mstore_rlp + + DUP1 // value = hp_len + DUP4 %add_const(1) // offset = rlp_pos + 1 + %mstore_rlp + + // rlp_pos += 2 + SWAP2 %add_const(2) SWAP2 + +start_loop: + // stack: hp_len, i, rlp_pos, num_nibbles, packed_nibbles, terminated, retdest + SWAP1 + +loop: + // stack: i, hp_len, rlp_pos, num_nibbles, packed_nibbles, terminated, retdest + // If i == 0, break to first_byte. + DUP1 ISZERO %jumpi(first_byte) + + // stack: i, hp_len, rlp_pos, num_nibbles, packed_nibbles, terminated, retdest + DUP5 // packed_nibbles + %and_const(0xFF) + // stack: byte_i, i, hp_len, rlp_pos, num_nibbles, packed_nibbles, terminated, retdest + DUP4 // rlp_pos + DUP3 // i + ADD // We'll write to offset rlp_pos + i + %mstore_rlp + + // stack: i, hp_len, rlp_pos, num_nibbles, packed_nibbles, terminated, retdest + %sub_const(1) + SWAP4 %shr_const(8) SWAP4 // packed_nibbles >>= 8 + %jump(loop) + +first_byte: + // stack: 0, hp_len, rlp_pos, num_nibbles, first_nibble_or_zero, terminated, retdest + POP + // stack: hp_len, rlp_pos, num_nibbles, first_nibble_or_zero, terminated, retdest + DUP2 ADD + // stack: rlp_end_pos, rlp_pos, num_nibbles, first_nibble_or_zero, terminated, retdest + SWAP4 + // stack: terminated, rlp_pos, num_nibbles, first_nibble_or_zero, rlp_end_pos, retdest + %mul_const(2) + // stack: terminated * 2, rlp_pos, num_nibbles, first_nibble_or_zero, rlp_end_pos, retdest + %stack (terminated_x2, rlp_pos, num_nibbles, first_nibble_or_zero) + -> (num_nibbles, terminated_x2, first_nibble_or_zero, rlp_pos) + // stack: num_nibbles, terminated * 2, first_nibble_or_zero, rlp_pos, rlp_end_pos, retdest + %mod_const(2) // parity + ADD + // stack: parity + terminated * 2, first_nibble_or_zero, rlp_pos, rlp_end_pos, retdest + %mul_const(16) + ADD + // stack: first_byte, rlp_pos, rlp_end_pos, retdest + SWAP1 + %mstore_rlp + // stack: rlp_end_pos, retdest + SWAP1 + JUMP diff --git a/evm/src/cpu/kernel/asm/mpt/load.asm b/evm/src/cpu/kernel/asm/mpt/load.asm new file mode 100644 index 00000000..2f1bd624 --- /dev/null +++ b/evm/src/cpu/kernel/asm/mpt/load.asm @@ -0,0 +1,180 @@ +// TODO: Receipt trie leaves are variable-length, so we need to be careful not +// to permit buffer over-reads. + +// Load all partial trie data from prover inputs. +global load_all_mpts: + // stack: retdest + // First set @GLOBAL_METADATA_TRIE_DATA_SIZE = 1. + // We don't want it to start at 0, as we use 0 as a null pointer. + PUSH 1 + %set_trie_data_size + + %load_mpt_and_return_root_ptr %mstore_global_metadata(@GLOBAL_METADATA_STATE_TRIE_ROOT) + %load_mpt_and_return_root_ptr %mstore_global_metadata(@GLOBAL_METADATA_TXN_TRIE_ROOT) + %load_mpt_and_return_root_ptr %mstore_global_metadata(@GLOBAL_METADATA_RECEIPT_TRIE_ROOT) + + PROVER_INPUT(mpt) + // stack: num_storage_tries, retdest + DUP1 %mstore_global_metadata(@GLOBAL_METADATA_NUM_STORAGE_TRIES) + // stack: num_storage_tries, retdest + PUSH 0 // i = 0 + // stack: i, num_storage_tries, retdest +storage_trie_loop: + DUP2 DUP2 EQ + // stack: i == num_storage_tries, i, num_storage_tries, retdest + %jumpi(storage_trie_loop_end) + // stack: i, num_storage_tries, retdest + PROVER_INPUT(mpt) + // stack: storage_trie_addr, i, num_storage_tries, retdest + DUP2 + // stack: i, storage_trie_addr, i, num_storage_tries, retdest + %mstore_kernel(@SEGMENT_STORAGE_TRIE_ADDRS) + // stack: i, num_storage_tries, retdest + %load_mpt_and_return_root_ptr + // stack: root_ptr, i, num_storage_tries, retdest + DUP2 + // stack: i, root_ptr, i, num_storage_tries, retdest + %mstore_kernel(@SEGMENT_STORAGE_TRIE_PTRS) + // stack: i, num_storage_tries, retdest + %jump(storage_trie_loop) +storage_trie_loop_end: + // stack: i, num_storage_tries, retdest + %pop2 + // stack: retdest + JUMP + +// Load an MPT from prover inputs. +// Pre stack: retdest +// Post stack: (empty) +load_mpt: + // stack: retdest + PROVER_INPUT(mpt) + // stack: node_type, retdest + DUP1 %append_to_trie_data + // stack: node_type, retdest + + DUP1 %eq_const(@MPT_NODE_EMPTY) %jumpi(load_mpt_empty) + DUP1 %eq_const(@MPT_NODE_BRANCH) %jumpi(load_mpt_branch) + DUP1 %eq_const(@MPT_NODE_EXTENSION) %jumpi(load_mpt_extension) + DUP1 %eq_const(@MPT_NODE_LEAF) %jumpi(load_mpt_leaf) + DUP1 %eq_const(@MPT_NODE_HASH) %jumpi(load_mpt_digest) + PANIC // Invalid node type + +load_mpt_empty: + // stack: node_type, retdest + POP + // stack: retdest + JUMP + +load_mpt_branch: + // stack: node_type, retdest + POP + // stack: retdest + %get_trie_data_size + // stack: ptr_children, retdest + DUP1 %add_const(16) + // stack: ptr_leaf, ptr_children, retdest + %set_trie_data_size + // stack: ptr_children, retdest + %load_leaf_value + + // Save the current trie_data_size (which now points to the end of the leaf) + // for later, then have it point to the start of our 16 child pointers. + %get_trie_data_size + // stack: ptr_end_of_leaf, ptr_children, retdest + SWAP1 + %set_trie_data_size + // stack: ptr_end_of_leaf, retdest + + // Load the 16 children. + %rep 16 + %load_mpt_and_return_root_ptr + // stack: child_ptr, ptr_end_of_leaf, retdest + %append_to_trie_data + // stack: ptr_end_of_leaf, retdest + %endrep + + %set_trie_data_size + // stack: retdest + JUMP + +load_mpt_extension: + // stack: node_type, retdest + POP + // stack: retdest + PROVER_INPUT(mpt) // read num_nibbles + %append_to_trie_data + PROVER_INPUT(mpt) // read packed_nibbles + %append_to_trie_data + // stack: retdest + + // Let i be the current trie data size. We still need to expand this node by + // one element, appending our child pointer. Thus our child node will start + // at i + 1. So we will set our child pointer to i + 1. + %get_trie_data_size + %add_const(1) + %append_to_trie_data + // stack: retdest + + %load_mpt + // stack: retdest + JUMP + +load_mpt_leaf: + // stack: node_type, retdest + POP + // stack: retdest + PROVER_INPUT(mpt) // read num_nibbles + %append_to_trie_data + PROVER_INPUT(mpt) // read packed_nibbles + %append_to_trie_data + // stack: retdest + %load_leaf_value + // stack: retdest + JUMP + +load_mpt_digest: + // stack: node_type, retdest + POP + // stack: retdest + PROVER_INPUT(mpt) // read digest + %append_to_trie_data + // stack: retdest + JUMP + +// Convenience macro to call load_mpt and return where we left off. +%macro load_mpt + PUSH %%after + %jump(load_mpt) +%%after: +%endmacro + +%macro load_mpt_and_return_root_ptr + // stack: (empty) + %get_trie_data_size + // stack: ptr + %load_mpt + // stack: ptr +%endmacro + +// Load a leaf from prover input, and append it to trie data. +%macro load_leaf_value + // stack: (empty) + PROVER_INPUT(mpt) + // stack: leaf_len +%%loop: + DUP1 ISZERO + // stack: leaf_len == 0, leaf_len + %jumpi(%%finish) + // stack: leaf_len + PROVER_INPUT(mpt) + // stack: leaf_part, leaf_len + %append_to_trie_data + // stack: leaf_len + %sub_const(1) + // stack: leaf_len' + %jump(%%loop) +%%finish: + POP + // stack: (empty) +%endmacro diff --git a/evm/src/cpu/kernel/asm/mpt/read.asm b/evm/src/cpu/kernel/asm/mpt/read.asm new file mode 100644 index 00000000..aec0c776 --- /dev/null +++ b/evm/src/cpu/kernel/asm/mpt/read.asm @@ -0,0 +1,144 @@ +// Given an address, return a pointer to the associated account data, which +// consists of four words (nonce, balance, storage_root, code_hash), in the +// state trie. Returns 0 if the address is not found. +global mpt_read_state_trie: + // stack: addr, retdest + // The key is the hash of the address. Since KECCAK_GENERAL takes input from + // memory, we will write addr bytes to SEGMENT_KERNEL_GENERAL[0..20] first. + %stack (addr) -> (0, @SEGMENT_KERNEL_GENERAL, 0, addr, 20, mpt_read_state_trie_after_mstore) + %jump(mstore_unpacking) +mpt_read_state_trie_after_mstore: + // stack: retdest + %stack () -> (0, @SEGMENT_KERNEL_GENERAL, 0, 20) // context, segment, offset, len + KECCAK_GENERAL + // stack: key, retdest + PUSH 64 // num_nibbles + %mload_global_metadata(@GLOBAL_METADATA_STATE_TRIE_ROOT) // node_ptr + // stack: node_ptr, num_nibbles, key, retdest + %jump(mpt_read) + +// Read a value from a MPT. +// +// Arguments: +// - the virtual address of the trie to search in +// - the key, as a U256 +// - the number of nibbles in the key (should start at 64) +// +// This function returns a pointer to the leaf, or 0 if the key is not found. +global mpt_read: + // stack: node_ptr, num_nibbles, key, retdest + DUP1 + %mload_trie_data + // stack: node_type, node_ptr, num_nibbles, key, retdest + // Increment node_ptr, so it points to the node payload instead of its type. + SWAP1 %add_const(1) SWAP1 + // stack: node_type, node_payload_ptr, num_nibbles, key, retdest + + DUP1 %eq_const(@MPT_NODE_EMPTY) %jumpi(mpt_read_empty) + DUP1 %eq_const(@MPT_NODE_BRANCH) %jumpi(mpt_read_branch) + DUP1 %eq_const(@MPT_NODE_EXTENSION) %jumpi(mpt_read_extension) + DUP1 %eq_const(@MPT_NODE_LEAF) %jumpi(mpt_read_leaf) + + // There's still the MPT_NODE_HASH case, but if we hit a digest node, + // it means the prover failed to provide necessary Merkle data, so panic. + PANIC + +mpt_read_empty: + // Return 0 to indicate that the value was not found. + %stack (node_type, node_payload_ptr, num_nibbles, key, retdest) + -> (retdest, 0) + JUMP + +mpt_read_branch: + // stack: node_type, node_payload_ptr, num_nibbles, key, retdest + POP + // stack: node_payload_ptr, num_nibbles, key, retdest + DUP2 // num_nibbles + ISZERO + // stack: num_nibbles == 0, node_payload_ptr, num_nibbles, key, retdest + %jumpi(mpt_read_branch_end_of_key) + + // We have not reached the end of the key, so we descend to one of our children. + // stack: node_payload_ptr, num_nibbles, key, retdest + %stack (node_payload_ptr, num_nibbles, key) + -> (num_nibbles, key, node_payload_ptr) + // stack: num_nibbles, key, node_payload_ptr, retdest + %split_first_nibble + %stack (first_nibble, num_nibbles, key, node_payload_ptr) + -> (node_payload_ptr, first_nibble, num_nibbles, key) + // child_ptr = load(node_payload_ptr + first_nibble) + ADD %mload_trie_data + // stack: child_ptr, num_nibbles, key, retdest + %jump(mpt_read) // recurse + +mpt_read_branch_end_of_key: + %stack (node_payload_ptr, num_nibbles, key, retdest) -> (node_payload_ptr, retdest) + // stack: node_payload_ptr, retdest + %add_const(16) // skip over the 16 child nodes + // stack: leaf_ptr, retdest + SWAP1 + JUMP + +mpt_read_extension: + // stack: node_type, node_payload_ptr, num_nibbles, key, retdest + %stack (node_type, node_payload_ptr, num_nibbles, key) + -> (num_nibbles, key, node_payload_ptr) + // stack: num_nibbles, key, node_payload_ptr, retdest + DUP3 %mload_trie_data + // stack: node_num_nibbles, num_nibbles, key, node_payload_ptr, retdest + SWAP1 + SUB + // stack: future_nibbles, key, node_payload_ptr, retdest + DUP2 DUP2 + // stack: future_nibbles, key, future_nibbles, key, node_payload_ptr, retdest + %mul_const(4) SHR // key_part = key >> (future_nibbles * 4) + DUP1 + // stack: key_part, key_part, future_nibbles, key, node_payload_ptr, retdest + DUP5 %add_const(1) %mload_trie_data + // stack: node_key, key_part, key_part, future_nibbles, key, node_payload_ptr, retdest + EQ // does the first part of our key match the node's key? + %jumpi(mpt_read_extension_found) + // Not found; return 0. + %stack (key_part, future_nibbles, node_payload_ptr, retdest) -> (retdest, 0) + JUMP +mpt_read_extension_found: + // stack: key_part, future_nibbles, key, node_payload_ptr, retdest + DUP2 %mul_const(4) SHL // key_part_shifted = (key_part << (future_nibbles * 4)) + // stack: key_part_shifted, future_nibbles, key, node_payload_ptr, retdest + %stack (key_part_shifted, future_nibbles, key) + -> (key, key_part_shifted, future_nibbles) + SUB // key -= key_part_shifted + // stack: key, future_nibbles, node_payload_ptr, retdest + SWAP2 + // stack: node_payload_ptr, future_nibbles, key, retdest + %add_const(2) // child pointer is third field of extension node + %mload_trie_data + // stack: child_ptr, future_nibbles, key, retdest + %jump(mpt_read) // recurse + +mpt_read_leaf: + // stack: node_type, node_payload_ptr, num_nibbles, key, retdest + POP + // stack: node_payload_ptr, num_nibbles, key, retdest + DUP1 %mload_trie_data + // stack: node_num_nibbles, node_payload_ptr, num_nibbles, key, retdest + DUP2 %add_const(1) %mload_trie_data + // stack: node_key, node_num_nibbles, node_payload_ptr, num_nibbles, key, retdest + SWAP3 + // stack: num_nibbles, node_num_nibbles, node_payload_ptr, node_key, key, retdest + EQ + %stack (num_nibbles_match, node_payload_ptr, node_key, key) + -> (key, node_key, num_nibbles_match, node_payload_ptr) + EQ + AND + // stack: keys_match && num_nibbles_match, node_payload_ptr, retdest + %jumpi(mpt_read_leaf_found) + // Not found; return 0. + %stack (node_payload_ptr, retdest) -> (retdest, 0) + JUMP +mpt_read_leaf_found: + // stack: node_payload_ptr, retdest + %add_const(2) // The leaf data is located after num_nibbles and the key. + // stack: value_ptr, retdest + SWAP1 + JUMP diff --git a/evm/src/cpu/kernel/asm/storage/read.asm b/evm/src/cpu/kernel/asm/mpt/storage_read.asm similarity index 100% rename from evm/src/cpu/kernel/asm/storage/read.asm rename to evm/src/cpu/kernel/asm/mpt/storage_read.asm diff --git a/evm/src/cpu/kernel/asm/storage/write.asm b/evm/src/cpu/kernel/asm/mpt/storage_write.asm similarity index 100% rename from evm/src/cpu/kernel/asm/storage/write.asm rename to evm/src/cpu/kernel/asm/mpt/storage_write.asm diff --git a/evm/src/cpu/kernel/asm/mpt/util.asm b/evm/src/cpu/kernel/asm/mpt/util.asm new file mode 100644 index 00000000..0e0006d3 --- /dev/null +++ b/evm/src/cpu/kernel/asm/mpt/util.asm @@ -0,0 +1,74 @@ +%macro mload_trie_data + // stack: virtual + %mload_kernel(@SEGMENT_TRIE_DATA) + // stack: value +%endmacro + +%macro mstore_trie_data + // stack: virtual, value + %mstore_kernel(@SEGMENT_TRIE_DATA) + // stack: (empty) +%endmacro + +%macro get_trie_data_size + // stack: (empty) + %mload_global_metadata(@GLOBAL_METADATA_TRIE_DATA_SIZE) + // stack: trie_data_size +%endmacro + +%macro set_trie_data_size + // stack: trie_data_size + %mstore_global_metadata(@GLOBAL_METADATA_TRIE_DATA_SIZE) + // stack: (empty) +%endmacro + +// Equivalent to: trie_data[trie_data_size++] = value +%macro append_to_trie_data + // stack: value + %get_trie_data_size + // stack: trie_data_size, value + DUP1 + %add_const(1) + // stack: trie_data_size', trie_data_size, value + %set_trie_data_size + // stack: trie_data_size, value + %mstore_trie_data + // stack: (empty) +%endmacro + +// Split off the first nibble from a key part. Roughly equivalent to +// def split_first_nibble(num_nibbles, key): +// num_nibbles -= 1 +// num_nibbles_x4 = num_nibbles * 4 +// first_nibble = (key >> num_nibbles_x4) & 0xF +// key -= (first_nibble << num_nibbles_x4) +// return (first_nibble, num_nibbles, key) +%macro split_first_nibble + // stack: num_nibbles, key + %sub_const(1) // num_nibbles -= 1 + // stack: num_nibbles, key + DUP2 + // stack: key, num_nibbles, key + DUP2 %mul_const(4) + // stack: num_nibbles_x4, key, num_nibbles, key + SHR + // stack: key >> num_nibbles_x4, num_nibbles, key + %and_const(0xF) + // stack: first_nibble, num_nibbles, key + DUP1 + // stack: first_nibble, first_nibble, num_nibbles, key + DUP3 %mul_const(4) + // stack: num_nibbles_x4, first_nibble, first_nibble, num_nibbles, key + SHL + // stack: first_nibble << num_nibbles_x4, first_nibble, num_nibbles, key + DUP1 + // stack: junk, first_nibble << num_nibbles_x4, first_nibble, num_nibbles, key + SWAP4 + // stack: key, first_nibble << num_nibbles_x4, first_nibble, num_nibbles, junk + SUB + // stack: key, first_nibble, num_nibbles, junk + SWAP3 + // stack: junk, first_nibble, num_nibbles, key + POP + // stack: first_nibble, num_nibbles, key +%endmacro diff --git a/evm/src/cpu/kernel/asm/mpt/write.asm b/evm/src/cpu/kernel/asm/mpt/write.asm new file mode 100644 index 00000000..5b59d016 --- /dev/null +++ b/evm/src/cpu/kernel/asm/mpt/write.asm @@ -0,0 +1,3 @@ +global mpt_write: + // stack: node_ptr, num_nibbles, key, retdest + // TODO diff --git a/evm/src/cpu/kernel/asm/rlp/decode.asm b/evm/src/cpu/kernel/asm/rlp/decode.asm index 0388276a..5749aee7 100644 --- a/evm/src/cpu/kernel/asm/rlp/decode.asm +++ b/evm/src/cpu/kernel/asm/rlp/decode.asm @@ -12,7 +12,6 @@ // Pre stack: pos, retdest // Post stack: pos', len global decode_rlp_string_len: - JUMPDEST // stack: pos, retdest DUP1 %mload_current(@SEGMENT_RLP_RAW) @@ -32,7 +31,6 @@ global decode_rlp_string_len: JUMP decode_rlp_string_len_medium: - JUMPDEST // String is 0-55 bytes long. First byte contains the len. // stack: first_byte, pos, retdest %sub_const(0x80) @@ -44,7 +42,6 @@ decode_rlp_string_len_medium: JUMP decode_rlp_string_len_large: - JUMPDEST // String is >55 bytes long. First byte contains the len of the len. // stack: first_byte, pos, retdest %sub_const(0xb7) @@ -69,7 +66,6 @@ decode_rlp_string_len_large: // bytes, so that the result can be returned as a single word on the stack. // As per the spec, scalars must not have leading zeros. global decode_rlp_scalar: - JUMPDEST // stack: pos, retdest PUSH decode_int_given_len // stack: decode_int_given_len, pos, retdest @@ -91,7 +87,6 @@ global decode_rlp_scalar: // Pre stack: pos, retdest // Post stack: pos', len global decode_rlp_list_len: - JUMPDEST // stack: pos, retdest DUP1 %mload_current(@SEGMENT_RLP_RAW) @@ -116,7 +111,6 @@ global decode_rlp_list_len: JUMP decode_rlp_list_len_big: - JUMPDEST // The length of the length is first_byte - 0xf7. // stack: first_byte, pos', retdest %sub_const(0xf7) @@ -137,7 +131,6 @@ decode_rlp_list_len_big: // Pre stack: pos, len, retdest // Post stack: pos', int decode_int_given_len: - JUMPDEST %stack (pos, len, retdest) -> (pos, len, pos, retdest) ADD // stack: end_pos, pos, retdest @@ -147,7 +140,6 @@ decode_int_given_len: // stack: acc, pos, end_pos, retdest decode_int_given_len_loop: - JUMPDEST // stack: acc, pos, end_pos, retdest DUP3 DUP3 @@ -171,6 +163,5 @@ decode_int_given_len_loop: %jump(decode_int_given_len_loop) decode_int_given_len_finish: - JUMPDEST %stack (acc, pos, end_pos, retdest) -> (retdest, pos, acc) JUMP diff --git a/evm/src/cpu/kernel/asm/rlp/encode.asm b/evm/src/cpu/kernel/asm/rlp/encode.asm index 7e296f9d..f92a8fda 100644 --- a/evm/src/cpu/kernel/asm/rlp/encode.asm +++ b/evm/src/cpu/kernel/asm/rlp/encode.asm @@ -79,7 +79,7 @@ encode_rlp_fixed: %add_const(1) // increment pos // stack: pos, len, string, retdest %stack (pos, len, string) -> (@SEGMENT_RLP_RAW, pos, string, len, encode_rlp_fixed_finish, pos, len) - GET_CONTEXT + PUSH 0 // context // stack: context, segment, pos, string, len, encode_rlp_fixed, pos, retdest %jump(mstore_unpacking) @@ -90,10 +90,125 @@ encode_rlp_fixed_finish: SWAP1 JUMP +// Pre stack: pos, payload_len, retdest +// Post stack: pos' +global encode_rlp_list_prefix: + // stack: pos, payload_len, retdest + DUP2 %gt_const(55) + %jumpi(encode_rlp_list_prefix_large) + // Small case: prefix is just 0xc0 + length. + // stack: pos, payload_len, retdest + SWAP1 + %add_const(0xc0) + // stack: prefix, pos, retdest + DUP2 + // stack: pos, prefix, pos, retdest + %mstore_rlp + // stack: pos, retdest + %add_const(1) + SWAP1 + JUMP +encode_rlp_list_prefix_large: + // Write 0xf7 + len_of_len. + // stack: pos, payload_len, retdest + DUP2 %num_bytes + // stack: len_of_len, pos, payload_len, retdest + DUP1 %add_const(0xf7) + // stack: first_byte, len_of_len, pos, payload_len, retdest + DUP3 // pos + %mstore_rlp + // stack: len_of_len, pos, payload_len, retdest + SWAP1 %add_const(1) + // stack: pos', len_of_len, payload_len, retdest + %stack (pos, len_of_len, payload_len, retdest) + -> (0, @SEGMENT_RLP_RAW, pos, payload_len, len_of_len, + encode_rlp_list_prefix_large_done_writing_len, + pos, len_of_len, retdest) + %jump(mstore_unpacking) +encode_rlp_list_prefix_large_done_writing_len: + // stack: pos', len_of_len, retdest + ADD + // stack: pos'', retdest + SWAP1 + JUMP + +%macro encode_rlp_list_prefix + %stack (pos, payload_len) -> (pos, payload_len, %%after) + %jump(encode_rlp_list_prefix) +%%after: +%endmacro + +// Given an RLP list payload which starts at position 9 and ends at the given +// position, prepend the appropriate RLP list prefix. Returns the updated start +// position, as well as the length of the RLP data (including the newly-added +// prefix). +// +// (We sometimes start list payloads at position 9 because 9 is the length of +// the longest possible RLP list prefix.) +// +// Pre stack: end_pos, retdest +// Post stack: start_pos, rlp_len +global prepend_rlp_list_prefix: + // stack: end_pos, retdest + // Since the list payload starts at position 9, payload_len = end_pos - 9. + PUSH 9 DUP2 SUB + // stack: payload_len, end_pos, retdest + DUP1 %gt_const(55) + %jumpi(prepend_rlp_list_prefix_big) + + // If we got here, we have a small list, so we prepend 0xc0 + len at position 8. + // stack: payload_len, end_pos, retdest + %add_const(0xc0) + // stack: prefix_byte, end_pos, retdest + PUSH 8 // offset + %mstore_rlp + // stack: end_pos, retdest + %sub_const(8) + // stack: rlp_len, retdest + PUSH 8 // start_pos + %stack (start_pos, rlp_len, retdest) -> (retdest, start_pos, rlp_len) + JUMP + +prepend_rlp_list_prefix_big: + // We have a large list, so we prepend 0xf7 + len_of_len at position + // 8 - len_of_len, followed by the length itself. + // stack: payload_len, end_pos, retdest + DUP1 %num_bytes + // stack: len_of_len, payload_len, end_pos, retdest + DUP1 + PUSH 8 + SUB + // stack: start_pos, len_of_len, payload_len, end_pos, retdest + DUP2 %add_const(0xf7) DUP2 %mstore_rlp // rlp[start_pos] = 0xf7 + len_of_len + DUP1 %add_const(1) // start_len_pos = start_pos + 1 + %stack (start_len_pos, start_pos, len_of_len, payload_len, end_pos, retdest) + -> (0, @SEGMENT_RLP_RAW, start_len_pos, // context, segment, offset + payload_len, len_of_len, + prepend_rlp_list_prefix_big_done_writing_len, + start_pos, end_pos, retdest) + %jump(mstore_unpacking) +prepend_rlp_list_prefix_big_done_writing_len: + // stack: start_pos, end_pos, retdest + DUP1 + SWAP2 + // stack: end_pos, start_pos, start_pos, retdest + SUB + // stack: rlp_len, start_pos, retdest + %stack (rlp_len, start_pos, retdest) -> (retdest, start_pos, rlp_len) + JUMP + +// Convenience macro to call prepend_rlp_list_prefix and return where we left off. +%macro prepend_rlp_list_prefix + %stack (start_pos) -> (start_pos, %%after) + %jump(prepend_rlp_list_prefix) +%%after: +%endmacro + // Get the number of bytes required to represent the given scalar. // The scalar is assumed to be non-zero, as small scalars like zero should // have already been handled with the small-scalar encoding. -num_bytes: +// TODO: Should probably unroll the loop +global num_bytes: // stack: x, retdest PUSH 0 // i // stack: i, x, retdest @@ -125,3 +240,22 @@ num_bytes_finish: %jump(num_bytes) %%after: %endmacro + +// Given some scalar, compute the number of bytes used in its RLP encoding, +// including any length prefix. +%macro scalar_rlp_len + // stack: scalar + // Since the scalar fits in a word, we can't hit the large (>55 byte) + // case, so we just check for small vs medium. + DUP1 %gt_const(0x7f) + // stack: is_medium, scalar + %jumpi(%%medium) + // Small case; result is 1. + %stack (scalar) -> (1) +%%medium: + // stack: scalar + %num_bytes + // stack: scalar_bytes + %add_const(1) // Account for the length prefix. + // stack: rlp_len +%endmacro diff --git a/evm/src/cpu/kernel/asm/rlp/read_to_memory.asm b/evm/src/cpu/kernel/asm/rlp/read_to_memory.asm index ae75e3d7..189edd1d 100644 --- a/evm/src/cpu/kernel/asm/rlp/read_to_memory.asm +++ b/evm/src/cpu/kernel/asm/rlp/read_to_memory.asm @@ -5,15 +5,13 @@ // Post stack: (empty) global read_rlp_to_memory: - JUMPDEST // stack: retdest - PROVER_INPUT // Read the RLP blob length from the prover tape. + PROVER_INPUT(rlp) // Read the RLP blob length from the prover tape. // stack: len, retdest PUSH 0 // initial position // stack: pos, len, retdest read_rlp_to_memory_loop: - JUMPDEST // stack: pos, len, retdest DUP2 DUP2 @@ -21,7 +19,7 @@ read_rlp_to_memory_loop: // stack: pos == len, pos, len, retdest %jumpi(read_rlp_to_memory_finish) // stack: pos, len, retdest - PROVER_INPUT + PROVER_INPUT(rlp) // stack: byte, pos, len, retdest DUP2 // stack: pos, byte, pos, len, retdest @@ -32,7 +30,6 @@ read_rlp_to_memory_loop: %jump(read_rlp_to_memory_loop) read_rlp_to_memory_finish: - JUMPDEST // stack: pos, len, retdest %pop2 // stack: retdest diff --git a/evm/src/cpu/kernel/asm/transactions/router.asm b/evm/src/cpu/kernel/asm/transactions/router.asm index 01a65fec..974fed99 100644 --- a/evm/src/cpu/kernel/asm/transactions/router.asm +++ b/evm/src/cpu/kernel/asm/transactions/router.asm @@ -3,16 +3,14 @@ // jump to the appropriate transaction parsing method. global route_txn: - JUMPDEST - // stack: (empty) + // stack: retdest // First load transaction data into memory, where it will be parsed. PUSH read_txn_from_memory %jump(read_rlp_to_memory) // At this point, the raw txn data is in memory. read_txn_from_memory: - JUMPDEST - // stack: (empty) + // stack: retdest // We will peak at the first byte to determine what type of transaction this is. // Note that type 1 and 2 transactions have a first byte of 1 and 2, respectively. @@ -22,17 +20,17 @@ read_txn_from_memory: PUSH 0 %mload_current(@SEGMENT_RLP_RAW) %eq_const(1) - // stack: first_byte == 1 + // stack: first_byte == 1, retdest %jumpi(process_type_1_txn) - // stack: (empty) + // stack: retdest PUSH 0 %mload_current(@SEGMENT_RLP_RAW) %eq_const(2) - // stack: first_byte == 2 + // stack: first_byte == 2, retdest %jumpi(process_type_2_txn) - // stack: (empty) + // stack: retdest // At this point, since it's not a type 1 or 2 transaction, // it must be a legacy (aka type 0) transaction. - %jump(process_type_2_txn) + %jump(process_type_0_txn) diff --git a/evm/src/cpu/kernel/asm/transactions/type_0.asm b/evm/src/cpu/kernel/asm/transactions/type_0.asm index 7c8488f7..7bc7a399 100644 --- a/evm/src/cpu/kernel/asm/transactions/type_0.asm +++ b/evm/src/cpu/kernel/asm/transactions/type_0.asm @@ -12,16 +12,15 @@ // keccak256(rlp([nonce, gas_price, gas_limit, to, value, data])) global process_type_0_txn: - JUMPDEST - // stack: (empty) + // stack: retdest PUSH 0 // initial pos - // stack: pos + // stack: pos, retdest %decode_rlp_list_len // We don't actually need the length. %stack (pos, len) -> (pos) // Decode the nonce and store it. - // stack: pos + // stack: pos, retdest %decode_rlp_scalar %stack (pos, nonce) -> (nonce, pos) %mstore_txn_field(@TXN_FIELD_NONCE) @@ -30,38 +29,38 @@ global process_type_0_txn: // For legacy transactions, we set both the // TXN_FIELD_MAX_PRIORITY_FEE_PER_GAS and TXN_FIELD_MAX_FEE_PER_GAS // fields to gas_price. - // stack: pos + // stack: pos, retdest %decode_rlp_scalar %stack (pos, gas_price) -> (gas_price, gas_price, pos) %mstore_txn_field(@TXN_FIELD_MAX_PRIORITY_FEE_PER_GAS) %mstore_txn_field(@TXN_FIELD_MAX_FEE_PER_GAS) // Decode the gas limit and store it. - // stack: pos + // stack: pos, retdest %decode_rlp_scalar %stack (pos, gas_limit) -> (gas_limit, pos) %mstore_txn_field(@TXN_FIELD_GAS_LIMIT) // Decode the "to" field and store it. - // stack: pos + // stack: pos, retdest %decode_rlp_scalar %stack (pos, to) -> (to, pos) %mstore_txn_field(@TXN_FIELD_TO) // Decode the value field and store it. - // stack: pos + // stack: pos, retdest %decode_rlp_scalar %stack (pos, value) -> (value, pos) %mstore_txn_field(@TXN_FIELD_VALUE) // Decode the data length, store it, and compute new_pos after any data. - // stack: pos + // stack: pos, retdest %decode_rlp_string_len %stack (pos, data_len) -> (data_len, pos, data_len, pos, data_len) %mstore_txn_field(@TXN_FIELD_DATA_LEN) - // stack: pos, data_len, pos, data_len + // stack: pos, data_len, pos, data_len, retdest ADD - // stack: new_pos, pos, data_len + // stack: new_pos, pos, data_len, retdest // Memcpy the txn data from @SEGMENT_RLP_RAW to @SEGMENT_TXN_DATA. PUSH parse_v @@ -71,62 +70,62 @@ global process_type_0_txn: PUSH 0 PUSH @SEGMENT_TXN_DATA GET_CONTEXT - // stack: DST, SRC, data_len, parse_v, new_pos + // stack: DST, SRC, data_len, parse_v, new_pos, retdest %jump(memcpy) parse_v: - // stack: pos + // stack: pos, retdest %decode_rlp_scalar - // stack: pos, v + // stack: pos, v, retdest SWAP1 - // stack: v, pos + // stack: v, pos, retdest DUP1 %gt_const(28) - // stack: v > 28, v, pos + // stack: v > 28, v, pos, retdest %jumpi(process_v_new_style) // We have an old style v, so y_parity = v - 27. // No chain ID is present, so we can leave TXN_FIELD_CHAIN_ID_PRESENT and // TXN_FIELD_CHAIN_ID with their default values of zero. - // stack: v, pos + // stack: v, pos, retdest %sub_const(27) %stack (y_parity, pos) -> (y_parity, pos) %mstore_txn_field(@TXN_FIELD_Y_PARITY) - // stack: pos + // stack: pos, retdest %jump(parse_r) process_v_new_style: - // stack: v, pos + // stack: v, pos, retdest // We have a new style v, so chain_id_present = 1, // chain_id = (v - 35) / 2, and y_parity = (v - 35) % 2. %stack (v, pos) -> (1, v, pos) %mstore_txn_field(@TXN_FIELD_CHAIN_ID_PRESENT) - // stack: v, pos + // stack: v, pos, retdest %sub_const(35) DUP1 - // stack: v - 35, v - 35, pos + // stack: v - 35, v - 35, pos, retdest %div_const(2) - // stack: chain_id, v - 35, pos + // stack: chain_id, v - 35, pos, retdest %mstore_txn_field(@TXN_FIELD_CHAIN_ID) - // stack: v - 35, pos + // stack: v - 35, pos, retdest %mod_const(2) - // stack: y_parity, pos + // stack: y_parity, pos, retdest %mstore_txn_field(@TXN_FIELD_Y_PARITY) parse_r: - // stack: pos + // stack: pos, retdest %decode_rlp_scalar %stack (pos, r) -> (r, pos) %mstore_txn_field(@TXN_FIELD_R) - // stack: pos + // stack: pos, retdest %decode_rlp_scalar %stack (pos, s) -> (s) %mstore_txn_field(@TXN_FIELD_S) - // stack: (empty) + // stack: retdest // TODO: Write the signed txn data to memory, where it can be hashed and // checked against the signature. diff --git a/evm/src/cpu/kernel/asm/transactions/type_1.asm b/evm/src/cpu/kernel/asm/transactions/type_1.asm index 5b9d2cdf..8c7fcaae 100644 --- a/evm/src/cpu/kernel/asm/transactions/type_1.asm +++ b/evm/src/cpu/kernel/asm/transactions/type_1.asm @@ -7,6 +7,5 @@ // data, access_list])) global process_type_1_txn: - JUMPDEST - // stack: (empty) + // stack: retdest PANIC // TODO: Unfinished diff --git a/evm/src/cpu/kernel/asm/transactions/type_2.asm b/evm/src/cpu/kernel/asm/transactions/type_2.asm index 9807f88f..f1ff18d8 100644 --- a/evm/src/cpu/kernel/asm/transactions/type_2.asm +++ b/evm/src/cpu/kernel/asm/transactions/type_2.asm @@ -8,6 +8,5 @@ // access_list])) global process_type_2_txn: - JUMPDEST - // stack: (empty) + // stack: retdest PANIC // TODO: Unfinished diff --git a/evm/src/cpu/kernel/asm/util/assertions.asm b/evm/src/cpu/kernel/asm/util/assertions.asm index 69193e5f..0051219c 100644 --- a/evm/src/cpu/kernel/asm/util/assertions.asm +++ b/evm/src/cpu/kernel/asm/util/assertions.asm @@ -1,7 +1,6 @@ // It is convenient to have a single panic routine, which we can jump to from // anywhere. global panic: - JUMPDEST PANIC // Consumes the top element and asserts that it is zero. diff --git a/evm/src/cpu/kernel/assembler.rs b/evm/src/cpu/kernel/assembler.rs index f5175c41..ede60a29 100644 --- a/evm/src/cpu/kernel/assembler.rs +++ b/evm/src/cpu/kernel/assembler.rs @@ -9,13 +9,13 @@ use crate::cpu::kernel::ast::Item::LocalLabelDeclaration; use crate::cpu::kernel::ast::StackReplacement; use crate::cpu::kernel::keccak_util::hash_kernel; use crate::cpu::kernel::optimizer::optimize_asm; -use crate::cpu::kernel::prover_input::ProverInputFn; use crate::cpu::kernel::stack::stack_manipulation::expand_stack_manipulation; use crate::cpu::kernel::utils::u256_to_trimmed_be_bytes; use crate::cpu::kernel::{ ast::{File, Item}, opcodes::{get_opcode, get_push_opcode}, }; +use crate::generation::prover_input::ProverInputFn; /// The number of bytes to push when pushing an offset within the code (i.e. when assembling jumps). /// Ideally we would automatically use the minimal number of bytes required, but that would be @@ -52,6 +52,12 @@ impl Kernel { } } +#[derive(Eq, PartialEq, Hash, Clone, Debug)] +struct MacroSignature { + name: String, + num_params: usize, +} + struct Macro { params: Vec, items: Vec, @@ -79,20 +85,20 @@ pub(crate) fn assemble( let mut local_labels = Vec::with_capacity(files.len()); let mut macro_counter = 0; for file in files { - let expanded_file = expand_macros(file.body, ¯os, &mut macro_counter); - let expanded_file = expand_repeats(expanded_file); - let expanded_file = inline_constants(expanded_file, &constants); - let mut expanded_file = expand_stack_manipulation(expanded_file); + let mut file = file.body; + file = expand_macros(file, ¯os, &mut macro_counter); + file = inline_constants(file, &constants); + file = expand_stack_manipulation(file); if optimize { - optimize_asm(&mut expanded_file); + optimize_asm(&mut file); } local_labels.push(find_labels( - &expanded_file, + &file, &mut offset, &mut global_labels, &mut prover_inputs, )); - expanded_files.push(expanded_file); + expanded_files.push(file); } let mut code = vec![]; for (file, locals) in izip!(expanded_files, local_labels) { @@ -105,17 +111,21 @@ pub(crate) fn assemble( Kernel::new(code, global_labels, prover_inputs) } -fn find_macros(files: &[File]) -> HashMap { +fn find_macros(files: &[File]) -> HashMap { let mut macros = HashMap::new(); for file in files { for item in &file.body { if let Item::MacroDef(name, params, items) = item { - let _macro = Macro { + let signature = MacroSignature { + name: name.clone(), + num_params: params.len(), + }; + let macro_ = Macro { params: params.clone(), items: items.clone(), }; - let old = macros.insert(name.clone(), _macro); - assert!(old.is_none(), "Duplicate macro: {name}"); + let old = macros.insert(signature.clone(), macro_); + assert!(old.is_none(), "Duplicate macro signature: {:?}", signature); } } } @@ -124,7 +134,7 @@ fn find_macros(files: &[File]) -> HashMap { fn expand_macros( body: Vec, - macros: &HashMap, + macros: &HashMap, macro_counter: &mut u32, ) -> Vec { let mut expanded = vec![]; @@ -136,6 +146,11 @@ fn expand_macros( Item::MacroCall(m, args) => { expanded.extend(expand_macro_call(m, args, macros, macro_counter)); } + Item::Repeat(count, body) => { + for _ in 0..count.as_usize() { + expanded.extend(expand_macros(body.clone(), macros, macro_counter)); + } + } item => { expanded.push(item); } @@ -147,30 +162,25 @@ fn expand_macros( fn expand_macro_call( name: String, args: Vec, - macros: &HashMap, + macros: &HashMap, macro_counter: &mut u32, ) -> Vec { - let _macro = macros - .get(&name) - .unwrap_or_else(|| panic!("No such macro: {}", name)); - - assert_eq!( - args.len(), - _macro.params.len(), - "Macro `{}`: expected {} arguments, got {}", + let signature = MacroSignature { name, - _macro.params.len(), - args.len() - ); + num_params: args.len(), + }; + let macro_ = macros + .get(&signature) + .unwrap_or_else(|| panic!("No such macro: {:?}", signature)); let get_actual_label = |macro_label| format!("@{}.{}", macro_counter, macro_label); let get_arg = |var| { - let param_index = _macro.get_param_index(var); + let param_index = macro_.get_param_index(var); args[param_index].clone() }; - let expanded_item = _macro + let expanded_item = macro_ .items .iter() .map(|item| match item { @@ -182,12 +192,10 @@ fn expand_macro_call( Item::MacroCall(name, args) => { let expanded_args = args .iter() - .map(|arg| { - if let PushTarget::MacroVar(var) = arg { - get_arg(var) - } else { - arg.clone() - } + .map(|arg| match arg { + PushTarget::MacroVar(var) => get_arg(var), + PushTarget::MacroLabel(l) => PushTarget::Label(get_actual_label(l)), + _ => arg.clone(), }) .collect(); Item::MacroCall(name.clone(), expanded_args) @@ -195,12 +203,12 @@ fn expand_macro_call( Item::StackManipulation(before, after) => { let after = after .iter() - .map(|replacement| { - if let StackReplacement::MacroLabel(label) = replacement { + .map(|replacement| match replacement { + StackReplacement::MacroLabel(label) => { StackReplacement::Identifier(get_actual_label(label)) - } else { - replacement.clone() } + StackReplacement::MacroVar(var) => get_arg(var).into(), + _ => replacement.clone(), }) .collect(); Item::StackManipulation(before.clone(), after) @@ -215,21 +223,6 @@ fn expand_macro_call( expand_macros(expanded_item, macros, macro_counter) } -fn expand_repeats(body: Vec) -> Vec { - let mut expanded = vec![]; - for item in body { - if let Item::Repeat(count, block) = item { - let reps = count.as_usize(); - for _ in 0..reps { - expanded.extend(block.clone()); - } - } else { - expanded.push(item); - } - } - expanded -} - fn inline_constants(body: Vec, constants: &HashMap) -> Vec { let resolve_const = |c| { *constants @@ -489,7 +482,8 @@ mod tests { #[test] fn macro_with_label() { let files = &[ - "%macro spin %%start: PUSH %%start JUMP %endmacro", + "%macro jump(x) PUSH $x JUMP %endmacro", + "%macro spin %%start: %jump(%%start) %endmacro", "%spin %spin", ]; let kernel = parse_and_assemble_ext(files, HashMap::new(), false); @@ -508,8 +502,31 @@ mod tests { "%macro bar(y) PUSH $y %endmacro", "%foo(42)", ]); - let push = get_push_opcode(1); - assert_eq!(kernel.code, vec![push, 42, push, 42]); + let push1 = get_push_opcode(1); + assert_eq!(kernel.code, vec![push1, 42, push1, 42]); + } + + #[test] + fn macro_with_reserved_prefix() { + // The name `repeat` should be allowed, even though `rep` is reserved. + parse_and_assemble(&["%macro repeat %endmacro", "%repeat"]); + } + + #[test] + fn overloaded_macros() { + let kernel = parse_and_assemble(&[ + "%macro push(x) PUSH $x %endmacro", + "%macro push(x, y) PUSH $x PUSH $y %endmacro", + "%push(5)", + "%push(6, 7)", + ]); + let push1 = get_push_opcode(1); + assert_eq!(kernel.code, vec![push1, 5, push1, 6, push1, 7]); + } + + #[test] + fn pop2_macro() { + parse_and_assemble(&["%macro pop2 %rep 2 pop %endrep %endmacro", "%pop2"]); } #[test] @@ -551,8 +568,16 @@ mod tests { let dup1 = get_opcode("DUP1"); let swap1 = get_opcode("SWAP1"); let swap2 = get_opcode("SWAP2"); + let swap3 = get_opcode("SWAP3"); + let push_one_byte = get_push_opcode(1); let push_label = get_push_opcode(BYTES_PER_OFFSET); + let kernel = parse_and_assemble(&["%stack () -> (1, 2, 3)"]); + assert_eq!( + kernel.code, + vec![push_one_byte, 3, push_one_byte, 2, push_one_byte, 1] + ); + let kernel = parse_and_assemble(&["%stack (a) -> (a)"]); assert_eq!(kernel.code, vec![]); @@ -562,6 +587,20 @@ mod tests { let kernel = parse_and_assemble(&["%stack (a, b, c) -> (b)"]); assert_eq!(kernel.code, vec![pop, swap1, pop]); + let kernel = parse_and_assemble(&["%stack (a, b, c) -> (7, b)"]); + assert_eq!(kernel.code, vec![pop, swap1, pop, push_one_byte, 7]); + + let kernel = parse_and_assemble(&["%stack (a, b: 3, c) -> (c)"]); + assert_eq!(kernel.code, vec![pop, pop, pop, pop]); + + let kernel = parse_and_assemble(&["%stack (a: 2, b: 2) -> (b, a)"]); + assert_eq!(kernel.code, vec![swap1, swap3, swap1, swap2]); + + let kernel1 = parse_and_assemble(&["%stack (a: 3, b: 3, c) -> (c, b, a)"]); + let kernel2 = + parse_and_assemble(&["%stack (a, b, c, d, e, f, g) -> (g, d, e, f, a, b, c)"]); + assert_eq!(kernel1.code, kernel2.code); + let mut consts = HashMap::new(); consts.insert("LIFE".into(), 42.into()); parse_and_assemble_ext(&["%stack (a, b) -> (b, @LIFE)"], consts, true); @@ -575,6 +614,34 @@ mod tests { assert_eq!(kernel.code, vec![dup1]); } + #[test] + fn stack_manipulation_in_macro() { + let pop = get_opcode("POP"); + let push1 = get_push_opcode(1); + + let kernel = parse_and_assemble(&[ + "%macro set_top(x) %stack (a) -> ($x) %endmacro", + "%set_top(42)", + ]); + assert_eq!(kernel.code, vec![pop, push1, 42]); + } + + #[test] + fn stack_manipulation_in_macro_with_name_collision() { + let pop = get_opcode("POP"); + let push_label = get_push_opcode(BYTES_PER_OFFSET); + + // In the stack directive, there's a named item `foo`. + // But when we invoke `%foo(foo)`, the argument refers to the `foo` label. + // Thus the expanded macro is `%stack (foo) -> (label foo)` (not real syntax). + let kernel = parse_and_assemble(&[ + "global foo:", + "%macro foo(x) %stack (foo) -> ($x) %endmacro", + "%foo(foo)", + ]); + assert_eq!(kernel.code, vec![pop, push_label, 0, 0, 0]); + } + fn parse_and_assemble(files: &[&str]) -> Kernel { parse_and_assemble_ext(files, HashMap::new(), true) } diff --git a/evm/src/cpu/kernel/ast.rs b/evm/src/cpu/kernel/ast.rs index 24cf01e1..3728aa35 100644 --- a/evm/src/cpu/kernel/ast.rs +++ b/evm/src/cpu/kernel/ast.rs @@ -1,6 +1,6 @@ use ethereum_types::U256; -use crate::cpu::kernel::prover_input::ProverInputFn; +use crate::generation::prover_input::ProverInputFn; #[derive(Debug)] pub(crate) struct File { @@ -19,7 +19,7 @@ pub(crate) enum Item { /// The first list gives names to items on the top of the stack. /// The second list specifies replacement items. /// Example: `(a, b, c) -> (c, 5, 0x20, @SOME_CONST, a)`. - StackManipulation(Vec, Vec), + StackManipulation(Vec, Vec), /// Declares a global label. GlobalLabelDeclaration(String), /// Declares a label that is local to the current file. @@ -36,16 +36,37 @@ pub(crate) enum Item { Bytes(Vec), } +/// The left hand side of a %stack stack-manipulation macro. +#[derive(Eq, PartialEq, Clone, Debug)] +pub(crate) enum StackPlaceholder { + Identifier(String), + Block(String, usize), +} + +/// The right hand side of a %stack stack-manipulation macro. #[derive(Eq, PartialEq, Clone, Debug)] pub(crate) enum StackReplacement { + Literal(U256), /// Can be either a named item or a label. Identifier(String), - Literal(U256), + Label(String), MacroLabel(String), MacroVar(String), Constant(String), } +impl From for StackReplacement { + fn from(target: PushTarget) -> Self { + match target { + PushTarget::Literal(x) => Self::Literal(x), + PushTarget::Label(l) => Self::Label(l), + PushTarget::MacroLabel(l) => Self::MacroLabel(l), + PushTarget::MacroVar(v) => Self::MacroVar(v), + PushTarget::Constant(c) => Self::Constant(c), + } + } +} + /// The target of a `PUSH` operation. #[derive(Clone, Debug, Eq, PartialEq, Hash)] pub(crate) enum PushTarget { diff --git a/evm/src/cpu/kernel/constants.rs b/evm/src/cpu/kernel/constants/mod.rs similarity index 83% rename from evm/src/cpu/kernel/constants.rs rename to evm/src/cpu/kernel/constants/mod.rs index 98fe57c6..2694b82a 100644 --- a/evm/src/cpu/kernel/constants.rs +++ b/evm/src/cpu/kernel/constants/mod.rs @@ -4,17 +4,23 @@ use ethereum_types::U256; use hex_literal::hex; use crate::cpu::decode::invalid_opcodes_user; +use crate::cpu::kernel::constants::trie_type::PartialTrieType; use crate::cpu::kernel::context_metadata::ContextMetadata; use crate::cpu::kernel::global_metadata::GlobalMetadata; use crate::cpu::kernel::txn_fields::NormalizedTxnField; use crate::memory::segments::Segment; +pub(crate) mod trie_type; + /// Constants that are accessible to our kernel assembly code. pub fn evm_constants() -> HashMap { let mut c = HashMap::new(); for (name, value) in EC_CONSTANTS { c.insert(name.into(), U256::from_big_endian(&value)); } + for (name, value) in HASH_CONSTANTS { + c.insert(name.into(), U256::from_big_endian(&value)); + } for (name, value) in GAS_CONSTANTS { c.insert(name.into(), U256::from(value)); } @@ -30,6 +36,9 @@ pub fn evm_constants() -> HashMap { for txn_field in ContextMetadata::all() { c.insert(txn_field.var_name().into(), (txn_field as u32).into()); } + for trie_type in PartialTrieType::all() { + c.insert(trie_type.var_name().into(), (trie_type as u32).into()); + } c.insert( "INVALID_OPCODES_USER".into(), U256::from_little_endian(&invalid_opcodes_user()), @@ -37,6 +46,14 @@ pub fn evm_constants() -> HashMap { c } +const HASH_CONSTANTS: [(&str, [u8; 32]); 1] = [ + // Hash of an empty node: keccak(rlp.encode(b'')).hex() + ( + "EMPTY_NODE_HASH", + hex!("56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421"), + ), +]; + const EC_CONSTANTS: [(&str, [u8; 32]); 3] = [ ( "BN_BASE", diff --git a/evm/src/cpu/kernel/constants/trie_type.rs b/evm/src/cpu/kernel/constants/trie_type.rs new file mode 100644 index 00000000..08fd8748 --- /dev/null +++ b/evm/src/cpu/kernel/constants/trie_type.rs @@ -0,0 +1,44 @@ +use eth_trie_utils::partial_trie::PartialTrie; + +pub(crate) enum PartialTrieType { + Empty = 0, + Hash = 1, + Branch = 2, + Extension = 3, + Leaf = 4, +} + +impl PartialTrieType { + pub(crate) const COUNT: usize = 5; + + pub(crate) fn of(trie: &PartialTrie) -> Self { + match trie { + PartialTrie::Empty => Self::Empty, + PartialTrie::Hash(_) => Self::Hash, + PartialTrie::Branch { .. } => Self::Branch, + PartialTrie::Extension { .. } => Self::Extension, + PartialTrie::Leaf { .. } => Self::Leaf, + } + } + + pub(crate) fn all() -> [Self; Self::COUNT] { + [ + Self::Empty, + Self::Hash, + Self::Branch, + Self::Extension, + Self::Leaf, + ] + } + + /// The variable name that gets passed into kernel assembly code. + pub(crate) fn var_name(&self) -> &'static str { + match self { + Self::Empty => "MPT_NODE_EMPTY", + Self::Hash => "MPT_NODE_HASH", + Self::Branch => "MPT_NODE_BRANCH", + Self::Extension => "MPT_NODE_EXTENSION", + Self::Leaf => "MPT_NODE_LEAF", + } + } +} diff --git a/evm/src/cpu/kernel/evm_asm.pest b/evm/src/cpu/kernel/evm_asm.pest index 8ea7de4b..9b8721f4 100644 --- a/evm/src/cpu/kernel/evm_asm.pest +++ b/evm/src/cpu/kernel/evm_asm.pest @@ -17,17 +17,23 @@ constant = ${ "@" ~ identifier } item = { macro_def | macro_call | repeat | stack | global_label_decl | local_label_decl | macro_label_decl | bytes_item | push_instruction | prover_input_instruction | nullary_instruction } macro_def = { ^"%macro" ~ identifier ~ paramlist? ~ item* ~ ^"%endmacro" } -macro_call = ${ "%" ~ !(^"macro" | ^"endmacro" | ^"rep" | ^"endrep" | ^"stack") ~ identifier ~ macro_arglist? } +macro_call = ${ "%" ~ !((^"macro" | ^"endmacro" | ^"rep" | ^"endrep" | ^"stack") ~ !identifier_char) ~ identifier ~ macro_arglist? } repeat = { ^"%rep" ~ literal ~ item* ~ ^"%endrep" } paramlist = { "(" ~ identifier ~ ("," ~ identifier)* ~ ")" } macro_arglist = !{ "(" ~ push_target ~ ("," ~ push_target)* ~ ")" } -stack = { ^"%stack" ~ paramlist ~ "->" ~ stack_replacements } + +stack = { ^"%stack" ~ stack_placeholders ~ "->" ~ stack_replacements } +stack_placeholders = { "(" ~ (stack_placeholder ~ ("," ~ stack_placeholder)*)? ~ ")" } +stack_placeholder = { stack_block | identifier } +stack_block = { identifier ~ ":" ~ literal_decimal } stack_replacements = { "(" ~ stack_replacement ~ ("," ~ stack_replacement)* ~ ")" } stack_replacement = { literal | identifier | constant | macro_label | variable } + global_label_decl = ${ ^"GLOBAL " ~ identifier ~ ":" } local_label_decl = ${ identifier ~ ":" } macro_label_decl = ${ "%%" ~ identifier ~ ":" } macro_label = ${ "%%" ~ identifier } + bytes_item = { ^"BYTES " ~ literal ~ ("," ~ literal)* } push_instruction = { ^"PUSH " ~ push_target } push_target = { literal | identifier | macro_label | variable | constant } diff --git a/evm/src/cpu/kernel/global_metadata.rs b/evm/src/cpu/kernel/global_metadata.rs index ddc3c839..f3f34e7a 100644 --- a/evm/src/cpu/kernel/global_metadata.rs +++ b/evm/src/cpu/kernel/global_metadata.rs @@ -24,13 +24,13 @@ pub(crate) enum GlobalMetadata { // The root digests of each Merkle trie before these transactions. StateTrieRootDigestBefore = 8, - TransactionsTrieRootDigestBefore = 9, - ReceiptsTrieRootDigestBefore = 10, + TransactionTrieRootDigestBefore = 9, + ReceiptTrieRootDigestBefore = 10, // The root digests of each Merkle trie after these transactions. StateTrieRootDigestAfter = 11, - TransactionsTrieRootDigestAfter = 12, - ReceiptsTrieRootDigestAfter = 13, + TransactionTrieRootDigestAfter = 12, + ReceiptTrieRootDigestAfter = 13, } impl GlobalMetadata { @@ -47,11 +47,11 @@ impl GlobalMetadata { Self::ReceiptTrieRoot, Self::NumStorageTries, Self::StateTrieRootDigestBefore, - Self::TransactionsTrieRootDigestBefore, - Self::ReceiptsTrieRootDigestBefore, + Self::TransactionTrieRootDigestBefore, + Self::ReceiptTrieRootDigestBefore, Self::StateTrieRootDigestAfter, - Self::TransactionsTrieRootDigestAfter, - Self::ReceiptsTrieRootDigestAfter, + Self::TransactionTrieRootDigestAfter, + Self::ReceiptTrieRootDigestAfter, ] } @@ -67,18 +67,18 @@ impl GlobalMetadata { GlobalMetadata::ReceiptTrieRoot => "GLOBAL_METADATA_RECEIPT_TRIE_ROOT", GlobalMetadata::NumStorageTries => "GLOBAL_METADATA_NUM_STORAGE_TRIES", GlobalMetadata::StateTrieRootDigestBefore => "GLOBAL_METADATA_STATE_TRIE_DIGEST_BEFORE", - GlobalMetadata::TransactionsTrieRootDigestBefore => { - "GLOBAL_METADATA_TXNS_TRIE_DIGEST_BEFORE" + GlobalMetadata::TransactionTrieRootDigestBefore => { + "GLOBAL_METADATA_TXN_TRIE_DIGEST_BEFORE" } - GlobalMetadata::ReceiptsTrieRootDigestBefore => { - "GLOBAL_METADATA_RECEIPTS_TRIE_DIGEST_BEFORE" + GlobalMetadata::ReceiptTrieRootDigestBefore => { + "GLOBAL_METADATA_RECEIPT_TRIE_DIGEST_BEFORE" } GlobalMetadata::StateTrieRootDigestAfter => "GLOBAL_METADATA_STATE_TRIE_DIGEST_AFTER", - GlobalMetadata::TransactionsTrieRootDigestAfter => { - "GLOBAL_METADATA_TXNS_TRIE_DIGEST_AFTER" + GlobalMetadata::TransactionTrieRootDigestAfter => { + "GLOBAL_METADATA_TXN_TRIE_DIGEST_AFTER" } - GlobalMetadata::ReceiptsTrieRootDigestAfter => { - "GLOBAL_METADATA_RECEIPTS_TRIE_DIGEST_AFTER" + GlobalMetadata::ReceiptTrieRootDigestAfter => { + "GLOBAL_METADATA_RECEIPT_TRIE_DIGEST_AFTER" } } } diff --git a/evm/src/cpu/kernel/interpreter.rs b/evm/src/cpu/kernel/interpreter.rs index 17be0523..45211848 100644 --- a/evm/src/cpu/kernel/interpreter.rs +++ b/evm/src/cpu/kernel/interpreter.rs @@ -1,16 +1,22 @@ use std::collections::HashMap; -use anyhow::{anyhow, bail}; -use ethereum_types::{BigEndianHash, U256, U512}; +use anyhow::{anyhow, bail, ensure}; +use ethereum_types::{U256, U512}; use keccak_hash::keccak; +use plonky2::field::goldilocks_field::GoldilocksField; use crate::cpu::kernel::aggregator::KERNEL; use crate::cpu::kernel::assembler::Kernel; -use crate::cpu::kernel::prover_input::ProverInputFn; +use crate::cpu::kernel::global_metadata::GlobalMetadata; use crate::cpu::kernel::txn_fields::NormalizedTxnField; use crate::generation::memory::{MemoryContextState, MemorySegmentState}; +use crate::generation::prover_input::ProverInputFn; +use crate::generation::state::GenerationState; +use crate::generation::GenerationInputs; use crate::memory::segments::Segment; +type F = GoldilocksField; + /// Halt interpreter execution whenever a jump to this offset is done. const DEFAULT_HALT_OFFSET: usize = 0xdeadbeef; @@ -41,10 +47,21 @@ impl InterpreterMemory { impl InterpreterMemory { fn mload_general(&self, context: usize, segment: Segment, offset: usize) -> U256 { - self.context_memory[context].segments[segment as usize].get(offset) + let value = self.context_memory[context].segments[segment as usize].get(offset); + assert!( + value.bits() <= segment.bit_range(), + "Value read from memory exceeds expected range of {:?} segment", + segment + ); + value } fn mstore_general(&mut self, context: usize, segment: Segment, offset: usize, value: U256) { + assert!( + value.bits() <= segment.bit_range(), + "Value written to memory exceeds expected range of {:?} segment", + segment + ); self.context_memory[context].segments[segment as usize].set(offset, value) } } @@ -52,11 +69,11 @@ impl InterpreterMemory { pub struct Interpreter<'a> { kernel_mode: bool, jumpdests: Vec, - offset: usize, + pub(crate) offset: usize, context: usize, pub(crate) memory: InterpreterMemory, + pub(crate) generation_state: GenerationState, prover_inputs_map: &'a HashMap, - prover_inputs: Vec, pub(crate) halt_offsets: Vec, running: bool, } @@ -107,15 +124,16 @@ impl<'a> Interpreter<'a> { jumpdests: find_jumpdests(code), offset: initial_offset, memory: InterpreterMemory::with_code_and_stack(code, initial_stack), + generation_state: GenerationState::new(GenerationInputs::default()), prover_inputs_map: prover_inputs, - prover_inputs: Vec::new(), context: 0, halt_offsets: vec![DEFAULT_HALT_OFFSET], - running: true, + running: false, } } pub(crate) fn run(&mut self) -> anyhow::Result<()> { + self.running = true; while self.running { self.run_opcode()?; } @@ -146,8 +164,16 @@ impl<'a> Interpreter<'a> { &self.memory.context_memory[0].segments[Segment::TxnData as usize].content } + pub(crate) fn get_global_metadata_field(&self, field: GlobalMetadata) -> U256 { + self.memory.context_memory[0].segments[Segment::GlobalMetadata as usize].get(field as usize) + } + + pub(crate) fn get_trie_data(&self) -> &[U256] { + &self.memory.context_memory[0].segments[Segment::TrieData as usize].content + } + pub(crate) fn get_rlp_memory(&self) -> Vec { - self.memory.context_memory[self.context].segments[Segment::RlpRaw as usize] + self.memory.context_memory[0].segments[Segment::RlpRaw as usize] .content .iter() .map(|x| x.as_u32() as u8) @@ -155,7 +181,7 @@ impl<'a> Interpreter<'a> { } pub(crate) fn set_rlp_memory(&mut self, rlp: Vec) { - self.memory.context_memory[self.context].segments[Segment::RlpRaw as usize].content = + self.memory.context_memory[0].segments[Segment::RlpRaw as usize].content = rlp.into_iter().map(U256::from).collect(); } @@ -171,7 +197,7 @@ impl<'a> Interpreter<'a> { &mut self.memory.context_memory[self.context].segments[Segment::Stack as usize].content } - fn push(&mut self, x: U256) { + pub(crate) fn push(&mut self, x: U256) { self.stack_mut().push(x); } @@ -187,99 +213,100 @@ impl<'a> Interpreter<'a> { let opcode = self.code().get(self.offset).byte(0); self.incr(1); match opcode { - 0x00 => self.run_stop(), // "STOP", - 0x01 => self.run_add(), // "ADD", - 0x02 => self.run_mul(), // "MUL", - 0x03 => self.run_sub(), // "SUB", - 0x04 => self.run_div(), // "DIV", - 0x05 => todo!(), // "SDIV", - 0x06 => self.run_mod(), // "MOD", - 0x07 => todo!(), // "SMOD", - 0x08 => self.run_addmod(), // "ADDMOD", - 0x09 => self.run_mulmod(), // "MULMOD", - 0x0a => self.run_exp(), // "EXP", - 0x0b => todo!(), // "SIGNEXTEND", - 0x10 => self.run_lt(), // "LT", - 0x11 => self.run_gt(), // "GT", - 0x12 => todo!(), // "SLT", - 0x13 => todo!(), // "SGT", - 0x14 => self.run_eq(), // "EQ", - 0x15 => self.run_iszero(), // "ISZERO", - 0x16 => self.run_and(), // "AND", - 0x17 => self.run_or(), // "OR", - 0x18 => self.run_xor(), // "XOR", - 0x19 => self.run_not(), // "NOT", - 0x1a => self.run_byte(), // "BYTE", - 0x1b => self.run_shl(), // "SHL", - 0x1c => todo!(), // "SHR", - 0x1d => todo!(), // "SAR", - 0x20 => self.run_keccak256(), // "KECCAK256", - 0x30 => todo!(), // "ADDRESS", - 0x31 => todo!(), // "BALANCE", - 0x32 => todo!(), // "ORIGIN", - 0x33 => todo!(), // "CALLER", - 0x34 => todo!(), // "CALLVALUE", - 0x35 => todo!(), // "CALLDATALOAD", - 0x36 => todo!(), // "CALLDATASIZE", - 0x37 => todo!(), // "CALLDATACOPY", - 0x38 => todo!(), // "CODESIZE", - 0x39 => todo!(), // "CODECOPY", - 0x3a => todo!(), // "GASPRICE", - 0x3b => todo!(), // "EXTCODESIZE", - 0x3c => todo!(), // "EXTCODECOPY", - 0x3d => todo!(), // "RETURNDATASIZE", - 0x3e => todo!(), // "RETURNDATACOPY", - 0x3f => todo!(), // "EXTCODEHASH", - 0x40 => todo!(), // "BLOCKHASH", - 0x41 => todo!(), // "COINBASE", - 0x42 => todo!(), // "TIMESTAMP", - 0x43 => todo!(), // "NUMBER", - 0x44 => todo!(), // "DIFFICULTY", - 0x45 => todo!(), // "GASLIMIT", - 0x46 => todo!(), // "CHAINID", - 0x48 => todo!(), // "BASEFEE", - 0x49 => self.run_prover_input()?, // "PROVER_INPUT", - 0x50 => self.run_pop(), // "POP", - 0x51 => self.run_mload(), // "MLOAD", - 0x52 => self.run_mstore(), // "MSTORE", - 0x53 => self.run_mstore8(), // "MSTORE8", - 0x54 => todo!(), // "SLOAD", - 0x55 => todo!(), // "SSTORE", - 0x56 => self.run_jump(), // "JUMP", - 0x57 => self.run_jumpi(), // "JUMPI", - 0x58 => todo!(), // "GETPC", - 0x59 => todo!(), // "MSIZE", - 0x5a => todo!(), // "GAS", - 0x5b => (), // "JUMPDEST", - 0x5c => todo!(), // "GET_STATE_ROOT", - 0x5d => todo!(), // "SET_STATE_ROOT", - 0x5e => todo!(), // "GET_RECEIPT_ROOT", - 0x5f => todo!(), // "SET_RECEIPT_ROOT", - x if (0x60..0x80).contains(&x) => self.run_push(x - 0x5f), // "PUSH" - x if (0x80..0x90).contains(&x) => self.run_dup(x - 0x7f), // "DUP" - x if (0x90..0xa0).contains(&x) => self.run_swap(x - 0x8f), // "SWAP" - 0xa0 => todo!(), // "LOG0", - 0xa1 => todo!(), // "LOG1", - 0xa2 => todo!(), // "LOG2", - 0xa3 => todo!(), // "LOG3", - 0xa4 => todo!(), // "LOG4", - 0xa5 => bail!("Executed PANIC"), // "PANIC", - 0xf0 => todo!(), // "CREATE", - 0xf1 => todo!(), // "CALL", - 0xf2 => todo!(), // "CALLCODE", - 0xf3 => todo!(), // "RETURN", - 0xf4 => todo!(), // "DELEGATECALL", - 0xf5 => todo!(), // "CREATE2", - 0xf6 => self.run_get_context(), // "GET_CONTEXT", - 0xf7 => self.run_set_context(), // "SET_CONTEXT", - 0xf8 => todo!(), // "CONSUME_GAS", - 0xf9 => todo!(), // "EXIT_KERNEL", - 0xfa => todo!(), // "STATICCALL", - 0xfb => self.run_mload_general(), // "MLOAD_GENERAL", - 0xfc => self.run_mstore_general(), // "MSTORE_GENERAL", - 0xfd => todo!(), // "REVERT", - 0xfe => bail!("Executed INVALID"), // "INVALID", - 0xff => todo!(), // "SELFDESTRUCT", + 0x00 => self.run_stop(), // "STOP", + 0x01 => self.run_add(), // "ADD", + 0x02 => self.run_mul(), // "MUL", + 0x03 => self.run_sub(), // "SUB", + 0x04 => self.run_div(), // "DIV", + 0x05 => todo!(), // "SDIV", + 0x06 => self.run_mod(), // "MOD", + 0x07 => todo!(), // "SMOD", + 0x08 => self.run_addmod(), // "ADDMOD", + 0x09 => self.run_mulmod(), // "MULMOD", + 0x0a => self.run_exp(), // "EXP", + 0x0b => todo!(), // "SIGNEXTEND", + 0x10 => self.run_lt(), // "LT", + 0x11 => self.run_gt(), // "GT", + 0x12 => todo!(), // "SLT", + 0x13 => todo!(), // "SGT", + 0x14 => self.run_eq(), // "EQ", + 0x15 => self.run_iszero(), // "ISZERO", + 0x16 => self.run_and(), // "AND", + 0x17 => self.run_or(), // "OR", + 0x18 => self.run_xor(), // "XOR", + 0x19 => self.run_not(), // "NOT", + 0x1a => self.run_byte(), // "BYTE", + 0x1b => self.run_shl(), // "SHL", + 0x1c => self.run_shr(), // "SHR", + 0x1d => todo!(), // "SAR", + 0x20 => self.run_keccak256(), // "KECCAK256", + 0x21 => self.run_keccak_general(), // "KECCAK_GENERAL", + 0x30 => todo!(), // "ADDRESS", + 0x31 => todo!(), // "BALANCE", + 0x32 => todo!(), // "ORIGIN", + 0x33 => todo!(), // "CALLER", + 0x34 => todo!(), // "CALLVALUE", + 0x35 => todo!(), // "CALLDATALOAD", + 0x36 => todo!(), // "CALLDATASIZE", + 0x37 => todo!(), // "CALLDATACOPY", + 0x38 => todo!(), // "CODESIZE", + 0x39 => todo!(), // "CODECOPY", + 0x3a => todo!(), // "GASPRICE", + 0x3b => todo!(), // "EXTCODESIZE", + 0x3c => todo!(), // "EXTCODECOPY", + 0x3d => todo!(), // "RETURNDATASIZE", + 0x3e => todo!(), // "RETURNDATACOPY", + 0x3f => todo!(), // "EXTCODEHASH", + 0x40 => todo!(), // "BLOCKHASH", + 0x41 => todo!(), // "COINBASE", + 0x42 => todo!(), // "TIMESTAMP", + 0x43 => todo!(), // "NUMBER", + 0x44 => todo!(), // "DIFFICULTY", + 0x45 => todo!(), // "GASLIMIT", + 0x46 => todo!(), // "CHAINID", + 0x48 => todo!(), // "BASEFEE", + 0x49 => self.run_prover_input()?, // "PROVER_INPUT", + 0x50 => self.run_pop(), // "POP", + 0x51 => self.run_mload(), // "MLOAD", + 0x52 => self.run_mstore(), // "MSTORE", + 0x53 => self.run_mstore8(), // "MSTORE8", + 0x54 => todo!(), // "SLOAD", + 0x55 => todo!(), // "SSTORE", + 0x56 => self.run_jump(), // "JUMP", + 0x57 => self.run_jumpi(), // "JUMPI", + 0x58 => todo!(), // "GETPC", + 0x59 => self.run_msize(), // "MSIZE", + 0x5a => todo!(), // "GAS", + 0x5b => self.run_jumpdest(), // "JUMPDEST", + 0x5c => todo!(), // "GET_STATE_ROOT", + 0x5d => todo!(), // "SET_STATE_ROOT", + 0x5e => todo!(), // "GET_RECEIPT_ROOT", + 0x5f => todo!(), // "SET_RECEIPT_ROOT", + x if (0x60..0x80).contains(&x) => self.run_push(x - 0x5f), // "PUSH" + x if (0x80..0x90).contains(&x) => self.run_dup(x - 0x7f), // "DUP" + x if (0x90..0xa0).contains(&x) => self.run_swap(x - 0x8f)?, // "SWAP" + 0xa0 => todo!(), // "LOG0", + 0xa1 => todo!(), // "LOG1", + 0xa2 => todo!(), // "LOG2", + 0xa3 => todo!(), // "LOG3", + 0xa4 => todo!(), // "LOG4", + 0xa5 => bail!("Executed PANIC"), // "PANIC", + 0xf0 => todo!(), // "CREATE", + 0xf1 => todo!(), // "CALL", + 0xf2 => todo!(), // "CALLCODE", + 0xf3 => todo!(), // "RETURN", + 0xf4 => todo!(), // "DELEGATECALL", + 0xf5 => todo!(), // "CREATE2", + 0xf6 => self.run_get_context(), // "GET_CONTEXT", + 0xf7 => self.run_set_context(), // "SET_CONTEXT", + 0xf8 => todo!(), // "CONSUME_GAS", + 0xf9 => todo!(), // "EXIT_KERNEL", + 0xfa => todo!(), // "STATICCALL", + 0xfb => self.run_mload_general(), // "MLOAD_GENERAL", + 0xfc => self.run_mstore_general(), // "MSTORE_GENERAL", + 0xfd => todo!(), // "REVERT", + 0xfe => bail!("Executed INVALID"), // "INVALID", + 0xff => todo!(), // "SELFDESTRUCT", _ => bail!("Unrecognized opcode {}.", opcode), }; Ok(()) @@ -412,6 +439,12 @@ impl<'a> Interpreter<'a> { self.push(x << shift); } + fn run_shr(&mut self) { + let shift = self.pop(); + let x = self.pop(); + self.push(x >> shift); + } + fn run_keccak256(&mut self) { let offset = self.pop().as_usize(); let size = self.pop().as_usize(); @@ -423,7 +456,19 @@ impl<'a> Interpreter<'a> { }) .collect::>(); let hash = keccak(bytes); - self.push(hash.into_uint()); + self.push(U256::from_big_endian(hash.as_bytes())); + } + + fn run_keccak_general(&mut self) { + let context = self.pop().as_usize(); + let segment = Segment::all()[self.pop().as_usize()]; + let offset = self.pop().as_usize(); + let size = self.pop().as_usize(); + let bytes = (offset..offset + size) + .map(|i| self.memory.mload_general(context, segment, i).byte(0)) + .collect::>(); + let hash = keccak(bytes); + self.push(U256::from_big_endian(hash.as_bytes())); } fn run_prover_input(&mut self) -> anyhow::Result<()> { @@ -431,9 +476,9 @@ impl<'a> Interpreter<'a> { .prover_inputs_map .get(&(self.offset - 1)) .ok_or_else(|| anyhow!("Offset not in prover inputs."))?; - let output = prover_input_fn.run(self.stack()); + let stack = self.stack().to_vec(); + let output = self.generation_state.prover_input(&stack, prover_input_fn); self.push(output); - self.prover_inputs.push(output); Ok(()) } @@ -490,6 +535,18 @@ impl<'a> Interpreter<'a> { } } + fn run_msize(&mut self) { + let num_bytes = self.memory.context_memory[self.context].segments + [Segment::MainMemory as usize] + .content + .len(); + self.push(U256::from(num_bytes)); + } + + fn run_jumpdest(&mut self) { + assert!(!self.kernel_mode, "JUMPDEST is not needed in kernel code"); + } + fn jump_to(&mut self, offset: usize) { // The JUMPDEST rule is not enforced in kernel mode. if !self.kernel_mode && self.jumpdests.binary_search(&offset).is_err() { @@ -513,9 +570,11 @@ impl<'a> Interpreter<'a> { self.push(self.stack()[self.stack().len() - n as usize]); } - fn run_swap(&mut self, n: u8) { + fn run_swap(&mut self, n: u8) -> anyhow::Result<()> { let len = self.stack().len(); + ensure!(len > n as usize); self.stack_mut().swap(len - 1, len - n as usize - 1); + Ok(()) } fn run_get_context(&mut self) { @@ -532,7 +591,6 @@ impl<'a> Interpreter<'a> { let segment = Segment::all()[self.pop().as_usize()]; let offset = self.pop().as_usize(); let value = self.memory.mload_general(context, segment, offset); - assert!(value.bits() <= segment.bit_range()); self.push(value); } @@ -541,7 +599,6 @@ impl<'a> Interpreter<'a> { let segment = Segment::all()[self.pop().as_usize()]; let offset = self.pop().as_usize(); let value = self.pop(); - assert!(value.bits() <= segment.bit_range()); self.memory.mstore_general(context, segment, offset, value); } } diff --git a/evm/src/cpu/kernel/mod.rs b/evm/src/cpu/kernel/mod.rs index ef5a9ba0..e14a6cd6 100644 --- a/evm/src/cpu/kernel/mod.rs +++ b/evm/src/cpu/kernel/mod.rs @@ -1,7 +1,7 @@ pub mod aggregator; pub mod assembler; mod ast; -mod constants; +pub(crate) mod constants; pub(crate) mod context_metadata; mod cost_estimator; pub(crate) mod global_metadata; @@ -9,7 +9,6 @@ pub(crate) mod keccak_util; mod opcodes; mod optimizer; mod parser; -pub mod prover_input; pub mod stack; mod txn_fields; mod utils; diff --git a/evm/src/cpu/kernel/opcodes.rs b/evm/src/cpu/kernel/opcodes.rs index 69ee13fe..2325c53a 100644 --- a/evm/src/cpu/kernel/opcodes.rs +++ b/evm/src/cpu/kernel/opcodes.rs @@ -35,6 +35,7 @@ pub(crate) fn get_opcode(mnemonic: &str) -> u8 { "SHR" => 0x1c, "SAR" => 0x1d, "KECCAK256" => 0x20, + "KECCAK_GENERAL" => 0x21, "ADDRESS" => 0x30, "BALANCE" => 0x31, "ORIGIN" => 0x32, diff --git a/evm/src/cpu/kernel/optimizer.rs b/evm/src/cpu/kernel/optimizer.rs index e23bf520..e2504203 100644 --- a/evm/src/cpu/kernel/optimizer.rs +++ b/evm/src/cpu/kernel/optimizer.rs @@ -80,9 +80,9 @@ fn no_op_jumps(code: &mut Vec) { replace_windows(code, |window| { if let [Push(Label(l)), StandardOp(jump), decl] = window && &jump == "JUMP" - && (decl == LocalLabelDeclaration(l.clone()) || decl == GlobalLabelDeclaration(l.clone())) + && (decl == LocalLabelDeclaration(l.clone()) || decl == GlobalLabelDeclaration(l)) { - Some(vec![LocalLabelDeclaration(l)]) + Some(vec![decl]) } else { None } diff --git a/evm/src/cpu/kernel/parser.rs b/evm/src/cpu/kernel/parser.rs index 35bde4b6..fd762eae 100644 --- a/evm/src/cpu/kernel/parser.rs +++ b/evm/src/cpu/kernel/parser.rs @@ -4,6 +4,7 @@ use ethereum_types::U256; use pest::iterators::Pair; use pest::Parser; +use super::ast::StackPlaceholder; use crate::cpu::kernel::ast::{File, Item, PushTarget, StackReplacement}; /// Parses EVM assembly code. @@ -98,20 +99,35 @@ fn parse_stack(item: Pair) -> Item { assert_eq!(item.as_rule(), Rule::stack); let mut inner = item.into_inner(); - let params = inner.next().unwrap(); - assert_eq!(params.as_rule(), Rule::paramlist); + let placeholders = inner.next().unwrap(); + assert_eq!(placeholders.as_rule(), Rule::stack_placeholders); let replacements = inner.next().unwrap(); assert_eq!(replacements.as_rule(), Rule::stack_replacements); - let params = params + let placeholders = placeholders .into_inner() - .map(|param| param.as_str().to_string()) + .map(parse_stack_placeholder) .collect(); let replacements = replacements .into_inner() .map(parse_stack_replacement) .collect(); - Item::StackManipulation(params, replacements) + Item::StackManipulation(placeholders, replacements) +} + +fn parse_stack_placeholder(target: Pair) -> StackPlaceholder { + assert_eq!(target.as_rule(), Rule::stack_placeholder); + let inner = target.into_inner().next().unwrap(); + match inner.as_rule() { + Rule::identifier => StackPlaceholder::Identifier(inner.as_str().into()), + Rule::stack_block => { + let mut block = inner.into_inner(); + let identifier = block.next().unwrap().as_str(); + let length = block.next().unwrap().as_str().parse().unwrap(); + StackPlaceholder::Block(identifier.to_string(), length) + } + _ => panic!("Unexpected {:?}", inner.as_rule()), + } } fn parse_stack_replacement(target: Pair) -> StackReplacement { diff --git a/evm/src/cpu/kernel/stack/stack_manipulation.rs b/evm/src/cpu/kernel/stack/stack_manipulation.rs index 9f685953..ebc54af1 100644 --- a/evm/src/cpu/kernel/stack/stack_manipulation.rs +++ b/evm/src/cpu/kernel/stack/stack_manipulation.rs @@ -1,13 +1,13 @@ use std::cmp::Ordering; use std::collections::hash_map::Entry::{Occupied, Vacant}; -use std::collections::{BinaryHeap, HashMap}; +use std::collections::{BinaryHeap, HashMap, HashSet}; use std::hash::Hash; use itertools::Itertools; use crate::cpu::columns::NUM_CPU_COLUMNS; use crate::cpu::kernel::assembler::BYTES_PER_OFFSET; -use crate::cpu::kernel::ast::{Item, PushTarget, StackReplacement}; +use crate::cpu::kernel::ast::{Item, PushTarget, StackPlaceholder, StackReplacement}; use crate::cpu::kernel::stack::permutations::{get_stack_ops_for_perm, is_permutation}; use crate::cpu::kernel::stack::stack_manipulation::StackOp::Pop; use crate::cpu::kernel::utils::u256_to_trimmed_be_bytes; @@ -25,25 +25,51 @@ pub(crate) fn expand_stack_manipulation(body: Vec) -> Vec { expanded } -fn expand(names: Vec, replacements: Vec) -> Vec { +fn expand(names: Vec, replacements: Vec) -> Vec { + let mut stack_blocks = HashMap::new(); + let mut stack_names = HashSet::new(); + let mut src = names .iter() .cloned() - .map(StackItem::NamedItem) + .flat_map(|item| match item { + StackPlaceholder::Identifier(name) => { + stack_names.insert(name.clone()); + vec![StackItem::NamedItem(name)] + } + StackPlaceholder::Block(name, n) => { + stack_blocks.insert(name.clone(), n); + (0..n) + .map(|i| { + let literal_name = format!("block_{}_{}", name, i); + StackItem::NamedItem(literal_name) + }) + .collect_vec() + } + }) .collect_vec(); let mut dst = replacements .into_iter() - .map(|item| match item { + .flat_map(|item| match item { + StackReplacement::Literal(n) => vec![StackItem::PushTarget(PushTarget::Literal(n))], StackReplacement::Identifier(name) => { // May be either a named item or a label. Named items have precedence. - if names.contains(&name) { - StackItem::NamedItem(name) + if stack_blocks.contains_key(&name) { + let n = *stack_blocks.get(&name).unwrap(); + (0..n) + .map(|i| { + let literal_name = format!("block_{}_{}", name, i); + StackItem::NamedItem(literal_name) + }) + .collect_vec() + } else if stack_names.contains(&name) { + vec![StackItem::NamedItem(name)] } else { - StackItem::PushTarget(PushTarget::Label(name)) + vec![StackItem::PushTarget(PushTarget::Label(name))] } } - StackReplacement::Literal(n) => StackItem::PushTarget(PushTarget::Literal(n)), + StackReplacement::Label(name) => vec![StackItem::PushTarget(PushTarget::Label(name))], StackReplacement::MacroLabel(_) | StackReplacement::MacroVar(_) | StackReplacement::Constant(_) => { diff --git a/evm/src/cpu/kernel/tests/mod.rs b/evm/src/cpu/kernel/tests/mod.rs index 925db56f..a9c8c08c 100644 --- a/evm/src/cpu/kernel/tests/mod.rs +++ b/evm/src/cpu/kernel/tests/mod.rs @@ -2,6 +2,7 @@ mod core; mod curve_ops; mod ecrecover; mod exp; +mod mpt; mod packing; mod rlp; mod transaction_parsing; diff --git a/evm/src/cpu/kernel/tests/mpt/hash.rs b/evm/src/cpu/kernel/tests/mpt/hash.rs new file mode 100644 index 00000000..5f212e3c --- /dev/null +++ b/evm/src/cpu/kernel/tests/mpt/hash.rs @@ -0,0 +1,62 @@ +use anyhow::Result; +use eth_trie_utils::partial_trie::{Nibbles, PartialTrie}; +use ethereum_types::{BigEndianHash, H256, U256}; +use hex_literal::hex; + +use crate::cpu::kernel::aggregator::KERNEL; +use crate::cpu::kernel::interpreter::Interpreter; +use crate::generation::mpt::all_mpt_prover_inputs_reversed; +use crate::generation::TrieInputs; + +#[test] +fn mpt_hash() -> Result<()> { + let nonce = U256::from(1111); + let balance = U256::from(2222); + let storage_root = U256::from(3333); + let code_hash = U256::from(4444); + + let account = &[nonce, balance, storage_root, code_hash]; + let account_rlp = rlp::encode_list(account); + + // TODO: Try this more "advanced" trie. + // let state_trie = state_trie_ext_to_account_leaf(account_rlp.to_vec()); + let state_trie = PartialTrie::Leaf { + nibbles: Nibbles { + count: 3, + packed: 0xABC.into(), + }, + value: account_rlp.to_vec(), + }; + // TODO: It seems like calc_hash isn't giving the expected hash yet, so for now, I'm using a + // hardcoded hash obtained from py-evm. + // let state_trie_hash = state_trie.calc_hash(); + let state_trie_hash = + hex!("e38d6053838fe057c865ec0c74a8f0de21865d74fac222a2d3241fe57c9c3a0f").into(); + + let trie_inputs = TrieInputs { + state_trie, + transactions_trie: Default::default(), + receipts_trie: Default::default(), + storage_tries: vec![], + }; + + let load_all_mpts = KERNEL.global_labels["load_all_mpts"]; + let mpt_hash_state_trie = KERNEL.global_labels["mpt_hash_state_trie"]; + + let initial_stack = vec![0xdeadbeefu32.into()]; + let mut interpreter = Interpreter::new_with_kernel(load_all_mpts, initial_stack); + interpreter.generation_state.mpt_prover_inputs = all_mpt_prover_inputs_reversed(&trie_inputs); + interpreter.run()?; + assert_eq!(interpreter.stack(), vec![]); + + // Now, execute mpt_hash_state_trie. + interpreter.offset = mpt_hash_state_trie; + interpreter.push(0xDEADBEEFu32.into()); + interpreter.run()?; + + assert_eq!(interpreter.stack().len(), 1); + let hash = H256::from_uint(&interpreter.stack()[0]); + assert_eq!(hash, state_trie_hash); + + Ok(()) +} diff --git a/evm/src/cpu/kernel/tests/mpt/hex_prefix.rs b/evm/src/cpu/kernel/tests/mpt/hex_prefix.rs new file mode 100644 index 00000000..c13b8122 --- /dev/null +++ b/evm/src/cpu/kernel/tests/mpt/hex_prefix.rs @@ -0,0 +1,87 @@ +use anyhow::Result; + +use crate::cpu::kernel::aggregator::KERNEL; +use crate::cpu::kernel::interpreter::Interpreter; + +#[test] +fn hex_prefix_even_nonterminated() -> Result<()> { + let hex_prefix = KERNEL.global_labels["hex_prefix_rlp"]; + + let retdest = 0xDEADBEEFu32.into(); + let terminated = 0.into(); + let packed_nibbles = 0xABCDEF.into(); + let num_nibbles = 6.into(); + let rlp_pos = 0.into(); + let initial_stack = vec![retdest, terminated, packed_nibbles, num_nibbles, rlp_pos]; + let mut interpreter = Interpreter::new_with_kernel(hex_prefix, initial_stack); + interpreter.run()?; + assert_eq!(interpreter.stack(), vec![5.into()]); + + assert_eq!( + interpreter.get_rlp_memory(), + vec![ + 0x80 + 4, // prefix + 0, // neither flag is set + 0xAB, + 0xCD, + 0xEF + ] + ); + + Ok(()) +} + +#[test] +fn hex_prefix_odd_terminated() -> Result<()> { + let hex_prefix = KERNEL.global_labels["hex_prefix_rlp"]; + + let retdest = 0xDEADBEEFu32.into(); + let terminated = 1.into(); + let packed_nibbles = 0xABCDE.into(); + let num_nibbles = 5.into(); + let rlp_pos = 0.into(); + let initial_stack = vec![retdest, terminated, packed_nibbles, num_nibbles, rlp_pos]; + let mut interpreter = Interpreter::new_with_kernel(hex_prefix, initial_stack); + interpreter.run()?; + assert_eq!(interpreter.stack(), vec![4.into()]); + + assert_eq!( + interpreter.get_rlp_memory(), + vec![ + 0x80 + 3, // prefix + (2 + 1) * 16 + 0xA, + 0xBC, + 0xDE, + ] + ); + + Ok(()) +} + +#[test] +fn hex_prefix_odd_terminated_tiny() -> Result<()> { + let hex_prefix = KERNEL.global_labels["hex_prefix_rlp"]; + + let retdest = 0xDEADBEEFu32.into(); + let terminated = 1.into(); + let packed_nibbles = 0xA.into(); + let num_nibbles = 1.into(); + let rlp_pos = 2.into(); + let initial_stack = vec![retdest, terminated, packed_nibbles, num_nibbles, rlp_pos]; + let mut interpreter = Interpreter::new_with_kernel(hex_prefix, initial_stack); + interpreter.run()?; + assert_eq!(interpreter.stack(), vec![3.into()]); + + assert_eq!( + interpreter.get_rlp_memory(), + vec![ + // Since rlp_pos = 2, we skipped over the first two bytes. + 0, + 0, + // No length prefix; this tiny string is its own RLP encoding. + (2 + 1) * 16 + 0xA, + ] + ); + + Ok(()) +} diff --git a/evm/src/cpu/kernel/tests/mpt/load.rs b/evm/src/cpu/kernel/tests/mpt/load.rs new file mode 100644 index 00000000..fbbc690a --- /dev/null +++ b/evm/src/cpu/kernel/tests/mpt/load.rs @@ -0,0 +1,65 @@ +use anyhow::Result; +use ethereum_types::U256; + +use crate::cpu::kernel::aggregator::KERNEL; +use crate::cpu::kernel::constants::trie_type::PartialTrieType; +use crate::cpu::kernel::global_metadata::GlobalMetadata; +use crate::cpu::kernel::interpreter::Interpreter; +use crate::cpu::kernel::tests::mpt::state_trie_ext_to_account_leaf; +use crate::generation::mpt::all_mpt_prover_inputs_reversed; +use crate::generation::TrieInputs; + +#[test] +fn load_all_mpts() -> Result<()> { + let nonce = U256::from(1111); + let balance = U256::from(2222); + let storage_root = U256::from(3333); + let code_hash = U256::from(4444); + + let account_rlp = rlp::encode_list(&[nonce, balance, storage_root, code_hash]); + + let trie_inputs = TrieInputs { + state_trie: state_trie_ext_to_account_leaf(account_rlp.to_vec()), + transactions_trie: Default::default(), + receipts_trie: Default::default(), + storage_tries: vec![], + }; + + let load_all_mpts = KERNEL.global_labels["load_all_mpts"]; + + let initial_stack = vec![0xdeadbeefu32.into()]; + let mut interpreter = Interpreter::new_with_kernel(load_all_mpts, initial_stack); + interpreter.generation_state.mpt_prover_inputs = all_mpt_prover_inputs_reversed(&trie_inputs); + interpreter.run()?; + assert_eq!(interpreter.stack(), vec![]); + + let type_empty = U256::from(PartialTrieType::Empty as u32); + let type_extension = U256::from(PartialTrieType::Extension as u32); + let type_leaf = U256::from(PartialTrieType::Leaf as u32); + assert_eq!( + interpreter.get_trie_data(), + vec![ + 0.into(), // First address is unused, so that 0 can be treated as a null pointer. + type_extension, + 3.into(), // 3 nibbles + 0xABC.into(), // key part + 5.into(), // Pointer to the leaf node immediately below. + type_leaf, + 3.into(), // 3 nibbles + 0xDEF.into(), // key part + nonce, + balance, + storage_root, + code_hash, + type_empty, + type_empty, + ] + ); + + assert_eq!( + interpreter.get_global_metadata_field(GlobalMetadata::NumStorageTries), + trie_inputs.storage_tries.len().into() + ); + + Ok(()) +} diff --git a/evm/src/cpu/kernel/tests/mpt/mod.rs b/evm/src/cpu/kernel/tests/mpt/mod.rs new file mode 100644 index 00000000..8308962a --- /dev/null +++ b/evm/src/cpu/kernel/tests/mpt/mod.rs @@ -0,0 +1,23 @@ +use eth_trie_utils::partial_trie::{Nibbles, PartialTrie}; + +mod hash; +mod hex_prefix; +mod load; +mod read; + +/// A `PartialTrie` where an extension node leads to a leaf node containing an account. +pub(crate) fn state_trie_ext_to_account_leaf(value: Vec) -> PartialTrie { + PartialTrie::Extension { + nibbles: Nibbles { + count: 3, + packed: 0xABC.into(), + }, + child: Box::new(PartialTrie::Leaf { + nibbles: Nibbles { + count: 3, + packed: 0xDEF.into(), + }, + value, + }), + } +} diff --git a/evm/src/cpu/kernel/tests/mpt/read.rs b/evm/src/cpu/kernel/tests/mpt/read.rs new file mode 100644 index 00000000..e20aa0fe --- /dev/null +++ b/evm/src/cpu/kernel/tests/mpt/read.rs @@ -0,0 +1,51 @@ +use anyhow::Result; +use ethereum_types::U256; + +use crate::cpu::kernel::aggregator::KERNEL; +use crate::cpu::kernel::global_metadata::GlobalMetadata; +use crate::cpu::kernel::interpreter::Interpreter; +use crate::cpu::kernel::tests::mpt::state_trie_ext_to_account_leaf; +use crate::generation::mpt::all_mpt_prover_inputs_reversed; +use crate::generation::TrieInputs; + +#[test] +fn mpt_read() -> Result<()> { + let nonce = U256::from(1111); + let balance = U256::from(2222); + let storage_root = U256::from(3333); + let code_hash = U256::from(4444); + + let account = &[nonce, balance, storage_root, code_hash]; + let account_rlp = rlp::encode_list(account); + + let trie_inputs = TrieInputs { + state_trie: state_trie_ext_to_account_leaf(account_rlp.to_vec()), + transactions_trie: Default::default(), + receipts_trie: Default::default(), + storage_tries: vec![], + }; + + let load_all_mpts = KERNEL.global_labels["load_all_mpts"]; + let mpt_read = KERNEL.global_labels["mpt_read"]; + + let initial_stack = vec![0xdeadbeefu32.into()]; + let mut interpreter = Interpreter::new_with_kernel(load_all_mpts, initial_stack); + interpreter.generation_state.mpt_prover_inputs = all_mpt_prover_inputs_reversed(&trie_inputs); + interpreter.run()?; + assert_eq!(interpreter.stack(), vec![]); + + // Now, execute mpt_read on the state trie. + interpreter.offset = mpt_read; + interpreter.push(0xdeadbeefu32.into()); + interpreter.push(0xABCDEFu64.into()); + interpreter.push(6.into()); + interpreter.push(interpreter.get_global_metadata_field(GlobalMetadata::StateTrieRoot)); + interpreter.run()?; + + assert_eq!(interpreter.stack().len(), 1); + let result_ptr = interpreter.stack()[0].as_usize(); + let result = &interpreter.get_trie_data()[result_ptr..][..account.len()]; + assert_eq!(result, account); + + Ok(()) +} diff --git a/evm/src/cpu/kernel/tests/rlp.rs b/evm/src/cpu/kernel/tests/rlp/decode.rs similarity index 59% rename from evm/src/cpu/kernel/tests/rlp.rs rename to evm/src/cpu/kernel/tests/rlp/decode.rs index 37949e13..a1ca3609 100644 --- a/evm/src/cpu/kernel/tests/rlp.rs +++ b/evm/src/cpu/kernel/tests/rlp/decode.rs @@ -3,84 +3,6 @@ use anyhow::Result; use crate::cpu::kernel::aggregator::KERNEL; use crate::cpu::kernel::interpreter::Interpreter; -#[test] -fn test_encode_rlp_scalar_small() -> Result<()> { - let encode_rlp_scalar = KERNEL.global_labels["encode_rlp_scalar"]; - - let retdest = 0xDEADBEEFu32.into(); - let scalar = 42.into(); - let pos = 2.into(); - let initial_stack = vec![retdest, scalar, pos]; - let mut interpreter = Interpreter::new_with_kernel(encode_rlp_scalar, initial_stack); - - interpreter.run()?; - let expected_stack = vec![3.into()]; // pos' = pos + rlp_len = 2 + 1 - let expected_rlp = vec![0, 0, 42]; - assert_eq!(interpreter.stack(), expected_stack); - assert_eq!(interpreter.get_rlp_memory(), expected_rlp); - - Ok(()) -} - -#[test] -fn test_encode_rlp_scalar_medium() -> Result<()> { - let encode_rlp_scalar = KERNEL.global_labels["encode_rlp_scalar"]; - - let retdest = 0xDEADBEEFu32.into(); - let scalar = 0x12345.into(); - let pos = 2.into(); - let initial_stack = vec![retdest, scalar, pos]; - let mut interpreter = Interpreter::new_with_kernel(encode_rlp_scalar, initial_stack); - - interpreter.run()?; - let expected_stack = vec![6.into()]; // pos' = pos + rlp_len = 2 + 4 - let expected_rlp = vec![0, 0, 0x80 + 3, 0x01, 0x23, 0x45]; - assert_eq!(interpreter.stack(), expected_stack); - assert_eq!(interpreter.get_rlp_memory(), expected_rlp); - - Ok(()) -} - -#[test] -fn test_encode_rlp_160() -> Result<()> { - let encode_rlp_160 = KERNEL.global_labels["encode_rlp_160"]; - - let retdest = 0xDEADBEEFu32.into(); - let string = 0x12345.into(); - let pos = 0.into(); - let initial_stack = vec![retdest, string, pos]; - let mut interpreter = Interpreter::new_with_kernel(encode_rlp_160, initial_stack); - - interpreter.run()?; - let expected_stack = vec![(1 + 20).into()]; // pos' - #[rustfmt::skip] - let expected_rlp = vec![0x80 + 20, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01, 0x23, 0x45]; - assert_eq!(interpreter.stack(), expected_stack); - assert_eq!(interpreter.get_rlp_memory(), expected_rlp); - - Ok(()) -} - -#[test] -fn test_encode_rlp_256() -> Result<()> { - let encode_rlp_256 = KERNEL.global_labels["encode_rlp_256"]; - - let retdest = 0xDEADBEEFu32.into(); - let string = 0x12345.into(); - let pos = 0.into(); - let initial_stack = vec![retdest, string, pos]; - let mut interpreter = Interpreter::new_with_kernel(encode_rlp_256, initial_stack); - - interpreter.run()?; - let expected_stack = vec![(1 + 32).into()]; // pos' - #[rustfmt::skip] - let expected_rlp = vec![0x80 + 32, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01, 0x23, 0x45]; - assert_eq!(interpreter.stack(), expected_stack); - assert_eq!(interpreter.get_rlp_memory(), expected_rlp); - - Ok(()) -} - #[test] fn test_decode_rlp_string_len_short() -> Result<()> { let decode_rlp_string_len = KERNEL.global_labels["decode_rlp_string_len"]; diff --git a/evm/src/cpu/kernel/tests/rlp/encode.rs b/evm/src/cpu/kernel/tests/rlp/encode.rs new file mode 100644 index 00000000..4e04b248 --- /dev/null +++ b/evm/src/cpu/kernel/tests/rlp/encode.rs @@ -0,0 +1,155 @@ +use anyhow::Result; + +use crate::cpu::kernel::aggregator::KERNEL; +use crate::cpu::kernel::interpreter::Interpreter; + +#[test] +fn test_encode_rlp_scalar_small() -> Result<()> { + let encode_rlp_scalar = KERNEL.global_labels["encode_rlp_scalar"]; + + let retdest = 0xDEADBEEFu32.into(); + let scalar = 42.into(); + let pos = 2.into(); + let initial_stack = vec![retdest, scalar, pos]; + let mut interpreter = Interpreter::new_with_kernel(encode_rlp_scalar, initial_stack); + + interpreter.run()?; + let expected_stack = vec![3.into()]; // pos' = pos + rlp_len = 2 + 1 + let expected_rlp = vec![0, 0, 42]; + assert_eq!(interpreter.stack(), expected_stack); + assert_eq!(interpreter.get_rlp_memory(), expected_rlp); + + Ok(()) +} + +#[test] +fn test_encode_rlp_scalar_medium() -> Result<()> { + let encode_rlp_scalar = KERNEL.global_labels["encode_rlp_scalar"]; + + let retdest = 0xDEADBEEFu32.into(); + let scalar = 0x12345.into(); + let pos = 2.into(); + let initial_stack = vec![retdest, scalar, pos]; + let mut interpreter = Interpreter::new_with_kernel(encode_rlp_scalar, initial_stack); + + interpreter.run()?; + let expected_stack = vec![6.into()]; // pos' = pos + rlp_len = 2 + 4 + let expected_rlp = vec![0, 0, 0x80 + 3, 0x01, 0x23, 0x45]; + assert_eq!(interpreter.stack(), expected_stack); + assert_eq!(interpreter.get_rlp_memory(), expected_rlp); + + Ok(()) +} + +#[test] +fn test_encode_rlp_160() -> Result<()> { + let encode_rlp_160 = KERNEL.global_labels["encode_rlp_160"]; + + let retdest = 0xDEADBEEFu32.into(); + let string = 0x12345.into(); + let pos = 0.into(); + let initial_stack = vec![retdest, string, pos]; + let mut interpreter = Interpreter::new_with_kernel(encode_rlp_160, initial_stack); + + interpreter.run()?; + let expected_stack = vec![(1 + 20).into()]; // pos' + #[rustfmt::skip] + let expected_rlp = vec![0x80 + 20, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01, 0x23, 0x45]; + assert_eq!(interpreter.stack(), expected_stack); + assert_eq!(interpreter.get_rlp_memory(), expected_rlp); + + Ok(()) +} + +#[test] +fn test_encode_rlp_256() -> Result<()> { + let encode_rlp_256 = KERNEL.global_labels["encode_rlp_256"]; + + let retdest = 0xDEADBEEFu32.into(); + let string = 0x12345.into(); + let pos = 0.into(); + let initial_stack = vec![retdest, string, pos]; + let mut interpreter = Interpreter::new_with_kernel(encode_rlp_256, initial_stack); + + interpreter.run()?; + let expected_stack = vec![(1 + 32).into()]; // pos' + #[rustfmt::skip] + let expected_rlp = vec![0x80 + 32, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01, 0x23, 0x45]; + assert_eq!(interpreter.stack(), expected_stack); + assert_eq!(interpreter.get_rlp_memory(), expected_rlp); + + Ok(()) +} + +#[test] +fn test_prepend_rlp_list_prefix_small() -> Result<()> { + let prepend_rlp_list_prefix = KERNEL.global_labels["prepend_rlp_list_prefix"]; + + let retdest = 0xDEADBEEFu32.into(); + let end_pos = (9 + 5).into(); + let initial_stack = vec![retdest, end_pos]; + let mut interpreter = Interpreter::new_with_kernel(prepend_rlp_list_prefix, initial_stack); + interpreter.set_rlp_memory(vec![ + // Nine 0s to leave room for the longest possible RLP list prefix. + 0, 0, 0, 0, 0, 0, 0, 0, 0, + // The actual RLP list payload, consisting of 5 tiny strings. + 1, 2, 3, 4, 5, + ]); + + interpreter.run()?; + + let expected_rlp_len = 6.into(); + let expected_start_pos = 8.into(); + let expected_stack = vec![expected_rlp_len, expected_start_pos]; + let expected_rlp = vec![0, 0, 0, 0, 0, 0, 0, 0, 0xc0 + 5, 1, 2, 3, 4, 5]; + + assert_eq!(interpreter.stack(), expected_stack); + assert_eq!(interpreter.get_rlp_memory(), expected_rlp); + + Ok(()) +} + +#[test] +fn test_prepend_rlp_list_prefix_large() -> Result<()> { + let prepend_rlp_list_prefix = KERNEL.global_labels["prepend_rlp_list_prefix"]; + + let retdest = 0xDEADBEEFu32.into(); + let end_pos = (9 + 60).into(); + let initial_stack = vec![retdest, end_pos]; + let mut interpreter = Interpreter::new_with_kernel(prepend_rlp_list_prefix, initial_stack); + + #[rustfmt::skip] + interpreter.set_rlp_memory(vec![ + // Nine 0s to leave room for the longest possible RLP list prefix. + 0, 0, 0, 0, 0, 0, 0, 0, 0, + // The actual RLP list payload, consisting of 60 tiny strings. + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, + 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, + 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, + 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, + ]); + + interpreter.run()?; + + let expected_rlp_len = 62.into(); + let expected_start_pos = 7.into(); + let expected_stack = vec![expected_rlp_len, expected_start_pos]; + + #[rustfmt::skip] + let expected_rlp = vec![ + 0, 0, 0, 0, 0, 0, 0, 0xf7 + 1, 60, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, + 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, + 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, + 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, + ]; + + assert_eq!(interpreter.stack(), expected_stack); + assert_eq!(interpreter.get_rlp_memory(), expected_rlp); + + Ok(()) +} diff --git a/evm/src/cpu/kernel/tests/rlp/mod.rs b/evm/src/cpu/kernel/tests/rlp/mod.rs new file mode 100644 index 00000000..bc9bde59 --- /dev/null +++ b/evm/src/cpu/kernel/tests/rlp/mod.rs @@ -0,0 +1,2 @@ +mod decode; +mod encode; diff --git a/evm/src/cpu/kernel/tests/transaction_parsing/parse_type_0_txn.rs b/evm/src/cpu/kernel/tests/transaction_parsing/parse_type_0_txn.rs index c01474ce..53a3d282 100644 --- a/evm/src/cpu/kernel/tests/transaction_parsing/parse_type_0_txn.rs +++ b/evm/src/cpu/kernel/tests/transaction_parsing/parse_type_0_txn.rs @@ -12,7 +12,8 @@ fn process_type_0_txn() -> Result<()> { let process_type_0_txn = KERNEL.global_labels["process_type_0_txn"]; let process_normalized_txn = KERNEL.global_labels["process_normalized_txn"]; - let mut interpreter = Interpreter::new_with_kernel(process_type_0_txn, vec![]); + let retaddr = 0xDEADBEEFu32.into(); + let mut interpreter = Interpreter::new_with_kernel(process_type_0_txn, vec![retaddr]); // When we reach process_normalized_txn, we're done with parsing and normalizing. // Processing normalized transactions is outside the scope of this test. diff --git a/evm/src/cpu/membus.rs b/evm/src/cpu/membus.rs new file mode 100644 index 00000000..c154301e --- /dev/null +++ b/evm/src/cpu/membus.rs @@ -0,0 +1,70 @@ +use plonky2::field::extension::Extendable; +use plonky2::field::packed::PackedField; +use plonky2::field::types::PrimeField64; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::ext_target::ExtensionTarget; + +use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; +use crate::cpu::columns::CpuColumnsView; + +/// General-purpose memory channels; they can read and write to all contexts/segments/addresses. +pub const NUM_GP_CHANNELS: usize = 4; + +pub mod channel_indices { + use std::ops::Range; + + pub const CODE: usize = 0; + pub const GP: Range = CODE + 1..(CODE + 1) + super::NUM_GP_CHANNELS; +} + +/// Total memory channels used by the CPU table. This includes all the `GP_MEM_CHANNELS` as well as +/// all special-purpose memory channels. +/// +/// Currently, there is one special-purpose memory channel, which reads the opcode from memory. Its +/// limitations are: +/// - it is enabled by `is_cpu_cycle`, +/// - it always reads and cannot write, +/// - the context is derived from the current context and the `is_kernel_mode` flag, +/// - the segment is hard-wired to the code segment, +/// - the address is `program_counter`, +/// - the value must fit in one byte (in the least-significant position) and its eight bits are +/// found in `opcode_bits`. +/// These limitations save us numerous columns in the CPU table. +pub const NUM_CHANNELS: usize = channel_indices::GP.end; + +/// Calculates `lv.stack_len_bounds_aux`. Note that this must be run after decode. +pub fn generate(lv: &mut CpuColumnsView) { + let cycle_filter = lv.is_cpu_cycle; + if cycle_filter == F::ZERO { + return; + } + + assert!(lv.is_kernel_mode.to_canonical_u64() <= 1); + + // Set `lv.code_context` to 0 if in kernel mode and to `lv.context` if in user mode. + lv.code_context = (F::ONE - lv.is_kernel_mode) * lv.context; +} + +pub fn eval_packed( + lv: &CpuColumnsView

, + yield_constr: &mut ConstraintConsumer

, +) { + // Validate `lv.code_context`. It should be 0 if in kernel mode and `lv.context` if in user + // mode. + yield_constr.constraint( + lv.is_cpu_cycle * (lv.code_context - (P::ONES - lv.is_kernel_mode) * lv.context), + ); +} + +pub fn eval_ext_circuit, const D: usize>( + builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, + lv: &CpuColumnsView>, + yield_constr: &mut RecursiveConstraintConsumer, +) { + // Validate `lv.code_context`. It should be 0 if in kernel mode and `lv.context` if in user + // mode. + let diff = builder.sub_extension(lv.context, lv.code_context); + let constr = builder.mul_sub_extension(lv.is_kernel_mode, lv.context, diff); + let filtered_constr = builder.mul_extension(lv.is_cpu_cycle, constr); + yield_constr.constraint(builder, filtered_constr); +} diff --git a/evm/src/cpu/mod.rs b/evm/src/cpu/mod.rs index 92e3e6ec..c5b7dd32 100644 --- a/evm/src/cpu/mod.rs +++ b/evm/src/cpu/mod.rs @@ -5,5 +5,8 @@ pub mod cpu_stark; pub(crate) mod decode; mod jumps; pub mod kernel; +pub(crate) mod membus; mod simple_logic; +mod stack; +mod stack_bounds; mod syscalls; diff --git a/evm/src/cpu/simple_logic/eq_iszero.rs b/evm/src/cpu/simple_logic/eq_iszero.rs index 6b7294a8..37e06248 100644 --- a/evm/src/cpu/simple_logic/eq_iszero.rs +++ b/evm/src/cpu/simple_logic/eq_iszero.rs @@ -10,8 +10,8 @@ use crate::cpu::columns::CpuColumnsView; pub fn generate(lv: &mut CpuColumnsView) { let input0 = lv.mem_channels[0].value; - let eq_filter = lv.is_eq.to_canonical_u64(); - let iszero_filter = lv.is_iszero.to_canonical_u64(); + let eq_filter = lv.op.eq.to_canonical_u64(); + let iszero_filter = lv.op.iszero.to_canonical_u64(); assert!(eq_filter <= 1); assert!(iszero_filter <= 1); assert!(eq_filter + iszero_filter <= 1); @@ -62,8 +62,8 @@ pub fn eval_packed( let input1 = lv.mem_channels[1].value; let output = lv.mem_channels[2].value; - let eq_filter = lv.is_eq; - let iszero_filter = lv.is_iszero; + let eq_filter = lv.op.eq; + let iszero_filter = lv.op.iszero; let eq_or_iszero_filter = eq_filter + iszero_filter; let equal = output[0]; @@ -110,8 +110,8 @@ pub fn eval_ext_circuit, const D: usize>( let input1 = lv.mem_channels[1].value; let output = lv.mem_channels[2].value; - let eq_filter = lv.is_eq; - let iszero_filter = lv.is_iszero; + let eq_filter = lv.op.eq; + let iszero_filter = lv.op.iszero; let eq_or_iszero_filter = builder.add_extension(eq_filter, iszero_filter); let equal = output[0]; diff --git a/evm/src/cpu/simple_logic/not.rs b/evm/src/cpu/simple_logic/not.rs index 83d43276..3b8a888f 100644 --- a/evm/src/cpu/simple_logic/not.rs +++ b/evm/src/cpu/simple_logic/not.rs @@ -11,7 +11,7 @@ const LIMB_SIZE: usize = 32; const ALL_1_LIMB: u64 = (1 << LIMB_SIZE) - 1; pub fn generate(lv: &mut CpuColumnsView) { - let is_not_filter = lv.is_not.to_canonical_u64(); + let is_not_filter = lv.op.not.to_canonical_u64(); if is_not_filter == 0 { return; } @@ -35,7 +35,7 @@ pub fn eval_packed( let input = lv.mem_channels[0].value; let output = lv.mem_channels[1].value; let cycle_filter = lv.is_cpu_cycle; - let is_not_filter = lv.is_not; + let is_not_filter = lv.op.not; let filter = cycle_filter * is_not_filter; for (input_limb, output_limb) in input.into_iter().zip(output) { yield_constr.constraint( @@ -52,7 +52,7 @@ pub fn eval_ext_circuit, const D: usize>( let input = lv.mem_channels[0].value; let output = lv.mem_channels[1].value; let cycle_filter = lv.is_cpu_cycle; - let is_not_filter = lv.is_not; + let is_not_filter = lv.op.not; let filter = builder.mul_extension(cycle_filter, is_not_filter); for (input_limb, output_limb) in input.into_iter().zip(output) { let constr = builder.add_extension(output_limb, input_limb); diff --git a/evm/src/cpu/stack.rs b/evm/src/cpu/stack.rs new file mode 100644 index 00000000..3186b5ae --- /dev/null +++ b/evm/src/cpu/stack.rs @@ -0,0 +1,308 @@ +use itertools::izip; +use plonky2::field::extension::Extendable; +use plonky2::field::packed::PackedField; +use plonky2::field::types::Field; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::ext_target::ExtensionTarget; + +use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; +use crate::cpu::columns::ops::OpsColumnsView; +use crate::cpu::columns::CpuColumnsView; +use crate::cpu::membus::NUM_GP_CHANNELS; +use crate::memory::segments::Segment; + +#[derive(Clone, Copy)] +struct StackBehavior { + num_pops: usize, + pushes: bool, + disable_other_channels: bool, +} + +const BASIC_UNARY_OP: Option = Some(StackBehavior { + num_pops: 1, + pushes: true, + disable_other_channels: true, +}); +const BASIC_BINARY_OP: Option = Some(StackBehavior { + num_pops: 2, + pushes: true, + disable_other_channels: true, +}); +const BASIC_TERNARY_OP: Option = Some(StackBehavior { + num_pops: 2, + pushes: true, + disable_other_channels: true, +}); + +const STACK_BEHAVIORS: OpsColumnsView> = OpsColumnsView { + stop: None, // TODO + add: BASIC_BINARY_OP, + mul: BASIC_BINARY_OP, + sub: BASIC_BINARY_OP, + div: BASIC_BINARY_OP, + sdiv: BASIC_BINARY_OP, + mod_: BASIC_BINARY_OP, + smod: BASIC_BINARY_OP, + addmod: BASIC_TERNARY_OP, + mulmod: BASIC_TERNARY_OP, + exp: None, // TODO + signextend: BASIC_BINARY_OP, + lt: BASIC_BINARY_OP, + gt: BASIC_BINARY_OP, + slt: BASIC_BINARY_OP, + sgt: BASIC_BINARY_OP, + eq: BASIC_BINARY_OP, + iszero: BASIC_UNARY_OP, + and: BASIC_BINARY_OP, + or: BASIC_BINARY_OP, + xor: BASIC_BINARY_OP, + not: BASIC_TERNARY_OP, + byte: BASIC_BINARY_OP, + shl: BASIC_BINARY_OP, + shr: BASIC_BINARY_OP, + sar: BASIC_BINARY_OP, + keccak256: None, // TODO + keccak_general: None, // TODO + address: None, // TODO + balance: None, // TODO + origin: None, // TODO + caller: None, // TODO + callvalue: None, // TODO + calldataload: None, // TODO + calldatasize: None, // TODO + calldatacopy: None, // TODO + codesize: None, // TODO + codecopy: None, // TODO + gasprice: None, // TODO + extcodesize: None, // TODO + extcodecopy: None, // TODO + returndatasize: None, // TODO + returndatacopy: None, // TODO + extcodehash: None, // TODO + blockhash: None, // TODO + coinbase: None, // TODO + timestamp: None, // TODO + number: None, // TODO + difficulty: None, // TODO + gaslimit: None, // TODO + chainid: None, // TODO + selfbalance: None, // TODO + basefee: None, // TODO + prover_input: None, // TODO + pop: None, // TODO + mload: None, // TODO + mstore: None, // TODO + mstore8: None, // TODO + sload: None, // TODO + sstore: None, // TODO + jump: None, // TODO + jumpi: None, // TODO + pc: None, // TODO + msize: None, // TODO + gas: None, // TODO + jumpdest: None, // TODO + get_state_root: None, // TODO + set_state_root: None, // TODO + get_receipt_root: None, // TODO + set_receipt_root: None, // TODO + push: None, // TODO + dup: None, // TODO + swap: None, // TODO + log0: None, // TODO + log1: None, // TODO + log2: None, // TODO + log3: None, // TODO + log4: None, // TODO + create: None, // TODO + call: None, // TODO + callcode: None, // TODO + return_: None, // TODO + delegatecall: None, // TODO + create2: None, // TODO + get_context: None, // TODO + set_context: None, // TODO + consume_gas: None, // TODO + exit_kernel: None, // TODO + staticcall: None, // TODO + mload_general: None, // TODO + mstore_general: None, // TODO + revert: None, // TODO + selfdestruct: None, // TODO + invalid: None, // TODO +}; + +fn eval_packed_one( + lv: &CpuColumnsView

, + filter: P, + stack_behavior: StackBehavior, + yield_constr: &mut ConstraintConsumer

, +) { + let num_operands = stack_behavior.num_pops + (stack_behavior.pushes as usize); + assert!(num_operands <= NUM_GP_CHANNELS); + + // Pops + for i in 0..stack_behavior.num_pops { + let channel = lv.mem_channels[i]; + + yield_constr.constraint(filter * (channel.used - P::ONES)); + yield_constr.constraint(filter * (channel.is_read - P::ONES)); + + yield_constr.constraint(filter * (channel.addr_context - lv.context)); + yield_constr.constraint( + filter * (channel.addr_segment - P::Scalar::from_canonical_u64(Segment::Stack as u64)), + ); + // E.g. if `stack_len == 1` and `i == 0`, we want `add_virtual == 0`. + let addr_virtual = lv.stack_len - P::Scalar::from_canonical_usize(i + 1); + yield_constr.constraint(filter * (channel.addr_virtual - addr_virtual)); + } + + // Pushes + if stack_behavior.pushes { + let channel = lv.mem_channels[NUM_GP_CHANNELS - 1]; + + yield_constr.constraint(filter * (channel.used - P::ONES)); + yield_constr.constraint(filter * channel.is_read); + + yield_constr.constraint(filter * (channel.addr_context - lv.context)); + yield_constr.constraint( + filter * (channel.addr_segment - P::Scalar::from_canonical_u64(Segment::Stack as u64)), + ); + let addr_virtual = lv.stack_len - P::Scalar::from_canonical_usize(stack_behavior.num_pops); + yield_constr.constraint(filter * (channel.addr_virtual - addr_virtual)); + } + + // Unused channels + if stack_behavior.disable_other_channels { + for i in stack_behavior.num_pops..NUM_GP_CHANNELS - (stack_behavior.pushes as usize) { + let channel = lv.mem_channels[i]; + yield_constr.constraint(filter * channel.used); + } + } +} + +pub fn eval_packed( + lv: &CpuColumnsView

, + yield_constr: &mut ConstraintConsumer

, +) { + for (op, stack_behavior) in izip!(lv.op.into_iter(), STACK_BEHAVIORS.into_iter()) { + if let Some(stack_behavior) = stack_behavior { + let filter = lv.is_cpu_cycle * op; + eval_packed_one(lv, filter, stack_behavior, yield_constr); + } + } +} + +fn eval_ext_circuit_one, const D: usize>( + builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, + lv: &CpuColumnsView>, + filter: ExtensionTarget, + stack_behavior: StackBehavior, + yield_constr: &mut RecursiveConstraintConsumer, +) { + let num_operands = stack_behavior.num_pops + (stack_behavior.pushes as usize); + assert!(num_operands <= NUM_GP_CHANNELS); + + // Pops + for i in 0..stack_behavior.num_pops { + let channel = lv.mem_channels[i]; + + { + let constr = builder.mul_sub_extension(filter, channel.used, filter); + yield_constr.constraint(builder, constr); + } + { + let constr = builder.mul_sub_extension(filter, channel.is_read, filter); + yield_constr.constraint(builder, constr); + } + + { + let diff = builder.sub_extension(channel.addr_context, lv.context); + let constr = builder.mul_extension(filter, diff); + yield_constr.constraint(builder, constr); + } + { + let constr = builder.arithmetic_extension( + F::ONE, + -F::from_canonical_u64(Segment::Stack as u64), + filter, + channel.addr_segment, + filter, + ); + yield_constr.constraint(builder, constr); + } + { + let diff = builder.sub_extension(channel.addr_virtual, lv.stack_len); + let constr = builder.arithmetic_extension( + F::ONE, + F::from_canonical_usize(i + 1), + filter, + diff, + filter, + ); + yield_constr.constraint(builder, constr); + } + } + + // Pushes + if stack_behavior.pushes { + let channel = lv.mem_channels[NUM_GP_CHANNELS - 1]; + + { + let constr = builder.mul_sub_extension(filter, channel.used, filter); + yield_constr.constraint(builder, constr); + } + { + let constr = builder.mul_extension(filter, channel.is_read); + yield_constr.constraint(builder, constr); + } + + { + let diff = builder.sub_extension(channel.addr_context, lv.context); + let constr = builder.mul_extension(filter, diff); + yield_constr.constraint(builder, constr); + } + { + let constr = builder.arithmetic_extension( + F::ONE, + -F::from_canonical_u64(Segment::Stack as u64), + filter, + channel.addr_segment, + filter, + ); + yield_constr.constraint(builder, constr); + } + { + let diff = builder.sub_extension(channel.addr_virtual, lv.stack_len); + let constr = builder.arithmetic_extension( + F::ONE, + F::from_canonical_usize(stack_behavior.num_pops), + filter, + diff, + filter, + ); + yield_constr.constraint(builder, constr); + } + } + + // Unused channels + if stack_behavior.disable_other_channels { + for i in stack_behavior.num_pops..NUM_GP_CHANNELS - (stack_behavior.pushes as usize) { + let channel = lv.mem_channels[i]; + let constr = builder.mul_extension(filter, channel.used); + yield_constr.constraint(builder, constr); + } + } +} + +pub fn eval_ext_circuit, const D: usize>( + builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, + lv: &CpuColumnsView>, + yield_constr: &mut RecursiveConstraintConsumer, +) { + for (op, stack_behavior) in izip!(lv.op.into_iter(), STACK_BEHAVIORS.into_iter()) { + if let Some(stack_behavior) = stack_behavior { + let filter = builder.mul_extension(lv.is_cpu_cycle, op); + eval_ext_circuit_one(builder, lv, filter, stack_behavior, yield_constr); + } + } +} diff --git a/evm/src/cpu/stack_bounds.rs b/evm/src/cpu/stack_bounds.rs new file mode 100644 index 00000000..99734433 --- /dev/null +++ b/evm/src/cpu/stack_bounds.rs @@ -0,0 +1,157 @@ +//! Checks for stack underflow and overflow. +//! +//! The constraints defined herein validate that stack exceptions (underflow and overflow) do not +//! occur. For example, if `is_add` is set but an addition would underflow, these constraints would +//! make the proof unverifiable. +//! +//! Faults are handled under a separate operation flag, `is_exception` (this is still TODO), which +//! traps to the kernel. The kernel then handles the exception. However, before it may do so, the +//! kernel must verify in software that an exception did in fact occur (i.e. the trap was +//! warranted) and `PANIC` otherwise; this prevents the prover from faking an exception on a valid +//! operation. + +use plonky2::field::extension::Extendable; +use plonky2::field::packed::PackedField; +use plonky2::field::types::Field; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::ext_target::ExtensionTarget; + +use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; +use crate::cpu::columns::{CpuColumnsView, COL_MAP}; + +const MAX_USER_STACK_SIZE: u64 = 1024; + +// Below only includes the operations that pop the top of the stack **without reading the value from +// memory**, i.e. `POP`. +// Other operations that have a minimum stack size (e.g. `MULMOD`, which has three inputs) read +// all their inputs from memory. On underflow, the cross-table lookup fails, as -1, ..., -17 are +// invalid memory addresses. +const DECREMENTING_FLAGS: [usize; 1] = [COL_MAP.op.pop]; + +// Operations that increase the stack length by 1, but excluding: +// - privileged (kernel-only) operations (superfluous; doesn't affect correctness), +// - operations that from userspace to the kernel (required for correctness). +// TODO: This list is incomplete. +const INCREMENTING_FLAGS: [usize; 2] = [COL_MAP.op.pc, COL_MAP.op.dup]; + +/// Calculates `lv.stack_len_bounds_aux`. Note that this must be run after decode. +pub fn generate(lv: &mut CpuColumnsView) { + let cycle_filter = lv.is_cpu_cycle; + if cycle_filter == F::ZERO { + return; + } + + let check_underflow: F = DECREMENTING_FLAGS.map(|i| lv[i]).into_iter().sum(); + let check_overflow: F = INCREMENTING_FLAGS.map(|i| lv[i]).into_iter().sum(); + let no_check = F::ONE - (check_underflow + check_overflow); + + let disallowed_len = check_overflow * F::from_canonical_u64(MAX_USER_STACK_SIZE) - no_check; + let diff = lv.stack_len - disallowed_len; + + let user_mode = F::ONE - lv.is_kernel_mode; + let rhs = user_mode + check_underflow; + + lv.stack_len_bounds_aux = match diff.try_inverse() { + Some(diff_inv) => diff_inv * rhs, // `rhs` may be a value other than 1 or 0 + None => { + assert_eq!(rhs, F::ZERO); + F::ZERO + } + } +} + +pub fn eval_packed( + lv: &CpuColumnsView

, + yield_constr: &mut ConstraintConsumer

, +) { + // `check_underflow`, `check_overflow`, and `no_check` are mutually exclusive. + let check_underflow: P = DECREMENTING_FLAGS.map(|i| lv[i]).into_iter().sum(); + let check_overflow: P = INCREMENTING_FLAGS.map(|i| lv[i]).into_iter().sum(); + let no_check = P::ONES - (check_underflow + check_overflow); + + // If `check_underflow`, then the instruction we are executing pops a value from the stack + // without reading it from memory, and the usual underflow checks do not work. We must show that + // `lv.stack_len` is not 0. We choose to perform this check whether or not we're in kernel mode. + // (The check in kernel mode is not necessary if the kernel is correct, but this is an easy + // sanity check. + // If `check_overflow`, then the instruction we are executing increases the stack length by 1. + // If we are in user mode, then we must show that the stack length is not currently + // `MAX_USER_STACK_SIZE`, as this is the maximum for the user stack. Note that this check must + // not run in kernel mode as the kernel's stack length is unrestricted. + // If `no_check`, then we don't need to check anything. The constraint is written to always + // test that `lv.stack_len` does not equal _something_ so we just show that it's not -1, which + // is always true. + + // 0 if `check_underflow`, `MAX_USER_STACK_SIZE` if `check_overflow`, and -1 if `no_check`. + let disallowed_len = + check_overflow * P::Scalar::from_canonical_u64(MAX_USER_STACK_SIZE) - no_check; + // This `lhs` must equal some `rhs`. If `rhs` is nonzero, then this shows that `lv.stack_len` is + // not `disallowed_len`. + let lhs = (lv.stack_len - disallowed_len) * lv.stack_len_bounds_aux; + + // We want this constraint to be active if we're in user mode OR the instruction might overflow. + // (In other words, we want to _skip_ overflow checks in kernel mode). + let user_mode = P::ONES - lv.is_kernel_mode; + // `rhs` is may be 0, 1, or 2. It's 0 if we're in kernel mode and we would be checking for + // overflow. + // Note: if `user_mode` and `check_underflow` then, `rhs` is 2. This is fine: we're still + // showing that `lv.stack_len - disallowed_len` is nonzero. + let rhs = user_mode + check_underflow; + + yield_constr.constraint(lv.is_cpu_cycle * (lhs - rhs)); +} + +pub fn eval_ext_circuit, const D: usize>( + builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, + lv: &CpuColumnsView>, + yield_constr: &mut RecursiveConstraintConsumer, +) { + let one = builder.one_extension(); + let max_stack_size = + builder.constant_extension(F::from_canonical_u64(MAX_USER_STACK_SIZE).into()); + + // `check_underflow`, `check_overflow`, and `no_check` are mutually exclusive. + let check_underflow = builder.add_many_extension(DECREMENTING_FLAGS.map(|i| lv[i])); + let check_overflow = builder.add_many_extension(INCREMENTING_FLAGS.map(|i| lv[i])); + let no_check = { + let any_check = builder.add_extension(check_underflow, check_overflow); + builder.sub_extension(one, any_check) + }; + + // If `check_underflow`, then the instruction we are executing pops a value from the stack + // without reading it from memory, and the usual underflow checks do not work. We must show that + // `lv.stack_len` is not 0. We choose to perform this check whether or not we're in kernel mode. + // (The check in kernel mode is not necessary if the kernel is correct, but this is an easy + // sanity check. + // If `check_overflow`, then the instruction we are executing increases the stack length by 1. + // If we are in user mode, then we must show that the stack length is not currently + // `MAX_USER_STACK_SIZE`, as this is the maximum for the user stack. Note that this check must + // not run in kernel mode as the kernel's stack length is unrestricted. + // If `no_check`, then we don't need to check anything. The constraint is written to always + // test that `lv.stack_len` does not equal _something_ so we just show that it's not -1, which + // is always true. + + // 0 if `check_underflow`, `MAX_USER_STACK_SIZE` if `check_overflow`, and -1 if `no_check`. + let disallowed_len = builder.mul_sub_extension(check_overflow, max_stack_size, no_check); + // This `lhs` must equal some `rhs`. If `rhs` is nonzero, then this shows that `lv.stack_len` is + // not `disallowed_len`. + let lhs = { + let diff = builder.sub_extension(lv.stack_len, disallowed_len); + builder.mul_extension(diff, lv.stack_len_bounds_aux) + }; + + // We want this constraint to be active if we're in user mode OR the instruction might overflow. + // (In other words, we want to _skip_ overflow checks in kernel mode). + let user_mode = builder.sub_extension(one, lv.is_kernel_mode); + // `rhs` is may be 0, 1, or 2. It's 0 if we're in kernel mode and we would be checking for + // overflow. + // Note: if `user_mode` and `check_underflow` then, `rhs` is 2. This is fine: we're still + // showing that `lv.stack_len - disallowed_len` is nonzero. + let rhs = builder.add_extension(user_mode, check_underflow); + + let constr = { + let diff = builder.sub_extension(lhs, rhs); + builder.mul_extension(lv.is_cpu_cycle, diff) + }; + yield_constr.constraint(builder, constr); +} diff --git a/evm/src/cpu/syscalls.rs b/evm/src/cpu/syscalls.rs index b0b63be8..0ac31ef6 100644 --- a/evm/src/cpu/syscalls.rs +++ b/evm/src/cpu/syscalls.rs @@ -18,9 +18,9 @@ const NUM_SYSCALLS: usize = 3; fn make_syscall_list() -> [(usize, usize); NUM_SYSCALLS] { let kernel = Lazy::force(&KERNEL); [ - (COL_MAP.is_stop, "sys_stop"), - (COL_MAP.is_exp, "sys_exp"), - (COL_MAP.is_invalid, "handle_invalid"), + (COL_MAP.op.stop, "sys_stop"), + (COL_MAP.op.exp, "sys_exp"), + (COL_MAP.op.invalid, "handle_invalid"), ] .map(|(col_index, handler_name)| (col_index, kernel.global_labels[handler_name])) } diff --git a/evm/src/generation/mod.rs b/evm/src/generation/mod.rs index f9ec0b3c..1e0e31ed 100644 --- a/evm/src/generation/mod.rs +++ b/evm/src/generation/mod.rs @@ -1,15 +1,19 @@ -use ethereum_types::Address; +use std::collections::HashMap; + +use eth_trie_utils::partial_trie::PartialTrie; +use ethereum_types::{Address, BigEndianHash, H256}; use plonky2::field::extension::Extendable; use plonky2::field::polynomial::PolynomialValues; use plonky2::field::types::Field; use plonky2::hash::hash_types::RichField; +use plonky2::util::timing::TimingTree; +use serde::{Deserialize, Serialize}; use crate::all_stark::{AllStark, NUM_TABLES}; use crate::config::StarkConfig; use crate::cpu::bootstrap_kernel::generate_bootstrap_kernel; use crate::cpu::columns::NUM_CPU_COLUMNS; use crate::cpu::kernel::global_metadata::GlobalMetadata; -use crate::generation::partial_trie::PartialTrie; use crate::generation::state::GenerationState; use crate::memory::segments::Segment; use crate::memory::NUM_CHANNELS; @@ -17,13 +21,27 @@ use crate::proof::{BlockMetadata, PublicValues, TrieRoots}; use crate::util::trace_rows_to_poly_values; pub(crate) mod memory; -pub mod partial_trie; +pub(crate) mod mpt; +pub(crate) mod prover_input; +pub(crate) mod rlp; pub(crate) mod state; +#[derive(Clone, Debug, Deserialize, Serialize, Default)] /// Inputs needed for trace generation. pub struct GenerationInputs { pub signed_txns: Vec>, + pub tries: TrieInputs, + + /// Mapping between smart contract code hashes and the contract byte code. + /// All account smart contracts that are invoked will have an entry present. + pub contract_code: HashMap>, + + pub block_metadata: BlockMetadata, +} + +#[derive(Clone, Debug, Deserialize, Serialize, Default)] +pub struct TrieInputs { /// A partial version of the state trie prior to these transactions. It should include all nodes /// that will be accessed by these transactions. pub state_trie: PartialTrie, @@ -39,16 +57,15 @@ pub struct GenerationInputs { /// A partial version of each storage trie prior to these transactions. It should include all /// storage tries, and nodes therein, that will be accessed by these transactions. pub storage_tries: Vec<(Address, PartialTrie)>, - - pub block_metadata: BlockMetadata, } pub(crate) fn generate_traces, const D: usize>( all_stark: &AllStark, inputs: GenerationInputs, config: &StarkConfig, + timing: &mut TimingTree, ) -> ([Vec>; NUM_TABLES], PublicValues) { - let mut state = GenerationState::::default(); + let mut state = GenerationState::::new(inputs.clone()); generate_bootstrap_kernel::(&mut state); @@ -70,14 +87,18 @@ pub(crate) fn generate_traces, const D: usize>( }; let trie_roots_before = TrieRoots { - state_root: read_metadata(GlobalMetadata::StateTrieRootDigestBefore), - transactions_root: read_metadata(GlobalMetadata::TransactionsTrieRootDigestBefore), - receipts_root: read_metadata(GlobalMetadata::ReceiptsTrieRootDigestBefore), + state_root: H256::from_uint(&read_metadata(GlobalMetadata::StateTrieRootDigestBefore)), + transactions_root: H256::from_uint(&read_metadata( + GlobalMetadata::TransactionTrieRootDigestBefore, + )), + receipts_root: H256::from_uint(&read_metadata(GlobalMetadata::ReceiptTrieRootDigestBefore)), }; let trie_roots_after = TrieRoots { - state_root: read_metadata(GlobalMetadata::StateTrieRootDigestAfter), - transactions_root: read_metadata(GlobalMetadata::TransactionsTrieRootDigestAfter), - receipts_root: read_metadata(GlobalMetadata::ReceiptsTrieRootDigestAfter), + state_root: H256::from_uint(&read_metadata(GlobalMetadata::StateTrieRootDigestAfter)), + transactions_root: H256::from_uint(&read_metadata( + GlobalMetadata::TransactionTrieRootDigestAfter, + )), + receipts_root: H256::from_uint(&read_metadata(GlobalMetadata::ReceiptTrieRootDigestAfter)), }; let GenerationState { @@ -92,12 +113,14 @@ pub(crate) fn generate_traces, const D: usize>( assert_eq!(current_cpu_row, [F::ZERO; NUM_CPU_COLUMNS].into()); let cpu_trace = trace_rows_to_poly_values(cpu_rows); - let keccak_trace = all_stark.keccak_stark.generate_trace(keccak_inputs); - let keccak_memory_trace = all_stark - .keccak_memory_stark - .generate_trace(keccak_memory_inputs, config.fri_config.num_cap_elements()); - let logic_trace = all_stark.logic_stark.generate_trace(logic_ops); - let memory_trace = all_stark.memory_stark.generate_trace(memory.log); + let keccak_trace = all_stark.keccak_stark.generate_trace(keccak_inputs, timing); + let keccak_memory_trace = all_stark.keccak_memory_stark.generate_trace( + keccak_memory_inputs, + config.fri_config.num_cap_elements(), + timing, + ); + let logic_trace = all_stark.logic_stark.generate_trace(logic_ops, timing); + let memory_trace = all_stark.memory_stark.generate_trace(memory.log, timing); let traces = [ cpu_trace, keccak_trace, diff --git a/evm/src/generation/mpt.rs b/evm/src/generation/mpt.rs new file mode 100644 index 00000000..0bdfede4 --- /dev/null +++ b/evm/src/generation/mpt.rs @@ -0,0 +1,77 @@ +use eth_trie_utils::partial_trie::PartialTrie; +use ethereum_types::U256; + +use crate::cpu::kernel::constants::trie_type::PartialTrieType; +use crate::generation::TrieInputs; + +pub(crate) fn all_mpt_prover_inputs_reversed(trie_inputs: &TrieInputs) -> Vec { + let mut inputs = all_mpt_prover_inputs(trie_inputs); + inputs.reverse(); + inputs +} + +/// Generate prover inputs for the initial MPT data, in the format expected by `mpt/load.asm`. +pub(crate) fn all_mpt_prover_inputs(trie_inputs: &TrieInputs) -> Vec { + let mut prover_inputs = vec![]; + + mpt_prover_inputs(&trie_inputs.state_trie, &mut prover_inputs, &|rlp| { + rlp::decode_list(rlp) + }); + + mpt_prover_inputs(&trie_inputs.transactions_trie, &mut prover_inputs, &|rlp| { + rlp::decode_list(rlp) + }); + + mpt_prover_inputs(&trie_inputs.receipts_trie, &mut prover_inputs, &|_rlp| { + // TODO: Decode receipt RLP. + vec![] + }); + + prover_inputs.push(trie_inputs.storage_tries.len().into()); + for (addr, storage_trie) in &trie_inputs.storage_tries { + prover_inputs.push(addr.0.as_ref().into()); + mpt_prover_inputs(storage_trie, &mut prover_inputs, &|leaf_be| { + vec![U256::from_big_endian(leaf_be)] + }); + } + + prover_inputs +} + +/// Given a trie, generate the prover input data for that trie. In essence, this serializes a trie +/// into a `U256` array, in a simple format which the kernel understands. For example, a leaf node +/// is serialized as `(TYPE_LEAF, key, value)`, where key is a `(nibbles, depth)` pair and `value` +/// is a variable-length structure which depends on which trie we're dealing with. +pub(crate) fn mpt_prover_inputs( + trie: &PartialTrie, + prover_inputs: &mut Vec, + parse_leaf: &F, +) where + F: Fn(&[u8]) -> Vec, +{ + prover_inputs.push((PartialTrieType::of(trie) as u32).into()); + match trie { + PartialTrie::Empty => {} + PartialTrie::Hash(h) => prover_inputs.push(U256::from_big_endian(h.as_bytes())), + PartialTrie::Branch { children, value } => { + for child in children { + mpt_prover_inputs(child, prover_inputs, parse_leaf); + } + let leaf = parse_leaf(value); + prover_inputs.push(leaf.len().into()); + prover_inputs.extend(leaf); + } + PartialTrie::Extension { nibbles, child } => { + prover_inputs.push(nibbles.count.into()); + prover_inputs.push(nibbles.packed); + mpt_prover_inputs(child, prover_inputs, parse_leaf); + } + PartialTrie::Leaf { nibbles, value } => { + prover_inputs.push(nibbles.count.into()); + prover_inputs.push(nibbles.packed); + let leaf = parse_leaf(value); + prover_inputs.push(leaf.len().into()); + prover_inputs.extend(leaf); + } + } +} diff --git a/evm/src/generation/partial_trie.rs b/evm/src/generation/partial_trie.rs deleted file mode 100644 index 96751310..00000000 --- a/evm/src/generation/partial_trie.rs +++ /dev/null @@ -1,32 +0,0 @@ -use ethereum_types::U256; - -/// A partial trie, or a sub-trie thereof. This mimics the structure of an Ethereum trie, except -/// with an additional `Hash` node type, representing a node whose data is not needed to process -/// our transaction. -pub enum PartialTrie { - /// An empty trie. - Empty, - /// The digest of trie whose data does not need to be stored. - Hash(U256), - /// A branch node, which consists of 16 children and an optional value. - Branch { - children: [Box; 16], - value: Option, - }, - /// An extension node, which consists of a list of nibbles and a single child. - Extension { - nibbles: Nibbles, - child: Box, - }, - /// A leaf node, which consists of a list of nibbles and a value. - Leaf { nibbles: Nibbles, value: Vec }, -} - -/// A sequence of nibbles. -pub struct Nibbles { - /// The number of nibbles in this sequence. - pub count: usize, - /// A packed encoding of these nibbles. Only the first (least significant) `4 * count` bits are - /// used. The rest are unused and should be zero. - pub packed: U256, -} diff --git a/evm/src/cpu/kernel/prover_input.rs b/evm/src/generation/prover_input.rs similarity index 63% rename from evm/src/cpu/kernel/prover_input.rs rename to evm/src/generation/prover_input.rs index 38e1914e..d5d7df7c 100644 --- a/evm/src/cpu/kernel/prover_input.rs +++ b/evm/src/generation/prover_input.rs @@ -1,11 +1,13 @@ use std::str::FromStr; use ethereum_types::U256; +use plonky2::field::types::Field; -use crate::cpu::kernel::prover_input::Field::{ +use crate::generation::prover_input::EvmField::{ Bn254Base, Bn254Scalar, Secp256k1Base, Secp256k1Scalar, }; -use crate::cpu::kernel::prover_input::FieldOp::{Inverse, Sqrt}; +use crate::generation::prover_input::FieldOp::{Inverse, Sqrt}; +use crate::generation::state::GenerationState; /// Prover input function represented as a scoped function name. /// Example: `PROVER_INPUT(ff::bn254_base::inverse)` is represented as `ProverInputFn([ff, bn254_base, inverse])`. @@ -18,32 +20,52 @@ impl From> for ProverInputFn { } } -impl ProverInputFn { - /// Run the function on the stack. - pub fn run(&self, stack: &[U256]) -> U256 { - match self.0[0].as_str() { - "ff" => self.run_ff(stack), - "mpt" => todo!(), +impl GenerationState { + #[allow(unused)] // TODO: Should be used soon. + pub(crate) fn prover_input(&mut self, stack: &[U256], input_fn: &ProverInputFn) -> U256 { + match input_fn.0[0].as_str() { + "end_of_txns" => self.run_end_of_txns(), + "ff" => self.run_ff(stack, input_fn), + "mpt" => self.run_mpt(), + "rlp" => self.run_rlp(), _ => panic!("Unrecognized prover input function."), } } - // Finite field operations. - fn run_ff(&self, stack: &[U256]) -> U256 { - let field = Field::from_str(self.0[1].as_str()).unwrap(); - let op = FieldOp::from_str(self.0[2].as_str()).unwrap(); + fn run_end_of_txns(&mut self) -> U256 { + let end = self.next_txn_index == self.inputs.signed_txns.len(); + if end { + U256::one() + } else { + self.next_txn_index += 1; + U256::zero() + } + } + + /// Finite field operations. + fn run_ff(&self, stack: &[U256], input_fn: &ProverInputFn) -> U256 { + let field = EvmField::from_str(input_fn.0[1].as_str()).unwrap(); + let op = FieldOp::from_str(input_fn.0[2].as_str()).unwrap(); let x = *stack.last().expect("Empty stack"); field.op(op, x) } - // MPT operations. - #[allow(dead_code)] - fn run_mpt(&self, _stack: Vec) -> U256 { - todo!() + /// MPT data. + fn run_mpt(&mut self) -> U256 { + self.mpt_prover_inputs + .pop() + .unwrap_or_else(|| panic!("Out of MPT data")) + } + + /// RLP data. + fn run_rlp(&mut self) -> U256 { + self.rlp_prover_inputs + .pop() + .unwrap_or_else(|| panic!("Out of RLP data")) } } -enum Field { +enum EvmField { Bn254Base, Bn254Scalar, Secp256k1Base, @@ -55,7 +77,7 @@ enum FieldOp { Sqrt, } -impl FromStr for Field { +impl FromStr for EvmField { type Err = (); fn from_str(s: &str) -> Result { @@ -81,19 +103,19 @@ impl FromStr for FieldOp { } } -impl Field { +impl EvmField { fn order(&self) -> U256 { match self { - Field::Bn254Base => { + EvmField::Bn254Base => { U256::from_str("0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47") .unwrap() } - Field::Bn254Scalar => todo!(), - Field::Secp256k1Base => { + EvmField::Bn254Scalar => todo!(), + EvmField::Secp256k1Base => { U256::from_str("0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f") .unwrap() } - Field::Secp256k1Scalar => { + EvmField::Secp256k1Scalar => { U256::from_str("0xfffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141") .unwrap() } diff --git a/evm/src/generation/rlp.rs b/evm/src/generation/rlp.rs new file mode 100644 index 00000000..f28272a2 --- /dev/null +++ b/evm/src/generation/rlp.rs @@ -0,0 +1,18 @@ +use ethereum_types::U256; + +pub(crate) fn all_rlp_prover_inputs_reversed(signed_txns: &[Vec]) -> Vec { + let mut inputs = all_rlp_prover_inputs(signed_txns); + inputs.reverse(); + inputs +} + +fn all_rlp_prover_inputs(signed_txns: &[Vec]) -> Vec { + let mut prover_inputs = vec![]; + for txn in signed_txns { + prover_inputs.push(txn.len().into()); + for &byte in txn { + prover_inputs.push(byte.into()); + } + } + prover_inputs +} diff --git a/evm/src/generation/state.rs b/evm/src/generation/state.rs index 4cbe61c8..17d63018 100644 --- a/evm/src/generation/state.rs +++ b/evm/src/generation/state.rs @@ -6,6 +6,9 @@ use tiny_keccak::keccakf; use crate::cpu::columns::{CpuColumnsView, NUM_CPU_COLUMNS}; use crate::generation::memory::MemoryState; +use crate::generation::mpt::all_mpt_prover_inputs_reversed; +use crate::generation::rlp::all_rlp_prover_inputs_reversed; +use crate::generation::GenerationInputs; use crate::keccak_memory::keccak_memory_stark::KeccakMemoryOp; use crate::memory::memory_stark::MemoryOp; use crate::memory::segments::Segment; @@ -15,6 +18,9 @@ use crate::{keccak, logic}; #[derive(Debug)] pub(crate) struct GenerationState { + #[allow(unused)] // TODO: Should be used soon. + pub(crate) inputs: GenerationInputs, + pub(crate) next_txn_index: usize, pub(crate) cpu_rows: Vec<[F; NUM_CPU_COLUMNS]>, pub(crate) current_cpu_row: CpuColumnsView, @@ -24,9 +30,36 @@ pub(crate) struct GenerationState { pub(crate) keccak_inputs: Vec<[u64; keccak::keccak_stark::NUM_INPUTS]>, pub(crate) keccak_memory_inputs: Vec, pub(crate) logic_ops: Vec, + + /// Prover inputs containing MPT data, in reverse order so that the next input can be obtained + /// via `pop()`. + pub(crate) mpt_prover_inputs: Vec, + + /// Prover inputs containing RLP data, in reverse order so that the next input can be obtained + /// via `pop()`. + pub(crate) rlp_prover_inputs: Vec, } impl GenerationState { + pub(crate) fn new(inputs: GenerationInputs) -> Self { + let mpt_prover_inputs = all_mpt_prover_inputs_reversed(&inputs.tries); + let rlp_prover_inputs = all_rlp_prover_inputs_reversed(&inputs.signed_txns); + + Self { + inputs, + next_txn_index: 0, + cpu_rows: vec![], + current_cpu_row: [F::ZERO; NUM_CPU_COLUMNS].into(), + current_context: 0, + memory: MemoryState::default(), + keccak_inputs: vec![], + keccak_memory_inputs: vec![], + logic_ops: vec![], + mpt_prover_inputs, + rlp_prover_inputs, + } + } + /// Compute logical AND, and record the operation to be added in the logic table later. #[allow(unused)] // TODO: Should be used soon. pub(crate) fn and(&mut self, input0: U256, input1: U256) -> U256 { @@ -217,19 +250,3 @@ impl GenerationState { self.cpu_rows.push(swapped_row.into()); } } - -// `GenerationState` can't `derive(Default)` because `Default` is only implemented for arrays up to -// length 32 :-\. -impl Default for GenerationState { - fn default() -> Self { - Self { - cpu_rows: vec![], - current_cpu_row: [F::ZERO; NUM_CPU_COLUMNS].into(), - current_context: 0, - memory: MemoryState::default(), - keccak_inputs: vec![], - keccak_memory_inputs: vec![], - logic_ops: vec![], - } - } -} diff --git a/evm/src/keccak/keccak_stark.rs b/evm/src/keccak/keccak_stark.rs index 23ffe0e9..87a61ae7 100644 --- a/evm/src/keccak/keccak_stark.rs +++ b/evm/src/keccak/keccak_stark.rs @@ -201,23 +201,22 @@ impl, const D: usize> KeccakStark { row[out_reg_hi] = F::from_canonical_u64(row[in_reg_hi].to_canonical_u64() ^ rc_hi); } - pub fn generate_trace(&self, inputs: Vec<[u64; NUM_INPUTS]>) -> Vec> { - let mut timing = TimingTree::new("generate trace", log::Level::Debug); - + pub fn generate_trace( + &self, + inputs: Vec<[u64; NUM_INPUTS]>, + timing: &mut TimingTree, + ) -> Vec> { // Generate the witness, except for permuted columns in the lookup argument. let trace_rows = timed!( - &mut timing, + timing, "generate trace rows", self.generate_trace_rows(inputs) ); - let trace_polys = timed!( - &mut timing, + timing, "convert to PolynomialValues", trace_rows_to_poly_values(trace_rows) ); - - timing.print(); trace_polys } } @@ -542,12 +541,22 @@ impl, const D: usize> Stark for KeccakStark Result<()> { + const NUM_PERMS: usize = 85; + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + type S = KeccakStark; + let stark = S::default(); + let config = StarkConfig::standard_fast_config(); + + init_logger(); + + let input: Vec<[u64; NUM_INPUTS]> = (0..NUM_PERMS).map(|_| rand::random()).collect(); + + let mut timing = TimingTree::new("prove", log::Level::Debug); + let trace_poly_values = timed!( + timing, + "generate trace", + stark.generate_trace(input.try_into().unwrap(), &mut timing) + ); + + // TODO: Cloning this isn't great; consider having `from_values` accept a reference, + // or having `compute_permutation_z_polys` read trace values from the `PolynomialBatch`. + let cloned_trace_poly_values = timed!(timing, "clone", trace_poly_values.clone()); + + let trace_commitments = timed!( + timing, + "compute trace commitment", + PolynomialBatch::::from_values( + cloned_trace_poly_values, + config.fri_config.rate_bits, + false, + config.fri_config.cap_height, + &mut timing, + None, + ) + ); + let degree = 1 << trace_commitments.degree_log; + + // Fake CTL data. + let ctl_z_data = CtlZData { + z: PolynomialValues::zero(degree), + challenge: GrandProductChallenge { + beta: F::ZERO, + gamma: F::ZERO, + }, + columns: vec![], + filter_column: None, + }; + let ctl_data = CtlData { + zs_columns: vec![ctl_z_data; config.num_challenges], + }; + + prove_single_table( + &stark, + &config, + &trace_poly_values, + &trace_commitments, + &ctl_data, + &mut Challenger::new(), + &mut timing, + )?; + + timing.print(); + Ok(()) + } + + fn init_logger() { + let _ = try_init_from_env(Env::default().filter_or(DEFAULT_FILTER_ENV, "debug")); + } } diff --git a/evm/src/keccak_memory/keccak_memory_stark.rs b/evm/src/keccak_memory/keccak_memory_stark.rs index cf8955b3..1bbea168 100644 --- a/evm/src/keccak_memory/keccak_memory_stark.rs +++ b/evm/src/keccak_memory/keccak_memory_stark.rs @@ -93,23 +93,21 @@ impl, const D: usize> KeccakMemoryStark { &self, operations: Vec, min_rows: usize, + timing: &mut TimingTree, ) -> Vec> { - let mut timing = TimingTree::new("generate trace", log::Level::Debug); - // Generate the witness row-wise. let trace_rows = timed!( - &mut timing, + timing, "generate trace rows", self.generate_trace_rows(operations, min_rows) ); let trace_polys = timed!( - &mut timing, + timing, "convert to PolynomialValues", trace_rows_to_poly_values(trace_rows) ); - timing.print(); trace_polys } diff --git a/evm/src/keccak_sponge/keccak_sponge_stark.rs b/evm/src/keccak_sponge/keccak_sponge_stark.rs index afde02c2..219c0c21 100644 --- a/evm/src/keccak_sponge/keccak_sponge_stark.rs +++ b/evm/src/keccak_sponge/keccak_sponge_stark.rs @@ -171,23 +171,21 @@ impl, const D: usize> KeccakSpongeStark { &self, operations: Vec, min_rows: usize, + timing: &mut TimingTree, ) -> Vec> { - let mut timing = TimingTree::new("generate trace", log::Level::Debug); - // Generate the witness row-wise. let trace_rows = timed!( - &mut timing, + timing, "generate trace rows", self.generate_trace_rows(operations, min_rows) ); let trace_polys = timed!( - &mut timing, + timing, "convert to PolynomialValues", trace_rows_to_poly_values(trace_rows) ); - timing.print(); trace_polys } diff --git a/evm/src/logic.rs b/evm/src/logic.rs index 2499101b..dc6fc777 100644 --- a/evm/src/logic.rs +++ b/evm/src/logic.rs @@ -7,9 +7,13 @@ use plonky2::field::packed::PackedField; use plonky2::field::polynomial::PolynomialValues; use plonky2::field::types::Field; use plonky2::hash::hash_types::RichField; +use plonky2::timed; +use plonky2::util::timing::TimingTree; +use plonky2_util::ceil_div_usize; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::cross_table_lookup::Column; +use crate::logic::columns::NUM_COLUMNS; use crate::stark::Stark; use crate::util::{limb_from_bits_le, limb_from_bits_le_recursive, trace_rows_to_poly_values}; use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars}; @@ -19,7 +23,7 @@ const VAL_BITS: usize = 256; // Number of bits stored per field element. Ensure that this fits; it is not checked. pub(crate) const PACKED_LIMB_BITS: usize = 32; // Number of field elements needed to store each input/output at the specified packing. -const PACKED_LEN: usize = (VAL_BITS + PACKED_LIMB_BITS - 1) / PACKED_LIMB_BITS; +const PACKED_LEN: usize = ceil_div_usize(VAL_BITS, PACKED_LIMB_BITS); pub(crate) mod columns { use std::cmp::min; @@ -100,7 +104,25 @@ impl Operation { } impl LogicStark { - pub(crate) fn generate_trace(&self, operations: Vec) -> Vec> { + pub(crate) fn generate_trace( + &self, + operations: Vec, + timing: &mut TimingTree, + ) -> Vec> { + let trace_rows = timed!( + timing, + "generate trace rows", + self.generate_trace_rows(operations) + ); + let trace_polys = timed!( + timing, + "convert to PolynomialValues", + trace_rows_to_poly_values(trace_rows) + ); + trace_polys + } + + fn generate_trace_rows(&self, operations: Vec) -> Vec<[F; NUM_COLUMNS]> { let len = operations.len(); let padded_len = len.next_power_of_two(); @@ -114,7 +136,7 @@ impl LogicStark { rows.push([F::ZERO; columns::NUM_COLUMNS]); } - trace_rows_to_poly_values(rows) + rows } fn generate_row(operation: Operation) -> [F; columns::NUM_COLUMNS] { diff --git a/evm/src/memory/memory_stark.rs b/evm/src/memory/memory_stark.rs index 1ec0c11c..f5455a53 100644 --- a/evm/src/memory/memory_stark.rs +++ b/evm/src/memory/memory_stark.rs @@ -187,12 +187,14 @@ impl, const D: usize> MemoryStark { } } - pub(crate) fn generate_trace(&self, memory_ops: Vec) -> Vec> { - let mut timing = TimingTree::new("generate trace", log::Level::Debug); - + pub(crate) fn generate_trace( + &self, + memory_ops: Vec, + timing: &mut TimingTree, + ) -> Vec> { // Generate most of the trace in row-major form. let trace_rows = timed!( - &mut timing, + timing, "generate trace rows", self.generate_trace_row_major(memory_ops) ); @@ -204,13 +206,10 @@ impl, const D: usize> MemoryStark { // A few final generation steps, which work better in column-major form. Self::generate_trace_col_major(&mut trace_col_vecs); - let trace_polys = trace_col_vecs + trace_col_vecs .into_iter() .map(|column| PolynomialValues::new(column)) - .collect(); - - timing.print(); - trace_polys + .collect() } } diff --git a/evm/src/memory/mod.rs b/evm/src/memory/mod.rs index dd82ad04..4cdfd1be 100644 --- a/evm/src/memory/mod.rs +++ b/evm/src/memory/mod.rs @@ -3,5 +3,5 @@ pub mod memory_stark; pub mod segments; // TODO: Move to CPU module, now that channels have been removed from the memory table. -pub(crate) const NUM_CHANNELS: usize = 4; +pub(crate) const NUM_CHANNELS: usize = crate::cpu::membus::NUM_CHANNELS; pub(crate) const VALUE_LIMBS: usize = 8; diff --git a/evm/src/proof.rs b/evm/src/proof.rs index 28560a6c..e9002c62 100644 --- a/evm/src/proof.rs +++ b/evm/src/proof.rs @@ -1,4 +1,4 @@ -use ethereum_types::{Address, U256}; +use ethereum_types::{Address, H256, U256}; use itertools::Itertools; use maybe_rayon::*; use plonky2::field::extension::{Extendable, FieldExtension}; @@ -13,6 +13,7 @@ use plonky2::hash::merkle_tree::MerkleCap; use plonky2::iop::ext_target::ExtensionTarget; use plonky2::iop::target::Target; use plonky2::plonk::config::GenericConfig; +use serde::{Deserialize, Serialize}; use crate::all_stark::NUM_TABLES; use crate::config::StarkConfig; @@ -57,12 +58,12 @@ pub struct PublicValues { #[derive(Debug, Clone, Default)] pub struct TrieRoots { - pub state_root: U256, - pub transactions_root: U256, - pub receipts_root: U256, + pub state_root: H256, + pub transactions_root: H256, + pub receipts_root: H256, } -#[derive(Debug, Clone, Default)] +#[derive(Debug, Clone, Default, Deserialize, Serialize)] pub struct BlockMetadata { pub block_beneficiary: Address, pub block_timestamp: U256, diff --git a/evm/src/prover.rs b/evm/src/prover.rs index 3b702c56..20e8c628 100644 --- a/evm/src/prover.rs +++ b/evm/src/prover.rs @@ -53,7 +53,7 @@ where [(); LogicStark::::COLUMNS]:, [(); MemoryStark::::COLUMNS]:, { - let (traces, public_values) = generate_traces(all_stark, inputs, config); + let (traces, public_values) = generate_traces(all_stark, inputs, config, timing); prove_with_traces(all_stark, config, traces, public_values, timing) } @@ -175,7 +175,7 @@ where } /// Compute proof for a single STARK table. -fn prove_single_table( +pub(crate) fn prove_single_table( stark: &S, config: &StarkConfig, trace_poly_values: &[PolynomialValues], @@ -212,7 +212,11 @@ where ) }); let permutation_zs = permutation_challenges.as_ref().map(|challenges| { - compute_permutation_z_polys::(stark, config, trace_poly_values, challenges) + timed!( + timing, + "compute permutation Z(x) polys", + compute_permutation_z_polys::(stark, config, trace_poly_values, challenges) + ) }); let num_permutation_zs = permutation_zs.as_ref().map(|v| v.len()).unwrap_or(0); @@ -225,13 +229,17 @@ where }; assert!(!z_polys.is_empty(), "No CTL?"); - let permutation_ctl_zs_commitment = PolynomialBatch::from_values( - z_polys, - rate_bits, - false, - config.fri_config.cap_height, + let permutation_ctl_zs_commitment = timed!( timing, - None, + "compute Zs commitment", + PolynomialBatch::from_values( + z_polys, + rate_bits, + false, + config.fri_config.cap_height, + timing, + None, + ) ); let permutation_ctl_zs_cap = permutation_ctl_zs_commitment.merkle_tree.cap.clone(); @@ -251,27 +259,37 @@ where config, ); } - let quotient_polys = compute_quotient_polys::::Packing, C, S, D>( - stark, - trace_commitment, - &permutation_ctl_zs_commitment, - permutation_challenges.as_ref(), - ctl_data, - alphas, - degree_bits, - num_permutation_zs, - config, + let quotient_polys = timed!( + timing, + "compute quotient polys", + compute_quotient_polys::::Packing, C, S, D>( + stark, + trace_commitment, + &permutation_ctl_zs_commitment, + permutation_challenges.as_ref(), + ctl_data, + alphas, + degree_bits, + num_permutation_zs, + config, + ) + ); + let all_quotient_chunks = timed!( + timing, + "split quotient polys", + quotient_polys + .into_par_iter() + .flat_map(|mut quotient_poly| { + quotient_poly + .trim_to_len(degree * stark.quotient_degree_factor()) + .expect( + "Quotient has failed, the vanishing polynomial is not divisible by Z_H", + ); + // Split quotient into degree-n chunks. + quotient_poly.chunks(degree) + }) + .collect() ); - let all_quotient_chunks = quotient_polys - .into_par_iter() - .flat_map(|mut quotient_poly| { - quotient_poly - .trim_to_len(degree * stark.quotient_degree_factor()) - .expect("Quotient has failed, the vanishing polynomial is not divisible by Z_H"); - // Split quotient into degree-n chunks. - quotient_poly.chunks(degree) - }) - .collect(); let quotient_commitment = timed!( timing, "compute quotient commitment", diff --git a/evm/src/recursive_verifier.rs b/evm/src/recursive_verifier.rs index ae3bd27f..b2f67610 100644 --- a/evm/src/recursive_verifier.rs +++ b/evm/src/recursive_verifier.rs @@ -19,7 +19,6 @@ use plonky2::plonk::proof::ProofWithPublicInputs; use plonky2::util::reducing::ReducingFactorTarget; use plonky2::with_context; -use crate::all_stark::{AllStark, Table, NUM_TABLES}; use crate::config::StarkConfig; use crate::constraint_consumer::RecursiveConstraintConsumer; use crate::cpu::cpu_stark::CpuStark; @@ -41,9 +40,13 @@ use crate::proof::{ StarkProofTarget, TrieRoots, TrieRootsTarget, }; use crate::stark::Stark; -use crate::util::{h160_limbs, u256_limbs}; +use crate::util::h160_limbs; use crate::vanishing_poly::eval_vanishing_poly_circuit; use crate::vars::StarkEvaluationTargets; +use crate::{ + all_stark::{AllStark, Table}, + util::h256_limbs, +}; /// Table-wise recursive proofs of an `AllProof`. pub struct RecursiveAllProof< @@ -606,8 +609,8 @@ fn verify_stark_proof_with_challenges_circuit< let degree_bits = proof.recover_degree_bits(inner_config); let zeta_pow_deg = builder.exp_power_of_2_extension(challenges.stark_zeta, degree_bits); let z_h_zeta = builder.sub_extension(zeta_pow_deg, one); - let (l_1, l_last) = - eval_l_1_and_l_last_circuit(builder, degree_bits, challenges.stark_zeta, z_h_zeta); + let (l_0, l_last) = + eval_l_0_and_l_last_circuit(builder, degree_bits, challenges.stark_zeta, z_h_zeta); let last = builder.constant_extension(F::Extension::primitive_root_of_unity(degree_bits).inverse()); let z_last = builder.sub_extension(challenges.stark_zeta, last); @@ -616,7 +619,7 @@ fn verify_stark_proof_with_challenges_circuit< builder.zero_extension(), challenges.stark_alphas.clone(), z_last, - l_1, + l_0, l_last, ); @@ -679,7 +682,7 @@ fn verify_stark_proof_with_challenges_circuit< ); } -fn eval_l_1_and_l_last_circuit, const D: usize>( +fn eval_l_0_and_l_last_circuit, const D: usize>( builder: &mut CircuitBuilder, log_n: usize, x: ExtensionTarget, @@ -688,12 +691,12 @@ fn eval_l_1_and_l_last_circuit, const D: usize>( let n = builder.constant_extension(F::Extension::from_canonical_usize(1 << log_n)); let g = builder.constant_extension(F::Extension::primitive_root_of_unity(log_n)); let one = builder.one_extension(); - let l_1_deno = builder.mul_sub_extension(n, x, n); + let l_0_deno = builder.mul_sub_extension(n, x, n); let l_last_deno = builder.mul_sub_extension(g, x, one); let l_last_deno = builder.mul_extension(n, l_last_deno); ( - builder.div_extension(z_x, l_1_deno), + builder.div_extension(z_x, l_0_deno), builder.div_extension(z_x, l_last_deno), ) } @@ -929,15 +932,15 @@ pub fn set_trie_roots_target( { witness.set_target_arr( trie_roots_target.state_root, - u256_limbs(trie_roots.state_root), + h256_limbs(trie_roots.state_root), ); witness.set_target_arr( trie_roots_target.transactions_root, - u256_limbs(trie_roots.transactions_root), + h256_limbs(trie_roots.transactions_root), ); witness.set_target_arr( trie_roots_target.receipts_root, - u256_limbs(trie_roots.receipts_root), + h256_limbs(trie_roots.receipts_root), ); } diff --git a/evm/src/stark.rs b/evm/src/stark.rs index a205547a..49c5b70b 100644 --- a/evm/src/stark.rs +++ b/evm/src/stark.rs @@ -16,6 +16,10 @@ use crate::permutation::PermutationPair; use crate::vars::StarkEvaluationTargets; use crate::vars::StarkEvaluationVars; +const TRACE_ORACLE_INDEX: usize = 0; +const PERMUTATION_CTL_ORACLE_INDEX: usize = 1; +const QUOTIENT_ORACLE_INDEX: usize = 2; + /// Represents a STARK system. pub trait Stark, const D: usize>: Sync { /// The total number of columns in the trace. @@ -72,6 +76,10 @@ pub trait Stark, const D: usize>: Sync { 1.max(self.constraint_degree() - 1) } + fn num_quotient_polys(&self, config: &StarkConfig) -> usize { + self.quotient_degree_factor() * config.num_challenges + } + /// Computes the FRI instance used to prove this Stark. fn fri_instance( &self, @@ -81,28 +89,35 @@ pub trait Stark, const D: usize>: Sync { num_ctl_zs: usize, config: &StarkConfig, ) -> FriInstanceInfo { - let no_blinding_oracle = FriOracleInfo { blinding: false }; - let mut oracle_indices = 0..; - - let trace_info = - FriPolynomialInfo::from_range(oracle_indices.next().unwrap(), 0..Self::COLUMNS); + let trace_oracle = FriOracleInfo { + num_polys: Self::COLUMNS, + blinding: false, + }; + let trace_info = FriPolynomialInfo::from_range(TRACE_ORACLE_INDEX, 0..Self::COLUMNS); let num_permutation_batches = self.num_permutation_batches(config); - let permutation_ctl_index = oracle_indices.next().unwrap(); + let num_perutation_ctl_polys = num_permutation_batches + num_ctl_zs; + let permutation_ctl_oracle = FriOracleInfo { + num_polys: num_perutation_ctl_polys, + blinding: false, + }; let permutation_ctl_zs_info = FriPolynomialInfo::from_range( - permutation_ctl_index, - 0..num_permutation_batches + num_ctl_zs, + PERMUTATION_CTL_ORACLE_INDEX, + 0..num_perutation_ctl_polys, ); let ctl_zs_info = FriPolynomialInfo::from_range( - permutation_ctl_index, + PERMUTATION_CTL_ORACLE_INDEX, num_permutation_batches..num_permutation_batches + num_ctl_zs, ); - let quotient_info = FriPolynomialInfo::from_range( - oracle_indices.next().unwrap(), - 0..self.quotient_degree_factor() * config.num_challenges, - ); + let num_quotient_polys = self.num_quotient_polys(config); + let quotient_oracle = FriOracleInfo { + num_polys: num_quotient_polys, + blinding: false, + }; + let quotient_info = + FriPolynomialInfo::from_range(QUOTIENT_ORACLE_INDEX, 0..num_quotient_polys); let zeta_batch = FriBatchInfo { point: zeta, @@ -122,7 +137,7 @@ pub trait Stark, const D: usize>: Sync { polynomials: ctl_zs_info, }; FriInstanceInfo { - oracles: vec![no_blinding_oracle; oracle_indices.next().unwrap()], + oracles: vec![trace_oracle, permutation_ctl_oracle, quotient_oracle], batches: vec![zeta_batch, zeta_next_batch, ctl_last_batch], } } @@ -137,28 +152,35 @@ pub trait Stark, const D: usize>: Sync { num_ctl_zs: usize, inner_config: &StarkConfig, ) -> FriInstanceInfoTarget { - let no_blinding_oracle = FriOracleInfo { blinding: false }; - let mut oracle_indices = 0..; - - let trace_info = - FriPolynomialInfo::from_range(oracle_indices.next().unwrap(), 0..Self::COLUMNS); + let trace_oracle = FriOracleInfo { + num_polys: Self::COLUMNS, + blinding: false, + }; + let trace_info = FriPolynomialInfo::from_range(TRACE_ORACLE_INDEX, 0..Self::COLUMNS); let num_permutation_batches = self.num_permutation_batches(inner_config); - let permutation_ctl_index = oracle_indices.next().unwrap(); + let num_perutation_ctl_polys = num_permutation_batches + num_ctl_zs; + let permutation_ctl_oracle = FriOracleInfo { + num_polys: num_perutation_ctl_polys, + blinding: false, + }; let permutation_ctl_zs_info = FriPolynomialInfo::from_range( - permutation_ctl_index, - 0..num_permutation_batches + num_ctl_zs, + PERMUTATION_CTL_ORACLE_INDEX, + 0..num_perutation_ctl_polys, ); let ctl_zs_info = FriPolynomialInfo::from_range( - permutation_ctl_index, + PERMUTATION_CTL_ORACLE_INDEX, num_permutation_batches..num_permutation_batches + num_ctl_zs, ); - let quotient_info = FriPolynomialInfo::from_range( - oracle_indices.next().unwrap(), - 0..self.quotient_degree_factor() * inner_config.num_challenges, - ); + let num_quotient_polys = self.num_quotient_polys(inner_config); + let quotient_oracle = FriOracleInfo { + num_polys: num_quotient_polys, + blinding: false, + }; + let quotient_info = + FriPolynomialInfo::from_range(QUOTIENT_ORACLE_INDEX, 0..num_quotient_polys); let zeta_batch = FriBatchInfoTarget { point: zeta, @@ -180,7 +202,7 @@ pub trait Stark, const D: usize>: Sync { polynomials: ctl_zs_info, }; FriInstanceInfoTarget { - oracles: vec![no_blinding_oracle; oracle_indices.next().unwrap()], + oracles: vec![trace_oracle, permutation_ctl_oracle, quotient_oracle], batches: vec![zeta_batch, zeta_next_batch, ctl_last_batch], } } diff --git a/evm/src/util.rs b/evm/src/util.rs index 12aead46..7f958fd2 100644 --- a/evm/src/util.rs +++ b/evm/src/util.rs @@ -1,6 +1,6 @@ use std::mem::{size_of, transmute_copy, ManuallyDrop}; -use ethereum_types::{H160, U256}; +use ethereum_types::{H160, H256, U256}; use itertools::Itertools; use plonky2::field::extension::Extendable; use plonky2::field::packed::PackedField; @@ -59,6 +59,17 @@ pub(crate) fn u256_limbs(u256: U256) -> [F; 8] { .unwrap() } +/// Returns the 32-bit little-endian limbs of a `H256`. +pub(crate) fn h256_limbs(h256: H256) -> [F; 8] { + h256.0 + .chunks(4) + .map(|chunk| u32::from_le_bytes(chunk.try_into().unwrap())) + .map(F::from_canonical_u32) + .collect_vec() + .try_into() + .unwrap() +} + /// Returns the 32-bit limbs of a `U160`. pub(crate) fn h160_limbs(h160: H160) -> [F; 5] { h160.0 diff --git a/evm/src/verifier.rs b/evm/src/verifier.rs index bcf483e4..0bfbc3d4 100644 --- a/evm/src/verifier.rs +++ b/evm/src/verifier.rs @@ -122,6 +122,7 @@ where [(); S::COLUMNS]:, [(); C::Hasher::HASH_SIZE]:, { + validate_proof_shape(&stark, proof, config, ctl_vars.len())?; let StarkOpeningSet { local_values, next_values, @@ -136,7 +137,7 @@ where }; let degree_bits = proof.recover_degree_bits(config); - let (l_1, l_last) = eval_l_1_and_l_last(degree_bits, challenges.stark_zeta); + let (l_0, l_last) = eval_l_0_and_l_last(degree_bits, challenges.stark_zeta); let last = F::primitive_root_of_unity(degree_bits).inverse(); let z_last = challenges.stark_zeta - last.into(); let mut consumer = ConstraintConsumer::::new( @@ -146,7 +147,7 @@ where .map(|&alpha| F::Extension::from_basefield(alpha)) .collect::>(), z_last, - l_1, + l_0, l_last, ); let num_permutation_zs = stark.num_permutation_batches(config); @@ -207,10 +208,61 @@ where Ok(()) } -/// Evaluate the Lagrange polynomials `L_1` and `L_n` at a point `x`. -/// `L_1(x) = (x^n - 1)/(n * (x - 1))` -/// `L_n(x) = (x^n - 1)/(n * (g * x - 1))`, with `g` the first element of the subgroup. -fn eval_l_1_and_l_last(log_n: usize, x: F) -> (F, F) { +fn validate_proof_shape( + stark: &S, + proof: &StarkProof, + config: &StarkConfig, + num_ctl_zs: usize, +) -> anyhow::Result<()> +where + F: RichField + Extendable, + C: GenericConfig, + S: Stark, + [(); S::COLUMNS]:, + [(); C::Hasher::HASH_SIZE]:, +{ + let StarkProof { + trace_cap, + permutation_ctl_zs_cap, + quotient_polys_cap, + openings, + // The shape of the opening proof will be checked in the FRI verifier (see + // validate_fri_proof_shape), so we ignore it here. + opening_proof: _, + } = proof; + + let StarkOpeningSet { + local_values, + next_values, + permutation_ctl_zs, + permutation_ctl_zs_next, + ctl_zs_last, + quotient_polys, + } = openings; + + let degree_bits = proof.recover_degree_bits(config); + let fri_params = config.fri_params(degree_bits); + let cap_height = fri_params.config.cap_height; + let num_zs = num_ctl_zs + stark.num_permutation_batches(config); + + ensure!(trace_cap.height() == cap_height); + ensure!(permutation_ctl_zs_cap.height() == cap_height); + ensure!(quotient_polys_cap.height() == cap_height); + + ensure!(local_values.len() == S::COLUMNS); + ensure!(next_values.len() == S::COLUMNS); + ensure!(permutation_ctl_zs.len() == num_zs); + ensure!(permutation_ctl_zs_next.len() == num_zs); + ensure!(ctl_zs_last.len() == num_ctl_zs); + ensure!(quotient_polys.len() == stark.num_quotient_polys(config)); + + Ok(()) +} + +/// Evaluate the Lagrange polynomials `L_0` and `L_(n-1)` at a point `x`. +/// `L_0(x) = (x^n - 1)/(n * (x - 1))` +/// `L_(n-1)(x) = (x^n - 1)/(n * (g * x - 1))`, with `g` the first element of the subgroup. +fn eval_l_0_and_l_last(log_n: usize, x: F) -> (F, F) { let n = F::from_canonical_usize(1 << log_n); let g = F::primitive_root_of_unity(log_n); let z_x = x.exp_power_of_2(log_n) - F::ONE; @@ -225,10 +277,10 @@ mod tests { use plonky2::field::polynomial::PolynomialValues; use plonky2::field::types::Field; - use crate::verifier::eval_l_1_and_l_last; + use crate::verifier::eval_l_0_and_l_last; #[test] - fn test_eval_l_1_and_l_last() { + fn test_eval_l_0_and_l_last() { type F = GoldilocksField; let log_n = 5; let n = 1 << log_n; @@ -237,7 +289,7 @@ mod tests { let expected_l_first_x = PolynomialValues::selector(n, 0).ifft().eval(x); let expected_l_last_x = PolynomialValues::selector(n, n - 1).ifft().eval(x); - let (l_first_x, l_last_x) = eval_l_1_and_l_last(log_n, x); + let (l_first_x, l_last_x) = eval_l_0_and_l_last(log_n, x); assert_eq!(l_first_x, expected_l_first_x); assert_eq!(l_last_x, expected_l_last_x); } diff --git a/evm/tests/empty_txn_list.rs b/evm/tests/empty_txn_list.rs new file mode 100644 index 00000000..6e16fa47 --- /dev/null +++ b/evm/tests/empty_txn_list.rs @@ -0,0 +1,81 @@ +use std::collections::HashMap; + +use eth_trie_utils::partial_trie::{Nibbles, PartialTrie}; +use plonky2::field::goldilocks_field::GoldilocksField; +use plonky2::plonk::config::PoseidonGoldilocksConfig; +use plonky2::util::timing::TimingTree; +use plonky2_evm::all_stark::AllStark; +use plonky2_evm::config::StarkConfig; +use plonky2_evm::generation::{GenerationInputs, TrieInputs}; +use plonky2_evm::proof::BlockMetadata; +use plonky2_evm::prover::prove; +use plonky2_evm::verifier::verify_proof; + +type F = GoldilocksField; +const D: usize = 2; +type C = PoseidonGoldilocksConfig; + +/// Execute the empty list of transactions, i.e. a no-op. +#[test] +#[ignore] // TODO: Won't work until witness generation logic is finished. +fn test_empty_txn_list() -> anyhow::Result<()> { + let all_stark = AllStark::::default(); + let config = StarkConfig::standard_fast_config(); + + let block_metadata = BlockMetadata::default(); + + let state_trie = PartialTrie::Leaf { + nibbles: Nibbles { + count: 5, + packed: 0xABCDE.into(), + }, + value: vec![1, 2, 3], + }; + let transactions_trie = PartialTrie::Empty; + let receipts_trie = PartialTrie::Empty; + let storage_tries = vec![]; + + let state_trie_root = state_trie.calc_hash(); + let txns_trie_root = transactions_trie.calc_hash(); + let receipts_trie_root = receipts_trie.calc_hash(); + + let inputs = GenerationInputs { + signed_txns: vec![], + tries: TrieInputs { + state_trie, + transactions_trie, + receipts_trie, + storage_tries, + }, + contract_code: HashMap::new(), + block_metadata, + }; + + let proof = prove::(&all_stark, &config, inputs, &mut TimingTree::default())?; + assert_eq!( + proof.public_values.trie_roots_before.state_root, + state_trie_root + ); + assert_eq!( + proof.public_values.trie_roots_after.state_root, + state_trie_root + ); + assert_eq!( + proof.public_values.trie_roots_before.transactions_root, + txns_trie_root + ); + assert_eq!( + proof.public_values.trie_roots_after.transactions_root, + txns_trie_root + ); + assert_eq!( + proof.public_values.trie_roots_before.receipts_root, + receipts_trie_root + ); + assert_eq!( + proof.public_values.trie_roots_after.receipts_root, + receipts_trie_root + ); + + verify_proof(all_stark, proof, &config) +} diff --git a/evm/tests/transfer_to_new_addr.rs b/evm/tests/transfer_to_new_addr.rs index ecb71076..1c74366e 100644 --- a/evm/tests/transfer_to_new_addr.rs +++ b/evm/tests/transfer_to_new_addr.rs @@ -1,11 +1,13 @@ +use std::collections::HashMap; + +use eth_trie_utils::partial_trie::PartialTrie; use hex_literal::hex; use plonky2::field::goldilocks_field::GoldilocksField; use plonky2::plonk::config::PoseidonGoldilocksConfig; use plonky2::util::timing::TimingTree; use plonky2_evm::all_stark::AllStark; use plonky2_evm::config::StarkConfig; -use plonky2_evm::generation::partial_trie::PartialTrie; -use plonky2_evm::generation::GenerationInputs; +use plonky2_evm::generation::{GenerationInputs, TrieInputs}; use plonky2_evm::proof::BlockMetadata; use plonky2_evm::prover::prove; use plonky2_evm::verifier::verify_proof; @@ -27,10 +29,13 @@ fn test_simple_transfer() -> anyhow::Result<()> { let inputs = GenerationInputs { signed_txns: vec![txn.to_vec()], - state_trie: PartialTrie::Empty, - transactions_trie: PartialTrie::Empty, - receipts_trie: PartialTrie::Empty, - storage_tries: vec![], + tries: TrieInputs { + state_trie: PartialTrie::Empty, + transactions_trie: PartialTrie::Empty, + receipts_trie: PartialTrie::Empty, + storage_tries: vec![], + }, + contract_code: HashMap::new(), block_metadata, }; diff --git a/field/src/extension/mod.rs b/field/src/extension/mod.rs index f54d669c..ed596764 100644 --- a/field/src/extension/mod.rs +++ b/field/src/extension/mod.rs @@ -22,8 +22,8 @@ pub trait OEF: FieldExtension { } impl OEF<1> for F { - const W: Self::BaseField = F::ZERO; - const DTH_ROOT: Self::BaseField = F::ZERO; + const W: Self::BaseField = F::ONE; + const DTH_ROOT: Self::BaseField = F::ONE; } pub trait Frobenius: OEF { @@ -80,8 +80,8 @@ pub trait Extendable: Field + Sized { impl + FieldExtension<1, BaseField = F>> Extendable<1> for F { type Extension = F; - const W: Self = F::ZERO; - const DTH_ROOT: Self = F::ZERO; + const W: Self = F::ONE; + const DTH_ROOT: Self = F::ONE; const EXT_MULTIPLICATIVE_GROUP_GENERATOR: [Self; 1] = [F::MULTIPLICATIVE_GROUP_GENERATOR]; const EXT_POWER_OF_TWO_GENERATOR: [Self; 1] = [F::POWER_OF_TWO_GENERATOR]; } diff --git a/field/src/types.rs b/field/src/types.rs index ac94bcfa..7130b7f5 100644 --- a/field/src/types.rs +++ b/field/src/types.rs @@ -427,6 +427,59 @@ pub trait Field: pub trait PrimeField: Field { fn to_canonical_biguint(&self) -> BigUint; + + fn is_quadratic_residue(&self) -> bool { + if self.is_zero() { + return true; + } + // This is based on Euler's criterion. + let power = Self::NEG_ONE.to_canonical_biguint() / 2u8; + let exp = self.exp_biguint(&power); + if exp == Self::ONE { + return true; + } + if exp == Self::NEG_ONE { + return false; + } + panic!("Unreachable") + } + + fn sqrt(&self) -> Option { + if self.is_zero() { + Some(*self) + } else if self.is_quadratic_residue() { + let t = (Self::order() - BigUint::from(1u32)) + / (BigUint::from(2u32).pow(Self::TWO_ADICITY as u32)); + let mut z = Self::POWER_OF_TWO_GENERATOR; + let mut w = self.exp_biguint(&((t - BigUint::from(1u32)) / BigUint::from(2u32))); + let mut x = w * *self; + let mut b = x * w; + + let mut v = Self::TWO_ADICITY as usize; + + while !b.is_one() { + let mut k = 0usize; + let mut b2k = b; + while !b2k.is_one() { + b2k = b2k * b2k; + k += 1; + } + let j = v - k - 1; + w = z; + for _ in 0..j { + w = w * w; + } + + z = w * w; + b *= z; + x *= w; + v = k; + } + Some(x) + } else { + None + } + } } /// A finite field of order less than 2^64. diff --git a/field/src/zero_poly_coset.rs b/field/src/zero_poly_coset.rs index 18cc3238..8d63bc69 100644 --- a/field/src/zero_poly_coset.rs +++ b/field/src/zero_poly_coset.rs @@ -51,8 +51,8 @@ impl ZeroPolyOnCoset { packed } - /// Returns `L_1(x) = Z_H(x)/(n * (x - 1))` with `x = w^i`. - pub fn eval_l1(&self, i: usize, x: F) -> F { + /// Returns `L_0(x) = Z_H(x)/(n * (x - 1))` with `x = w^i`. + pub fn eval_l_0(&self, i: usize, x: F) -> F { // Could also precompute the inverses using Montgomery. self.eval(i) * (self.n * (x - F::ONE)).inverse() } diff --git a/maybe_rayon/src/lib.rs b/maybe_rayon/src/lib.rs index 1a9bd823..d24ba2e5 100644 --- a/maybe_rayon/src/lib.rs +++ b/maybe_rayon/src/lib.rs @@ -1,6 +1,6 @@ #[cfg(not(feature = "parallel"))] use std::{ - iter::{IntoIterator, Iterator}, + iter::{FlatMap, IntoIterator, Iterator}, slice::{Chunks, ChunksExact, ChunksExactMut, ChunksMut}, }; @@ -223,13 +223,21 @@ impl MaybeParChunksMut for [T] { } } +#[cfg(not(feature = "parallel"))] pub trait ParallelIteratorMock { type Item; fn find_any

(self, predicate: P) -> Option where P: Fn(&Self::Item) -> bool + Sync + Send; + + fn flat_map_iter(self, map_op: F) -> FlatMap + where + Self: Sized, + U: IntoIterator, + F: Fn(Self::Item) -> U; } +#[cfg(not(feature = "parallel"))] impl ParallelIteratorMock for T { type Item = T::Item; @@ -239,6 +247,15 @@ impl ParallelIteratorMock for T { { self.find(predicate) } + + fn flat_map_iter(self, map_op: F) -> FlatMap + where + Self: Sized, + U: IntoIterator, + F: Fn(Self::Item) -> U, + { + self.flat_map(map_op) + } } #[cfg(feature = "parallel")] diff --git a/plonky2/examples/factorial.rs b/plonky2/examples/factorial.rs new file mode 100644 index 00000000..bcdb35dc --- /dev/null +++ b/plonky2/examples/factorial.rs @@ -0,0 +1,43 @@ +use anyhow::Result; +use plonky2::field::types::Field; +use plonky2::iop::witness::{PartialWitness, Witness}; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use plonky2::plonk::circuit_data::CircuitConfig; +use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + +/// An example of using Plonky2 to prove a statement of the form +/// "I know n * (n + 1) * ... * (n + 99)". +/// When n == 1, this is proving knowledge of 100!. +fn main() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::::new(config); + + // The arithmetic circuit. + let initial = builder.add_virtual_target(); + let mut cur_target = initial; + for i in 2..101 { + let i_target = builder.constant(F::from_canonical_u32(i)); + cur_target = builder.mul(cur_target, i_target); + } + + // Public inputs are the initial value (provided below) and the result (which is generated). + builder.register_public_input(initial); + builder.register_public_input(cur_target); + + let mut pw = PartialWitness::new(); + pw.set_target(initial, F::ONE); + + let data = builder.build::(); + let proof = data.prove(pw)?; + + println!( + "Factorial starting at {} is {}!", + proof.public_inputs[0], proof.public_inputs[1] + ); + + data.verify(proof) +} diff --git a/plonky2/examples/fibonacci.rs b/plonky2/examples/fibonacci.rs new file mode 100644 index 00000000..6609fc1d --- /dev/null +++ b/plonky2/examples/fibonacci.rs @@ -0,0 +1,49 @@ +use anyhow::Result; +use plonky2::field::types::Field; +use plonky2::iop::witness::{PartialWitness, Witness}; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use plonky2::plonk::circuit_data::CircuitConfig; +use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + +/// An example of using Plonky2 to prove a statement of the form +/// "I know the 100th element of the Fibonacci sequence, starting with constants a and b." +/// When a == 0 and b == 1, this is proving knowledge of the 100th (standard) Fibonacci number. +fn main() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::::new(config); + + // The arithmetic circuit. + let initial_a = builder.add_virtual_target(); + let initial_b = builder.add_virtual_target(); + let mut prev_target = initial_a; + let mut cur_target = initial_b; + for _ in 0..99 { + let temp = builder.add(prev_target, cur_target); + prev_target = cur_target; + cur_target = temp; + } + + // Public inputs are the two initial values (provided below) and the result (which is generated). + builder.register_public_input(initial_a); + builder.register_public_input(initial_b); + builder.register_public_input(cur_target); + + // Provide initial values. + let mut pw = PartialWitness::new(); + pw.set_target(initial_a, F::ZERO); + pw.set_target(initial_b, F::ONE); + + let data = builder.build::(); + let proof = data.prove(pw)?; + + println!( + "100th Fibonacci number mod |F| (starting with {}, {}) is: {}", + proof.public_inputs[0], proof.public_inputs[1], proof.public_inputs[2] + ); + + data.verify(proof) +} diff --git a/plonky2/examples/square_root.rs b/plonky2/examples/square_root.rs new file mode 100644 index 00000000..0bc89f47 --- /dev/null +++ b/plonky2/examples/square_root.rs @@ -0,0 +1,81 @@ +use std::marker::PhantomData; + +use anyhow::Result; +use plonky2::field::types::{Field, PrimeField}; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::generator::{GeneratedValues, SimpleGenerator}; +use plonky2::iop::target::Target; +use plonky2::iop::witness::{PartialWitness, PartitionWitness, Witness}; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use plonky2::plonk::circuit_data::CircuitConfig; +use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; +use plonky2_field::extension::Extendable; + +/// A generator used by the prover to calculate the square root (`x`) of a given value +/// (`x_squared`), outside of the circuit, in order to supply it as an additional public input. +#[derive(Debug)] +struct SquareRootGenerator, const D: usize> { + x: Target, + x_squared: Target, + _phantom: PhantomData, +} + +impl, const D: usize> SimpleGenerator + for SquareRootGenerator +{ + fn dependencies(&self) -> Vec { + vec![self.x_squared] + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let x_squared = witness.get_target(self.x_squared); + let x = x_squared.sqrt().unwrap(); + + println!("Square root: {}", x); + + out_buffer.set_target(self.x, x); + } +} + +/// An example of using Plonky2 to prove a statement of the form +/// "I know the square root of this field element." +fn main() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let config = CircuitConfig::standard_recursion_config(); + + let mut builder = CircuitBuilder::::new(config); + + let x = builder.add_virtual_target(); + let x_squared = builder.square(x); + + builder.register_public_input(x_squared); + + builder.add_simple_generator(SquareRootGenerator:: { + x, + x_squared, + _phantom: PhantomData, + }); + + // Randomly generate the value of x^2: any quadratic residue in the field works. + let x_squared_value = { + let mut val = F::rand(); + while !val.is_quadratic_residue() { + val = F::rand(); + } + val + }; + + let mut pw = PartialWitness::new(); + pw.set_target(x_squared, x_squared_value); + + let data = builder.build::(); + let proof = data.prove(pw.clone())?; + + let x_squared_actual = proof.public_inputs[0]; + println!("Field element (square): {}", x_squared_actual); + + data.verify(proof) +} diff --git a/plonky2/plonky2.pdf b/plonky2/plonky2.pdf index ad0cef02..8f0f9ece 100644 Binary files a/plonky2/plonky2.pdf and b/plonky2/plonky2.pdf differ diff --git a/plonky2/src/fri/mod.rs b/plonky2/src/fri/mod.rs index def33a73..87c4c2aa 100644 --- a/plonky2/src/fri/mod.rs +++ b/plonky2/src/fri/mod.rs @@ -7,6 +7,7 @@ pub mod prover; pub mod recursive_verifier; pub mod reduction_strategies; pub mod structure; +mod validate_shape; pub mod verifier; pub mod witness_util; diff --git a/plonky2/src/fri/oracle.rs b/plonky2/src/fri/oracle.rs index 1f5b648f..75f8847a 100644 --- a/plonky2/src/fri/oracle.rs +++ b/plonky2/src/fri/oracle.rs @@ -180,7 +180,15 @@ impl, C: GenericConfig, const D: usize> // Final low-degree polynomial that goes into FRI. let mut final_poly = PolynomialCoeffs::empty(); + // Each batch `i` consists of an opening point `z_i` and polynomials `{f_ij}_j` to be opened at that point. + // For each batch, we compute the composition polynomial `F_i = sum alpha^j f_ij`, + // where `alpha` is a random challenge in the extension field. + // The final polynomial is then computed as `final_poly = sum_i alpha^(k_i) (F_i(X) - F_i(z_i))/(X-z_i)` + // where the `k_i`s are chosen such that each power of `alpha` appears only once in the final sum. + // There are usually two batches for the openings at `zeta` and `g * zeta`. + // The oracles used in Plonky2 are given in `FRI_ORACLES` in `plonky2/src/plonk/plonk_common.rs`. for FriBatchInfo { point, polynomials } in &instance.batches { + // Collect the coefficients of all the polynomials in `polynomials`. let polys_coeff = polynomials.iter().map(|fri_poly| { &oracles[fri_poly.oracle_index].polynomials[fri_poly.polynomial_index] }); diff --git a/plonky2/src/fri/prover.rs b/plonky2/src/fri/prover.rs index 39e25869..71efe98a 100644 --- a/plonky2/src/fri/prover.rs +++ b/plonky2/src/fri/prover.rs @@ -149,15 +149,12 @@ fn fri_prover_query_rounds< n: usize, fri_params: &FriParams, ) -> Vec> { - (0..fri_params.config.num_query_rounds) - .map(|_| { - fri_prover_query_round::( - initial_merkle_trees, - trees, - challenger, - n, - fri_params, - ) + challenger + .get_n_challenges(fri_params.config.num_query_rounds) + .into_par_iter() + .map(|rand| { + let x_index = rand.to_canonical_u64() as usize % n; + fri_prover_query_round::(initial_merkle_trees, trees, x_index, fri_params) }) .collect() } @@ -169,13 +166,10 @@ fn fri_prover_query_round< >( initial_merkle_trees: &[&MerkleTree], trees: &[MerkleTree], - challenger: &mut Challenger, - n: usize, + mut x_index: usize, fri_params: &FriParams, ) -> FriQueryRound { let mut query_steps = Vec::new(); - let x = challenger.get_challenge(); - let mut x_index = x.to_canonical_u64() as usize % n; let initial_proof = initial_merkle_trees .iter() .map(|t| (t.get(x_index).to_vec(), t.prove(x_index))) diff --git a/plonky2/src/fri/recursive_verifier.rs b/plonky2/src/fri/recursive_verifier.rs index 1a3739b4..ac7e3a87 100644 --- a/plonky2/src/fri/recursive_verifier.rs +++ b/plonky2/src/fri/recursive_verifier.rs @@ -8,9 +8,9 @@ use crate::fri::proof::{ }; use crate::fri::structure::{FriBatchInfoTarget, FriInstanceInfoTarget, FriOpeningsTarget}; use crate::fri::{FriConfig, FriParams}; -use crate::gadgets::interpolation::InterpolationGate; use crate::gates::gate::Gate; -use crate::gates::interpolation::HighDegreeInterpolationGate; +use crate::gates::high_degree_interpolation::HighDegreeInterpolationGate; +use crate::gates::interpolation::InterpolationGate; use crate::gates::low_degree_interpolation::LowDegreeInterpolationGate; use crate::gates::random_access::RandomAccessGate; use crate::hash::hash_types::MerkleCapTarget; diff --git a/plonky2/src/fri/structure.rs b/plonky2/src/fri/structure.rs index 1a37a1b2..0d64ae20 100644 --- a/plonky2/src/fri/structure.rs +++ b/plonky2/src/fri/structure.rs @@ -25,6 +25,7 @@ pub struct FriInstanceInfoTarget { #[derive(Copy, Clone)] pub struct FriOracleInfo { + pub num_polys: usize, pub blinding: bool, } @@ -42,7 +43,7 @@ pub struct FriBatchInfoTarget { #[derive(Copy, Clone, Debug)] pub struct FriPolynomialInfo { - /// Index into `FriInstanceInfoTarget`'s `oracles` list. + /// Index into `FriInstanceInfo`'s `oracles` list. pub oracle_index: usize, /// Index of the polynomial within the oracle. pub polynomial_index: usize, diff --git a/plonky2/src/fri/validate_shape.rs b/plonky2/src/fri/validate_shape.rs new file mode 100644 index 00000000..0ef85c4c --- /dev/null +++ b/plonky2/src/fri/validate_shape.rs @@ -0,0 +1,67 @@ +use anyhow::ensure; +use plonky2_field::extension::Extendable; + +use crate::fri::proof::{FriProof, FriQueryRound, FriQueryStep}; +use crate::fri::structure::FriInstanceInfo; +use crate::fri::FriParams; +use crate::hash::hash_types::RichField; +use crate::plonk::config::GenericConfig; +use crate::plonk::plonk_common::salt_size; + +pub(crate) fn validate_fri_proof_shape( + proof: &FriProof, + instance: &FriInstanceInfo, + params: &FriParams, +) -> anyhow::Result<()> +where + F: RichField + Extendable, + C: GenericConfig, +{ + let FriProof { + commit_phase_merkle_caps, + query_round_proofs, + final_poly, + pow_witness: _pow_witness, + } = proof; + + let cap_height = params.config.cap_height; + for cap in commit_phase_merkle_caps { + ensure!(cap.height() == cap_height); + } + + for query_round in query_round_proofs { + let FriQueryRound { + initial_trees_proof, + steps, + } = query_round; + + ensure!(initial_trees_proof.evals_proofs.len() == instance.oracles.len()); + for ((leaf, merkle_proof), oracle) in initial_trees_proof + .evals_proofs + .iter() + .zip(&instance.oracles) + { + ensure!(leaf.len() == oracle.num_polys + salt_size(oracle.blinding && params.hiding)); + ensure!(merkle_proof.len() + cap_height == params.lde_bits()); + } + + ensure!(steps.len() == params.reduction_arity_bits.len()); + let mut codeword_len_bits = params.lde_bits(); + for (step, arity_bits) in steps.iter().zip(¶ms.reduction_arity_bits) { + let FriQueryStep { + evals, + merkle_proof, + } = step; + + let arity = 1 << arity_bits; + codeword_len_bits -= arity_bits; + + ensure!(evals.len() == arity); + ensure!(merkle_proof.len() + cap_height == codeword_len_bits); + } + } + + ensure!(final_poly.len() == params.final_poly_len()); + + Ok(()) +} diff --git a/plonky2/src/fri/verifier.rs b/plonky2/src/fri/verifier.rs index ed44f0c4..02816000 100644 --- a/plonky2/src/fri/verifier.rs +++ b/plonky2/src/fri/verifier.rs @@ -6,6 +6,7 @@ use plonky2_util::{log2_strict, reverse_index_bits_in_place}; use crate::fri::proof::{FriChallenges, FriInitialTreeProof, FriProof, FriQueryRound}; use crate::fri::structure::{FriBatchInfo, FriInstanceInfo, FriOpenings}; +use crate::fri::validate_shape::validate_fri_proof_shape; use crate::fri::{FriConfig, FriParams}; use crate::hash::hash_types::RichField; use crate::hash::merkle_proofs::verify_merkle_proof_to_cap; @@ -67,10 +68,7 @@ pub fn verify_fri_proof, C: GenericConfig where [(); C::Hasher::HASH_SIZE]:, { - ensure!( - params.final_poly_len() == proof.final_poly.len(), - "Final polynomial has wrong degree." - ); + validate_fri_proof_shape::(proof, instance, params)?; // Size of the LDE domain. let n = params.lde_size(); diff --git a/plonky2/src/gadgets/interpolation.rs b/plonky2/src/gadgets/interpolation.rs deleted file mode 100644 index b22f3b59..00000000 --- a/plonky2/src/gadgets/interpolation.rs +++ /dev/null @@ -1,178 +0,0 @@ -use std::ops::Range; - -use plonky2_field::extension::Extendable; - -use crate::gates::gate::Gate; -use crate::hash::hash_types::RichField; -use crate::iop::ext_target::ExtensionTarget; -use crate::iop::target::Target; -use crate::plonk::circuit_builder::CircuitBuilder; - -/// Trait for gates which interpolate a polynomial, whose points are a (base field) coset of the multiplicative subgroup -/// with the given size, and whose values are extension field elements, given by input wires. -/// Outputs the evaluation of the interpolant at a given (extension field) evaluation point. -pub(crate) trait InterpolationGate, const D: usize>: - Gate + Copy -{ - fn new(subgroup_bits: usize) -> Self; - - fn num_points(&self) -> usize; - - /// Wire index of the coset shift. - fn wire_shift(&self) -> usize { - 0 - } - - fn start_values(&self) -> usize { - 1 - } - - /// Wire indices of the `i`th interpolant value. - fn wires_value(&self, i: usize) -> Range { - debug_assert!(i < self.num_points()); - let start = self.start_values() + i * D; - start..start + D - } - - fn start_evaluation_point(&self) -> usize { - self.start_values() + self.num_points() * D - } - - /// Wire indices of the point to evaluate the interpolant at. - fn wires_evaluation_point(&self) -> Range { - let start = self.start_evaluation_point(); - start..start + D - } - - fn start_evaluation_value(&self) -> usize { - self.start_evaluation_point() + D - } - - /// Wire indices of the interpolated value. - fn wires_evaluation_value(&self) -> Range { - let start = self.start_evaluation_value(); - start..start + D - } - - fn start_coeffs(&self) -> usize { - self.start_evaluation_value() + D - } - - /// The number of routed wires required in the typical usage of this gate, where the points to - /// interpolate, the evaluation point, and the corresponding value are all routed. - fn num_routed_wires(&self) -> usize { - self.start_coeffs() - } - - /// Wire indices of the interpolant's `i`th coefficient. - fn wires_coeff(&self, i: usize) -> Range { - debug_assert!(i < self.num_points()); - let start = self.start_coeffs() + i * D; - start..start + D - } - - fn end_coeffs(&self) -> usize { - self.start_coeffs() + D * self.num_points() - } -} - -impl, const D: usize> CircuitBuilder { - /// Interpolates a polynomial, whose points are a coset of the multiplicative subgroup with the - /// given size, and whose values are given. Returns the evaluation of the interpolant at - /// `evaluation_point`. - pub(crate) fn interpolate_coset>( - &mut self, - subgroup_bits: usize, - coset_shift: Target, - values: &[ExtensionTarget], - evaluation_point: ExtensionTarget, - ) -> ExtensionTarget { - let gate = G::new(subgroup_bits); - let row = self.add_gate(gate, vec![]); - self.connect(coset_shift, Target::wire(row, gate.wire_shift())); - for (i, &v) in values.iter().enumerate() { - self.connect_extension(v, ExtensionTarget::from_range(row, gate.wires_value(i))); - } - self.connect_extension( - evaluation_point, - ExtensionTarget::from_range(row, gate.wires_evaluation_point()), - ); - - ExtensionTarget::from_range(row, gate.wires_evaluation_value()) - } -} - -#[cfg(test)] -mod tests { - use anyhow::Result; - use plonky2_field::extension::FieldExtension; - use plonky2_field::interpolation::interpolant; - use plonky2_field::types::Field; - - use crate::gates::interpolation::HighDegreeInterpolationGate; - use crate::gates::low_degree_interpolation::LowDegreeInterpolationGate; - use crate::iop::witness::PartialWitness; - use crate::plonk::circuit_builder::CircuitBuilder; - use crate::plonk::circuit_data::CircuitConfig; - use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; - use crate::plonk::verifier::verify; - - #[test] - fn test_interpolate() -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - type FF = >::FE; - let config = CircuitConfig::standard_recursion_config(); - let pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - let subgroup_bits = 2; - let len = 1 << subgroup_bits; - let coset_shift = F::rand(); - let g = F::primitive_root_of_unity(subgroup_bits); - let points = F::cyclic_subgroup_coset_known_order(g, coset_shift, len); - let values = FF::rand_vec(len); - - let homogeneous_points = points - .iter() - .zip(values.iter()) - .map(|(&a, &b)| (>::from_basefield(a), b)) - .collect::>(); - - let true_interpolant = interpolant(&homogeneous_points); - - let z = FF::rand(); - let true_eval = true_interpolant.eval(z); - - let coset_shift_target = builder.constant(coset_shift); - - let value_targets = values - .iter() - .map(|&v| (builder.constant_extension(v))) - .collect::>(); - - let zt = builder.constant_extension(z); - - let eval_hd = builder.interpolate_coset::>( - subgroup_bits, - coset_shift_target, - &value_targets, - zt, - ); - let eval_ld = builder.interpolate_coset::>( - subgroup_bits, - coset_shift_target, - &value_targets, - zt, - ); - let true_eval_target = builder.constant_extension(true_eval); - builder.connect_extension(eval_hd, true_eval_target); - builder.connect_extension(eval_ld, true_eval_target); - - let data = builder.build::(); - let proof = data.prove(pw)?; - - verify(proof, &data.verifier_only, &data.common) - } -} diff --git a/plonky2/src/gadgets/mod.rs b/plonky2/src/gadgets/mod.rs index 6309eb3d..a3e50c4e 100644 --- a/plonky2/src/gadgets/mod.rs +++ b/plonky2/src/gadgets/mod.rs @@ -1,7 +1,6 @@ pub mod arithmetic; pub mod arithmetic_extension; pub mod hash; -pub mod interpolation; pub mod polynomial; pub mod random_access; pub mod range_check; diff --git a/plonky2/src/gates/base_sum.rs b/plonky2/src/gates/base_sum.rs index a7f4fcb3..5be54eeb 100644 --- a/plonky2/src/gates/base_sum.rs +++ b/plonky2/src/gates/base_sum.rs @@ -3,6 +3,7 @@ use std::ops::Range; use plonky2_field::extension::Extendable; use plonky2_field::packed::PackedField; use plonky2_field::types::{Field, Field64}; +use plonky2_util::log_floor; use crate::gates::gate::Gate; use crate::gates::packed_util::PackedEvaluableBase; @@ -32,7 +33,8 @@ impl BaseSumGate { } pub fn new_from_config(config: &CircuitConfig) -> Self { - let num_limbs = F::BITS.min(config.num_routed_wires - Self::START_LIMBS); + let num_limbs = + log_floor(F::ORDER - 1, B as u64).min(config.num_routed_wires - Self::START_LIMBS); Self::new(num_limbs) } diff --git a/plonky2/src/gates/high_degree_interpolation.rs b/plonky2/src/gates/high_degree_interpolation.rs new file mode 100644 index 00000000..bcdf2276 --- /dev/null +++ b/plonky2/src/gates/high_degree_interpolation.rs @@ -0,0 +1,363 @@ +use std::marker::PhantomData; +use std::ops::Range; + +use plonky2_field::extension::algebra::PolynomialCoeffsAlgebra; +use plonky2_field::extension::{Extendable, FieldExtension}; +use plonky2_field::interpolation::interpolant; +use plonky2_field::polynomial::PolynomialCoeffs; + +use crate::gadgets::polynomial::PolynomialCoeffsExtAlgebraTarget; +use crate::gates::gate::Gate; +use crate::gates::interpolation::InterpolationGate; +use crate::gates::util::StridedConstraintConsumer; +use crate::hash::hash_types::RichField; +use crate::iop::ext_target::ExtensionTarget; +use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; +use crate::iop::target::Target; +use crate::iop::wire::Wire; +use crate::iop::witness::{PartitionWitness, Witness}; +use crate::plonk::circuit_builder::CircuitBuilder; +use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; + +/// One of the instantiations of `InterpolationGate`: allows constraints of variable +/// degree, up to `1<, const D: usize> { + pub subgroup_bits: usize, + _phantom: PhantomData, +} + +impl, const D: usize> InterpolationGate + for HighDegreeInterpolationGate +{ + fn new(subgroup_bits: usize) -> Self { + Self { + subgroup_bits, + _phantom: PhantomData, + } + } + + fn num_points(&self) -> usize { + 1 << self.subgroup_bits + } +} + +impl, const D: usize> HighDegreeInterpolationGate { + /// End of wire indices, exclusive. + fn end(&self) -> usize { + self.start_coeffs() + self.num_points() * D + } + + /// The domain of the points we're interpolating. + fn coset(&self, shift: F) -> impl Iterator { + let g = F::primitive_root_of_unity(self.subgroup_bits); + let size = 1 << self.subgroup_bits; + // Speed matters here, so we avoid `cyclic_subgroup_coset_known_order` which allocates. + g.powers().take(size).map(move |x| x * shift) + } + + /// The domain of the points we're interpolating. + fn coset_ext(&self, shift: F::Extension) -> impl Iterator { + let g = F::primitive_root_of_unity(self.subgroup_bits); + let size = 1 << self.subgroup_bits; + g.powers().take(size).map(move |x| shift.scalar_mul(x)) + } + + /// The domain of the points we're interpolating. + fn coset_ext_circuit( + &self, + builder: &mut CircuitBuilder, + shift: ExtensionTarget, + ) -> Vec> { + let g = F::primitive_root_of_unity(self.subgroup_bits); + let size = 1 << self.subgroup_bits; + g.powers() + .take(size) + .map(move |x| { + let subgroup_element = builder.constant(x); + builder.scalar_mul_ext(subgroup_element, shift) + }) + .collect() + } +} + +impl, const D: usize> Gate + for HighDegreeInterpolationGate +{ + fn id(&self) -> String { + format!("{:?}", self, D) + } + + fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { + let mut constraints = Vec::with_capacity(self.num_constraints()); + + let coeffs = (0..self.num_points()) + .map(|i| vars.get_local_ext_algebra(self.wires_coeff(i))) + .collect(); + let interpolant = PolynomialCoeffsAlgebra::new(coeffs); + + let coset = self.coset_ext(vars.local_wires[self.wire_shift()]); + for (i, point) in coset.into_iter().enumerate() { + let value = vars.get_local_ext_algebra(self.wires_value(i)); + let computed_value = interpolant.eval_base(point); + constraints.extend((value - computed_value).to_basefield_array()); + } + + let evaluation_point = vars.get_local_ext_algebra(self.wires_evaluation_point()); + let evaluation_value = vars.get_local_ext_algebra(self.wires_evaluation_value()); + let computed_evaluation_value = interpolant.eval(evaluation_point); + constraints.extend((evaluation_value - computed_evaluation_value).to_basefield_array()); + + constraints + } + + fn eval_unfiltered_base_one( + &self, + vars: EvaluationVarsBase, + mut yield_constr: StridedConstraintConsumer, + ) { + let coeffs = (0..self.num_points()) + .map(|i| vars.get_local_ext(self.wires_coeff(i))) + .collect(); + let interpolant = PolynomialCoeffs::new(coeffs); + + let coset = self.coset(vars.local_wires[self.wire_shift()]); + for (i, point) in coset.into_iter().enumerate() { + let value = vars.get_local_ext(self.wires_value(i)); + let computed_value = interpolant.eval_base(point); + yield_constr.many((value - computed_value).to_basefield_array()); + } + + let evaluation_point = vars.get_local_ext(self.wires_evaluation_point()); + let evaluation_value = vars.get_local_ext(self.wires_evaluation_value()); + let computed_evaluation_value = interpolant.eval(evaluation_point); + yield_constr.many((evaluation_value - computed_evaluation_value).to_basefield_array()); + } + + fn eval_unfiltered_circuit( + &self, + builder: &mut CircuitBuilder, + vars: EvaluationTargets, + ) -> Vec> { + let mut constraints = Vec::with_capacity(self.num_constraints()); + + let coeffs = (0..self.num_points()) + .map(|i| vars.get_local_ext_algebra(self.wires_coeff(i))) + .collect(); + let interpolant = PolynomialCoeffsExtAlgebraTarget(coeffs); + + let coset = self.coset_ext_circuit(builder, vars.local_wires[self.wire_shift()]); + for (i, point) in coset.into_iter().enumerate() { + let value = vars.get_local_ext_algebra(self.wires_value(i)); + let computed_value = interpolant.eval_scalar(builder, point); + constraints.extend( + builder + .sub_ext_algebra(value, computed_value) + .to_ext_target_array(), + ); + } + + let evaluation_point = vars.get_local_ext_algebra(self.wires_evaluation_point()); + let evaluation_value = vars.get_local_ext_algebra(self.wires_evaluation_value()); + let computed_evaluation_value = interpolant.eval(builder, evaluation_point); + constraints.extend( + builder + .sub_ext_algebra(evaluation_value, computed_evaluation_value) + .to_ext_target_array(), + ); + + constraints + } + + fn generators(&self, row: usize, _local_constants: &[F]) -> Vec>> { + let gen = InterpolationGenerator:: { + row, + gate: *self, + _phantom: PhantomData, + }; + vec![Box::new(gen.adapter())] + } + + fn num_wires(&self) -> usize { + self.end() + } + + fn num_constants(&self) -> usize { + 0 + } + + fn degree(&self) -> usize { + // The highest power of x is `num_points - 1`, and then multiplication by the coefficient + // adds 1. + self.num_points() + } + + fn num_constraints(&self) -> usize { + // num_points * D constraints to check for consistency between the coefficients and the + // point-value pairs, plus D constraints for the evaluation value. + self.num_points() * D + D + } +} + +#[derive(Debug)] +struct InterpolationGenerator, const D: usize> { + row: usize, + gate: HighDegreeInterpolationGate, + _phantom: PhantomData, +} + +impl, const D: usize> SimpleGenerator + for InterpolationGenerator +{ + fn dependencies(&self) -> Vec { + let local_target = |column| { + Target::Wire(Wire { + row: self.row, + column, + }) + }; + + let local_targets = |columns: Range| columns.map(local_target); + + let num_points = self.gate.num_points(); + let mut deps = Vec::with_capacity(1 + D + num_points * D); + + deps.push(local_target(self.gate.wire_shift())); + deps.extend(local_targets(self.gate.wires_evaluation_point())); + for i in 0..num_points { + deps.extend(local_targets(self.gate.wires_value(i))); + } + deps + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let local_wire = |column| Wire { + row: self.row, + column, + }; + + let get_local_wire = |column| witness.get_wire(local_wire(column)); + + let get_local_ext = |wire_range: Range| { + debug_assert_eq!(wire_range.len(), D); + let values = wire_range.map(get_local_wire).collect::>(); + let arr = values.try_into().unwrap(); + F::Extension::from_basefield_array(arr) + }; + + // Compute the interpolant. + let points = self.gate.coset(get_local_wire(self.gate.wire_shift())); + let points = points + .into_iter() + .enumerate() + .map(|(i, point)| (point.into(), get_local_ext(self.gate.wires_value(i)))) + .collect::>(); + let interpolant = interpolant(&points); + + for (i, &coeff) in interpolant.coeffs.iter().enumerate() { + let wires = self.gate.wires_coeff(i).map(local_wire); + out_buffer.set_ext_wires(wires, coeff); + } + + let evaluation_point = get_local_ext(self.gate.wires_evaluation_point()); + let evaluation_value = interpolant.eval(evaluation_point); + let evaluation_value_wires = self.gate.wires_evaluation_value().map(local_wire); + out_buffer.set_ext_wires(evaluation_value_wires, evaluation_value); + } +} + +#[cfg(test)] +mod tests { + use std::marker::PhantomData; + + use anyhow::Result; + use plonky2_field::goldilocks_field::GoldilocksField; + use plonky2_field::polynomial::PolynomialCoeffs; + use plonky2_field::types::Field; + + use crate::gates::gate::Gate; + use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; + use crate::gates::high_degree_interpolation::HighDegreeInterpolationGate; + use crate::gates::interpolation::InterpolationGate; + use crate::hash::hash_types::HashOut; + use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + use crate::plonk::vars::EvaluationVars; + + #[test] + fn wire_indices() { + let gate = HighDegreeInterpolationGate:: { + subgroup_bits: 1, + _phantom: PhantomData, + }; + + // The exact indices aren't really important, but we want to make sure we don't have any + // overlaps or gaps. + assert_eq!(gate.wire_shift(), 0); + assert_eq!(gate.wires_value(0), 1..5); + assert_eq!(gate.wires_value(1), 5..9); + assert_eq!(gate.wires_evaluation_point(), 9..13); + assert_eq!(gate.wires_evaluation_value(), 13..17); + assert_eq!(gate.wires_coeff(0), 17..21); + assert_eq!(gate.wires_coeff(1), 21..25); + assert_eq!(gate.num_wires(), 25); + } + + #[test] + fn low_degree() { + test_low_degree::(HighDegreeInterpolationGate::new(2)); + } + + #[test] + fn eval_fns() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + test_eval_fns::(HighDegreeInterpolationGate::new(2)) + } + + #[test] + fn test_gate_constraint() { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + type FF = >::FE; + + /// Returns the local wires for an interpolation gate for given coeffs, points and eval point. + fn get_wires( + gate: &HighDegreeInterpolationGate, + shift: F, + coeffs: PolynomialCoeffs, + eval_point: FF, + ) -> Vec { + let points = gate.coset(shift); + let mut v = vec![shift]; + for x in points { + v.extend(coeffs.eval(x.into()).0); + } + v.extend(eval_point.0); + v.extend(coeffs.eval(eval_point).0); + for i in 0..coeffs.len() { + v.extend(coeffs.coeffs[i].0); + } + v.iter().map(|&x| x.into()).collect() + } + + // Get a working row for InterpolationGate. + let shift = F::rand(); + let coeffs = PolynomialCoeffs::new(vec![FF::rand(), FF::rand()]); + let eval_point = FF::rand(); + let gate = HighDegreeInterpolationGate::::new(1); + let vars = EvaluationVars { + local_constants: &[], + local_wires: &get_wires(&gate, shift, coeffs, eval_point), + public_inputs_hash: &HashOut::rand(), + }; + + assert!( + gate.eval_unfiltered(vars).iter().all(|x| x.is_zero()), + "Gate constraints are not satisfied." + ); + } +} diff --git a/plonky2/src/gates/interpolation.rs b/plonky2/src/gates/interpolation.rs index a619d1f2..d417fa6b 100644 --- a/plonky2/src/gates/interpolation.rs +++ b/plonky2/src/gates/interpolation.rs @@ -1,361 +1,178 @@ -use std::marker::PhantomData; use std::ops::Range; -use plonky2_field::extension::algebra::PolynomialCoeffsAlgebra; -use plonky2_field::extension::{Extendable, FieldExtension}; -use plonky2_field::interpolation::interpolant; -use plonky2_field::polynomial::PolynomialCoeffs; +use plonky2_field::extension::Extendable; -use crate::gadgets::interpolation::InterpolationGate; -use crate::gadgets::polynomial::PolynomialCoeffsExtAlgebraTarget; use crate::gates::gate::Gate; -use crate::gates::util::StridedConstraintConsumer; use crate::hash::hash_types::RichField; use crate::iop::ext_target::ExtensionTarget; -use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; use crate::iop::target::Target; -use crate::iop::wire::Wire; -use crate::iop::witness::{PartitionWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; -use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; -/// Interpolation gate with constraints of degree at most `1<, const D: usize> { - pub subgroup_bits: usize, - _phantom: PhantomData, -} - -impl, const D: usize> InterpolationGate - for HighDegreeInterpolationGate +/// Trait for gates which interpolate a polynomial, whose points are a (base field) coset of the multiplicative subgroup +/// with the given size, and whose values are extension field elements, given by input wires. +/// Outputs the evaluation of the interpolant at a given (extension field) evaluation point. +pub(crate) trait InterpolationGate, const D: usize>: + Gate + Copy { - fn new(subgroup_bits: usize) -> Self { - Self { - subgroup_bits, - _phantom: PhantomData, - } - } + fn new(subgroup_bits: usize) -> Self; - fn num_points(&self) -> usize { - 1 << self.subgroup_bits - } -} + fn num_points(&self) -> usize; -impl, const D: usize> HighDegreeInterpolationGate { - /// End of wire indices, exclusive. - fn end(&self) -> usize { - self.start_coeffs() + self.num_points() * D - } - - /// The domain of the points we're interpolating. - fn coset(&self, shift: F) -> impl Iterator { - let g = F::primitive_root_of_unity(self.subgroup_bits); - let size = 1 << self.subgroup_bits; - // Speed matters here, so we avoid `cyclic_subgroup_coset_known_order` which allocates. - g.powers().take(size).map(move |x| x * shift) - } - - /// The domain of the points we're interpolating. - fn coset_ext(&self, shift: F::Extension) -> impl Iterator { - let g = F::primitive_root_of_unity(self.subgroup_bits); - let size = 1 << self.subgroup_bits; - g.powers().take(size).map(move |x| shift.scalar_mul(x)) - } - - /// The domain of the points we're interpolating. - fn coset_ext_circuit( - &self, - builder: &mut CircuitBuilder, - shift: ExtensionTarget, - ) -> Vec> { - let g = F::primitive_root_of_unity(self.subgroup_bits); - let size = 1 << self.subgroup_bits; - g.powers() - .take(size) - .map(move |x| { - let subgroup_element = builder.constant(x); - builder.scalar_mul_ext(subgroup_element, shift) - }) - .collect() - } -} - -impl, const D: usize> Gate - for HighDegreeInterpolationGate -{ - fn id(&self) -> String { - format!("{:?}", self, D) - } - - fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { - let mut constraints = Vec::with_capacity(self.num_constraints()); - - let coeffs = (0..self.num_points()) - .map(|i| vars.get_local_ext_algebra(self.wires_coeff(i))) - .collect(); - let interpolant = PolynomialCoeffsAlgebra::new(coeffs); - - let coset = self.coset_ext(vars.local_wires[self.wire_shift()]); - for (i, point) in coset.into_iter().enumerate() { - let value = vars.get_local_ext_algebra(self.wires_value(i)); - let computed_value = interpolant.eval_base(point); - constraints.extend((value - computed_value).to_basefield_array()); - } - - let evaluation_point = vars.get_local_ext_algebra(self.wires_evaluation_point()); - let evaluation_value = vars.get_local_ext_algebra(self.wires_evaluation_value()); - let computed_evaluation_value = interpolant.eval(evaluation_point); - constraints.extend((evaluation_value - computed_evaluation_value).to_basefield_array()); - - constraints - } - - fn eval_unfiltered_base_one( - &self, - vars: EvaluationVarsBase, - mut yield_constr: StridedConstraintConsumer, - ) { - let coeffs = (0..self.num_points()) - .map(|i| vars.get_local_ext(self.wires_coeff(i))) - .collect(); - let interpolant = PolynomialCoeffs::new(coeffs); - - let coset = self.coset(vars.local_wires[self.wire_shift()]); - for (i, point) in coset.into_iter().enumerate() { - let value = vars.get_local_ext(self.wires_value(i)); - let computed_value = interpolant.eval_base(point); - yield_constr.many((value - computed_value).to_basefield_array()); - } - - let evaluation_point = vars.get_local_ext(self.wires_evaluation_point()); - let evaluation_value = vars.get_local_ext(self.wires_evaluation_value()); - let computed_evaluation_value = interpolant.eval(evaluation_point); - yield_constr.many((evaluation_value - computed_evaluation_value).to_basefield_array()); - } - - fn eval_unfiltered_circuit( - &self, - builder: &mut CircuitBuilder, - vars: EvaluationTargets, - ) -> Vec> { - let mut constraints = Vec::with_capacity(self.num_constraints()); - - let coeffs = (0..self.num_points()) - .map(|i| vars.get_local_ext_algebra(self.wires_coeff(i))) - .collect(); - let interpolant = PolynomialCoeffsExtAlgebraTarget(coeffs); - - let coset = self.coset_ext_circuit(builder, vars.local_wires[self.wire_shift()]); - for (i, point) in coset.into_iter().enumerate() { - let value = vars.get_local_ext_algebra(self.wires_value(i)); - let computed_value = interpolant.eval_scalar(builder, point); - constraints.extend( - builder - .sub_ext_algebra(value, computed_value) - .to_ext_target_array(), - ); - } - - let evaluation_point = vars.get_local_ext_algebra(self.wires_evaluation_point()); - let evaluation_value = vars.get_local_ext_algebra(self.wires_evaluation_value()); - let computed_evaluation_value = interpolant.eval(builder, evaluation_point); - constraints.extend( - builder - .sub_ext_algebra(evaluation_value, computed_evaluation_value) - .to_ext_target_array(), - ); - - constraints - } - - fn generators(&self, row: usize, _local_constants: &[F]) -> Vec>> { - let gen = InterpolationGenerator:: { - row, - gate: *self, - _phantom: PhantomData, - }; - vec![Box::new(gen.adapter())] - } - - fn num_wires(&self) -> usize { - self.end() - } - - fn num_constants(&self) -> usize { + /// Wire index of the coset shift. + fn wire_shift(&self) -> usize { 0 } - fn degree(&self) -> usize { - // The highest power of x is `num_points - 1`, and then multiplication by the coefficient - // adds 1. - self.num_points() + fn start_values(&self) -> usize { + 1 } - fn num_constraints(&self) -> usize { - // num_points * D constraints to check for consistency between the coefficients and the - // point-value pairs, plus D constraints for the evaluation value. - self.num_points() * D + D + /// Wire indices of the `i`th interpolant value. + fn wires_value(&self, i: usize) -> Range { + debug_assert!(i < self.num_points()); + let start = self.start_values() + i * D; + start..start + D + } + + fn start_evaluation_point(&self) -> usize { + self.start_values() + self.num_points() * D + } + + /// Wire indices of the point to evaluate the interpolant at. + fn wires_evaluation_point(&self) -> Range { + let start = self.start_evaluation_point(); + start..start + D + } + + fn start_evaluation_value(&self) -> usize { + self.start_evaluation_point() + D + } + + /// Wire indices of the interpolated value. + fn wires_evaluation_value(&self) -> Range { + let start = self.start_evaluation_value(); + start..start + D + } + + fn start_coeffs(&self) -> usize { + self.start_evaluation_value() + D + } + + /// The number of routed wires required in the typical usage of this gate, where the points to + /// interpolate, the evaluation point, and the corresponding value are all routed. + fn num_routed_wires(&self) -> usize { + self.start_coeffs() + } + + /// Wire indices of the interpolant's `i`th coefficient. + fn wires_coeff(&self, i: usize) -> Range { + debug_assert!(i < self.num_points()); + let start = self.start_coeffs() + i * D; + start..start + D + } + + fn end_coeffs(&self) -> usize { + self.start_coeffs() + D * self.num_points() } } -#[derive(Debug)] -struct InterpolationGenerator, const D: usize> { - row: usize, - gate: HighDegreeInterpolationGate, - _phantom: PhantomData, -} - -impl, const D: usize> SimpleGenerator - for InterpolationGenerator -{ - fn dependencies(&self) -> Vec { - let local_target = |column| { - Target::Wire(Wire { - row: self.row, - column, - }) - }; - - let local_targets = |columns: Range| columns.map(local_target); - - let num_points = self.gate.num_points(); - let mut deps = Vec::with_capacity(1 + D + num_points * D); - - deps.push(local_target(self.gate.wire_shift())); - deps.extend(local_targets(self.gate.wires_evaluation_point())); - for i in 0..num_points { - deps.extend(local_targets(self.gate.wires_value(i))); +impl, const D: usize> CircuitBuilder { + /// Interpolates a polynomial, whose points are a coset of the multiplicative subgroup with the + /// given size, and whose values are given. Returns the evaluation of the interpolant at + /// `evaluation_point`. + pub(crate) fn interpolate_coset>( + &mut self, + subgroup_bits: usize, + coset_shift: Target, + values: &[ExtensionTarget], + evaluation_point: ExtensionTarget, + ) -> ExtensionTarget { + let gate = G::new(subgroup_bits); + let row = self.add_gate(gate, vec![]); + self.connect(coset_shift, Target::wire(row, gate.wire_shift())); + for (i, &v) in values.iter().enumerate() { + self.connect_extension(v, ExtensionTarget::from_range(row, gate.wires_value(i))); } - deps - } + self.connect_extension( + evaluation_point, + ExtensionTarget::from_range(row, gate.wires_evaluation_point()), + ); - fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { - let local_wire = |column| Wire { - row: self.row, - column, - }; - - let get_local_wire = |column| witness.get_wire(local_wire(column)); - - let get_local_ext = |wire_range: Range| { - debug_assert_eq!(wire_range.len(), D); - let values = wire_range.map(get_local_wire).collect::>(); - let arr = values.try_into().unwrap(); - F::Extension::from_basefield_array(arr) - }; - - // Compute the interpolant. - let points = self.gate.coset(get_local_wire(self.gate.wire_shift())); - let points = points - .into_iter() - .enumerate() - .map(|(i, point)| (point.into(), get_local_ext(self.gate.wires_value(i)))) - .collect::>(); - let interpolant = interpolant(&points); - - for (i, &coeff) in interpolant.coeffs.iter().enumerate() { - let wires = self.gate.wires_coeff(i).map(local_wire); - out_buffer.set_ext_wires(wires, coeff); - } - - let evaluation_point = get_local_ext(self.gate.wires_evaluation_point()); - let evaluation_value = interpolant.eval(evaluation_point); - let evaluation_value_wires = self.gate.wires_evaluation_value().map(local_wire); - out_buffer.set_ext_wires(evaluation_value_wires, evaluation_value); + ExtensionTarget::from_range(row, gate.wires_evaluation_value()) } } #[cfg(test)] mod tests { - use std::marker::PhantomData; - use anyhow::Result; - use plonky2_field::goldilocks_field::GoldilocksField; - use plonky2_field::polynomial::PolynomialCoeffs; + use plonky2_field::extension::FieldExtension; + use plonky2_field::interpolation::interpolant; use plonky2_field::types::Field; - use crate::gadgets::interpolation::InterpolationGate; - use crate::gates::gate::Gate; - use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; - use crate::gates::interpolation::HighDegreeInterpolationGate; - use crate::hash::hash_types::HashOut; + use crate::gates::high_degree_interpolation::HighDegreeInterpolationGate; + use crate::gates::low_degree_interpolation::LowDegreeInterpolationGate; + use crate::iop::witness::PartialWitness; + use crate::plonk::circuit_builder::CircuitBuilder; + use crate::plonk::circuit_data::CircuitConfig; use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; - use crate::plonk::vars::EvaluationVars; + use crate::plonk::verifier::verify; #[test] - fn wire_indices() { - let gate = HighDegreeInterpolationGate:: { - subgroup_bits: 1, - _phantom: PhantomData, - }; - - // The exact indices aren't really important, but we want to make sure we don't have any - // overlaps or gaps. - assert_eq!(gate.wire_shift(), 0); - assert_eq!(gate.wires_value(0), 1..5); - assert_eq!(gate.wires_value(1), 5..9); - assert_eq!(gate.wires_evaluation_point(), 9..13); - assert_eq!(gate.wires_evaluation_value(), 13..17); - assert_eq!(gate.wires_coeff(0), 17..21); - assert_eq!(gate.wires_coeff(1), 21..25); - assert_eq!(gate.num_wires(), 25); - } - - #[test] - fn low_degree() { - test_low_degree::(HighDegreeInterpolationGate::new(2)); - } - - #[test] - fn eval_fns() -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - test_eval_fns::(HighDegreeInterpolationGate::new(2)) - } - - #[test] - fn test_gate_constraint() { + fn test_interpolate() -> Result<()> { const D: usize = 2; type C = PoseidonGoldilocksConfig; type F = >::F; type FF = >::FE; + let config = CircuitConfig::standard_recursion_config(); + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); - /// Returns the local wires for an interpolation gate for given coeffs, points and eval point. - fn get_wires( - gate: &HighDegreeInterpolationGate, - shift: F, - coeffs: PolynomialCoeffs, - eval_point: FF, - ) -> Vec { - let points = gate.coset(shift); - let mut v = vec![shift]; - for x in points { - v.extend(coeffs.eval(x.into()).0); - } - v.extend(eval_point.0); - v.extend(coeffs.eval(eval_point).0); - for i in 0..coeffs.len() { - v.extend(coeffs.coeffs[i].0); - } - v.iter().map(|&x| x.into()).collect() - } + let subgroup_bits = 2; + let len = 1 << subgroup_bits; + let coset_shift = F::rand(); + let g = F::primitive_root_of_unity(subgroup_bits); + let points = F::cyclic_subgroup_coset_known_order(g, coset_shift, len); + let values = FF::rand_vec(len); - // Get a working row for InterpolationGate. - let shift = F::rand(); - let coeffs = PolynomialCoeffs::new(vec![FF::rand(), FF::rand()]); - let eval_point = FF::rand(); - let gate = HighDegreeInterpolationGate::::new(1); - let vars = EvaluationVars { - local_constants: &[], - local_wires: &get_wires(&gate, shift, coeffs, eval_point), - public_inputs_hash: &HashOut::rand(), - }; + let homogeneous_points = points + .iter() + .zip(values.iter()) + .map(|(&a, &b)| (>::from_basefield(a), b)) + .collect::>(); - assert!( - gate.eval_unfiltered(vars).iter().all(|x| x.is_zero()), - "Gate constraints are not satisfied." + let true_interpolant = interpolant(&homogeneous_points); + + let z = FF::rand(); + let true_eval = true_interpolant.eval(z); + + let coset_shift_target = builder.constant(coset_shift); + + let value_targets = values + .iter() + .map(|&v| (builder.constant_extension(v))) + .collect::>(); + + let zt = builder.constant_extension(z); + + let eval_hd = builder.interpolate_coset::>( + subgroup_bits, + coset_shift_target, + &value_targets, + zt, ); + let eval_ld = builder.interpolate_coset::>( + subgroup_bits, + coset_shift_target, + &value_targets, + zt, + ); + let true_eval_target = builder.constant_extension(true_eval); + builder.connect_extension(eval_hd, true_eval_target); + builder.connect_extension(eval_ld, true_eval_target); + + let data = builder.build::(); + let proof = data.prove(pw)?; + + verify(proof, &data.verifier_only, &data.common) } } diff --git a/plonky2/src/gates/low_degree_interpolation.rs b/plonky2/src/gates/low_degree_interpolation.rs index dabadfa4..3edc4175 100644 --- a/plonky2/src/gates/low_degree_interpolation.rs +++ b/plonky2/src/gates/low_degree_interpolation.rs @@ -7,9 +7,9 @@ use plonky2_field::interpolation::interpolant; use plonky2_field::polynomial::PolynomialCoeffs; use plonky2_field::types::Field; -use crate::gadgets::interpolation::InterpolationGate; use crate::gadgets::polynomial::PolynomialCoeffsExtAlgebraTarget; use crate::gates::gate::Gate; +use crate::gates::interpolation::InterpolationGate; use crate::gates::util::StridedConstraintConsumer; use crate::hash::hash_types::RichField; use crate::iop::ext_target::ExtensionTarget; @@ -20,8 +20,9 @@ use crate::iop::witness::{PartitionWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; -/// Interpolation gate with constraints of degree 2. -/// `eval_unfiltered_recursively` uses more gates than `HighDegreeInterpolationGate`. +/// One of the instantiations of `InterpolationGate`: all constraints are degree <= 2. +/// The lower degree is a tradeoff for more gates (`eval_unfiltered_recursively` for +/// this version uses more gates than `LowDegreeInterpolationGate`). #[derive(Copy, Clone, Debug)] pub struct LowDegreeInterpolationGate, const D: usize> { pub subgroup_bits: usize, @@ -387,9 +388,9 @@ mod tests { use plonky2_field::polynomial::PolynomialCoeffs; use plonky2_field::types::Field; - use crate::gadgets::interpolation::InterpolationGate; use crate::gates::gate::Gate; use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; + use crate::gates::interpolation::InterpolationGate; use crate::gates::low_degree_interpolation::LowDegreeInterpolationGate; use crate::hash::hash_types::HashOut; use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; diff --git a/plonky2/src/gates/mod.rs b/plonky2/src/gates/mod.rs index 48e319ef..1d2fc058 100644 --- a/plonky2/src/gates/mod.rs +++ b/plonky2/src/gates/mod.rs @@ -7,6 +7,7 @@ pub mod base_sum; pub mod constant; pub mod exponentiation; pub mod gate; +pub mod high_degree_interpolation; pub mod interpolation; pub mod low_degree_interpolation; pub mod multiplication_extension; diff --git a/plonky2/src/gates/random_access.rs b/plonky2/src/gates/random_access.rs index 2df392bc..fa365f16 100644 --- a/plonky2/src/gates/random_access.rs +++ b/plonky2/src/gates/random_access.rs @@ -24,9 +24,15 @@ use crate::plonk::vars::{ /// A gate for checking that a particular element of a list matches a given value. #[derive(Copy, Clone, Debug)] pub struct RandomAccessGate, const D: usize> { + /// Number of bits in the index (log2 of the list size). pub bits: usize, + + /// How many separate copies are packed into one gate. pub num_copies: usize, + + /// Leftover wires are used as global scratch space to store constants. pub num_extra_constants: usize, + _phantom: PhantomData, } @@ -41,13 +47,18 @@ impl, const D: usize> RandomAccessGate { } pub fn new_from_config(config: &CircuitConfig, bits: usize) -> Self { + // We can access a list of 2^bits elements. let vec_size = 1 << bits; - // Need `(2 + vec_size) * num_copies` routed wires + + // We need `(2 + vec_size) * num_copies` routed wires. let max_copies = (config.num_routed_wires / (2 + vec_size)).min( - // Need `(2 + vec_size + bits) * num_copies` wires + // We need `(2 + vec_size + bits) * num_copies` wires in total. config.num_wires / (2 + vec_size + bits), ); + + // Any leftover wires can be used for constants. let max_extra_constants = config.num_routed_wires - (2 + vec_size) * max_copies; + Self::new( max_copies, bits, @@ -55,20 +66,24 @@ impl, const D: usize> RandomAccessGate { ) } + /// Length of the list being accessed. fn vec_size(&self) -> usize { 1 << self.bits } + /// For each copy, a wire containing the claimed index of the element. pub fn wire_access_index(&self, copy: usize) -> usize { debug_assert!(copy < self.num_copies); (2 + self.vec_size()) * copy } + /// For each copy, a wire containing the element claimed to be at the index. pub fn wire_claimed_element(&self, copy: usize) -> usize { debug_assert!(copy < self.num_copies); (2 + self.vec_size()) * copy + 1 } + /// For each copy, wires containing the entire list. pub fn wire_list_item(&self, i: usize, copy: usize) -> usize { debug_assert!(i < self.vec_size()); debug_assert!(copy < self.num_copies); @@ -84,6 +99,7 @@ impl, const D: usize> RandomAccessGate { self.start_extra_constants() + i } + /// All above wires are routed. pub fn num_routed_wires(&self) -> usize { self.start_extra_constants() + self.num_extra_constants } @@ -202,10 +218,12 @@ impl, const D: usize> Gate for RandomAccessGa .collect() } + // Check that the one remaining element after the folding is the claimed element. debug_assert_eq!(list_items.len(), 1); constraints.push(builder.sub_extension(list_items[0], claimed_element)); } + // Check the constant values. constraints.extend((0..self.num_extra_constants).map(|i| { builder.sub_extension( vars.local_constants[i], diff --git a/plonky2/src/hash/hash_types.rs b/plonky2/src/hash/hash_types.rs index 14303ad3..f416732a 100644 --- a/plonky2/src/hash/hash_types.rs +++ b/plonky2/src/hash/hash_types.rs @@ -115,7 +115,7 @@ pub struct MerkleCapTarget(pub Vec); pub struct BytesHash(pub [u8; N]); impl BytesHash { - #[cfg(feature = "parallel")] + #[cfg(feature = "rand")] pub fn rand_from_rng(rng: &mut R) -> Self { let mut buf = [0; N]; rng.fill_bytes(&mut buf); diff --git a/plonky2/src/hash/merkle_proofs.rs b/plonky2/src/hash/merkle_proofs.rs index 90d55ce1..f54793d9 100644 --- a/plonky2/src/hash/merkle_proofs.rs +++ b/plonky2/src/hash/merkle_proofs.rs @@ -17,6 +17,12 @@ pub struct MerkleProof> { pub siblings: Vec, } +impl> MerkleProof { + pub fn len(&self) -> usize { + self.siblings.len() + } +} + #[derive(Clone, Debug)] pub struct MerkleProofTarget { /// The Merkle digest of each sibling subtree, staying from the bottommost layer. diff --git a/plonky2/src/hash/merkle_tree.rs b/plonky2/src/hash/merkle_tree.rs index 1da66bff..703a353e 100644 --- a/plonky2/src/hash/merkle_tree.rs +++ b/plonky2/src/hash/merkle_tree.rs @@ -21,6 +21,10 @@ impl> MerkleCap { self.0.len() } + pub fn height(&self) -> usize { + log2_strict(self.len()) + } + pub fn flatten(&self) -> Vec { self.0.iter().flat_map(|&h| h.to_vec()).collect() } diff --git a/plonky2/src/iop/generator.rs b/plonky2/src/iop/generator.rs index 5bedf13d..3614b2e4 100644 --- a/plonky2/src/iop/generator.rs +++ b/plonky2/src/iop/generator.rs @@ -31,7 +31,6 @@ pub(crate) fn generate_partial_witness< let mut witness = PartitionWitness::new( config.num_wires, common_data.degree(), - common_data.num_virtual_targets, &prover_data.representative_map, ); diff --git a/plonky2/src/iop/witness.rs b/plonky2/src/iop/witness.rs index caa22c33..e7f21241 100644 --- a/plonky2/src/iop/witness.rs +++ b/plonky2/src/iop/witness.rs @@ -278,14 +278,9 @@ pub struct PartitionWitness<'a, F: Field> { } impl<'a, F: Field> PartitionWitness<'a, F> { - pub fn new( - num_wires: usize, - degree: usize, - num_virtual_targets: usize, - representative_map: &'a [usize], - ) -> Self { + pub fn new(num_wires: usize, degree: usize, representative_map: &'a [usize]) -> Self { Self { - values: vec![None; degree * num_wires + num_virtual_targets], + values: vec![None; representative_map.len()], representative_map, num_wires, degree, diff --git a/plonky2/src/plonk/circuit_builder.rs b/plonky2/src/plonk/circuit_builder.rs index 7be27c0a..05fe649b 100644 --- a/plonky2/src/plonk/circuit_builder.rs +++ b/plonky2/src/plonk/circuit_builder.rs @@ -821,7 +821,6 @@ impl, const D: usize> CircuitBuilder { quotient_degree_factor, num_gate_constraints, num_constants, - num_virtual_targets: self.virtual_target_index, num_public_inputs, k_is, num_partial_products, diff --git a/plonky2/src/plonk/circuit_data.rs b/plonky2/src/plonk/circuit_data.rs index 16c899de..7e69ef31 100644 --- a/plonky2/src/plonk/circuit_data.rs +++ b/plonky2/src/plonk/circuit_data.rs @@ -9,7 +9,8 @@ use crate::field::types::Field; use crate::fri::oracle::PolynomialBatch; use crate::fri::reduction_strategies::FriReductionStrategy; use crate::fri::structure::{ - FriBatchInfo, FriBatchInfoTarget, FriInstanceInfo, FriInstanceInfoTarget, FriPolynomialInfo, + FriBatchInfo, FriBatchInfoTarget, FriInstanceInfo, FriInstanceInfoTarget, FriOracleInfo, + FriPolynomialInfo, }; use crate::fri::{FriConfig, FriParams}; use crate::gates::gate::GateRef; @@ -22,7 +23,7 @@ use crate::iop::target::Target; use crate::iop::witness::PartialWitness; use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::config::{GenericConfig, Hasher}; -use crate::plonk::plonk_common::{PlonkOracle, FRI_ORACLES}; +use crate::plonk::plonk_common::PlonkOracle; use crate::plonk::proof::{CompressedProofWithPublicInputs, ProofWithPublicInputs}; use crate::plonk::prover::prove; use crate::plonk::verifier::verify; @@ -289,8 +290,6 @@ pub struct CommonCircuitData< /// The number of constant wires. pub(crate) num_constants: usize, - pub(crate) num_virtual_targets: usize, - pub(crate) num_public_inputs: usize, /// The `{k_i}` valued used in `S_ID_i` in Plonk's permutation argument. @@ -368,7 +367,7 @@ impl, C: GenericConfig, const D: usize> let openings = vec![zeta_batch, zeta_next_batch]; FriInstanceInfo { - oracles: FRI_ORACLES.to_vec(), + oracles: self.fri_oracles(), batches: openings, } } @@ -394,11 +393,32 @@ impl, C: GenericConfig, const D: usize> let openings = vec![zeta_batch, zeta_next_batch]; FriInstanceInfoTarget { - oracles: FRI_ORACLES.to_vec(), + oracles: self.fri_oracles(), batches: openings, } } + fn fri_oracles(&self) -> Vec { + vec![ + FriOracleInfo { + num_polys: self.num_preprocessed_polys(), + blinding: PlonkOracle::CONSTANTS_SIGMAS.blinding, + }, + FriOracleInfo { + num_polys: self.config.num_wires, + blinding: PlonkOracle::WIRES.blinding, + }, + FriOracleInfo { + num_polys: self.num_zs_partial_products_polys(), + blinding: PlonkOracle::ZS_PARTIAL_PRODUCTS.blinding, + }, + FriOracleInfo { + num_polys: self.num_quotient_polys(), + blinding: PlonkOracle::QUOTIENT.blinding, + }, + ] + } + fn fri_preprocessed_polys(&self) -> Vec { FriPolynomialInfo::from_range( PlonkOracle::CONSTANTS_SIGMAS.index, diff --git a/plonky2/src/plonk/mod.rs b/plonky2/src/plonk/mod.rs index 4f2fa4e1..73e6c96e 100644 --- a/plonky2/src/plonk/mod.rs +++ b/plonky2/src/plonk/mod.rs @@ -8,6 +8,7 @@ pub mod plonk_common; pub mod proof; pub mod prover; pub mod recursive_verifier; +mod validate_shape; pub(crate) mod vanishing_poly; pub mod vars; pub mod verifier; diff --git a/plonky2/src/plonk/plonk_common.rs b/plonky2/src/plonk/plonk_common.rs index 4f92d732..24a94bb3 100644 --- a/plonky2/src/plonk/plonk_common.rs +++ b/plonky2/src/plonk/plonk_common.rs @@ -3,7 +3,6 @@ use plonky2_field::packed::PackedField; use plonky2_field::types::Field; use crate::fri::oracle::SALT_SIZE; -use crate::fri::structure::FriOracleInfo; use crate::gates::arithmetic_base::ArithmeticGate; use crate::hash::hash_types::RichField; use crate::iop::ext_target::ExtensionTarget; @@ -11,13 +10,6 @@ use crate::iop::target::Target; use crate::plonk::circuit_builder::CircuitBuilder; use crate::util::reducing::ReducingFactorTarget; -pub(crate) const FRI_ORACLES: [FriOracleInfo; 4] = [ - PlonkOracle::CONSTANTS_SIGMAS.as_fri_oracle(), - PlonkOracle::WIRES.as_fri_oracle(), - PlonkOracle::ZS_PARTIAL_PRODUCTS.as_fri_oracle(), - PlonkOracle::QUOTIENT.as_fri_oracle(), -]; - /// Holds the Merkle tree index and blinding flag of a set of polynomials used in FRI. #[derive(Debug, Copy, Clone)] pub struct PlonkOracle { @@ -42,12 +34,6 @@ impl PlonkOracle { index: 3, blinding: true, }; - - pub(crate) const fn as_fri_oracle(&self) -> FriOracleInfo { - FriOracleInfo { - blinding: self.blinding, - } - } } pub fn salt_size(salted: bool) -> usize { @@ -64,31 +50,31 @@ pub(crate) fn eval_zero_poly(n: usize, x: F) -> F { x.exp_u64(n as u64) - F::ONE } -/// Evaluate the Lagrange basis `L_1` with `L_1(1) = 1`, and `L_1(x) = 0` for other members of an +/// Evaluate the Lagrange basis `L_0` with `L_0(1) = 1`, and `L_0(x) = 0` for other members of the /// order `n` multiplicative subgroup. -pub(crate) fn eval_l_1(n: usize, x: F) -> F { +pub(crate) fn eval_l_0(n: usize, x: F) -> F { if x.is_one() { // The code below would divide by zero, since we have (x - 1) in both the numerator and // denominator. return F::ONE; } - // L_1(x) = (x^n - 1) / (n * (x - 1)) + // L_0(x) = (x^n - 1) / (n * (x - 1)) // = Z(x) / (n * (x - 1)) eval_zero_poly(n, x) / (F::from_canonical_usize(n) * (x - F::ONE)) } -/// Evaluates the Lagrange basis L_1(x), which has L_1(1) = 1 and vanishes at all other points in +/// Evaluates the Lagrange basis L_0(x), which has L_0(1) = 1 and vanishes at all other points in /// the order-`n` subgroup. /// /// Assumes `x != 1`; if `x` could be 1 then this is unsound. -pub(crate) fn eval_l_1_circuit, const D: usize>( +pub(crate) fn eval_l_0_circuit, const D: usize>( builder: &mut CircuitBuilder, n: usize, x: ExtensionTarget, x_pow_n: ExtensionTarget, ) -> ExtensionTarget { - // L_1(x) = (x^n - 1) / (n * (x - 1)) + // L_0(x) = (x^n - 1) / (n * (x - 1)) // = Z(x) / (n * (x - 1)) let one = builder.one_extension(); let neg_one = builder.neg_one(); diff --git a/plonky2/src/plonk/validate_shape.rs b/plonky2/src/plonk/validate_shape.rs new file mode 100644 index 00000000..f7ec1b6e --- /dev/null +++ b/plonky2/src/plonk/validate_shape.rs @@ -0,0 +1,77 @@ +use anyhow::ensure; +use plonky2_field::extension::Extendable; + +use crate::hash::hash_types::RichField; +use crate::plonk::circuit_data::CommonCircuitData; +use crate::plonk::config::{GenericConfig, Hasher}; +use crate::plonk::proof::{OpeningSet, Proof, ProofWithPublicInputs}; + +pub(crate) fn validate_proof_with_pis_shape( + proof_with_pis: &ProofWithPublicInputs, + common_data: &CommonCircuitData, +) -> anyhow::Result<()> +where + F: RichField + Extendable, + C: GenericConfig, + [(); C::Hasher::HASH_SIZE]:, +{ + let ProofWithPublicInputs { + proof, + public_inputs, + } = proof_with_pis; + + validate_proof_shape(proof, common_data)?; + + ensure!( + public_inputs.len() == common_data.num_public_inputs, + "Number of public inputs doesn't match circuit data." + ); + + Ok(()) +} + +fn validate_proof_shape( + proof: &Proof, + common_data: &CommonCircuitData, +) -> anyhow::Result<()> +where + F: RichField + Extendable, + C: GenericConfig, + [(); C::Hasher::HASH_SIZE]:, +{ + let config = &common_data.config; + let Proof { + wires_cap, + plonk_zs_partial_products_cap, + quotient_polys_cap, + openings, + // The shape of the opening proof will be checked in the FRI verifier (see + // validate_fri_proof_shape), so we ignore it here. + opening_proof: _, + } = proof; + + let OpeningSet { + constants, + plonk_sigmas, + wires, + plonk_zs, + plonk_zs_next, + partial_products, + quotient_polys, + } = openings; + + let cap_height = common_data.fri_params.config.cap_height; + ensure!(wires_cap.height() == cap_height); + ensure!(plonk_zs_partial_products_cap.height() == cap_height); + ensure!(quotient_polys_cap.height() == cap_height); + + ensure!(constants.len() == common_data.num_constants); + ensure!(plonk_sigmas.len() == config.num_routed_wires); + ensure!(wires.len() == config.num_wires); + ensure!(plonk_zs.len() == config.num_challenges); + ensure!(plonk_zs_next.len() == config.num_challenges); + ensure!(partial_products.len() == config.num_challenges * common_data.num_partial_products); + ensure!(quotient_polys.len() == common_data.num_quotient_polys()); + + Ok(()) +} diff --git a/plonky2/src/plonk/vanishing_poly.rs b/plonky2/src/plonk/vanishing_poly.rs index ab0ba53b..303f698b 100644 --- a/plonky2/src/plonk/vanishing_poly.rs +++ b/plonky2/src/plonk/vanishing_poly.rs @@ -10,7 +10,7 @@ use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_data::CommonCircuitData; use crate::plonk::config::GenericConfig; use crate::plonk::plonk_common; -use crate::plonk::plonk_common::eval_l_1_circuit; +use crate::plonk::plonk_common::eval_l_0_circuit; use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBaseBatch}; use crate::util::partial_products::{check_partial_products, check_partial_products_circuit}; use crate::util::reducing::ReducingFactorTarget; @@ -41,17 +41,17 @@ pub(crate) fn eval_vanishing_poly< let constraint_terms = evaluate_gate_constraints(common_data, vars); - // The L_1(x) (Z(x) - 1) vanishing terms. + // The L_0(x) (Z(x) - 1) vanishing terms. let mut vanishing_z_1_terms = Vec::new(); // The terms checking the partial products. let mut vanishing_partial_products_terms = Vec::new(); - let l1_x = plonk_common::eval_l_1(common_data.degree(), x); + let l_0_x = plonk_common::eval_l_0(common_data.degree(), x); for i in 0..common_data.config.num_challenges { let z_x = local_zs[i]; let z_gx = next_zs[i]; - vanishing_z_1_terms.push(l1_x * (z_x - F::Extension::ONE)); + vanishing_z_1_terms.push(l_0_x * (z_x - F::Extension::ONE)); let numerator_values = (0..common_data.config.num_routed_wires) .map(|j| { @@ -135,7 +135,7 @@ pub(crate) fn eval_vanishing_poly_base_batch< let mut numerator_values = Vec::with_capacity(num_routed_wires); let mut denominator_values = Vec::with_capacity(num_routed_wires); - // The L_1(x) (Z(x) - 1) vanishing terms. + // The L_0(x) (Z(x) - 1) vanishing terms. let mut vanishing_z_1_terms = Vec::with_capacity(num_challenges); // The terms checking the partial products. let mut vanishing_partial_products_terms = Vec::new(); @@ -152,11 +152,11 @@ pub(crate) fn eval_vanishing_poly_base_batch< let constraint_terms = PackedStridedView::new(&constraint_terms_batch, n, k); - let l1_x = z_h_on_coset.eval_l1(index, x); + let l_0_x = z_h_on_coset.eval_l_0(index, x); for i in 0..num_challenges { let z_x = local_zs[i]; let z_gx = next_zs[i]; - vanishing_z_1_terms.push(l1_x * z_x.sub_one()); + vanishing_z_1_terms.push(l_0_x * z_x.sub_one()); numerator_values.extend((0..num_routed_wires).map(|j| { let wire_value = vars.local_wires[j]; @@ -332,12 +332,12 @@ pub(crate) fn eval_vanishing_poly_circuit< evaluate_gate_constraints_circuit(builder, common_data, vars,) ); - // The L_1(x) (Z(x) - 1) vanishing terms. + // The L_0(x) (Z(x) - 1) vanishing terms. let mut vanishing_z_1_terms = Vec::new(); // The terms checking the partial products. let mut vanishing_partial_products_terms = Vec::new(); - let l1_x = eval_l_1_circuit(builder, common_data.degree(), x, x_pow_deg); + let l_0_x = eval_l_0_circuit(builder, common_data.degree(), x, x_pow_deg); // Holds `k[i] * x`. let mut s_ids = Vec::new(); @@ -350,8 +350,8 @@ pub(crate) fn eval_vanishing_poly_circuit< let z_x = local_zs[i]; let z_gx = next_zs[i]; - // L_1(x) Z(x) = 0. - vanishing_z_1_terms.push(builder.mul_sub_extension(l1_x, z_x, l1_x)); + // L_0(x) (Z(x) - 1) = 0. + vanishing_z_1_terms.push(builder.mul_sub_extension(l_0_x, z_x, l_0_x)); let mut numerator_values = Vec::new(); let mut denominator_values = Vec::new(); diff --git a/plonky2/src/plonk/verifier.rs b/plonky2/src/plonk/verifier.rs index 13821ff3..6a4f3790 100644 --- a/plonky2/src/plonk/verifier.rs +++ b/plonky2/src/plonk/verifier.rs @@ -8,6 +8,7 @@ use crate::plonk::circuit_data::{CommonCircuitData, VerifierOnlyCircuitData}; use crate::plonk::config::{GenericConfig, Hasher}; use crate::plonk::plonk_common::reduce_with_powers; use crate::plonk::proof::{Proof, ProofChallenges, ProofWithPublicInputs}; +use crate::plonk::validate_shape::validate_proof_with_pis_shape; use crate::plonk::vanishing_poly::eval_vanishing_poly; use crate::plonk::vars::EvaluationVars; @@ -19,10 +20,8 @@ pub(crate) fn verify, C: GenericConfig, c where [(); C::Hasher::HASH_SIZE]:, { - ensure!( - proof_with_pis.public_inputs.len() == common_data.num_public_inputs, - "Number of public inputs doesn't match circuit data." - ); + validate_proof_with_pis_shape(&proof_with_pis, common_data)?; + let public_inputs_hash = proof_with_pis.get_public_inputs_hash(); let challenges = proof_with_pis.get_challenges(public_inputs_hash, common_data)?; diff --git a/projects/cache-friendly-fft/__init__.py b/projects/cache-friendly-fft/__init__.py new file mode 100644 index 00000000..08f1acac --- /dev/null +++ b/projects/cache-friendly-fft/__init__.py @@ -0,0 +1,229 @@ +import numpy as np + +from transpose import transpose_square +from util import lb_exact + + +def _interleave(x, scratch): + """Interleave the elements in an array in-place. + + For example, if `x` is `array([1, 2, 3, 4, 5, 6, 7, 8])`, then its + contents will be rearranged to `array([1, 5, 2, 6, 3, 7, 4, 8])`. + + `scratch` is an externally-allocated buffer, whose `dtype` matches + `x` and whose length is at least half the length of `x`. + """ + assert len(x.shape) == len(scratch.shape) == 1 + + n, = x.shape + assert n % 2 == 0 + + half_n = n // 2 + assert scratch.shape[0] >= half_n + + assert x.dtype == scratch.dtype + scratch = scratch[:half_n] + + scratch[:] = x[:half_n] # Save the first half of `x`. + for i in range(half_n): + x[2 * i] = scratch[i] + x[2 * i + 1] = x[half_n + i] + + +def _deinterleave(x, scratch): + """Deinterleave the elements in an array in-place. + + For example, if `x` is `array([1, 2, 3, 4, 5, 6, 7, 8])`, then its + contents will be rearranged to `array([1, 3, 5, 7, 2, 4, 6, 8])`. + + `scratch` is an externally-allocated buffer, whose `dtype` matches + `x` and whose length is at least half the length of `x`. + """ + assert len(x.shape) == len(scratch.shape) == 1 + + n, = x.shape + assert n % 2 == 0 + + half_n = n // 2 + assert scratch.shape[0] >= half_n + + assert x.dtype == scratch.dtype + scratch = scratch[:half_n] + + for i in range(half_n): + x[i] = x[2 * i] + scratch[i] = x[2 * i + 1] + x[half_n:] = scratch + + +def _fft_inplace_evenpow(x, scratch): + """In-place FFT of length 2^even""" + # Reshape `x` to a square matrix in row-major order. + vec_len = x.shape[0] + n = 1 << (lb_exact(vec_len) >> 1) # Matrix dimension + x.shape = n, n, 1 + + # We want to recursively apply FFT to every column. Because `x` is + # in row-major order, we transpose it to make the columns contiguous + # in memory, then recurse, and finally transpose it back. While the + # row is in cache, we also multiply by the twiddle factors. + transpose_square(x) + for i, row in enumerate(x[..., 0]): + _fft_inplace(row, scratch) + # Multiply by the twiddle factors + for j in range(n): + row[j] *= np.exp(-2j * np.pi * (i * j) / vec_len) + transpose_square(x) + + # Now recursively apply FFT to the rows. + for row in x[..., 0]: + _fft_inplace(row, scratch) + + # Transpose again before returning. + transpose_square(x) + + +def _fft_inplace_oddpow(x, scratch): + """In-place FFT of length 2^odd""" + # This code is based on `_fft_inplace_evenpow`, but it has to + # account for some additional complications. + + vec_len = x.shape[0] + # `vec_len` is an odd power of 2, so we cannot reshape `x` to a + # matrix square. Instead, we'll (conceptually) reshape it to a + # matrix that's twice as wide as it is high. E.g., `[1 ... 8]` + # becomes `[1 2 3 4]` + # `[5 6 7 8]`. + col_len = 1 << (lb_exact(vec_len) >> 1) + row_len = col_len << 1 + + # We can only perform efficient, in-place transposes on square + # matrices, so we will actually treat this as a square matrix of + # 2-tuples, e.g. `[(1 2) (3 4)]` + # `[(5 6) (7 8)]`. + # Note that we can currently `.reshape` it to our intended wide + # matrix (although this is broken by transposition). + x.shape = col_len, col_len, 2 + + # We want to apply FFT to each column. We transpose our + # matrix-of-tuples and get something like `[(1 2) (5 6)]` + # `[(3 4) (7 8)]`. + # Note that each row of the transposed matrix represents two columns + # of the original matrix. We can deinterleave the values to recover + # the original columns. + transpose_square(x) + + for i, row_pair in enumerate(x): + # `row_pair` represents two columns of the original matrix. + # Their values must be deinterleaved to recover the columns. + row_pair.shape = row_len, + _deinterleave(row_pair, scratch) + # The below are rows of the transposed matrix(/cols of the + # original matrix. + row0 = row_pair[:col_len] + row1 = row_pair[col_len:] + + # Apply FFT and twiddle factors to each. + _fft_inplace(row0, scratch) + for j in range(col_len): + row0[j] *= np.exp(-2j * np.pi * ((2 * i) * j) / vec_len) + _fft_inplace(row1, scratch) + for j in range(col_len): + row1[j] *= np.exp(-2j * np.pi * ((2 * i + 1) * j) / vec_len) + + # Re-interleave them and transpose back. + _interleave(row_pair, scratch) + + transpose_square(x) + + # Recursively apply FFT to each row of the matrix. + for row in x: + # Turn vec of 2-tuples into vec of single elements. + row.shape = row_len, + _fft_inplace(row, scratch) + + # Transpose again before returning. This again involves + # deinterleaving. + transpose_square(x) + for row_pair in x: + row_pair.shape = row_len, + _deinterleave(row_pair, scratch) + + +def _fft_inplace(x, scratch): + """In-place FFT.""" + # Avoid modifying the shape of the original. + # This does not copy the buffer. + x = x.view() + assert x.flags['C_CONTIGUOUS'] + + n, = x.shape + if n == 1: + return + if n == 2: + x0, x1 = x + x[0] = x0 + x1 + x[1] = x0 - x1 + return + + lb_n = lb_exact(n) + is_odd = lb_n & 1 != 0 + if is_odd: + _fft_inplace_oddpow(x, scratch) + else: + _fft_inplace_evenpow(x, scratch) + + +def _scrach_length(lb_n): + """Find the amount of scratch space required to run the FFT. + + Layers where the input's length is an even power of two do not + require scratch space, but the layers where that power is odd do. + """ + if lb_n == 0: + # Length-1 input. + return 0 + # Repeatedly halve lb_n as long as it's even. This is the same as + # `n = sqrt(n)`, where the `sqrt` is exact. + while lb_n & 1 == 0: + lb_n >>= 1 + # `lb_n` is now odd, so `n` is not an even power of 2. + lb_res = (lb_n - 1) >> 1 + if lb_res == 0: + # Special case (n == 2 or n == 4): no scratch needed. + return 0 + return 1 << lb_res + + +def fft(x): + """Returns the FFT of `x`. + + This is a wrapper around an in-place routine, provided for user + convenience. + """ + n, = x.shape + lb_n = lb_exact(n) # Raises if not a power of 2. + # We have one scratch buffer for the whole algorithm. If we were to + # parallelize it, we'd need one thread-local buffer for each worker + # thread. + scratch_len = _scrach_length(lb_n) + if scratch_len == 0: + scratch = None + else: + scratch = np.empty_like(x, shape=scratch_len, order='C', subok=False) + + res = x.copy(order='C') + _fft_inplace(res, scratch) + + return res + + +if __name__ == "__main__": + LENGTH = 1 << 10 + v = np.random.normal(size=LENGTH).astype(complex) + print(v) + numpy_fft = np.fft.fft(v) + print(numpy_fft) + our_fft = fft(v) + print(our_fft) + print(np.isclose(numpy_fft, our_fft).all()) diff --git a/projects/cache-friendly-fft/transpose.py b/projects/cache-friendly-fft/transpose.py new file mode 100644 index 00000000..ea20bf6b --- /dev/null +++ b/projects/cache-friendly-fft/transpose.py @@ -0,0 +1,61 @@ +from util import lb_exact + + +def _swap_transpose_square(a, b): + """Transpose two square matrices in-place and swap them. + + The matrices must be a of shape `(n, n, m)`, where the `m` dimension + may be of arbitrary length and is not moved. + """ + assert len(a.shape) == len(b.shape) == 3 + n = a.shape[0] + m = a.shape[2] + assert n == a.shape[1] == b.shape[0] == b.shape[1] + assert m == b.shape[2] + + if n == 0: + return + if n == 1: + # Swap the two matrices (transposition is a no-op). + a = a[0, 0] + b = b[0, 0] + # Recall that each element of the matrix is an `m`-vector. Swap + # all `m` elements. + for i in range(m): + a[i], b[i] = b[i], a[i] + return + + half_n = n >> 1 + # Transpose and swap top-left of `a` with top-left of `b`. + _swap_transpose_square(a[:half_n, :half_n], b[:half_n, :half_n]) + # ...top-right of `a` with bottom-left of `b`. + _swap_transpose_square(a[:half_n, half_n:], b[half_n:, :half_n]) + # ...bottom-left of `a` with top-right of `b`. + _swap_transpose_square(a[half_n:, :half_n], b[:half_n, half_n:]) + # ...bottom-right of `a` with bottom-right of `b`. + _swap_transpose_square(a[half_n:, half_n:], b[half_n:, half_n:]) + + +def transpose_square(a): + """In-place transpose of a square matrix. + + The matrix must be a of shape `(n, n, m)`, where the `m` dimension + may be of arbitrary length and is not moved. + """ + if len(a.shape) != 3: + raise ValueError("a must be a matrix of batches") + n, n_, _ = a.shape + if n != n_: + raise ValueError("a must be square") + lb_exact(n) + + if n <= 1: + return # Base case: no-op + + half_n = n >> 1 + # Transpose top-left quarter in-place. + transpose_square(a[:half_n, :half_n]) + # Transpose top-right and bottom-left quarters and swap them. + _swap_transpose_square(a[:half_n, half_n:], a[half_n:, :half_n]) + # Transpose bottom-right quarter in-place. + transpose_square(a[half_n:, half_n:]) diff --git a/projects/cache-friendly-fft/util.py b/projects/cache-friendly-fft/util.py new file mode 100644 index 00000000..50118827 --- /dev/null +++ b/projects/cache-friendly-fft/util.py @@ -0,0 +1,6 @@ +def lb_exact(n): + """Returns `log2(n)`, raising if `n` is not a power of 2.""" + lb = n.bit_length() - 1 + if lb < 0 or n != 1 << lb: + raise ValueError(f"{n} is not a power of 2") + return lb diff --git a/starky/Cargo.toml b/starky/Cargo.toml index 80a26bfc..43bea53e 100644 --- a/starky/Cargo.toml +++ b/starky/Cargo.toml @@ -6,13 +6,13 @@ edition = "2021" [features] default = ["parallel"] -parallel = ["maybe_rayon/parallel"] +parallel = ["plonky2/parallel", "maybe_rayon/parallel"] [dependencies] -plonky2 = { path = "../plonky2" } +plonky2 = { path = "../plonky2", default-features = false, features = ["rand", "timing"] } plonky2_util = { path = "../util" } +maybe_rayon = { path = "../maybe_rayon"} anyhow = "1.0.40" env_logger = "0.9.0" itertools = "0.10.0" log = "0.4.14" -maybe_rayon = { path = "../maybe_rayon"} diff --git a/starky/src/config.rs b/starky/src/config.rs index 500cd957..a593c827 100644 --- a/starky/src/config.rs +++ b/starky/src/config.rs @@ -21,9 +21,9 @@ impl StarkConfig { fri_config: FriConfig { rate_bits: 1, cap_height: 4, - proof_of_work_bits: 10, + proof_of_work_bits: 16, reduction_strategy: FriReductionStrategy::ConstantArityBits(4, 5), - num_query_rounds: 90, + num_query_rounds: 84, }, } } diff --git a/starky/src/recursive_verifier.rs b/starky/src/recursive_verifier.rs index 7f20d89b..04858d55 100644 --- a/starky/src/recursive_verifier.rs +++ b/starky/src/recursive_verifier.rs @@ -102,8 +102,8 @@ fn verify_stark_proof_with_challenges_circuit< let zeta_pow_deg = builder.exp_power_of_2_extension(challenges.stark_zeta, degree_bits); let z_h_zeta = builder.sub_extension(zeta_pow_deg, one); - let (l_1, l_last) = - eval_l_1_and_l_last_circuit(builder, degree_bits, challenges.stark_zeta, z_h_zeta); + let (l_0, l_last) = + eval_l_0_and_l_last_circuit(builder, degree_bits, challenges.stark_zeta, z_h_zeta); let last = builder.constant_extension(F::Extension::primitive_root_of_unity(degree_bits).inverse()); let z_last = builder.sub_extension(challenges.stark_zeta, last); @@ -112,7 +112,7 @@ fn verify_stark_proof_with_challenges_circuit< builder.zero_extension(), challenges.stark_alphas, z_last, - l_1, + l_0, l_last, ); @@ -170,7 +170,7 @@ fn verify_stark_proof_with_challenges_circuit< ); } -fn eval_l_1_and_l_last_circuit, const D: usize>( +fn eval_l_0_and_l_last_circuit, const D: usize>( builder: &mut CircuitBuilder, log_n: usize, x: ExtensionTarget, @@ -179,12 +179,12 @@ fn eval_l_1_and_l_last_circuit, const D: usize>( let n = builder.constant_extension(F::Extension::from_canonical_usize(1 << log_n)); let g = builder.constant_extension(F::Extension::primitive_root_of_unity(log_n)); let one = builder.one_extension(); - let l_1_deno = builder.mul_sub_extension(n, x, n); + let l_0_deno = builder.mul_sub_extension(n, x, n); let l_last_deno = builder.mul_sub_extension(g, x, one); let l_last_deno = builder.mul_extension(n, l_last_deno); ( - builder.div_extension(z_x, l_1_deno), + builder.div_extension(z_x, l_0_deno), builder.div_extension(z_x, l_last_deno), ) } diff --git a/starky/src/stark.rs b/starky/src/stark.rs index df549572..8ebca87c 100644 --- a/starky/src/stark.rs +++ b/starky/src/stark.rs @@ -78,6 +78,10 @@ pub trait Stark, const D: usize>: Sync { 1.max(self.constraint_degree() - 1) } + fn num_quotient_polys(&self, config: &StarkConfig) -> usize { + self.quotient_degree_factor() * config.num_challenges + } + /// Computes the FRI instance used to prove this Stark. fn fri_instance( &self, @@ -85,25 +89,32 @@ pub trait Stark, const D: usize>: Sync { g: F, config: &StarkConfig, ) -> FriInstanceInfo { - let no_blinding_oracle = FriOracleInfo { blinding: false }; - let mut oracle_indices = 0..; + let mut oracles = vec![]; - let trace_info = - FriPolynomialInfo::from_range(oracle_indices.next().unwrap(), 0..Self::COLUMNS); + let trace_info = FriPolynomialInfo::from_range(oracles.len(), 0..Self::COLUMNS); + oracles.push(FriOracleInfo { + num_polys: Self::COLUMNS, + blinding: false, + }); let permutation_zs_info = if self.uses_permutation_args() { - FriPolynomialInfo::from_range( - oracle_indices.next().unwrap(), - 0..self.num_permutation_batches(config), - ) + let num_z_polys = self.num_permutation_batches(config); + let polys = FriPolynomialInfo::from_range(oracles.len(), 0..num_z_polys); + oracles.push(FriOracleInfo { + num_polys: num_z_polys, + blinding: false, + }); + polys } else { vec![] }; - let quotient_info = FriPolynomialInfo::from_range( - oracle_indices.next().unwrap(), - 0..self.quotient_degree_factor() * config.num_challenges, - ); + let num_quotient_polys = self.quotient_degree_factor() * config.num_challenges; + let quotient_info = FriPolynomialInfo::from_range(oracles.len(), 0..num_quotient_polys); + oracles.push(FriOracleInfo { + num_polys: num_quotient_polys, + blinding: false, + }); let zeta_batch = FriBatchInfo { point: zeta, @@ -118,10 +129,9 @@ pub trait Stark, const D: usize>: Sync { point: zeta.scalar_mul(g), polynomials: [trace_info, permutation_zs_info].concat(), }; - FriInstanceInfo { - oracles: vec![no_blinding_oracle; oracle_indices.next().unwrap()], - batches: vec![zeta_batch, zeta_next_batch], - } + let batches = vec![zeta_batch, zeta_next_batch]; + + FriInstanceInfo { oracles, batches } } /// Computes the FRI instance used to prove this Stark. @@ -132,25 +142,32 @@ pub trait Stark, const D: usize>: Sync { g: F, config: &StarkConfig, ) -> FriInstanceInfoTarget { - let no_blinding_oracle = FriOracleInfo { blinding: false }; - let mut oracle_indices = 0..; + let mut oracles = vec![]; - let trace_info = - FriPolynomialInfo::from_range(oracle_indices.next().unwrap(), 0..Self::COLUMNS); + let trace_info = FriPolynomialInfo::from_range(oracles.len(), 0..Self::COLUMNS); + oracles.push(FriOracleInfo { + num_polys: Self::COLUMNS, + blinding: false, + }); let permutation_zs_info = if self.uses_permutation_args() { - FriPolynomialInfo::from_range( - oracle_indices.next().unwrap(), - 0..self.num_permutation_batches(config), - ) + let num_z_polys = self.num_permutation_batches(config); + let polys = FriPolynomialInfo::from_range(oracles.len(), 0..num_z_polys); + oracles.push(FriOracleInfo { + num_polys: num_z_polys, + blinding: false, + }); + polys } else { vec![] }; - let quotient_info = FriPolynomialInfo::from_range( - oracle_indices.next().unwrap(), - 0..self.quotient_degree_factor() * config.num_challenges, - ); + let num_quotient_polys = self.quotient_degree_factor() * config.num_challenges; + let quotient_info = FriPolynomialInfo::from_range(oracles.len(), 0..num_quotient_polys); + oracles.push(FriOracleInfo { + num_polys: num_quotient_polys, + blinding: false, + }); let zeta_batch = FriBatchInfoTarget { point: zeta, @@ -166,10 +183,9 @@ pub trait Stark, const D: usize>: Sync { point: zeta_next, polynomials: [trace_info, permutation_zs_info].concat(), }; - FriInstanceInfoTarget { - oracles: vec![no_blinding_oracle; oracle_indices.next().unwrap()], - batches: vec![zeta_batch, zeta_next_batch], - } + let batches = vec![zeta_batch, zeta_next_batch]; + + FriInstanceInfoTarget { oracles, batches } } /// Pairs of lists of columns that should be permutations of one another. A permutation argument diff --git a/starky/src/verifier.rs b/starky/src/verifier.rs index 306d3d14..18ae9a27 100644 --- a/starky/src/verifier.rs +++ b/starky/src/verifier.rs @@ -1,6 +1,6 @@ use std::iter::once; -use anyhow::{ensure, Result}; +use anyhow::{anyhow, ensure, Result}; use itertools::Itertools; use plonky2::field::extension::{Extendable, FieldExtension}; use plonky2::field::types::Field; @@ -12,7 +12,7 @@ use plonky2::plonk::plonk_common::reduce_with_powers; use crate::config::StarkConfig; use crate::constraint_consumer::ConstraintConsumer; use crate::permutation::PermutationCheckVars; -use crate::proof::{StarkOpeningSet, StarkProofChallenges, StarkProofWithPublicInputs}; +use crate::proof::{StarkOpeningSet, StarkProof, StarkProofChallenges, StarkProofWithPublicInputs}; use crate::stark::Stark; use crate::vanishing_poly::eval_vanishing_poly; use crate::vars::StarkEvaluationVars; @@ -55,6 +55,7 @@ where [(); S::PUBLIC_INPUTS]:, [(); C::Hasher::HASH_SIZE]:, { + validate_proof_shape(&stark, &proof_with_pis, config)?; check_permutation_options(&stark, &proof_with_pis, &challenges)?; let StarkProofWithPublicInputs { proof, @@ -78,7 +79,7 @@ where .unwrap(), }; - let (l_1, l_last) = eval_l_1_and_l_last(degree_bits, challenges.stark_zeta); + let (l_0, l_last) = eval_l_0_and_l_last(degree_bits, challenges.stark_zeta); let last = F::primitive_root_of_unity(degree_bits).inverse(); let z_last = challenges.stark_zeta - last.into(); let mut consumer = ConstraintConsumer::::new( @@ -88,7 +89,7 @@ where .map(|&alpha| F::Extension::from_basefield(alpha)) .collect::>(), z_last, - l_1, + l_0, l_last, ); let permutation_data = stark.uses_permutation_args().then(|| PermutationCheckVars { @@ -144,10 +145,82 @@ where Ok(()) } -/// Evaluate the Lagrange polynomials `L_1` and `L_n` at a point `x`. -/// `L_1(x) = (x^n - 1)/(n * (x - 1))` -/// `L_n(x) = (x^n - 1)/(n * (g * x - 1))`, with `g` the first element of the subgroup. -fn eval_l_1_and_l_last(log_n: usize, x: F) -> (F, F) { +fn validate_proof_shape( + stark: &S, + proof_with_pis: &StarkProofWithPublicInputs, + config: &StarkConfig, +) -> anyhow::Result<()> +where + F: RichField + Extendable, + C: GenericConfig, + S: Stark, + [(); S::COLUMNS]:, + [(); C::Hasher::HASH_SIZE]:, +{ + let StarkProofWithPublicInputs { + proof, + public_inputs, + } = proof_with_pis; + let degree_bits = proof.recover_degree_bits(config); + + let StarkProof { + trace_cap, + permutation_zs_cap, + quotient_polys_cap, + openings, + // The shape of the opening proof will be checked in the FRI verifier (see + // validate_fri_proof_shape), so we ignore it here. + opening_proof: _, + } = proof; + + let StarkOpeningSet { + local_values, + next_values, + permutation_zs, + permutation_zs_next, + quotient_polys, + } = openings; + + ensure!(public_inputs.len() == S::PUBLIC_INPUTS); + + let fri_params = config.fri_params(degree_bits); + let cap_height = fri_params.config.cap_height; + let num_zs = stark.num_permutation_batches(config); + + ensure!(trace_cap.height() == cap_height); + ensure!(quotient_polys_cap.height() == cap_height); + + ensure!(local_values.len() == S::COLUMNS); + ensure!(next_values.len() == S::COLUMNS); + ensure!(quotient_polys.len() == stark.num_quotient_polys(config)); + + if stark.uses_permutation_args() { + let permutation_zs_cap = permutation_zs_cap + .as_ref() + .ok_or_else(|| anyhow!("Missing Zs cap"))?; + let permutation_zs = permutation_zs + .as_ref() + .ok_or_else(|| anyhow!("Missing permutation_zs"))?; + let permutation_zs_next = permutation_zs_next + .as_ref() + .ok_or_else(|| anyhow!("Missing permutation_zs_next"))?; + + ensure!(permutation_zs_cap.height() == cap_height); + ensure!(permutation_zs.len() == num_zs); + ensure!(permutation_zs_next.len() == num_zs); + } else { + ensure!(permutation_zs_cap.is_none()); + ensure!(permutation_zs.is_none()); + ensure!(permutation_zs_next.is_none()); + } + + Ok(()) +} + +/// Evaluate the Lagrange polynomials `L_0` and `L_(n-1)` at a point `x`. +/// `L_0(x) = (x^n - 1)/(n * (x - 1))` +/// `L_(n-1)(x) = (x^n - 1)/(n * (g * x - 1))`, with `g` the first element of the subgroup. +fn eval_l_0_and_l_last(log_n: usize, x: F) -> (F, F) { let n = F::from_canonical_usize(1 << log_n); let g = F::primitive_root_of_unity(log_n); let z_x = x.exp_power_of_2(log_n) - F::ONE; @@ -189,10 +262,10 @@ mod tests { use plonky2::field::polynomial::PolynomialValues; use plonky2::field::types::Field; - use crate::verifier::eval_l_1_and_l_last; + use crate::verifier::eval_l_0_and_l_last; #[test] - fn test_eval_l_1_and_l_last() { + fn test_eval_l_0_and_l_last() { type F = GoldilocksField; let log_n = 5; let n = 1 << log_n; @@ -201,7 +274,7 @@ mod tests { let expected_l_first_x = PolynomialValues::selector(n, 0).ifft().eval(x); let expected_l_last_x = PolynomialValues::selector(n, n - 1).ifft().eval(x); - let (l_first_x, l_last_x) = eval_l_1_and_l_last(log_n, x); + let (l_first_x, l_last_x) = eval_l_0_and_l_last(log_n, x); assert_eq!(l_first_x, expected_l_first_x); assert_eq!(l_last_x, expected_l_last_x); } diff --git a/util/src/lib.rs b/util/src/lib.rs index 61677ff0..bbc2af98 100644 --- a/util/src/lib.rs +++ b/util/src/lib.rs @@ -38,6 +38,23 @@ pub fn log2_strict(n: usize) -> usize { res as usize } +/// Returns the largest integer `i` such that `base**i <= n`. +pub const fn log_floor(n: u64, base: u64) -> usize { + assert!(n > 0); + assert!(base > 1); + let mut i = 0; + let mut cur: u64 = 1; + loop { + let (mul, overflow) = cur.overflowing_mul(base); + if overflow || mul > n { + return i; + } else { + i += 1; + cur = mul; + } + } +} + /// Permutes `arr` such that each index is mapped to its reverse in binary. pub fn reverse_index_bits(arr: &[T]) -> Vec { let n = arr.len();