Skip to content

Commit

Permalink
Send batch_size on commit_sig retransmit (#2809)
Browse files Browse the repository at this point in the history
If we get disconnected after sending `commit_sig`, we will retransmit
that message when reconnecting if our peer has not received it.

When we have multiple commitments, we need to retransmit one message
per commitment, and include the `batch_size` tlv.

We incorrectly stored `commit_sig` without the `batch_size` tlv in the
next remote commit field, and thus retransmitted without that tlv.
  • Loading branch information
t-bast authored Jan 10, 2024
1 parent a9b5903 commit 3bd3d07
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -616,7 +616,7 @@ case class Commitment(fundingTxIndex: Long,
Right(())
}

def sendCommit(keyManager: ChannelKeyManager, params: ChannelParams, changes: CommitmentChanges, remoteNextPerCommitmentPoint: PublicKey)(implicit log: LoggingAdapter): (Commitment, CommitSig) = {
def sendCommit(keyManager: ChannelKeyManager, params: ChannelParams, changes: CommitmentChanges, remoteNextPerCommitmentPoint: PublicKey, batchSize: Int)(implicit log: LoggingAdapter): (Commitment, CommitSig) = {
// remote commitment will include all local proposed changes + remote acked changes
val spec = CommitmentSpec.reduce(remoteCommit.spec, changes.remoteChanges.acked, changes.localChanges.proposed)
val (remoteCommitTx, htlcTxs) = Commitment.makeRemoteTxs(keyManager, params.channelConfig, params.channelFeatures, remoteCommit.index + 1, params.localParams, params.remoteParams, fundingTxIndex, remoteFundingPubKey, commitInput, remoteNextPerCommitmentPoint, spec)
Expand All @@ -630,7 +630,9 @@ case class Commitment(fundingTxIndex: Long,
log.info(s"built remote commit number=${remoteCommit.index + 1} toLocalMsat=${spec.toLocal.toLong} toRemoteMsat=${spec.toRemote.toLong} htlc_in={} htlc_out={} feeratePerKw=${spec.commitTxFeerate} txid=${remoteCommitTx.tx.txid} fundingTxId=$fundingTxId", spec.htlcs.collect(DirectedHtlc.outgoing).map(_.id).mkString(","), spec.htlcs.collect(DirectedHtlc.incoming).map(_.id).mkString(","))
Metrics.recordHtlcsInFlight(spec, remoteCommit.spec)

val commitSig = CommitSig(params.channelId, sig, htlcSigs.toList)
val commitSig = CommitSig(params.channelId, sig, htlcSigs.toList, TlvStream(Set(
if (batchSize > 1) Some(CommitSigTlv.BatchTlv(batchSize)) else None
).flatten[CommitSigTlv]))
val nextRemoteCommit = NextRemoteCommit(commitSig, RemoteCommit(remoteCommit.index + 1, spec, remoteCommitTx.tx.txid, remoteNextPerCommitmentPoint))
(copy(nextRemoteCommit_opt = Some(nextRemoteCommit)), commitSig)
}
Expand Down Expand Up @@ -987,7 +989,7 @@ case class Commitments(params: ChannelParams,
remoteNextCommitInfo match {
case Right(_) if !changes.localHasChanges => Left(CannotSignWithoutChanges(channelId))
case Right(remoteNextPerCommitmentPoint) =>
val (active1, sigs) = active.map(_.sendCommit(keyManager, params, changes, remoteNextPerCommitmentPoint)).unzip
val (active1, sigs) = active.map(_.sendCommit(keyManager, params, changes, remoteNextPerCommitmentPoint, active.size)).unzip
val commitments1 = copy(
changes = changes.copy(
localChanges = changes.localChanges.copy(proposed = Nil, signed = changes.localChanges.proposed),
Expand All @@ -996,13 +998,7 @@ case class Commitments(params: ChannelParams,
active = active1,
remoteNextCommitInfo = Left(WaitForRev(localCommitIndex))
)
val sigs1 = if (sigs.size > 1) {
// if there are more than one sig, we add a tlv to tell the receiver how many sigs are to be expected
sigs.map { sig => sig.modify(_.tlvStream.records).using(_ + CommitSigTlv.BatchTlv(sigs.size)) }
} else {
sigs
}
Right(commitments1, sigs1)
Right(commitments1, sigs)
case Left(_) => Left(CannotSignBeforeRevocation(channelId))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -826,22 +826,56 @@ class NormalSplicesStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLik
import f._
initiateSplice(f, spliceIn_opt = Some(SpliceIn(500_000 sat)))
val sender = TestProbe()
alice ! CMD_ADD_HTLC(sender.ref, 500000 msat, randomBytes32(), CltvExpiryDelta(144).toCltvExpiry(currentBlockHeight), TestConstants.emptyOnionPacket, None, localOrigin(sender.ref))
alice ! CMD_ADD_HTLC(sender.ref, 500_000 msat, randomBytes32(), CltvExpiryDelta(144).toCltvExpiry(currentBlockHeight), TestConstants.emptyOnionPacket, None, localOrigin(sender.ref))
sender.expectMsgType[RES_SUCCESS[CMD_ADD_HTLC]]
alice2bob.expectMsgType[UpdateAddHtlc]
alice2bob.forward(bob)
alice ! CMD_SIGN()
val sig1 = alice2bob.expectMsgType[CommitSig]
assert(sig1.batchSize == 2)
val sigA1 = alice2bob.expectMsgType[CommitSig]
assert(sigA1.batchSize == 2)
alice2bob.forward(bob)
val sig2 = alice2bob.expectMsgType[CommitSig]
assert(sig2.batchSize == 2)
val sigA2 = alice2bob.expectMsgType[CommitSig]
assert(sigA2.batchSize == 2)
alice2bob.forward(bob)
bob2alice.expectMsgType[RevokeAndAck]
bob2alice.forward(alice)
bob2alice.expectMsgType[CommitSig]
val sigB1 = bob2alice.expectMsgType[CommitSig]
assert(sigB1.batchSize == 2)
bob2alice.forward(alice)
bob2alice.expectMsgType[CommitSig]
val sigB2 = bob2alice.expectMsgType[CommitSig]
assert(sigB2.batchSize == 2)
bob2alice.forward(alice)
alice2bob.expectMsgType[RevokeAndAck]
alice2bob.forward(bob)
awaitCond(alice.stateData.asInstanceOf[DATA_NORMAL].commitments.active.forall(_.localCommit.spec.htlcs.size == 1))
awaitCond(bob.stateData.asInstanceOf[DATA_NORMAL].commitments.active.forall(_.localCommit.spec.htlcs.size == 1))
}

test("recv CMD_ADD_HTLC with multiple commitments and reconnect") { f =>
import f._
initiateSplice(f, spliceIn_opt = Some(SpliceIn(500_000 sat)))
val sender = TestProbe()
alice ! CMD_ADD_HTLC(sender.ref, 500_000 msat, randomBytes32(), CltvExpiryDelta(144).toCltvExpiry(currentBlockHeight), TestConstants.emptyOnionPacket, None, localOrigin(sender.ref))
sender.expectMsgType[RES_SUCCESS[CMD_ADD_HTLC]]
alice2bob.expectMsgType[UpdateAddHtlc]
alice2bob.forward(bob)
alice ! CMD_SIGN()
assert(alice2bob.expectMsgType[CommitSig].batchSize == 2)
assert(alice2bob.expectMsgType[CommitSig].batchSize == 2)
// Bob disconnects before receiving Alice's commit_sig.
disconnect(f)
reconnect(f, interceptFundingDeeplyBuried = false)
alice2bob.expectMsgType[UpdateAddHtlc]
alice2bob.forward(bob)
assert(alice2bob.expectMsgType[CommitSig].batchSize == 2)
alice2bob.forward(bob)
assert(alice2bob.expectMsgType[CommitSig].batchSize == 2)
alice2bob.forward(bob)
bob2alice.expectMsgType[RevokeAndAck]
bob2alice.forward(alice)
assert(bob2alice.expectMsgType[CommitSig].batchSize == 2)
bob2alice.forward(alice)
assert(bob2alice.expectMsgType[CommitSig].batchSize == 2)
bob2alice.forward(alice)
alice2bob.expectMsgType[RevokeAndAck]
alice2bob.forward(bob)
Expand Down

0 comments on commit 3bd3d07

Please sign in to comment.