Skip to content

Commit

Permalink
Merge pull request #328 from zhelenskiy/dev
Browse files Browse the repository at this point in the history
Karatsuba added, 2 bugs are fixed
  • Loading branch information
altavir authored May 14, 2021
2 parents a86e8eb + bdb9ce6 commit c1b94ff
Show file tree
Hide file tree
Showing 9 changed files with 249 additions and 74 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,19 @@ import kotlinx.benchmark.Blackhole
import org.openjdk.jmh.annotations.Benchmark
import org.openjdk.jmh.annotations.Scope
import org.openjdk.jmh.annotations.State
import space.kscience.kmath.operations.BigInt
import space.kscience.kmath.operations.BigIntField
import space.kscience.kmath.operations.JBigIntegerField
import space.kscience.kmath.operations.invoke
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.*
import java.math.BigInteger

private fun BigInt.pow(power: Int): BigInt = modPow(BigIntField.number(power), BigInt.ZERO)

@UnstableKMathAPI
@State(Scope.Benchmark)
internal class BigIntBenchmark {

val kmNumber = BigIntField.number(Int.MAX_VALUE)
val jvmNumber = JBigIntegerField.number(Int.MAX_VALUE)
val largeKmNumber = BigIntField { number(11).pow(100_000) }
val largeJvmNumber = JBigIntegerField { number(11).pow(100_000) }
val largeKmNumber = BigIntField { number(11).pow(100_000U) }
val largeJvmNumber: BigInteger = JBigIntegerField { number(11).pow(100_000) }
val bigExponent = 50_000

@Benchmark
Expand All @@ -36,6 +35,16 @@ internal class BigIntBenchmark {
blackhole.consume(jvmNumber + jvmNumber + jvmNumber)
}

@Benchmark
fun kmAddLarge(blackhole: Blackhole) = BigIntField {
blackhole.consume(largeKmNumber + largeKmNumber + largeKmNumber)
}

@Benchmark
fun jvmAddLarge(blackhole: Blackhole) = JBigIntegerField {
blackhole.consume(largeJvmNumber + largeJvmNumber + largeJvmNumber)
}

@Benchmark
fun kmMultiply(blackhole: Blackhole) = BigIntField {
blackhole.consume(kmNumber * kmNumber * kmNumber)
Expand All @@ -56,13 +65,33 @@ internal class BigIntBenchmark {
blackhole.consume(largeJvmNumber*largeJvmNumber)
}

// @Benchmark
// fun kmPower(blackhole: Blackhole) = BigIntField {
// blackhole.consume(kmNumber.pow(bigExponent))
// }
//
// @Benchmark
// fun jvmPower(blackhole: Blackhole) = JBigIntegerField {
// blackhole.consume(jvmNumber.pow(bigExponent))
// }
@Benchmark
fun kmPower(blackhole: Blackhole) = BigIntField {
blackhole.consume(kmNumber.pow(bigExponent.toUInt()))
}

@Benchmark
fun jvmPower(blackhole: Blackhole) = JBigIntegerField {
blackhole.consume(jvmNumber.pow(bigExponent))
}

@Benchmark
fun kmParsing16(blackhole: Blackhole) = JBigIntegerField {
blackhole.consume("0x7f57ed8b89c29a3b9a85c7a5b84ca3929c7b7488593".parseBigInteger())
}

@Benchmark
fun kmParsing10(blackhole: Blackhole) = JBigIntegerField {
blackhole.consume("236656783929183747565738292847574838922010".parseBigInteger())
}

@Benchmark
fun jvmParsing10(blackhole: Blackhole) = JBigIntegerField {
blackhole.consume("236656783929183747565738292847574838922010".toBigInteger(10))
}

@Benchmark
fun jvmParsing16(blackhole: Blackhole) = JBigIntegerField {
blackhole.consume("7f57ed8b89c29a3b9a85c7a5b84ca3929c7b7488593".toBigInteger(16))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package space.kscience.kmath.operations

import space.kscience.kmath.misc.Symbol
import space.kscience.kmath.misc.UnstableKMathAPI

/**
* Stub for DSL the [Algebra] is.
Expand Down Expand Up @@ -247,7 +248,7 @@ public interface RingOperations<T> : GroupOperations<T> {
*/
public interface Ring<T> : Group<T>, RingOperations<T> {
/**
* neutral operation for multiplication
* The neutral element of multiplication
*/
public val one: T
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@ public class BigInt internal constructor(
else -> sign * compareMagnitudes(magnitude, other.magnitude)
}

public override fun equals(other: Any?): Boolean =
if (other is BigInt) compareTo(other) == 0 else error("Can't compare KBigInteger to a different type")
public override fun equals(other: Any?): Boolean = other is BigInt && compareTo(other) == 0

public override fun hashCode(): Int = magnitude.hashCode() + sign

Expand Down Expand Up @@ -87,20 +86,25 @@ public class BigInt internal constructor(
public operator fun times(b: BigInt): BigInt = when {
this.sign == 0.toByte() -> ZERO
b.sign == 0.toByte() -> ZERO
// TODO: Karatsuba
b.magnitude.size == 1 -> this * b.magnitude[0] * b.sign.toInt()
this.magnitude.size == 1 -> b * this.magnitude[0] * this.sign.toInt()
else -> BigInt((this.sign * b.sign).toByte(), multiplyMagnitudes(this.magnitude, b.magnitude))
}

public operator fun times(other: UInt): BigInt = when {
sign == 0.toByte() -> ZERO
other == 0U -> ZERO
other == 1U -> this
else -> BigInt(sign, multiplyMagnitudeByUInt(magnitude, other))
}

public operator fun times(other: Int): BigInt = if (other > 0)
this * kotlin.math.abs(other).toUInt()
else
-this * kotlin.math.abs(other).toUInt()
public fun pow(exponent: UInt): BigInt = BigIntField.power(this@BigInt, exponent)

public operator fun times(other: Int): BigInt = when {
other > 0 -> this * kotlin.math.abs(other).toUInt()
other != Int.MIN_VALUE -> -this * kotlin.math.abs(other).toUInt()
else -> times(other.toBigInt())
}

public operator fun div(other: UInt): BigInt = BigInt(this.sign, divideMagnitudeByUInt(this.magnitude, other))

Expand Down Expand Up @@ -238,6 +242,7 @@ public class BigInt internal constructor(
public const val BASE_SIZE: Int = 32
public val ZERO: BigInt = BigInt(0, uintArrayOf())
public val ONE: BigInt = BigInt(1, uintArrayOf(1u))
private const val KARATSUBA_THRESHOLD = 80

private val hexMapping: HashMap<UInt, String> = hashMapOf(
0U to "0", 1U to "1", 2U to "2", 3U to "3",
Expand Down Expand Up @@ -276,7 +281,7 @@ public class BigInt internal constructor(
}

result[i] = (res and BASE).toUInt()
carry = (res shr BASE_SIZE)
carry = res shr BASE_SIZE
}

result[resultLength - 1] = carry.toUInt()
Expand Down Expand Up @@ -318,7 +323,14 @@ public class BigInt internal constructor(
return stripLeadingZeros(result)
}

private fun multiplyMagnitudes(mag1: Magnitude, mag2: Magnitude): Magnitude {
internal fun multiplyMagnitudes(mag1: Magnitude, mag2: Magnitude): Magnitude = when {
mag1.size + mag2.size < KARATSUBA_THRESHOLD || mag1.isEmpty() || mag2.isEmpty() ->
naiveMultiplyMagnitudes(mag1, mag2)
// TODO implement Fourier
else -> karatsubaMultiplyMagnitudes(mag1, mag2)
}

internal fun naiveMultiplyMagnitudes(mag1: Magnitude, mag2: Magnitude): Magnitude {
val resultLength = mag1.size + mag2.size
val result = Magnitude(resultLength)

Expand All @@ -337,6 +349,21 @@ public class BigInt internal constructor(
return stripLeadingZeros(result)
}

internal fun karatsubaMultiplyMagnitudes(mag1: Magnitude, mag2: Magnitude): Magnitude {
//https://en.wikipedia.org/wiki/Karatsuba_algorithm
val halfSize = min(mag1.size, mag2.size) / 2
val x0 = mag1.sliceArray(0 until halfSize).toBigInt(1)
val x1 = mag1.sliceArray(halfSize until mag1.size).toBigInt(1)
val y0 = mag2.sliceArray(0 until halfSize).toBigInt(1)
val y1 = mag2.sliceArray(halfSize until mag2.size).toBigInt(1)

val z0 = x0 * y0
val z2 = x1 * y1
val z1 = (x0 - x1) * (y1 - y0) + z0 + z2

return (z2.shl(2 * halfSize * BASE_SIZE) + z1.shl(halfSize * BASE_SIZE) + z0).magnitude
}

private fun divideMagnitudeByUInt(mag: Magnitude, x: UInt): Magnitude {
val resultLength = mag.size
val result = Magnitude(resultLength)
Expand Down Expand Up @@ -414,58 +441,90 @@ public fun UIntArray.toBigInt(sign: Byte): BigInt {
return BigInt(sign, copyOf())
}

private val hexChToInt: MutableMap<Char, Int> = hashMapOf(
'0' to 0, '1' to 1, '2' to 2, '3' to 3,
'4' to 4, '5' to 5, '6' to 6, '7' to 7,
'8' to 8, '9' to 9, 'A' to 10, 'B' to 11,
'C' to 12, 'D' to 13, 'E' to 14, 'F' to 15
)

/**
* Returns null if a valid number can not be read from a string
*/
public fun String.parseBigInteger(): BigInt? {
if (this.isEmpty()) return null
val sign: Int
val sPositive: String

when {
this[0] == '+' -> {
val positivePartIndex = when (this[0]) {
'+' -> {
sign = +1
sPositive = this.substring(1)
1
}
this[0] == '-' -> {
'-' -> {
sign = -1
sPositive = this.substring(1)
1
}
else -> {
sPositive = this
sign = +1
0
}
}

var res = BigInt.ZERO
var digitValue = BigInt.ONE
val sPositiveUpper = sPositive.uppercase()
var isEmpty = true

if (sPositiveUpper.startsWith("0X")) { // hex representation
val sHex = sPositiveUpper.substring(2)
return if (this.startsWith("0X", startIndex = positivePartIndex, ignoreCase = true)) {
// hex representation

val uInts = ArrayList<UInt>(length).apply { add(0U) }
var offset = 0
fun addDigit(value: UInt) {
uInts[uInts.lastIndex] += value shl offset
offset += 4
if (offset == 32) {
uInts.add(0U)
offset = 0
}
}

for (ch in sHex.reversed()) {
if (ch == '_') continue
res += digitValue * (hexChToInt[ch] ?: return null)
digitValue *= 16.toBigInt()
for (index in lastIndex downTo positivePartIndex + 2) {
when (val ch = this[index]) {
'_' -> continue
in '0'..'9' -> addDigit((ch - '0').toUInt())
in 'A'..'F' -> addDigit((ch - 'A').toUInt() + 10U)
in 'a'..'f' -> addDigit((ch - 'a').toUInt() + 10U)
else -> return null
}
isEmpty = false
}
} else for (ch in sPositiveUpper.reversed()) {

while (uInts.isNotEmpty() && uInts.last() == 0U)
uInts.removeLast()

if (isEmpty) null else BigInt(sign.toByte(), uInts.toUIntArray())
} else {
// decimal representation
if (ch == '_') continue
if (ch !in '0'..'9') {
return null

val positivePart = buildList(length) {
for (index in positivePartIndex until length)
when (val a = this@parseBigInteger[index]) {
'_' -> continue
in '0'..'9' -> add(a)
else -> return null
}
}
res += digitValue * (ch.code - '0'.code)
digitValue *= 10.toBigInt()
}

return res * sign
val offset = positivePart.size % 9
isEmpty = offset == 0

fun parseUInt(fromIndex: Int, toIndex: Int): UInt? {
var res = 0U
for (i in fromIndex until toIndex) {
res = res * 10U + (positivePart[i].digitToIntOrNull()?.toUInt() ?: return null)
}
return res
}

var res = parseUInt(0, offset)?.toBigInt() ?: return null

for (index in offset..positivePart.lastIndex step 9) {
isEmpty = false
res = res * 1_000_000_000U + (parseUInt(index, index + 9) ?: return null).toBigInt()
}
if (isEmpty) null else res * sign
}
}

public inline fun Buffer.Companion.bigInt(size: Int, initializer: (Int) -> BigInt): Buffer<BigInt> =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ public interface PowerOperations<T> : Algebra<T> {
}

/**
* Raises this element to the power [pow].
* Raises this element to the power [power].
*
* @receiver the base.
* @param power the exponent.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,34 +97,45 @@ public fun <T, S> Sequence<T>.averageWith(space: S): T where S : Ring<T>, S : Sc
//TODO optimized power operation

/**
* Raises [arg] to the natural power [power].
* Raises [arg] to the non-negative integer power [power].
*
* Special case: 0 ^ 0 is 1.
*
* @receiver the algebra to provide multiplication.
* @param arg the base.
* @param power the exponent.
* @return the base raised to the power.
* @author Evgeniy Zhelenskiy
*/
public fun <T> Ring<T>.power(arg: T, power: Int): T {
require(power >= 0) { "The power can't be negative." }
require(power != 0 || arg != zero) { "The $zero raised to $power is not defined." }
if (power == 0) return one
var res = arg
repeat(power - 1) { res *= arg }
return res
public fun <T> Ring<T>.power(arg: T, power: UInt): T = when {
arg == zero && power > 0U -> zero
arg == one -> arg
arg == -one -> powWithoutOptimization(arg, power % 2U)
else -> powWithoutOptimization(arg, power)
}

private fun <T> Ring<T>.powWithoutOptimization(base: T, exponent: UInt): T = when (exponent) {
0U -> one
1U -> base
else -> {
val pre = powWithoutOptimization(base, exponent shr 1).let { it * it }
if (exponent and 1U == 0U) pre else pre * base
}
}


/**
* Raises [arg] to the integer power [power].
*
* Special case: 0 ^ 0 is 1.
*
* @receiver the algebra to provide multiplication and division.
* @param arg the base.
* @param power the exponent.
* @return the base raised to the power.
* @author Iaroslav Postovalov
* @author Iaroslav Postovalov, Evgeniy Zhelenskiy
*/
public fun <T> Field<T>.power(arg: T, power: Int): T {
require(power != 0 || arg != zero) { "The $zero raised to $power is not defined." }
if (power == 0) return one
if (power < 0) return one / (this as Ring<T>).power(arg, -power)
return (this as Ring<T>).power(arg, power)
public fun <T> Field<T>.power(arg: T, power: UInt): T = when {
power < 0 -> one / (this as Ring<T>).power(arg, power)
else -> (this as Ring<T>).power(arg, power)
}
Loading

0 comments on commit c1b94ff

Please sign in to comment.