Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jsonExtractIndex support array of default values #12748

Merged
merged 2 commits into from
Mar 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
*/
package org.apache.pinot.core.operator.transform.function;

import com.fasterxml.jackson.databind.JsonNode;
import java.io.IOException;
import java.math.BigDecimal;
import java.util.List;
import java.util.Map;
Expand All @@ -27,6 +29,7 @@
import org.apache.pinot.core.operator.transform.TransformResultMetadata;
import org.apache.pinot.segment.spi.index.reader.JsonIndexReader;
import org.apache.pinot.spi.data.FieldSpec.DataType;
import org.apache.pinot.spi.utils.JsonUtils;
import org.roaringbitmap.RoaringBitmap;


Expand Down Expand Up @@ -101,7 +104,24 @@ public void init(List<TransformFunction> arguments, Map<String, ColumnContext> c
if (!(fourthArgument instanceof LiteralTransformFunction)) {
throw new IllegalArgumentException("Default value must be a literal");
}
_defaultValue = dataType.convert(((LiteralTransformFunction) fourthArgument).getStringLiteral());

if (isSingleValue) {
_defaultValue = dataType.convert(((LiteralTransformFunction) fourthArgument).getStringLiteral());
} else {
try {
JsonNode mvArray = JsonUtils.stringToJsonNode(((LiteralTransformFunction) fourthArgument).getStringLiteral());
if (!mvArray.isArray()) {
throw new IllegalArgumentException("Default value must be a valid JSON array");
}
Object[] defaultValues = new Object[mvArray.size()];
for (int i = 0; i < mvArray.size(); i++) {
defaultValues[i] = dataType.convert(mvArray.get(i).asText());
}
_defaultValue = defaultValues;
} catch (IOException e) {
throw new IllegalArgumentException("Default value must be a valid JSON array");
}
}
}

String filterJsonPath = null;
Expand Down Expand Up @@ -267,6 +287,17 @@ public int[][] transformToIntValuesMV(ValueBlock valueBlock) {

for (int i = 0; i < numDocs; i++) {
String[] value = valuesFromIndex[i];
if (value.length == 0) {
if (_defaultValue != null) {
_intValuesMV[i] = new int[((Object[]) (_defaultValue)).length];
for (int j = 0; j < _intValuesMV[i].length; j++) {
_intValuesMV[i][j] = (int) ((Object[]) _defaultValue)[j];
}
continue;
}
throw new RuntimeException(
String.format("Illegal Json Path: [%s], for docId [%s]", _jsonPathString, valueBlock.getDocIds()[i]));
}
_intValuesMV[i] = new int[value.length];
for (int j = 0; j < value.length; j++) {
_intValuesMV[i][j] = Integer.parseInt(value[j]);
Expand All @@ -283,6 +314,17 @@ public long[][] transformToLongValuesMV(ValueBlock valueBlock) {
_valueToMatchingDocsMap);
for (int i = 0; i < numDocs; i++) {
String[] value = valuesFromIndex[i];
if (value.length == 0) {
if (_defaultValue != null) {
_longValuesMV[i] = new long[((Object[]) (_defaultValue)).length];
for (int j = 0; j < _longValuesMV[i].length; j++) {
_longValuesMV[i][j] = (long) ((Object[]) _defaultValue)[j];
}
continue;
}
throw new RuntimeException(
String.format("Illegal Json Path: [%s], for docId [%s]", _jsonPathString, valueBlock.getDocIds()[i]));
}
_longValuesMV[i] = new long[value.length];
for (int j = 0; j < value.length; j++) {
_longValuesMV[i][j] = Long.parseLong(value[j]);
Expand All @@ -299,6 +341,17 @@ public float[][] transformToFloatValuesMV(ValueBlock valueBlock) {
_valueToMatchingDocsMap);
for (int i = 0; i < numDocs; i++) {
String[] value = valuesFromIndex[i];
if (value.length == 0) {
if (_defaultValue != null) {
_floatValuesMV[i] = new float[((Object[]) (_defaultValue)).length];
for (int j = 0; j < _floatValuesMV[i].length; j++) {
_floatValuesMV[i][j] = (float) ((Object[]) _defaultValue)[j];
}
continue;
}
throw new RuntimeException(
String.format("Illegal Json Path: [%s], for docId [%s]", _jsonPathString, valueBlock.getDocIds()[i]));
}
_floatValuesMV[i] = new float[value.length];
for (int j = 0; j < value.length; j++) {
_floatValuesMV[i][j] = Float.parseFloat(value[j]);
Expand All @@ -315,6 +368,17 @@ public double[][] transformToDoubleValuesMV(ValueBlock valueBlock) {
_valueToMatchingDocsMap);
for (int i = 0; i < numDocs; i++) {
String[] value = valuesFromIndex[i];
if (value.length == 0) {
if (_defaultValue != null) {
_doubleValuesMV[i] = new double[((Object[]) (_defaultValue)).length];
for (int j = 0; j < _doubleValuesMV[i].length; j++) {
_doubleValuesMV[i][j] = (double) ((Object[]) _defaultValue)[j];
}
continue;
}
throw new RuntimeException(
String.format("Illegal Json Path: [%s], for docId [%s]", _jsonPathString, valueBlock.getDocIds()[i]));
}
_doubleValuesMV[i] = new double[value.length];
for (int j = 0; j < value.length; j++) {
_doubleValuesMV[i][j] = Double.parseDouble(value[j]);
Expand All @@ -331,6 +395,17 @@ public String[][] transformToStringValuesMV(ValueBlock valueBlock) {
_valueToMatchingDocsMap);
for (int i = 0; i < numDocs; i++) {
String[] value = valuesFromIndex[i];
if (value.length == 0) {
if (_defaultValue != null) {
_stringValuesMV[i] = new String[((Object[]) (_defaultValue)).length];
for (int j = 0; j < _stringValuesMV[i].length; j++) {
_stringValuesMV[i][j] = (String) ((Object[]) _defaultValue)[j];
}
continue;
}
throw new RuntimeException(
String.format("Illegal Json Path: [%s], for docId [%s]", _jsonPathString, valueBlock.getDocIds()[i]));
}
_stringValuesMV[i] = new String[value.length];
System.arraycopy(value, 0, _stringValuesMV[i], 0, value.length);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,15 +251,15 @@ private void addMvTests(List<Object[]> testArguments) {
// MV with filters
testArguments.add(new Object[]{
String.format(
"jsonExtractIndex(%s,'%s','INT_ARRAY', '0', 'REGEXP_LIKE(\"$.arrayField[*].arrStringField\", ''.*y.*'')')",
"jsonExtractIndex(%s,'%s','INT_ARRAY', '[]', 'REGEXP_LIKE(\"$.arrayField[*].arrStringField\", ''.*y.*'')')",
JSON_STRING_SV_COLUMN,
"$.arrayField[*].arrIntField"), "$.arrayField[?(@.arrStringField =~ /.*y.*/)].arrIntField", DataType.INT,
false
});

testArguments.add(new Object[]{
String.format(
"jsonExtractIndex(%s,'%s','STRING_ARRAY', '0', '\"$.arrayField[*].arrIntField\" > 2')",
"jsonExtractIndex(%s,'%s','STRING_ARRAY', '[]', '\"$.arrayField[*].arrIntField\" > 2')",
JSON_STRING_SV_COLUMN,
"$.arrayField[*].arrStringField"), "$.arrayField[?(@.arrIntField > 2)].arrStringField", DataType.STRING,
false
Expand All @@ -268,7 +268,7 @@ private void addMvTests(List<Object[]> testArguments) {

@Test(dataProvider = "testJsonExtractIndexDefaultValue")
public void testJsonExtractIndexDefaultValue(String expressionStr, String jsonPathString, DataType resultsDataType,
boolean isSingleValue) {
boolean isSingleValue, Object expectedDefaultValue) {
ExpressionContext expression = RequestContextUtils.getExpression(expressionStr);
TransformFunction transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap);
Assert.assertTrue(transformFunction instanceof JsonExtractIndexTransformFunction);
Expand All @@ -281,37 +281,72 @@ public void testJsonExtractIndexDefaultValue(String expressionStr, String jsonPa
case INT:
int[] intValues = transformFunction.transformToIntValuesSV(_projectionBlock);
for (int i = 0; i < NUM_ROWS; i++) {
Assert.assertEquals(intValues[i], 0);
Assert.assertEquals(intValues[i], expectedDefaultValue);
}
break;
case LONG:
long[] longValues = transformFunction.transformToLongValuesSV(_projectionBlock);
for (int i = 0; i < NUM_ROWS; i++) {
Assert.assertEquals(longValues[i], 0L);
Assert.assertEquals(longValues[i], expectedDefaultValue);
}
break;
case FLOAT:
float[] floatValues = transformFunction.transformToFloatValuesSV(_projectionBlock);
for (int i = 0; i < NUM_ROWS; i++) {
Assert.assertEquals(floatValues[i], 0f);
Assert.assertEquals(floatValues[i], expectedDefaultValue);
}
break;
case DOUBLE:
double[] doubleValues = transformFunction.transformToDoubleValuesSV(_projectionBlock);
for (int i = 0; i < NUM_ROWS; i++) {
Assert.assertEquals(doubleValues[i], 0d);
Assert.assertEquals(doubleValues[i], expectedDefaultValue);
}
break;
case BIG_DECIMAL:
BigDecimal[] bigDecimalValues = transformFunction.transformToBigDecimalValuesSV(_projectionBlock);
for (int i = 0; i < NUM_ROWS; i++) {
Assert.assertEquals(bigDecimalValues[i], BigDecimal.ZERO);
Assert.assertEquals(bigDecimalValues[i], expectedDefaultValue);
}
break;
case STRING:
String[] stringValues = transformFunction.transformToStringValuesSV(_projectionBlock);
for (int i = 0; i < NUM_ROWS; i++) {
Assert.assertEquals(stringValues[i], "null");
Assert.assertEquals(stringValues[i], expectedDefaultValue);
}
break;
default:
throw new UnsupportedOperationException("Not support data type - " + resultsDataType);
}
} else {
switch (resultsDataType) {
case INT:
int[][] intValues = transformFunction.transformToIntValuesMV(_projectionBlock);
for (int i = 0; i < NUM_ROWS; i++) {
Assert.assertEquals(intValues[i], expectedDefaultValue);
}
break;
case LONG:
long[][] longValues = transformFunction.transformToLongValuesMV(_projectionBlock);
for (int i = 0; i < NUM_ROWS; i++) {
Assert.assertEquals(longValues[i], expectedDefaultValue);
}
break;
case FLOAT:
float[][] floatValues = transformFunction.transformToFloatValuesMV(_projectionBlock);
for (int i = 0; i < NUM_ROWS; i++) {
Assert.assertEquals(floatValues[i], expectedDefaultValue);
}
break;
case DOUBLE:
double[][] doubleValues = transformFunction.transformToDoubleValuesMV(_projectionBlock);
for (int i = 0; i < NUM_ROWS; i++) {
Assert.assertEquals(doubleValues[i], expectedDefaultValue);
}
break;
case STRING:
String[][] stringValues = transformFunction.transformToStringValuesMV(_projectionBlock);
for (int i = 0; i < NUM_ROWS; i++) {
Assert.assertEquals(stringValues[i], expectedDefaultValue);
}
break;
default:
Expand All @@ -326,31 +361,56 @@ public Object[][] testJsonExtractIndexDefaultValueDataProvider() {
// With default value
testArguments.add(new Object[]{
String.format("jsonExtractIndex(%s,'%s','INT',0)", JSON_STRING_SV_COLUMN,
"$.noField"), "$.noField", DataType.INT, true
"$.noField"), "$.noField", DataType.INT, true, 0
});
testArguments.add(new Object[]{
String.format("jsonExtractIndex(%s,'%s','LONG',0)", JSON_STRING_SV_COLUMN,
"$.noField"), "$.noField", DataType.LONG, true
"$.noField"), "$.noField", DataType.LONG, true, 0L
});
testArguments.add(new Object[]{
String.format("jsonExtractIndex(%s,'%s','FLOAT',0)", JSON_STRING_SV_COLUMN,
"$.noField"), "$.noField", DataType.FLOAT, true
"$.noField"), "$.noField", DataType.FLOAT, true, (float) 0
});
testArguments.add(new Object[]{
String.format("jsonExtractIndex(%s,'%s','DOUBLE',0)", JSON_STRING_SV_COLUMN,
"$.noField"), "$.noField", DataType.DOUBLE, true
"$.noField"), "$.noField", DataType.DOUBLE, true, (double) 0
});
testArguments.add(new Object[]{
String.format("jsonExtractIndex(%s,'%s','BIG_DECIMAL',0)", JSON_STRING_SV_COLUMN,
"$.noField"), "$.noField", DataType.BIG_DECIMAL, true
"$.noField"), "$.noField", DataType.BIG_DECIMAL, true, new BigDecimal(0)
});
testArguments.add(new Object[]{
String.format("jsonExtractIndex(%s,'%s','STRING','null')", JSON_STRING_SV_COLUMN,
"$.noField"), "$.noField", DataType.STRING, true
"$.noField"), "$.noField", DataType.STRING, true, "null"
});
addMvDefaultValueTests(testArguments);
return testArguments.toArray(new Object[0][]);
}

private void addMvDefaultValueTests(List<Object[]> testArguments) {
testArguments.add(new Object[]{
String.format("jsonExtractIndex(%s,'%s','INT_ARRAY', '%s')", JSON_STRING_SV_COLUMN, "$.noField",
"[1, 2, 3]"), "$.noField", DataType.INT, false, new Integer[]{1, 2, 3}
});
testArguments.add(new Object[]{
String.format("jsonExtractIndex(%s,'%s','LONG_ARRAY', '%s')", JSON_STRING_SV_COLUMN, "$.noField",
"[1, 5, 6]"), "$.noField", DataType.LONG, false, new Long[]{1L, 5L, 6L}
});
testArguments.add(new Object[]{
String.format("jsonExtractIndex(%s,'%s','FLOAT_ARRAY', '%s')", JSON_STRING_SV_COLUMN, "$.noField",
"[1.2, 3.1, 1.6]"), "$.noField", DataType.FLOAT, false, new Float[]{1.2f, 3.1f, 1.6f}
});
testArguments.add(new Object[]{
String.format("jsonExtractIndex(%s,'%s','DOUBLE_ARRAY', '%s')", JSON_STRING_SV_COLUMN, "$.noField",
"[1.5, 3.4, 1.6]"), "$.noField", DataType.DOUBLE, false, new Double[]{1.5d, 3.4d, 1.6d}
});
testArguments.add(new Object[]{
String.format("jsonExtractIndex(%s,'%s','STRING_ARRAY', '%s')", JSON_STRING_SV_COLUMN, "$.noField",
"[\"randomString1\", \"randomString2\"]"), "$.noField", DataType.STRING, false,
new String[]{"randomString1", "randomString2"}
});
}

// get value for key, excluding nested
private String getValueForKey(String blob, JsonPath path) {
Object out = JSON_PARSER_CONTEXT.parse(blob).read(path);
Expand Down
Loading