Skip to content

Commit

Permalink
[SPARK-49506][SQL] Optimize ArrayBinarySearch for foldable array
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
The pr aims to
- optimize `ArrayBinarySearch` for `foldable` array.
- fix a bug in the original implementation.

### Why are the changes needed?
The changes improve performance of the `array_binary_search()` function.
- create an instance of `foldable{DataType}ArrayData` only once at the initialization ( avoid frequent calls to `ArrayData.to{DataType}Array()` ), and reuse it inside of `replacement` in the case when the `array` parameter is foldable.

Before:
```
Running benchmark: array binary search
  Running case: no foldable optimize
  Stopped after 100 iterations, 93668 ms

OpenJDK 64-Bit Server VM 17.0.10+7-LTS on Mac OS X 14.6.1
Apple M2
array binary search:                      Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------------------------------
no foldable optimize                                916            937          24         10.9          91.6       1.0X
```

After:
```
Running benchmark: array binary search
  Running case: has foldable optimize
  Stopped after 100 iterations, 17206 ms

OpenJDK 64-Bit Server VM 17.0.10+7-LTS on Mac OS X 14.6.1
Apple M2
array binary search:                      Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------------------------------
has foldable optimize                               164            172          22         61.1          16.4       1.0X
```

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
- Update existed UT.
- Pass GA.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes apache#48225 from panbingkun/SPARK-49506_FOLLOWUP.

Authored-by: panbingkun <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
panbingkun authored and cloud-fan committed Oct 29, 2024
1 parent 4d30048 commit 6a36c43
Show file tree
Hide file tree
Showing 6 changed files with 534 additions and 103 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,158 +19,190 @@
import java.util.Arrays;
import java.util.Comparator;

import org.apache.spark.sql.catalyst.util.ArrayData;
import org.apache.spark.sql.catalyst.util.SQLOrderingUtil;
import org.apache.spark.sql.types.ByteType$;
import org.apache.spark.sql.types.BooleanType$;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DoubleType$;
import org.apache.spark.sql.types.FloatType$;
import org.apache.spark.sql.types.IntegerType$;
import org.apache.spark.sql.types.LongType$;
import org.apache.spark.sql.types.ShortType$;

public class ArrayExpressionUtils {

private static final Comparator<Object> booleanComp = (o1, o2) -> {
// comparator
// Boolean ascending nullable comparator
private static final Comparator<Boolean> booleanComp = (o1, o2) -> {
if (o1 == null && o2 == null) {
return 0;
} else if (o1 == null) {
return -1;
} else if (o2 == null) {
return 1;
}
boolean c1 = (Boolean) o1, c2 = (Boolean) o2;
return c1 == c2 ? 0 : (c1 ? 1 : -1);
return o1.equals(o2) ? 0 : (o1 ? 1 : -1);
};

private static final Comparator<Object> byteComp = (o1, o2) -> {
// Byte ascending nullable comparator
private static final Comparator<Byte> byteComp = (o1, o2) -> {
if (o1 == null && o2 == null) {
return 0;
} else if (o1 == null) {
return -1;
} else if (o2 == null) {
return 1;
}
byte c1 = (Byte) o1, c2 = (Byte) o2;
return Byte.compare(c1, c2);
return Byte.compare(o1, o2);
};

private static final Comparator<Object> shortComp = (o1, o2) -> {
// Short ascending nullable comparator
private static final Comparator<Short> shortComp = (o1, o2) -> {
if (o1 == null && o2 == null) {
return 0;
} else if (o1 == null) {
return -1;
} else if (o2 == null) {
return 1;
}
short c1 = (Short) o1, c2 = (Short) o2;
return Short.compare(c1, c2);
return Short.compare(o1, o2);
};

private static final Comparator<Object> integerComp = (o1, o2) -> {
// Integer ascending nullable comparator
private static final Comparator<Integer> integerComp = (o1, o2) -> {
if (o1 == null && o2 == null) {
return 0;
} else if (o1 == null) {
return -1;
} else if (o2 == null) {
return 1;
}
int c1 = (Integer) o1, c2 = (Integer) o2;
return Integer.compare(c1, c2);
return Integer.compare(o1, o2);
};

private static final Comparator<Object> longComp = (o1, o2) -> {
// Long ascending nullable comparator
private static final Comparator<Long> longComp = (o1, o2) -> {
if (o1 == null && o2 == null) {
return 0;
} else if (o1 == null) {
return -1;
} else if (o2 == null) {
return 1;
}
long c1 = (Long) o1, c2 = (Long) o2;
return Long.compare(c1, c2);
return Long.compare(o1, o2);
};

private static final Comparator<Object> floatComp = (o1, o2) -> {
// Float ascending nullable comparator
private static final Comparator<Float> floatComp = (o1, o2) -> {
if (o1 == null && o2 == null) {
return 0;
} else if (o1 == null) {
return -1;
} else if (o2 == null) {
return 1;
}
float c1 = (Float) o1, c2 = (Float) o2;
return SQLOrderingUtil.compareFloats(c1, c2);
return SQLOrderingUtil.compareFloats(o1, o2);
};

private static final Comparator<Object> doubleComp = (o1, o2) -> {
// Double ascending nullable comparator
private static final Comparator<Double> doubleComp = (o1, o2) -> {
if (o1 == null && o2 == null) {
return 0;
} else if (o1 == null) {
return -1;
} else if (o2 == null) {
return 1;
}
double c1 = (Double) o1, c2 = (Double) o2;
return SQLOrderingUtil.compareDoubles(c1, c2);
return SQLOrderingUtil.compareDoubles(o1, o2);
};

public static int binarySearchNullSafe(ArrayData data, Boolean value) {
return Arrays.binarySearch(data.toObjectArray(BooleanType$.MODULE$), value, booleanComp);
// boolean
// boolean non-nullable
public static int binarySearch(boolean[] data, boolean value) {
int low = 0;
int high = data.length - 1;

while (low <= high) {
int mid = (low + high) >>> 1;
boolean midVal = data[mid];

if (value == midVal) {
return mid; // key found
} else if (value) {
low = mid + 1;
} else {
high = mid - 1;
}
}

return -(low + 1); // key not found.
}

// Boolean nullable
public static int binarySearch(Boolean[] data, Boolean value) {
return Arrays.binarySearch(data, value, booleanComp);
}

public static int binarySearch(ArrayData data, byte value) {
return Arrays.binarySearch(data.toByteArray(), value);
// byte
// byte non-nullable
public static int binarySearch(byte[] data, byte value) {
return Arrays.binarySearch(data, value);
}

public static int binarySearchNullSafe(ArrayData data, Byte value) {
return Arrays.binarySearch(data.toObjectArray(ByteType$.MODULE$), value, byteComp);
// Byte nullable
public static int binarySearch(Byte[] data, Byte value) {
return Arrays.binarySearch(data, value, byteComp);
}

public static int binarySearch(ArrayData data, short value) {
return Arrays.binarySearch(data.toShortArray(), value);
// short
// short non-nullable
public static int binarySearch(short[] data, short value) {
return Arrays.binarySearch(data, value);
}

public static int binarySearchNullSafe(ArrayData data, Short value) {
return Arrays.binarySearch(data.toObjectArray(ShortType$.MODULE$), value, shortComp);
// Short nullable
public static int binarySearch(Short[] data, Short value) {
return Arrays.binarySearch(data, value, shortComp);
}

public static int binarySearch(ArrayData data, int value) {
return Arrays.binarySearch(data.toIntArray(), value);
// int
// int non-nullable
public static int binarySearch(int[] data, int value) {
return Arrays.binarySearch(data, value);
}

public static int binarySearchNullSafe(ArrayData data, Integer value) {
return Arrays.binarySearch(data.toObjectArray(IntegerType$.MODULE$), value, integerComp);
// Integer nullable
public static int binarySearch(Integer[] data, Integer value) {
return Arrays.binarySearch(data, value, integerComp);
}

public static int binarySearch(ArrayData data, long value) {
return Arrays.binarySearch(data.toLongArray(), value);
// long
// long non-nullable
public static int binarySearch(long[] data, long value) {
return Arrays.binarySearch(data, value);
}

public static int binarySearchNullSafe(ArrayData data, Long value) {
return Arrays.binarySearch(data.toObjectArray(LongType$.MODULE$), value, longComp);
// Long nullable
public static int binarySearch(Long[] data, Long value) {
return Arrays.binarySearch(data, value, longComp);
}

public static int binarySearch(ArrayData data, float value) {
return Arrays.binarySearch(data.toFloatArray(), value);
// float
// float non-nullable
public static int binarySearch(float[] data, float value) {
return Arrays.binarySearch(data, value);
}

public static int binarySearchNullSafe(ArrayData data, Float value) {
return Arrays.binarySearch(data.toObjectArray(FloatType$.MODULE$), value, floatComp);
// Float nullable
public static int binarySearch(Float[] data, Float value) {
return Arrays.binarySearch(data, value, floatComp);
}

public static int binarySearch(ArrayData data, double value) {
return Arrays.binarySearch(data.toDoubleArray(), value);
// double
// double non-nullable
public static int binarySearch(double[] data, double value) {
return Arrays.binarySearch(data, value);
}

public static int binarySearchNullSafe(ArrayData data, Double value) {
return Arrays.binarySearch(data.toObjectArray(DoubleType$.MODULE$), value, doubleComp);
// Double nullable
public static int binarySearch(Double[] data, Double value) {
return Arrays.binarySearch(data, value, doubleComp);
}

public static int binarySearch(
DataType elementType, Comparator<Object> comp, ArrayData data, Object value) {
Object[] array = data.toObjectArray(elementType);
return Arrays.binarySearch(array, value, comp);
// Object
public static int binarySearch(Object[] data, Object value, Comparator<Object> comp) {
return Arrays.binarySearch(data, value, comp);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.expressions;

import scala.reflect.ClassTag$;

import org.apache.spark.sql.catalyst.util.ArrayData;

import static org.apache.spark.sql.types.DataTypes.BooleanType;
import static org.apache.spark.sql.types.DataTypes.ByteType;
import static org.apache.spark.sql.types.DataTypes.DoubleType;
import static org.apache.spark.sql.types.DataTypes.FloatType;
import static org.apache.spark.sql.types.DataTypes.IntegerType;
import static org.apache.spark.sql.types.DataTypes.LongType;
import static org.apache.spark.sql.types.DataTypes.ShortType;

public class ToJavaArrayUtils {

// boolean
// boolean non-nullable
public static boolean[] toBooleanArray(ArrayData arrayData) {
return arrayData.toBooleanArray();
}

// Boolean nullable
public static Boolean[] toBoxedBooleanArray(ArrayData arrayData) {
return (Boolean[]) arrayData.toArray(BooleanType,
ClassTag$.MODULE$.apply(java.lang.Boolean.class));
}

// byte
// byte non-nullable
public static byte[] toByteArray(ArrayData arrayData) {
return arrayData.toByteArray();
}

// Byte nullable
public static Byte[] toBoxedByteArray(ArrayData arrayData) {
return (Byte[]) arrayData.toArray(ByteType, ClassTag$.MODULE$.apply(java.lang.Byte.class));
}

// short
// short non-nullable
public static short[] toShortArray(ArrayData arrayData) {
return arrayData.toShortArray();
}

// Short nullable
public static Short[] toBoxedShortArray(ArrayData arrayData) {
return (Short[]) arrayData.toArray(ShortType, ClassTag$.MODULE$.apply(java.lang.Short.class));
}

// int
// int non-nullable
public static int[] toIntegerArray(ArrayData arrayData) {
return arrayData.toIntArray();
}

// Integer nullable
public static Integer[] toBoxedIntegerArray(ArrayData arrayData) {
return (Integer[]) arrayData.toArray(IntegerType,
ClassTag$.MODULE$.apply(java.lang.Integer.class));
}

// long
// long non-nullable
public static long[] toLongArray(ArrayData arrayData) {
return arrayData.toLongArray();
}

// Long nullable
public static Long[] toBoxedLongArray(ArrayData arrayData) {
return (Long[]) arrayData.toArray(LongType, ClassTag$.MODULE$.apply(java.lang.Long.class));
}

// float
// float non-nullable
public static float[] toFloatArray(ArrayData arrayData) {
return arrayData.toFloatArray();
}

// Float nullable
public static Float[] toBoxedFloatArray(ArrayData arrayData) {
return (Float[]) arrayData.toArray(FloatType, ClassTag$.MODULE$.apply(java.lang.Float.class));
}

// double
// double non-nullable
public static double[] toDoubleArray(ArrayData arrayData) {
return arrayData.toDoubleArray();
}

// Double nullable
public static Double[] toBoxedDoubleArray(ArrayData arrayData) {
return (Double[]) arrayData.toArray(DoubleType,
ClassTag$.MODULE$.apply(java.lang.Double.class));
}
}
Loading

0 comments on commit 6a36c43

Please sign in to comment.