From ebce0799a2abca0ef0e58d062edc285c527afe18 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Thu, 28 Oct 2021 11:48:14 -0700 Subject: [PATCH] initial curve_types and curve_adds --- src/curve/curve_adds.rs | 129 +++++++++++++++++ src/curve/curve_types.rs | 299 +++++++++++++++++++++++++++++++++++++++ src/curve/mod.rs | 2 + src/field/field_types.rs | 8 ++ src/lib.rs | 1 + 5 files changed, 439 insertions(+) create mode 100644 src/curve/curve_adds.rs create mode 100644 src/curve/curve_types.rs create mode 100644 src/curve/mod.rs diff --git a/src/curve/curve_adds.rs b/src/curve/curve_adds.rs new file mode 100644 index 00000000..c5fc9ba4 --- /dev/null +++ b/src/curve/curve_adds.rs @@ -0,0 +1,129 @@ +use std::ops::Add; + +use crate::field::field_types::Field; +use crate::curve::curve_types::{AffinePoint, Curve, ProjectivePoint}; + +impl Add> for ProjectivePoint { + type Output = ProjectivePoint; + + fn add(self, rhs: ProjectivePoint) -> Self::Output { + let ProjectivePoint { x: x1, y: y1, z: z1, zero: zero1 } = self; + let ProjectivePoint { x: x2, y: y2, z: z2, zero: zero2 } = rhs; + + if zero1 { + return rhs; + } + if zero2 { + return self; + } + + let x1z2 = x1 * z2; + let y1z2 = y1 * z2; + let x2z1 = x2 * z1; + let y2z1 = y2 * z1; + + // Check if we're doubling or adding inverses. + if x1z2 == x2z1 { + if y1z2 == y2z1 { + // TODO: inline to avoid redundant muls. + return self.double(); + } + if y1z2 == -y2z1 { + return ProjectivePoint::ZERO; + } + } + + let z1z2 = z1 * z2; + let u = y2z1 - y1z2; + let uu = u.square(); + let v = x2z1 - x1z2; + let vv = v.square(); + let vvv = v * vv; + let r = vv * x1z2; + let a = uu * z1z2 - vvv - r.double(); + let x3 = v * a; + let y3 = u * (r - a) - vvv * y1z2; + let z3 = vvv * z1z2; + ProjectivePoint::nonzero(x3, y3, z3) + } +} + +impl Add> for ProjectivePoint { + type Output = ProjectivePoint; + + fn add(self, rhs: AffinePoint) -> Self::Output { + let ProjectivePoint { x: x1, y: y1, z: z1, zero: zero1 } = self; + let AffinePoint { x: x2, y: y2, zero: zero2 } = rhs; + + if zero1 { + return rhs.to_projective(); + } + if zero2 { + return self; + } + + let x2z1 = x2 * z1; + let y2z1 = y2 * z1; + + // Check if we're doubling or adding inverses. + if x1 == x2z1 { + if y1 == y2z1 { + // TODO: inline to avoid redundant muls. + return self.double(); + } + if y1 == -y2z1 { + return ProjectivePoint::ZERO; + } + } + + let u = y2z1 - y1; + let uu = u.square(); + let v = x2z1 - x1; + let vv = v.square(); + let vvv = v * vv; + let r = vv * x1; + let a = uu * z1 - vvv - r.double(); + let x3 = v * a; + let y3 = u * (r - a) - vvv * y1; + let z3 = vvv * z1; + ProjectivePoint::nonzero(x3, y3, z3) + } +} + +impl Add> for AffinePoint { + type Output = ProjectivePoint; + + fn add(self, rhs: AffinePoint) -> Self::Output { + let AffinePoint { x: x1, y: y1, zero: zero1 } = self; + let AffinePoint { x: x2, y: y2, zero: zero2 } = rhs; + + if zero1 { + return rhs.to_projective(); + } + if zero2 { + return self.to_projective(); + } + + // Check if we're doubling or adding inverses. + if x1 == x2 { + if y1 == y2 { + return self.to_projective().double(); + } + if y1 == -y2 { + return ProjectivePoint::ZERO; + } + } + + let u = y2 - y1; + let uu = u.square(); + let v = x2 - x1; + let vv = v.square(); + let vvv = v * vv; + let r = vv * x1; + let a = uu - vvv - r.double(); + let x3 = v * a; + let y3 = u * (r - a) - vvv * y1; + let z3 = vvv; + ProjectivePoint::nonzero(x3, y3, z3) + } +} diff --git a/src/curve/curve_types.rs b/src/curve/curve_types.rs new file mode 100644 index 00000000..f7d55c0c --- /dev/null +++ b/src/curve/curve_types.rs @@ -0,0 +1,299 @@ +use std::ops::Neg; + +use anyhow::Result; + +use crate::field::field_types::Field; +use std::fmt::Debug; + +// To avoid implementation conflicts from associated types, +// see https://github.com/rust-lang/rust/issues/20400 +pub struct CurveScalar(pub ::ScalarField); + +/// A short Weierstrass curve. +pub trait Curve: 'static + Sync + Sized + Copy + Debug { + type BaseField: Field; + type ScalarField: Field; + + const A: Self::BaseField; + const B: Self::BaseField; + + const GENERATOR_AFFINE: AffinePoint; + + const GENERATOR_PROJECTIVE: ProjectivePoint = ProjectivePoint { + x: Self::GENERATOR_AFFINE.x, + y: Self::GENERATOR_AFFINE.y, + z: Self::BaseField::ONE, + zero: false, + }; + + fn convert(x: Self::ScalarField) -> CurveScalar { + CurveScalar(x) + } + + /*fn try_convert_b2s(x: Self::BaseField) -> Result { + x.try_convert::() + } + + fn try_convert_s2b(x: Self::ScalarField) -> Result { + x.try_convert::() + } + + fn try_convert_s2b_slice(s: &[Self::ScalarField]) -> Result> { + let mut res = Vec::with_capacity(s.len()); + for &x in s { + res.push(Self::try_convert_s2b(x)?); + } + Ok(res) + } + + fn try_convert_b2s_slice(s: &[Self::BaseField]) -> Result> { + let mut res = Vec::with_capacity(s.len()); + for &x in s { + res.push(Self::try_convert_b2s(x)?); + } + Ok(res) + }*/ + + fn is_safe_curve() -> bool{ + // Added additional check to prevent using vulnerabilties in case a discriminant is equal to 0. + (Self::A.cube().double().double() + Self::B.square().triple().triple().triple()).is_nonzero() + } +} + +/// A point on a short Weierstrass curve, represented in affine coordinates. +#[derive(Copy, Clone, Debug)] +pub struct AffinePoint { + pub x: C::BaseField, + pub y: C::BaseField, + pub zero: bool, +} + +impl AffinePoint { + pub const ZERO: Self = Self { + x: C::BaseField::ZERO, + y: C::BaseField::ZERO, + zero: true, + }; + + pub fn nonzero(x: C::BaseField, y: C::BaseField) -> Self { + let point = Self { x, y, zero: false }; + debug_assert!(point.is_valid()); + point + } + + pub fn is_valid(&self) -> bool { + let Self { x, y, zero } = *self; + zero || y.square() == x.cube() + C::A * x + C::B + } + + pub fn to_projective(&self) -> ProjectivePoint { + let Self { x, y, zero } = *self; + ProjectivePoint { + x, + y, + z: C::BaseField::ONE, + zero, + } + } + + pub fn batch_to_projective(affine_points: &[Self]) -> Vec> { + affine_points.iter().map(Self::to_projective).collect() + } + + pub fn double(&self) -> Self { + let AffinePoint { + x: x1, + y: y1, + zero, + } = *self; + + if zero { + return AffinePoint::ZERO; + } + + let double_y = y1.double(); + let inv_double_y = double_y.inverse(); // (2y)^(-1) + let triple_xx = x1.square().triple(); // 3x^2 + let lambda = (triple_xx + C::A) * inv_double_y; + let x3 = lambda.square() - self.x.double(); + let y3 = lambda * (x1 - x3) - y1; + + Self { + x: x3, + y: y3, + zero: false, + } + } + +} + +impl PartialEq for AffinePoint { + fn eq(&self, other: &Self) -> bool { + let AffinePoint { + x: x1, + y: y1, + zero: zero1, + } = *self; + let AffinePoint { + x: x2, + y: y2, + zero: zero2, + } = *other; + if zero1 || zero2 { + return zero1 == zero2; + } + x1 == x2 && y1 == y2 + } +} + +impl Eq for AffinePoint {} + +/// A point on a short Weierstrass curve, represented in projective coordinates. +#[derive(Copy, Clone, Debug)] +pub struct ProjectivePoint { + pub x: C::BaseField, + pub y: C::BaseField, + pub z: C::BaseField, + pub zero: bool, +} + +impl ProjectivePoint { + pub const ZERO: Self = Self { + x: C::BaseField::ZERO, + y: C::BaseField::ZERO, + z: C::BaseField::ZERO, + zero: true, + }; + + pub fn nonzero(x: C::BaseField, y: C::BaseField, z: C::BaseField) -> Self { + let point = Self { + x, + y, + z, + zero: false, + }; + debug_assert!(point.is_valid()); + point + } + + pub fn is_valid(&self) -> bool { + self.to_affine().is_valid() + } + + pub fn to_affine(&self) -> AffinePoint { + let Self { x, y, z, zero } = *self; + if zero { + AffinePoint::ZERO + } else { + let z_inv = z.inverse(); + AffinePoint::nonzero(x * z_inv, y * z_inv) + } + } + + pub fn batch_to_affine(proj_points: &[Self]) -> Vec> { + let n = proj_points.len(); + let zs: Vec = proj_points.iter().map(|pp| pp.z).collect(); + let z_invs = C::BaseField::batch_multiplicative_inverse(&zs); + + let mut result = Vec::with_capacity(n); + for i in 0..n { + let Self { x, y, z: _, zero } = proj_points[i]; + result.push(if zero { + AffinePoint::ZERO + } else { + let z_inv = z_invs[i]; + AffinePoint::nonzero(x * z_inv, y * z_inv) + }); + } + result + } + + pub fn double(&self) -> Self { + let Self { x, y, z, zero } = *self; + if zero { + return ProjectivePoint::ZERO; + } + + let xx = x.square(); + let zz = z.square(); + let mut w = xx.triple(); + if C::A.is_nonzero() { + w = w + C::A * zz; + } + let s = y.double() * z; + let r = y * s; + let rr = r.square(); + let b = (x + r).square() - (xx + rr); + let h = w.square() - b.double(); + let x3 = h * s; + let y3 = w * (b - h) - rr.double(); + let z3 = s.cube(); + Self { + x: x3, + y: y3, + z: z3, + zero: false, + } + } + + pub fn add_slices(a: &[Self], b: &[Self]) -> Vec { + assert_eq!(a.len(), b.len()); + a.iter() + .zip(b.iter()) + .map(|(&a_i, &b_i)| a_i + b_i) + .collect() + } + + pub fn neg(&self) -> Self { + Self { + x: self.x, + y: -self.y, + z: self.z, + zero: self.zero, + } + } +} + +impl PartialEq for ProjectivePoint { + fn eq(&self, other: &Self) -> bool { + let ProjectivePoint { + x: x1, + y: y1, + z: z1, + zero: zero1, + } = *self; + let ProjectivePoint { + x: x2, + y: y2, + z: z2, + zero: zero2, + } = *other; + if zero1 || zero2 { + return zero1 == zero2; + } + + // We want to compare (x1/z1, y1/z1) == (x2/z2, y2/z2). + // But to avoid field division, it is better to compare (x1*z2, y1*z2) == (x2*z1, y2*z1). + x1 * z2 == x2 * z1 && y1 * z2 == y2 * z1 + } +} + +impl Eq for ProjectivePoint {} + +impl Neg for AffinePoint { + type Output = AffinePoint; + + fn neg(self) -> Self::Output { + let AffinePoint { x, y, zero } = self; + AffinePoint { x, y: -y, zero } + } +} + +impl Neg for ProjectivePoint { + type Output = ProjectivePoint; + + fn neg(self) -> Self::Output { + let ProjectivePoint { x, y, z, zero } = self; + ProjectivePoint { x, y: -y, z, zero } + } +} diff --git a/src/curve/mod.rs b/src/curve/mod.rs new file mode 100644 index 00000000..1e536564 --- /dev/null +++ b/src/curve/mod.rs @@ -0,0 +1,2 @@ +pub mod curve_adds; +pub mod curve_types; diff --git a/src/field/field_types.rs b/src/field/field_types.rs index 036793cc..250338fb 100644 --- a/src/field/field_types.rs +++ b/src/field/field_types.rs @@ -91,6 +91,14 @@ pub trait Field: self.square() * *self } + fn double(&self) -> Self { + *self * Self::TWO + } + + fn triple(&self) -> Self { + *self * (Self::ONE + Self::TWO) + } + /// Compute the multiplicative inverse of this field element. fn try_inverse(&self) -> Option; diff --git a/src/lib.rs b/src/lib.rs index e76e312c..b0158d7a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,6 +11,7 @@ #![feature(specialization)] #![feature(stdsimd)] +pub mod curve; pub mod field; pub mod fri; pub mod gadgets;