Skip to content

Commit

Permalink
[SPARK-7983] [MLLIB] Add require for one-based indices in loadLibSVMFile
Browse files Browse the repository at this point in the history
jira: https://issues.apache.org/jira/browse/SPARK-7983

Customers frequently use zero-based indices in their LIBSVM files. No warnings or errors from Spark will be reported during their computation afterwards, and usually it will lead to wired result for many algorithms (like GBDT).

add a quick check.

Author: Yuhao Yang <[email protected]>

Closes apache#6538 from hhbyyh/loadSVM and squashes the following commits:

79d9c11 [Yuhao Yang] optimization as respond to comments
4310710 [Yuhao Yang] merge conflict
96460f1 [Yuhao Yang] merge conflict
20a2811 [Yuhao Yang] use require
6e4f8ca [Yuhao Yang] add check for ascending order
9956365 [Yuhao Yang] add ut for 0-based loadlibsvm exception
5bd1f9a [Yuhao Yang] add require for one-based in loadLIBSVM
  • Loading branch information
hhbyyh authored and jeanlyn committed Jun 12, 2015
1 parent 9ff596c commit 6d06cbb
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 0 deletions.
12 changes: 12 additions & 0 deletions mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,18 @@ object MLUtils {
val value = indexAndValue(1).toDouble
(index, value)
}.unzip

// check if indices are one-based and in ascending order
var previous = -1
var i = 0
val indicesLength = indices.length
while (i < indicesLength) {
val current = indices(i)
require(current > previous, "indices should be one-based and in ascending order" )
previous = current
i += 1
}

(label, indices.toArray, values.toArray)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import breeze.linalg.{squaredDistance => breezeSquaredDistance}
import com.google.common.base.Charsets
import com.google.common.io.Files

import org.apache.spark.SparkException
import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
Expand Down Expand Up @@ -108,6 +109,40 @@ class MLUtilsSuite extends SparkFunSuite with MLlibTestSparkContext {
Utils.deleteRecursively(tempDir)
}

test("loadLibSVMFile throws IllegalArgumentException when indices is zero-based") {
val lines =
"""
|0
|0 0:4.0 4:5.0 6:6.0
""".stripMargin
val tempDir = Utils.createTempDir()
val file = new File(tempDir.getPath, "part-00000")
Files.write(lines, file, Charsets.US_ASCII)
val path = tempDir.toURI.toString

intercept[SparkException] {
loadLibSVMFile(sc, path).collect()
}
Utils.deleteRecursively(tempDir)
}

test("loadLibSVMFile throws IllegalArgumentException when indices is not in ascending order") {
val lines =
"""
|0
|0 3:4.0 2:5.0 6:6.0
""".stripMargin
val tempDir = Utils.createTempDir()
val file = new File(tempDir.getPath, "part-00000")
Files.write(lines, file, Charsets.US_ASCII)
val path = tempDir.toURI.toString

intercept[SparkException] {
loadLibSVMFile(sc, path).collect()
}
Utils.deleteRecursively(tempDir)
}

test("saveAsLibSVMFile") {
val examples = sc.parallelize(Seq(
LabeledPoint(1.1, Vectors.sparse(3, Seq((0, 1.23), (2, 4.56)))),
Expand Down

0 comments on commit 6d06cbb

Please sign in to comment.