From d9ce2c61e407714ccbb996e8d25d52818860e732 Mon Sep 17 00:00:00 2001 From: xphade <18196286+xphade@users.noreply.github.com> Date: Tue, 16 May 2023 22:13:04 +0200 Subject: [PATCH] Implement dynamic amount of tokens for change With the recent update to NUT-08, we can ensure that the amount of blank outputs is always enough to cover any overpaid lightning fees. This change implements this functionality for both the wallet and the mint. The mint updateis backwards-compatible with respect to old wallets. --- cashu/core/helpers.py | 11 ++++++ cashu/mint/ledger.py | 32 +++++++++--------- cashu/wallet/api/router.py | 2 +- cashu/wallet/cli/cli.py | 2 +- cashu/wallet/wallet.py | 12 +++---- tests/test_core.py | 26 +++++++++++++++ tests/test_mint.py | 68 ++++++++++++++++++++++++++++++++++++++ 7 files changed, 130 insertions(+), 23 deletions(-) diff --git a/cashu/core/helpers.py b/cashu/core/helpers.py index 1d586d2b..58a5d0a5 100644 --- a/cashu/core/helpers.py +++ b/cashu/core/helpers.py @@ -1,4 +1,5 @@ import asyncio +import math from functools import partial, wraps from typing import List @@ -42,3 +43,13 @@ def fee_reserve(amount_msat: int, internal=False) -> int: int(settings.lightning_reserve_fee_min), int(amount_msat * settings.lightning_fee_percent / 100.0), ) + + +def calculate_number_of_blank_outputs(fee_reserve_sat: int): + """Calculates the number of blank outputs used for returning overpaid fees. + + The formula ensures that any overpaid fees can be represented by the blank outputs, + see NUT-08 for details. + """ + assert fee_reserve_sat > 0, "Fee reserve has to be positive." + return max(math.ceil(math.log2(fee_reserve_sat)), 1) diff --git a/cashu/mint/ledger.py b/cashu/mint/ledger.py index dfa14047..4b099069 100644 --- a/cashu/mint/ledger.py +++ b/cashu/mint/ledger.py @@ -515,7 +515,7 @@ async def _generate_change_promises( total_provided: int, invoice_amount: int, ln_fee_msat: int, - outputs: List[BlindedMessage], + outputs: Optional[List[BlindedMessage]], keyset: Optional[MintKeyset] = None, ): """Generates a set of new promises (blinded signatures) from a set of blank outputs @@ -523,15 +523,15 @@ async def _generate_change_promises( fee reserve provided by the wallet and the actual Lightning fee paid by the mint. If there is a positive difference, produces maximum `n_return_outputs` new outputs - with values close or equal to the fee difference. We can't be sure that we hit the - fee perfectly because we can only work with a limited set of blanket outputs and - their values are limited to 2^n. + with values close or equal to the fee difference. If the given number of `outputs` matches + the equation defined in NUT-08, we can be sure to return the overpaid fee perfectly. + Otherwise, a smaller amount will be returned. Args: total_provided (int): Amount of the proofs provided by the wallet. invoice_amount (int): Amount of the invoice to be paid. ln_fee_msat (int): Actually paid Lightning network fees. - outputs (List[BlindedMessage]): Outputs to sign for returning the overpaid fees. + outputs (Optional[List[BlindedMessage]]): Outputs to sign for returning the overpaid fees. Raises: Exception: Output validation failed. @@ -541,22 +541,24 @@ async def _generate_change_promises( """ # we make sure that the fee is positive ln_fee_msat = abs(ln_fee_msat) - # maximum number of change outputs (must be in consensus with wallet) - n_return_outputs = 4 + ln_fee_sat = math.ceil(ln_fee_msat / 1000) user_paid_fee_sat = total_provided - invoice_amount + overpaid_fee_sat = user_paid_fee_sat - ln_fee_sat logger.debug( - f"Lightning fee was: {ln_fee_sat}. User paid: {user_paid_fee_sat}. Returning difference." + f"Lightning fee was: {ln_fee_sat}. User paid: {user_paid_fee_sat}. " + f"Returning difference: {overpaid_fee_sat}." ) - if user_paid_fee_sat - ln_fee_sat > 0 and outputs is not None: - # we will only accept at maximum n_return_outputs outputs - assert len(outputs) <= n_return_outputs, Exception( - "too many change outputs provided" - ) - return_amounts = amount_split(user_paid_fee_sat - ln_fee_sat) + if overpaid_fee_sat > 0 and outputs is not None: + return_amounts = amount_split(overpaid_fee_sat) + + # We return at most as many outputs as were provided or as many as are + # required to pay back the overpaid fee. + n_return_outputs = min(len(outputs), len(return_amounts)) + # we only need as many outputs as we have change to return - outputs = outputs[: len(return_amounts)] + outputs = outputs[:n_return_outputs] # we sort the return_amounts in descending order so we only # take the largest values in the next step return_amounts_sorted = sorted(return_amounts, reverse=True) diff --git a/cashu/wallet/api/router.py b/cashu/wallet/api/router.py index 23cf106a..fcf26270 100644 --- a/cashu/wallet/api/router.py +++ b/cashu/wallet/api/router.py @@ -75,7 +75,7 @@ async def pay( status_code=status.HTTP_400_BAD_REQUEST, detail="balance is too low." ) _, send_proofs = await wallet.split_to_send(wallet.proofs, total_amount) - await wallet.pay_lightning(send_proofs, invoice) + await wallet.pay_lightning(send_proofs, invoice, fee_reserve_sat) await wallet.load_proofs() return { "amount": total_amount - fee_reserve_sat, diff --git a/cashu/wallet/cli/cli.py b/cashu/wallet/cli/cli.py index 61e4fcd0..d3801a54 100644 --- a/cashu/wallet/cli/cli.py +++ b/cashu/wallet/cli/cli.py @@ -115,7 +115,7 @@ async def pay(ctx: Context, invoice: str, yes: bool): print("Error: Balance too low.") return _, send_proofs = await wallet.split_to_send(wallet.proofs, total_amount) - await wallet.pay_lightning(send_proofs, invoice) + await wallet.pay_lightning(send_proofs, invoice, fee_reserve_sat) await wallet.load_proofs() wallet.status() diff --git a/cashu/wallet/wallet.py b/cashu/wallet/wallet.py index da69bc90..0419a133 100644 --- a/cashu/wallet/wallet.py +++ b/cashu/wallet/wallet.py @@ -38,7 +38,7 @@ from ..core.crypto import b_dhke from ..core.crypto.secp import PrivateKey, PublicKey from ..core.db import Database -from ..core.helpers import sum_proofs +from ..core.helpers import calculate_number_of_blank_outputs, sum_proofs from ..core.script import ( step0_carol_checksig_redeemscrip, step0_carol_privkey, @@ -586,13 +586,13 @@ async def split( await invalidate_proof(proof, db=self.db) return frst_proofs, scnd_proofs - async def pay_lightning(self, proofs: List[Proof], invoice: str): + async def pay_lightning(self, proofs: List[Proof], invoice: str, fee_reserve: int): """Pays a lightning invoice""" - # generate outputs for the change for overpaid fees - # we will generate four blanked outputs that the mint will - # imprint with value depending on the fees we overpaid - n_return_outputs = 4 + # Generate a number of blank outputs for any overpaid fees. As described in + # NUT-08, the mint will imprint these outputs with a value depending on the + # amount of fees we overpaid. + n_return_outputs = calculate_number_of_blank_outputs(fee_reserve) secrets = [self._generate_secret() for _ in range(n_return_outputs)] outputs, rs = self._construct_outputs(n_return_outputs * [1], secrets) diff --git a/tests/test_core.py b/tests/test_core.py index 882cae19..3782aa9a 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1,4 +1,7 @@ +import pytest + from cashu.core.base import TokenV3 +from cashu.core.helpers import calculate_number_of_blank_outputs from cashu.core.split import amount_split @@ -22,3 +25,26 @@ def test_tokenv3_deserialize_serialize(): token_str = "cashuAeyJ0b2tlbiI6IFt7InByb29mcyI6IFt7ImlkIjogIkplaFpMVTZuQ3BSZCIsICJhbW91bnQiOiAyLCAic2VjcmV0IjogIjBFN2lDazRkVmxSZjVQRjFnNFpWMnciLCAiQyI6ICIwM2FiNTgwYWQ5NTc3OGVkNTI5NmY4YmVlNjU1ZGJkN2Q2NDJmNWQzMmRlOGUyNDg0NzdlMGI0ZDZhYTg2M2ZjZDUifSwgeyJpZCI6ICJKZWhaTFU2bkNwUmQiLCAiYW1vdW50IjogOCwgInNlY3JldCI6ICJzNklwZXh3SGNxcXVLZDZYbW9qTDJnIiwgIkMiOiAiMDIyZDAwNGY5ZWMxNmE1OGFkOTAxNGMyNTliNmQ2MTRlZDM2ODgyOWYwMmMzODc3M2M0NzIyMWY0OTYxY2UzZjIzIn1dLCAibWludCI6ICJodHRwOi8vbG9jYWxob3N0OjMzMzgifV19" token = TokenV3.deserialize(token_str) assert token.serialize() == token_str + + +def test_calculate_number_of_blank_outputs(): + # Example from NUT-08 specification. + fee_reserve_sat = 1000 + expected_n_blank_outputs = 10 + n_blank_outputs = calculate_number_of_blank_outputs(fee_reserve_sat) + assert n_blank_outputs == expected_n_blank_outputs + + +def test_calculate_number_of_blank_outputs_for_small_fee_reserve(): + # There should always be at least one blank output. + fee_reserve_sat = 1 + expected_n_blank_outputs = 1 + n_blank_outputs = calculate_number_of_blank_outputs(fee_reserve_sat) + assert n_blank_outputs == expected_n_blank_outputs + + +def test_calculate_number_of_blank_outputs_fails_for_negative_fee_reserve(): + # Negative fee reserve is not supported. + fee_reserve_sat = 0 + with pytest.raises(AssertionError): + _ = calculate_number_of_blank_outputs(fee_reserve_sat) diff --git a/tests/test_mint.py b/tests/test_mint.py index 005d5595..03c61829 100644 --- a/tests/test_mint.py +++ b/tests/test_mint.py @@ -3,6 +3,8 @@ import pytest from cashu.core.base import BlindedMessage, Proof +from cashu.core.crypto.b_dhke import step1_alice +from cashu.core.helpers import calculate_number_of_blank_outputs from cashu.core.migrations import migrate_databases SERVER_ENDPOINT = "http://localhost:3338" @@ -110,3 +112,69 @@ async def test_generate_promises(ledger: Ledger): promises[0].C_ == "037074c4f53e326ee14ed67125f387d160e0e729351471b69ad41f7d5d21071e15" ) + + +@pytest.mark.asyncio +async def test_generate_change_promises(ledger: Ledger): + # Example slightly adapted from NUT-08 because we want to ensure the dynamic change + # token amount works: `n_blank_outputs != n_returned_promises != 4`. + invoice_amount = 100_000 + fee_reserve = 2_000 + total_provided = invoice_amount + fee_reserve + actual_fee_msat = 100_000 + + expected_returned_promises = 7 # Amounts = [4, 8, 32, 64, 256, 512, 1024] + expected_returned_fees = 1900 + + n_blank_outputs = calculate_number_of_blank_outputs(fee_reserve) + blinded_msgs = [step1_alice(str(n)) for n in range(n_blank_outputs)] + outputs = [ + BlindedMessage(amount=1, B_=b.serialize().hex()) for b, _ in blinded_msgs + ] + + promises = await ledger._generate_change_promises( + total_provided, invoice_amount, actual_fee_msat, outputs + ) + + assert len(promises) == expected_returned_promises + assert sum([promise.amount for promise in promises]) == expected_returned_fees + + +@pytest.mark.asyncio +async def test_generate_change_promises_legacy_wallet(ledger: Ledger): + # Check if mint handles a legacy wallet implementation (always sends 4 blank + # outputs) as well. + invoice_amount = 100_000 + fee_reserve = 2_000 + total_provided = invoice_amount + fee_reserve + actual_fee_msat = 100_000 + + expected_returned_promises = 4 # Amounts = [64, 256, 512, 1024] + expected_returned_fees = 1856 + + n_blank_outputs = 4 + blinded_msgs = [step1_alice(str(n)) for n in range(n_blank_outputs)] + outputs = [ + BlindedMessage(amount=1, B_=b.serialize().hex()) for b, _ in blinded_msgs + ] + + promises = await ledger._generate_change_promises( + total_provided, invoice_amount, actual_fee_msat, outputs + ) + + assert len(promises) == expected_returned_promises + assert sum([promise.amount for promise in promises]) == expected_returned_fees + + +@pytest.mark.asyncio +async def test_generate_change_promises_returns_empty_if_no_outputs(ledger: Ledger): + invoice_amount = 100_000 + fee_reserve = 1_000 + total_provided = invoice_amount + fee_reserve + actual_fee_msat = 100_000 + outputs = None + + promises = await ledger._generate_change_promises( + total_provided, invoice_amount, actual_fee_msat, outputs + ) + assert len(promises) == 0