Skip to content

Commit

Permalink
Relay onion messages to compact node id (#2821)
Browse files Browse the repository at this point in the history
To save space, blinded routes may use a compact node id (scid + direction instead of public key) as an introduction node.
When using such a compact route, the sender must use it's knowledge of the network to convert that to a public key, however trampoline users don't have that knowledge, they must transmit the compact route to the trempoline provider.
We extend the spec to allow compact node ids in the `next_node_id` field.


Co-authored-by: t-bast <[email protected]>
  • Loading branch information
thomash-acinq and t-bast authored Feb 19, 2024
1 parent 62b739a commit 86c4837
Show file tree
Hide file tree
Showing 23 changed files with 267 additions and 116 deletions.
19 changes: 19 additions & 0 deletions eclair-core/src/main/scala/fr/acinq/eclair/EncodedNodeId.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package fr.acinq.eclair

import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey

sealed trait EncodedNodeId

object EncodedNodeId {
/** Nodes are usually identified by their public key. */
case class Plain(publicKey: PublicKey) extends EncodedNodeId {
override def toString: String = publicKey.toString
}

/** For compactness, nodes may be identified by the shortChannelId of one of their public channels. */
case class ShortChannelIdDir(isNode1: Boolean, scid: RealShortChannelId) extends EncodedNodeId {
override def toString: String = if (isNode1) s"<-$scid" else s"$scid->"
}

def apply(publicKey: PublicKey): EncodedNodeId = Plain(publicKey)
}
2 changes: 1 addition & 1 deletion eclair-core/src/main/scala/fr/acinq/eclair/Setup.scala
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ class Setup(val datadir: File,
txPublisherFactory = Channel.SimpleTxPublisherFactory(nodeParams, watcher, bitcoinClient)
channelFactory = Peer.SimpleChannelFactory(nodeParams, watcher, relayer, bitcoinClient, txPublisherFactory)
pendingChannelsRateLimiter = system.spawn(Behaviors.supervise(PendingChannelsRateLimiter(nodeParams, router.toTyped, channels)).onFailure(typed.SupervisorStrategy.resume), name = "pending-channels-rate-limiter")
peerFactory = Switchboard.SimplePeerFactory(nodeParams, bitcoinClient, channelFactory, pendingChannelsRateLimiter, register)
peerFactory = Switchboard.SimplePeerFactory(nodeParams, bitcoinClient, channelFactory, pendingChannelsRateLimiter, register, router.toTyped)

switchboard = system.actorOf(SimpleSupervisor.props(Switchboard.props(nodeParams, peerFactory), "switchboard", SupervisorStrategy.Resume))
_ = switchboard ! Switchboard.Init(channels)
Expand Down
123 changes: 80 additions & 43 deletions eclair-core/src/main/scala/fr/acinq/eclair/io/MessageRelay.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,27 @@
package fr.acinq.eclair.io

import akka.actor.typed.Behavior
import akka.actor.typed.eventstream.EventStream
import akka.actor.typed.scaladsl.adapter.TypedActorRefOps
import akka.actor.typed.scaladsl.{ActorContext, Behaviors}
import akka.actor.{ActorRef, typed}
import fr.acinq.bitcoin.scalacompat.ByteVector32
import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey
import fr.acinq.eclair.ShortChannelId
import fr.acinq.eclair.channel.Register
import fr.acinq.eclair.io.Peer.{PeerInfo, PeerInfoResponse}
import fr.acinq.eclair.io.Switchboard.GetPeerInfo
import fr.acinq.eclair.message.OnionMessages
import fr.acinq.eclair.message.OnionMessages.DropReason
import fr.acinq.eclair.router.Router
import fr.acinq.eclair.wire.protocol.OnionMessage
import fr.acinq.eclair.{EncodedNodeId, NodeParams, ShortChannelId}

object MessageRelay {
// @formatter:off
sealed trait Command
case class RelayMessage(messageId: ByteVector32,
switchboard: ActorRef,
register: ActorRef,
prevNodeId: PublicKey,
nextNode: Either[ShortChannelId, PublicKey],
nextNode: Either[ShortChannelId, EncodedNodeId],
msg: OnionMessage,
policy: RelayPolicy,
replyTo_opt: Option[typed.ActorRef[Status]]) extends Command
Expand All @@ -60,66 +62,101 @@ object MessageRelay {
case class UnknownOutgoingChannel(messageId: ByteVector32, outgoingChannelId: ShortChannelId) extends Failure {
override def toString: String = s"Unknown outgoing channel: $outgoingChannelId"
}
case class DroppedMessage(messageId: ByteVector32, reason: DropReason) extends Failure {
override def toString: String = s"Message dropped: $reason"
}

sealed trait RelayPolicy
case object RelayChannelsOnly extends RelayPolicy
case object RelayAll extends RelayPolicy
// @formatter:on

def apply(): Behavior[Command] = {
Behaviors.receivePartial {
case (context, RelayMessage(messageId, switchboard, register, prevNodeId, Left(outgoingChannelId), msg, policy, replyTo_opt)) =>
def apply(nodeParams: NodeParams,
switchboard: ActorRef,
register: ActorRef,
router: typed.ActorRef[Router.GetNodeId]): Behavior[Command] = {
Behaviors.setup { context =>
Behaviors.receiveMessagePartial {
case RelayMessage(messageId, prevNodeId, nextNode, msg, policy, replyTo_opt) =>
val relay = new MessageRelay(nodeParams, messageId, prevNodeId, policy, switchboard, register, router, replyTo_opt, context)
relay.queryNextNodeId(msg, nextNode)
}
}
}
}

private class MessageRelay(nodeParams: NodeParams,
messageId: ByteVector32,
prevNodeId: PublicKey,
policy: MessageRelay.RelayPolicy,
switchboard: ActorRef,
register: ActorRef,
router: typed.ActorRef[Router.GetNodeId],
replyTo_opt: Option[typed.ActorRef[MessageRelay.Status]],
context: ActorContext[MessageRelay.Command]) {

import MessageRelay._

def queryNextNodeId(msg: OnionMessage, nextNode: Either[ShortChannelId, EncodedNodeId]): Behavior[Command] = {
nextNode match {
case Left(outgoingChannelId) =>
register ! Register.GetNextNodeId(context.messageAdapter(WrappedOptionalNodeId), outgoingChannelId)
waitForNextNodeId(messageId, switchboard, prevNodeId, outgoingChannelId, msg, policy, replyTo_opt)
case (context, RelayMessage(messageId, switchboard, _, prevNodeId, Right(nextNodeId), msg, policy, replyTo_opt)) =>
withNextNodeId(context, messageId, switchboard, prevNodeId, nextNodeId, msg, policy, replyTo_opt)
waitForNextNodeId(msg, outgoingChannelId)
case Right(EncodedNodeId.ShortChannelIdDir(isNode1, scid)) =>
router ! Router.GetNodeId(context.messageAdapter(WrappedOptionalNodeId), scid, isNode1)
waitForNextNodeId(msg, scid)
case Right(EncodedNodeId.Plain(nextNodeId)) =>
withNextNodeId(msg, nextNodeId)
}
}

def waitForNextNodeId(messageId: ByteVector32,
switchboard: ActorRef,
prevNodeId: PublicKey,
outgoingChannelId: ShortChannelId,
msg: OnionMessage,
policy: RelayPolicy,
replyTo_opt: Option[typed.ActorRef[Status]]): Behavior[Command] =
Behaviors.receivePartial {
case (_, WrappedOptionalNodeId(None)) =>
private def waitForNextNodeId(msg: OnionMessage, outgoingChannelId: ShortChannelId): Behavior[Command] =
Behaviors.receiveMessagePartial {
case WrappedOptionalNodeId(None) =>
replyTo_opt.foreach(_ ! UnknownOutgoingChannel(messageId, outgoingChannelId))
Behaviors.stopped
case (context, WrappedOptionalNodeId(Some(nextNodeId))) =>
withNextNodeId(context, messageId, switchboard, prevNodeId, nextNodeId, msg, policy, replyTo_opt)
case WrappedOptionalNodeId(Some(nextNodeId)) =>
withNextNodeId(msg, nextNodeId)
}

def withNextNodeId(context: ActorContext[Command],
messageId: ByteVector32,
switchboard: ActorRef,
prevNodeId: PublicKey,
nextNodeId: PublicKey,
msg: OnionMessage,
policy: RelayPolicy,
replyTo_opt: Option[typed.ActorRef[Status]]): Behavior[Command] =
policy match {
case RelayChannelsOnly =>
switchboard ! GetPeerInfo(context.messageAdapter(WrappedPeerInfo), prevNodeId)
waitForPreviousPeer(messageId, switchboard, nextNodeId, msg, replyTo_opt)
case RelayAll =>
switchboard ! Peer.Connect(nextNodeId, None, context.messageAdapter(WrappedConnectionResult).toClassic, isPersistent = false)
waitForConnection(messageId, msg, replyTo_opt)
}
private def withNextNodeId(msg: OnionMessage, nextNodeId: PublicKey): Behavior[Command] = {
if (nextNodeId == nodeParams.nodeId) {
OnionMessages.process(nodeParams.privateKey, msg) match {
case OnionMessages.DropMessage(reason) =>
replyTo_opt.foreach(_ ! DroppedMessage(messageId, reason))
Behaviors.stopped
case OnionMessages.SendMessage(nextNode, nextMessage) =>
// We need to repeat the process until we identify the (real) next node, or find out that we're the recipient.
queryNextNodeId(nextMessage, nextNode)
case received: OnionMessages.ReceiveMessage =>
context.system.eventStream ! EventStream.Publish(received)
replyTo_opt.foreach(_ ! Sent(messageId))
Behaviors.stopped
}
} else {
policy match {
case RelayChannelsOnly =>
switchboard ! GetPeerInfo(context.messageAdapter(WrappedPeerInfo), prevNodeId)
waitForPreviousPeerForPolicyCheck(msg, nextNodeId)
case RelayAll =>
switchboard ! Peer.Connect(nextNodeId, None, context.messageAdapter(WrappedConnectionResult).toClassic, isPersistent = false)
waitForConnection(msg)
}
}
}

def waitForPreviousPeer(messageId: ByteVector32, switchboard: ActorRef, nextNodeId: PublicKey, msg: OnionMessage, replyTo_opt: Option[typed.ActorRef[Status]]): Behavior[Command] = {
Behaviors.receivePartial {
case (context, WrappedPeerInfo(PeerInfo(_, _, _, _, channels))) if channels.nonEmpty =>
private def waitForPreviousPeerForPolicyCheck(msg: OnionMessage, nextNodeId: PublicKey): Behavior[Command] = {
Behaviors.receiveMessagePartial {
case WrappedPeerInfo(PeerInfo(_, _, _, _, channels)) if channels.nonEmpty =>
switchboard ! GetPeerInfo(context.messageAdapter(WrappedPeerInfo), nextNodeId)
waitForNextPeer(messageId, msg, replyTo_opt)
waitForNextPeerForPolicyCheck(msg)
case _ =>
replyTo_opt.foreach(_ ! AgainstPolicy(messageId, RelayChannelsOnly))
Behaviors.stopped
}
}

def waitForNextPeer(messageId: ByteVector32, msg: OnionMessage, replyTo_opt: Option[typed.ActorRef[Status]]): Behavior[Command] = {
private def waitForNextPeerForPolicyCheck(msg: OnionMessage): Behavior[Command] = {
Behaviors.receiveMessagePartial {
case WrappedPeerInfo(PeerInfo(peer, _, _, _, channels)) if channels.nonEmpty =>
peer ! Peer.RelayOnionMessage(messageId, msg, replyTo_opt)
Expand All @@ -130,7 +167,7 @@ object MessageRelay {
}
}

def waitForConnection(messageId: ByteVector32, msg: OnionMessage, replyTo_opt: Option[typed.ActorRef[Status]]): Behavior[Command] = {
private def waitForConnection(msg: OnionMessage): Behavior[Command] = {
Behaviors.receiveMessagePartial {
case WrappedConnectionResult(r: PeerConnection.ConnectionResult.HasConnection) =>
r.peer ! Peer.RelayOnionMessage(messageId, msg, replyTo_opt)
Expand Down
17 changes: 13 additions & 4 deletions eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import fr.acinq.eclair.io.OpenChannelInterceptor.{OpenChannelInitiator, OpenChan
import fr.acinq.eclair.io.PeerConnection.KillReason
import fr.acinq.eclair.message.OnionMessages
import fr.acinq.eclair.remote.EclairInternalsSerializer.RemoteTypes
import fr.acinq.eclair.router.Router
import fr.acinq.eclair.wire.protocol
import fr.acinq.eclair.wire.protocol.{Error, HasChannelId, HasTemporaryChannelId, LightningMessage, NodeAddress, OnionMessage, RoutingMessage, UnknownMessage, Warning}

Expand All @@ -51,7 +52,14 @@ import fr.acinq.eclair.wire.protocol.{Error, HasChannelId, HasTemporaryChannelId
*
* Created by PM on 26/08/2016.
*/
class Peer(val nodeParams: NodeParams, remoteNodeId: PublicKey, wallet: OnchainPubkeyCache, channelFactory: Peer.ChannelFactory, switchboard: ActorRef, register: ActorRef, pendingChannelsRateLimiter: typed.ActorRef[PendingChannelsRateLimiter.Command]) extends FSMDiagnosticActorLogging[Peer.State, Peer.Data] {
class Peer(val nodeParams: NodeParams,
remoteNodeId: PublicKey,
wallet: OnchainPubkeyCache,
channelFactory: Peer.ChannelFactory,
switchboard: ActorRef,
register: ActorRef,
router: typed.ActorRef[Router.GetNodeId],
pendingChannelsRateLimiter: typed.ActorRef[PendingChannelsRateLimiter.Command]) extends FSMDiagnosticActorLogging[Peer.State, Peer.Data] {

import Peer._

Expand Down Expand Up @@ -279,8 +287,8 @@ class Peer(val nodeParams: NodeParams, remoteNodeId: PublicKey, wallet: OnchainP
log.debug("dropping message from {}: {}", remoteNodeId.value.toHex, reason.toString)
case OnionMessages.SendMessage(nextNode, message) if nodeParams.features.hasFeature(Features.OnionMessages) =>
val messageId = randomBytes32()
val relay = context.spawn(Behaviors.supervise(MessageRelay()).onFailure(typed.SupervisorStrategy.stop), s"relay-message-$messageId")
relay ! MessageRelay.RelayMessage(messageId, switchboard, register, remoteNodeId, nextNode, message, nodeParams.onionMessageConfig.relayPolicy, None)
val relay = context.spawn(Behaviors.supervise(MessageRelay(nodeParams, switchboard, register, router)).onFailure(typed.SupervisorStrategy.stop), s"relay-message-$messageId")
relay ! MessageRelay.RelayMessage(messageId, remoteNodeId, nextNode, message, nodeParams.onionMessageConfig.relayPolicy, None)
case OnionMessages.SendMessage(_, _) =>
log.debug("dropping message from {}: relaying onion messages is disabled", remoteNodeId.value.toHex)
case received: OnionMessages.ReceiveMessage =>
Expand Down Expand Up @@ -458,7 +466,8 @@ object Peer {
context.actorOf(Channel.props(nodeParams, wallet, remoteNodeId, watcher, relayer, txPublisherFactory))
}

def props(nodeParams: NodeParams, remoteNodeId: PublicKey, wallet: OnchainPubkeyCache, channelFactory: ChannelFactory, switchboard: ActorRef, register: ActorRef, pendingChannelsRateLimiter: typed.ActorRef[PendingChannelsRateLimiter.Command]): Props = Props(new Peer(nodeParams, remoteNodeId, wallet, channelFactory, switchboard, register, pendingChannelsRateLimiter))
def props(nodeParams: NodeParams, remoteNodeId: PublicKey, wallet: OnchainPubkeyCache, channelFactory: ChannelFactory, switchboard: ActorRef, register: ActorRef, router: typed.ActorRef[Router.GetNodeId], pendingChannelsRateLimiter: typed.ActorRef[PendingChannelsRateLimiter.Command]): Props =
Props(new Peer(nodeParams, remoteNodeId, wallet, channelFactory, switchboard, register, router, pendingChannelsRateLimiter))

// @formatter:off

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import fr.acinq.eclair.channel._
import fr.acinq.eclair.io.IncomingConnectionsTracker.TrackIncomingConnection
import fr.acinq.eclair.io.Peer.{PeerInfoResponse, PeerNotFound}
import fr.acinq.eclair.remote.EclairInternalsSerializer.RemoteTypes
import fr.acinq.eclair.router.Router
import fr.acinq.eclair.router.Router.RouterConf
import fr.acinq.eclair.{NodeParams, SubscriptionsComplete}

Expand Down Expand Up @@ -159,9 +160,9 @@ object Switchboard {
def spawn(context: ActorContext, remoteNodeId: PublicKey): ActorRef
}

case class SimplePeerFactory(nodeParams: NodeParams, wallet: OnchainPubkeyCache, channelFactory: Peer.ChannelFactory, pendingChannelsRateLimiter: typed.ActorRef[PendingChannelsRateLimiter.Command], register: ActorRef) extends PeerFactory {
case class SimplePeerFactory(nodeParams: NodeParams, wallet: OnchainPubkeyCache, channelFactory: Peer.ChannelFactory, pendingChannelsRateLimiter: typed.ActorRef[PendingChannelsRateLimiter.Command], register: ActorRef, router: typed.ActorRef[Router.GetNodeId]) extends PeerFactory {
override def spawn(context: ActorContext, remoteNodeId: PublicKey): ActorRef =
context.actorOf(Peer.props(nodeParams, remoteNodeId, wallet, channelFactory, context.self, register, pendingChannelsRateLimiter), name = peerActorName(remoteNodeId))
context.actorOf(Peer.props(nodeParams, remoteNodeId, wallet, channelFactory, context.self, register, router, pendingChannelsRateLimiter), name = peerActorName(remoteNodeId))
}

def props(nodeParams: NodeParams, peerFactory: PeerFactory) = Props(new Switchboard(nodeParams, peerFactory))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package fr.acinq.eclair.message

import fr.acinq.bitcoin.scalacompat.Crypto.{PrivateKey, PublicKey}
import fr.acinq.eclair.ShortChannelId
import fr.acinq.eclair.{EncodedNodeId, ShortChannelId}
import fr.acinq.eclair.crypto.Sphinx
import fr.acinq.eclair.io.MessageRelay.RelayPolicy
import fr.acinq.eclair.wire.protocol.MessageOnion.{FinalPayload, IntermediatePayload}
Expand Down Expand Up @@ -105,9 +105,9 @@ object OnionMessages {
case Left(_) => None
case Right(decoded) =>
decoded.tlvs.get[RouteBlindingEncryptedDataTlv.OutgoingNodeId] match {
case None => None
case Some(RouteBlindingEncryptedDataTlv.OutgoingNodeId(nextNodeId)) =>
case Some(RouteBlindingEncryptedDataTlv.OutgoingNodeId(EncodedNodeId.Plain(nextNodeId))) =>
Some(Sphinx.RouteBlinding.BlindedRoute(nextNodeId, decoded.nextBlinding, route.blindedNodes.tail))
case _ => None // TODO: allow compact node id and OutgoingChannelId
}
}
case BlindedPath(route) if intermediateNodes.isEmpty => Some(route)
Expand Down Expand Up @@ -165,7 +165,7 @@ object OnionMessages {
// @formatter:off
sealed trait Action
case class DropMessage(reason: DropReason) extends Action
case class SendMessage(nextNode: Either[ShortChannelId, PublicKey], message: OnionMessage) extends Action
case class SendMessage(nextNode: Either[ShortChannelId, EncodedNodeId], message: OnionMessage) extends Action
case class ReceiveMessage(finalPayload: FinalPayload) extends Action

sealed trait DropReason
Expand Down Expand Up @@ -211,8 +211,8 @@ object OnionMessages {
case Left(f) => DropMessage(f)
case Right(DecodedEncryptedData(blindedPayload, nextBlinding)) => nextPacket_opt match {
case Some(nextPacket) => validateRelayPayload(payload, blindedPayload, nextBlinding, nextPacket) match {
case SendMessage(Right(nextNodeId), nextMsg) if nextNodeId == privateKey.publicKey => process(privateKey, nextMsg)
case SendMessage(Left(outgoingChannelId), nextMsg) if outgoingChannelId == ShortChannelId.toSelf => process(privateKey, nextMsg)
case SendMessage(Right(EncodedNodeId.Plain(publicKey)), nextMsg) if publicKey == privateKey.publicKey => process(privateKey, nextMsg) // TODO: remove and rely on MessageRelay
case SendMessage(Left(outgoingChannelId), nextMsg) if outgoingChannelId == ShortChannelId.toSelf => process(privateKey, nextMsg) // TODO: remove and rely on MessageRelay
case action => action
}
case None => validateFinalPayload(payload, blindedPayload)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import fr.acinq.eclair.router.Router.{MessageRoute, MessageRouteNotFound, Messag
import fr.acinq.eclair.wire.protocol.MessageOnion.{FinalPayload, InvoiceRequestPayload}
import fr.acinq.eclair.wire.protocol.OfferTypes.{CompactBlindedPath, ContactInfo}
import fr.acinq.eclair.wire.protocol.{OfferTypes, OnionMessagePayloadTlv, TlvStream}
import fr.acinq.eclair.{NodeParams, randomBytes32, randomKey}
import fr.acinq.eclair.{EncodedNodeId, NodeParams, randomBytes32, randomKey}

import scala.collection.mutable

Expand Down Expand Up @@ -214,8 +214,8 @@ private class SendingMessage(nodeParams: NodeParams,
replyTo ! Postman.MessageFailed(failure.toString)
Behaviors.stopped
case Right((nextNodeId, message)) =>
val relay = context.spawn(Behaviors.supervise(MessageRelay()).onFailure(typed.SupervisorStrategy.stop), s"relay-message-$messageId")
relay ! MessageRelay.RelayMessage(messageId, switchboard, register, nodeParams.nodeId, Right(nextNodeId), message, MessageRelay.RelayAll, Some(context.messageAdapter[MessageRelay.Status](SendingStatus)))
val relay = context.spawn(Behaviors.supervise(MessageRelay(nodeParams, switchboard, register, router)).onFailure(typed.SupervisorStrategy.stop), s"relay-message-$messageId")
relay ! MessageRelay.RelayMessage(messageId, nodeParams.nodeId, Right(EncodedNodeId(nextNodeId)), message, MessageRelay.RelayAll, Some(context.messageAdapter[MessageRelay.Status](SendingStatus)))
waitForSent()
}
}
Expand Down
Loading

0 comments on commit 86c4837

Please sign in to comment.