Skip to content

Commit

Permalink
Trampoline to blinded (types only) (#2813)
Browse files Browse the repository at this point in the history
Add types needed for trampoline to pay a list of blinded paths instead of a node id.
  • Loading branch information
thomash-acinq authored Jan 23, 2024
1 parent 5fb9fef commit e66e6d2
Show file tree
Hide file tree
Showing 12 changed files with 155 additions and 82 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import fr.acinq.eclair.channel.{CMD_ADD_HTLC, CMD_FAIL_HTLC, CannotExtractShared
import fr.acinq.eclair.crypto.Sphinx
import fr.acinq.eclair.payment.send.Recipient
import fr.acinq.eclair.router.Router.{BlindedHop, Route}
import fr.acinq.eclair.wire.protocol.OnionPaymentPayloadTlv.OutgoingBlindedPaths
import fr.acinq.eclair.wire.protocol.PaymentOnion.{FinalPayload, IntermediatePayload, PerHopPayload}
import fr.acinq.eclair.wire.protocol._
import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Feature, Features, MilliSatoshi, ShortChannelId, TimestampMilli, UInt64, randomKey}
Expand Down Expand Up @@ -147,7 +148,12 @@ object IncomingPaymentPacket {
// blinding point and use it to derive the decryption key for the blinded trampoline onion.
decryptOnion(add.paymentHash, privateKey, trampolinePacket).flatMap {
case DecodedOnionPacket(innerPayload, Some(next)) => validateNodeRelay(add, payload, innerPayload, next)
case DecodedOnionPacket(innerPayload, None) => validateTrampolineFinalPayload(add, payload, innerPayload)
case DecodedOnionPacket(innerPayload, None) =>
if (innerPayload.get[OutgoingBlindedPaths].isDefined) {
Left(InvalidOnionPayload(UInt64(66102), 0)) // Trampoline to blinded paths is not yet supported.
} else {
validateTrampolineFinalPayload(add, payload, innerPayload)
}
}
case None => validateFinalPayload(add, payload)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,15 +198,15 @@ class PaymentInitiator(nodeParams: NodeParams, outgoingPaymentFactory: PaymentIn

}

private def buildTrampolineRecipient(r: SendRequestedPayment, trampolineHop: NodeHop): Try[ClearTrampolineRecipient] = {
private def buildTrampolineRecipient(r: SendRequestedPayment, trampolineHop: NodeHop): Try[TrampolineRecipient] = {
// We generate a random secret for the payment to the trampoline node.
val trampolineSecret = r match {
case r: SendPaymentToRoute => r.trampoline_opt.map(_.paymentSecret).getOrElse(randomBytes32())
case _ => randomBytes32()
}
val finalExpiry = r.finalExpiry(nodeParams)
r.invoice match {
case invoice: Bolt11Invoice => Success(ClearTrampolineRecipient(invoice, r.recipientAmount, finalExpiry, trampolineHop, trampolineSecret))
case invoice: Bolt11Invoice => Success(TrampolineRecipient(invoice, r.recipientAmount, finalExpiry, trampolineHop, trampolineSecret))
case _: Bolt12Invoice => Failure(new IllegalArgumentException("trampoline blinded payments are not supported yet"))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey
import fr.acinq.eclair.crypto.Sphinx
import fr.acinq.eclair.payment.Invoice.ExtraEdge
import fr.acinq.eclair.payment.OutgoingPaymentPacket._
import fr.acinq.eclair.payment.{Bolt11Invoice, Bolt12Invoice, OutgoingPaymentPacket, PaymentBlindedRoute}
import fr.acinq.eclair.payment.{Bolt11Invoice, Bolt12Invoice, Invoice, OutgoingPaymentPacket, PaymentBlindedRoute}
import fr.acinq.eclair.router.Router._
import fr.acinq.eclair.wire.protocol.PaymentOnion.{FinalPayload, IntermediatePayload, OutgoingBlindedPerHopPayload}
import fr.acinq.eclair.wire.protocol.{GenericTlv, OnionRoutingPacket}
Expand Down Expand Up @@ -190,12 +190,12 @@ object BlindedRecipient {
* Note that we don't need to support the case where we'd use multiple trampoline hops in the same route: since we have
* access to the network graph, it's always more efficient to find a channel route to the last trampoline node.
*/
case class ClearTrampolineRecipient(invoice: Bolt11Invoice,
totalAmount: MilliSatoshi,
expiry: CltvExpiry,
trampolineHop: NodeHop,
trampolinePaymentSecret: ByteVector32,
customTlvs: Set[GenericTlv] = Set.empty) extends Recipient {
case class TrampolineRecipient(invoice: Invoice,
totalAmount: MilliSatoshi,
expiry: CltvExpiry,
trampolineHop: NodeHop,
trampolinePaymentSecret: ByteVector32,
customTlvs: Set[GenericTlv] = Set.empty) extends Recipient {
require(trampolineHop.nextNodeId == invoice.nodeId, "trampoline hop must end at the recipient")

val trampolineNodeId = trampolineHop.nodeId
Expand Down Expand Up @@ -225,19 +225,25 @@ case class ClearTrampolineRecipient(invoice: Bolt11Invoice,
}

private def createTrampolinePacket(paymentHash: ByteVector32, trampolineHop: NodeHop): Either[OutgoingPaymentError, Sphinx.PacketAndSecrets] = {
if (invoice.features.hasFeature(Features.TrampolinePaymentPrototype)) {
// This is the payload the final recipient will receive, so we use the invoice's payment secret.
val finalPayload = NodePayload(nodeId, FinalPayload.Standard.createPayload(totalAmount, totalAmount, expiry, invoice.paymentSecret, invoice.paymentMetadata, customTlvs))
val trampolinePayload = NodePayload(trampolineHop.nodeId, IntermediatePayload.NodeRelay.Standard(totalAmount, expiry, nodeId))
val payloads = Seq(trampolinePayload, finalPayload)
OutgoingPaymentPacket.buildOnion(payloads, paymentHash, packetPayloadLength_opt = None)
} else {
// The recipient doesn't support trampoline: the trampoline node will convert the payment to a non-trampoline payment.
// The final payload will thus never reach the recipient, so we create the smallest payload possible to avoid overflowing the trampoline onion size.
val dummyFinalPayload = NodePayload(nodeId, IntermediatePayload.ChannelRelay.Standard(ShortChannelId(0), 0 msat, CltvExpiry(0)))
val trampolinePayload = NodePayload(trampolineHop.nodeId, IntermediatePayload.NodeRelay.Standard.createNodeRelayToNonTrampolinePayload(totalAmount, totalAmount, expiry, nodeId, invoice))
val payloads = Seq(trampolinePayload, dummyFinalPayload)
OutgoingPaymentPacket.buildOnion(payloads, paymentHash, packetPayloadLength_opt = None)
invoice match {
case invoice: Bolt11Invoice =>
if (invoice.features.hasFeature(Features.TrampolinePaymentPrototype)) {
// This is the payload the final recipient will receive, so we use the invoice's payment secret.
val finalPayload = NodePayload(nodeId, FinalPayload.Standard.createPayload(totalAmount, totalAmount, expiry, invoice.paymentSecret, invoice.paymentMetadata, customTlvs))
val trampolinePayload = NodePayload(trampolineHop.nodeId, IntermediatePayload.NodeRelay.Standard(totalAmount, expiry, nodeId))
val payloads = Seq(trampolinePayload, finalPayload)
OutgoingPaymentPacket.buildOnion(payloads, paymentHash, packetPayloadLength_opt = None)
} else {
// The recipient doesn't support trampoline: the trampoline node will convert the payment to a non-trampoline payment.
// The final payload will thus never reach the recipient, so we create the smallest payload possible to avoid overflowing the trampoline onion size.
val dummyFinalPayload = NodePayload(nodeId, IntermediatePayload.ChannelRelay.Standard(ShortChannelId(0), 0 msat, CltvExpiry(0)))
val trampolinePayload = NodePayload(trampolineHop.nodeId, IntermediatePayload.NodeRelay.Standard.createNodeRelayToNonTrampolinePayload(totalAmount, totalAmount, expiry, nodeId, invoice))
val payloads = Seq(trampolinePayload, dummyFinalPayload)
OutgoingPaymentPacket.buildOnion(payloads, paymentHash, packetPayloadLength_opt = None)
}
case invoice: Bolt12Invoice =>
val trampolinePayload = NodePayload(trampolineHop.nodeId, IntermediatePayload.NodeRelay.ToBlindedPaths(totalAmount, expiry, invoice))
OutgoingPaymentPacket.buildOnion(Seq(trampolinePayload), paymentHash, packetPayloadLength_opt = None)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ object RouteCalculation {
// In that case, we will slightly over-estimate the fee we're paying, but at least we won't exceed our fee budget.
val maxFee = totalMaxFee - pendingChannelFee - r.pendingPayments.map(_.blindedFee).sum
(targetNodeId, amountToSend, maxFee, extraEdges)
case recipient: ClearTrampolineRecipient =>
case recipient: TrampolineRecipient =>
// Trampoline payments require finding routes to the trampoline node, not the final recipient.
// This also ensures that we correctly take the trampoline fee into account only once, even when using MPP to
// reach the trampoline node (which will aggregate the incoming MPP payment and re-split as necessary).
Expand All @@ -180,7 +180,7 @@ object RouteCalculation {
recipient match {
case _: ClearRecipient => Some(route)
case _: SpontaneousRecipient => Some(route)
case recipient: ClearTrampolineRecipient => Some(route.copy(finalHop_opt = Some(recipient.trampolineHop)))
case recipient: TrampolineRecipient => Some(route.copy(finalHop_opt = Some(recipient.trampolineHop)))
case recipient: BlindedRecipient =>
route.hops.lastOption.flatMap {
hop => recipient.blindedHops.find(_.dummyId == hop.shortChannelId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ object OfferCodecs {

private val invoicePaths: Codec[InvoicePaths] = tlvField(list(pathCodec).xmap[Seq[BlindedContactInfo]](_.toSeq, _.toList))

private val paymentInfo: Codec[PaymentInfo] =
val paymentInfo: Codec[PaymentInfo] =
(("fee_base_msat" | millisatoshi32) ::
("fee_proportional_millionths" | uint32) ::
("cltv_expiry_delta" | cltvExpiryDelta) ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package fr.acinq.eclair.wire.protocol

import fr.acinq.bitcoin.scalacompat.ByteVector32
import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey
import fr.acinq.eclair.payment.Bolt11Invoice
import fr.acinq.eclair.payment.{Bolt11Invoice, Bolt12Invoice, PaymentBlindedContactInfo}
import fr.acinq.eclair.wire.protocol.CommonCodecs._
import fr.acinq.eclair.wire.protocol.OnionRoutingCodecs.{ForbiddenTlv, InvalidTlvPayload, MissingRequiredTlv}
import fr.acinq.eclair.wire.protocol.TlvCodecs._
Expand Down Expand Up @@ -184,32 +184,35 @@ object OnionPaymentPayloadTlv {

/** Only included for intermediate trampoline nodes that should wait before forwarding this payment */
case class AsyncPayment() extends OnionPaymentPayloadTlv

/** Blinded paths to relay the payment to */
case class OutgoingBlindedPaths(paths: Seq[PaymentBlindedContactInfo]) extends OnionPaymentPayloadTlv
}

object PaymentOnion {

import OnionPaymentPayloadTlv._

/*
* PerHopPayload
* |
* |
* +------------------------------+-----------------------------+
* | | |
* | | |
* IntermediatePayload FinalPayload OutgoingBlindedPerHopPayload
* | |
* | |
* +---------+---------+ +------+------+
* | | | |
* | | | |
* ChannelRelay NodeRelay Standard Blinded
* | |
* | |
* +------+------+ |
* | | |
* | | |
* Standard Blinded Standard
* PerHopPayload
* |
* |
* +---------------------------------+-----------------------------+
* | | |
* | | |
* IntermediatePayload FinalPayload OutgoingBlindedPerHopPayload
* | |
* | |
* +------------+-------------+ +------+------+
* | | | |
* | | | |
* ChannelRelay NodeRelay Standard Blinded
* | |
* | |
* +------+------+ +----------------+
* | | | |
* | | | |
* Standard Blinded Standard ToBlindedPaths
*/

/** Per-hop payload from an HTLC's payment onion (after decryption and decoding). */
Expand Down Expand Up @@ -287,13 +290,12 @@ object PaymentOnion {
}

sealed trait NodeRelay extends IntermediatePayload {
def outgoingNodeId: PublicKey
val amountToForward = records.get[AmountToForward].get.amount
val outgoingCltv = records.get[OutgoingCltv].get.cltv
}

object NodeRelay {
case class Standard(records: TlvStream[OnionPaymentPayloadTlv]) extends NodeRelay {
val amountToForward = records.get[AmountToForward].get.amount
val outgoingCltv = records.get[OutgoingCltv].get.cltv
val outgoingNodeId = records.get[OutgoingNodeId].get.nodeId
// The following fields are only included in the trampoline-to-legacy case.
val totalAmount = records.get[PaymentData].map(_.totalAmount match {
Expand Down Expand Up @@ -323,7 +325,6 @@ object PaymentOnion {
}

/** Create a trampoline inner payload instructing the trampoline node to relay via a non-trampoline payment. */
// TODO: Allow sending blinded routes to trampoline nodes instead of routing hints to support BOLT12Invoice
def createNodeRelayToNonTrampolinePayload(amount: MilliSatoshi, totalAmount: MilliSatoshi, expiry: CltvExpiry, targetNodeId: PublicKey, invoice: Bolt11Invoice): Standard = {
val tlvs: Set[OnionPaymentPayloadTlv] = Set(
Some(AmountToForward(amount)),
Expand All @@ -342,6 +343,33 @@ object PaymentOnion {
Standard(TlvStream(AmountToForward(amount), OutgoingCltv(expiry), OutgoingNodeId(nextNodeId), AsyncPayment()))
}
}

case class ToBlindedPaths(records: TlvStream[OnionPaymentPayloadTlv]) extends NodeRelay {
val outgoingBlindedPaths = records.get[OutgoingBlindedPaths].get.paths
val invoiceFeatures = records.get[InvoiceFeatures].get.features
}

object ToBlindedPaths {
def apply(amount: MilliSatoshi, expiry: CltvExpiry, invoice: Bolt12Invoice): ToBlindedPaths = {
val tlvs: Set[OnionPaymentPayloadTlv] = Set(
Some(AmountToForward(amount)),
Some(OutgoingCltv(expiry)),
Some(OutgoingBlindedPaths(invoice.blindedPaths)),
Some(InvoiceFeatures(invoice.features.toByteVector)),
).flatten
ToBlindedPaths(TlvStream(tlvs))
}

def validate(records: TlvStream[OnionPaymentPayloadTlv]): Either[InvalidTlvPayload, ToBlindedPaths] = {
if (records.get[AmountToForward].isEmpty) return Left(MissingRequiredTlv(UInt64(2)))
if (records.get[OutgoingCltv].isEmpty) return Left(MissingRequiredTlv(UInt64(4)))
if (records.get[OutgoingBlindedPaths].isEmpty) return Left(MissingRequiredTlv(UInt64(66102)))
if (records.get[InvoiceFeatures].isEmpty) return Left(MissingRequiredTlv(UInt64(66097)))
if (records.get[EncryptedRecipientData].nonEmpty) return Left(ForbiddenTlv(UInt64(10)))
if (records.get[BlindingPoint].nonEmpty) return Left(ForbiddenTlv(UInt64(12)))
Right(ToBlindedPaths(records))
}
}
}
}

Expand Down Expand Up @@ -509,6 +537,13 @@ object PaymentOnionCodecs {

private val trampolineOnion: Codec[TrampolineOnion] = tlvField(OnionRoutingCodecs.variableSizeOnionRoutingPacketCodec)

private val paymentBlindedContactInfo: Codec[PaymentBlindedContactInfo] =
(("route" | OfferCodecs.pathCodec) ::
("paymentInfo" | OfferCodecs.paymentInfo)).as[PaymentBlindedContactInfo]

private val outgoingBlindedPaths: Codec[OutgoingBlindedPaths] =
tlvField(list(paymentBlindedContactInfo).xmap[Seq[PaymentBlindedContactInfo]](_.toSeq, _.toList))

private val keySend: Codec[KeySend] = tlvField(bytes32)

private val asyncPayment: Codec[AsyncPayment] = tlvField(provide(AsyncPayment()))
Expand All @@ -527,6 +562,7 @@ object PaymentOnionCodecs {
.typecase(UInt64(66098), outgoingNodeId)
.typecase(UInt64(66099), invoiceRoutingInfo)
.typecase(UInt64(66100), trampolineOnion)
.typecase(UInt64(66102), outgoingBlindedPaths)
.typecase(UInt64(181324718L), asyncPayment)
.typecase(UInt64(5482373484L), keySend)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS
assert(payFsm.stateName == WAIT_FOR_PAYMENT_REQUEST)
val invoice = Bolt11Invoice(Block.RegtestGenesisBlock.hash, Some(finalAmount), randomBytes32(), randomKey(), Left("invoice"), CltvExpiryDelta(12))
val trampolineHop = NodeHop(e, invoice.nodeId, CltvExpiryDelta(50), 1000 msat)
val recipient = ClearTrampolineRecipient(invoice, finalAmount, expiry, trampolineHop, randomBytes32())
val recipient = TrampolineRecipient(invoice, finalAmount, expiry, trampolineHop, randomBytes32())
val payment = SendMultiPartPayment(sender.ref, recipient, 1, routeParams)
sender.send(payFsm, payment)

Expand Down
Loading

0 comments on commit e66e6d2

Please sign in to comment.