323 lines
6.8 KiB
C++
Raw Normal View History

2023-11-20 21:15:41 -07:00
#include "fr.hpp"
#include <stdio.h>
#include <gmp.h>
#include <assert.h>
#include <string>
static mpz_t q;
static mpz_t zero;
static mpz_t one;
static mpz_t mask;
static size_t nBits;
static bool initialized = false;
void Fr_toMpz(mpz_t r, PFrElement pE) {
FrElement tmp;
Fr_toNormal(&tmp, pE);
if (!(tmp.type & Fr_LONG)) {
mpz_set_si(r, tmp.shortVal);
if (tmp.shortVal<0) {
mpz_add(r, r, q);
}
} else {
mpz_import(r, Fr_N64, -1, 8, -1, 0, (const void *)tmp.longVal);
}
}
void Fr_fromMpz(PFrElement pE, mpz_t v) {
if (mpz_fits_sint_p(v)) {
pE->type = Fr_SHORT;
pE->shortVal = mpz_get_si(v);
} else {
pE->type = Fr_LONG;
for (int i=0; i<Fr_N64; i++) pE->longVal[i] = 0;
mpz_export((void *)(pE->longVal), NULL, -1, 8, -1, 0, v);
}
}
bool Fr_init() {
if (initialized) return false;
initialized = true;
mpz_init(q);
mpz_import(q, Fr_N64, -1, 8, -1, 0, (const void *)Fr_q.longVal);
mpz_init_set_ui(zero, 0);
mpz_init_set_ui(one, 1);
nBits = mpz_sizeinbase (q, 2);
mpz_init(mask);
mpz_mul_2exp(mask, one, nBits);
mpz_sub(mask, mask, one);
return true;
}
void Fr_str2element(PFrElement pE, char const *s, uint base) {
mpz_t mr;
mpz_init_set_str(mr, s, base);
mpz_fdiv_r(mr, mr, q);
Fr_fromMpz(pE, mr);
mpz_clear(mr);
}
char *Fr_element2str(PFrElement pE) {
FrElement tmp;
mpz_t r;
if (!(pE->type & Fr_LONG)) {
if (pE->shortVal>=0) {
char *r = new char[32];
sprintf(r, "%d", pE->shortVal);
return r;
} else {
mpz_init_set_si(r, pE->shortVal);
mpz_add(r, r, q);
}
} else {
Fr_toNormal(&tmp, pE);
mpz_init(r);
mpz_import(r, Fr_N64, -1, 8, -1, 0, (const void *)tmp.longVal);
}
char *res = mpz_get_str (0, 10, r);
mpz_clear(r);
return res;
}
void Fr_idiv(PFrElement r, PFrElement a, PFrElement b) {
mpz_t ma;
mpz_t mb;
mpz_t mr;
mpz_init(ma);
mpz_init(mb);
mpz_init(mr);
Fr_toMpz(ma, a);
// char *s1 = mpz_get_str (0, 10, ma);
// printf("s1 %s\n", s1);
Fr_toMpz(mb, b);
// char *s2 = mpz_get_str (0, 10, mb);
// printf("s2 %s\n", s2);
mpz_fdiv_q(mr, ma, mb);
// char *sr = mpz_get_str (0, 10, mr);
// printf("r %s\n", sr);
Fr_fromMpz(r, mr);
mpz_clear(ma);
mpz_clear(mb);
mpz_clear(mr);
}
void Fr_mod(PFrElement r, PFrElement a, PFrElement b) {
mpz_t ma;
mpz_t mb;
mpz_t mr;
mpz_init(ma);
mpz_init(mb);
mpz_init(mr);
Fr_toMpz(ma, a);
Fr_toMpz(mb, b);
mpz_fdiv_r(mr, ma, mb);
Fr_fromMpz(r, mr);
mpz_clear(ma);
mpz_clear(mb);
mpz_clear(mr);
}
void Fr_pow(PFrElement r, PFrElement a, PFrElement b) {
mpz_t ma;
mpz_t mb;
mpz_t mr;
mpz_init(ma);
mpz_init(mb);
mpz_init(mr);
Fr_toMpz(ma, a);
Fr_toMpz(mb, b);
mpz_powm(mr, ma, mb, q);
Fr_fromMpz(r, mr);
mpz_clear(ma);
mpz_clear(mb);
mpz_clear(mr);
}
void Fr_inv(PFrElement r, PFrElement a) {
mpz_t ma;
mpz_t mr;
mpz_init(ma);
mpz_init(mr);
Fr_toMpz(ma, a);
mpz_invert(mr, ma, q);
Fr_fromMpz(r, mr);
mpz_clear(ma);
mpz_clear(mr);
}
void Fr_div(PFrElement r, PFrElement a, PFrElement b) {
FrElement tmp;
Fr_inv(&tmp, b);
Fr_mul(r, a, &tmp);
}
void Fr_fail() {
assert(false);
}
void Fr_longErr()
{
Fr_fail();
}
RawFr::RawFr() {
Fr_init();
set(fZero, 0);
set(fOne, 1);
neg(fNegOne, fOne);
}
RawFr::~RawFr() {
}
void RawFr::fromString(Element &r, const std::string &s, uint32_t radix) {
mpz_t mr;
mpz_init_set_str(mr, s.c_str(), radix);
mpz_fdiv_r(mr, mr, q);
for (int i=0; i<Fr_N64; i++) r.v[i] = 0;
mpz_export((void *)(r.v), NULL, -1, 8, -1, 0, mr);
Fr_rawToMontgomery(r.v,r.v);
mpz_clear(mr);
}
void RawFr::fromUI(Element &r, unsigned long int v) {
mpz_t mr;
mpz_init(mr);
mpz_set_ui(mr, v);
for (int i=0; i<Fr_N64; i++) r.v[i] = 0;
mpz_export((void *)(r.v), NULL, -1, 8, -1, 0, mr);
Fr_rawToMontgomery(r.v,r.v);
mpz_clear(mr);
}
RawFr::Element RawFr::set(int value) {
Element r;
set(r, value);
return r;
}
void RawFr::set(Element &r, int value) {
mpz_t mr;
mpz_init(mr);
mpz_set_si(mr, value);
if (value < 0) {
mpz_add(mr, mr, q);
}
mpz_export((void *)(r.v), NULL, -1, 8, -1, 0, mr);
for (int i=0; i<Fr_N64; i++) r.v[i] = 0;
mpz_export((void *)(r.v), NULL, -1, 8, -1, 0, mr);
Fr_rawToMontgomery(r.v,r.v);
mpz_clear(mr);
}
std::string RawFr::toString(const Element &a, uint32_t radix) {
Element tmp;
mpz_t r;
Fr_rawFromMontgomery(tmp.v, a.v);
mpz_init(r);
mpz_import(r, Fr_N64, -1, 8, -1, 0, (const void *)(tmp.v));
char *res = mpz_get_str (0, radix, r);
mpz_clear(r);
std::string resS(res);
free(res);
return resS;
}
void RawFr::inv(Element &r, const Element &a) {
mpz_t mr;
mpz_init(mr);
mpz_import(mr, Fr_N64, -1, 8, -1, 0, (const void *)(a.v));
mpz_invert(mr, mr, q);
for (int i=0; i<Fr_N64; i++) r.v[i] = 0;
mpz_export((void *)(r.v), NULL, -1, 8, -1, 0, mr);
Fr_rawMMul(r.v, r.v,Fr_R3.longVal);
mpz_clear(mr);
}
void RawFr::div(Element &r, const Element &a, const Element &b) {
Element tmp;
inv(tmp, b);
mul(r, a, tmp);
}
#define BIT_IS_SET(s, p) (s[p>>3] & (1 << (p & 0x7)))
void RawFr::exp(Element &r, const Element &base, uint8_t* scalar, unsigned int scalarSize) {
bool oneFound = false;
Element copyBase;
copy(copyBase, base);
for (int i=scalarSize*8-1; i>=0; i--) {
if (!oneFound) {
if ( !BIT_IS_SET(scalar, i) ) continue;
copy(r, copyBase);
oneFound = true;
continue;
}
square(r, r);
if ( BIT_IS_SET(scalar, i) ) {
mul(r, r, copyBase);
}
}
if (!oneFound) {
copy(r, fOne);
}
}
void RawFr::toMpz(mpz_t r, const Element &a) {
Element tmp;
Fr_rawFromMontgomery(tmp.v, a.v);
mpz_import(r, Fr_N64, -1, 8, -1, 0, (const void *)tmp.v);
}
void RawFr::fromMpz(Element &r, const mpz_t a) {
for (int i=0; i<Fr_N64; i++) r.v[i] = 0;
mpz_export((void *)(r.v), NULL, -1, 8, -1, 0, a);
Fr_rawToMontgomery(r.v, r.v);
}
int RawFr::toRprBE(const Element &element, uint8_t *data, int bytes)
{
if (bytes < Fr_N64 * 8) {
return -(Fr_N64 * 8);
}
mpz_t r;
mpz_init(r);
toMpz(r, element);
mpz_export(data, NULL, 1, 8, 1, 0, r);
return Fr_N64 * 8;
}
int RawFr::fromRprBE(Element &element, const uint8_t *data, int bytes)
{
if (bytes < Fr_N64 * 8) {
return -(Fr_N64* 8);
}
mpz_t r;
mpz_init(r);
mpz_import(r, Fr_N64 * 8, 0, 1, 0, 0, data);
fromMpz(element, r);
return Fr_N64 * 8;
}
static bool init = Fr_init();
RawFr RawFr::field;