diff --git a/evm/src/bn254_pairing.rs b/evm/src/bn254_pairing.rs index 25cb649f..44e22189 100644 --- a/evm/src/bn254_pairing.rs +++ b/evm/src/bn254_pairing.rs @@ -29,6 +29,15 @@ impl Add for Curve { type Output = Self; fn add(self, other: Self) -> Self { + if self == Curve::::unit() { + return other; + } + if other == Curve::::unit() { + return self; + } + if self == -other { + return Curve::::unit(); + } let m = if self == other { T::new(3) * self.x * self.x / (T::new(2) * self.y) } else { @@ -69,26 +78,39 @@ impl CurveGroup for Curve { }; } -// impl Mul for Curve { -// type Output = Curve; +impl Mul for Curve +where + T: FieldExt, + Curve: CurveGroup, +{ + type Output = Curve; -// fn mul(self, other: i32) -> Self { -// let mut result: Curve = self; -// if other.is_negative() { -// result = -result; -// } -// let mut multiplier = result; -// let mut exp = other.unsigned_abs() as usize; -// while exp > 0 { -// if exp % 2 == 1 { -// result = result + multiplier; -// } -// exp >>= 1; -// multiplier = multiplier + multiplier; -// } -// result -// } -// } + fn mul(self, other: i32) -> Self { + if other == 0 { + return Curve::::unit(); + } + if self == Curve::::unit() { + return Curve::::unit(); + } + + let mut x: Curve = self; + if other.is_negative() { + x = -x; + } + let mut result = Curve::::unit(); + + let mut exp = other.unsigned_abs() as usize; + while exp > 0 { + if exp % 2 == 1 { + result = result + x; + } + exp >>= 1; + x = x + x; + } + println!("result: {:?}", result); + result + } +} /// The twisted curve consists of pairs /// (x, y): (Fp2, Fp2) | y^2 = x^3 + 3/(9 + i) diff --git a/evm/src/cpu/kernel/tests/bn254.rs b/evm/src/cpu/kernel/tests/bn254.rs index 6895e8a4..ee1ec702 100644 --- a/evm/src/cpu/kernel/tests/bn254.rs +++ b/evm/src/cpu/kernel/tests/bn254.rs @@ -202,7 +202,7 @@ fn test_bn_final_exponent() -> Result<()> { } fn pairing_input() -> Vec { - let curve_gen: [U256; 2] = unsafe { transmute(Curve::::GENERATOR) }; + let curve_gen: [U256; 2] = unsafe { transmute(Curve::::GENERATOR * 1) }; let twisted_gen: [U256; 4] = unsafe { transmute(Curve::>::GENERATOR) }; let mut input = curve_gen.to_vec(); input.extend_from_slice(&twisted_gen); @@ -223,7 +223,8 @@ fn test_bn_miller() -> Result<()> { }; let interpreter = run_interpreter_with_memory(setup).unwrap(); let output: Vec = interpreter.extract_kernel_memory(BnPairing, out..out + 12); - let expected = miller_loop(Curve::::GENERATOR, Curve::>::GENERATOR).on_stack(); + let expected = + miller_loop(Curve::::GENERATOR, Curve::>::GENERATOR).on_stack(); assert_eq!(output, expected); diff --git a/evm/src/extension_tower.rs b/evm/src/extension_tower.rs index 1b5bf684..2c81b035 100644 --- a/evm/src/extension_tower.rs +++ b/evm/src/extension_tower.rs @@ -1,3 +1,4 @@ +use std::fmt::Debug; use std::mem::transmute; use std::ops::{Add, Div, Mul, Neg, Sub}; @@ -7,6 +8,7 @@ use rand::Rng; pub trait FieldExt: Copy + + std::fmt::Debug + std::cmp::PartialEq + std::ops::Add + std::ops::Neg @@ -980,7 +982,7 @@ where t1: Fp2::::ZERO, t2: Fp2::::ZERO, }; - + fn new(val: usize) -> Fp6 { Fp6 { t0: Fp2::::new(val),