Skip to content

Commit

Permalink
Add tests and fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
sryza committed Jan 26, 2015
1 parent 0d504f1 commit 864514b
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,9 @@ class SparkHadoopUtil extends Logging {
* statistics are only available as of Hadoop 2.5 (see HADOOP-10688).
* Returns None if the required method can't be found.
*/
private[spark] def getFSBytesReadOnThreadCallback(conf: Configuration)
: Option[() => Long] = {
private[spark] def getFSBytesReadOnThreadCallback(): Option[() => Long] = {
try {
val threadStats = getFileSystemThreadStatistics(conf)
val threadStats = getFileSystemThreadStatistics()
val getBytesReadMethod = getFileSystemThreadStatisticsMethod("getBytesRead")
val f = () => threadStats.map(getBytesReadMethod.invoke(_).asInstanceOf[Long]).sum
val baselineBytesRead = f()
Expand All @@ -156,10 +155,9 @@ class SparkHadoopUtil extends Logging {
* statistics are only available as of Hadoop 2.5 (see HADOOP-10688).
* Returns None if the required method can't be found.
*/
private[spark] def getFSBytesWrittenOnThreadCallback(conf: Configuration)
: Option[() => Long] = {
private[spark] def getFSBytesWrittenOnThreadCallback(): Option[() => Long] = {
try {
val threadStats = getFileSystemThreadStatistics(conf)
val threadStats = getFileSystemThreadStatistics()
val getBytesWrittenMethod = getFileSystemThreadStatisticsMethod("getBytesWritten")
val f = () => threadStats.map(getBytesWrittenMethod.invoke(_).asInstanceOf[Long]).sum
val baselineBytesWritten = f()
Expand All @@ -172,7 +170,7 @@ class SparkHadoopUtil extends Logging {
}
}

private def getFileSystemThreadStatistics(conf: Configuration): Seq[AnyRef] = {
private def getFileSystemThreadStatistics(): Seq[AnyRef] = {
val stats = FileSystem.getAllStatistics()
stats.map(Utils.invoke(classOf[Statistics], _, "getThreadStatistics"))
}
Expand Down
5 changes: 3 additions & 2 deletions core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ class HadoopRDD[K, V](
val bytesReadCallback = inputMetrics.bytesReadCallback.orElse {
split.inputSplit.value match {
case _: FileSplit | _: CombineFileSplit =>
SparkHadoopUtil.get.getFSBytesReadOnThreadCallback(jobConf)
SparkHadoopUtil.get.getFSBytesReadOnThreadCallback()
case _ => None
}
}
Expand Down Expand Up @@ -255,7 +255,8 @@ class HadoopRDD[K, V](
reader.close()
if (bytesReadCallback.isDefined) {
inputMetrics.updateBytesRead()
} else if (split.inputSplit.value.isInstanceOf[FileSplit]) {
} else if (split.inputSplit.value.isInstanceOf[FileSplit] ||
split.inputSplit.value.isInstanceOf[CombineFileSplit]) {
// If we can't get the bytes read from the FS stats, fall back to the split size,
// which may be inaccurate.
try {
Expand Down
7 changes: 4 additions & 3 deletions core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import org.apache.spark.Logging
import org.apache.spark.Partition
import org.apache.spark.SerializableWritable
import org.apache.spark.{SparkContext, TaskContext}
import org.apache.spark.executor.{DataReadMethod, InputMetrics}
import org.apache.spark.executor.DataReadMethod
import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil
import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -117,7 +117,7 @@ class NewHadoopRDD[K, V](
val bytesReadCallback = inputMetrics.bytesReadCallback.orElse {
split.serializableHadoopSplit.value match {
case _: FileSplit | _: CombineFileSplit =>
SparkHadoopUtil.get.getFSBytesReadOnThreadCallback(conf)
SparkHadoopUtil.get.getFSBytesReadOnThreadCallback()
case _ => None
}
}
Expand Down Expand Up @@ -163,7 +163,8 @@ class NewHadoopRDD[K, V](
reader.close()
if (bytesReadCallback.isDefined) {
inputMetrics.updateBytesRead()
} else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit]) {
} else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] ||
split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) {
// If we can't get the bytes read from the FS stats, fall back to the split size,
// which may be inaccurate.
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -990,7 +990,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
val committer = format.getOutputCommitter(hadoopContext)
committer.setupTask(hadoopContext)

val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context, config)
val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context)

val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K,V]]
try {
Expand Down Expand Up @@ -1061,7 +1061,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
// around by taking a mod. We expect that no task will be attempted 2 billion times.
val taskAttemptId = (context.taskAttemptId % Int.MaxValue).toInt

val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context, config)
val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context)

writer.setup(context.stageId, context.partitionId, taskAttemptId)
writer.open()
Expand All @@ -1086,9 +1086,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
writer.commitJob()
}

private def initHadoopOutputMetrics(context: TaskContext, config: Configuration)
: (OutputMetrics, Option[() => Long]) = {
val bytesWrittenCallback = SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback(config)
private def initHadoopOutputMetrics(context: TaskContext): (OutputMetrics, Option[() => Long]) = {
val bytesWrittenCallback = SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback()
val outputMetrics = new OutputMetrics(DataWriteMethod.Hadoop)
if (bytesWrittenCallback.isDefined) {
context.taskMetrics.outputMetrics = Some(outputMetrics)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,16 @@ import org.scalatest.FunSuite
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.io.{LongWritable, Text}
import org.apache.hadoop.mapreduce.lib.input.{TextInputFormat => NewTextInputFormat}
import org.apache.hadoop.mapred.{FileSplit => OldFileSplit, InputSplit => OldInputSplit, JobConf,
LineRecordReader => OldLineRecordReader, RecordReader => OldRecordReader, Reporter,
TextInputFormat => OldTextInputFormat}
import org.apache.hadoop.mapred.lib.{CombineFileInputFormat => OldCombineFileInputFormat,
CombineFileSplit => OldCombineFileSplit, CombineFileRecordReader => OldCombineFileRecordReader}
import org.apache.hadoop.mapreduce.{InputSplit => NewInputSplit, RecordReader => NewRecordReader,
TaskAttemptContext}
import org.apache.hadoop.mapreduce.lib.input.{CombineFileInputFormat => NewCombineFileInputFormat,
CombineFileRecordReader => NewCombineFileRecordReader, CombineFileSplit => NewCombineFileSplit,
FileSplit => NewFileSplit, TextInputFormat => NewTextInputFormat}

import org.apache.spark.SharedSparkContext
import org.apache.spark.deploy.SparkHadoopUtil
Expand Down Expand Up @@ -202,7 +211,7 @@ class InputOutputMetricsSuite extends FunSuite with SharedSparkContext {
val fs = FileSystem.getLocal(new Configuration())
val outPath = new Path(fs.getWorkingDirectory, "outdir")

if (SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback(fs.getConf).isDefined) {
if (SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback().isDefined) {
val taskBytesWritten = new ArrayBuffer[Long]()
sc.addSparkListener(new SparkListener() {
override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
Expand All @@ -225,4 +234,88 @@ class InputOutputMetricsSuite extends FunSuite with SharedSparkContext {
}
}
}

test("input metrics with old CombineFileInputFormat") {
val bytesRead = runAndReturnBytesRead {
sc.hadoopFile(tmpFilePath, classOf[OldCombineTextInputFormat], classOf[LongWritable],
classOf[Text], 2).count()
}
assert(bytesRead >= tmpFile.length())
}

test("input metrics with new CombineFileInputFormat") {
val bytesRead = runAndReturnBytesRead {
sc.newAPIHadoopFile(tmpFilePath, classOf[NewCombineTextInputFormat], classOf[LongWritable],
classOf[Text], new Configuration()).count()
}
assert(bytesRead >= tmpFile.length())
}
}

/**
* Hadoop 2 has a version of this, but we can't use it for backwards compatibility
*/
class OldCombineTextInputFormat extends OldCombineFileInputFormat[LongWritable, Text] {
override def getRecordReader(split: OldInputSplit, conf: JobConf, reporter: Reporter)
: OldRecordReader[LongWritable, Text] = {
new OldCombineFileRecordReader[LongWritable, Text](conf,
split.asInstanceOf[OldCombineFileSplit], reporter, classOf[OldCombineTextRecordReaderWrapper]
.asInstanceOf[Class[OldRecordReader[LongWritable, Text]]])
}
}

class OldCombineTextRecordReaderWrapper(
split: OldCombineFileSplit,
conf: Configuration,
reporter: Reporter,
idx: Integer) extends OldRecordReader[LongWritable, Text] {

val fileSplit = new OldFileSplit(split.getPath(idx),
split.getOffset(idx),
split.getLength(idx),
split.getLocations())

val delegate: OldLineRecordReader = new OldTextInputFormat().getRecordReader(fileSplit,
conf.asInstanceOf[JobConf], reporter).asInstanceOf[OldLineRecordReader]

override def next(key: LongWritable, value: Text): Boolean = delegate.next(key, value)
override def createKey(): LongWritable = delegate.createKey()
override def createValue(): Text = delegate.createValue()
override def getPos(): Long = delegate.getPos
override def close(): Unit = delegate.close()
override def getProgress(): Float = delegate.getProgress
}

/**
* Hadoop 2 has a version of this, but we can't use it for backwards compatibility
*/
class NewCombineTextInputFormat extends NewCombineFileInputFormat[LongWritable,Text] {
def createRecordReader(split: NewInputSplit, context: TaskAttemptContext)
: NewRecordReader[LongWritable, Text] = {
new NewCombineFileRecordReader[LongWritable,Text](split.asInstanceOf[NewCombineFileSplit],
context, classOf[NewCombineTextRecordReaderWrapper])
}
}

class NewCombineTextRecordReaderWrapper(
split: NewCombineFileSplit,
context: TaskAttemptContext,
idx: Integer) extends NewRecordReader[LongWritable, Text] {

val fileSplit = new NewFileSplit(split.getPath(idx),
split.getOffset(idx),
split.getLength(idx),
split.getLocations())

val delegate = new NewTextInputFormat().createRecordReader(fileSplit, context)

override def initialize(split: NewInputSplit, context: TaskAttemptContext): Unit = {
delegate.initialize(fileSplit, context)
}

override def nextKeyValue(): Boolean = delegate.nextKeyValue()
override def getCurrentKey(): LongWritable = delegate.getCurrentKey
override def getCurrentValue(): Text = delegate.getCurrentValue
override def getProgress(): Float = delegate.getProgress
override def close(): Unit = delegate.close()
}

0 comments on commit 864514b

Please sign in to comment.