Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Karatsuba added, 2 bugs are fixed #328

Merged
merged 15 commits into from
May 14, 2021
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,19 @@ import kotlinx.benchmark.Blackhole
import org.openjdk.jmh.annotations.Benchmark
import org.openjdk.jmh.annotations.Scope
import org.openjdk.jmh.annotations.State
import space.kscience.kmath.operations.BigInt
import space.kscience.kmath.operations.BigIntField
import space.kscience.kmath.operations.JBigIntegerField
import space.kscience.kmath.operations.invoke
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.*
import java.math.BigInteger

private fun BigInt.pow(power: Int): BigInt = modPow(BigIntField.number(power), BigInt.ZERO)

@UnstableKMathAPI
@State(Scope.Benchmark)
internal class BigIntBenchmark {

val kmNumber = BigIntField.number(Int.MAX_VALUE)
val jvmNumber = JBigIntegerField.number(Int.MAX_VALUE)
val largeKmNumber = BigIntField { number(11).pow(100_000) }
val largeJvmNumber = JBigIntegerField { number(11).pow(100_000) }
val largeKmNumber = BigIntField { number(11).pow(100_000U) }
val largeJvmNumber: BigInteger = JBigIntegerField { number(11).pow(100_000) }
val bigExponent = 50_000

@Benchmark
Expand All @@ -36,6 +35,16 @@ internal class BigIntBenchmark {
blackhole.consume(jvmNumber + jvmNumber + jvmNumber)
}

@Benchmark
fun kmAddLarge(blackhole: Blackhole) = BigIntField {
blackhole.consume(largeKmNumber + largeKmNumber + largeKmNumber)
}

@Benchmark
fun jvmAddLarge(blackhole: Blackhole) = JBigIntegerField {
blackhole.consume(largeJvmNumber + largeJvmNumber + largeJvmNumber)
}

@Benchmark
fun kmMultiply(blackhole: Blackhole) = BigIntField {
blackhole.consume(kmNumber * kmNumber * kmNumber)
Expand All @@ -56,13 +65,33 @@ internal class BigIntBenchmark {
blackhole.consume(largeJvmNumber*largeJvmNumber)
}

// @Benchmark
// fun kmPower(blackhole: Blackhole) = BigIntField {
// blackhole.consume(kmNumber.pow(bigExponent))
// }
//
// @Benchmark
// fun jvmPower(blackhole: Blackhole) = JBigIntegerField {
// blackhole.consume(jvmNumber.pow(bigExponent))
// }
@Benchmark
fun kmPower(blackhole: Blackhole) = BigIntField {
blackhole.consume(kmNumber.pow(bigExponent.toUInt()))
}

@Benchmark
fun jvmPower(blackhole: Blackhole) = JBigIntegerField {
blackhole.consume(jvmNumber.pow(bigExponent))
}

@Benchmark
fun kmParsing16(blackhole: Blackhole) = JBigIntegerField {
blackhole.consume("0x7f57ed8b89c29a3b9a85c7a5b84ca3929c7b7488593".parseBigInteger())
}

@Benchmark
fun kmParsing10(blackhole: Blackhole) = JBigIntegerField {
blackhole.consume("236656783929183747565738292847574838922010".parseBigInteger())
}

@Benchmark
fun jvmParsing10(blackhole: Blackhole) = JBigIntegerField {
blackhole.consume("236656783929183747565738292847574838922010".toBigInteger(10))
}

@Benchmark
fun jvmParsing16(blackhole: Blackhole) = JBigIntegerField {
blackhole.consume("7f57ed8b89c29a3b9a85c7a5b84ca3929c7b7488593".toBigInteger(16))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package space.kscience.kmath.operations

import space.kscience.kmath.misc.Symbol
import space.kscience.kmath.misc.UnstableKMathAPI

/**
* Stub for DSL the [Algebra] is.
Expand Down Expand Up @@ -247,7 +248,7 @@ public interface RingOperations<T> : GroupOperations<T> {
*/
public interface Ring<T> : Group<T>, RingOperations<T> {
/**
* neutral operation for multiplication
* The neutral element of multiplication
*/
public val one: T
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@ public class BigInt internal constructor(
else -> sign * compareMagnitudes(magnitude, other.magnitude)
}

public override fun equals(other: Any?): Boolean =
if (other is BigInt) compareTo(other) == 0 else error("Can't compare KBigInteger to a different type")
public override fun equals(other: Any?): Boolean = other is BigInt && compareTo(other) == 0

public override fun hashCode(): Int = magnitude.hashCode() + sign

Expand Down Expand Up @@ -87,20 +86,25 @@ public class BigInt internal constructor(
public operator fun times(b: BigInt): BigInt = when {
this.sign == 0.toByte() -> ZERO
b.sign == 0.toByte() -> ZERO
// TODO: Karatsuba
b.magnitude.size == 1 -> this * b.magnitude[0] * b.sign.toInt()
this.magnitude.size == 1 -> b * this.magnitude[0] * this.sign.toInt()
else -> BigInt((this.sign * b.sign).toByte(), multiplyMagnitudes(this.magnitude, b.magnitude))
}

public operator fun times(other: UInt): BigInt = when {
sign == 0.toByte() -> ZERO
other == 0U -> ZERO
other == 1U -> this
else -> BigInt(sign, multiplyMagnitudeByUInt(magnitude, other))
}

public operator fun times(other: Int): BigInt = if (other > 0)
this * kotlin.math.abs(other).toUInt()
else
-this * kotlin.math.abs(other).toUInt()
public fun pow(exponent: UInt): BigInt = BigIntField.power(this@BigInt, exponent)

public operator fun times(other: Int): BigInt = when {
other > 0 -> this * kotlin.math.abs(other).toUInt()
other != Int.MIN_VALUE -> -this * kotlin.math.abs(other).toUInt()
else -> times(other.toBigInt())
}

public operator fun div(other: UInt): BigInt = BigInt(this.sign, divideMagnitudeByUInt(this.magnitude, other))

Expand Down Expand Up @@ -238,6 +242,7 @@ public class BigInt internal constructor(
public const val BASE_SIZE: Int = 32
public val ZERO: BigInt = BigInt(0, uintArrayOf())
public val ONE: BigInt = BigInt(1, uintArrayOf(1u))
private const val KARATSUBA_THRESHOLD = 80

private val hexMapping: HashMap<UInt, String> = hashMapOf(
0U to "0", 1U to "1", 2U to "2", 3U to "3",
Expand Down Expand Up @@ -276,7 +281,7 @@ public class BigInt internal constructor(
}

result[i] = (res and BASE).toUInt()
carry = (res shr BASE_SIZE)
carry = res shr BASE_SIZE
}

result[resultLength - 1] = carry.toUInt()
Expand Down Expand Up @@ -318,7 +323,14 @@ public class BigInt internal constructor(
return stripLeadingZeros(result)
}

private fun multiplyMagnitudes(mag1: Magnitude, mag2: Magnitude): Magnitude {
internal fun multiplyMagnitudes(mag1: Magnitude, mag2: Magnitude): Magnitude = when {
mag1.size + mag2.size < KARATSUBA_THRESHOLD || mag1.isEmpty() || mag2.isEmpty() ->
naiveMultiplyMagnitudes(mag1, mag2)
// TODO implement Fourier
else -> karatsubaMultiplyMagnitudes(mag1, mag2)
}

internal fun naiveMultiplyMagnitudes(mag1: Magnitude, mag2: Magnitude): Magnitude {
val resultLength = mag1.size + mag2.size
val result = Magnitude(resultLength)

Expand All @@ -337,6 +349,21 @@ public class BigInt internal constructor(
return stripLeadingZeros(result)
}

internal fun karatsubaMultiplyMagnitudes(mag1: Magnitude, mag2: Magnitude): Magnitude {
//https://en.wikipedia.org/wiki/Karatsuba_algorithm
val halfSize = min(mag1.size, mag2.size) / 2
val x0 = mag1.sliceArray(0 until halfSize).toBigInt(1)
val x1 = mag1.sliceArray(halfSize until mag1.size).toBigInt(1)
val y0 = mag2.sliceArray(0 until halfSize).toBigInt(1)
val y1 = mag2.sliceArray(halfSize until mag2.size).toBigInt(1)

val z0 = x0 * y0
val z2 = x1 * y1
val z1 = (x0 - x1) * (y1 - y0) + z0 + z2

return (z2.shl(2 * halfSize * BASE_SIZE) + z1.shl(halfSize * BASE_SIZE) + z0).magnitude
}

private fun divideMagnitudeByUInt(mag: Magnitude, x: UInt): Magnitude {
val resultLength = mag.size
val result = Magnitude(resultLength)
Expand Down Expand Up @@ -414,58 +441,90 @@ public fun UIntArray.toBigInt(sign: Byte): BigInt {
return BigInt(sign, copyOf())
}

private val hexChToInt: MutableMap<Char, Int> = hashMapOf(
'0' to 0, '1' to 1, '2' to 2, '3' to 3,
'4' to 4, '5' to 5, '6' to 6, '7' to 7,
'8' to 8, '9' to 9, 'A' to 10, 'B' to 11,
'C' to 12, 'D' to 13, 'E' to 14, 'F' to 15
)

/**
* Returns null if a valid number can not be read from a string
*/
public fun String.parseBigInteger(): BigInt? {
if (this.isEmpty()) return null
val sign: Int
val sPositive: String

when {
this[0] == '+' -> {
val positivePartIndex = when (this[0]) {
'+' -> {
sign = +1
sPositive = this.substring(1)
1
}
this[0] == '-' -> {
'-' -> {
sign = -1
sPositive = this.substring(1)
1
}
else -> {
sPositive = this
sign = +1
0
}
}

var res = BigInt.ZERO
var digitValue = BigInt.ONE
val sPositiveUpper = sPositive.uppercase()
var isEmpty = true

if (sPositiveUpper.startsWith("0X")) { // hex representation
val sHex = sPositiveUpper.substring(2)
return if (this.startsWith("0X", startIndex = positivePartIndex, ignoreCase = true)) {
// hex representation

val uInts = ArrayList<UInt>(length).apply { add(0U) }
var offset = 0
fun addDigit(value: UInt) {
uInts[uInts.lastIndex] += value shl offset
offset += 4
if (offset == 32) {
uInts.add(0U)
offset = 0
}
}

for (ch in sHex.reversed()) {
if (ch == '_') continue
res += digitValue * (hexChToInt[ch] ?: return null)
digitValue *= 16.toBigInt()
for (index in lastIndex downTo positivePartIndex + 2) {
when (val ch = this[index]) {
'_' -> continue
in '0'..'9' -> addDigit((ch - '0').toUInt())
in 'A'..'F' -> addDigit((ch - 'A').toUInt() + 10U)
in 'a'..'f' -> addDigit((ch - 'a').toUInt() + 10U)
else -> return null
}
isEmpty = false
}
} else for (ch in sPositiveUpper.reversed()) {

while (uInts.isNotEmpty() && uInts.last() == 0U)
uInts.removeLast()

if (isEmpty) null else BigInt(sign.toByte(), uInts.toUIntArray())
} else {
// decimal representation
if (ch == '_') continue
if (ch !in '0'..'9') {
return null

val positivePart = buildList(length) {
for (index in positivePartIndex until length)
when (val a = this@parseBigInteger[index]) {
'_' -> continue
in '0'..'9' -> add(a)
else -> return null
}
}
res += digitValue * (ch.code - '0'.code)
digitValue *= 10.toBigInt()
}

return res * sign
val offset = positivePart.size % 9
isEmpty = offset == 0

fun parseUInt(fromIndex: Int, toIndex: Int): UInt? {
var res = 0U
for (i in fromIndex until toIndex) {
res = res * 10U + (positivePart[i].digitToIntOrNull()?.toUInt() ?: return null)
}
return res
}

var res = parseUInt(0, offset)?.toBigInt() ?: return null

for (index in offset..positivePart.lastIndex step 9) {
isEmpty = false
res = res * 1_000_000_000U + (parseUInt(index, index + 9) ?: return null).toBigInt()
}
if (isEmpty) null else res * sign
}
}

public inline fun Buffer.Companion.bigInt(size: Int, initializer: (Int) -> BigInt): Buffer<BigInt> =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ public interface PowerOperations<T> : Algebra<T> {
}

/**
* Raises this element to the power [pow].
* Raises this element to the power [power].
*
* @receiver the base.
* @param power the exponent.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,34 +97,45 @@ public fun <T, S> Sequence<T>.averageWith(space: S): T where S : Ring<T>, S : Sc
//TODO optimized power operation

/**
* Raises [arg] to the natural power [power].
* Raises [arg] to the non-negative integer power [power].
*
* Special case: 0 ^ 0 is 1.
*
* @receiver the algebra to provide multiplication.
* @param arg the base.
* @param power the exponent.
* @return the base raised to the power.
* @author Evgeniy Zhelenskiy
*/
public fun <T> Ring<T>.power(arg: T, power: Int): T {
require(power >= 0) { "The power can't be negative." }
require(power != 0 || arg != zero) { "The $zero raised to $power is not defined." }
if (power == 0) return one
var res = arg
repeat(power - 1) { res *= arg }
return res
public fun <T> Ring<T>.power(arg: T, power: UInt): T = when {
arg == zero && power > 0U -> zero
arg == one -> arg
arg == -one -> powWithoutOptimization(arg, power % 2U)
else -> powWithoutOptimization(arg, power)
}

private fun <T> Ring<T>.powWithoutOptimization(base: T, exponent: UInt): T = when (exponent) {
0U -> one
1U -> base
else -> {
val pre = powWithoutOptimization(base, exponent shr 1).let { it * it }
if (exponent and 1U == 0U) pre else pre * base
}
}


/**
* Raises [arg] to the integer power [power].
*
* Special case: 0 ^ 0 is 1.
*
* @receiver the algebra to provide multiplication and division.
* @param arg the base.
* @param power the exponent.
* @return the base raised to the power.
* @author Iaroslav Postovalov
* @author Iaroslav Postovalov, Evgeniy Zhelenskiy
*/
public fun <T> Field<T>.power(arg: T, power: Int): T {
require(power != 0 || arg != zero) { "The $zero raised to $power is not defined." }
if (power == 0) return one
if (power < 0) return one / (this as Ring<T>).power(arg, -power)
return (this as Ring<T>).power(arg, power)
public fun <T> Field<T>.power(arg: T, power: UInt): T = when {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@altavir You are wrong here: Field can have negative power as Field has division. That was even in the previous implementation.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is reverted.

power < 0 -> one / (this as Ring<T>).power(arg, power)
else -> (this as Ring<T>).power(arg, power)
}
Loading