diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentPacket.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentPacket.scala index 9a70628a0a..a49a4ce93c 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentPacket.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentPacket.scala @@ -23,6 +23,7 @@ import fr.acinq.eclair.channel.{CMD_ADD_HTLC, CMD_FAIL_HTLC, CannotExtractShared import fr.acinq.eclair.crypto.Sphinx import fr.acinq.eclair.payment.send.Recipient import fr.acinq.eclair.router.Router.{BlindedHop, Route} +import fr.acinq.eclair.wire.protocol.OnionPaymentPayloadTlv.OutgoingBlindedPaths import fr.acinq.eclair.wire.protocol.PaymentOnion.{FinalPayload, IntermediatePayload, PerHopPayload} import fr.acinq.eclair.wire.protocol._ import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Feature, Features, MilliSatoshi, ShortChannelId, TimestampMilli, UInt64, randomKey} @@ -147,7 +148,12 @@ object IncomingPaymentPacket { // blinding point and use it to derive the decryption key for the blinded trampoline onion. decryptOnion(add.paymentHash, privateKey, trampolinePacket).flatMap { case DecodedOnionPacket(innerPayload, Some(next)) => validateNodeRelay(add, payload, innerPayload, next) - case DecodedOnionPacket(innerPayload, None) => validateTrampolineFinalPayload(add, payload, innerPayload) + case DecodedOnionPacket(innerPayload, None) => + if (innerPayload.get[OutgoingBlindedPaths].isDefined) { + Left(InvalidOnionPayload(UInt64(66102), 0)) // Trampoline to blinded paths is not yet supported. + } else { + validateTrampolineFinalPayload(add, payload, innerPayload) + } } case None => validateFinalPayload(add, payload) } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentInitiator.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentInitiator.scala index 806ec7fdba..4a3e436770 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentInitiator.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentInitiator.scala @@ -198,7 +198,7 @@ class PaymentInitiator(nodeParams: NodeParams, outgoingPaymentFactory: PaymentIn } - private def buildTrampolineRecipient(r: SendRequestedPayment, trampolineHop: NodeHop): Try[ClearTrampolineRecipient] = { + private def buildTrampolineRecipient(r: SendRequestedPayment, trampolineHop: NodeHop): Try[TrampolineRecipient] = { // We generate a random secret for the payment to the trampoline node. val trampolineSecret = r match { case r: SendPaymentToRoute => r.trampoline_opt.map(_.paymentSecret).getOrElse(randomBytes32()) @@ -206,7 +206,7 @@ class PaymentInitiator(nodeParams: NodeParams, outgoingPaymentFactory: PaymentIn } val finalExpiry = r.finalExpiry(nodeParams) r.invoice match { - case invoice: Bolt11Invoice => Success(ClearTrampolineRecipient(invoice, r.recipientAmount, finalExpiry, trampolineHop, trampolineSecret)) + case invoice: Bolt11Invoice => Success(TrampolineRecipient(invoice, r.recipientAmount, finalExpiry, trampolineHop, trampolineSecret)) case _: Bolt12Invoice => Failure(new IllegalArgumentException("trampoline blinded payments are not supported yet")) } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/Recipient.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/Recipient.scala index 2b6b216611..9069546d4b 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/Recipient.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/Recipient.scala @@ -21,7 +21,7 @@ import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.eclair.crypto.Sphinx import fr.acinq.eclair.payment.Invoice.ExtraEdge import fr.acinq.eclair.payment.OutgoingPaymentPacket._ -import fr.acinq.eclair.payment.{Bolt11Invoice, Bolt12Invoice, OutgoingPaymentPacket, PaymentBlindedRoute} +import fr.acinq.eclair.payment.{Bolt11Invoice, Bolt12Invoice, Invoice, OutgoingPaymentPacket, PaymentBlindedRoute} import fr.acinq.eclair.router.Router._ import fr.acinq.eclair.wire.protocol.PaymentOnion.{FinalPayload, IntermediatePayload, OutgoingBlindedPerHopPayload} import fr.acinq.eclair.wire.protocol.{GenericTlv, OnionRoutingPacket} @@ -190,12 +190,12 @@ object BlindedRecipient { * Note that we don't need to support the case where we'd use multiple trampoline hops in the same route: since we have * access to the network graph, it's always more efficient to find a channel route to the last trampoline node. */ -case class ClearTrampolineRecipient(invoice: Bolt11Invoice, - totalAmount: MilliSatoshi, - expiry: CltvExpiry, - trampolineHop: NodeHop, - trampolinePaymentSecret: ByteVector32, - customTlvs: Set[GenericTlv] = Set.empty) extends Recipient { +case class TrampolineRecipient(invoice: Invoice, + totalAmount: MilliSatoshi, + expiry: CltvExpiry, + trampolineHop: NodeHop, + trampolinePaymentSecret: ByteVector32, + customTlvs: Set[GenericTlv] = Set.empty) extends Recipient { require(trampolineHop.nextNodeId == invoice.nodeId, "trampoline hop must end at the recipient") val trampolineNodeId = trampolineHop.nodeId @@ -225,19 +225,25 @@ case class ClearTrampolineRecipient(invoice: Bolt11Invoice, } private def createTrampolinePacket(paymentHash: ByteVector32, trampolineHop: NodeHop): Either[OutgoingPaymentError, Sphinx.PacketAndSecrets] = { - if (invoice.features.hasFeature(Features.TrampolinePaymentPrototype)) { - // This is the payload the final recipient will receive, so we use the invoice's payment secret. - val finalPayload = NodePayload(nodeId, FinalPayload.Standard.createPayload(totalAmount, totalAmount, expiry, invoice.paymentSecret, invoice.paymentMetadata, customTlvs)) - val trampolinePayload = NodePayload(trampolineHop.nodeId, IntermediatePayload.NodeRelay.Standard(totalAmount, expiry, nodeId)) - val payloads = Seq(trampolinePayload, finalPayload) - OutgoingPaymentPacket.buildOnion(payloads, paymentHash, packetPayloadLength_opt = None) - } else { - // The recipient doesn't support trampoline: the trampoline node will convert the payment to a non-trampoline payment. - // The final payload will thus never reach the recipient, so we create the smallest payload possible to avoid overflowing the trampoline onion size. - val dummyFinalPayload = NodePayload(nodeId, IntermediatePayload.ChannelRelay.Standard(ShortChannelId(0), 0 msat, CltvExpiry(0))) - val trampolinePayload = NodePayload(trampolineHop.nodeId, IntermediatePayload.NodeRelay.Standard.createNodeRelayToNonTrampolinePayload(totalAmount, totalAmount, expiry, nodeId, invoice)) - val payloads = Seq(trampolinePayload, dummyFinalPayload) - OutgoingPaymentPacket.buildOnion(payloads, paymentHash, packetPayloadLength_opt = None) + invoice match { + case invoice: Bolt11Invoice => + if (invoice.features.hasFeature(Features.TrampolinePaymentPrototype)) { + // This is the payload the final recipient will receive, so we use the invoice's payment secret. + val finalPayload = NodePayload(nodeId, FinalPayload.Standard.createPayload(totalAmount, totalAmount, expiry, invoice.paymentSecret, invoice.paymentMetadata, customTlvs)) + val trampolinePayload = NodePayload(trampolineHop.nodeId, IntermediatePayload.NodeRelay.Standard(totalAmount, expiry, nodeId)) + val payloads = Seq(trampolinePayload, finalPayload) + OutgoingPaymentPacket.buildOnion(payloads, paymentHash, packetPayloadLength_opt = None) + } else { + // The recipient doesn't support trampoline: the trampoline node will convert the payment to a non-trampoline payment. + // The final payload will thus never reach the recipient, so we create the smallest payload possible to avoid overflowing the trampoline onion size. + val dummyFinalPayload = NodePayload(nodeId, IntermediatePayload.ChannelRelay.Standard(ShortChannelId(0), 0 msat, CltvExpiry(0))) + val trampolinePayload = NodePayload(trampolineHop.nodeId, IntermediatePayload.NodeRelay.Standard.createNodeRelayToNonTrampolinePayload(totalAmount, totalAmount, expiry, nodeId, invoice)) + val payloads = Seq(trampolinePayload, dummyFinalPayload) + OutgoingPaymentPacket.buildOnion(payloads, paymentHash, packetPayloadLength_opt = None) + } + case invoice: Bolt12Invoice => + val trampolinePayload = NodePayload(trampolineHop.nodeId, IntermediatePayload.NodeRelay.ToBlindedPaths(totalAmount, expiry, invoice)) + OutgoingPaymentPacket.buildOnion(Seq(trampolinePayload), paymentHash, packetPayloadLength_opt = None) } } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/RouteCalculation.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/RouteCalculation.scala index b0b1e4c4c9..8c8e524afb 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/router/RouteCalculation.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/router/RouteCalculation.scala @@ -164,7 +164,7 @@ object RouteCalculation { // In that case, we will slightly over-estimate the fee we're paying, but at least we won't exceed our fee budget. val maxFee = totalMaxFee - pendingChannelFee - r.pendingPayments.map(_.blindedFee).sum (targetNodeId, amountToSend, maxFee, extraEdges) - case recipient: ClearTrampolineRecipient => + case recipient: TrampolineRecipient => // Trampoline payments require finding routes to the trampoline node, not the final recipient. // This also ensures that we correctly take the trampoline fee into account only once, even when using MPP to // reach the trampoline node (which will aggregate the incoming MPP payment and re-split as necessary). @@ -180,7 +180,7 @@ object RouteCalculation { recipient match { case _: ClearRecipient => Some(route) case _: SpontaneousRecipient => Some(route) - case recipient: ClearTrampolineRecipient => Some(route.copy(finalHop_opt = Some(recipient.trampolineHop))) + case recipient: TrampolineRecipient => Some(route.copy(finalHop_opt = Some(recipient.trampolineHop))) case recipient: BlindedRecipient => route.hops.lastOption.flatMap { hop => recipient.blindedHops.find(_.dummyId == hop.shortChannelId) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/OfferCodecs.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/OfferCodecs.scala index f8d8a6abae..5bee31ab1e 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/OfferCodecs.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/OfferCodecs.scala @@ -131,7 +131,7 @@ object OfferCodecs { private val invoicePaths: Codec[InvoicePaths] = tlvField(list(pathCodec).xmap[Seq[BlindedContactInfo]](_.toSeq, _.toList)) - private val paymentInfo: Codec[PaymentInfo] = + val paymentInfo: Codec[PaymentInfo] = (("fee_base_msat" | millisatoshi32) :: ("fee_proportional_millionths" | uint32) :: ("cltv_expiry_delta" | cltvExpiryDelta) :: diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/PaymentOnion.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/PaymentOnion.scala index 6cb9608914..084358a9c1 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/PaymentOnion.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/PaymentOnion.scala @@ -18,7 +18,7 @@ package fr.acinq.eclair.wire.protocol import fr.acinq.bitcoin.scalacompat.ByteVector32 import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey -import fr.acinq.eclair.payment.Bolt11Invoice +import fr.acinq.eclair.payment.{Bolt11Invoice, Bolt12Invoice, PaymentBlindedContactInfo} import fr.acinq.eclair.wire.protocol.CommonCodecs._ import fr.acinq.eclair.wire.protocol.OnionRoutingCodecs.{ForbiddenTlv, InvalidTlvPayload, MissingRequiredTlv} import fr.acinq.eclair.wire.protocol.TlvCodecs._ @@ -184,6 +184,9 @@ object OnionPaymentPayloadTlv { /** Only included for intermediate trampoline nodes that should wait before forwarding this payment */ case class AsyncPayment() extends OnionPaymentPayloadTlv + + /** Blinded paths to relay the payment to */ + case class OutgoingBlindedPaths(paths: Seq[PaymentBlindedContactInfo]) extends OnionPaymentPayloadTlv } object PaymentOnion { @@ -191,25 +194,25 @@ object PaymentOnion { import OnionPaymentPayloadTlv._ /* - * PerHopPayload - * | - * | - * +------------------------------+-----------------------------+ - * | | | - * | | | - * IntermediatePayload FinalPayload OutgoingBlindedPerHopPayload - * | | - * | | - * +---------+---------+ +------+------+ - * | | | | - * | | | | - * ChannelRelay NodeRelay Standard Blinded - * | | - * | | - * +------+------+ | - * | | | - * | | | - * Standard Blinded Standard + * PerHopPayload + * | + * | + * +---------------------------------+-----------------------------+ + * | | | + * | | | + * IntermediatePayload FinalPayload OutgoingBlindedPerHopPayload + * | | + * | | + * +------------+-------------+ +------+------+ + * | | | | + * | | | | + * ChannelRelay NodeRelay Standard Blinded + * | | + * | | + * +------+------+ +----------------+ + * | | | | + * | | | | + * Standard Blinded Standard ToBlindedPaths */ /** Per-hop payload from an HTLC's payment onion (after decryption and decoding). */ @@ -287,13 +290,12 @@ object PaymentOnion { } sealed trait NodeRelay extends IntermediatePayload { - def outgoingNodeId: PublicKey + val amountToForward = records.get[AmountToForward].get.amount + val outgoingCltv = records.get[OutgoingCltv].get.cltv } object NodeRelay { case class Standard(records: TlvStream[OnionPaymentPayloadTlv]) extends NodeRelay { - val amountToForward = records.get[AmountToForward].get.amount - val outgoingCltv = records.get[OutgoingCltv].get.cltv val outgoingNodeId = records.get[OutgoingNodeId].get.nodeId // The following fields are only included in the trampoline-to-legacy case. val totalAmount = records.get[PaymentData].map(_.totalAmount match { @@ -323,7 +325,6 @@ object PaymentOnion { } /** Create a trampoline inner payload instructing the trampoline node to relay via a non-trampoline payment. */ - // TODO: Allow sending blinded routes to trampoline nodes instead of routing hints to support BOLT12Invoice def createNodeRelayToNonTrampolinePayload(amount: MilliSatoshi, totalAmount: MilliSatoshi, expiry: CltvExpiry, targetNodeId: PublicKey, invoice: Bolt11Invoice): Standard = { val tlvs: Set[OnionPaymentPayloadTlv] = Set( Some(AmountToForward(amount)), @@ -342,6 +343,33 @@ object PaymentOnion { Standard(TlvStream(AmountToForward(amount), OutgoingCltv(expiry), OutgoingNodeId(nextNodeId), AsyncPayment())) } } + + case class ToBlindedPaths(records: TlvStream[OnionPaymentPayloadTlv]) extends NodeRelay { + val outgoingBlindedPaths = records.get[OutgoingBlindedPaths].get.paths + val invoiceFeatures = records.get[InvoiceFeatures].get.features + } + + object ToBlindedPaths { + def apply(amount: MilliSatoshi, expiry: CltvExpiry, invoice: Bolt12Invoice): ToBlindedPaths = { + val tlvs: Set[OnionPaymentPayloadTlv] = Set( + Some(AmountToForward(amount)), + Some(OutgoingCltv(expiry)), + Some(OutgoingBlindedPaths(invoice.blindedPaths)), + Some(InvoiceFeatures(invoice.features.toByteVector)), + ).flatten + ToBlindedPaths(TlvStream(tlvs)) + } + + def validate(records: TlvStream[OnionPaymentPayloadTlv]): Either[InvalidTlvPayload, ToBlindedPaths] = { + if (records.get[AmountToForward].isEmpty) return Left(MissingRequiredTlv(UInt64(2))) + if (records.get[OutgoingCltv].isEmpty) return Left(MissingRequiredTlv(UInt64(4))) + if (records.get[OutgoingBlindedPaths].isEmpty) return Left(MissingRequiredTlv(UInt64(66102))) + if (records.get[InvoiceFeatures].isEmpty) return Left(MissingRequiredTlv(UInt64(66097))) + if (records.get[EncryptedRecipientData].nonEmpty) return Left(ForbiddenTlv(UInt64(10))) + if (records.get[BlindingPoint].nonEmpty) return Left(ForbiddenTlv(UInt64(12))) + Right(ToBlindedPaths(records)) + } + } } } @@ -509,6 +537,13 @@ object PaymentOnionCodecs { private val trampolineOnion: Codec[TrampolineOnion] = tlvField(OnionRoutingCodecs.variableSizeOnionRoutingPacketCodec) + private val paymentBlindedContactInfo: Codec[PaymentBlindedContactInfo] = + (("route" | OfferCodecs.pathCodec) :: + ("paymentInfo" | OfferCodecs.paymentInfo)).as[PaymentBlindedContactInfo] + + private val outgoingBlindedPaths: Codec[OutgoingBlindedPaths] = + tlvField(list(paymentBlindedContactInfo).xmap[Seq[PaymentBlindedContactInfo]](_.toSeq, _.toList)) + private val keySend: Codec[KeySend] = tlvField(bytes32) private val asyncPayment: Codec[AsyncPayment] = tlvField(provide(AsyncPayment())) @@ -527,6 +562,7 @@ object PaymentOnionCodecs { .typecase(UInt64(66098), outgoingNodeId) .typecase(UInt64(66099), invoiceRoutingInfo) .typecase(UInt64(66100), trampolineOnion) + .typecase(UInt64(66102), outgoingBlindedPaths) .typecase(UInt64(181324718L), asyncPayment) .typecase(UInt64(5482373484L), keySend) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartPaymentLifecycleSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartPaymentLifecycleSpec.scala index ee5da5b8ff..d86a9db2ff 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartPaymentLifecycleSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartPaymentLifecycleSpec.scala @@ -147,7 +147,7 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS assert(payFsm.stateName == WAIT_FOR_PAYMENT_REQUEST) val invoice = Bolt11Invoice(Block.RegtestGenesisBlock.hash, Some(finalAmount), randomBytes32(), randomKey(), Left("invoice"), CltvExpiryDelta(12)) val trampolineHop = NodeHop(e, invoice.nodeId, CltvExpiryDelta(50), 1000 msat) - val recipient = ClearTrampolineRecipient(invoice, finalAmount, expiry, trampolineHop, randomBytes32()) + val recipient = TrampolineRecipient(invoice, finalAmount, expiry, trampolineHop, randomBytes32()) val payment = SendMultiPartPayment(sender.ref, recipient, 1, routeParams) sender.send(payFsm, payment) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentInitiatorSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentInitiatorSpec.scala index ef8fd8df15..32a8be4738 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentInitiatorSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentInitiatorSpec.scala @@ -394,11 +394,11 @@ class PaymentInitiatorSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike assert(msg.recipient.totalAmount == finalAmount) assert(msg.recipient.expiry.toLong == currentBlockCount + 9 + 1) assert(msg.recipient.features.hasFeature(Features.TrampolinePaymentPrototype)) - assert(msg.recipient.isInstanceOf[ClearTrampolineRecipient]) - assert(msg.recipient.asInstanceOf[ClearTrampolineRecipient].trampolineNodeId == b) - assert(msg.recipient.asInstanceOf[ClearTrampolineRecipient].trampolineAmount == finalAmount + trampolineFees) - assert(msg.recipient.asInstanceOf[ClearTrampolineRecipient].trampolineExpiry == CltvExpiry(currentBlockCount + 9 + 1 + 12)) - assert(msg.recipient.asInstanceOf[ClearTrampolineRecipient].trampolinePaymentSecret != invoice.paymentSecret) // we should not leak the invoice secret to the trampoline node + assert(msg.recipient.isInstanceOf[TrampolineRecipient]) + assert(msg.recipient.asInstanceOf[TrampolineRecipient].trampolineNodeId == b) + assert(msg.recipient.asInstanceOf[TrampolineRecipient].trampolineAmount == finalAmount + trampolineFees) + assert(msg.recipient.asInstanceOf[TrampolineRecipient].trampolineExpiry == CltvExpiry(currentBlockCount + 9 + 1 + 12)) + assert(msg.recipient.asInstanceOf[TrampolineRecipient].trampolinePaymentSecret != invoice.paymentSecret) // we should not leak the invoice secret to the trampoline node assert(msg.maxAttempts == nodeParams.maxPaymentAttempts) } @@ -416,11 +416,11 @@ class PaymentInitiatorSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike assert(msg.recipient.totalAmount == finalAmount) assert(msg.recipient.expiry.toLong == currentBlockCount + 9 + 1) assert(!msg.recipient.features.hasFeature(Features.TrampolinePaymentPrototype)) - assert(msg.recipient.isInstanceOf[ClearTrampolineRecipient]) - assert(msg.recipient.asInstanceOf[ClearTrampolineRecipient].trampolineNodeId == b) - assert(msg.recipient.asInstanceOf[ClearTrampolineRecipient].trampolineAmount == finalAmount + trampolineFees) - assert(msg.recipient.asInstanceOf[ClearTrampolineRecipient].trampolineExpiry == CltvExpiry(currentBlockCount + 9 + 1 + 12)) - assert(msg.recipient.asInstanceOf[ClearTrampolineRecipient].trampolinePaymentSecret != invoice.paymentSecret) // we should not leak the invoice secret to the trampoline node + assert(msg.recipient.isInstanceOf[TrampolineRecipient]) + assert(msg.recipient.asInstanceOf[TrampolineRecipient].trampolineNodeId == b) + assert(msg.recipient.asInstanceOf[TrampolineRecipient].trampolineAmount == finalAmount + trampolineFees) + assert(msg.recipient.asInstanceOf[TrampolineRecipient].trampolineExpiry == CltvExpiry(currentBlockCount + 9 + 1 + 12)) + assert(msg.recipient.asInstanceOf[TrampolineRecipient].trampolinePaymentSecret != invoice.paymentSecret) // we should not leak the invoice secret to the trampoline node assert(msg.maxAttempts == nodeParams.maxPaymentAttempts) } @@ -454,7 +454,7 @@ class PaymentInitiatorSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike val msg1 = multiPartPayFsm.expectMsgType[SendMultiPartPayment] assert(msg1.recipient.totalAmount == finalAmount) - assert(msg1.recipient.asInstanceOf[ClearTrampolineRecipient].trampolineAmount == finalAmount + 21_000.msat) + assert(msg1.recipient.asInstanceOf[TrampolineRecipient].trampolineAmount == finalAmount + 21_000.msat) sender.send(initiator, GetPayment(PaymentIdentifier.PaymentUUID(id))) sender.expectMsgType[PaymentIsPending] @@ -464,7 +464,7 @@ class PaymentInitiatorSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike multiPartPayFsm.expectMsgType[SendPaymentConfig] val msg2 = multiPartPayFsm.expectMsgType[SendMultiPartPayment] assert(msg2.recipient.totalAmount == finalAmount) - assert(msg2.recipient.asInstanceOf[ClearTrampolineRecipient].trampolineAmount == finalAmount + 25_000.msat) + assert(msg2.recipient.asInstanceOf[TrampolineRecipient].trampolineAmount == finalAmount + 25_000.msat) // Simulate success which should publish the event and respond to the original sender. val success = PaymentSent(cfg.parentId, invoice.paymentHash, randomBytes32(), finalAmount, c, Seq(PaymentSent.PartialPayment(UUID.randomUUID(), 1000 msat, 500 msat, randomBytes32(), None))) @@ -491,14 +491,14 @@ class PaymentInitiatorSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike val msg1 = multiPartPayFsm.expectMsgType[SendMultiPartPayment] assert(msg1.recipient.totalAmount == finalAmount) - assert(msg1.recipient.asInstanceOf[ClearTrampolineRecipient].trampolineAmount == finalAmount + 21_000.msat) + assert(msg1.recipient.asInstanceOf[TrampolineRecipient].trampolineAmount == finalAmount + 21_000.msat) // Simulate a failure which should trigger a retry. multiPartPayFsm.send(initiator, PaymentFailed(cfg.parentId, invoice.paymentHash, Seq(RemoteFailure(msg1.recipient.totalAmount, Nil, Sphinx.DecryptedFailurePacket(b, TrampolineFeeInsufficient()))))) multiPartPayFsm.expectMsgType[SendPaymentConfig] val msg2 = multiPartPayFsm.expectMsgType[SendMultiPartPayment] assert(msg2.recipient.totalAmount == finalAmount) - assert(msg2.recipient.asInstanceOf[ClearTrampolineRecipient].trampolineAmount == finalAmount + 25_000.msat) + assert(msg2.recipient.asInstanceOf[TrampolineRecipient].trampolineAmount == finalAmount + 25_000.msat) // Simulate a failure that exhausts payment attempts. val failed = PaymentFailed(cfg.parentId, invoice.paymentHash, Seq(RemoteFailure(msg2.recipient.totalAmount, Nil, Sphinx.DecryptedFailurePacket(b, TemporaryNodeFailure())))) @@ -519,13 +519,13 @@ class PaymentInitiatorSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike val cfg = multiPartPayFsm.expectMsgType[SendPaymentConfig] val msg1 = multiPartPayFsm.expectMsgType[SendMultiPartPayment] - assert(msg1.recipient.asInstanceOf[ClearTrampolineRecipient].trampolineAmount == finalAmount + 21_000.msat) + assert(msg1.recipient.asInstanceOf[TrampolineRecipient].trampolineAmount == finalAmount + 21_000.msat) // Trampoline node couldn't find a route for the given fee. val failed = PaymentFailed(cfg.parentId, invoice.paymentHash, Seq(RemoteFailure(msg1.recipient.totalAmount, Nil, Sphinx.DecryptedFailurePacket(b, TrampolineFeeInsufficient())))) multiPartPayFsm.send(initiator, failed) multiPartPayFsm.expectMsgType[SendPaymentConfig] val msg2 = multiPartPayFsm.expectMsgType[SendMultiPartPayment] - assert(msg2.recipient.asInstanceOf[ClearTrampolineRecipient].trampolineAmount == finalAmount + 25_000.msat) + assert(msg2.recipient.asInstanceOf[TrampolineRecipient].trampolineAmount == finalAmount + 25_000.msat) // Trampoline node couldn't find a route even with the increased fee. multiPartPayFsm.send(initiator, failed) @@ -550,9 +550,9 @@ class PaymentInitiatorSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike assert(msg.route == Left(route)) assert(msg.amount == finalAmount + trampolineAttempt.fees) assert(msg.recipient.totalAmount == finalAmount) - assert(msg.recipient.isInstanceOf[ClearTrampolineRecipient]) - assert(msg.recipient.asInstanceOf[ClearTrampolineRecipient].trampolineAmount == finalAmount + trampolineAttempt.fees) - assert(msg.recipient.asInstanceOf[ClearTrampolineRecipient].trampolinePaymentSecret == payment.trampolineSecret.get) + assert(msg.recipient.isInstanceOf[TrampolineRecipient]) + assert(msg.recipient.asInstanceOf[TrampolineRecipient].trampolineAmount == finalAmount + trampolineAttempt.fees) + assert(msg.recipient.asInstanceOf[TrampolineRecipient].trampolinePaymentSecret == payment.trampolineSecret.get) } } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentPacketSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentPacketSpec.scala index eecfa5d933..64c42ae372 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentPacketSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentPacketSpec.scala @@ -27,7 +27,7 @@ import fr.acinq.eclair.channel.fsm.Channel import fr.acinq.eclair.crypto.{ShaChain, Sphinx} import fr.acinq.eclair.payment.IncomingPaymentPacket.{ChannelRelayPacket, FinalPacket, NodeRelayPacket, decrypt} import fr.acinq.eclair.payment.OutgoingPaymentPacket._ -import fr.acinq.eclair.payment.send.{BlindedRecipient, ClearRecipient, ClearTrampolineRecipient} +import fr.acinq.eclair.payment.send.{BlindedRecipient, ClearRecipient, TrampolineRecipient} import fr.acinq.eclair.router.BaseRouterSpec.{blindedRouteFromHops, channelHopFromUpdate} import fr.acinq.eclair.router.BlindedRouteCreation import fr.acinq.eclair.router.Router.{NodeHop, Route} @@ -286,7 +286,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { // a -> b -> c e val invoiceFeatures = Features[Bolt11Feature](VariableLengthOnion -> Mandatory, PaymentSecret -> Mandatory, BasicMultiPartPayment -> Optional, PaymentMetadata -> Optional, TrampolinePaymentPrototype -> Optional) val invoice = Bolt11Invoice(Block.RegtestGenesisBlock.hash, None, paymentHash, priv_e.privateKey, Left("invoice"), CltvExpiryDelta(6), paymentSecret = paymentSecret, features = invoiceFeatures, paymentMetadata = Some(hex"010203")) - val recipient = ClearTrampolineRecipient(invoice, finalAmount, finalExpiry, trampolineHop, randomBytes32()) + val recipient = TrampolineRecipient(invoice, finalAmount, finalExpiry, trampolineHop, randomBytes32()) assert(recipient.trampolineAmount == amount_bc) assert(recipient.trampolineExpiry == expiry_bc) val Right(payment) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, Route(recipient.trampolineAmount, trampolineChannelHops, Some(trampolineHop)), recipient) @@ -339,7 +339,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { val routingHints = List(List(Bolt11Invoice.ExtraHop(randomKey().publicKey, ShortChannelId(42), 10 msat, 100, CltvExpiryDelta(144)))) val invoiceFeatures = Features[Bolt11Feature](VariableLengthOnion -> Mandatory, PaymentSecret -> Mandatory, BasicMultiPartPayment -> Optional) val invoice = Bolt11Invoice(Block.RegtestGenesisBlock.hash, Some(finalAmount), paymentHash, priv_e.privateKey, Left("#reckless"), CltvExpiryDelta(18), extraHops = routingHints, features = invoiceFeatures, paymentMetadata = Some(hex"010203")) - val recipient = ClearTrampolineRecipient(invoice, finalAmount, finalExpiry, trampolineHop, randomBytes32()) + val recipient = TrampolineRecipient(invoice, finalAmount, finalExpiry, trampolineHop, randomBytes32()) assert(recipient.trampolineAmount == amount_bc) assert(recipient.trampolineExpiry == expiry_bc) val Right(payment) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, Route(recipient.trampolineAmount, trampolineChannelHops, Some(trampolineHop)), recipient) @@ -393,7 +393,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { val paymentMetadata = ByteVector.fromValidHex("2a" * 450) val invoiceFeatures = Features[Bolt11Feature](VariableLengthOnion -> Mandatory, PaymentSecret -> Mandatory, BasicMultiPartPayment -> Optional) val invoice = Bolt11Invoice(Block.RegtestGenesisBlock.hash, Some(finalAmount), paymentHash, priv_e.privateKey, Left("#reckless"), CltvExpiryDelta(18), extraHops = routingHints, features = invoiceFeatures, paymentMetadata = Some(paymentMetadata)) - val recipient = ClearTrampolineRecipient(invoice, finalAmount, finalExpiry, trampolineHop, randomBytes32()) + val recipient = TrampolineRecipient(invoice, finalAmount, finalExpiry, trampolineHop, randomBytes32()) val Right(payment) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, Route(recipient.trampolineAmount, trampolineChannelHops, Some(trampolineHop)), recipient) assert(payment.outgoingChannel == channelUpdate_ab.shortChannelId) @@ -432,7 +432,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { test("fail to build outgoing trampoline payment with invalid route") { val invoiceFeatures = Features[Bolt11Feature](VariableLengthOnion -> Mandatory, PaymentSecret -> Mandatory, BasicMultiPartPayment -> Optional, PaymentMetadata -> Optional, TrampolinePaymentPrototype -> Optional) val invoice = Bolt11Invoice(Block.RegtestGenesisBlock.hash, None, paymentHash, priv_e.privateKey, Left("invoice"), CltvExpiryDelta(6), paymentSecret = paymentSecret, features = invoiceFeatures) - val recipient = ClearTrampolineRecipient(invoice, finalAmount, finalExpiry, trampolineHop, randomBytes32()) + val recipient = TrampolineRecipient(invoice, finalAmount, finalExpiry, trampolineHop, randomBytes32()) val route = Route(finalAmount, trampolineChannelHops, None) // missing trampoline hop val Left(failure) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, route, recipient) assert(failure == MissingTrampolineHop(c)) @@ -457,7 +457,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { test("fail to decrypt when the trampoline onion is invalid") { val invoiceFeatures = Features[Bolt11Feature](VariableLengthOnion -> Mandatory, PaymentSecret -> Mandatory, BasicMultiPartPayment -> Optional, PaymentMetadata -> Optional, TrampolinePaymentPrototype -> Optional) val invoice = Bolt11Invoice(Block.RegtestGenesisBlock.hash, None, paymentHash, priv_e.privateKey, Left("invoice"), CltvExpiryDelta(6), paymentSecret = paymentSecret, features = invoiceFeatures, paymentMetadata = Some(hex"010203")) - val recipient = ClearTrampolineRecipient(invoice, finalAmount, finalExpiry, trampolineHop, randomBytes32()) + val recipient = TrampolineRecipient(invoice, finalAmount, finalExpiry, trampolineHop, randomBytes32()) val Right(payment) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, Route(recipient.trampolineAmount, trampolineChannelHops, Some(trampolineHop)), recipient) val add_b = UpdateAddHtlc(randomBytes32(), 1, payment.cmd.amount, paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion, None) @@ -599,7 +599,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { def createIntermediateTrampolinePayment(): UpdateAddHtlc = { val invoiceFeatures = Features[Bolt11Feature](VariableLengthOnion -> Mandatory, PaymentSecret -> Mandatory, BasicMultiPartPayment -> Optional, TrampolinePaymentPrototype -> Optional) val invoice = Bolt11Invoice(Block.RegtestGenesisBlock.hash, None, paymentHash, priv_e.privateKey, Left("invoice"), CltvExpiryDelta(6), paymentSecret = paymentSecret, features = invoiceFeatures) - val recipient = ClearTrampolineRecipient(invoice, finalAmount, finalExpiry, trampolineHop, randomBytes32()) + val recipient = TrampolineRecipient(invoice, finalAmount, finalExpiry, trampolineHop, randomBytes32()) val Right(payment) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, Route(recipient.trampolineAmount, trampolineChannelHops, Some(trampolineHop)), recipient) val add_b = UpdateAddHtlc(randomBytes32(), 1, payment.cmd.amount, paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion, None) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/RelayerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/RelayerSpec.scala index 33d6c3560a..c0c4a9fa42 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/RelayerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/RelayerSpec.scala @@ -33,7 +33,7 @@ import fr.acinq.eclair.payment.IncomingPaymentPacket.FinalPacket import fr.acinq.eclair.payment.OutgoingPaymentPacket.{NodePayload, Upstream, buildOnion, buildOutgoingPayment} import fr.acinq.eclair.payment.PaymentPacketSpec._ import fr.acinq.eclair.payment.relay.Relayer._ -import fr.acinq.eclair.payment.send.{ClearRecipient, ClearTrampolineRecipient} +import fr.acinq.eclair.payment.send.{ClearRecipient, TrampolineRecipient} import fr.acinq.eclair.router.BaseRouterSpec.{blindedRouteFromHops, channelHopFromUpdate} import fr.acinq.eclair.router.Router.{NodeHop, Route} import fr.acinq.eclair.wire.protocol.PaymentOnion.FinalPayload @@ -198,7 +198,7 @@ class RelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("applicat val invoiceFeatures = Features[Bolt11Feature](VariableLengthOnion -> Mandatory, PaymentSecret -> Mandatory, BasicMultiPartPayment -> Optional, PaymentMetadata -> Optional, TrampolinePaymentPrototype -> Optional) val invoice = Bolt11Invoice(Block.RegtestGenesisBlock.hash, None, paymentHash, priv_c.privateKey, Left("invoice"), CltvExpiryDelta(6), paymentSecret = paymentSecret, features = invoiceFeatures) val trampolineHop = NodeHop(b, c, channelUpdate_bc.cltvExpiryDelta, fee_b) - val recipient = ClearTrampolineRecipient(invoice, finalAmount, finalExpiry, trampolineHop, randomBytes32()) + val recipient = TrampolineRecipient(invoice, finalAmount, finalExpiry, trampolineHop, randomBytes32()) val Right(payment) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, Route(recipient.trampolineAmount, Seq(channelHopFromUpdate(priv_a.publicKey, b, channelUpdate_ab)), Some(trampolineHop)), recipient) // and then manually build an htlc diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/router/RouterSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/router/RouterSpec.scala index a8c82aa9be..89f79247b6 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/router/RouterSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/router/RouterSpec.scala @@ -29,7 +29,7 @@ import fr.acinq.eclair.crypto.TransportHandler import fr.acinq.eclair.io.Peer.PeerRoutingMessage import fr.acinq.eclair.payment.Bolt11Invoice.ExtraHop import fr.acinq.eclair.payment.Invoice.ExtraEdge -import fr.acinq.eclair.payment.send.{ClearRecipient, ClearTrampolineRecipient, SpontaneousRecipient} +import fr.acinq.eclair.payment.send.{ClearRecipient, TrampolineRecipient, SpontaneousRecipient} import fr.acinq.eclair.payment.{Bolt11Invoice, Invoice} import fr.acinq.eclair.router.Announcements.{makeChannelUpdate, makeNodeAnnouncement} import fr.acinq.eclair.router.BaseRouterSpec.{blindedRoutesFromPaths, channelAnnouncement} @@ -515,7 +515,7 @@ class RouterSpec extends BaseRouterSpec { val recipientKey = randomKey() val invoice = Bolt11Invoice(Block.RegtestGenesisBlock.hash, None, randomBytes32(), recipientKey, Left("invoice"), CltvExpiryDelta(6)) val trampolineHop = NodeHop(c, recipientKey.publicKey, CltvExpiryDelta(100), 25_000 msat) - val recipient = ClearTrampolineRecipient(invoice, 725_000 msat, DEFAULT_EXPIRY, trampolineHop, randomBytes32()) + val recipient = TrampolineRecipient(invoice, 725_000 msat, DEFAULT_EXPIRY, trampolineHop, randomBytes32()) sender.send(router, RouteRequest(a, recipient, routeParams)) val route1 = sender.expectMsgType[RouteResponse].routes.head assert(route1.amount == 750_000.msat) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/PaymentOnionSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/PaymentOnionSpec.scala index d92a7d1bca..5707cc6a1f 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/PaymentOnionSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/PaymentOnionSpec.scala @@ -19,12 +19,14 @@ package fr.acinq.eclair.wire.protocol import fr.acinq.bitcoin.scalacompat.ByteVector32 import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.eclair.UInt64.Conversions._ +import fr.acinq.eclair.crypto.Sphinx.RouteBlinding import fr.acinq.eclair.payment.Bolt11Invoice.ExtraHop +import fr.acinq.eclair.payment.PaymentBlindedContactInfo import fr.acinq.eclair.wire.protocol.OnionPaymentPayloadTlv._ import fr.acinq.eclair.wire.protocol.OnionRoutingCodecs.{ForbiddenTlv, InvalidTlvPayload, MissingRequiredTlv} import fr.acinq.eclair.wire.protocol.PaymentOnion._ import fr.acinq.eclair.wire.protocol.PaymentOnionCodecs._ -import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, MilliSatoshiLong, ShortChannelId, UInt64, randomKey} +import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, FeatureSupport, Features, MilliSatoshiLong, RealShortChannelId, ShortChannelId, UInt64, randomKey} import org.scalatest.funsuite.AnyFunSuite import scodec.bits.{ByteVector, HexStringSyntax} @@ -162,6 +164,29 @@ class PaymentOnionSpec extends AnyFunSuite { assert(encoded == bin) } + test("encode/decode node relay to blinded paths per-hop payload") { + val features = Features(Features.BasicMultiPartPayment -> FeatureSupport.Optional).toByteVector + val blindedRoute = OfferTypes.CompactBlindedPath( + OfferTypes.ShortChannelIdDir(isNode1 = false, RealShortChannelId(468)), + PublicKey(hex"0232882c4982576e00f0d6bd4998f5b3e92d47ecc8fbad5b6a5e7521819d891d9e"), + Seq(RouteBlinding.BlindedNode(PublicKey(hex"03823aa560d631e9d7b686be4a9227e577009afb5173023b458a6a6aff056ac980"), hex"")) + ) + val path = PaymentBlindedContactInfo(blindedRoute, OfferTypes.PaymentInfo(1000 msat, 678, CltvExpiryDelta(82), 300 msat, 4000000 msat, Features.empty)) + val expected = TlvStream[OnionPaymentPayloadTlv](AmountToForward(341 msat), OutgoingCltv(CltvExpiry(826483)), OutgoingBlindedPaths(Seq(path)), InvoiceFeatures(features)) + val bin = hex"82 02020155 04030c9c73 fe0001023103020000 fe000102366a0100000000000001d40232882c4982576e00f0d6bd4998f5b3e92d47ecc8fbad5b6a5e7521819d891d9e0103823aa560d631e9d7b686be4a9227e577009afb5173023b458a6a6aff056ac9800000000003e8000002a60052000000000000012c00000000003d09000000" + + val decoded = perHopPayloadCodec.decode(bin.bits).require.value + assert(decoded == expected) + val Right(payload) = IntermediatePayload.NodeRelay.ToBlindedPaths.validate(decoded) + assert(payload.amountToForward == 341.msat) + assert(payload.outgoingCltv == CltvExpiry(826483)) + assert(payload.outgoingBlindedPaths == Seq(path)) + assert(payload.invoiceFeatures == features) + + val encoded = perHopPayloadCodec.encode(expected).require.bytes + assert(encoded == bin) + } + test("encode/decode final per-hop payload") { val testCases = Map( TlvStream[OnionPaymentPayloadTlv](AmountToForward(561 msat), OutgoingCltv(CltvExpiry(42)), PaymentData(ByteVector32(hex"eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619"), 0 msat)) -> hex"29 02020231 04012a 0820eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619",