cl/ptx: get balance commitments working

This commit is contained in:
David Rusu 2024-05-29 12:50:24 +04:00
parent e25051b582
commit 5cae33a95a
6 changed files with 81 additions and 45 deletions

View File

@ -1,4 +1,4 @@
from keum import grumpkin, PrimeFiniteField
from keum import grumpkin
import poseidon
@ -9,50 +9,46 @@ Point = grumpkin.AffineWeierstrass
Field = grumpkin.Fq
class Field(PrimeFiniteField):
ORDER = poseidon.prime_64
def fake_algebraic_hash(data) -> Field:
"""
HACK: we'll fake the algebraic hash using sha256(data) % Field.ORDER
"""
assert all(isinstance(d, Field) for d in data), f"{data}\n{[type(d) for d in data]}"
data = b"".join(d.v.to_bytes(256 // 8) for d in data)
from hashlib import sha256
return Field(int(sha256(data).hexdigest(), 16))
def poseidon_grumpkin_field():
# TODO: These parameters are made up.
# return poseidon.Poseidon(
# p=Field.ORDER,
# security_level=128,
# alpha=5,
# input_rate=3,
# t=9,
# )
h, _ = poseidon.case_simple()
# h, _ = poseidon.case_neptune()
# h = poseidon.Poseidon(
# p=Field.ORDER,
# security_level=128,
# alpha=5,
# input_rate=3,
# t=9,
# )
def build_poseidon():
h = poseidon.Poseidon(
p=Field.ORDER,
security_level=128,
alpha=5,
input_rate=3,
t=9,
)
# TODO: this is hacks on hacks to make poseidon take in arbitrary input length.
# Fix is to implement a sponge as described in section 2.1 of
# https://eprint.iacr.org/2019/458.pdf
def inner(data):
assert all(
isinstance(d, Field) for d in data
), f"{data}\n{[type(d) for d in data]}"
data = [d.v for d in data]
digest = 0
for i in range(0, len(data), h.input_rate - 1):
digest = h.run_hash([digest, *data[i : i + h.input_rate - 1]])
return digest
return Field(int(digest))
return inner
POSEIDON = poseidon_grumpkin_field()
# HASH = build_poseidon()
HASH = fake_algebraic_hash
def prf(domain, *elements) -> Field:
return Field(int(POSEIDON([*_str_to_vec(domain), *elements])))
return HASH([*_str_to_vec(domain), *elements])
def hash_to_curve(domain, *elements) -> Point:

View File

@ -90,14 +90,20 @@ class PublicNote:
"""Blinding factor used in balance commitments"""
return prf("CL_NOTE_BAL_BLIND", tx_rand, self.note.nonce, self.nf_pk)
def balance(self, rand):
def balance(self, tx_rand):
"""
Returns the pederson commitment to the notes value.
"""
return balance_commitment(
self.note.value,
self.blinding(rand),
self.note.fungibility_domain,
self.note.value, self.blinding(tx_rand), self.note.fungibility_domain
)
def zero(self, tx_rand):
"""
Returns the pederson commitment to the notes value.
"""
return balance_commitment(
Field.zero(), self.blinding(tx_rand), self.note.fungibility_domain
)
def commit(self) -> Field:
@ -125,7 +131,7 @@ class SecretNote:
note: InnerNote
nf_sk: Field
def to_public_note(self) -> PublicNote:
def to_public(self) -> PublicNote:
return PublicNote(note=self.note, nf_pk=nf_pk(self.nf_sk))
def nullifier(self):
@ -136,6 +142,7 @@ class SecretNote:
"""
return prf("NULLIFIER", self.nonce, self.nf_sk)
# TODO: is this used?
def zero(self, rand):
"""
Returns the pederson commitment to zero using the same blinding as the balance

View File

@ -19,13 +19,30 @@ class PartialTransaction:
outputs: list[Output]
rand: Field
def balance(self) -> Point:
output_balance = sum(n.balance for n in self.outputs)
input_balance = sum(n.note.balance() for n in self.inputs)
return output_balance - input_balance
def verify(self) -> bool:
raise NotImplementedError()
def balance(self) -> Point:
output_balance = sum((n.balance for n in self.outputs), start=Point.zero())
# TODO: once again just mentioning this inefficiency. we are converting our private
# inputs to public inputs to compute the balance, so we don't need an Output class,
# we can directly compute the balance commitment from the public output notes.
input_balance = sum(
(n.to_public().balance(self.rand) for n in self.inputs), start=Point.zero()
)
return output_balance + input_balance.negate()
# TODO: do we need this?
def blinding(self) -> Field:
return sum(outputs.blinding(self.rand)) - sum(outputs.blinding(self.rand))
def zero(self) -> Field:
return sum(outputs.note.zero(self.rand)) - sum(inputs.zero(self.rand))
output_zero = sum((n.zero for n in self.outputs), start=Point.zero())
# TODO: once again just mentioning this inefficiency. we are converting our private
# inputs to public inputs to compute the zero commitment, so we don't need an Output class,
# we can directly compute the zero commitment from the public output notes.
input_zero = sum(
(n.to_public().zero(self.rand) for n in self.inputs), start=Point.zero()
)
return output_zero + input_zero.negate()

View File

@ -6,16 +6,28 @@ the basic behaviour that we need.
from unittest import TestCase
from crypto import hash_to_curve, Field
from crypto import Field, Point, hash_to_curve, prf
class TestCrypto(TestCase):
def test_hash_to_curve(self):
p1 = hash_to_curve(Field(0), Field(1), Field(2))
p2 = hash_to_curve(Field(0), Field(1), Field(2))
p1 = hash_to_curve("TEST", Field(0), Field(1), Field(2))
p2 = hash_to_curve("TEST", Field(0), Field(1), Field(2))
assert isinstance(p1, Point)
assert p1 == p2
p3 = hash_to_curve(Field(0), Field(1), Field(3))
p3 = hash_to_curve("TEST", Field(0), Field(1), Field(3))
assert p1 != p3
def test_prf(self):
r1 = prf("TEST", Field(0), Field(1), Field(2))
r2 = prf("TEST", Field(0), Field(1), Field(2))
assert isinstance(r1, Field)
assert r1 == r2
r3 = prf("TEST", Field(0), Field(1), Field(3))
assert r1 != r3

View File

@ -60,7 +60,7 @@ class TestTransfer(TestCase):
)
ptx = PartialTransaction(
inputs=[alices_note], outputs=[alices_note], rand=tx_rand
inputs=[alices_note], outputs=[tx_output], rand=tx_rand
)
bundle = TransactionBundle(bundle=[ptx])

View File

@ -1,7 +1,7 @@
from dataclasses import dataclass
from partial_transaction import PartialTransaction
from crypto import Field
from crypto import Field, Point
@dataclass
@ -10,7 +10,11 @@ class TransactionBundle:
def is_balanced(self) -> bool:
# TODO: move this to a NOIR constraint
return Field.zero() == sum(ptx.balance() - ptx.zero() for ptx in self.bundle)
balance_commitment = sum(
(ptx.balance() + ptx.zero().negate() for ptx in self.bundle),
start=Point.zero(),
)
return Point.zero() == balance_commitment
def verify(self) -> bool:
return self.is_balanced() and all(ptx.verify() for ptx in self.bundle)