@@ -348,86 +348,102 @@ object Sphinx extends Logging {
348348 HtlcFailure (attribution1_opt.map(n => HoldTime (n._1, ss.remoteNodeId) +: downstreamHoldTimes).getOrElse(Nil ), failure)
349349 }
350350 }
351+ }
352+
353+ /**
354+ * Attribution data is added to the failure packet and prevents a node from evading responsibility for its failures.
355+ * Nodes that relay attribution data can prove that they are not the erring node and in case the erring node tries
356+ * to hide, there will only be at most two nodes that can be the erring node (the last one to send attribution data
357+ * and the one after it). It also adds timing data for each node on the path.
358+ * Attribution data can also be added to fulfilled HTLCs to provide timing data and allow choosing fast nodes for
359+ * future payments.
360+ * https://github.com/lightning/bolts/pull/1044
361+ */
362+ object Attribution {
363+ val maxNumHops = 20
364+ val holdTimeLength = 4
365+ val hmacLength = 4 // HMACs are truncated to 4 bytes to save space
366+ val totalLength = maxNumHops * holdTimeLength + maxNumHops * (maxNumHops + 1 ) / 2 * hmacLength // = 920
367+
368+ private def cipher (bytes : ByteVector , sharedSecret : ByteVector32 ): ByteVector = {
369+ val key = generateKey(" ammagext" , sharedSecret)
370+ val stream = generateStream(key, totalLength)
371+ bytes xor stream
372+ }
351373
352374 /**
353- * Attribution data is added to the failure packet and prevents a node from evading responsibility for its failures.
354- * Nodes that relay attribution data can prove that they are not the erring node and in case the erring node tries
355- * to hide, there will only be at most two nodes that can be the erring node (the last one to send attribution data
356- * and the one after it).
357- * It also adds timing data for each node on the path.
358- * https://github.com/lightning/bolts/pull/1044
375+ * Get the HMACs from the attribution data.
376+ * The layout of the attribution data is as follows (using maxNumHops = 3 for conciseness):
377+ * holdTime(0) ++ holdTime(1) ++ holdTime(2) ++
378+ * hmacs(0)(0) ++ hmacs(0)(1) ++ hmacs(0)(2) ++
379+ * hmacs(1)(0) ++ hmacs(1)(1) ++
380+ * hmacs(2)(0)
381+ *
382+ * Where `hmac(i)(j)` is the hmac added by node `i` (counted from the node that built the attribution data),
383+ * assuming it is `maxNumHops - 1 - i - j` hops away from the erring node.
359384 */
360- object Attribution {
361- val maxNumHops = 20
362- val holdTimeLength = 4
363- val hmacLength = 4 // HMACs are truncated to 4 bytes to save space
364- val totalLength = maxNumHops * holdTimeLength + maxNumHops * (maxNumHops + 1 ) / 2 * hmacLength // = 920
365-
366- private def cipher (bytes : ByteVector , sharedSecret : ByteVector32 ): ByteVector = {
367- val key = generateKey(" ammagext" , sharedSecret)
368- val stream = generateStream(key, totalLength)
369- bytes xor stream
370- }
385+ private def getHmacs (bytes : ByteVector ): Seq [Seq [ByteVector ]] =
386+ (0 until maxNumHops).map(i => (0 until (maxNumHops - i)).map(j => {
387+ val start = maxNumHops * holdTimeLength + (maxNumHops * i - (i * (i - 1 )) / 2 + j) * hmacLength
388+ bytes.slice(start, start + hmacLength)
389+ }))
371390
372- /**
373- * Get the HMACs from the attribution data.
374- * The layout of the attribution data is as follows (using maxNumHops = 3 for conciseness):
375- * holdTime(0) ++ holdTime(1) ++ holdTime(2) ++
376- * hmacs(0)(0) ++ hmacs(0)(1) ++ hmacs(0)(2) ++
377- * hmacs(1)(0) ++ hmacs(1)(1) ++
378- * hmacs(2)(0)
379- *
380- * Where `hmac(i)(j)` is the hmac added by node `i` (counted from the node that built the attribution data),
381- * assuming it is `maxNumHops - 1 - i - j` hops away from the erring node.
382- */
383- private def getHmacs (bytes : ByteVector ): Seq [Seq [ByteVector ]] =
384- (0 until maxNumHops).map(i => (0 until (maxNumHops - i)).map(j => {
385- val start = maxNumHops * holdTimeLength + (maxNumHops * i - (i * (i - 1 )) / 2 + j) * hmacLength
386- bytes.slice(start, start + hmacLength)
387- }))
388-
389- /**
390- * Computes the HMACs for the node that is `minNumHop` hops away from us. Hence we only compute `maxNumHops - minNumHop` HMACs.
391- * HMACs are truncated to 4 bytes to save space. An attacker has only one try to guess the HMAC so 4 bytes should be enough.
392- */
393- private def computeHmacs (mac : Mac32 , failurePacket : ByteVector , holdTimes : ByteVector , hmacs : Seq [Seq [ByteVector ]], minNumHop : Int ): Seq [ByteVector ] = {
394- (minNumHop until maxNumHops).map(i => {
395- val y = maxNumHops - i
396- mac.mac(failurePacket ++
397- holdTimes.take(y * holdTimeLength) ++
398- ByteVector .concat((0 until y - 1 ).map(j => hmacs(j)(i)))).bytes.take(hmacLength)
399- })
400- }
391+ /**
392+ * Computes the HMACs for the node that is `minNumHop` hops away from us. Hence we only compute `maxNumHops - minNumHop` HMACs.
393+ * HMACs are truncated to 4 bytes to save space. An attacker has only one try to guess the HMAC so 4 bytes should be enough.
394+ */
395+ private def computeHmacs (mac : Mac32 , failurePacket : ByteVector , holdTimes : ByteVector , hmacs : Seq [Seq [ByteVector ]], minNumHop : Int ): Seq [ByteVector ] = {
396+ (minNumHop until maxNumHops).map(i => {
397+ val y = maxNumHops - i
398+ mac.mac(failurePacket ++
399+ holdTimes.take(y * holdTimeLength) ++
400+ ByteVector .concat((0 until y - 1 ).map(j => hmacs(j)(i)))).bytes.take(hmacLength)
401+ })
402+ }
403+
404+ /**
405+ * Create attribution data to send with the failure packet or with a fulfilled HTLC
406+ *
407+ * @param failurePacket_opt the failure packet before being wrapped or `None` for fulfilled HTLCs
408+ */
409+ def create (previousAttribution_opt : Option [ByteVector ], failurePacket_opt : Option [ByteVector ], holdTime : FiniteDuration , sharedSecret : ByteVector32 ): ByteVector = {
410+ val previousAttribution = previousAttribution_opt.getOrElse(ByteVector .low(totalLength))
411+ val previousHmacs = getHmacs(previousAttribution).dropRight(1 ).map(_.drop(1 ))
412+ val mac = Hmac256 (generateKey(" um" , sharedSecret))
413+ val holdTimes = uint32.encode(holdTime.toMillis).require.bytes ++ previousAttribution.take((maxNumHops - 1 ) * holdTimeLength)
414+ val hmacs = computeHmacs(mac, failurePacket_opt.getOrElse(ByteVector .empty), holdTimes, previousHmacs, 0 ) +: previousHmacs
415+ cipher(holdTimes ++ ByteVector .concat(hmacs.map(ByteVector .concat(_))), sharedSecret)
416+ }
401417
402- /**
403- * Create attribution data to send with the failure packet
404- *
405- * @param failurePacket the failure packet before being wrapped
406- */
407- def create (previousAttribution_opt : Option [ByteVector ], failurePacket : ByteVector , holdTime : FiniteDuration , sharedSecret : ByteVector32 ): ByteVector = {
408- val previousAttribution = previousAttribution_opt.getOrElse(ByteVector .low(totalLength))
409- val previousHmacs = getHmacs(previousAttribution).dropRight(1 ).map(_.drop(1 ))
410- val mac = Hmac256 (generateKey(" um" , sharedSecret))
411- val holdTimes = uint32.encode(holdTime.toMillis).require.bytes ++ previousAttribution.take((maxNumHops - 1 ) * holdTimeLength)
412- val hmacs = computeHmacs(mac, failurePacket, holdTimes, previousHmacs, 0 ) +: previousHmacs
413- cipher(holdTimes ++ ByteVector .concat(hmacs.map(ByteVector .concat(_))), sharedSecret)
418+ /**
419+ * Unwrap one hop of attribution data
420+ * @return a pair with the hold time for this hop and the attribution data for the next hop, or None if the attribution data was invalid
421+ */
422+ def unwrap (encrypted : ByteVector , failurePacket : ByteVector , sharedSecret : ByteVector32 , minNumHop : Int ): Option [(FiniteDuration , ByteVector )] = {
423+ val bytes = cipher(encrypted, sharedSecret)
424+ val holdTime = uint32.decode(bytes.take(holdTimeLength).bits).require.value.milliseconds
425+ val hmacs = getHmacs(bytes)
426+ val mac = Hmac256 (generateKey(" um" , sharedSecret))
427+ if (computeHmacs(mac, failurePacket, bytes.take(maxNumHops * holdTimeLength), hmacs.drop(1 ), minNumHop) == hmacs.head.drop(minNumHop)) {
428+ val unwrapped = bytes.slice(holdTimeLength, maxNumHops * holdTimeLength) ++ ByteVector .low(holdTimeLength) ++ ByteVector .concat((hmacs.drop(1 ) :+ Seq ()).map(s => ByteVector .low(hmacLength) ++ ByteVector .concat(s)))
429+ Some (holdTime, unwrapped)
430+ } else {
431+ None
414432 }
433+ }
415434
416- /**
417- * Unwrap one hop of attribution data
418- * @return a pair with the hold time for this hop and the attribution data for the next hop, or None if the attribution data was invalid
419- */
420- def unwrap (encrypted : ByteVector , failurePacket : ByteVector , sharedSecret : ByteVector32 , minNumHop : Int ): Option [(FiniteDuration , ByteVector )] = {
421- val bytes = cipher(encrypted, sharedSecret)
422- val holdTime = uint32.decode(bytes.take(holdTimeLength).bits).require.value.milliseconds
423- val hmacs = getHmacs(bytes)
424- val mac = Hmac256 (generateKey(" um" , sharedSecret))
425- if (computeHmacs(mac, failurePacket, bytes.take(maxNumHops * holdTimeLength), hmacs.drop(1 ), minNumHop) == hmacs.head.drop(minNumHop)) {
426- val unwrapped = bytes.slice(holdTimeLength, maxNumHops * holdTimeLength) ++ ByteVector .low(holdTimeLength) ++ ByteVector .concat((hmacs.drop(1 ) :+ Seq ()).map(s => ByteVector .low(hmacLength) ++ ByteVector .concat(s)))
427- Some (holdTime, unwrapped)
428- } else {
429- None
430- }
435+ /**
436+ * Decrypt the hold times from the attribution data of a fulfilled HTLC
437+ */
438+ def fulfillHoldTimes (attribution : ByteVector , sharedSecrets : Seq [SharedSecret ], hopIndex : Int = 0 ): List [HoldTime ] = {
439+ sharedSecrets match {
440+ case Nil => Nil
441+ case ss :: tail =>
442+ unwrap(attribution, ByteVector .empty, ss.secret, hopIndex) match {
443+ case Some ((holdTime, nextAttribution)) =>
444+ HoldTime (holdTime, ss.remoteNodeId) :: fulfillHoldTimes(nextAttribution, tail, hopIndex + 1 )
445+ case None => Nil
446+ }
431447 }
432448 }
433449 }
0 commit comments