diff --git a/spark/src/main/scala/io/delta/sharing/spark/DeltaSharingClient.scala b/spark/src/main/scala/io/delta/sharing/spark/DeltaSharingClient.scala index 36fa7842f..a470b87af 100644 --- a/spark/src/main/scala/io/delta/sharing/spark/DeltaSharingClient.scala +++ b/spark/src/main/scala/io/delta/sharing/spark/DeltaSharingClient.scala @@ -16,6 +16,7 @@ package io.delta.sharing.spark +import java.io.{BufferedReader, InputStream, InputStreamReader} import java.net.{URL, URLEncoder} import java.nio.charset.StandardCharsets.UTF_8 import java.sql.Timestamp @@ -23,9 +24,10 @@ import java.time.LocalDateTime import java.time.format.DateTimeFormatter.ISO_DATE_TIME import scala.collection.JavaConverters._ -import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.{ArrayBuffer, ListBuffer} import org.apache.commons.io.IOUtils +import org.apache.commons.io.input.BoundedInputStream import org.apache.hadoop.util.VersionInfo import org.apache.http.{HttpHeaders, HttpHost, HttpStatus} import org.apache.http.client.config.RequestConfig @@ -188,7 +190,7 @@ private[spark] class DeltaSharingRestClient( val target = getTargetUrl(s"/shares/$encodedShareName/schemas/$encodedSchemaName/tables/" + s"$encodedTableName/version$encodedParam") - val (version, _) = getResponse(new HttpGet(target), true) + val (version, _) = getResponse(new HttpGet(target), true, true) version.getOrElse { throw new IllegalStateException("Cannot find Delta-Table-Version in the header") } @@ -324,14 +326,14 @@ private[spark] class DeltaSharingRestClient( } private def getNDJson(target: String, requireVersion: Boolean = true): (Long, Seq[String]) = { - val (version, response) = getResponse(new HttpGet(target)) + val (version, lines) = getResponse(new HttpGet(target)) version.getOrElse { if (requireVersion) { throw new IllegalStateException("Cannot find Delta-Table-Version in the header") } else { 0L } - } -> response.split("[\n\r]+") + } -> lines } private def getNDJson[T: Manifest](target: String, data: T): (Long, Seq[String]) = { @@ -339,15 +341,20 @@ private[spark] class DeltaSharingRestClient( val json = JsonUtils.toJson(data) httpPost.setHeader("Content-type", "application/json") httpPost.setEntity(new StringEntity(json, UTF_8)) - val (version, response) = getResponse(httpPost) + val (version, lines) = getResponse(httpPost) version.getOrElse { throw new IllegalStateException("Cannot find Delta-Table-Version in the header") - } -> response.split("[\n\r]+") + } -> lines } private def getJson[R: Manifest](target: String): R = { - val (_, response) = getResponse(new HttpGet(target)) - JsonUtils.fromJson[R](response) + val (_, response) = getResponse(new HttpGet(target), false, true) + if (response.size != 1) { + throw new IllegalStateException( + "Unexpected response for target: " + target + ", response=" + response + ) + } + JsonUtils.fromJson[R](response(0)) } private def getHttpHost(endpoint: String): HttpHost = { @@ -393,11 +400,17 @@ private[spark] class DeltaSharingRestClient( /** * Send the http request and return the table version in the header if any, and the response * content. + * + * The response can be: + * - empty if allowNoContent is true. + * - single string, if fetchAsOneString is true. + * - multi-line response (typically, one per action). This is the default. */ private def getResponse( httpRequest: HttpRequestBase, - allowNoContent: Boolean = false - ): (Option[Long], String) = + allowNoContent: Boolean = false, + fetchAsOneString: Boolean = false + ): (Option[Long], Seq[String]) = { RetryUtils.runWithExponentialBackoff(numRetries) { val profile = profileProvider.getProfile val response = client.execute( @@ -408,12 +421,26 @@ private[spark] class DeltaSharingRestClient( try { val status = response.getStatusLine() val entity = response.getEntity() - val body = if (entity == null) { - "" + val lines = if (entity == null) { + List("") } else { val input = entity.getContent() try { - IOUtils.toString(input, UTF_8) + if (fetchAsOneString) { + Seq(IOUtils.toString(input, UTF_8)) + } else { + val reader = new BufferedReader( + new InputStreamReader(new BoundedInputStream(input), UTF_8) + ) + var line: Option[String] = None + val lineBuffer = ListBuffer[String]() + while ({ + line = Option(reader.readLine()); line.isDefined + }) { + lineBuffer += line.get + } + lineBuffer.toList + } } finally { input.close() } @@ -427,15 +454,18 @@ private[spark] class DeltaSharingRestClient( additionalErrorInfo = s"It may be caused by an expired token as it has expired " + s"at ${profile.expirationTime}" } + // Only show the last 100 lines in the error to keep it contained. + val responseToShow = lines.drop(lines.size - 100).mkString("\n") throw new UnexpectedHttpStatus( - s"HTTP request failed with status: $status $body. $additionalErrorInfo", + s"HTTP request failed with status: $status $responseToShow. $additionalErrorInfo", statusCode) } - Option(response.getFirstHeader("Delta-Table-Version")).map(_.getValue.toLong) -> body + Option(response.getFirstHeader("Delta-Table-Version")).map(_.getValue.toLong) -> lines } finally { response.close() } } + } // Add SparkStructuredStreaming in the USER_AGENT header, in order for the delta sharing server // to recognize the request for streaming, and take corresponding actions. diff --git a/spark/src/main/scala/io/delta/sharing/spark/DeltaSharingSource.scala b/spark/src/main/scala/io/delta/sharing/spark/DeltaSharingSource.scala index 8eccf2cad..ed55b5218 100644 --- a/spark/src/main/scala/io/delta/sharing/spark/DeltaSharingSource.scala +++ b/spark/src/main/scala/io/delta/sharing/spark/DeltaSharingSource.scala @@ -132,6 +132,9 @@ case class DeltaSharingSource( private val tableId = initSnapshot.metadata.id + private val refreshPresignedUrls = spark.sessionState.conf.getConfString( + "spark.delta.sharing.source.refreshPresignedUrls.enabled", "true").toBoolean + // Records until which offset the delta sharing source has been processing the table files. private var previousOffset: DeltaSharingSourceOffset = null @@ -149,6 +152,11 @@ case class DeltaSharingSource( // a variable to be used by the CachedTableManager to refresh the presigned urls if the query // runs for a long time. private var latestRefreshFunc = () => { Map.empty[String, String] } + // The latest timestamp in millisecond, records the time of the last rpc sent to the server to + // fetch the pre-signed urls. + // This is used to track whether the pre-signed urls stored in sortedFetchedFiles are going to + // expire and need a refresh. + private var lastQueryTableTimestamp: Long = -1 // Check the latest table version from the delta sharing server through the client.getTableVersion // RPC. Adding a minimum interval of QUERY_TABLE_VERSION_INTERVAL_MILLIS between two consecutive @@ -231,6 +239,7 @@ case class DeltaSharingSource( fromIndex: Long, isStartingVersion: Boolean, currentLatestVersion: Long): Unit = { + lastQueryTableTimestamp = System.currentTimeMillis() if (isStartingVersion) { // If isStartingVersion is true, it means to fetch the snapshot at the fromVersion, which may // include table changes from previous versions. @@ -307,6 +316,7 @@ case class DeltaSharingSource( fromVersion: Long, fromIndex: Long, currentLatestVersion: Long): Unit = { + lastQueryTableTimestamp = System.currentTimeMillis() val tableFiles = deltaLog.client.getCDFFiles( deltaLog.table, Map(DeltaSharingOptions.CDF_START_VERSION -> fromVersion.toString), true) latestRefreshFunc = () => { @@ -459,6 +469,51 @@ case class DeltaSharingSource( endOffset: DeltaSharingSourceOffset): DataFrame = { maybeGetFileChanges(startVersion, startIndex, isStartingVersion) + if (refreshPresignedUrls && + (CachedTableManager.INSTANCE.preSignedUrlExpirationMs + lastQueryTableTimestamp - + System.currentTimeMillis() < CachedTableManager.INSTANCE.refreshThresholdMs)) { + // force a refresh if needed. + lastQueryTableTimestamp = System.currentTimeMillis() + val newIdToUrl = latestRefreshFunc() + sortedFetchedFiles = sortedFetchedFiles.map { indexedFile => + IndexedFile( + version = indexedFile.version, + index = indexedFile.index, + add = if (indexedFile.add == null) { + null + } else { + val newUrl = newIdToUrl.getOrElse( + indexedFile.add.id, + throw new IllegalStateException(s"cannot find url for id ${indexedFile.add.id} " + + s"when refreshing table ${deltaLog.path}") + ) + indexedFile.add.copy(url = newUrl) + }, + remove = if (indexedFile.remove == null) { + null + } else { + val newUrl = newIdToUrl.getOrElse( + indexedFile.remove.id, + throw new IllegalStateException(s"cannot find url for id ${indexedFile.remove.id} " + + s"when refreshing table ${deltaLog.path}") + ) + indexedFile.remove.copy(url = newUrl) + }, + cdc = if (indexedFile.cdc == null) { + null + } else { + val newUrl = newIdToUrl.getOrElse( + indexedFile.cdc.id, + throw new IllegalStateException(s"cannot find url for id ${indexedFile.cdc.id} " + + s"when refreshing table ${deltaLog.path}") + ) + indexedFile.cdc.copy(url = newUrl) + }, + isLast = indexedFile.isLast + ) + } + } + val fileActions = sortedFetchedFiles.takeWhile { case IndexedFile(version, index, _, _, _, _) => version < endOffset.tableVersion || @@ -545,7 +600,8 @@ case class DeltaSharingSource( removeFiles, schema, isStreaming = true, - latestRefreshFunc + latestRefreshFunc, + lastQueryTableTimestamp ) } diff --git a/spark/src/main/scala/io/delta/sharing/spark/RemoteDeltaCDFRelation.scala b/spark/src/main/scala/io/delta/sharing/spark/RemoteDeltaCDFRelation.scala index 9fb3fa6d4..994cbeb16 100644 --- a/spark/src/main/scala/io/delta/sharing/spark/RemoteDeltaCDFRelation.scala +++ b/spark/src/main/scala/io/delta/sharing/spark/RemoteDeltaCDFRelation.scala @@ -71,7 +71,9 @@ object DeltaSharingCDFReader { removeFiles: Seq[RemoveFile], schema: StructType, isStreaming: Boolean, - refresher: () => Map[String, String]): DataFrame = { + refresher: () => Map[String, String], + lastQueryTableTimestamp: Long = System.currentTimeMillis() + ): DataFrame = { val dfs = ListBuffer[DataFrame]() val refs = ListBuffer[WeakReference[AnyRef]]() @@ -92,7 +94,8 @@ object DeltaSharingCDFReader { getIdToUrl(addFiles, cdfFiles, removeFiles), refs, params.profileProvider, - refresher + refresher, + lastQueryTableTimestamp ) dfs.reduce((df1, df2) => df1.unionAll(df2)) diff --git a/spark/src/main/scala/io/delta/sharing/spark/RemoteDeltaFileIndex.scala b/spark/src/main/scala/io/delta/sharing/spark/RemoteDeltaFileIndex.scala index e4c26047f..10464459a 100644 --- a/spark/src/main/scala/io/delta/sharing/spark/RemoteDeltaFileIndex.scala +++ b/spark/src/main/scala/io/delta/sharing/spark/RemoteDeltaFileIndex.scala @@ -111,7 +111,7 @@ private[sharing] abstract class RemoteDeltaFileIndexBase( // not perform json predicate based filtering. protected def convertToJsonPredicate(partitionFilters: Seq[Expression]) : Option[String] = { if (!params.spark.sessionState.conf.getConfString( - "spark.delta.sharing.jsonPredicateHints.enabled", "false").toBoolean) { + "spark.delta.sharing.jsonPredicateHints.enabled", "true").toBoolean) { return None } try { diff --git a/spark/src/main/scala/org/apache/spark/delta/sharing/PreSignedUrlCache.scala b/spark/src/main/scala/org/apache/spark/delta/sharing/PreSignedUrlCache.scala index f5f003e79..2e52128cb 100644 --- a/spark/src/main/scala/org/apache/spark/delta/sharing/PreSignedUrlCache.scala +++ b/spark/src/main/scala/org/apache/spark/delta/sharing/PreSignedUrlCache.scala @@ -45,9 +45,9 @@ class CachedTable( val refresher: () => Map[String, String]) class CachedTableManager( - preSignedUrlExpirationMs: Long, + val preSignedUrlExpirationMs: Long, refreshCheckIntervalMs: Long, - refreshThresholdMs: Long, + val refreshThresholdMs: Long, expireAfterAccessMs: Long) extends Logging { private val cache = new java.util.concurrent.ConcurrentHashMap[String, CachedTable]() @@ -134,19 +134,35 @@ class CachedTableManager( * signed url cache of this table form the cache. * @param profileProvider a profile Provider that can provide customized refresher function. * @param refresher A function to re-generate pre signed urls for the table. + * @param lastQueryTableTimestamp A timestamp to indicate the last time the idToUrl mapping is + * generated, to refresh the urls in time based on it. */ def register( tablePath: String, idToUrl: Map[String, String], refs: Seq[WeakReference[AnyRef]], profileProvider: DeltaSharingProfileProvider, - refresher: () => Map[String, String]): Unit = { + refresher: () => Map[String, String], + lastQueryTableTimestamp: Long = System.currentTimeMillis()): Unit = { val customTablePath = profileProvider.getCustomTablePath(tablePath) val customRefresher = profileProvider.getCustomRefresher(refresher) val cachedTable = new CachedTable( - preSignedUrlExpirationMs + System.currentTimeMillis(), - idToUrl, + if (preSignedUrlExpirationMs + lastQueryTableTimestamp - System.currentTimeMillis() < + refreshThresholdMs) { + // If there is a refresh, start counting from now. + preSignedUrlExpirationMs + System.currentTimeMillis() + } else { + // Otherwise, start counting from lastQueryTableTimestamp. + preSignedUrlExpirationMs + lastQueryTableTimestamp + }, + idToUrl = if (preSignedUrlExpirationMs + lastQueryTableTimestamp - System.currentTimeMillis() + < refreshThresholdMs) { + // force a refresh upon register + customRefresher() + } else { + idToUrl + }, refs, System.currentTimeMillis(), customRefresher diff --git a/spark/src/test/scala/org/apache/spark/delta/sharing/CachedTableManagerSuite.scala b/spark/src/test/scala/org/apache/spark/delta/sharing/CachedTableManagerSuite.scala index 1522e0c56..51c9be90c 100644 --- a/spark/src/test/scala/org/apache/spark/delta/sharing/CachedTableManagerSuite.scala +++ b/spark/src/test/scala/org/apache/spark/delta/sharing/CachedTableManagerSuite.scala @@ -81,6 +81,22 @@ class CachedTableManagerSuite extends SparkFunSuite { intercept[IllegalStateException](manager.getPreSignedUrl( provider.getCustomTablePath("test-table-path3"), "id1")) } + + manager.register( + "test-table-path4", + Map("id1" -> "url1", "id2" -> "url2"), + Seq(new WeakReference(ref)), + provider, + () => { + Map("id1" -> "url3", "id2" -> "url4") + }, + -1 + ) + // We should get new urls immediately because it's refreshed upon register + assert(manager.getPreSignedUrl(provider.getCustomTablePath("test-table-path4"), + "id1")._1 == "url3") + assert(manager.getPreSignedUrl(provider.getCustomTablePath("test-table-path4"), + "id2")._1 == "url4") } finally { manager.stop() } @@ -108,7 +124,8 @@ class CachedTableManagerSuite extends SparkFunSuite { Thread.sleep(1000) // We should remove the cached table when it's not accessed intercept[IllegalStateException](manager.getPreSignedUrl( - provider.getCustomTablePath("test-table-path"), "id1")) + provider.getCustomTablePath("test-table-path"), "id1") + ) } finally { manager.stop() }