Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Karatsuba added, 2 bugs are fixed #328

Merged
merged 15 commits into from
May 14, 2021
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,19 @@ import kotlinx.benchmark.Blackhole
import org.openjdk.jmh.annotations.Benchmark
import org.openjdk.jmh.annotations.Scope
import org.openjdk.jmh.annotations.State
import space.kscience.kmath.operations.BigInt
import space.kscience.kmath.operations.BigIntField
import space.kscience.kmath.operations.JBigIntegerField
import space.kscience.kmath.operations.invoke

private fun BigInt.pow(power: Int): BigInt = modPow(BigIntField.number(power), BigInt.ONE)
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.*
import java.math.BigInteger


@UnstableKMathAPI
@State(Scope.Benchmark)
internal class BigIntBenchmark {

val kmNumber = BigIntField.number(Int.MAX_VALUE)
val jvmNumber = JBigIntegerField.number(Int.MAX_VALUE)
val largeKmNumber = BigIntField { number(11).pow(100_000) }
val largeJvmNumber = JBigIntegerField { number(11).pow(100_000) }
val largeKmNumber = BigIntField { number(11).pow(100_000UL) }
val largeJvmNumber: BigInteger = JBigIntegerField { number(11).pow(100_000) }
val bigExponent = 50_000

@Benchmark
Expand All @@ -37,6 +35,16 @@ internal class BigIntBenchmark {
blackhole.consume(jvmNumber + jvmNumber + jvmNumber)
}

@Benchmark
fun kmAddLarge(blackhole: Blackhole) = BigIntField {
blackhole.consume(largeKmNumber + largeKmNumber + largeKmNumber)
}

@Benchmark
fun jvmAddLarge(blackhole: Blackhole) = JBigIntegerField {
blackhole.consume(largeJvmNumber + largeJvmNumber + largeJvmNumber)
}

@Benchmark
fun kmMultiply(blackhole: Blackhole) = BigIntField {
blackhole.consume(kmNumber * kmNumber * kmNumber)
Expand All @@ -59,11 +67,31 @@ internal class BigIntBenchmark {

@Benchmark
fun kmPower(blackhole: Blackhole) = BigIntField {
blackhole.consume(kmNumber.pow(bigExponent))
blackhole.consume(kmNumber.pow(bigExponent.toULong()))
}

@Benchmark
fun jvmPower(blackhole: Blackhole) = JBigIntegerField {
blackhole.consume(jvmNumber.pow(bigExponent))
}

@Benchmark
fun kmParsing16(blackhole: Blackhole) = JBigIntegerField {
blackhole.consume("0x7f57ed8b89c29a3b9a85c7a5b84ca3929c7b7488593".parseBigInteger())
}

@Benchmark
fun kmParsing10(blackhole: Blackhole) = JBigIntegerField {
blackhole.consume("236656783929183747565738292847574838922010".parseBigInteger())
}

@Benchmark
fun jvmParsing10(blackhole: Blackhole) = JBigIntegerField {
blackhole.consume("236656783929183747565738292847574838922010".toBigInteger(10))
}

@Benchmark
fun jvmParsing16(blackhole: Blackhole) = JBigIntegerField {
blackhole.consume("7f57ed8b89c29a3b9a85c7a5b84ca3929c7b7488593".toBigInteger(16))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package space.kscience.kmath.operations

import space.kscience.kmath.misc.Symbol
import space.kscience.kmath.misc.UnstableKMathAPI

/**
* Stub for DSL the [Algebra] is.
Expand Down Expand Up @@ -247,11 +248,31 @@ public interface RingOperations<T> : GroupOperations<T> {
*/
public interface Ring<T> : Group<T>, RingOperations<T> {
/**
* neutral operation for multiplication
* The neutral element of multiplication
*/
public val one: T
}

@UnstableKMathAPI
public fun <T> Ring<T>.pow(base: T, exponent: ULong): T = when {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add KDoc

this == zero && exponent > 0UL -> zero
this == one -> base
this == -one -> powWithoutOptimization(base, exponent % 2UL)
else -> powWithoutOptimization(base, exponent)
}

@UnstableKMathAPI
public fun <T> Ring<T>.pow(base: T, exponent: UInt): T = pow(base, exponent.toULong())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add KDoc


private fun <T> Ring<T>.powWithoutOptimization(base: T, exponent: ULong): T = when (exponent) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move pow operation to algebraExtensions.kt. It actually already have those functions: https://github.com/mipt-npm/kmath/blob/ca02f5406d06d18d07909f3d48508e1a99701fa5/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/algebraExtensions.kt#L107, so you shoud probably just replace the implementation and use UInt instead of Int.

Copy link
Contributor Author

@zhelenskiy zhelenskiy May 13, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not ULong? Because of being consistent with Field<T>.power which has a stable API?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, I see a bug with Field.power:

public fun <T> Field<T>.power(arg: T, power: Int): T {
    require(power != 0 || arg != zero) { "The $zero raised to $power is not defined." }
    if (power == 0) return one
    if (power < 0) return one / (this as Ring<T>).power(arg, -power)
    return (this as Ring<T>).power(arg, power)
}

If [power] is Int.MIN_VALUE, this will fail

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The implementation could be wrong in corner cases. It is not properly tested. As for ULong, the major bulk of operations on JVM and in kotlin use Int. I can't see a case where people would use more than Int for a power (please show them if you know them) and using Ulong would force users to do an additional unnecessary conversion.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am absolutely sure that there is no need to power numbers by ULong ranged numbers, however, this may be ok for some other Rings (albeit I don't know such ones).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, 0^0 is not defined here while it is 1 in java's BigInt. Which convention to choose?

public fun <T> Ring<T>.power(arg: T, power: Int): T {
    require(power >= 0) { "The power can't be negative." }
    require(power != 0 || arg != zero) { "The $zero raised to $power is not defined." }
    if (power == 0) return one
    var res = arg
    repeat(power - 1) { res *= arg }
    return res
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's stick to UInt for now because I still can't imagine a case for longs. Since we are using extensions, it is easy to add another one if somebody needs it. As for 0^0, please fix it. I don't remember where this method comes from, but it was quite obviously forgotten.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean that 0^0 is expected to be 1?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I published the changes.

0UL -> one
1UL -> base
else -> {
val pre = powWithoutOptimization(base, exponent shr 1).let { it * it }
if (exponent and 1UL == 0UL) pre else pre * base
}
}

/**
* Represents field without without multiplicative and additive identities, i.e. algebraic structure with associative, binary, commutative operations
* [add] and [multiply]; binary operation [divide] as multiplication of left operand by reciprocal of right one.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@ public class BigInt internal constructor(
else -> sign * compareMagnitudes(magnitude, other.magnitude)
}

public override fun equals(other: Any?): Boolean =
if (other is BigInt) compareTo(other) == 0 else error("Can't compare KBigInteger to a different type")
public override fun equals(other: Any?): Boolean = other is BigInt && compareTo(other) == 0

public override fun hashCode(): Int = magnitude.hashCode() + sign

Expand Down Expand Up @@ -87,20 +86,29 @@ public class BigInt internal constructor(
public operator fun times(b: BigInt): BigInt = when {
this.sign == 0.toByte() -> ZERO
b.sign == 0.toByte() -> ZERO
// TODO: Karatsuba
b.magnitude.size == 1 -> this * b.magnitude[0] * b.sign.toInt()
this.magnitude.size == 1 -> b * this.magnitude[0] * this.sign.toInt()
else -> BigInt((this.sign * b.sign).toByte(), multiplyMagnitudes(this.magnitude, b.magnitude))
}

public operator fun times(other: UInt): BigInt = when {
sign == 0.toByte() -> ZERO
other == 0U -> ZERO
other == 1U -> this
else -> BigInt(sign, multiplyMagnitudeByUInt(magnitude, other))
}

public operator fun times(other: Int): BigInt = if (other > 0)
this * kotlin.math.abs(other).toUInt()
else
-this * kotlin.math.abs(other).toUInt()
@UnstableKMathAPI
public fun pow(other: ULong): BigInt = BigIntField { pow(this@BigInt, other) }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BigIntField.pow(this, other) should be more concise.


@UnstableKMathAPI
public fun pow(other: UInt): BigInt = BigIntField { pow(this@BigInt, other) }

public operator fun times(other: Int): BigInt = when {
other > 0 -> this * kotlin.math.abs(other).toUInt()
other != Int.MIN_VALUE -> -this * kotlin.math.abs(other).toUInt()
else -> times(other.toBigInt())
}

public operator fun div(other: UInt): BigInt = BigInt(this.sign, divideMagnitudeByUInt(this.magnitude, other))

Expand Down Expand Up @@ -238,6 +246,7 @@ public class BigInt internal constructor(
public const val BASE_SIZE: Int = 32
public val ZERO: BigInt = BigInt(0, uintArrayOf())
public val ONE: BigInt = BigInt(1, uintArrayOf(1u))
private const val KARATSUBA_THRESHOLD = 80

private val hexMapping: HashMap<UInt, String> = hashMapOf(
0U to "0", 1U to "1", 2U to "2", 3U to "3",
Expand Down Expand Up @@ -276,7 +285,7 @@ public class BigInt internal constructor(
}

result[i] = (res and BASE).toUInt()
carry = (res shr BASE_SIZE)
carry = res shr BASE_SIZE
}

result[resultLength - 1] = carry.toUInt()
Expand Down Expand Up @@ -318,7 +327,14 @@ public class BigInt internal constructor(
return stripLeadingZeros(result)
}

private fun multiplyMagnitudes(mag1: Magnitude, mag2: Magnitude): Magnitude {
internal fun multiplyMagnitudes(mag1: Magnitude, mag2: Magnitude): Magnitude = when {
mag1.size + mag2.size < KARATSUBA_THRESHOLD || mag1.isEmpty() || mag2.isEmpty() ->
naiveMultiplyMagnitudes(mag1, mag2)
// TODO implement Fourier
else -> karatsubaMultiplyMagnitudes(mag1, mag2)
}

internal fun naiveMultiplyMagnitudes(mag1: Magnitude, mag2: Magnitude): Magnitude {
val resultLength = mag1.size + mag2.size
val result = Magnitude(resultLength)

Expand All @@ -337,6 +353,21 @@ public class BigInt internal constructor(
return stripLeadingZeros(result)
}

internal fun karatsubaMultiplyMagnitudes(mag1: Magnitude, mag2: Magnitude): Magnitude {
//https://en.wikipedia.org/wiki/Karatsuba_algorithm
val halfSize = min(mag1.size, mag2.size) / 2
val x0 = mag1.sliceArray(0 until halfSize).toBigInt(1)
val x1 = mag1.sliceArray(halfSize until mag1.size).toBigInt(1)
val y0 = mag2.sliceArray(0 until halfSize).toBigInt(1)
val y1 = mag2.sliceArray(halfSize until mag2.size).toBigInt(1)

val z0 = x0 * y0
val z2 = x1 * y1
val z1 = (x0 - x1) * (y1 - y0) + z0 + z2

return (z2.shl(2 * halfSize * BASE_SIZE) + z1.shl(halfSize * BASE_SIZE) + z0).magnitude
}

private fun divideMagnitudeByUInt(mag: Magnitude, x: UInt): Magnitude {
val resultLength = mag.size
val result = Magnitude(resultLength)
Expand Down Expand Up @@ -414,58 +445,90 @@ public fun UIntArray.toBigInt(sign: Byte): BigInt {
return BigInt(sign, copyOf())
}

private val hexChToInt: MutableMap<Char, Int> = hashMapOf(
'0' to 0, '1' to 1, '2' to 2, '3' to 3,
'4' to 4, '5' to 5, '6' to 6, '7' to 7,
'8' to 8, '9' to 9, 'A' to 10, 'B' to 11,
'C' to 12, 'D' to 13, 'E' to 14, 'F' to 15
)

/**
* Returns null if a valid number can not be read from a string
*/
public fun String.parseBigInteger(): BigInt? {
if (this.isEmpty()) return null
val sign: Int
val sPositive: String

when {
this[0] == '+' -> {
val positivePartIndex = when (this[0]) {
'+' -> {
sign = +1
sPositive = this.substring(1)
1
}
this[0] == '-' -> {
'-' -> {
sign = -1
sPositive = this.substring(1)
1
}
else -> {
sPositive = this
sign = +1
0
}
}

var res = BigInt.ZERO
var digitValue = BigInt.ONE
val sPositiveUpper = sPositive.uppercase()
var isEmpty = true

if (sPositiveUpper.startsWith("0X")) { // hex representation
val sHex = sPositiveUpper.substring(2)
return if (this.startsWith("0X", startIndex = positivePartIndex, ignoreCase = true)) {
// hex representation

for (ch in sHex.reversed()) {
if (ch == '_') continue
res += digitValue * (hexChToInt[ch] ?: return null)
digitValue *= 16.toBigInt()
val uInts = ArrayList<UInt>(length).apply { add(0U) }
var offset = 0
fun addDigit(value: UInt) {
uInts[uInts.lastIndex] += value shl offset
offset += 4
if (offset == 32) {
uInts.add(0U)
offset = 0
}
}
} else for (ch in sPositiveUpper.reversed()) {

for (index in lastIndex downTo positivePartIndex + 2) {
when (val ch = this[index]) {
'_' -> continue
in '0'..'9' -> addDigit((ch - '0').toUInt())
in 'A'..'F' -> addDigit((ch - 'A').toUInt() + 10U)
in 'a'..'f' -> addDigit((ch - 'a').toUInt() + 10U)
else -> return null
}
isEmpty = false
}

while (uInts.isNotEmpty() && uInts.last() == 0U)
uInts.removeLast()

if (isEmpty) null else BigInt(sign.toByte(), uInts.toUIntArray())
} else {
// decimal representation
if (ch == '_') continue
if (ch !in '0'..'9') {
return null

val positivePart = buildList(length) {
for (index in positivePartIndex until length)
when (val a = this@parseBigInteger[index]) {
'_' -> continue
in '0'..'9' -> add(a)
else -> return null
}
}
res += digitValue * (ch.code - '0'.code)
digitValue *= 10.toBigInt()
}

return res * sign
val offset = positivePart.size % 9
isEmpty = offset == 0

fun parseUInt(fromIndex: Int, toIndex: Int): UInt? {
var res = 0U
for (i in fromIndex until toIndex) {
res = res * 10U + (positivePart[i].digitToIntOrNull()?.toUInt() ?: return null)
}
return res
}

var res = parseUInt(0, offset)?.toBigInt() ?: return null

for (index in offset..positivePart.lastIndex step 9) {
isEmpty = false
res = res * 1_000_000_000U + (parseUInt(index, index + 9) ?: return null).toBigInt()
}
if (isEmpty) null else res * sign
}
}

public inline fun Buffer.Companion.bigInt(size: Int, initializer: (Int) -> BigInt): Buffer<BigInt> =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ public interface PowerOperations<T> : Algebra<T> {
}

/**
* Raises this element to the power [pow].
* Raises this element to the power [power].
*
* @receiver the base.
* @param power the exponent.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@

package space.kscience.kmath.operations

import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.testutils.RingVerifier
import kotlin.math.pow
import kotlin.test.Test
import kotlin.test.assertEquals

Expand All @@ -21,6 +23,24 @@ internal class BigIntAlgebraTest {
assertEquals(res, 1_000_000.toBigInt())
}

@UnstableKMathAPI
@Test
fun testKBigIntegerRingPow() {
for (num in -5..5)
for (exponent in 0U..10U) {
assertEquals(
num.toDouble().pow(exponent.toInt()).toLong().toBigInt(),
num.toBigInt().pow(exponent.toULong()),
"$num ^ $exponent"
)
assertEquals(
num.toDouble().pow(exponent.toInt()).toLong().toBigInt(),
num.toBigInt().pow(exponent),
"$num ^ $exponent"
)
}
}

@Test
fun testKBigIntegerRingSum_100_000_000__100_000_000() {
BigIntField {
Expand Down
Loading