Skip to content

Commit

Permalink
Merge branch 'dialberg_delta_sharing_python_client_pull_request_12M23…
Browse files Browse the repository at this point in the history
…' of https://github.com/dialberg/delta-sharing into dialberg_delta_sharing_python_client_pull_request_12M23
  • Loading branch information
dialberg committed May 12, 2023
2 parents 1159228 + 24771a3 commit c0c5522
Show file tree
Hide file tree
Showing 6 changed files with 147 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,18 @@

package io.delta.sharing.spark

import java.io.{BufferedReader, InputStream, InputStreamReader}
import java.net.{URL, URLEncoder}
import java.nio.charset.StandardCharsets.UTF_8
import java.sql.Timestamp
import java.time.LocalDateTime
import java.time.format.DateTimeFormatter.ISO_DATE_TIME

import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.{ArrayBuffer, ListBuffer}

import org.apache.commons.io.IOUtils
import org.apache.commons.io.input.BoundedInputStream
import org.apache.hadoop.util.VersionInfo
import org.apache.http.{HttpHeaders, HttpHost, HttpStatus}
import org.apache.http.client.config.RequestConfig
Expand Down Expand Up @@ -188,7 +190,7 @@ private[spark] class DeltaSharingRestClient(
val target =
getTargetUrl(s"/shares/$encodedShareName/schemas/$encodedSchemaName/tables/" +
s"$encodedTableName/version$encodedParam")
val (version, _) = getResponse(new HttpGet(target), true)
val (version, _) = getResponse(new HttpGet(target), true, true)
version.getOrElse {
throw new IllegalStateException("Cannot find Delta-Table-Version in the header")
}
Expand Down Expand Up @@ -324,30 +326,35 @@ private[spark] class DeltaSharingRestClient(
}

private def getNDJson(target: String, requireVersion: Boolean = true): (Long, Seq[String]) = {
val (version, response) = getResponse(new HttpGet(target))
val (version, lines) = getResponse(new HttpGet(target))
version.getOrElse {
if (requireVersion) {
throw new IllegalStateException("Cannot find Delta-Table-Version in the header")
} else {
0L
}
} -> response.split("[\n\r]+")
} -> lines
}

private def getNDJson[T: Manifest](target: String, data: T): (Long, Seq[String]) = {
val httpPost = new HttpPost(target)
val json = JsonUtils.toJson(data)
httpPost.setHeader("Content-type", "application/json")
httpPost.setEntity(new StringEntity(json, UTF_8))
val (version, response) = getResponse(httpPost)
val (version, lines) = getResponse(httpPost)
version.getOrElse {
throw new IllegalStateException("Cannot find Delta-Table-Version in the header")
} -> response.split("[\n\r]+")
} -> lines
}

private def getJson[R: Manifest](target: String): R = {
val (_, response) = getResponse(new HttpGet(target))
JsonUtils.fromJson[R](response)
val (_, response) = getResponse(new HttpGet(target), false, true)
if (response.size != 1) {
throw new IllegalStateException(
"Unexpected response for target: " + target + ", response=" + response
)
}
JsonUtils.fromJson[R](response(0))
}

private def getHttpHost(endpoint: String): HttpHost = {
Expand Down Expand Up @@ -393,11 +400,17 @@ private[spark] class DeltaSharingRestClient(
/**
* Send the http request and return the table version in the header if any, and the response
* content.
*
* The response can be:
* - empty if allowNoContent is true.
* - single string, if fetchAsOneString is true.
* - multi-line response (typically, one per action). This is the default.
*/
private def getResponse(
httpRequest: HttpRequestBase,
allowNoContent: Boolean = false
): (Option[Long], String) =
allowNoContent: Boolean = false,
fetchAsOneString: Boolean = false
): (Option[Long], Seq[String]) = {
RetryUtils.runWithExponentialBackoff(numRetries) {
val profile = profileProvider.getProfile
val response = client.execute(
Expand All @@ -408,12 +421,26 @@ private[spark] class DeltaSharingRestClient(
try {
val status = response.getStatusLine()
val entity = response.getEntity()
val body = if (entity == null) {
""
val lines = if (entity == null) {
List("")
} else {
val input = entity.getContent()
try {
IOUtils.toString(input, UTF_8)
if (fetchAsOneString) {
Seq(IOUtils.toString(input, UTF_8))
} else {
val reader = new BufferedReader(
new InputStreamReader(new BoundedInputStream(input), UTF_8)
)
var line: Option[String] = None
val lineBuffer = ListBuffer[String]()
while ({
line = Option(reader.readLine()); line.isDefined
}) {
lineBuffer += line.get
}
lineBuffer.toList
}
} finally {
input.close()
}
Expand All @@ -427,15 +454,18 @@ private[spark] class DeltaSharingRestClient(
additionalErrorInfo = s"It may be caused by an expired token as it has expired " +
s"at ${profile.expirationTime}"
}
// Only show the last 100 lines in the error to keep it contained.
val responseToShow = lines.drop(lines.size - 100).mkString("\n")
throw new UnexpectedHttpStatus(
s"HTTP request failed with status: $status $body. $additionalErrorInfo",
s"HTTP request failed with status: $status $responseToShow. $additionalErrorInfo",
statusCode)
}
Option(response.getFirstHeader("Delta-Table-Version")).map(_.getValue.toLong) -> body
Option(response.getFirstHeader("Delta-Table-Version")).map(_.getValue.toLong) -> lines
} finally {
response.close()
}
}
}

// Add SparkStructuredStreaming in the USER_AGENT header, in order for the delta sharing server
// to recognize the request for streaming, and take corresponding actions.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,9 @@ case class DeltaSharingSource(

private val tableId = initSnapshot.metadata.id

private val refreshPresignedUrls = spark.sessionState.conf.getConfString(
"spark.delta.sharing.source.refreshPresignedUrls.enabled", "true").toBoolean

// Records until which offset the delta sharing source has been processing the table files.
private var previousOffset: DeltaSharingSourceOffset = null

Expand All @@ -149,6 +152,11 @@ case class DeltaSharingSource(
// a variable to be used by the CachedTableManager to refresh the presigned urls if the query
// runs for a long time.
private var latestRefreshFunc = () => { Map.empty[String, String] }
// The latest timestamp in millisecond, records the time of the last rpc sent to the server to
// fetch the pre-signed urls.
// This is used to track whether the pre-signed urls stored in sortedFetchedFiles are going to
// expire and need a refresh.
private var lastQueryTableTimestamp: Long = -1

// Check the latest table version from the delta sharing server through the client.getTableVersion
// RPC. Adding a minimum interval of QUERY_TABLE_VERSION_INTERVAL_MILLIS between two consecutive
Expand Down Expand Up @@ -231,6 +239,7 @@ case class DeltaSharingSource(
fromIndex: Long,
isStartingVersion: Boolean,
currentLatestVersion: Long): Unit = {
lastQueryTableTimestamp = System.currentTimeMillis()
if (isStartingVersion) {
// If isStartingVersion is true, it means to fetch the snapshot at the fromVersion, which may
// include table changes from previous versions.
Expand Down Expand Up @@ -307,6 +316,7 @@ case class DeltaSharingSource(
fromVersion: Long,
fromIndex: Long,
currentLatestVersion: Long): Unit = {
lastQueryTableTimestamp = System.currentTimeMillis()
val tableFiles = deltaLog.client.getCDFFiles(
deltaLog.table, Map(DeltaSharingOptions.CDF_START_VERSION -> fromVersion.toString), true)
latestRefreshFunc = () => {
Expand Down Expand Up @@ -459,6 +469,51 @@ case class DeltaSharingSource(
endOffset: DeltaSharingSourceOffset): DataFrame = {
maybeGetFileChanges(startVersion, startIndex, isStartingVersion)

if (refreshPresignedUrls &&
(CachedTableManager.INSTANCE.preSignedUrlExpirationMs + lastQueryTableTimestamp -
System.currentTimeMillis() < CachedTableManager.INSTANCE.refreshThresholdMs)) {
// force a refresh if needed.
lastQueryTableTimestamp = System.currentTimeMillis()
val newIdToUrl = latestRefreshFunc()
sortedFetchedFiles = sortedFetchedFiles.map { indexedFile =>
IndexedFile(
version = indexedFile.version,
index = indexedFile.index,
add = if (indexedFile.add == null) {
null
} else {
val newUrl = newIdToUrl.getOrElse(
indexedFile.add.id,
throw new IllegalStateException(s"cannot find url for id ${indexedFile.add.id} " +
s"when refreshing table ${deltaLog.path}")
)
indexedFile.add.copy(url = newUrl)
},
remove = if (indexedFile.remove == null) {
null
} else {
val newUrl = newIdToUrl.getOrElse(
indexedFile.remove.id,
throw new IllegalStateException(s"cannot find url for id ${indexedFile.remove.id} " +
s"when refreshing table ${deltaLog.path}")
)
indexedFile.remove.copy(url = newUrl)
},
cdc = if (indexedFile.cdc == null) {
null
} else {
val newUrl = newIdToUrl.getOrElse(
indexedFile.cdc.id,
throw new IllegalStateException(s"cannot find url for id ${indexedFile.cdc.id} " +
s"when refreshing table ${deltaLog.path}")
)
indexedFile.cdc.copy(url = newUrl)
},
isLast = indexedFile.isLast
)
}
}

val fileActions = sortedFetchedFiles.takeWhile {
case IndexedFile(version, index, _, _, _, _) =>
version < endOffset.tableVersion ||
Expand Down Expand Up @@ -545,7 +600,8 @@ case class DeltaSharingSource(
removeFiles,
schema,
isStreaming = true,
latestRefreshFunc
latestRefreshFunc,
lastQueryTableTimestamp
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ object DeltaSharingCDFReader {
removeFiles: Seq[RemoveFile],
schema: StructType,
isStreaming: Boolean,
refresher: () => Map[String, String]): DataFrame = {
refresher: () => Map[String, String],
lastQueryTableTimestamp: Long = System.currentTimeMillis()
): DataFrame = {
val dfs = ListBuffer[DataFrame]()
val refs = ListBuffer[WeakReference[AnyRef]]()

Expand All @@ -92,7 +94,8 @@ object DeltaSharingCDFReader {
getIdToUrl(addFiles, cdfFiles, removeFiles),
refs,
params.profileProvider,
refresher
refresher,
lastQueryTableTimestamp
)

dfs.reduce((df1, df2) => df1.unionAll(df2))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ private[sharing] abstract class RemoteDeltaFileIndexBase(
// not perform json predicate based filtering.
protected def convertToJsonPredicate(partitionFilters: Seq[Expression]) : Option[String] = {
if (!params.spark.sessionState.conf.getConfString(
"spark.delta.sharing.jsonPredicateHints.enabled", "false").toBoolean) {
"spark.delta.sharing.jsonPredicateHints.enabled", "true").toBoolean) {
return None
}
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ class CachedTable(
val refresher: () => Map[String, String])

class CachedTableManager(
preSignedUrlExpirationMs: Long,
val preSignedUrlExpirationMs: Long,
refreshCheckIntervalMs: Long,
refreshThresholdMs: Long,
val refreshThresholdMs: Long,
expireAfterAccessMs: Long) extends Logging {

private val cache = new java.util.concurrent.ConcurrentHashMap[String, CachedTable]()
Expand Down Expand Up @@ -134,19 +134,35 @@ class CachedTableManager(
* signed url cache of this table form the cache.
* @param profileProvider a profile Provider that can provide customized refresher function.
* @param refresher A function to re-generate pre signed urls for the table.
* @param lastQueryTableTimestamp A timestamp to indicate the last time the idToUrl mapping is
* generated, to refresh the urls in time based on it.
*/
def register(
tablePath: String,
idToUrl: Map[String, String],
refs: Seq[WeakReference[AnyRef]],
profileProvider: DeltaSharingProfileProvider,
refresher: () => Map[String, String]): Unit = {
refresher: () => Map[String, String],
lastQueryTableTimestamp: Long = System.currentTimeMillis()): Unit = {
val customTablePath = profileProvider.getCustomTablePath(tablePath)
val customRefresher = profileProvider.getCustomRefresher(refresher)

val cachedTable = new CachedTable(
preSignedUrlExpirationMs + System.currentTimeMillis(),
idToUrl,
if (preSignedUrlExpirationMs + lastQueryTableTimestamp - System.currentTimeMillis() <
refreshThresholdMs) {
// If there is a refresh, start counting from now.
preSignedUrlExpirationMs + System.currentTimeMillis()
} else {
// Otherwise, start counting from lastQueryTableTimestamp.
preSignedUrlExpirationMs + lastQueryTableTimestamp
},
idToUrl = if (preSignedUrlExpirationMs + lastQueryTableTimestamp - System.currentTimeMillis()
< refreshThresholdMs) {
// force a refresh upon register
customRefresher()
} else {
idToUrl
},
refs,
System.currentTimeMillis(),
customRefresher
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,22 @@ class CachedTableManagerSuite extends SparkFunSuite {
intercept[IllegalStateException](manager.getPreSignedUrl(
provider.getCustomTablePath("test-table-path3"), "id1"))
}

manager.register(
"test-table-path4",
Map("id1" -> "url1", "id2" -> "url2"),
Seq(new WeakReference(ref)),
provider,
() => {
Map("id1" -> "url3", "id2" -> "url4")
},
-1
)
// We should get new urls immediately because it's refreshed upon register
assert(manager.getPreSignedUrl(provider.getCustomTablePath("test-table-path4"),
"id1")._1 == "url3")
assert(manager.getPreSignedUrl(provider.getCustomTablePath("test-table-path4"),
"id2")._1 == "url4")
} finally {
manager.stop()
}
Expand Down Expand Up @@ -108,7 +124,8 @@ class CachedTableManagerSuite extends SparkFunSuite {
Thread.sleep(1000)
// We should remove the cached table when it's not accessed
intercept[IllegalStateException](manager.getPreSignedUrl(
provider.getCustomTablePath("test-table-path"), "id1"))
provider.getCustomTablePath("test-table-path"), "id1")
)
} finally {
manager.stop()
}
Expand Down

0 comments on commit c0c5522

Please sign in to comment.