Skip to content

Commit

Permalink
Improve traverseViaChain API (#3535)
Browse files Browse the repository at this point in the history
* Improve traverseViaChain API

* fix alleycats

* remove unneeded private method

* fix 2.13 compilation

* be more paranoid about mutability
  • Loading branch information
johnynek authored Aug 3, 2020
1 parent 201e834 commit 763847f
Show file tree
Hide file tree
Showing 7 changed files with 87 additions and 124 deletions.
13 changes: 11 additions & 2 deletions alleycats-core/src/main/scala/alleycats/std/map.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package std

import cats._
import cats.data.Chain
import cats.kernel.instances.StaticMethods.wrapMutableIndexedSeq

object map extends MapInstances

Expand All @@ -15,7 +16,11 @@ trait MapInstances {
def traverse[G[_], A, B](fa: Map[K, A])(f: A => G[B])(implicit G: Applicative[G]): G[Map[K, B]] =
if (fa.isEmpty) G.pure(Map.empty[K, B])
else
G.map(Chain.traverseViaChain(fa.iterator) {
G.map(Chain.traverseViaChain {
val as = collection.mutable.ArrayBuffer[(K, A)]()
as ++= fa
wrapMutableIndexedSeq(as)
} {
case (k, a) => G.map(f(a))((k, _))
}) { chain => chain.foldLeft(Map.empty[K, B]) { case (m, (k, b)) => m.updated(k, b) } }

Expand Down Expand Up @@ -62,7 +67,11 @@ trait MapInstances {
def traverseFilter[G[_], A, B](fa: Map[K, A])(f: A => G[Option[B]])(implicit G: Applicative[G]): G[Map[K, B]] =
if (fa.isEmpty) G.pure(Map.empty[K, B])
else
G.map(Chain.traverseFilterViaChain(fa.iterator) {
G.map(Chain.traverseFilterViaChain {
val as = collection.mutable.ArrayBuffer[(K, A)]()
as ++= fa
wrapMutableIndexedSeq(as)
} {
case (k, a) =>
G.map(f(a)) { optB =>
if (optB.isDefined) Some((k, optB.get))
Expand Down
42 changes: 21 additions & 21 deletions core/src/main/scala/cats/data/Chain.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import Chain._
import cats.kernel.instances.StaticMethods

import scala.annotation.tailrec
import scala.collection.immutable.{SortedMap, TreeSet}
import scala.collection.immutable.{SortedMap, TreeSet, IndexedSeq => ImIndexedSeq}
import scala.collection.mutable.ListBuffer

/**
Expand Down Expand Up @@ -382,14 +382,6 @@ sealed abstract class Chain[+A] {
go(this, Chain.nil)
}

/**
* Applies the supplied function to each element, left to right.
*/
final private def foreach(f: A => Unit): Unit =
foreachUntil { a =>
f(a); false
}

/**
* Applies the supplied function to each element, left to right, but stops when true is returned
*/
Expand Down Expand Up @@ -508,11 +500,11 @@ sealed abstract class Chain[+A] {
val builder = new StringBuilder("Chain(")
var first = true

foreach { a =>
foreachUntil { a =>
if (first) {
builder ++= AA.show(a); first = false
} else builder ++= ", " + AA.show(a)
()
false
}
builder += ')'
builder.result()
Expand Down Expand Up @@ -629,13 +621,13 @@ object Chain extends ChainInstances {
def apply[A](as: A*): Chain[A] =
fromSeq(as)

def traverseViaChain[G[_], A, B](iter: Iterator[A])(f: A => G[B])(implicit G: Applicative[G]): G[Chain[B]] =
if (!iter.hasNext) G.pure(Chain.nil)
def traverseViaChain[G[_], A, B](
as: ImIndexedSeq[A]
)(f: A => G[B])(implicit G: Applicative[G]): G[Chain[B]] =
if (as.isEmpty) G.pure(Chain.nil)
else {
// we branch out by this factor
val width = 128
val as = collection.mutable.Buffer[A]()
as ++= iter
// By making a tree here we don't blow the stack
// even if the List is very long
// by construction, this is never called with start == end
Expand Down Expand Up @@ -676,14 +668,12 @@ object Chain extends ChainInstances {
}

def traverseFilterViaChain[G[_], A, B](
iter: Iterator[A]
as: ImIndexedSeq[A]
)(f: A => G[Option[B]])(implicit G: Applicative[G]): G[Chain[B]] =
if (!iter.hasNext) G.pure(Chain.nil)
if (as.isEmpty) G.pure(Chain.nil)
else {
// we branch out by this factor
val width = 128
val as = collection.mutable.Buffer[A]()
as ++= iter
// By making a tree here we don't blow the stack
// even if the List is very long
// by construction, this is never called with start == end
Expand Down Expand Up @@ -862,7 +852,12 @@ sealed abstract private[data] class ChainInstances extends ChainInstances1 {

def traverse[G[_], A, B](fa: Chain[A])(f: A => G[B])(implicit G: Applicative[G]): G[Chain[B]] =
if (fa.isEmpty) G.pure(Chain.nil)
else traverseViaChain(fa.iterator)(f)
else
traverseViaChain {
val as = collection.mutable.ArrayBuffer[A]()
as ++= fa.iterator
StaticMethods.wrapMutableIndexedSeq(as)
}(f)

def empty[A]: Chain[A] = Chain.nil
def combineK[A](c: Chain[A], c2: Chain[A]): Chain[A] = Chain.concat(c, c2)
Expand Down Expand Up @@ -963,7 +958,12 @@ sealed abstract private[data] class ChainInstances extends ChainInstances1 {

def traverseFilter[G[_], A, B](fa: Chain[A])(f: A => G[Option[B]])(implicit G: Applicative[G]): G[Chain[B]] =
if (fa.isEmpty) G.pure(Chain.nil)
else traverseFilterViaChain(fa.iterator)(f)
else
traverseFilterViaChain {
val as = collection.mutable.ArrayBuffer[A]()
as ++= fa.iterator
StaticMethods.wrapMutableIndexedSeq(as)
}(f)

override def filterA[G[_], A](fa: Chain[A])(f: A => G[Boolean])(implicit G: Applicative[G]): G[Chain[A]] =
traverse
Expand Down
15 changes: 13 additions & 2 deletions core/src/main/scala/cats/instances/list.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package cats
package instances

import cats.data.{Chain, ZipList}
import cats.kernel.instances.StaticMethods.wrapMutableIndexedSeq
import cats.syntax.show._

import scala.annotation.tailrec
Expand Down Expand Up @@ -87,7 +88,12 @@ trait ListInstances extends cats.kernel.instances.ListInstances {

def traverse[G[_], A, B](fa: List[A])(f: A => G[B])(implicit G: Applicative[G]): G[List[B]] =
if (fa.isEmpty) G.pure(Nil)
else G.map(Chain.traverseViaChain(fa.iterator)(f))(_.toList)
else
G.map(Chain.traverseViaChain {
val as = collection.mutable.ArrayBuffer[A]()
as ++= fa
wrapMutableIndexedSeq(as)
}(f))(_.toList)

def functor: Functor[List] = this

Expand Down Expand Up @@ -212,7 +218,12 @@ private[instances] trait ListInstancesBinCompat0 {

def traverseFilter[G[_], A, B](fa: List[A])(f: (A) => G[Option[B]])(implicit G: Applicative[G]): G[List[B]] =
if (fa.isEmpty) G.pure(Nil)
else G.map(Chain.traverseFilterViaChain(fa.iterator)(f))(_.toList)
else
G.map(Chain.traverseFilterViaChain {
val as = collection.mutable.ArrayBuffer[A]()
as ++= fa
wrapMutableIndexedSeq(as)
}(f))(_.toList)

override def filterA[G[_], A](fa: List[A])(f: (A) => G[Boolean])(implicit G: Applicative[G]): G[List[A]] =
traverse
Expand Down
13 changes: 11 additions & 2 deletions core/src/main/scala/cats/instances/queue.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package cats
package instances

import cats.data.Chain
import cats.kernel.instances.StaticMethods.wrapMutableIndexedSeq
import cats.syntax.show._
import scala.annotation.tailrec
import scala.collection.immutable.Queue
Expand Down Expand Up @@ -82,7 +83,11 @@ trait QueueInstances extends cats.kernel.instances.QueueInstances {
def traverse[G[_], A, B](fa: Queue[A])(f: A => G[B])(implicit G: Applicative[G]): G[Queue[B]] =
if (fa.isEmpty) G.pure(Queue.empty[B])
else
G.map(Chain.traverseViaChain(fa.iterator)(f)) { chain =>
G.map(Chain.traverseViaChain {
val as = collection.mutable.ArrayBuffer[A]()
as ++= fa
wrapMutableIndexedSeq(as)
}(f)) { chain =>
chain.foldLeft(Queue.empty[B])(_ :+ _)
}

Expand Down Expand Up @@ -177,7 +182,11 @@ private object QueueInstances {
def traverseFilter[G[_], A, B](fa: Queue[A])(f: (A) => G[Option[B]])(implicit G: Applicative[G]): G[Queue[B]] =
if (fa.isEmpty) G.pure(Queue.empty[B])
else
G.map(Chain.traverseFilterViaChain(fa.iterator)(f)) { chain =>
G.map(Chain.traverseFilterViaChain {
val as = collection.mutable.ArrayBuffer[A]()
as ++= fa
wrapMutableIndexedSeq(as)
}(f)) { chain =>
chain.foldLeft(Queue.empty[B])(_ :+ _)
}

Expand Down
13 changes: 11 additions & 2 deletions core/src/main/scala/cats/instances/sortedMap.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package cats.instances
import cats._
import cats.data.{Chain, Ior}
import cats.kernel.{CommutativeMonoid, CommutativeSemigroup}
import cats.kernel.instances.StaticMethods.wrapMutableIndexedSeq

import scala.annotation.tailrec
import scala.collection.immutable.SortedMap
Expand Down Expand Up @@ -35,7 +36,11 @@ trait SortedMapInstances extends SortedMapInstances2 {
implicit val ordering: Ordering[K] = fa.ordering
if (fa.isEmpty) G.pure(SortedMap.empty[K, B])
else
G.map(Chain.traverseViaChain(fa.iterator) {
G.map(Chain.traverseViaChain {
val as = collection.mutable.ArrayBuffer[(K, A)]()
as ++= fa
wrapMutableIndexedSeq(as)
} {
case (k, a) => G.map(f(a))((k, _))
}) { chain => chain.foldLeft(SortedMap.empty[K, B]) { case (m, (k, b)) => m.updated(k, b) } }
}
Expand Down Expand Up @@ -194,7 +199,11 @@ private[instances] trait SortedMapInstancesBinCompat0 {
implicit val ordering: Ordering[K] = fa.ordering
if (fa.isEmpty) G.pure(SortedMap.empty[K, B])
else
G.map(Chain.traverseFilterViaChain(fa.iterator) {
G.map(Chain.traverseFilterViaChain {
val as = collection.mutable.ArrayBuffer[(K, A)]()
as ++= fa
wrapMutableIndexedSeq(as)
} {
case (k, a) =>
G.map(f(a)) { optB =>
if (optB.isDefined) Some((k, optB.get))
Expand Down
97 changes: 2 additions & 95 deletions core/src/main/scala/cats/instances/vector.scala
Original file line number Diff line number Diff line change
Expand Up @@ -93,52 +93,7 @@ trait VectorInstances extends cats.kernel.instances.VectorInstances {
}

final override def traverse[G[_], A, B](fa: Vector[A])(f: A => G[B])(implicit G: Applicative[G]): G[Vector[B]] =
if (fa.isEmpty) G.pure(empty)
else {
// this is a specialized version of Chain.traverseViaChain since
// we don't need to materialize the Vector first

// we branch out by this factor
val width = 128
// By making a tree here we don't blow the stack
// even if the List is very long
// by construction, this is never called with start == end
def loop(start: Int, end: Int): Eval[G[Chain[B]]] =
if (end - start <= width) {
// Here we are at the leafs of the trees
// we don't use map2Eval since it is always
// at most width in size.
val aend = fa(end - 1)
var flist = Eval.later(G.map(f(aend))(_ :: Nil))
var idx = end - 2
while (start <= idx) {
val a = fa(idx)
// don't capture a var in the defer
val right = flist
flist = Eval.defer(G.map2Eval(f(a), right)(_ :: _))
idx = idx - 1
}
flist.map { glist => G.map(glist)(Chain.fromSeq(_)) }
} else {
// we have width + 1 or more nodes left
val step = (end - start) / width

var fchain = Eval.defer(loop(start, start + step))
var start0 = start + step
var end0 = start0 + step

while (start0 < end) {
val end1 = math.min(end, end0)
val right = loop(start0, end1)
fchain = fchain.flatMap(G.map2Eval(_, right)(_.concat(_)))
start0 = start0 + step
end0 = end0 + step
}
fchain
}

G.map(loop(0, fa.size).value)(_.toVector)
}
G.map(Chain.traverseViaChain(fa)(f))(_.toVector)

override def mapWithIndex[A, B](fa: Vector[A])(f: (A, Int) => B): Vector[B] =
fa.iterator.zipWithIndex.map(ai => f(ai._1, ai._2)).toVector
Expand Down Expand Up @@ -225,55 +180,7 @@ private[instances] trait VectorInstancesBinCompat0 {
override def flattenOption[A](fa: Vector[Option[A]]): Vector[A] = fa.flatten

def traverseFilter[G[_], A, B](fa: Vector[A])(f: (A) => G[Option[B]])(implicit G: Applicative[G]): G[Vector[B]] =
if (fa.isEmpty) G.pure(Vector.empty[B])
else {
// we branch out by this factor
val width = 128
// By making a tree here we don't blow the stack
// even if the List is very long
// by construction, this is never called with start == end
def loop(start: Int, end: Int): Eval[G[Chain[B]]] =
if (end - start <= width) {
// Here we are at the leafs of the trees
// we don't use map2Eval since it is always
// at most width in size.
val aend = fa(end - 1)
var flist = Eval.later(G.map(f(aend)) { optB =>
if (optB.isDefined) optB.get :: Nil
else Nil
})
var idx = end - 2
while (start <= idx) {
val a = fa(idx)
// don't capture a var in the defer
val right = flist
flist = Eval.defer(G.map2Eval(f(a), right) { (optB, tail) =>
if (optB.isDefined) optB.get :: tail
else tail
})
idx = idx - 1
}
flist.map { glist => G.map(glist)(Chain.fromSeq(_)) }
} else {
// we have width + 1 or more nodes left
val step = (end - start) / width

var fchain = Eval.defer(loop(start, start + step))
var start0 = start + step
var end0 = start0 + step

while (start0 < end) {
val end1 = math.min(end, end0)
val right = loop(start0, end1)
fchain = fchain.flatMap(G.map2Eval(_, right)(_.concat(_)))
start0 = start0 + step
end0 = end0 + step
}
fchain
}

G.map(loop(0, fa.size).value)(_.toVector)
}
G.map(Chain.traverseFilterViaChain(fa)(f))(_.toVector)

override def filterA[G[_], A](fa: Vector[A])(f: (A) => G[Boolean])(implicit G: Applicative[G]): G[Vector[A]] =
traverse
Expand Down
18 changes: 18 additions & 0 deletions kernel/src/main/scala/cats/kernel/instances/StaticMethods.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ package cats
package kernel
package instances

import scala.collection.immutable.{IndexedSeq => ImIndexedSeq}
import scala.collection.mutable
import compat.scalaVersionSpecific._

@suppressUnusedImportWarningForScalaVersionSpecific
object StaticMethods extends cats.kernel.compat.HashCompat {

Expand All @@ -17,6 +19,22 @@ object StaticMethods extends cats.kernel.compat.HashCompat {
def iterator: Iterator[(K, V)] = m.iterator
}

/**
* When you "own" this m, and will not mutate it again, this
* is safe to call. It is unsafe to call this, then mutate
* the original collection.
*
* You are giving up ownership when calling this method
*/
def wrapMutableIndexedSeq[A](m: mutable.IndexedSeq[A]): ImIndexedSeq[A] =
new WrappedIndexedSeq(m)

private[kernel] class WrappedIndexedSeq[A](m: mutable.IndexedSeq[A]) extends ImIndexedSeq[A] {
override def length: Int = m.length
override def apply(i: Int): A = m(i)
override def iterator: Iterator[A] = m.iterator
}

// scalastyle:off return
def iteratorCompare[A](xs: Iterator[A], ys: Iterator[A])(implicit ev: Order[A]): Int = {
while (true) {
Expand Down

0 comments on commit 763847f

Please sign in to comment.