Skip to content

Commit

Permalink
Added S3 metadata reader to estimate partitions number basing on a de…
Browse files Browse the repository at this point in the history
…sired partition size in bytes.

Signed-off-by: Grisha Pomadchin <[email protected]>
  • Loading branch information
pomadchin committed Jul 25, 2017
1 parent c9db248 commit 18c3176
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,15 @@ package geotrellis.spark.io.s3.testkit

import geotrellis.spark.io.s3._
import geotrellis.util.LazyLogging

import com.amazonaws.services.s3.model._
import com.amazonaws.services.s3.internal.AmazonS3ExceptionBuilder
import org.apache.commons.io.IOUtils

import java.io.ByteArrayInputStream
import java.util.concurrent.ConcurrentHashMap

import scala.collection.immutable.TreeMap
import scala.collection.JavaConverters._
import scala.collection.mutable

class MockS3Client() extends S3Client with LazyLogging {
import MockS3Client._
Expand Down Expand Up @@ -112,7 +111,18 @@ class MockS3Client() extends S3Client with LazyLogging {
}
}

def listKeys(listObjectsRequest: ListObjectsRequest): Seq[String] = ???
def listKeys(listObjectsRequest: ListObjectsRequest): Seq[String] = {
var listing: ObjectListing = null
val result = mutable.ListBuffer[String]()
do {
listing = listObjects(listObjectsRequest)
// avoid including "directories" in the input split, can cause 403 errors on GET
result ++= listing.getObjectSummaries.asScala.map(_.getKey).filterNot(_ endsWith "/")
listObjectsRequest.setMarker(listing.getNextMarker)
} while (listing.isTruncated)

result
}

def readBytes(getObjectRequest: GetObjectRequest): Array[Byte] = {
val obj = getObject(getObjectRequest)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class AmazonS3Client(s3client: AWSAmazonS3Client) extends S3Client {
listObjectsRequest.setMarker(listing.getNextMarker)
} while (listing.isTruncated)

result.toSeq
result
}

def getObject(getObjectRequest: GetObjectRequest): S3Object =
Expand Down
2 changes: 1 addition & 1 deletion s3/src/main/scala/geotrellis/spark/io/s3/S3Client.scala
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ trait S3Client extends LazyLogging {
.toList

// Empty listings cause malformed XML to be sent to AWS and lead to unhelpful exceptions
if (! listings.isEmpty) {
if (listings.nonEmpty) {
deleteObjects(bucket, listings)
if (listing.isTruncated) deleteListing(bucket, listNextBatchOfObjects(listing))
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package geotrellis.spark.io.s3

import geotrellis.raster.io.geotiff.reader.TiffTagsReader
import geotrellis.raster.io.geotiff.tags.TiffTags
import geotrellis.spark.io.RasterReader
import geotrellis.spark.io.s3.util.S3RangeReader
import geotrellis.util.LazyLogging

import com.amazonaws.services.s3.model.ListObjectsRequest

case class S3GeoTiffMetadataReader(
bucket: String,
prefix: String,
getS3Client: () => S3Client = () => S3Client.DEFAULT,
delimiter: Option[String] = None
) extends LazyLogging {
lazy val tiffTags: List[TiffTags] = {
val s3Client = getS3Client()

val listObjectsRequest =
delimiter
.fold(new ListObjectsRequest(bucket, prefix, null, null, null))(new ListObjectsRequest(bucket, prefix, null, _, null))

s3Client
.listKeys(listObjectsRequest)
.map(key => TiffTagsReader.read(S3RangeReader(bucket, key, s3Client)))
.toList
}

lazy val averagePixelSize: Option[Int] =
if(tiffTags.nonEmpty) {
Some((tiffTags.map(_.bytesPerPixel.toLong).sum / tiffTags.length).toInt)
} else {
logger.error(s"No tiff tags in $bucket/$prefix")
None
}

def windowsCount(maxTileSize: Option[Int] = None): Int =
tiffTags
.flatMap { tiffTag => RasterReader.listWindows(tiffTag.cols, tiffTag.rows, maxTileSize) }
.length

def estimatePartitionsNumber(partitionBytes: Long, maxTileSize: Option[Int] = None): Option[Int] = {
(maxTileSize, averagePixelSize) match {
case (Some(tileSize), Some(pixelSize)) =>
val numPartitions = (tileSize * pixelSize * windowsCount(maxTileSize) / partitionBytes).toInt
logger.info(s"Estimated partitions number: $numPartitions")
if (numPartitions > 0) Some(numPartitions)
else None
case _ =>
logger.error("Can't estimate partitions number")
None
}
}
}

object S3GeoTiffMetadataReader {
def apply(
bucket: String,
prefix: String,
options: S3GeoTiffRDD.Options
): S3GeoTiffMetadataReader = S3GeoTiffMetadataReader(bucket, prefix, options.getS3Client, options.delimiter)
}
28 changes: 14 additions & 14 deletions s3/src/main/scala/geotrellis/spark/io/s3/S3GeoTiffRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import geotrellis.raster.io.geotiff.tags.TiffTags
import geotrellis.spark._
import geotrellis.spark.io.RasterReader
import geotrellis.spark.io.s3.util.S3RangeReader
import geotrellis.util.{StreamingByteReader, LazyLogging}
import geotrellis.util.{LazyLogging, StreamingByteReader}
import geotrellis.vector._

import org.apache.hadoop.conf.Configuration
Expand Down Expand Up @@ -50,9 +50,11 @@ object S3GeoTiffRDD extends LazyLogging {
* @param maxTileSize Maximum allowed size of each tiles in output RDD.
* May result in a one input GeoTiff being split amongst multiple records if it exceeds this size.
* If no maximum tile size is specific, then each file file is read fully.
* 1024 by defaut.
* @param numPartitions How many partitions Spark should create when it repartitions the data.
* @param partitionBytes Desired partition size in bytes, at least one item per partition will be assigned.
This option is incompatible with the maxTileSize option.
* 128 Mb by default.
* @param chunkSize How many bytes should be read in at a time.
* @param delimiter Delimiter to use for S3 objet listings. See
* @param getS3Client A function to instantiate an S3Client. Must be serializable.
Expand All @@ -62,9 +64,9 @@ object S3GeoTiffRDD extends LazyLogging {
crs: Option[CRS] = None,
timeTag: String = GEOTIFF_TIME_TAG_DEFAULT,
timeFormat: String = GEOTIFF_TIME_FORMAT_DEFAULT,
maxTileSize: Option[Int] = None,
maxTileSize: Option[Int] = Some(1024),
numPartitions: Option[Int] = None,
partitionBytes: Option[Long] = None,
partitionBytes: Option[Long] = Some(128l * 1024 * 1024),
chunkSize: Option[Int] = None,
delimiter: Option[String] = None,
getS3Client: () => S3Client = () => S3Client.DEFAULT
Expand Down Expand Up @@ -105,9 +107,10 @@ object S3GeoTiffRDD extends LazyLogging {
(implicit sc: SparkContext, rr: RasterReader[Options, (I, V)]): RDD[(K, V)] = {

val conf = configuration(bucket, prefix, options)
lazy val sourceMetadata = S3GeoTiffMetadataReader(bucket, prefix, options)

options.maxTileSize match {
case Some(tileSize) =>
case Some(_) =>
val objectRequestsToDimensions: RDD[(GetObjectRequest, (Int, Int))] =
sc.newAPIHadoopRDD(
conf,
Expand All @@ -116,7 +119,7 @@ object S3GeoTiffRDD extends LazyLogging {
classOf[TiffTags]
).mapValues { tiffTags => (tiffTags.cols, tiffTags.rows) }

apply[I, K, V](objectRequestsToDimensions, uriToKey, options)
apply[I, K, V](objectRequestsToDimensions, uriToKey, options, sourceMetadata)
case None =>
sc.newAPIHadoopRDD(
conf,
Expand Down Expand Up @@ -151,7 +154,7 @@ object S3GeoTiffRDD extends LazyLogging {
* @param uriToKey function to transform input key basing on the URI information.
* @param options An instance of [[Options]] that contains any user defined or default settings.
*/
def apply[I, K, V](objectRequestsToDimensions: RDD[(GetObjectRequest, (Int, Int))], uriToKey: (URI, I) => K, options: Options)
def apply[I, K, V](objectRequestsToDimensions: RDD[(GetObjectRequest, (Int, Int))], uriToKey: (URI, I) => K, options: Options, sourceMetadata: => S3GeoTiffMetadataReader)
(implicit rr: RasterReader[Options, (I, V)]): RDD[(K, V)] = {

val windows =
Expand All @@ -167,14 +170,11 @@ object S3GeoTiffRDD extends LazyLogging {
case None =>
options.partitionBytes match {
case Some(byteCount) =>
// Because we do not have cell type information, we cannot
// perform the necessary estimates for the partition bytes.
logger.warn(
s"${classOf[Options].getName}.partitionBytes set with maxTileSize, " +
"cannot perform partitioning based on byte count. Option ignored. " +
"Use numPartitions instead.")
windows
case None =>
sourceMetadata.estimatePartitionsNumber(byteCount, options.maxTileSize) match {
case Some(numPartitions) if numPartitions != windows.partitions.length => windows.repartition(numPartitions)
case _ => windows
}
case _ =>
windows
}
}
Expand Down

0 comments on commit 18c3176

Please sign in to comment.