Skip to content

Commit

Permalink
Fill with zeros if the error to relay is too short
Browse files Browse the repository at this point in the history
  • Loading branch information
thomash-acinq committed Nov 14, 2023
1 parent bfdb39e commit df0dab1
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 19 deletions.
23 changes: 9 additions & 14 deletions eclair-core/src/main/scala/fr/acinq/eclair/crypto/Sphinx.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import fr.acinq.bitcoin.scalacompat.Crypto.{PrivateKey, PublicKey}
import fr.acinq.bitcoin.scalacompat.{ByteVector32, Crypto}
import fr.acinq.eclair.wire.protocol._
import grizzled.slf4j.Logging
import scodec.Attempt
import scodec.{Attempt, DecodeResult}
import scodec.bits.ByteVector

import scala.annotation.tailrec
Expand Down Expand Up @@ -347,10 +347,10 @@ object Sphinx extends Logging {

def create(sharedSecret: ByteVector32, failure: FailureMessage, holdTime: FiniteDuration): ByteVector = {
val failurePayload = FailureMessageCodecs.failureOnionPayload(payloadAndPadLength).encode(failure).require.toByteVector
val zeroPayloads = Seq.fill(maxNumHop)(ByteVector.fill(hopPayloadLength)(0))
val zeroHmacs = (maxNumHop.to(1, -1)).map(Seq.fill(_)(ByteVector.low(4)))
val zeroPayloads = Seq.fill(maxNumHop)(ByteVector.low(hopPayloadLength))
val zeroHmacs = maxNumHop.to(1, -1).map(Seq.fill(_)(ByteVector.low(4)))
val plainError = attributableErrorCodec(totalLength, hopPayloadLength, maxNumHop).encode(AttributableError(failurePayload, zeroPayloads, zeroHmacs)).require.bytes
wrap(plainError, sharedSecret, holdTime, isSource = true).get
wrap(plainError, sharedSecret, holdTime, isSource = true)
}

private def computeHmacs(mac: Mac32, failurePayload: ByteVector, hopPayloads: Seq[ByteVector], hmacs: Seq[Seq[ByteVector]], minNumHop: Int): Seq[ByteVector] = {
Expand All @@ -363,9 +363,12 @@ object Sphinx extends Logging {
newHmacs
}

def wrap(errorPacket: ByteVector, sharedSecret: ByteVector32, holdTime: FiniteDuration, isSource: Boolean): Try[ByteVector] = Try {
def wrap(errorPacket: ByteVector, sharedSecret: ByteVector32, holdTime: FiniteDuration, isSource: Boolean): ByteVector = {
val um = generateKey("um", sharedSecret)
val error = attributableErrorCodec(errorPacket.length.toInt, hopPayloadLength, maxNumHop).decode(errorPacket.bits).require.value
val error = attributableErrorCodec(errorPacket.length.toInt, hopPayloadLength, maxNumHop).decode(errorPacket.bits) match {
case Attempt.Successful(DecodeResult(value, _)) => value
case Attempt.Failure(_) => AttributableError.zero(payloadAndPadLength, hopPayloadLength, maxNumHop)
}
val hopPayloads = hopPayloadCodec.encode(HopPayload(isSource, holdTime)).require.bytes +: error.hopPayloads.dropRight(1)
val hmacs = computeHmacs(Hmac256(um), error.failurePayload, hopPayloads, error.hmacs.map(_.drop(1)), 0) +: error.hmacs.dropRight(1).map(_.drop(1))
val newError = attributableErrorCodec(errorPacket.length.toInt, hopPayloadLength, maxNumHop).encode(AttributableError(error.failurePayload, hopPayloads, hmacs)).require.bytes
Expand All @@ -374,14 +377,6 @@ object Sphinx extends Logging {
newError xor stream
}

def wrapOrCreate(errorPacket: ByteVector, sharedSecret: ByteVector32, holdTime: FiniteDuration): ByteVector =
wrap(errorPacket, sharedSecret, holdTime, isSource = false) match {
case Failure(_) =>
// There is no failure message for this use-case, using TemporaryNodeFailure instead.
create(sharedSecret, TemporaryNodeFailure(), holdTime)
case Success(value) => value
}

private def unwrap(errorPacket: ByteVector, sharedSecret: ByteVector32, minNumHop: Int): Try[(ByteVector, HopPayload)] = Try {
val key = generateKey("ammag", sharedSecret)
val stream = generateStream(key, errorPacket.length.toInt)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ object OutgoingPaymentPacket {
Sphinx.peel(nodeSecret, Some(add.paymentHash), add.onionRoutingPacket) match {
case Right(Sphinx.DecryptedPacket(_, _, sharedSecret)) =>
val encryptedReason = reason match {
case Left(forwarded) if useAttributableErrors => Sphinx.AttributableErrorPacket.wrapOrCreate(forwarded, sharedSecret, holdTime)
case Left(forwarded) if useAttributableErrors => Sphinx.AttributableErrorPacket.wrap(forwarded, sharedSecret, holdTime, isSource = false)
case Right(failure) if useAttributableErrors => Sphinx.AttributableErrorPacket.create(sharedSecret, failure, holdTime)
case Left(forwarded) => Sphinx.FailurePacket.wrap(forwarded, sharedSecret)
case Right(failure) => Sphinx.FailurePacket.create(sharedSecret, failure)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,10 @@ object AttributableError {
(("failure_payload" | bytes(totalLength - metadataLength)) ::
("hop_payloads" | listOfN(provide(maxNumHop), bytes(hopPayloadLength)).xmap[Seq[ByteVector]](_.toSeq, _.toList)) ::
("hmacs" | hmacsCodec(maxNumHop))).as[AttributableError].complete}

def zero(payloadAndPadLength: Int, hopPayloadLength: Int, maxNumHop: Int): AttributableError =
AttributableError(
ByteVector.low(payloadAndPadLength),
Seq.fill(maxNumHop)(ByteVector.low(hopPayloadLength)),
maxNumHop.to(1, -1).map(Seq.fill(_)(ByteVector.low(4))))
}
Original file line number Diff line number Diff line change
Expand Up @@ -414,13 +414,13 @@ class SphinxSpec extends AnyFunSuite {
val Right(decrypted1) = AttributableErrorPacket.decrypt(packet1, (2 to 4).map(i => (sharedSecrets(i), publicKeys(i))))
assert(decrypted1 == expected)

val Success(packet2) = AttributableErrorPacket.wrap(packet1, sharedSecrets(1), 5 millis, isSource = false)
val packet2 = AttributableErrorPacket.wrap(packet1, sharedSecrets(1), 5 millis, isSource = false)
assert(packet2.length == 1200)

val Right(decrypted2) = AttributableErrorPacket.decrypt(packet2, (1 to 4).map(i => (sharedSecrets(i), publicKeys(i))))
assert(decrypted2 == expected)

val Success(packet3) = AttributableErrorPacket.wrap(packet2, sharedSecrets(0), 9 millis, isSource = false)
val packet3 = AttributableErrorPacket.wrap(packet2, sharedSecrets(0), 9 millis, isSource = false)
assert(packet3.length == 1200)

val Right(decrypted3) = AttributableErrorPacket.decrypt(packet3, (0 to 4).map(i => (sharedSecrets(i), publicKeys(i))))
Expand All @@ -440,11 +440,11 @@ class SphinxSpec extends AnyFunSuite {
val packet1 = randomBytes(1200)

val hopPayload2 = AttributableError.HopPayload(isPayloadSource = false, 50 millis)
val Success(packet2) = AttributableErrorPacket.wrap(packet1, sharedSecrets(1), 50 millis, isSource = false)
val packet2 = AttributableErrorPacket.wrap(packet1, sharedSecrets(1), 50 millis, isSource = false)
assert(packet2.length == 1200)

val hopPayload3 = AttributableError.HopPayload(isPayloadSource = false, 100 millis)
val Success(packet3) = AttributableErrorPacket.wrap(packet2, sharedSecrets(0), 100 millis, isSource = false)
val packet3 = AttributableErrorPacket.wrap(packet2, sharedSecrets(0), 100 millis, isSource = false)
assert(packet3.length == 1200)

val Left(decryptionError) = AttributableErrorPacket.decrypt(packet3, (0 to 4).map(i => (sharedSecrets(i), publicKeys(i))))
Expand Down

0 comments on commit df0dab1

Please sign in to comment.