Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Reduced code duplication in classify method in Classifier.scala
Browse files Browse the repository at this point in the history
  • Loading branch information
piyushghai committed Dec 28, 2018
1 parent d5595a0 commit 926871e
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,14 @@ object MX_PRIMITIVES {
def unary_- : MX_PRIMITIVE_TYPE
}

trait MXPrimitiveOrdering extends Ordering[MX_PRIMITIVE_TYPE] {

def compare(x: MX_PRIMITIVE_TYPE, y: MX_PRIMITIVE_TYPE) = x.compare(y)

}

implicit object MX_PRIMITIVE_TYPE extends MXPrimitiveOrdering

/**
* Mimics Float in Scala.
* @param data
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ package org.apache.mxnet.infer
import org.apache.mxnet._
import java.io.File

import org.apache.mxnet.MX_PRIMITIVES.MX_PRIMITIVE_TYPE
import org.slf4j.LoggerFactory

import scala.io
import scala.collection.mutable.ListBuffer
import scala.collection.parallel.mutable.ParArray
Expand Down Expand Up @@ -88,40 +90,24 @@ class Classifier(modelPathPrefix: String,
// considering only the first output
val result = input(0)(0) match {
case d: Double => {
classifyWithDoubleImpl(input.asInstanceOf[IndexedSeq[Array[Double]]], topK)
classifyImpl(input.asInstanceOf[IndexedSeq[Array[Double]]], topK)
}
case _ => {
classifyWithFloatImpl(input.asInstanceOf[IndexedSeq[Array[Float]]], topK)
classifyImpl(input.asInstanceOf[IndexedSeq[Array[Float]]], topK)
}
}

result.asInstanceOf[IndexedSeq[(String, T)]]
}

private def classifyWithFloatImpl(input: IndexedSeq[Array[Float]], topK: Option[Int] = None)
: IndexedSeq[(String, Float)] = {

// considering only the first output
val predictResult = predictor.predict(input)(0)

var result: IndexedSeq[(String, Float)] = IndexedSeq.empty

if (topK.isDefined) {
val sortedIndex = predictResult.zipWithIndex.sortBy(-_._1).map(_._2).take(topK.get)
result = sortedIndex.map(i => (synset(i), predictResult(i))).toIndexedSeq
} else {
result = synset.zip(predictResult).toIndexedSeq
}
result
}

private def classifyWithDoubleImpl(input: IndexedSeq[Array[Double]], topK: Option[Int] = None)
: IndexedSeq[(String, Double)] = {
private def classifyImpl[B, A <: MX_PRIMITIVE_TYPE]
(input: IndexedSeq[Array[B]], topK: Option[Int] = None)(implicit ev: B => A)
: IndexedSeq[(String, B)] = {

// considering only the first output
val predictResult = predictor.predict(input)(0)

var result: IndexedSeq[(String, Double)] = IndexedSeq.empty
var result: IndexedSeq[(String, B)] = IndexedSeq.empty

if (topK.isDefined) {
val sortedIndex = predictResult.zipWithIndex.sortBy(-_._1).map(_._2).take(topK.get)
Expand Down

0 comments on commit 926871e

Please sign in to comment.