Skip to content

Commit

Permalink
[AN-333] Prevent infinite DRS download retries (#7679)
Browse files Browse the repository at this point in the history
  • Loading branch information
lucymcnatt authored Jan 23, 2025
1 parent 4614c07 commit 2c134ec
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,33 +57,29 @@ case class GcsUriDownloader(gcsUrl: String,
downloadAttempt: Int = 0
): IO[DownloadResult] = {

def maybeRetryForDownloadFailure(t: Throwable): IO[DownloadResult] =
if (downloadAttempt < downloadRetries) {
backoff foreach { b => Thread.sleep(b.backoffMillis) }
logger.warn(s"Attempting download retry $downloadAttempt of $downloadRetries for a GCS url", t)
downloadWithRetries(downloadRetries,
backoff map {
_.next
},
downloadAttempt + 1
)
} else {
IO.raiseError(new RuntimeException(s"Exhausted $downloadRetries resolution retries to download GCS file", t))
}

runDownloadCommand.redeemWith(
recover = maybeRetryForDownloadFailure,
bind = {
case s: DownloadSuccess.type =>
IO.pure(s)
case _: RecognizedRetryableDownloadFailure =>
downloadWithRetries(downloadRetries, backoff, downloadAttempt + 1)
case _: UnrecognizedRetryableDownloadFailure =>
downloadWithRetries(downloadRetries, backoff, downloadAttempt + 1)
case _ =>
downloadWithRetries(downloadRetries, backoff, downloadAttempt + 1)
}
)
// Necessary function to handle the throwable when trying to recover a failed download
def handleDownloadFailure(t: Throwable): IO[DownloadResult] =
downloadWithRetries(downloadRetries, backoff, downloadAttempt + 1)

if (downloadAttempt < downloadRetries) {
backoff foreach { b => Thread.sleep(b.backoffMillis) }
logger.info(s"Attempting download attempt $downloadAttempt of $downloadRetries for a GCS url")
runDownloadCommand.redeemWith(
recover = handleDownloadFailure,
bind = {
case s: DownloadSuccess.type =>
IO.pure(s)
case _: RecognizedRetryableDownloadFailure =>
downloadWithRetries(downloadRetries, backoff, downloadAttempt + 1)
case _: UnrecognizedRetryableDownloadFailure =>
downloadWithRetries(downloadRetries, backoff, downloadAttempt + 1)
case _ =>
downloadWithRetries(downloadRetries, backoff, downloadAttempt + 1)
}
)
} else {
IO.raiseError(new RuntimeException(s"Exhausted $downloadRetries resolution retries to download GCS file"))
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package drs.localizer.downloaders

import common.assertion.CromwellTimeoutSpec
import org.mockito.Mockito.{spy, times, verify}
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers

Expand Down Expand Up @@ -96,4 +97,27 @@ class GcsUriDownloaderSpec extends AnyFlatSpec with CromwellTimeoutSpec with Mat

downloader.generateDownloadScript(gcsUrl, Option(fakeSAJsonPath)) shouldBe expectedDownloadScript
}

it should "fail to download GCS URL after 5 attempts" in {
val gcsUrl = "gs://foo/bar.bam"
val downloader = spy(
new GcsUriDownloader(
gcsUrl = gcsUrl,
downloadLoc = fakeDownloadLocation,
requesterPaysProjectIdOption = Option(fakeRequesterPaysId),
serviceAccountJson = None
)
)

val result = downloader.downloadWithRetries(5, None).attempt.unsafeRunSync()

result.isLeft shouldBe true
// attempts to download the 1st time and the 5th time, but doesn't attempt a 6th
verify(downloader, times(1)).downloadWithRetries(5, None, 1)
verify(downloader, times(1)).downloadWithRetries(5, None, 5)
verify(downloader, times(0)).downloadWithRetries(5, None, 6)
// attempts the actual download command 5 times
verify(downloader, times(5)).runDownloadCommand

}
}

0 comments on commit 2c134ec

Please sign in to comment.