Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement dynamic amount of tokens for change #223

Merged
merged 1 commit into from
May 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions cashu/core/helpers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import math
from functools import partial, wraps
from typing import List

Expand Down Expand Up @@ -42,3 +43,13 @@ def fee_reserve(amount_msat: int, internal=False) -> int:
int(settings.lightning_reserve_fee_min),
int(amount_msat * settings.lightning_fee_percent / 100.0),
)


def calculate_number_of_blank_outputs(fee_reserve_sat: int):
callebtc marked this conversation as resolved.
Show resolved Hide resolved
"""Calculates the number of blank outputs used for returning overpaid fees.

The formula ensures that any overpaid fees can be represented by the blank outputs,
see NUT-08 for details.
"""
assert fee_reserve_sat > 0, "Fee reserve has to be positive."
return max(math.ceil(math.log2(fee_reserve_sat)), 1)
Comment on lines +48 to +55
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if I really should put it in here, or rather into the wallet directly.

32 changes: 17 additions & 15 deletions cashu/mint/ledger.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,23 +515,23 @@ async def _generate_change_promises(
total_provided: int,
invoice_amount: int,
ln_fee_msat: int,
outputs: List[BlindedMessage],
outputs: Optional[List[BlindedMessage]],
keyset: Optional[MintKeyset] = None,
):
callebtc marked this conversation as resolved.
Show resolved Hide resolved
"""Generates a set of new promises (blinded signatures) from a set of blank outputs
(outputs with no or ignored amount) by looking at the difference between the Lightning
fee reserve provided by the wallet and the actual Lightning fee paid by the mint.

If there is a positive difference, produces maximum `n_return_outputs` new outputs
with values close or equal to the fee difference. We can't be sure that we hit the
fee perfectly because we can only work with a limited set of blanket outputs and
their values are limited to 2^n.
with values close or equal to the fee difference. If the given number of `outputs` matches
the equation defined in NUT-08, we can be sure to return the overpaid fee perfectly.
Otherwise, a smaller amount will be returned.

Args:
total_provided (int): Amount of the proofs provided by the wallet.
invoice_amount (int): Amount of the invoice to be paid.
ln_fee_msat (int): Actually paid Lightning network fees.
outputs (List[BlindedMessage]): Outputs to sign for returning the overpaid fees.
outputs (Optional[List[BlindedMessage]]): Outputs to sign for returning the overpaid fees.

Raises:
Exception: Output validation failed.
Expand All @@ -541,22 +541,24 @@ async def _generate_change_promises(
"""
# we make sure that the fee is positive
ln_fee_msat = abs(ln_fee_msat)
# maximum number of change outputs (must be in consensus with wallet)
n_return_outputs = 4

ln_fee_sat = math.ceil(ln_fee_msat / 1000)
user_paid_fee_sat = total_provided - invoice_amount
overpaid_fee_sat = user_paid_fee_sat - ln_fee_sat
logger.debug(
f"Lightning fee was: {ln_fee_sat}. User paid: {user_paid_fee_sat}. Returning difference."
f"Lightning fee was: {ln_fee_sat}. User paid: {user_paid_fee_sat}. "
f"Returning difference: {overpaid_fee_sat}."
)
if user_paid_fee_sat - ln_fee_sat > 0 and outputs is not None:
# we will only accept at maximum n_return_outputs outputs
assert len(outputs) <= n_return_outputs, Exception(
"too many change outputs provided"
)
Comment on lines -551 to -555
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although not strictly necessary, for DOS safety, we could add a check here using calculate_number_of_blank_outputs to check whether the user has generated too many blank outputs.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, can do. What do you want to happen in that case? Right now, if a wallet sends too many blank outputs, I'm just ignoring the one that are not necessary, see line 558.


return_amounts = amount_split(user_paid_fee_sat - ln_fee_sat)
if overpaid_fee_sat > 0 and outputs is not None:
return_amounts = amount_split(overpaid_fee_sat)

# We return at most as many outputs as were provided or as many as are
# required to pay back the overpaid fee.
n_return_outputs = min(len(outputs), len(return_amounts))

# we only need as many outputs as we have change to return
outputs = outputs[: len(return_amounts)]
outputs = outputs[:n_return_outputs]
# we sort the return_amounts in descending order so we only
# take the largest values in the next step
return_amounts_sorted = sorted(return_amounts, reverse=True)
Expand Down
2 changes: 1 addition & 1 deletion cashu/wallet/api/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ async def pay(
status_code=status.HTTP_400_BAD_REQUEST, detail="balance is too low."
)
_, send_proofs = await wallet.split_to_send(wallet.proofs, total_amount)
await wallet.pay_lightning(send_proofs, invoice)
await wallet.pay_lightning(send_proofs, invoice, fee_reserve_sat)
await wallet.load_proofs()
return {
"amount": total_amount - fee_reserve_sat,
Expand Down
2 changes: 1 addition & 1 deletion cashu/wallet/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ async def pay(ctx: Context, invoice: str, yes: bool):
print("Error: Balance too low.")
return
_, send_proofs = await wallet.split_to_send(wallet.proofs, total_amount)
await wallet.pay_lightning(send_proofs, invoice)
await wallet.pay_lightning(send_proofs, invoice, fee_reserve_sat)
await wallet.load_proofs()
wallet.status()

Expand Down
12 changes: 6 additions & 6 deletions cashu/wallet/wallet.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from ..core.crypto import b_dhke
from ..core.crypto.secp import PrivateKey, PublicKey
from ..core.db import Database
from ..core.helpers import sum_proofs
from ..core.helpers import calculate_number_of_blank_outputs, sum_proofs
from ..core.script import (
step0_carol_checksig_redeemscrip,
step0_carol_privkey,
Expand Down Expand Up @@ -586,13 +586,13 @@ async def split(
await invalidate_proof(proof, db=self.db)
return frst_proofs, scnd_proofs

async def pay_lightning(self, proofs: List[Proof], invoice: str):
async def pay_lightning(self, proofs: List[Proof], invoice: str, fee_reserve: int):
"""Pays a lightning invoice"""

# generate outputs for the change for overpaid fees
# we will generate four blanked outputs that the mint will
# imprint with value depending on the fees we overpaid
n_return_outputs = 4
# Generate a number of blank outputs for any overpaid fees. As described in
# NUT-08, the mint will imprint these outputs with a value depending on the
# amount of fees we overpaid.
n_return_outputs = calculate_number_of_blank_outputs(fee_reserve)
secrets = [self._generate_secret() for _ in range(n_return_outputs)]
outputs, rs = self._construct_outputs(n_return_outputs * [1], secrets)

Expand Down
26 changes: 26 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import pytest

from cashu.core.base import TokenV3
from cashu.core.helpers import calculate_number_of_blank_outputs
from cashu.core.split import amount_split


Expand All @@ -22,3 +25,26 @@ def test_tokenv3_deserialize_serialize():
token_str = "cashuAeyJ0b2tlbiI6IFt7InByb29mcyI6IFt7ImlkIjogIkplaFpMVTZuQ3BSZCIsICJhbW91bnQiOiAyLCAic2VjcmV0IjogIjBFN2lDazRkVmxSZjVQRjFnNFpWMnciLCAiQyI6ICIwM2FiNTgwYWQ5NTc3OGVkNTI5NmY4YmVlNjU1ZGJkN2Q2NDJmNWQzMmRlOGUyNDg0NzdlMGI0ZDZhYTg2M2ZjZDUifSwgeyJpZCI6ICJKZWhaTFU2bkNwUmQiLCAiYW1vdW50IjogOCwgInNlY3JldCI6ICJzNklwZXh3SGNxcXVLZDZYbW9qTDJnIiwgIkMiOiAiMDIyZDAwNGY5ZWMxNmE1OGFkOTAxNGMyNTliNmQ2MTRlZDM2ODgyOWYwMmMzODc3M2M0NzIyMWY0OTYxY2UzZjIzIn1dLCAibWludCI6ICJodHRwOi8vbG9jYWxob3N0OjMzMzgifV19"
token = TokenV3.deserialize(token_str)
assert token.serialize() == token_str


def test_calculate_number_of_blank_outputs():
# Example from NUT-08 specification.
fee_reserve_sat = 1000
expected_n_blank_outputs = 10
n_blank_outputs = calculate_number_of_blank_outputs(fee_reserve_sat)
assert n_blank_outputs == expected_n_blank_outputs


def test_calculate_number_of_blank_outputs_for_small_fee_reserve():
# There should always be at least one blank output.
fee_reserve_sat = 1
expected_n_blank_outputs = 1
n_blank_outputs = calculate_number_of_blank_outputs(fee_reserve_sat)
assert n_blank_outputs == expected_n_blank_outputs


def test_calculate_number_of_blank_outputs_fails_for_negative_fee_reserve():
# Negative fee reserve is not supported.
fee_reserve_sat = 0
with pytest.raises(AssertionError):
_ = calculate_number_of_blank_outputs(fee_reserve_sat)
68 changes: 68 additions & 0 deletions tests/test_mint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import pytest

from cashu.core.base import BlindedMessage, Proof
from cashu.core.crypto.b_dhke import step1_alice
from cashu.core.helpers import calculate_number_of_blank_outputs
from cashu.core.migrations import migrate_databases

SERVER_ENDPOINT = "http://localhost:3338"
Expand Down Expand Up @@ -110,3 +112,69 @@ async def test_generate_promises(ledger: Ledger):
promises[0].C_
== "037074c4f53e326ee14ed67125f387d160e0e729351471b69ad41f7d5d21071e15"
)


@pytest.mark.asyncio
async def test_generate_change_promises(ledger: Ledger):
# Example slightly adapted from NUT-08 because we want to ensure the dynamic change
# token amount works: `n_blank_outputs != n_returned_promises != 4`.
invoice_amount = 100_000
fee_reserve = 2_000
total_provided = invoice_amount + fee_reserve
actual_fee_msat = 100_000

expected_returned_promises = 7 # Amounts = [4, 8, 32, 64, 256, 512, 1024]
expected_returned_fees = 1900

n_blank_outputs = calculate_number_of_blank_outputs(fee_reserve)
blinded_msgs = [step1_alice(str(n)) for n in range(n_blank_outputs)]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we rather use pre-computed blinded messages here? We don't use them for anything in this test, but maybe better to have it as deterministic as possible.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use a predetermined blinding_factor when calling step1_alice, then the outcome is determinstic!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, now I'm thinking this might break legacy wallets which always send 4 outputs (or we add an exception to that but that would be extra ugly).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use a predetermined blinding_factor when calling step1_alice, then the outcome is determinstic!

Yeah, that was my proposal. I guess that's a yes from your side, so I'll add predefined blinding factors. 👍

Actually, now I'm thinking this might break legacy wallets which always send 4 outputs (or we add an exception to that but that would be extra ugly).

Sorry, I'm a bit confused here, can you expand on this? How does it break legacy wallets? As you can see below, I also added a test for it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I meant: Legacy wallets send 4 outputs, but if our new algo says "we only need 3" then legacy wallets would get an error. We can leave it as you did: just ignore outputs that are unnecessary.

I think it's safe to ignore this for now. There is more than just one place in the code where we can DoS the mint.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I meant: Legacy wallets send 4 outputs, but if our new algo says "we only need 3" then legacy wallets would get an error.

Ah, I think I get what you mean now. However, I think it was already the case before that less outputs could be sent back (e.g. if the return output fits into 3). See line 557 and following in the old version:

return_amounts = amount_split(user_paid_fee_sat - ln_fee_sat)
# we only need as many outputs as we have change to return
outputs = outputs[: len(return_amounts)]

At least if I understand that part correctly. 🙂

outputs = [
BlindedMessage(amount=1, B_=b.serialize().hex()) for b, _ in blinded_msgs
]

promises = await ledger._generate_change_promises(
total_provided, invoice_amount, actual_fee_msat, outputs
)

assert len(promises) == expected_returned_promises
assert sum([promise.amount for promise in promises]) == expected_returned_fees


@pytest.mark.asyncio
async def test_generate_change_promises_legacy_wallet(ledger: Ledger):
# Check if mint handles a legacy wallet implementation (always sends 4 blank
# outputs) as well.
Comment on lines +144 to +146
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent! Very thoughtful!

invoice_amount = 100_000
fee_reserve = 2_000
total_provided = invoice_amount + fee_reserve
actual_fee_msat = 100_000

expected_returned_promises = 4 # Amounts = [64, 256, 512, 1024]
expected_returned_fees = 1856

n_blank_outputs = 4
blinded_msgs = [step1_alice(str(n)) for n in range(n_blank_outputs)]
outputs = [
BlindedMessage(amount=1, B_=b.serialize().hex()) for b, _ in blinded_msgs
]

promises = await ledger._generate_change_promises(
total_provided, invoice_amount, actual_fee_msat, outputs
)

assert len(promises) == expected_returned_promises
assert sum([promise.amount for promise in promises]) == expected_returned_fees


@pytest.mark.asyncio
async def test_generate_change_promises_returns_empty_if_no_outputs(ledger: Ledger):
invoice_amount = 100_000
fee_reserve = 1_000
total_provided = invoice_amount + fee_reserve
actual_fee_msat = 100_000
outputs = None

promises = await ledger._generate_change_promises(
total_provided, invoice_amount, actual_fee_msat, outputs
)
assert len(promises) == 0