Allow wallet to be used as a reference type

This commit is contained in:
Mark Spanbroek 2021-04-19 16:14:27 +02:00
parent 50d59ce48a
commit 2a8e4e5bf4
3 changed files with 90 additions and 17 deletions

46
nitro/wallet/deref.nim Normal file
View File

@ -0,0 +1,46 @@
import std/macros
func identName(identDefs: NimNode): NimNode =
identDefs.expectKind(nnkIdentDefs)
identDefs[0]
func identType(identDefs: NimNode): NimNode =
identDefs.expectKind(nnkIdentDefs)
identDefs[^2]
func `identType=`(identDefs: NimNode, identType: NimNode) =
identDefs.expectKind(nnkIdentDefs)
identDefs[^2] = identType
func insertRef(function: NimNode) =
function.expectKind(nnkFuncDef)
var paramType = function.params[1].identType
if paramType.kind == nnkVarTy:
paramType = paramType[0]
function.params[1].identType = newNimNode(nnkRefTy, paramType).add(paramType)
func paramNames(function: NimNode): seq[NimNode] =
function.expectKind(nnkFuncDef)
for i in 1..<function.params.len:
result.add(function.params[i].identName)
func insertDeref(params: var seq[NimNode]) =
params[0] = newNimNode(nnkBracketExpr, params[0]).add(params[0])
func derefOverload(function: NimNode): NimNode =
function.expectKind(nnkFuncDef)
var arguments = function.paramNames
arguments.insertDeref()
result = function.copyNimTree()
result.insertRef()
result.body = newCall(function.name, arguments)
macro deref*(function: untyped{nkFuncDef}): untyped =
## Creates an overload that dereferences the first argument of the function
## call. Roughly equivalent to the `implicitDeref` experimental feature of
## Nim.
let overload = derefOverload(function)
quote do:
`function`
`overload`

View File

@ -6,6 +6,7 @@ import ./signedstate
import ./ledger import ./ledger
import ./balances import ./balances
import ./nonces import ./nonces
import ./deref
push: {.upraises:[].} push: {.upraises:[].}
@ -19,21 +20,25 @@ type
key: EthPrivateKey key: EthPrivateKey
channels: Table[ChannelId, SignedState] channels: Table[ChannelId, SignedState]
nonces: Nonces nonces: Nonces
WalletRef* = ref Wallet
ChannelId* = Destination ChannelId* = Destination
Payment* = tuple Payment* = tuple
destination: Destination destination: Destination
amount: UInt256 amount: UInt256
func init*(_: type Wallet, key: EthPrivateKey): Wallet = func init*(_: type Wallet, key: EthPrivateKey): Wallet =
result.key = key Wallet(key: key)
func publicKey*(wallet: Wallet): EthPublicKey = func new*(_: type WalletRef, key: EthPrivateKey): WalletRef =
WalletRef(key: key)
func publicKey*(wallet: Wallet): EthPublicKey {.deref.} =
wallet.key.toPublicKey wallet.key.toPublicKey
func address*(wallet: Wallet): EthAddress = func address*(wallet: Wallet): EthAddress {.deref.} =
wallet.publicKey.toAddress wallet.publicKey.toAddress
func destination*(wallet: Wallet): Destination = func destination*(wallet: Wallet): Destination {.deref.}=
wallet.address.toDestination wallet.address.toDestination
func sign(wallet: Wallet, state: SignedState): SignedState = func sign(wallet: Wallet, state: SignedState): SignedState =
@ -63,7 +68,7 @@ func openLedgerChannel*(wallet: var Wallet,
chainId: UInt256, chainId: UInt256,
nonce: UInt48, nonce: UInt48,
asset: EthAddress, asset: EthAddress,
amount: UInt256): ?!ChannelId = amount: UInt256): ?!ChannelId {.deref.} =
let state = startLedger(wallet.address, hub, chainId, nonce, asset, amount) let state = startLedger(wallet.address, hub, chainId, nonce, asset, amount)
wallet.createChannel(state) wallet.createChannel(state)
@ -71,11 +76,12 @@ func openLedgerChannel*(wallet: var Wallet,
hub: EthAddress, hub: EthAddress,
chainId: UInt256, chainId: UInt256,
asset: EthAddress, asset: EthAddress,
amount: UInt256): ?!ChannelId = amount: UInt256): ?!ChannelId {.deref.} =
let nonce = wallet.nonces.getNonce(chainId, wallet.address, hub) let nonce = wallet.nonces.getNonce(chainId, wallet.address, hub)
openLedgerChannel(wallet, hub, chainId, nonce, asset, amount) openLedgerChannel(wallet, hub, chainId, nonce, asset, amount)
func acceptChannel*(wallet: var Wallet, signed: SignedState): ?!ChannelId = func acceptChannel*(wallet: var Wallet,
signed: SignedState): ?!ChannelId {.deref.} =
if not signed.hasParticipant(wallet.address): if not signed.hasParticipant(wallet.address):
return failure "wallet owner is not a participant" return failure "wallet owner is not a participant"
@ -84,18 +90,21 @@ func acceptChannel*(wallet: var Wallet, signed: SignedState): ?!ChannelId =
wallet.createChannel(signed) wallet.createChannel(signed)
func latestSignedState*(wallet: Wallet, channel: ChannelId): ?SignedState = func latestSignedState*(wallet: Wallet,
channel: ChannelId): ?SignedState {.deref.} =
wallet.channels.?[channel] wallet.channels.?[channel]
func state*(wallet: Wallet, channel: ChannelId): ?State = func state*(wallet: Wallet,
channel: ChannelId): ?State {.deref.} =
wallet.latestSignedState(channel).?state wallet.latestSignedState(channel).?state
func signatures*(wallet: Wallet, channel: ChannelId): ?seq[Signature] = func signatures*(wallet: Wallet,
channel: ChannelId): ?seq[Signature] {.deref.} =
wallet.latestSignedState(channel).?signatures wallet.latestSignedState(channel).?signatures
func signature*(wallet: Wallet, func signature*(wallet: Wallet,
channel: ChannelId, channel: ChannelId,
address: EthAddress): ?Signature = address: EthAddress): ?Signature {.deref.} =
if signed =? wallet.latestSignedState(channel): if signed =? wallet.latestSignedState(channel):
for signature in signed.signatures: for signature in signed.signatures:
if signer =? signature.recover(signed.state): if signer =? signature.recover(signed.state):
@ -117,7 +126,7 @@ func balance(state: State,
func balance*(wallet: Wallet, func balance*(wallet: Wallet,
channel: ChannelId, channel: ChannelId,
asset: EthAddress, asset: EthAddress,
destination: Destination): UInt256 = destination: Destination): UInt256 {.deref.} =
if state =? wallet.state(channel): if state =? wallet.state(channel):
state.balance(asset, destination) state.balance(asset, destination)
else: else:
@ -126,10 +135,12 @@ func balance*(wallet: Wallet,
func balance*(wallet: Wallet, func balance*(wallet: Wallet,
channel: ChannelId, channel: ChannelId,
asset: EthAddress, asset: EthAddress,
address: EthAddress): UInt256 = address: EthAddress): UInt256 {.deref.} =
wallet.balance(channel, asset, address.toDestination) wallet.balance(channel, asset, address.toDestination)
func balance*(wallet: Wallet, channel: ChannelId, asset: EthAddress): UInt256 = func balance*(wallet: Wallet,
channel: ChannelId,
asset: EthAddress): UInt256 {.deref.} =
wallet.balance(channel, asset, wallet.address) wallet.balance(channel, asset, wallet.address)
func total(state: State, asset: EthAddress): UInt256 = func total(state: State, asset: EthAddress): UInt256 =
@ -149,7 +160,7 @@ func pay*(wallet: var Wallet,
channel: ChannelId, channel: ChannelId,
asset: EthAddress, asset: EthAddress,
receiver: Destination, receiver: Destination,
amount: UInt256): ?!SignedState = amount: UInt256): ?!SignedState {.deref.} =
without var state =? wallet.state(channel): without var state =? wallet.state(channel):
return failure "channel not found" return failure "channel not found"
@ -165,14 +176,14 @@ func pay*(wallet: var Wallet,
channel: ChannelId, channel: ChannelId,
asset: EthAddress, asset: EthAddress,
receiver: EthAddress, receiver: EthAddress,
amount: UInt256): ?!SignedState = amount: UInt256): ?!SignedState {.deref.} =
wallet.pay(channel, asset, receiver.toDestination, amount) wallet.pay(channel, asset, receiver.toDestination, amount)
func acceptPayment*(wallet: var Wallet, func acceptPayment*(wallet: var Wallet,
channel: ChannelId, channel: ChannelId,
asset: EthAddress, asset: EthAddress,
sender: EthAddress, sender: EthAddress,
payment: SignedState): ?!void = payment: SignedState): ?!void {.deref.} =
if not wallet.channels.contains(channel): if not wallet.channels.contains(channel):
return failure "unknown channel" return failure "unknown channel"

View File

@ -213,3 +213,19 @@ suite "wallet: accepting payments":
payment.state.appDefinition = EthAddress.example payment.state.appDefinition = EthAddress.example
payment.signatures = @[payerKey.sign(payment.state)] payment.signatures = @[payerKey.sign(payment.state)]
check receiver.acceptPayment(channel, asset, payer.address, payment).isFailure check receiver.acceptPayment(channel, asset, payer.address, payment).isFailure
suite "wallet reference type":
let asset = EthAddress.example
let amount = 42.u256
let chainId = UInt256.example
test "wallet can also be used as a reference type":
let wallet1 = WalletRef.new(EthPrivateKey.random())
let wallet2 = WalletRef.new(EthPrivateKey.random())
let address1 = wallet1.address
let address2 = wallet2.address
let channel = !wallet1.openLedgerChannel(address2, chainId, asset, amount)
check !wallet2.acceptChannel(!wallet1.latestSignedState(channel)) == channel
let payment = !wallet1.pay(channel, asset, address2, amount)
check wallet2.acceptPayment(channel, asset, address1, payment).isSuccess