diff --git a/cashu/core/helpers.py b/cashu/core/helpers.py index 1d586d2b..58a5d0a5 100644 --- a/cashu/core/helpers.py +++ b/cashu/core/helpers.py @@ -1,4 +1,5 @@ import asyncio +import math from functools import partial, wraps from typing import List @@ -42,3 +43,13 @@ def fee_reserve(amount_msat: int, internal=False) -> int: int(settings.lightning_reserve_fee_min), int(amount_msat * settings.lightning_fee_percent / 100.0), ) + + +def calculate_number_of_blank_outputs(fee_reserve_sat: int): + """Calculates the number of blank outputs used for returning overpaid fees. + + The formula ensures that any overpaid fees can be represented by the blank outputs, + see NUT-08 for details. + """ + assert fee_reserve_sat > 0, "Fee reserve has to be positive." + return max(math.ceil(math.log2(fee_reserve_sat)), 1) diff --git a/cashu/mint/ledger.py b/cashu/mint/ledger.py index dfa14047..4b099069 100644 --- a/cashu/mint/ledger.py +++ b/cashu/mint/ledger.py @@ -515,7 +515,7 @@ async def _generate_change_promises( total_provided: int, invoice_amount: int, ln_fee_msat: int, - outputs: List[BlindedMessage], + outputs: Optional[List[BlindedMessage]], keyset: Optional[MintKeyset] = None, ): """Generates a set of new promises (blinded signatures) from a set of blank outputs @@ -523,15 +523,15 @@ async def _generate_change_promises( fee reserve provided by the wallet and the actual Lightning fee paid by the mint. If there is a positive difference, produces maximum `n_return_outputs` new outputs - with values close or equal to the fee difference. We can't be sure that we hit the - fee perfectly because we can only work with a limited set of blanket outputs and - their values are limited to 2^n. + with values close or equal to the fee difference. If the given number of `outputs` matches + the equation defined in NUT-08, we can be sure to return the overpaid fee perfectly. + Otherwise, a smaller amount will be returned. Args: total_provided (int): Amount of the proofs provided by the wallet. invoice_amount (int): Amount of the invoice to be paid. ln_fee_msat (int): Actually paid Lightning network fees. - outputs (List[BlindedMessage]): Outputs to sign for returning the overpaid fees. + outputs (Optional[List[BlindedMessage]]): Outputs to sign for returning the overpaid fees. Raises: Exception: Output validation failed. @@ -541,22 +541,24 @@ async def _generate_change_promises( """ # we make sure that the fee is positive ln_fee_msat = abs(ln_fee_msat) - # maximum number of change outputs (must be in consensus with wallet) - n_return_outputs = 4 + ln_fee_sat = math.ceil(ln_fee_msat / 1000) user_paid_fee_sat = total_provided - invoice_amount + overpaid_fee_sat = user_paid_fee_sat - ln_fee_sat logger.debug( - f"Lightning fee was: {ln_fee_sat}. User paid: {user_paid_fee_sat}. Returning difference." + f"Lightning fee was: {ln_fee_sat}. User paid: {user_paid_fee_sat}. " + f"Returning difference: {overpaid_fee_sat}." ) - if user_paid_fee_sat - ln_fee_sat > 0 and outputs is not None: - # we will only accept at maximum n_return_outputs outputs - assert len(outputs) <= n_return_outputs, Exception( - "too many change outputs provided" - ) - return_amounts = amount_split(user_paid_fee_sat - ln_fee_sat) + if overpaid_fee_sat > 0 and outputs is not None: + return_amounts = amount_split(overpaid_fee_sat) + + # We return at most as many outputs as were provided or as many as are + # required to pay back the overpaid fee. + n_return_outputs = min(len(outputs), len(return_amounts)) + # we only need as many outputs as we have change to return - outputs = outputs[: len(return_amounts)] + outputs = outputs[:n_return_outputs] # we sort the return_amounts in descending order so we only # take the largest values in the next step return_amounts_sorted = sorted(return_amounts, reverse=True) diff --git a/cashu/wallet/api/router.py b/cashu/wallet/api/router.py index 23cf106a..fcf26270 100644 --- a/cashu/wallet/api/router.py +++ b/cashu/wallet/api/router.py @@ -75,7 +75,7 @@ async def pay( status_code=status.HTTP_400_BAD_REQUEST, detail="balance is too low." ) _, send_proofs = await wallet.split_to_send(wallet.proofs, total_amount) - await wallet.pay_lightning(send_proofs, invoice) + await wallet.pay_lightning(send_proofs, invoice, fee_reserve_sat) await wallet.load_proofs() return { "amount": total_amount - fee_reserve_sat, diff --git a/cashu/wallet/cli/cli.py b/cashu/wallet/cli/cli.py index 61e4fcd0..d3801a54 100644 --- a/cashu/wallet/cli/cli.py +++ b/cashu/wallet/cli/cli.py @@ -115,7 +115,7 @@ async def pay(ctx: Context, invoice: str, yes: bool): print("Error: Balance too low.") return _, send_proofs = await wallet.split_to_send(wallet.proofs, total_amount) - await wallet.pay_lightning(send_proofs, invoice) + await wallet.pay_lightning(send_proofs, invoice, fee_reserve_sat) await wallet.load_proofs() wallet.status() diff --git a/cashu/wallet/wallet.py b/cashu/wallet/wallet.py index da69bc90..0419a133 100644 --- a/cashu/wallet/wallet.py +++ b/cashu/wallet/wallet.py @@ -38,7 +38,7 @@ from ..core.crypto import b_dhke from ..core.crypto.secp import PrivateKey, PublicKey from ..core.db import Database -from ..core.helpers import sum_proofs +from ..core.helpers import calculate_number_of_blank_outputs, sum_proofs from ..core.script import ( step0_carol_checksig_redeemscrip, step0_carol_privkey, @@ -586,13 +586,13 @@ async def split( await invalidate_proof(proof, db=self.db) return frst_proofs, scnd_proofs - async def pay_lightning(self, proofs: List[Proof], invoice: str): + async def pay_lightning(self, proofs: List[Proof], invoice: str, fee_reserve: int): """Pays a lightning invoice""" - # generate outputs for the change for overpaid fees - # we will generate four blanked outputs that the mint will - # imprint with value depending on the fees we overpaid - n_return_outputs = 4 + # Generate a number of blank outputs for any overpaid fees. As described in + # NUT-08, the mint will imprint these outputs with a value depending on the + # amount of fees we overpaid. + n_return_outputs = calculate_number_of_blank_outputs(fee_reserve) secrets = [self._generate_secret() for _ in range(n_return_outputs)] outputs, rs = self._construct_outputs(n_return_outputs * [1], secrets) diff --git a/tests/test_core.py b/tests/test_core.py index 882cae19..3782aa9a 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1,4 +1,7 @@ +import pytest + from cashu.core.base import TokenV3 +from cashu.core.helpers import calculate_number_of_blank_outputs from cashu.core.split import amount_split @@ -22,3 +25,26 @@ def test_tokenv3_deserialize_serialize(): token_str = "cashuAeyJ0b2tlbiI6IFt7InByb29mcyI6IFt7ImlkIjogIkplaFpMVTZuQ3BSZCIsICJhbW91bnQiOiAyLCAic2VjcmV0IjogIjBFN2lDazRkVmxSZjVQRjFnNFpWMnciLCAiQyI6ICIwM2FiNTgwYWQ5NTc3OGVkNTI5NmY4YmVlNjU1ZGJkN2Q2NDJmNWQzMmRlOGUyNDg0NzdlMGI0ZDZhYTg2M2ZjZDUifSwgeyJpZCI6ICJKZWhaTFU2bkNwUmQiLCAiYW1vdW50IjogOCwgInNlY3JldCI6ICJzNklwZXh3SGNxcXVLZDZYbW9qTDJnIiwgIkMiOiAiMDIyZDAwNGY5ZWMxNmE1OGFkOTAxNGMyNTliNmQ2MTRlZDM2ODgyOWYwMmMzODc3M2M0NzIyMWY0OTYxY2UzZjIzIn1dLCAibWludCI6ICJodHRwOi8vbG9jYWxob3N0OjMzMzgifV19" token = TokenV3.deserialize(token_str) assert token.serialize() == token_str + + +def test_calculate_number_of_blank_outputs(): + # Example from NUT-08 specification. + fee_reserve_sat = 1000 + expected_n_blank_outputs = 10 + n_blank_outputs = calculate_number_of_blank_outputs(fee_reserve_sat) + assert n_blank_outputs == expected_n_blank_outputs + + +def test_calculate_number_of_blank_outputs_for_small_fee_reserve(): + # There should always be at least one blank output. + fee_reserve_sat = 1 + expected_n_blank_outputs = 1 + n_blank_outputs = calculate_number_of_blank_outputs(fee_reserve_sat) + assert n_blank_outputs == expected_n_blank_outputs + + +def test_calculate_number_of_blank_outputs_fails_for_negative_fee_reserve(): + # Negative fee reserve is not supported. + fee_reserve_sat = 0 + with pytest.raises(AssertionError): + _ = calculate_number_of_blank_outputs(fee_reserve_sat) diff --git a/tests/test_mint.py b/tests/test_mint.py index 005d5595..03c61829 100644 --- a/tests/test_mint.py +++ b/tests/test_mint.py @@ -3,6 +3,8 @@ import pytest from cashu.core.base import BlindedMessage, Proof +from cashu.core.crypto.b_dhke import step1_alice +from cashu.core.helpers import calculate_number_of_blank_outputs from cashu.core.migrations import migrate_databases SERVER_ENDPOINT = "http://localhost:3338" @@ -110,3 +112,69 @@ async def test_generate_promises(ledger: Ledger): promises[0].C_ == "037074c4f53e326ee14ed67125f387d160e0e729351471b69ad41f7d5d21071e15" ) + + +@pytest.mark.asyncio +async def test_generate_change_promises(ledger: Ledger): + # Example slightly adapted from NUT-08 because we want to ensure the dynamic change + # token amount works: `n_blank_outputs != n_returned_promises != 4`. + invoice_amount = 100_000 + fee_reserve = 2_000 + total_provided = invoice_amount + fee_reserve + actual_fee_msat = 100_000 + + expected_returned_promises = 7 # Amounts = [4, 8, 32, 64, 256, 512, 1024] + expected_returned_fees = 1900 + + n_blank_outputs = calculate_number_of_blank_outputs(fee_reserve) + blinded_msgs = [step1_alice(str(n)) for n in range(n_blank_outputs)] + outputs = [ + BlindedMessage(amount=1, B_=b.serialize().hex()) for b, _ in blinded_msgs + ] + + promises = await ledger._generate_change_promises( + total_provided, invoice_amount, actual_fee_msat, outputs + ) + + assert len(promises) == expected_returned_promises + assert sum([promise.amount for promise in promises]) == expected_returned_fees + + +@pytest.mark.asyncio +async def test_generate_change_promises_legacy_wallet(ledger: Ledger): + # Check if mint handles a legacy wallet implementation (always sends 4 blank + # outputs) as well. + invoice_amount = 100_000 + fee_reserve = 2_000 + total_provided = invoice_amount + fee_reserve + actual_fee_msat = 100_000 + + expected_returned_promises = 4 # Amounts = [64, 256, 512, 1024] + expected_returned_fees = 1856 + + n_blank_outputs = 4 + blinded_msgs = [step1_alice(str(n)) for n in range(n_blank_outputs)] + outputs = [ + BlindedMessage(amount=1, B_=b.serialize().hex()) for b, _ in blinded_msgs + ] + + promises = await ledger._generate_change_promises( + total_provided, invoice_amount, actual_fee_msat, outputs + ) + + assert len(promises) == expected_returned_promises + assert sum([promise.amount for promise in promises]) == expected_returned_fees + + +@pytest.mark.asyncio +async def test_generate_change_promises_returns_empty_if_no_outputs(ledger: Ledger): + invoice_amount = 100_000 + fee_reserve = 1_000 + total_provided = invoice_amount + fee_reserve + actual_fee_msat = 100_000 + outputs = None + + promises = await ledger._generate_change_promises( + total_provided, invoice_amount, actual_fee_msat, outputs + ) + assert len(promises) == 0