diff --git a/evm/src/cpu/kernel/aggregator.rs b/evm/src/cpu/kernel/aggregator.rs index a23e18b8..a5d46a28 100644 --- a/evm/src/cpu/kernel/aggregator.rs +++ b/evm/src/cpu/kernel/aggregator.rs @@ -35,7 +35,8 @@ mod tests { #[test] fn make_kernel() { // Make sure we can parse and assemble the entire kernel. - combined_kernel(); + let kernel = combined_kernel(); + println!("Kernel size: {} bytes", kernel.code.len()); } fn u256ify<'a>(hexes: impl IntoIterator) -> Result> { diff --git a/evm/src/cpu/kernel/asm/basic_macros.asm b/evm/src/cpu/kernel/asm/basic_macros.asm index 8b6410c7..200aeea0 100644 --- a/evm/src/cpu/kernel/asm/basic_macros.asm +++ b/evm/src/cpu/kernel/asm/basic_macros.asm @@ -1,3 +1,28 @@ +%macro jump(dst) + push $dst + jump +%endmacro + +%macro jumpi(dst) + push $dst + jumpi +%endmacro + +%macro pop2 + pop + pop +%endmacro + +%macro pop3 + pop + %pop2 +%endmacro + +%macro pop4 + %pop2 + %pop2 +%endmacro + // If pred is zero, yields z; otherwise, yields nz %macro select // stack: pred, nz, z diff --git a/evm/src/cpu/kernel/asm/curve_add.asm b/evm/src/cpu/kernel/asm/curve_add.asm index fdbbf997..4ac4e0e4 100644 --- a/evm/src/cpu/kernel/asm/curve_add.asm +++ b/evm/src/cpu/kernel/asm/curve_add.asm @@ -27,19 +27,11 @@ global ec_add: // stack: isValid(x1, y1), isValid(x0, y0), x0, y0, x1, y1, retdest AND // stack: isValid(x1, y1) & isValid(x0, y0), x0, y0, x1, y1, retdest - PUSH ec_add_valid_points - // stack: ec_add_valid_points, isValid(x1, y1) & isValid(x0, y0), x0, y0, x1, y1, retdest - JUMPI + %jumpi(ec_add_valid_points) // stack: x0, y0, x1, y1, retdest // Otherwise return - POP - // stack: y0, x1, y1, retdest - POP - // stack: x1, y1, retdest - POP - // stack: y1, retdest - POP + %pop4 // stack: retdest %ec_invalid_input @@ -56,9 +48,7 @@ global ec_add_valid_points: // stack: x0, y0, x0, y0, x1, y1, retdest %ec_isidentity // stack: (x0,y0)==(0,0), x0, y0, x1, y1, retdest - PUSH ec_add_first_zero - // stack: ec_add_first_zero, (x0,y0)==(0,0), x0, y0, x1, y1, retdest - JUMPI + %jumpi(ec_add_first_zero) // stack: x0, y0, x1, y1, retdest // Check if the first point is the identity. @@ -68,9 +58,7 @@ global ec_add_valid_points: // stack: x1, y1, x0, y0, x1, y1, retdest %ec_isidentity // stack: (x1,y1)==(0,0), x0, y0, x1, y1, retdest - PUSH ec_add_snd_zero - // stack: ec_add_snd_zero, (x1,y1)==(0,0), x0, y0, x1, y1, retdest - JUMPI + %jumpi(ec_add_snd_zero) // stack: x0, y0, x1, y1, retdest // Check if both points have the same x-coordinate. @@ -80,9 +68,7 @@ global ec_add_valid_points: // stack: x0, x1, x0, y0, x1, y1, retdest EQ // stack: x0 == x1, x0, y0, x1, y1, retdest - PUSH ec_add_equal_first_coord - // stack: ec_add_equal_first_coord, x0 == x1, x0, y0, x1, y1, retdest - JUMPI + %jumpi(ec_add_equal_first_coord) // stack: x0, y0, x1, y1, retdest // Otherwise, we can use the standard formula. @@ -101,9 +87,7 @@ global ec_add_valid_points: // stack: x0 - x1, y0 - y1, x0, y0, x1, y1, retdest %moddiv // stack: lambda, x0, y0, x1, y1, retdest - PUSH ec_add_valid_points_with_lambda - // stack: ec_add_valid_points_with_lambda, lambda, x0, y0, x1, y1, retdest - JUMP + %jump(ec_add_valid_points_with_lambda) // BN254 elliptic curve addition. // Assumption: (x0,y0) == (0,0) @@ -112,9 +96,7 @@ ec_add_first_zero: // stack: x0, y0, x1, y1, retdest // Just return (x1,y1) - POP - // stack: y0, x1, y1, retdest - POP + %pop2 // stack: x1, y1, retdest SWAP1 // stack: y1, x1, retdest @@ -194,13 +176,7 @@ ec_add_valid_points_with_lambda: // stack: x2, lambda, x0, y0, y2, y1, retdest SWAP5 // stack: y1, lambda, x0, y0, y2, x2, retdest - POP - // stack: lambda, x0, y0, y2, x2, retdest - POP - // stack: x0, y0, y2, x2, retdest - POP - // stack: y0, y2, x2, retdest - POP + %pop4 // stack: y2, x2, retdest SWAP2 // stack: retdest, x2, y2 @@ -219,19 +195,11 @@ ec_add_equal_first_coord: // stack: y1, y0, x0, y0, x1, y1, retdest EQ // stack: y1 == y0, x0, y0, x1, y1, retdest - PUSH ec_add_equal_points - // stack: ec_add_equal_points, y1 == y0, x0, y0, x1, y1, retdest - JUMPI + %jumpi(ec_add_equal_points) // stack: x0, y0, x1, y1, retdest // Otherwise, one is the negation of the other so we can return (0,0). - POP - // stack: y0, x1, y1, retdest - POP - // stack: x1, y1, retdest - POP - // stack: y1, retdest - POP + %pop4 // stack: retdest PUSH 0 // stack: 0, retdest @@ -268,9 +236,7 @@ ec_add_equal_points: // stack: y0, 3/2 * x0^2, x0, y0, x1, y1, retdest %moddiv // stack: lambda, x0, y0, x1, y1, retdest - PUSH ec_add_valid_points_with_lambda - // stack: ec_add_valid_points_with_lambda, lambda, x0, y0, x1, y1, retdest - JUMP + %jump(ec_add_valid_points_with_lambda) // BN254 elliptic curve doubling. // Assumption: (x0,y0) is a valid point. @@ -282,9 +248,7 @@ global ec_double: // stack: y0, x0, y0, retdest DUP2 // stack: x0, y0, x0, y0, retdest - PUSH ec_add_equal_points - // stack: ec_add_equal_points, x0, y0, x0, y0, retdest - JUMP + %jump(ec_add_equal_points) // Push the order of the BN254 base field. %macro bn_base diff --git a/evm/src/cpu/kernel/asm/curve_mul.asm b/evm/src/cpu/kernel/asm/curve_mul.asm index 0826b0e3..85469b65 100644 --- a/evm/src/cpu/kernel/asm/curve_mul.asm +++ b/evm/src/cpu/kernel/asm/curve_mul.asm @@ -28,12 +28,7 @@ global ec_mul: // stack: ec_mul_valid_point, isValid(x, y), x, y, s, retdest JUMPI // stack: x, y, s, retdest - POP - // stack: y, s, retdest - POP - // stack: s, retdest - POP - // stack: retdest + %pop3 %ec_invalid_input // Same algorithm as in `exp.asm` @@ -46,9 +41,7 @@ ec_mul_valid_point: // stack: step_case, s, x, y, s, retdest JUMPI // stack: x, y, s, retdest - PUSH ret_zero - // stack: ret_zero, x, y, s, retdest - JUMP + %jump(ret_zero) step_case: JUMPDEST @@ -67,17 +60,13 @@ step_case: // stack: y, step_case_contd, s / 2, recursion_return, x, y, s, retdest DUP5 // stack: x, y, step_case_contd, s / 2, recursion_return, x, y, s, retdest - PUSH ec_double - // stack: ec_double, x, y, step_case_contd, s / 2, recursion_return, x, y, s, retdest - JUMP + %jump(ec_double) // Assumption: 2(x,y) = (x',y') step_case_contd: JUMPDEST // stack: x', y', s / 2, recursion_return, x, y, s, retdest - PUSH ec_mul_valid_point - // stack: ec_mul_valid_point, x', y', s / 2, recursion_return, x, y, s, retdest - JUMP + %jump(ec_mul_valid_point) recursion_return: JUMPDEST @@ -98,9 +87,7 @@ recursion_return: // stack: x', s & 1, y', x, y, retdest SWAP1 // stack: s & 1, x', y', x, y, retdest - PUSH odd_scalar - // stack: odd_scalar, s & 1, x', y', x, y, retdest - JUMPI + %jumpi(odd_scalar) // stack: x', y', x, y, retdest SWAP3 // stack: y, y', x, x', retdest @@ -117,18 +104,12 @@ recursion_return: odd_scalar: JUMPDEST // stack: x', y', x, y, retdest - PUSH ec_add_valid_points - // stack: ec_add_valid_points, x', y', x, y, retdest - JUMP + %jump(ec_add_valid_points) ret_zero: JUMPDEST // stack: x, y, s, retdest - POP - // stack: y, s, retdest - POP - // stack: s, retdest - POP + %pop3 // stack: retdest PUSH 0 // stack: 0, retdest diff --git a/evm/src/cpu/kernel/asm/exp.asm b/evm/src/cpu/kernel/asm/exp.asm index 683e67c3..389f8490 100644 --- a/evm/src/cpu/kernel/asm/exp.asm +++ b/evm/src/cpu/kernel/asm/exp.asm @@ -14,9 +14,7 @@ global exp: // stack: x, e, retdest dup2 // stack: e, x, e, retdest - push step_case - // stack: step_case, e, x, e, retdest - jumpi + %jumpi(step_case) // stack: x, e, retdest pop // stack: e, retdest @@ -43,9 +41,7 @@ step_case: // stack: x, e / 2, recursion_return, x, e, retdest %square // stack: x * x, e / 2, recursion_return, x, e, retdest - push exp - // stack: exp, x * x, e / 2, recursion_return, x, e, retdest - jump + %jump(exp) recursion_return: jumpdest // stack: exp(x * x, e / 2), x, e, retdest diff --git a/evm/src/cpu/kernel/assembler.rs b/evm/src/cpu/kernel/assembler.rs index 91f64a42..59db93a3 100644 --- a/evm/src/cpu/kernel/assembler.rs +++ b/evm/src/cpu/kernel/assembler.rs @@ -19,6 +19,20 @@ pub struct Kernel { pub(crate) global_labels: HashMap, } +struct Macro { + params: Vec, + items: Vec, +} + +impl Macro { + fn get_param_index(&self, param: &str) -> usize { + self.params + .iter() + .position(|p| p == param) + .unwrap_or_else(|| panic!("No such param: {} {:?}", param, &self.params)) + } +} + pub(crate) fn assemble(files: Vec) -> Kernel { let macros = find_macros(&files); let mut global_labels = HashMap::new(); @@ -41,33 +55,31 @@ pub(crate) fn assemble(files: Vec) -> Kernel { } } -fn find_macros(files: &[File]) -> HashMap> { +fn find_macros(files: &[File]) -> HashMap { let mut macros = HashMap::new(); for file in files { for item in &file.body { - if let Item::MacroDef(name, items) = item { - macros.insert(name.clone(), items.clone()); + if let Item::MacroDef(name, params, items) = item { + let _macro = Macro { + params: params.clone(), + items: items.clone(), + }; + macros.insert(name.clone(), _macro); } } } macros } -fn expand_macros(body: Vec, macros: &HashMap>) -> Vec { +fn expand_macros(body: Vec, macros: &HashMap) -> Vec { let mut expanded = vec![]; for item in body { match item { - Item::MacroDef(_, _) => { + Item::MacroDef(_, _, _) => { // At this phase, we no longer need macro definitions. } - Item::MacroCall(m) => { - let mut expanded_item = macros - .get(&m) - .cloned() - .unwrap_or_else(|| panic!("No such macro: {}", m)); - // Recursively expand any macros in the expanded code. - expanded_item = expand_macros(expanded_item, macros); - expanded.extend(expanded_item); + Item::MacroCall(m, args) => { + expanded.extend(expand_macro_call(m, args, macros)); } item => { expanded.push(item); @@ -77,6 +89,41 @@ fn expand_macros(body: Vec, macros: &HashMap>) -> Vec, + macros: &HashMap, +) -> Vec { + let _macro = macros + .get(&name) + .unwrap_or_else(|| panic!("No such macro: {}", name)); + + assert_eq!( + args.len(), + _macro.params.len(), + "Macro `{}`: expected {} arguments, got {}", + name, + _macro.params.len(), + args.len() + ); + + let expanded_item = _macro + .items + .iter() + .map(|item| { + if let Item::Push(PushTarget::MacroVar(var)) = item { + let param_index = _macro.get_param_index(var); + Item::Push(args[param_index].clone()) + } else { + item.clone() + } + }) + .collect(); + + // Recursively expand any macros in the expanded code. + expand_macros(expanded_item, macros) +} + fn find_labels( body: &[Item], offset: &mut usize, @@ -86,7 +133,7 @@ fn find_labels( let mut local_labels = HashMap::::new(); for item in body { match item { - Item::MacroDef(_, _) | Item::MacroCall(_) => { + Item::MacroDef(_, _, _) | Item::MacroCall(_, _) => { panic!("Macros should have been expanded already") } Item::GlobalLabelDeclaration(label) => { @@ -114,7 +161,7 @@ fn assemble_file( // Assemble the file. for item in body { match item { - Item::MacroDef(_, _) | Item::MacroCall(_) => { + Item::MacroDef(_, _, _) | Item::MacroCall(_, _) => { panic!("Macros should have been expanded already") } Item::GlobalLabelDeclaration(_) | Item::LocalLabelDeclaration(_) => { @@ -135,6 +182,7 @@ fn assemble_file( .map(|i| offset.to_le_bytes()[i as usize]) .collect() } + PushTarget::MacroVar(v) => panic!("Variable not in a macro: {}", v), }; code.push(get_push_opcode(target_bytes.len() as u8)); code.extend(target_bytes); @@ -152,6 +200,7 @@ fn push_target_size(target: &PushTarget) -> u8 { match target { PushTarget::Literal(lit) => lit.to_trimmed_be_bytes().len() as u8, PushTarget::Label(_) => BYTES_PER_OFFSET, + PushTarget::MacroVar(v) => panic!("Variable not in a macro: {}", v), } } @@ -281,6 +330,32 @@ mod tests { assert_eq!(kernel.code, vec![add, add]); } + #[test] + fn macro_with_vars() { + let kernel = parse_and_assemble(&[ + "%macro add(x, y) PUSH $x PUSH $y ADD %endmacro", + "%add(2, 3)", + ]); + let push1 = get_push_opcode(1); + let add = get_opcode("ADD"); + assert_eq!(kernel.code, vec![push1, 2, push1, 3, add]); + } + + #[test] + #[should_panic] + fn macro_with_wrong_vars() { + parse_and_assemble(&[ + "%macro add(x, y) PUSH $x PUSH $y ADD %endmacro", + "%add(2, 3, 4)", + ]); + } + + #[test] + #[should_panic] + fn var_not_in_macro() { + parse_and_assemble(&["push $abc"]); + } + fn parse_and_assemble(files: &[&str]) -> Kernel { let parsed_files = files.iter().map(|f| parse(f)).collect_vec(); assemble(parsed_files) diff --git a/evm/src/cpu/kernel/ast.rs b/evm/src/cpu/kernel/ast.rs index 1409b5d7..f011f1ff 100644 --- a/evm/src/cpu/kernel/ast.rs +++ b/evm/src/cpu/kernel/ast.rs @@ -8,10 +8,10 @@ pub(crate) struct File { #[derive(Clone, Debug)] pub(crate) enum Item { - /// Defines a new macro. - MacroDef(String, Vec), - /// Calls a macro. - MacroCall(String), + /// Defines a new macro: name, params, body. + MacroDef(String, Vec, Vec), + /// Calls a macro: name, args. + MacroCall(String, Vec), /// Declares a global label. GlobalLabelDeclaration(String), /// Declares a label that is local to the current file. @@ -29,6 +29,7 @@ pub(crate) enum Item { pub(crate) enum PushTarget { Literal(Literal), Label(String), + MacroVar(String), } #[derive(Clone, Debug)] diff --git a/evm/src/cpu/kernel/evm_asm.pest b/evm/src/cpu/kernel/evm_asm.pest index 587f87f1..8333a230 100644 --- a/evm/src/cpu/kernel/evm_asm.pest +++ b/evm/src/cpu/kernel/evm_asm.pest @@ -12,13 +12,18 @@ literal_decimal = @{ ASCII_DIGIT+ } literal_hex = @{ ^"0x" ~ ASCII_HEX_DIGIT+ } literal = { literal_hex | literal_decimal } +variable = ${ "$" ~ identifier } + item = { macro_def | macro_call | global_label | local_label | bytes_item | push_instruction | nullary_instruction } -macro_def = { ^"%macro" ~ identifier ~ item* ~ ^"%endmacro" } -macro_call = ${ "%" ~ !(^"macro" | ^"endmacro") ~ identifier } +macro_def = { ^"%macro" ~ identifier ~ macro_paramlist? ~ item* ~ ^"%endmacro" } +macro_call = ${ "%" ~ !(^"macro" | ^"endmacro") ~ identifier ~ macro_arglist? } +macro_paramlist = { "(" ~ identifier ~ ("," ~ identifier)* ~ ")" } +macro_arglist = !{ "(" ~ push_target ~ ("," ~ push_target)* ~ ")" } global_label = { ^"GLOBAL " ~ identifier ~ ":" } local_label = { identifier ~ ":" } bytes_item = { ^"BYTES " ~ literal ~ ("," ~ literal)* } -push_instruction = { ^"PUSH " ~ (literal | identifier) } +push_instruction = { ^"PUSH " ~ push_target } +push_target = { literal | identifier | variable } nullary_instruction = { identifier } file = { SOI ~ item* ~ silent_eoi } diff --git a/evm/src/cpu/kernel/parser.rs b/evm/src/cpu/kernel/parser.rs index ab928582..4145b5f0 100644 --- a/evm/src/cpu/kernel/parser.rs +++ b/evm/src/cpu/kernel/parser.rs @@ -18,14 +18,11 @@ pub(crate) fn parse(s: &str) -> File { } fn parse_item(item: Pair) -> Item { + assert_eq!(item.as_rule(), Rule::item); let item = item.into_inner().next().unwrap(); match item.as_rule() { - Rule::macro_def => { - let mut inner = item.into_inner(); - let name = inner.next().unwrap().as_str().into(); - Item::MacroDef(name, inner.map(parse_item).collect()) - } - Rule::macro_call => Item::MacroCall(item.into_inner().next().unwrap().as_str().into()), + Rule::macro_def => parse_macro_def(item), + Rule::macro_call => parse_macro_call(item), Rule::global_label => { Item::GlobalLabelDeclaration(item.into_inner().next().unwrap().as_str().into()) } @@ -39,11 +36,48 @@ fn parse_item(item: Pair) -> Item { } } +fn parse_macro_def(item: Pair) -> Item { + assert_eq!(item.as_rule(), Rule::macro_def); + let mut inner = item.into_inner().peekable(); + + let name = inner.next().unwrap().as_str().into(); + + // The parameter list is optional. + let params = if let Some(Rule::macro_paramlist) = inner.peek().map(|pair| pair.as_rule()) { + let params = inner.next().unwrap().into_inner(); + params.map(|param| param.as_str().to_string()).collect() + } else { + vec![] + }; + + Item::MacroDef(name, params, inner.map(parse_item).collect()) +} + +fn parse_macro_call(item: Pair) -> Item { + assert_eq!(item.as_rule(), Rule::macro_call); + let mut inner = item.into_inner(); + + let name = inner.next().unwrap().as_str().into(); + + // The arg list is optional. + let args = if let Some(arglist) = inner.next() { + assert_eq!(arglist.as_rule(), Rule::macro_arglist); + arglist.into_inner().map(parse_push_target).collect() + } else { + vec![] + }; + + Item::MacroCall(name, args) +} + fn parse_push_target(target: Pair) -> PushTarget { - match target.as_rule() { - Rule::identifier => PushTarget::Label(target.as_str().into()), - Rule::literal => PushTarget::Literal(parse_literal(target)), - _ => panic!("Unexpected {:?}", target.as_rule()), + assert_eq!(target.as_rule(), Rule::push_target); + let inner = target.into_inner().next().unwrap(); + match inner.as_rule() { + Rule::literal => PushTarget::Literal(parse_literal(inner)), + Rule::identifier => PushTarget::Label(inner.as_str().into()), + Rule::variable => PushTarget::MacroVar(inner.into_inner().next().unwrap().as_str().into()), + _ => panic!("Unexpected {:?}", inner.as_rule()), } }