Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding noise to GDBSCAN result #44

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 19 additions & 9 deletions src/main/scala/nak/cluster/GDBSCAN.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package nak.cluster

import scala.language.{ implicitConversions, postfixOps }

import breeze.numerics._
import breeze.linalg._
import breeze.util._
Expand All @@ -26,13 +28,14 @@ class GDBSCAN[T](
* @param data - each row is treated as a feature vector
* @return clusters - a list of clusters with
*/
def cluster(data: DenseMatrix[T]): Seq[Cluster[T]] = {
def cluster(data: DenseMatrix[T]): Seq[AbstractCluster[T]] = {
// Visited - using row indices
val visited = MutableSet[Point[T]]()
val clustered = MutableSet[Point[T]]()

// Init points
val points = for (row <- 0 until data.rows) yield Point(row)(data(row, ::).inner)
val noise = Noise[T]()

// Start clustering
points.collect {
Expand All @@ -44,12 +47,13 @@ class GDBSCAN[T](
expand(point, neighbours, cluster)(points, visited, clustered)
Some(cluster)
} else {
noise add point
None // noise
}
}.flatten // remove noise
}.flatten :+ noise
}

private def expand(point: Point[T], neighbours: Seq[Point[T]], cluster: Cluster[T])(implicit points: Seq[Point[T]], visited: MutableSet[Point[T]], clustered: MutableSet[Point[T]]) {
private def expand(point: Point[T], neighbours: Seq[Point[T]], cluster: AbstractCluster[T])(points: Seq[Point[T]], visited: MutableSet[Point[T]], clustered: MutableSet[Point[T]]) {
cluster add point
clustered add point
neighbours.foldLeft(neighbours) {
Expand Down Expand Up @@ -80,19 +84,25 @@ object GDBSCAN {
override def toString() = s"[$row]: $value"
}

/** Cluster description */
case class Cluster[T](id: Long) {
private var _points = ListBuffer[Point[T]]()

abstract class AbstractCluster[T] {
val id: Long

protected var _points = ListBuffer[Point[T]]()

def add(p: Point[T]) {
_points += p
}

def points: Seq[Point[T]] = Seq(_points: _*)
}

/** Cluster description */
case class Cluster[T](id: Long) extends AbstractCluster[T] {
override def toString() = s"Cluster [$id]\t:\t${_points.size} points\t${_points mkString "|"}"
}

case class Noise[T](id: Long = -1) extends AbstractCluster[T]

}

/**
Expand Down Expand Up @@ -122,4 +132,4 @@ object DBSCAN {
def isCorePoint(minPoints: Double)(point: Point[Double], neighbours: Seq[Point[Double]]): Boolean = {
neighbours.size >= minPoints
}
}
}
8 changes: 5 additions & 3 deletions src/test/scala/nak/cluster/GDBSCANTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@ class GDBSCANTest extends FlatSpec with Matchers {
val cluster = gdbscan cluster input
val clusterPoints = cluster.map(_.points.map(_.value.toArray))

cluster.size shouldBe 2
cluster.size shouldBe 3
clusterPoints(0) should contain only (Array(0.9, 1.0), Array(1.0, 1.0), Array(1.0, 1.1))
clusterPoints(1) should contain only (Array(15.0, 15.0), Array(15.0, 14.1), Array(15.3, 15.0))
clusterPoints(2) should contain only (Array(5.0,5.0))
}

it should "work with custom predicates" in {
Expand All @@ -55,8 +56,9 @@ class GDBSCANTest extends FlatSpec with Matchers {
val cluster = gdbscan cluster input
val clusterPoints = cluster.map(_.points.map(_.value.toArray))

cluster.size shouldBe 2
cluster.size shouldBe 3
clusterPoints(0) should contain only (Array(1.0), Array(3.0))
clusterPoints(1) should contain only (Array(2.0), Array(4.0))
clusterPoints(2) shouldBe empty
}
}
}