Skip to content

Commit

Permalink
Added LabeledPoint class in Python
Browse files Browse the repository at this point in the history
  • Loading branch information
mateiz committed Apr 15, 2014
1 parent 889dde8 commit 74eefe7
Show file tree
Hide file tree
Showing 5 changed files with 261 additions and 127 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,23 +39,24 @@ class PythonMLLibAPI extends Serializable {
private val DENSE_VECTOR_MAGIC: Byte = 1
private val SPARSE_VECTOR_MAGIC: Byte = 2
private val DENSE_MATRIX_MAGIC: Byte = 3
private val LABELED_POINT_MAGIC: Byte = 4

private def deserializeDoubleVector(bytes: Array[Byte]): Vector = {
require(bytes.length >= 5, "Byte array too short")
val magic = bytes(0)
private def deserializeDoubleVector(bytes: Array[Byte], offset: Int = 0): Vector = {
require(bytes.length - offset >= 5, "Byte array too short")
val magic = bytes(offset)
if (magic == DENSE_VECTOR_MAGIC) {
deserializeDenseVector(bytes)
deserializeDenseVector(bytes, offset)
} else if (magic == SPARSE_VECTOR_MAGIC) {
deserializeSparseVector(bytes)
deserializeSparseVector(bytes, offset)
} else {
throw new IllegalArgumentException("Magic " + magic + " is wrong.")
}
}

private def deserializeDenseVector(bytes: Array[Byte]): Vector = {
val packetLength = bytes.length
private def deserializeDenseVector(bytes: Array[Byte], offset: Int = 0): Vector = {
val packetLength = bytes.length - offset
require(packetLength >= 5, "Byte array too short")
val bb = ByteBuffer.wrap(bytes)
val bb = ByteBuffer.wrap(bytes, offset, bytes.length - offset)
bb.order(ByteOrder.nativeOrder())
val magic = bb.get()
require(magic == DENSE_VECTOR_MAGIC, "Invalid magic: " + magic)
Expand All @@ -67,10 +68,10 @@ class PythonMLLibAPI extends Serializable {
Vectors.dense(ans)
}

private def deserializeSparseVector(bytes: Array[Byte]): Vector = {
val packetLength = bytes.length
private def deserializeSparseVector(bytes: Array[Byte], offset: Int = 0): Vector = {
val packetLength = bytes.length - offset
require(packetLength >= 9, "Byte array too short")
val bb = ByteBuffer.wrap(bytes)
val bb = ByteBuffer.wrap(bytes, offset, bytes.length - offset)
bb.order(ByteOrder.nativeOrder())
val magic = bb.get()
require(magic == SPARSE_VECTOR_MAGIC, "Invalid magic: " + magic)
Expand Down Expand Up @@ -166,14 +167,23 @@ class PythonMLLibAPI extends Serializable {
bytes
}

private def deserializeLabeledPoint(bytes: Array[Byte]): LabeledPoint = {
require(bytes.length >= 9, "Byte array too short")
val magic = bytes(0)
if (magic != LABELED_POINT_MAGIC) {
throw new IllegalArgumentException("Magic " + magic + " is wrong.")
}
val labelBytes = ByteBuffer.wrap(bytes, 1, 8)
labelBytes.order(ByteOrder.nativeOrder())
val label = labelBytes.asDoubleBuffer().get(0)
LabeledPoint(label, deserializeDoubleVector(bytes, 9))
}

private def trainRegressionModel(
trainFunc: (RDD[LabeledPoint], Vector) => GeneralizedLinearModel,
dataBytesJRDD: JavaRDD[Array[Byte]],
initialWeightsBA: Array[Byte]): java.util.LinkedList[java.lang.Object] = {
val data = dataBytesJRDD.rdd.map(xBytes => {
val x = deserializeDoubleVector(xBytes)
LabeledPoint(x(0), x.slice(1, x.size))
})
val data = dataBytesJRDD.rdd.map(deserializeLabeledPoint)
val initialWeights = deserializeDoubleVector(initialWeightsBA)
val model = trainFunc(data, initialWeights)
val ret = new java.util.LinkedList[java.lang.Object]()
Expand Down Expand Up @@ -299,10 +309,7 @@ class PythonMLLibAPI extends Serializable {
def trainNaiveBayes(
dataBytesJRDD: JavaRDD[Array[Byte]],
lambda: Double): java.util.List[java.lang.Object] = {
val data = dataBytesJRDD.rdd.map(xBytes => {
val x = deserializeDoubleVector(xBytes)
LabeledPoint(x(0), x.slice(1, x.size))
})
val data = dataBytesJRDD.rdd.map(deserializeLabeledPoint)
val model = NaiveBayes.train(data, lambda)
val ret = new java.util.LinkedList[java.lang.Object]()
ret.add(serializeDoubleVector(Vectors.dense(model.labels)))
Expand All @@ -320,7 +327,7 @@ class PythonMLLibAPI extends Serializable {
maxIterations: Int,
runs: Int,
initializationMode: String): java.util.List[java.lang.Object] = {
val data = dataBytesJRDD.rdd.map(deserializeDoubleVector)
val data = dataBytesJRDD.rdd.map(bytes => deserializeDoubleVector(bytes))
val model = KMeans.train(data, k, maxIterations, runs, initializationMode)
val ret = new java.util.LinkedList[java.lang.Object]()
ret.add(serializeDoubleMatrix(model.clusterCenters.map(_.toArray)))
Expand Down
Loading

0 comments on commit 74eefe7

Please sign in to comment.