From 15ddef48dad2540c2c1127e310061bf99e3264a7 Mon Sep 17 00:00:00 2001 From: Alejandro Cabeza Romero Date: Tue, 21 Apr 2026 16:20:22 +0200 Subject: [PATCH] Adapt circom functions to memory-based. --- src/circom_adapter.cpp | 191 +++++++++++++++++++++++++++++++++++++++++ src/circom_adapter.hpp | 13 +++ src/poq/ffi.cpp | 25 ++++-- src/types.hpp | 19 ++-- 4 files changed, 234 insertions(+), 14 deletions(-) create mode 100644 src/circom_adapter.cpp create mode 100644 src/circom_adapter.hpp diff --git a/src/circom_adapter.cpp b/src/circom_adapter.cpp new file mode 100644 index 0000000..c2a590d --- /dev/null +++ b/src/circom_adapter.cpp @@ -0,0 +1,191 @@ +#include "circom_adapter.hpp" +#include "circom_fwd.hpp" + +#include +#include +#include + +Circom_Circuit* loadCircuit(const ConstBytes& circuit_bytes) { + Circom_Circuit* circuit = new Circom_Circuit; + + circuit->InputHashMap = new HashSignalInfo[get_size_of_input_hashmap()]; + uint dsize = get_size_of_input_hashmap() * sizeof(HashSignalInfo); + memcpy((void*)(circuit->InputHashMap), (void*)circuit_bytes.data, dsize); + + circuit->witness2SignalList = new u64[get_size_of_witness()]; + uint inisize = dsize; + dsize = get_size_of_witness() * sizeof(u64); + memcpy((void*)(circuit->witness2SignalList), (void*)(circuit_bytes.data + inisize), dsize); + + circuit->circuitConstants = new FrElement[get_size_of_constants()]; + if (get_size_of_constants() > 0) { + inisize += dsize; + dsize = get_size_of_constants() * sizeof(FrElement); + memcpy((void*)(circuit->circuitConstants), (void*)(circuit_bytes.data + inisize), dsize); + } + + std::map templateInsId2IOSignalInfo1; + IOFieldDefPair* busInsId2FieldInfo1 = nullptr; + if (get_size_of_io_map() > 0) { + u32 index[get_size_of_io_map()]; + inisize += dsize; + dsize = get_size_of_io_map() * sizeof(u32); + memcpy((void*)index, (void*)(circuit_bytes.data + inisize), dsize); + inisize += dsize; + assert(inisize % sizeof(u32) == 0); + assert(circuit_bytes.size % sizeof(u32) == 0); + u32 dataiomap[(circuit_bytes.size - inisize) / sizeof(u32)]; + memcpy((void*)dataiomap, (void*)(circuit_bytes.data + inisize), circuit_bytes.size - inisize); + u32* pu32 = dataiomap; + for (int i = 0; i < get_size_of_io_map(); i++) { + u32 n = *pu32; + IOFieldDefPair p; + p.len = n; + IOFieldDef defs[n]; + pu32 += 1; + for (u32 j = 0; j < n; j++) { + defs[j].offset = *pu32; + u32 len = *(pu32 + 1); + defs[j].len = len; + defs[j].lengths = new u32[len]; + memcpy((void*)defs[j].lengths, (void*)(pu32 + 2), len * sizeof(u32)); + pu32 += len + 2; + defs[j].size = *pu32; + defs[j].busId = *(pu32 + 1); + pu32 += 2; + } + p.defs = (IOFieldDef*)calloc(p.len, sizeof(IOFieldDef)); + for (u32 j = 0; j < p.len; j++) { + p.defs[j] = defs[j]; + } + templateInsId2IOSignalInfo1[index[i]] = p; + } + busInsId2FieldInfo1 = (IOFieldDefPair*)calloc(get_size_of_bus_field_map(), sizeof(IOFieldDefPair)); + for (int i = 0; i < get_size_of_bus_field_map(); i++) { + u32 n = *pu32; + IOFieldDefPair p; + p.len = n; + IOFieldDef defs[n]; + pu32 += 1; + for (u32 j = 0; j < n; j++) { + defs[j].offset = *pu32; + u32 len = *(pu32 + 1); + defs[j].len = len; + defs[j].lengths = new u32[len]; + memcpy((void*)defs[j].lengths, (void*)(pu32 + 2), len * sizeof(u32)); + pu32 += len + 2; + defs[j].size = *pu32; + defs[j].busId = *(pu32 + 1); + pu32 += 2; + } + p.defs = (IOFieldDef*)calloc(10, sizeof(IOFieldDef)); + for (u32 j = 0; j < p.len; j++) { + p.defs[j] = defs[j]; + } + busInsId2FieldInfo1[i] = p; + } + } + circuit->templateInsId2IOSignalInfo = move(templateInsId2IOSignalInfo1); + circuit->busInsId2FieldInfo = busInsId2FieldInfo1; + + return circuit; +} + +void loadJson(Circom_CalcWit *ctx, const char* inputs_json) { + json jin = json::parse(inputs_json); + json j; + + //std::cout << jin << std::endl; + std::string prefix = ""; + qualify_input(prefix, jin, j); + //std::cout << j << std::endl; + + u64 nItems = j.size(); + // printf("Items : %llu\n",nItems); + if (nItems == 0){ + ctx->tryRunCircuit(); + } + for (json::iterator it = j.begin(); it != j.end(); ++it) { + // std::cout << it.key() << " => " << it.value() << '\n'; + u64 h = fnv1a(it.key()); + std::vector v; + json2FrElements(it.value(),v); + uint signalSize = ctx->getInputSignalSize(h); + if (v.size() < signalSize) { + std::ostringstream errStrStream; + errStrStream << "Error loading signal " << it.key() << ": Not enough values\n"; + throw std::runtime_error(errStrStream.str() ); + } + if (v.size() > signalSize) { + std::ostringstream errStrStream; + errStrStream << "Error loading signal " << it.key() << ": Too many values\n"; + throw std::runtime_error(errStrStream.str() ); + } + for (uint i = 0; i " << Fr_element2str(&(v[i])) << '\n'; + ctx->setInputSignal(h,i,v[i]); + } catch (std::runtime_error e) { + std::ostringstream errStrStream; + errStrStream << "Error setting signal: " << it.key() << "\n" << e.what(); + throw std::runtime_error(errStrStream.str() ); + } + } + } +} + +void writeBinWitness(Circom_CalcWit *ctx, Bytes* output_witness) { + std::vector buf; + + auto write = [&](const void* data, size_t size) { + const uint8_t* p = (const uint8_t*)data; + buf.insert(buf.end(), p, p + size); + }; + + write("wtns", 4); + + u32 version = 2; + write(&version, 4); + + u32 nSections = 2; + write(&nSections, 4); + + // Header + u32 idSection1 = 1; + write(&idSection1, 4); + + u32 n8 = Fr_N64*8; + + u64 idSection1length = 8 + n8; + write(&idSection1length, 8); + + write(&n8, 4); + + write(Fr_q.longVal, Fr_N64*8); + + uint Nwtns = get_size_of_witness(); + + u32 nVars = (u32)Nwtns; + write(&nVars, 4); + + // Data + u32 idSection2 = 2; + write(&idSection2, 4); + + u64 idSection2length = (u64)n8*(u64)Nwtns; + write(&idSection2length, 8); + + FrElement v; + + for (int i=0;igetWitness(i, &v); + Fr_toLongNormal(&v, &v); + write(v.longVal, Fr_N64*8); + } + + size_t size = buf.size(); + output_witness->data = static_cast(malloc(size)); + if (output_witness->data == nullptr) return; + output_witness->size = size; + memcpy(output_witness->data, buf.data(), size); +} diff --git a/src/circom_adapter.hpp b/src/circom_adapter.hpp new file mode 100644 index 0000000..e90e0c6 --- /dev/null +++ b/src/circom_adapter.hpp @@ -0,0 +1,13 @@ +#ifndef CIRCOM_ADAPTER_HPP +#define CIRCOM_ADAPTER_HPP + +#include "types.hpp" +#include "calcwit.hpp" +#include "circom.hpp" + +// Return value +Circom_Circuit* loadCircuit(const ConstBytes& circuit); +void loadJson(Circom_CalcWit *ctx, const char* inputs_json); +void writeBinWitness(Circom_CalcWit *ctx, Bytes* output_witness); + +#endif diff --git a/src/poq/ffi.cpp b/src/poq/ffi.cpp index f78c632..eb5ebc9 100644 --- a/src/poq/ffi.cpp +++ b/src/poq/ffi.cpp @@ -1,9 +1,12 @@ #include "poq/ffi.hpp" #include "circom_fwd.hpp" +#include "circom_adapter.hpp" #include #include +#include "../types.hpp" + template static Status exceptions_into_status(T&& func) { try { @@ -84,18 +87,22 @@ static Status validate_witness_arguments(const WitnessInput* input, const Bytes* } static Status generate_witness_impl(const WitnessInput* input, Bytes* output) { - // TODO: Implement the actual witness generation logic using the provided input data. - const uint8_t dummy_witness[] = {0, 1, 2, 3}; + const ConstBytes& circuit_bytes = input->dat; - const size_t witness_size = sizeof(dummy_witness); - uint8_t* witness_data = static_cast(malloc(witness_size)); - if (witness_data == nullptr) { - return status_new(StatusCode_OutOfMemory, "Failed to allocate witness memory."); + Circom_Circuit* circuit = loadCircuit(circuit_bytes); + Circom_CalcWit* ctx = new Circom_CalcWit(circuit); + + loadJson(ctx, input->inputs_json); + if (ctx->getRemaingInputsToBeSet()!=0) { + const std::string message = "Not all inputs have been set. Only " + std::to_string(get_main_input_signal_no()-ctx->getRemaingInputsToBeSet()) + " out of " + std::to_string(get_main_input_signal_no()) + "."; + delete ctx; + delete circuit; + return status_new(StatusCode_InvalidInput, message.c_str()); } - std::copy(dummy_witness, dummy_witness + witness_size, witness_data); - output->data = witness_data; - output->size = witness_size; + writeBinWitness(ctx, output); + delete ctx; + delete circuit; return status_ok(); } diff --git a/src/types.hpp b/src/types.hpp index 81ce778..da24e25 100644 --- a/src/types.hpp +++ b/src/types.hpp @@ -52,11 +52,19 @@ static inline Status status_new(const StatusCode code, const char* message) { } return status; } -static inline Status status_from_code(const StatusCode code) { return status_new(code, NULL); } -static inline Status status_ok() { return status_from_code(StatusCode_Ok); } +static inline Status status_from_code(const StatusCode code) { + return status_new(code, NULL); +} +static inline Status status_ok() { + return status_from_code(StatusCode_Ok); +} -static inline bool status_is_ok(const Status status) { return status_code_is_ok(status.code); } -static inline bool status_is_error(const Status status) { return status_code_is_error(status.code); } +static inline bool status_is_ok(const Status status) { + return status_code_is_ok(status.code); +} +static inline bool status_is_error(const Status status) { + return status_code_is_error(status.code); +} /// Inputs for witness generation. typedef struct WitnessInput { @@ -67,7 +75,8 @@ typedef struct WitnessInput { } WitnessInput; static inline void free_bytes(Bytes* bytes) { - if (bytes == NULL) return; + if (bytes == NULL) + return; free(bytes->data); bytes->data = NULL; bytes->size = 0;