cl/ptx: get balance commitments working

This commit is contained in:
David Rusu 2024-05-29 12:50:24 +04:00
parent e25051b582
commit 5cae33a95a
6 changed files with 81 additions and 45 deletions

View File

@ -1,4 +1,4 @@
from keum import grumpkin, PrimeFiniteField from keum import grumpkin
import poseidon import poseidon
@ -9,50 +9,46 @@ Point = grumpkin.AffineWeierstrass
Field = grumpkin.Fq Field = grumpkin.Fq
class Field(PrimeFiniteField): def fake_algebraic_hash(data) -> Field:
ORDER = poseidon.prime_64 """
HACK: we'll fake the algebraic hash using sha256(data) % Field.ORDER
"""
assert all(isinstance(d, Field) for d in data), f"{data}\n{[type(d) for d in data]}"
data = b"".join(d.v.to_bytes(256 // 8) for d in data)
from hashlib import sha256
return Field(int(sha256(data).hexdigest(), 16))
def poseidon_grumpkin_field(): def build_poseidon():
# TODO: These parameters are made up. h = poseidon.Poseidon(
# return poseidon.Poseidon( p=Field.ORDER,
# p=Field.ORDER, security_level=128,
# security_level=128, alpha=5,
# alpha=5, input_rate=3,
# input_rate=3, t=9,
# t=9, )
# )
h, _ = poseidon.case_simple()
# h, _ = poseidon.case_neptune()
# h = poseidon.Poseidon(
# p=Field.ORDER,
# security_level=128,
# alpha=5,
# input_rate=3,
# t=9,
# )
# TODO: this is hacks on hacks to make poseidon take in arbitrary input length. # TODO: this is hacks on hacks to make poseidon take in arbitrary input length.
# Fix is to implement a sponge as described in section 2.1 of # Fix is to implement a sponge as described in section 2.1 of
# https://eprint.iacr.org/2019/458.pdf # https://eprint.iacr.org/2019/458.pdf
def inner(data): def inner(data):
assert all(
isinstance(d, Field) for d in data
), f"{data}\n{[type(d) for d in data]}"
data = [d.v for d in data]
digest = 0 digest = 0
for i in range(0, len(data), h.input_rate - 1): for i in range(0, len(data), h.input_rate - 1):
digest = h.run_hash([digest, *data[i : i + h.input_rate - 1]]) digest = h.run_hash([digest, *data[i : i + h.input_rate - 1]])
return digest return Field(int(digest))
return inner return inner
POSEIDON = poseidon_grumpkin_field() # HASH = build_poseidon()
HASH = fake_algebraic_hash
def prf(domain, *elements) -> Field: def prf(domain, *elements) -> Field:
return Field(int(POSEIDON([*_str_to_vec(domain), *elements]))) return HASH([*_str_to_vec(domain), *elements])
def hash_to_curve(domain, *elements) -> Point: def hash_to_curve(domain, *elements) -> Point:

View File

@ -90,14 +90,20 @@ class PublicNote:
"""Blinding factor used in balance commitments""" """Blinding factor used in balance commitments"""
return prf("CL_NOTE_BAL_BLIND", tx_rand, self.note.nonce, self.nf_pk) return prf("CL_NOTE_BAL_BLIND", tx_rand, self.note.nonce, self.nf_pk)
def balance(self, rand): def balance(self, tx_rand):
""" """
Returns the pederson commitment to the notes value. Returns the pederson commitment to the notes value.
""" """
return balance_commitment( return balance_commitment(
self.note.value, self.note.value, self.blinding(tx_rand), self.note.fungibility_domain
self.blinding(rand), )
self.note.fungibility_domain,
def zero(self, tx_rand):
"""
Returns the pederson commitment to the notes value.
"""
return balance_commitment(
Field.zero(), self.blinding(tx_rand), self.note.fungibility_domain
) )
def commit(self) -> Field: def commit(self) -> Field:
@ -125,7 +131,7 @@ class SecretNote:
note: InnerNote note: InnerNote
nf_sk: Field nf_sk: Field
def to_public_note(self) -> PublicNote: def to_public(self) -> PublicNote:
return PublicNote(note=self.note, nf_pk=nf_pk(self.nf_sk)) return PublicNote(note=self.note, nf_pk=nf_pk(self.nf_sk))
def nullifier(self): def nullifier(self):
@ -136,6 +142,7 @@ class SecretNote:
""" """
return prf("NULLIFIER", self.nonce, self.nf_sk) return prf("NULLIFIER", self.nonce, self.nf_sk)
# TODO: is this used?
def zero(self, rand): def zero(self, rand):
""" """
Returns the pederson commitment to zero using the same blinding as the balance Returns the pederson commitment to zero using the same blinding as the balance

View File

@ -19,13 +19,30 @@ class PartialTransaction:
outputs: list[Output] outputs: list[Output]
rand: Field rand: Field
def balance(self) -> Point: def verify(self) -> bool:
output_balance = sum(n.balance for n in self.outputs) raise NotImplementedError()
input_balance = sum(n.note.balance() for n in self.inputs)
return output_balance - input_balance
def balance(self) -> Point:
output_balance = sum((n.balance for n in self.outputs), start=Point.zero())
# TODO: once again just mentioning this inefficiency. we are converting our private
# inputs to public inputs to compute the balance, so we don't need an Output class,
# we can directly compute the balance commitment from the public output notes.
input_balance = sum(
(n.to_public().balance(self.rand) for n in self.inputs), start=Point.zero()
)
return output_balance + input_balance.negate()
# TODO: do we need this?
def blinding(self) -> Field: def blinding(self) -> Field:
return sum(outputs.blinding(self.rand)) - sum(outputs.blinding(self.rand)) return sum(outputs.blinding(self.rand)) - sum(outputs.blinding(self.rand))
def zero(self) -> Field: def zero(self) -> Field:
return sum(outputs.note.zero(self.rand)) - sum(inputs.zero(self.rand)) output_zero = sum((n.zero for n in self.outputs), start=Point.zero())
# TODO: once again just mentioning this inefficiency. we are converting our private
# inputs to public inputs to compute the zero commitment, so we don't need an Output class,
# we can directly compute the zero commitment from the public output notes.
input_zero = sum(
(n.to_public().zero(self.rand) for n in self.inputs), start=Point.zero()
)
return output_zero + input_zero.negate()

View File

@ -6,16 +6,28 @@ the basic behaviour that we need.
from unittest import TestCase from unittest import TestCase
from crypto import hash_to_curve, Field from crypto import Field, Point, hash_to_curve, prf
class TestCrypto(TestCase): class TestCrypto(TestCase):
def test_hash_to_curve(self): def test_hash_to_curve(self):
p1 = hash_to_curve(Field(0), Field(1), Field(2)) p1 = hash_to_curve("TEST", Field(0), Field(1), Field(2))
p2 = hash_to_curve(Field(0), Field(1), Field(2)) p2 = hash_to_curve("TEST", Field(0), Field(1), Field(2))
assert isinstance(p1, Point)
assert p1 == p2 assert p1 == p2
p3 = hash_to_curve(Field(0), Field(1), Field(3)) p3 = hash_to_curve("TEST", Field(0), Field(1), Field(3))
assert p1 != p3 assert p1 != p3
def test_prf(self):
r1 = prf("TEST", Field(0), Field(1), Field(2))
r2 = prf("TEST", Field(0), Field(1), Field(2))
assert isinstance(r1, Field)
assert r1 == r2
r3 = prf("TEST", Field(0), Field(1), Field(3))
assert r1 != r3

View File

@ -60,7 +60,7 @@ class TestTransfer(TestCase):
) )
ptx = PartialTransaction( ptx = PartialTransaction(
inputs=[alices_note], outputs=[alices_note], rand=tx_rand inputs=[alices_note], outputs=[tx_output], rand=tx_rand
) )
bundle = TransactionBundle(bundle=[ptx]) bundle = TransactionBundle(bundle=[ptx])

View File

@ -1,7 +1,7 @@
from dataclasses import dataclass from dataclasses import dataclass
from partial_transaction import PartialTransaction from partial_transaction import PartialTransaction
from crypto import Field from crypto import Field, Point
@dataclass @dataclass
@ -10,7 +10,11 @@ class TransactionBundle:
def is_balanced(self) -> bool: def is_balanced(self) -> bool:
# TODO: move this to a NOIR constraint # TODO: move this to a NOIR constraint
return Field.zero() == sum(ptx.balance() - ptx.zero() for ptx in self.bundle) balance_commitment = sum(
(ptx.balance() + ptx.zero().negate() for ptx in self.bundle),
start=Point.zero(),
)
return Point.zero() == balance_commitment
def verify(self) -> bool: def verify(self) -> bool:
return self.is_balanced() and all(ptx.verify() for ptx in self.bundle) return self.is_balanced() and all(ptx.verify() for ptx in self.bundle)