Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-11622][MLLIB] Make LibSVMRelation extends HadoopFsRelation and… #9595

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,21 @@

package org.apache.spark.ml.source.libsvm

import java.io.IOException

import com.google.common.base.Objects
import org.apache.hadoop.fs.{FileStatus, Path}
import org.apache.hadoop.io.{NullWritable, Text}
import org.apache.hadoop.mapreduce.{RecordWriter, TaskAttemptContext}
import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat

import org.apache.spark.Logging
import org.apache.spark.annotation.Since
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, DataFrameReader, Row, SQLContext}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
import org.apache.spark.sql.types._

/**
* LibSVMRelation provides the DataFrame constructed from LibSVM format data.
Expand All @@ -37,14 +42,10 @@ import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
*/
private[libsvm] class LibSVMRelation(val path: String, val numFeatures: Int, val vectorType: String)
(@transient val sqlContext: SQLContext)
extends BaseRelation with TableScan with Logging with Serializable {

override def schema: StructType = StructType(
StructField("label", DoubleType, nullable = false) ::
StructField("features", new VectorUDT(), nullable = false) :: Nil
)
extends HadoopFsRelation with Serializable {

override def buildScan(): RDD[Row] = {
override def buildScan(requiredColumns: Array[String], inputFiles: Array[FileStatus])
: RDD[Row] = {
val sc = sqlContext.sparkContext
val baseRdd = MLUtils.loadLibSVMFile(sc, path, numFeatures)
val sparse = vectorType == "sparse"
Expand All @@ -66,8 +67,63 @@ private[libsvm] class LibSVMRelation(val path: String, val numFeatures: Int, val
case _ =>
false
}

override def prepareJobForWrite(job: _root_.org.apache.hadoop.mapreduce.Job):
_root_.org.apache.spark.sql.sources.OutputWriterFactory = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use just OutputWriterFactory because of importing org.apache.spark.sql.sources._?

new OutputWriterFactory {
override def newInstance(
path: String,
dataSchema: StructType,
context: TaskAttemptContext): OutputWriter = {
new LibSVMOutputWriter(path, dataSchema, context)
}
}
}

override def paths: Array[String] = Array(path)

override def dataSchema: StructType = StructType(
StructField("label", DoubleType, nullable = false) ::
StructField("features", new VectorUDT(), nullable = false) :: Nil)
}


private[libsvm] class LibSVMOutputWriter(
path: String,
dataSchema: StructType,
context: TaskAttemptContext)
extends OutputWriter {

private[this] val buffer = new Text()

private val recordWriter: RecordWriter[NullWritable, Text] = {
new TextOutputFormat[NullWritable, Text]() {
override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = {
val configuration = context.getConfiguration
val uniqueWriteJobId = configuration.get("spark.sql.sources.writeJobUUID")
val taskAttemptId = context.getTaskAttemptID
val split = taskAttemptId.getTaskID.getId
new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$extension")
}
}.getRecordWriter(context)
}

override def write(row: Row): Unit = {
val label = row.get(0)
val vector = row.get(1).asInstanceOf[Vector]
val sb = new StringBuilder(label.toString)
vector.foreachActive { case (i, v) =>
sb += ' '
sb ++= s"${i + 1}:$v"
}
buffer.set(sb.mkString)
recordWriter.write(NullWritable.get(), buffer)
}

override def close(): Unit = {
recordWriter.close(context)
}
}
/**
* `libsvm` package implements Spark SQL data source API for loading LIBSVM data as [[DataFrame]].
* The loaded [[DataFrame]] has two columns: `label` containing labels stored as doubles and
Expand Down Expand Up @@ -99,16 +155,32 @@ private[libsvm] class LibSVMRelation(val path: String, val numFeatures: Int, val
* @see [[https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/ LIBSVM datasets]]
*/
@Since("1.6.0")
class DefaultSource extends RelationProvider with DataSourceRegister {
class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister {

@Since("1.6.0")
override def shortName(): String = "libsvm"

@Since("1.6.0")
override def createRelation(sqlContext: SQLContext, parameters: Map[String, String])
: BaseRelation = {
val path = parameters.getOrElse("path",
throw new IllegalArgumentException("'path' must be specified"))
private def verifySchema(dataSchema: StructType): Unit = {
if (dataSchema.size != 2 ||
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may be necessary to attache @Since annotation to verifySchema. @since("1.6.0").

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm sorry, it's private. Never mind.

(!dataSchema(0).dataType.sameType(DataTypes.DoubleType)
|| !dataSchema(1).dataType.sameType(new VectorUDT()))) {
throw new IOException(s"Illegal schema for libsvm data, schema=${dataSchema}")
}
}

override def createRelation(
sqlContext: SQLContext,
paths: Array[String],
dataSchema: Option[StructType],
partitionColumns: Option[StructType],
parameters: Map[String, String]): HadoopFsRelation = {
val path = if (paths.length == 1) paths(0)
else if (paths.isEmpty) throw new IOException("No input path specified for libsvm data")
else throw new IOException("Multiple input paths are not supported for libsvm data")
if (partitionColumns.isDefined && !partitionColumns.get.isEmpty) {
throw new IOException("Partition is not supported for libsvm data")
}
dataSchema.foreach(verifySchema(_))
val numFeatures = parameters.getOrElse("numFeatures", "-1").toInt
val vectorType = parameters.getOrElse("vectorType", "sparse")
new LibSVMRelation(path, numFeatures, vectorType)(sqlContext)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@

package org.apache.spark.ml.source.libsvm

import java.io.File
import java.io.{File, IOException}

import com.google.common.base.Charsets
import com.google.common.io.Files

import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.SaveMode
import org.apache.spark.util.Utils

class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext {
Expand Down Expand Up @@ -82,4 +83,24 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext {
val v = row1.getAs[SparseVector](1)
assert(v == Vectors.sparse(100, Seq((0, 1.0), (2, 2.0), (4, 3.0))))
}

test("write libsvm data and read it again") {
val df = sqlContext.read.format("libsvm").load(path)
val tempDir2 = Utils.createTempDir()
val writepath = tempDir2.toURI.toString
df.write.format("libsvm").mode(SaveMode.Overwrite).save(writepath)

val df2 = sqlContext.read.format("libsvm").load(writepath)
val row1 = df2.first()
val v = row1.getAs[SparseVector](1)
assert(v == Vectors.sparse(6, Seq((0, 1.0), (2, 2.0), (4, 3.0))))
}

test("write libsvm data failed due to invalid schema") {
val df = sqlContext.read.format("text").load(path)
val e = intercept[IOException] {
df.write.format("libsvm").save(path + "_2")
}
assert(e.getMessage.contains("Illegal schema for libsvm data"))
}
}
4 changes: 4 additions & 0 deletions project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,10 @@ object MimaExcludes {
// SPARK-12847 Remove StreamingListenerBus and post all Streaming events to the same thread as Spark events
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.AsynchronousListenerBus$"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.AsynchronousListenerBus")
) ++ Seq(
// SPARK-11622 Make LibSVMRelation extends HadoopFsRelation and Add LibSVMOutputWriter
ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.source.libsvm.DefaultSource"),
ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.source.libsvm.DefaultSource.createRelation")
)
case v if v.startsWith("1.6") =>
Seq(
Expand Down