Skip to content

Commit

Permalink
Merge pull request #199 from mikegagnon/map_monoid
Browse files Browse the repository at this point in the history
scala.collection.Map algebras
  • Loading branch information
johnynek committed Sep 18, 2013
2 parents d887b85 + 6ebbcbe commit e6a97b1
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,5 @@ object Group extends GeneratedGroupImplicits with ProductGroups {
implicit val jdoubleGroup : Group[JDouble] = JDoubleField
implicit def indexedSeqGroup[T:Group]: Group[IndexedSeq[T]] = new IndexedSeqGroup[T]
implicit def mapGroup[K,V](implicit group : Group[V]) = new MapGroup[K,V]()(group)
implicit def scMapGroup[K,V](implicit group : Group[V]) = new ScMapGroup[K,V]()(group)
}
56 changes: 43 additions & 13 deletions algebird-core/src/main/scala/com/twitter/algebird/MapAlgebra.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,30 @@ limitations under the License.
package com.twitter.algebird

import scala.annotation.tailrec
import scala.collection.{Map => ScMap}

trait MapOperations[K, V, M <: ScMap[K, V]] {
def add(oldMap: M, kv: (K,V)): M
def remove(oldMap: M, k: K): M
}

abstract class GenericMapMonoid[K, V, M <: ScMap[K, V]](implicit val semigroup: Semigroup[V])
extends Monoid[M] with MapOperations[K, V, M] {

/** You can think of this as a Sparse vector monoid
*/
class MapMonoid[K,V](implicit val semigroup: Semigroup[V]) extends Monoid[Map[K,V]] {
val nonZero: (V => Boolean) = semigroup match {
case mon: Monoid[_] => mon.isNonZero(_)
case _ => (_ => true)
}

override def isNonZero(x : Map[K,V]) =
override def isNonZero(x : M) =
!x.isEmpty && (semigroup match {
case mon: Monoid[_] => x.valuesIterator.exists { v =>
mon.isNonZero(v)
}
case _ => true
})

override lazy val zero = Map[K,V]()

override def plus(x : Map[K,V], y : Map[K,V]) = {
override def plus(x : M, y : M) = {
// Scala maps can reuse internal structure, so don't copy just add into the bigger one:
// This really saves computation when adding lots of small maps into big ones (common)
val (big, small, bigOnLeft) = if(x.size > y.size) { (x,y,true) } else { (y,x,false) }
Expand All @@ -50,43 +54,69 @@ class MapMonoid[K,V](implicit val semigroup: Semigroup[V]) extends Monoid[Map[K,
}
.getOrElse(kv._2)
if (nonZero(newV))
oldMap + (kv._1 -> newV)
add(oldMap, (kv._1 -> newV))
else
oldMap - kv._1
remove(oldMap, kv._1)
}
}
}

class MapMonoid[K,V](implicit semigroup: Semigroup[V]) extends GenericMapMonoid[K, V, Map[K,V]] {
override lazy val zero = Map[K,V]()
override def add(oldMap: Map[K,V], kv: (K, V)) = oldMap + kv
override def remove(oldMap: Map[K,V], k: K) = oldMap - k
}

class ScMapMonoid[K,V](implicit semigroup: Semigroup[V]) extends GenericMapMonoid[K, V, ScMap[K,V]] {
override lazy val zero = ScMap[K,V]()
override def add(oldMap: ScMap[K,V], kv: (K, V)) = oldMap + kv
override def remove(oldMap: ScMap[K,V], k: K) = oldMap - k
}

/** You can think of this as a Sparse vector group
*/
class MapGroup[K,V](implicit val group : Group[V]) extends MapMonoid[K,V]()(group)
with Group[Map[K,V]] {
override def negate(kv : Map[K,V]) = kv.mapValues { v => group.negate(v) }
}

class ScMapGroup[K,V](implicit val group : Group[V]) extends ScMapMonoid[K,V]()(group)
with Group[ScMap[K,V]] {
override def negate(kv : ScMap[K,V]) = kv.mapValues { v => group.negate(v) }
}

/** You can think of this as a Sparse vector ring
*/
class MapRing[K,V](implicit val ring : Ring[V]) extends MapGroup[K,V]()(ring) with Ring[Map[K,V]] {
trait GenericMapRing[K, V, M <: ScMap[K, V]] extends Ring[M] with MapOperations[K, V, M] {

implicit def ring : Ring[V]

// It is possible to implement this, but we need a special "identity map" which we
// deal with as if it were map with all possible keys (.get(x) == ring.one for all x).
// Then we have to manage the delta from this map as we add elements. That said, it
// is not actually needed in matrix multiplication, so we are punting on it for now.
override def one = sys.error("multiplicative identity for Map unimplemented")
override def times(x : Map[K,V], y : Map[K,V]) : Map[K,V] = {
override def times(x : M, y : M) : M = {
val (big, small, bigOnLeft) = if(x.size > y.size) { (x,y,true) } else { (y,x,false) }
small.foldLeft(zero) { (oldMap, kv) =>
val bigV = big.getOrElse(kv._1, ring.zero)
val newV = if(bigOnLeft) ring.times(bigV, kv._2) else ring.times(kv._2, bigV)
if (ring.isNonZero(newV)) {
oldMap + (kv._1 -> newV)
add(oldMap, (kv._1 -> newV))
}
else {
oldMap - kv._1
remove(oldMap, kv._1)
}
}
}
}

class MapRing[K,V](implicit val ring : Ring[V]) extends MapGroup[K,V]()(ring)
with GenericMapRing[K, V, Map[K, V]]

class ScMapRing[K,V](implicit val ring : Ring[V]) extends ScMapGroup[K,V]()(ring)
with GenericMapRing[K, V, ScMap[K, V]]

object MapAlgebra {
def rightContainsLeft[K,V: Equiv](l: Map[K, V], r: Map[K, V]): Boolean =
l.forall { case (k, v) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ object Monoid extends GeneratedMonoidImplicits with ProductMonoids {
implicit def jlistMonoid[T] : Monoid[JList[T]] = new JListMonoid[T]
implicit def setMonoid[T] : Monoid[Set[T]] = new SetMonoid[T]
implicit def mapMonoid[K,V: Semigroup] = new MapMonoid[K,V]
implicit def scMapMonoid[K,V: Semigroup] = new ScMapMonoid[K,V]
implicit def jmapMonoid[K,V : Semigroup] = new JMapMonoid[K,V]
implicit def eitherMonoid[L : Semigroup, R : Monoid] = new EitherMonoid[L, R]
implicit def function1Monoid[T] = new Function1Monoid[T]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,4 +105,5 @@ object Ring extends GeneratedRingImplicits with ProductRings {
implicit val jdoubleRing : Ring[JDouble] = JDoubleField
implicit def indexedSeqRing[T:Ring]: Ring[IndexedSeq[T]] = new IndexedSeqRing[T]
implicit def mapRing[K,V](implicit ring : Ring[V]) = new MapRing[K,V]()(ring)
implicit def scMapRing[K,V](implicit ring : Ring[V]) = new ScMapRing[K,V]()(ring)
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import java.lang.{Integer => JInt, Short => JShort, Long => JLong, Float => JFlo
import java.util.{List => JList, Map => JMap}

import scala.collection.mutable.{Map => MMap}
import scala.collection.{Map => ScMap}
import scala.annotation.{implicitNotFound, tailrec}

/**
Expand Down Expand Up @@ -144,6 +145,7 @@ object Semigroup extends GeneratedSemigroupImplicits with ProductSemigroups {
implicit def jlistSemigroup[T] : Semigroup[JList[T]] = new JListMonoid[T]
implicit def setSemigroup[T] : Semigroup[Set[T]] = new SetMonoid[T]
implicit def mapSemigroup[K,V:Semigroup]: Semigroup[Map[K,V]] = new MapMonoid[K,V]
implicit def scMapSemigroup[K,V:Semigroup]: Semigroup[ScMap[K,V]] = new ScMapMonoid[K,V]
implicit def jmapSemigroup[K,V : Semigroup] : Semigroup[JMap[K, V]] = new JMapMonoid[K,V]
implicit def eitherSemigroup[L : Semigroup, R : Semigroup] = new EitherSemigroup[L,R]
implicit def function1Semigroup[T] : Semigroup[Function1[T,T]] = new Function1Monoid[T]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import org.scalacheck.Gen.choose
import org.scalacheck.Properties
import org.scalacheck.Prop._

import scala.collection.{Map => ScMap}

object CollectionSpecification extends Properties("Collections") {
import BaseProperties._

Expand Down Expand Up @@ -60,19 +62,32 @@ object CollectionSpecification extends Properties("Collections") {
.map { _.filter { kv => mv.isNonZero(kv._2) } }
}

property("Map plus/times keys") = forAll { (a : Map[Int,Int], b : Map[Int,Int]) =>
val rng = implicitly[Ring[Map[Int,Int]]]
(rng.zero == Map[Int,Int]()) &&
// Subsets because zeros are removed from the times/plus values
(rng.times(a,b)).keys.toSet.subsetOf((a.keys.toSet & b.keys.toSet)) &&
(rng.plus(a,b)).keys.toSet.subsetOf((a.keys.toSet | b.keys.toSet)) &&
(rng.plus(a,a).keys == (a.filter { kv => (kv._2 + kv._2) != 0 }).keys)
implicit def scMapArb[K : Arbitrary, V : Arbitrary : Monoid] = Arbitrary {
mapArb[K, V]
.arbitrary
.map { map: Map[K,V] => map: ScMap[K,V] }
}

def mapPlusTimesKeys[M <: ScMap[Int, Int]]
(implicit rng: Ring[M], arbMap: Arbitrary[M]) =
forAll { (a: M, b: M) =>
// Subsets because zeros are removed from the times/plus values
(rng.times(a,b)).keys.toSet.subsetOf((a.keys.toSet & b.keys.toSet)) &&
(rng.plus(a,b)).keys.toSet.subsetOf((a.keys.toSet | b.keys.toSet)) &&
(rng.plus(a,a).keys == (a.filter { kv => (kv._2 + kv._2) != 0 }).keys)
}

property("Map plus/times keys") = mapPlusTimesKeys[Map[Int, Int]]
property("ScMap plus/times keys") = mapPlusTimesKeys[ScMap[Int, Int]]
property("Map[Int,Int] Monoid laws") = isAssociative[Map[Int,Int]] && weakZero[Map[Int,Int]]
property("ScMap[Int,Int] Monoid laws") = isAssociative[ScMap[Int,Int]] && weakZero[ScMap[Int,Int]]
property("Map[Int,Int] has -") = hasAdditiveInverses[Map[Int,Int]]
property("ScMap[Int,Int] has -") = hasAdditiveInverses[ScMap[Int,Int]]
property("Map[Int,String] Monoid laws") = isAssociative[Map[Int,String]] && weakZero[Map[Int,String]]
property("ScMap[Int,String] Monoid laws") = isAssociative[ScMap[Int,String]] && weakZero[ScMap[Int,String]]
// We haven't implemented ring.one yet for the Map, so skip the one property
property("Map is distributive") = isDistributive[Map[Int,Int]]
property("ScMap is distributive") = isDistributive[ScMap[Int,Int]]
implicit def arbIndexedSeq[T:Arbitrary] : Arbitrary[IndexedSeq[T]] =
Arbitrary { implicitly[Arbitrary[List[T]]].arbitrary.map { _.toIndexedSeq } }

Expand Down

0 comments on commit e6a97b1

Please sign in to comment.