Skip to content

Commit

Permalink
Faster addAll/putAll implementation (#86)
Browse files Browse the repository at this point in the history
  • Loading branch information
belyaev-mikhail authored Sep 17, 2020
1 parent 0501885 commit 690343a
Show file tree
Hide file tree
Showing 10 changed files with 485 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@

package kotlinx.collections.immutable.implementations.immutableMap

import kotlinx.collections.immutable.*
import kotlinx.collections.immutable.ImmutableCollection
import kotlinx.collections.immutable.ImmutableSet
import kotlinx.collections.immutable.PersistentMap
import kotlinx.collections.immutable.mutate

internal class PersistentHashMap<K, V>(internal val node: TrieNode<K, V>,
override val size: Int): AbstractMap<K, V>(), PersistentMap<K, V> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package kotlinx.collections.immutable.implementations.immutableMap

import kotlinx.collections.immutable.PersistentMap
import kotlinx.collections.immutable.internal.DeltaCounter
import kotlinx.collections.immutable.internal.MutabilityOwnership

internal class PersistentHashMapBuilder<K, V>(private var map: PersistentHashMap<K, V>) : PersistentMap.Builder<K, V>, AbstractMutableMap<K, V>() {
Expand Down Expand Up @@ -61,6 +62,18 @@ internal class PersistentHashMapBuilder<K, V>(private var map: PersistentHashMap
return operationResult
}

override fun putAll(from: Map<out K, V>) {
val map = from as? PersistentHashMap ?: (from as? PersistentHashMapBuilder)?.map
if (map != null) @Suppress("UNCHECKED_CAST") {
val intersectionCounter = DeltaCounter()
val oldSize = this.size
node = node.mutablePutAll(map.node as TrieNode<K, V>, 0, intersectionCounter, this)
val newSize = oldSize + map.size - intersectionCounter.count
if(oldSize != newSize) this.size = newSize
}
else super.putAll(from)
}

override fun remove(key: K): V? {
operationResult = null
@Suppress("UNCHECKED_CAST")
Expand Down
181 changes: 181 additions & 0 deletions core/commonMain/src/implementations/immutableMap/TrieNode.kt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@

package kotlinx.collections.immutable.implementations.immutableMap

import kotlinx.collections.immutable.internal.DeltaCounter
import kotlinx.collections.immutable.internal.MutabilityOwnership
import kotlinx.collections.immutable.internal.assert
import kotlinx.collections.immutable.internal.forEachOneBit


internal const val MAX_BRANCHING_FACTOR = 32
Expand Down Expand Up @@ -436,6 +439,87 @@ internal class TrieNode<K, V>(
return this
}

private fun mutableCollisionPutAll(otherNode: TrieNode<K, V>,
intersectionCounter: DeltaCounter,
owner: MutabilityOwnership): TrieNode<K, V> {
assert(nodeMap == 0)
assert(dataMap == 0)
assert(otherNode.nodeMap == 0)
assert(otherNode.dataMap == 0)
val tempBuffer = this.buffer.copyOf(newSize = this.buffer.size + otherNode.buffer.size)
var i = this.buffer.size
for (j in 0 until otherNode.buffer.size step ENTRY_SIZE) {
@Suppress("UNCHECKED_CAST")
if (!this.collisionContainsKey(otherNode.buffer[j] as K)) {
tempBuffer[i] = otherNode.buffer[j]
tempBuffer[i + 1] = otherNode.buffer[j + 1]
i += ENTRY_SIZE
} else intersectionCounter.count++
}

return when (val newSize = i) {
this.buffer.size -> this
otherNode.buffer.size -> otherNode
tempBuffer.size -> TrieNode(0, 0, tempBuffer, owner)
else -> TrieNode(0, 0, tempBuffer.copyOf(newSize), owner)
}
}

private fun mutablePutAllFromOtherNodeCell(other: TrieNode<K, V>,
positionMask: Int,
shift: Int,
intersectionCounter: DeltaCounter,
mutator: PersistentHashMapBuilder<K, V>): TrieNode<K, V> {
return when {
other.hasNodeAt(positionMask) -> {
mutablePutAll(
other.nodeAtIndex(other.nodeIndex(positionMask)),
shift + LOG_MAX_BRANCHING_FACTOR,
intersectionCounter,
mutator
)
}
other.hasEntryAt(positionMask) -> {
val keyIndex = other.entryKeyIndex(positionMask)
val key = other.keyAtIndex(keyIndex)
val value = other.valueAtKeyIndex(keyIndex)
val oldSize = mutator.size
val newNode = mutablePut(
key.hashCode(),
key,
value,
shift + LOG_MAX_BRANCHING_FACTOR,
mutator
)
if (mutator.size == oldSize) {
intersectionCounter.count++
}
newNode
}
else -> this
}
}

private fun calculateSize(): Int {
if (nodeMap == 0) return buffer.size / ENTRY_SIZE
val numValues = dataMap.countOneBits()
var result = numValues
for(i in (numValues * ENTRY_SIZE) until buffer.size) {
result += nodeAtIndex(i).calculateSize()
}
return result
}

private fun elementsIdentityEquals(otherNode: TrieNode<K, V>): Boolean {
if (this === otherNode) return true
if (nodeMap != otherNode.nodeMap) return false
if (dataMap != otherNode.dataMap) return false
for (i in 0 until buffer.size) {
if(buffer[i] !== otherNode.buffer[i]) return false
}
return true
}

fun containsKey(keyHash: Int, key: K, shift: Int): Boolean {
val keyPositionMask = 1 shl indexSegment(keyHash, shift)

Expand Down Expand Up @@ -477,6 +561,103 @@ internal class TrieNode<K, V>(
return null
}

fun mutablePutAll(otherNode: TrieNode<K, V>,
shift: Int,
intersectionCounter: DeltaCounter,
mutator: PersistentHashMapBuilder<K, V>): TrieNode<K, V> {
if (this === otherNode) {
intersectionCounter += calculateSize()
return this
}
// the collision case
if (shift > MAX_SHIFT) {
return mutableCollisionPutAll(otherNode, intersectionCounter, mutator.ownership)
}

// new nodes are where either of the old ones were
var newNodeMap = nodeMap or otherNode.nodeMap
// entries stay being entries only if one bits were in exactly one of input nodes
// but not in the new data nodes
var newDataMap = dataMap xor otherNode.dataMap and newNodeMap.inv()
// (**) now, this is tricky: we have a number of entry-entry pairs and we don't know yet whether
// they result in an entry (if they are equal) or a new node (if they are not)
// but we want to keep it to single allocation, so we check and mark equal ones here
(dataMap and otherNode.dataMap).forEachOneBit { positionMask, _ ->
val leftKey = this.keyAtIndex(this.entryKeyIndex(positionMask))
val rightKey = otherNode.keyAtIndex(otherNode.entryKeyIndex(positionMask))
// if they are equal, put them in the data map
if (leftKey == rightKey) newDataMap = newDataMap or positionMask
// if they are not, put them in the node map
else newNodeMap = newNodeMap or positionMask
// we can use this later to skip calling equals() again
}
assert(newNodeMap and newDataMap == 0)
val mutableNode = when {
this.ownedBy == mutator.ownership && this.dataMap == newDataMap && this.nodeMap == newNodeMap -> this
else -> {
val newBuffer = arrayOfNulls<Any>(newDataMap.countOneBits() * ENTRY_SIZE + newNodeMap.countOneBits())
TrieNode(newDataMap, newNodeMap, newBuffer)
}
}
newNodeMap.forEachOneBit { positionMask, index ->
val newNodeIndex = mutableNode.buffer.size - 1 - index
mutableNode.buffer[newNodeIndex] = when {
hasNodeAt(positionMask) -> {
val before = nodeAtIndex(nodeIndex(positionMask))
before.mutablePutAllFromOtherNodeCell(otherNode, positionMask, shift, intersectionCounter, mutator)
}

otherNode.hasNodeAt(positionMask) -> {
val before = otherNode.nodeAtIndex(otherNode.nodeIndex(positionMask))
before.mutablePutAllFromOtherNodeCell(this, positionMask, shift, intersectionCounter, mutator)
}

else -> { // two entries, and they are not equal by key (see ** above)
val thisKeyIndex = this.entryKeyIndex(positionMask)
val thisKey = this.keyAtIndex(thisKeyIndex)
val thisValue = this.valueAtKeyIndex(thisKeyIndex)
val otherKeyIndex = otherNode.entryKeyIndex(positionMask)
val otherKey = otherNode.keyAtIndex(otherKeyIndex)
val otherValue = otherNode.valueAtKeyIndex(otherKeyIndex)
makeNode(
thisKey.hashCode(),
thisKey,
thisValue,
otherKey.hashCode(),
otherKey,
otherValue,
shift + LOG_MAX_BRANCHING_FACTOR,
mutator.ownership
)
}
}
}
newDataMap.forEachOneBit { positionMask, index ->
val newKeyIndex = index * ENTRY_SIZE
when {
!otherNode.hasEntryAt(positionMask) -> {
val oldKeyIndex = this.entryKeyIndex(positionMask)
mutableNode.buffer[newKeyIndex] = this.keyAtIndex(oldKeyIndex)
mutableNode.buffer[newKeyIndex + 1] = this.valueAtKeyIndex(oldKeyIndex)
}
// there is either only one entry in otherNode, or
// both entries are here => they are equal, see ** above
// so just overwrite that
else -> {
val oldKeyIndex = otherNode.entryKeyIndex(positionMask)
mutableNode.buffer[newKeyIndex] = otherNode.keyAtIndex(oldKeyIndex)
mutableNode.buffer[newKeyIndex + 1] = otherNode.valueAtKeyIndex(oldKeyIndex)
if (this.hasEntryAt(positionMask)) intersectionCounter.count++
}
}
}
return when {
this.elementsIdentityEquals(mutableNode) -> this
otherNode.elementsIdentityEquals(mutableNode) -> otherNode
else -> mutableNode
}
}

fun put(keyHash: Int, key: K, value: @UnsafeVariance V, shift: Int): ModificationResult<K, V>? {
val keyPositionMask = 1 shl indexSegment(keyHash, shift)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package kotlinx.collections.immutable.implementations.immutableSet

import kotlinx.collections.immutable.PersistentSet
import kotlinx.collections.immutable.internal.DeltaCounter
import kotlinx.collections.immutable.mutate

internal class PersistentHashSet<E>(internal val node: TrieNode<E>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package kotlinx.collections.immutable.implementations.immutableSet

import kotlinx.collections.immutable.PersistentSet
import kotlinx.collections.immutable.internal.DeltaCounter
import kotlinx.collections.immutable.internal.MutabilityOwnership

internal class PersistentHashSetBuilder<E>(private var set: PersistentHashSet<E>) : AbstractMutableSet<E>(), PersistentSet.Builder<E> {
Expand Down Expand Up @@ -43,6 +44,21 @@ internal class PersistentHashSetBuilder<E>(private var set: PersistentHashSet<E>
return size != this.size
}

override fun addAll(elements: Collection<E>): Boolean {
val set = elements as? PersistentHashSet ?: (elements as? PersistentHashSetBuilder)?.set
if (set !== null) {
val deltaCounter = DeltaCounter()
val oldSize = this.size
node = node.mutableAddAll(set.node, 0, deltaCounter, this)
val newSize = oldSize + elements.size - deltaCounter.count
if (oldSize != newSize) {
this.size = newSize
}
return oldSize != this.size
}
return super.addAll(elements)
}

override fun remove(element: E): Boolean {
val size = this.size
@Suppress("UNCHECKED_CAST")
Expand Down
Loading

0 comments on commit 690343a

Please sign in to comment.