Merge branch 'ec_use_macro_params' into evm_interpreter

This commit is contained in:
wborgeaud 2022-07-07 19:17:31 +02:00
commit 122188c817
9 changed files with 195 additions and 113 deletions

View File

@ -35,7 +35,8 @@ mod tests {
#[test]
fn make_kernel() {
// Make sure we can parse and assemble the entire kernel.
combined_kernel();
let kernel = combined_kernel();
println!("Kernel size: {} bytes", kernel.code.len());
}
fn u256ify<'a>(hexes: impl IntoIterator<Item = &'a str>) -> Result<Vec<U256>> {

View File

@ -1,3 +1,28 @@
%macro jump(dst)
push $dst
jump
%endmacro
%macro jumpi(dst)
push $dst
jumpi
%endmacro
%macro pop2
pop
pop
%endmacro
%macro pop3
pop
%pop2
%endmacro
%macro pop4
%pop2
%pop2
%endmacro
// If pred is zero, yields z; otherwise, yields nz
%macro select
// stack: pred, nz, z

View File

@ -27,19 +27,11 @@ global ec_add:
// stack: isValid(x1, y1), isValid(x0, y0), x0, y0, x1, y1, retdest
AND
// stack: isValid(x1, y1) & isValid(x0, y0), x0, y0, x1, y1, retdest
PUSH ec_add_valid_points
// stack: ec_add_valid_points, isValid(x1, y1) & isValid(x0, y0), x0, y0, x1, y1, retdest
JUMPI
%jumpi(ec_add_valid_points)
// stack: x0, y0, x1, y1, retdest
// Otherwise return
POP
// stack: y0, x1, y1, retdest
POP
// stack: x1, y1, retdest
POP
// stack: y1, retdest
POP
%pop4
// stack: retdest
%ec_invalid_input
@ -56,9 +48,7 @@ global ec_add_valid_points:
// stack: x0, y0, x0, y0, x1, y1, retdest
%ec_isidentity
// stack: (x0,y0)==(0,0), x0, y0, x1, y1, retdest
PUSH ec_add_first_zero
// stack: ec_add_first_zero, (x0,y0)==(0,0), x0, y0, x1, y1, retdest
JUMPI
%jumpi(ec_add_first_zero)
// stack: x0, y0, x1, y1, retdest
// Check if the first point is the identity.
@ -68,9 +58,7 @@ global ec_add_valid_points:
// stack: x1, y1, x0, y0, x1, y1, retdest
%ec_isidentity
// stack: (x1,y1)==(0,0), x0, y0, x1, y1, retdest
PUSH ec_add_snd_zero
// stack: ec_add_snd_zero, (x1,y1)==(0,0), x0, y0, x1, y1, retdest
JUMPI
%jumpi(ec_add_snd_zero)
// stack: x0, y0, x1, y1, retdest
// Check if both points have the same x-coordinate.
@ -80,9 +68,7 @@ global ec_add_valid_points:
// stack: x0, x1, x0, y0, x1, y1, retdest
EQ
// stack: x0 == x1, x0, y0, x1, y1, retdest
PUSH ec_add_equal_first_coord
// stack: ec_add_equal_first_coord, x0 == x1, x0, y0, x1, y1, retdest
JUMPI
%jumpi(ec_add_equal_first_coord)
// stack: x0, y0, x1, y1, retdest
// Otherwise, we can use the standard formula.
@ -101,9 +87,7 @@ global ec_add_valid_points:
// stack: x0 - x1, y0 - y1, x0, y0, x1, y1, retdest
%moddiv
// stack: lambda, x0, y0, x1, y1, retdest
PUSH ec_add_valid_points_with_lambda
// stack: ec_add_valid_points_with_lambda, lambda, x0, y0, x1, y1, retdest
JUMP
%jump(ec_add_valid_points_with_lambda)
// BN254 elliptic curve addition.
// Assumption: (x0,y0) == (0,0)
@ -112,9 +96,7 @@ ec_add_first_zero:
// stack: x0, y0, x1, y1, retdest
// Just return (x1,y1)
POP
// stack: y0, x1, y1, retdest
POP
%pop2
// stack: x1, y1, retdest
SWAP1
// stack: y1, x1, retdest
@ -194,13 +176,7 @@ ec_add_valid_points_with_lambda:
// stack: x2, lambda, x0, y0, y2, y1, retdest
SWAP5
// stack: y1, lambda, x0, y0, y2, x2, retdest
POP
// stack: lambda, x0, y0, y2, x2, retdest
POP
// stack: x0, y0, y2, x2, retdest
POP
// stack: y0, y2, x2, retdest
POP
%pop4
// stack: y2, x2, retdest
SWAP2
// stack: retdest, x2, y2
@ -219,19 +195,11 @@ ec_add_equal_first_coord:
// stack: y1, y0, x0, y0, x1, y1, retdest
EQ
// stack: y1 == y0, x0, y0, x1, y1, retdest
PUSH ec_add_equal_points
// stack: ec_add_equal_points, y1 == y0, x0, y0, x1, y1, retdest
JUMPI
%jumpi(ec_add_equal_points)
// stack: x0, y0, x1, y1, retdest
// Otherwise, one is the negation of the other so we can return (0,0).
POP
// stack: y0, x1, y1, retdest
POP
// stack: x1, y1, retdest
POP
// stack: y1, retdest
POP
%pop4
// stack: retdest
PUSH 0
// stack: 0, retdest
@ -268,9 +236,7 @@ ec_add_equal_points:
// stack: y0, 3/2 * x0^2, x0, y0, x1, y1, retdest
%moddiv
// stack: lambda, x0, y0, x1, y1, retdest
PUSH ec_add_valid_points_with_lambda
// stack: ec_add_valid_points_with_lambda, lambda, x0, y0, x1, y1, retdest
JUMP
%jump(ec_add_valid_points_with_lambda)
// BN254 elliptic curve doubling.
// Assumption: (x0,y0) is a valid point.
@ -282,9 +248,7 @@ global ec_double:
// stack: y0, x0, y0, retdest
DUP2
// stack: x0, y0, x0, y0, retdest
PUSH ec_add_equal_points
// stack: ec_add_equal_points, x0, y0, x0, y0, retdest
JUMP
%jump(ec_add_equal_points)
// Push the order of the BN254 base field.
%macro bn_base

View File

@ -28,12 +28,7 @@ global ec_mul:
// stack: ec_mul_valid_point, isValid(x, y), x, y, s, retdest
JUMPI
// stack: x, y, s, retdest
POP
// stack: y, s, retdest
POP
// stack: s, retdest
POP
// stack: retdest
%pop3
%ec_invalid_input
// Same algorithm as in `exp.asm`
@ -46,9 +41,7 @@ ec_mul_valid_point:
// stack: step_case, s, x, y, s, retdest
JUMPI
// stack: x, y, s, retdest
PUSH ret_zero
// stack: ret_zero, x, y, s, retdest
JUMP
%jump(ret_zero)
step_case:
JUMPDEST
@ -67,17 +60,13 @@ step_case:
// stack: y, step_case_contd, s / 2, recursion_return, x, y, s, retdest
DUP5
// stack: x, y, step_case_contd, s / 2, recursion_return, x, y, s, retdest
PUSH ec_double
// stack: ec_double, x, y, step_case_contd, s / 2, recursion_return, x, y, s, retdest
JUMP
%jump(ec_double)
// Assumption: 2(x,y) = (x',y')
step_case_contd:
JUMPDEST
// stack: x', y', s / 2, recursion_return, x, y, s, retdest
PUSH ec_mul_valid_point
// stack: ec_mul_valid_point, x', y', s / 2, recursion_return, x, y, s, retdest
JUMP
%jump(ec_mul_valid_point)
recursion_return:
JUMPDEST
@ -98,9 +87,7 @@ recursion_return:
// stack: x', s & 1, y', x, y, retdest
SWAP1
// stack: s & 1, x', y', x, y, retdest
PUSH odd_scalar
// stack: odd_scalar, s & 1, x', y', x, y, retdest
JUMPI
%jumpi(odd_scalar)
// stack: x', y', x, y, retdest
SWAP3
// stack: y, y', x, x', retdest
@ -117,18 +104,12 @@ recursion_return:
odd_scalar:
JUMPDEST
// stack: x', y', x, y, retdest
PUSH ec_add_valid_points
// stack: ec_add_valid_points, x', y', x, y, retdest
JUMP
%jump(ec_add_valid_points)
ret_zero:
JUMPDEST
// stack: x, y, s, retdest
POP
// stack: y, s, retdest
POP
// stack: s, retdest
POP
%pop3
// stack: retdest
PUSH 0
// stack: 0, retdest

View File

@ -14,9 +14,7 @@ global exp:
// stack: x, e, retdest
dup2
// stack: e, x, e, retdest
push step_case
// stack: step_case, e, x, e, retdest
jumpi
%jumpi(step_case)
// stack: x, e, retdest
pop
// stack: e, retdest
@ -43,9 +41,7 @@ step_case:
// stack: x, e / 2, recursion_return, x, e, retdest
%square
// stack: x * x, e / 2, recursion_return, x, e, retdest
push exp
// stack: exp, x * x, e / 2, recursion_return, x, e, retdest
jump
%jump(exp)
recursion_return:
jumpdest
// stack: exp(x * x, e / 2), x, e, retdest

View File

@ -19,6 +19,20 @@ pub struct Kernel {
pub(crate) global_labels: HashMap<String, usize>,
}
struct Macro {
params: Vec<String>,
items: Vec<Item>,
}
impl Macro {
fn get_param_index(&self, param: &str) -> usize {
self.params
.iter()
.position(|p| p == param)
.unwrap_or_else(|| panic!("No such param: {} {:?}", param, &self.params))
}
}
pub(crate) fn assemble(files: Vec<File>) -> Kernel {
let macros = find_macros(&files);
let mut global_labels = HashMap::new();
@ -41,33 +55,31 @@ pub(crate) fn assemble(files: Vec<File>) -> Kernel {
}
}
fn find_macros(files: &[File]) -> HashMap<String, Vec<Item>> {
fn find_macros(files: &[File]) -> HashMap<String, Macro> {
let mut macros = HashMap::new();
for file in files {
for item in &file.body {
if let Item::MacroDef(name, items) = item {
macros.insert(name.clone(), items.clone());
if let Item::MacroDef(name, params, items) = item {
let _macro = Macro {
params: params.clone(),
items: items.clone(),
};
macros.insert(name.clone(), _macro);
}
}
}
macros
}
fn expand_macros(body: Vec<Item>, macros: &HashMap<String, Vec<Item>>) -> Vec<Item> {
fn expand_macros(body: Vec<Item>, macros: &HashMap<String, Macro>) -> Vec<Item> {
let mut expanded = vec![];
for item in body {
match item {
Item::MacroDef(_, _) => {
Item::MacroDef(_, _, _) => {
// At this phase, we no longer need macro definitions.
}
Item::MacroCall(m) => {
let mut expanded_item = macros
.get(&m)
.cloned()
.unwrap_or_else(|| panic!("No such macro: {}", m));
// Recursively expand any macros in the expanded code.
expanded_item = expand_macros(expanded_item, macros);
expanded.extend(expanded_item);
Item::MacroCall(m, args) => {
expanded.extend(expand_macro_call(m, args, macros));
}
item => {
expanded.push(item);
@ -77,6 +89,41 @@ fn expand_macros(body: Vec<Item>, macros: &HashMap<String, Vec<Item>>) -> Vec<It
expanded
}
fn expand_macro_call(
name: String,
args: Vec<PushTarget>,
macros: &HashMap<String, Macro>,
) -> Vec<Item> {
let _macro = macros
.get(&name)
.unwrap_or_else(|| panic!("No such macro: {}", name));
assert_eq!(
args.len(),
_macro.params.len(),
"Macro `{}`: expected {} arguments, got {}",
name,
_macro.params.len(),
args.len()
);
let expanded_item = _macro
.items
.iter()
.map(|item| {
if let Item::Push(PushTarget::MacroVar(var)) = item {
let param_index = _macro.get_param_index(var);
Item::Push(args[param_index].clone())
} else {
item.clone()
}
})
.collect();
// Recursively expand any macros in the expanded code.
expand_macros(expanded_item, macros)
}
fn find_labels(
body: &[Item],
offset: &mut usize,
@ -86,7 +133,7 @@ fn find_labels(
let mut local_labels = HashMap::<String, usize>::new();
for item in body {
match item {
Item::MacroDef(_, _) | Item::MacroCall(_) => {
Item::MacroDef(_, _, _) | Item::MacroCall(_, _) => {
panic!("Macros should have been expanded already")
}
Item::GlobalLabelDeclaration(label) => {
@ -114,7 +161,7 @@ fn assemble_file(
// Assemble the file.
for item in body {
match item {
Item::MacroDef(_, _) | Item::MacroCall(_) => {
Item::MacroDef(_, _, _) | Item::MacroCall(_, _) => {
panic!("Macros should have been expanded already")
}
Item::GlobalLabelDeclaration(_) | Item::LocalLabelDeclaration(_) => {
@ -135,6 +182,7 @@ fn assemble_file(
.map(|i| offset.to_le_bytes()[i as usize])
.collect()
}
PushTarget::MacroVar(v) => panic!("Variable not in a macro: {}", v),
};
code.push(get_push_opcode(target_bytes.len() as u8));
code.extend(target_bytes);
@ -152,6 +200,7 @@ fn push_target_size(target: &PushTarget) -> u8 {
match target {
PushTarget::Literal(lit) => lit.to_trimmed_be_bytes().len() as u8,
PushTarget::Label(_) => BYTES_PER_OFFSET,
PushTarget::MacroVar(v) => panic!("Variable not in a macro: {}", v),
}
}
@ -281,6 +330,32 @@ mod tests {
assert_eq!(kernel.code, vec![add, add]);
}
#[test]
fn macro_with_vars() {
let kernel = parse_and_assemble(&[
"%macro add(x, y) PUSH $x PUSH $y ADD %endmacro",
"%add(2, 3)",
]);
let push1 = get_push_opcode(1);
let add = get_opcode("ADD");
assert_eq!(kernel.code, vec![push1, 2, push1, 3, add]);
}
#[test]
#[should_panic]
fn macro_with_wrong_vars() {
parse_and_assemble(&[
"%macro add(x, y) PUSH $x PUSH $y ADD %endmacro",
"%add(2, 3, 4)",
]);
}
#[test]
#[should_panic]
fn var_not_in_macro() {
parse_and_assemble(&["push $abc"]);
}
fn parse_and_assemble(files: &[&str]) -> Kernel {
let parsed_files = files.iter().map(|f| parse(f)).collect_vec();
assemble(parsed_files)

View File

@ -8,10 +8,10 @@ pub(crate) struct File {
#[derive(Clone, Debug)]
pub(crate) enum Item {
/// Defines a new macro.
MacroDef(String, Vec<Item>),
/// Calls a macro.
MacroCall(String),
/// Defines a new macro: name, params, body.
MacroDef(String, Vec<String>, Vec<Item>),
/// Calls a macro: name, args.
MacroCall(String, Vec<PushTarget>),
/// Declares a global label.
GlobalLabelDeclaration(String),
/// Declares a label that is local to the current file.
@ -29,6 +29,7 @@ pub(crate) enum Item {
pub(crate) enum PushTarget {
Literal(Literal),
Label(String),
MacroVar(String),
}
#[derive(Clone, Debug)]

View File

@ -12,13 +12,18 @@ literal_decimal = @{ ASCII_DIGIT+ }
literal_hex = @{ ^"0x" ~ ASCII_HEX_DIGIT+ }
literal = { literal_hex | literal_decimal }
variable = ${ "$" ~ identifier }
item = { macro_def | macro_call | global_label | local_label | bytes_item | push_instruction | nullary_instruction }
macro_def = { ^"%macro" ~ identifier ~ item* ~ ^"%endmacro" }
macro_call = ${ "%" ~ !(^"macro" | ^"endmacro") ~ identifier }
macro_def = { ^"%macro" ~ identifier ~ macro_paramlist? ~ item* ~ ^"%endmacro" }
macro_call = ${ "%" ~ !(^"macro" | ^"endmacro") ~ identifier ~ macro_arglist? }
macro_paramlist = { "(" ~ identifier ~ ("," ~ identifier)* ~ ")" }
macro_arglist = !{ "(" ~ push_target ~ ("," ~ push_target)* ~ ")" }
global_label = { ^"GLOBAL " ~ identifier ~ ":" }
local_label = { identifier ~ ":" }
bytes_item = { ^"BYTES " ~ literal ~ ("," ~ literal)* }
push_instruction = { ^"PUSH " ~ (literal | identifier) }
push_instruction = { ^"PUSH " ~ push_target }
push_target = { literal | identifier | variable }
nullary_instruction = { identifier }
file = { SOI ~ item* ~ silent_eoi }

View File

@ -18,14 +18,11 @@ pub(crate) fn parse(s: &str) -> File {
}
fn parse_item(item: Pair<Rule>) -> Item {
assert_eq!(item.as_rule(), Rule::item);
let item = item.into_inner().next().unwrap();
match item.as_rule() {
Rule::macro_def => {
let mut inner = item.into_inner();
let name = inner.next().unwrap().as_str().into();
Item::MacroDef(name, inner.map(parse_item).collect())
}
Rule::macro_call => Item::MacroCall(item.into_inner().next().unwrap().as_str().into()),
Rule::macro_def => parse_macro_def(item),
Rule::macro_call => parse_macro_call(item),
Rule::global_label => {
Item::GlobalLabelDeclaration(item.into_inner().next().unwrap().as_str().into())
}
@ -39,11 +36,48 @@ fn parse_item(item: Pair<Rule>) -> Item {
}
}
fn parse_macro_def(item: Pair<Rule>) -> Item {
assert_eq!(item.as_rule(), Rule::macro_def);
let mut inner = item.into_inner().peekable();
let name = inner.next().unwrap().as_str().into();
// The parameter list is optional.
let params = if let Some(Rule::macro_paramlist) = inner.peek().map(|pair| pair.as_rule()) {
let params = inner.next().unwrap().into_inner();
params.map(|param| param.as_str().to_string()).collect()
} else {
vec![]
};
Item::MacroDef(name, params, inner.map(parse_item).collect())
}
fn parse_macro_call(item: Pair<Rule>) -> Item {
assert_eq!(item.as_rule(), Rule::macro_call);
let mut inner = item.into_inner();
let name = inner.next().unwrap().as_str().into();
// The arg list is optional.
let args = if let Some(arglist) = inner.next() {
assert_eq!(arglist.as_rule(), Rule::macro_arglist);
arglist.into_inner().map(parse_push_target).collect()
} else {
vec![]
};
Item::MacroCall(name, args)
}
fn parse_push_target(target: Pair<Rule>) -> PushTarget {
match target.as_rule() {
Rule::identifier => PushTarget::Label(target.as_str().into()),
Rule::literal => PushTarget::Literal(parse_literal(target)),
_ => panic!("Unexpected {:?}", target.as_rule()),
assert_eq!(target.as_rule(), Rule::push_target);
let inner = target.into_inner().next().unwrap();
match inner.as_rule() {
Rule::literal => PushTarget::Literal(parse_literal(inner)),
Rule::identifier => PushTarget::Label(inner.as_str().into()),
Rule::variable => PushTarget::MacroVar(inner.into_inner().next().unwrap().as_str().into()),
_ => panic!("Unexpected {:?}", inner.as_rule()),
}
}