From ee24b1c92b41e9f9f1625e1036f790555d7fba07 Mon Sep 17 00:00:00 2001 From: Yizhe Liu <59710443+yizheliu-amazon@users.noreply.github.com> Date: Fri, 27 Dec 2024 14:35:02 -0800 Subject: [PATCH] Support empty string for fields in text embedding processor (#1041) * Allow empty string for field in field map Signed-off-by: Yizhe Liu * Allow empty string when validation Signed-off-by: Yizhe Liu * Add to change log Signed-off-by: Yizhe Liu * Update CHANGELOG to: Support empty string for fields in text embedding processor Signed-off-by: Yizhe Liu --------- Signed-off-by: Yizhe Liu --- CHANGELOG.md | 1 + .../processor/InferenceProcessor.java | 23 +++- .../processor/InferenceProcessorTests.java | 43 ------- .../TextEmbeddingProcessorTests.java | 113 ++++++++++++------ 4 files changed, 94 insertions(+), 86 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5345d416f..6db95ed90 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Explainability in hybrid query ([#970](https://github.com/opensearch-project/neural-search/pull/970)) - Support new knn query parameter expand_nested ([#1013](https://github.com/opensearch-project/neural-search/pull/1013)) - Implement pruning for neural sparse ingestion pipeline and two phase search processor ([#988](https://github.com/opensearch-project/neural-search/pull/988)) +- Support empty string for fields in text embedding processor ([#1041](https://github.com/opensearch-project/neural-search/pull/1041)) ### Bug Fixes - Address inconsistent scoring in hybrid query results ([#998](https://github.com/opensearch-project/neural-search/pull/998)) ### Infrastructure diff --git a/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java index ae996251d..12608fccb 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java @@ -307,7 +307,7 @@ Map buildMapWithTargetKeys(IngestDocument ingestDocument) { buildNestedMap(originalKey, targetKey, sourceAndMetadataMap, treeRes); mapWithProcessorKeys.put(originalKey, treeRes.get(originalKey)); } else { - mapWithProcessorKeys.put(String.valueOf(targetKey), sourceAndMetadataMap.get(originalKey)); + mapWithProcessorKeys.put(String.valueOf(targetKey), normalizeSourceValue(sourceAndMetadataMap.get(originalKey))); } } return mapWithProcessorKeys; @@ -333,7 +333,10 @@ void buildNestedMap(String parentKey, Object processorKey, Map s } else if (sourceAndMetadataMap.get(parentKey) instanceof List) { for (Map.Entry nestedFieldMapEntry : ((Map) processorKey).entrySet()) { List> list = (List>) sourceAndMetadataMap.get(parentKey); - List listOfStrings = list.stream().map(x -> x.get(nestedFieldMapEntry.getKey())).collect(Collectors.toList()); + List listOfStrings = list.stream().map(x -> { + Object nestedSourceValue = x.get(nestedFieldMapEntry.getKey()); + return normalizeSourceValue(nestedSourceValue); + }).collect(Collectors.toList()); Map map = new LinkedHashMap<>(); map.put(nestedFieldMapEntry.getKey(), listOfStrings); buildNestedMap(nestedFieldMapEntry.getKey(), nestedFieldMapEntry.getValue(), map, next); @@ -341,11 +344,23 @@ void buildNestedMap(String parentKey, Object processorKey, Map s } treeRes.merge(parentKey, next, REMAPPING_FUNCTION); } else { + Object parentValue = sourceAndMetadataMap.get(parentKey); String key = String.valueOf(processorKey); - treeRes.put(key, sourceAndMetadataMap.get(parentKey)); + treeRes.put(key, normalizeSourceValue(parentValue)); } } + private boolean isBlankString(Object object) { + return object instanceof String && StringUtils.isBlank((String) object); + } + + private Object normalizeSourceValue(Object value) { + if (isBlankString(value)) { + return null; + } + return value; + } + /** * Process the nested key, such as "a.b.c" to "a", "b.c" * @param nestedFieldMapEntry @@ -376,7 +391,7 @@ private void validateEmbeddingFieldsValue(IngestDocument ingestDocument) { indexName, clusterService, environment, - false + true ); } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTests.java index dc86975bd..bcc6b74b5 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTests.java @@ -66,28 +66,6 @@ public void test_batchExecute_emptyInput() { verify(clientAccessor, never()).inferenceSentences(anyString(), anyList(), any()); } - public void test_batchExecuteWithEmpty_allFailedValidation() { - final int docCount = 2; - TestInferenceProcessor processor = new TestInferenceProcessor(createMockVectorResult(), BATCH_SIZE, null); - List wrapperList = createIngestDocumentWrappers(docCount); - wrapperList.get(0).getIngestDocument().setFieldValue("key1", Arrays.asList("", "value1")); - wrapperList.get(1).getIngestDocument().setFieldValue("key1", Arrays.asList("", "value1")); - Consumer resultHandler = mock(Consumer.class); - processor.batchExecute(wrapperList, resultHandler); - ArgumentCaptor> captor = ArgumentCaptor.forClass(List.class); - verify(resultHandler).accept(captor.capture()); - assertEquals(docCount, captor.getValue().size()); - for (int i = 0; i < docCount; ++i) { - assertNotNull(captor.getValue().get(i).getException()); - assertEquals( - "list type field [key1] has empty string, cannot process it", - captor.getValue().get(i).getException().getMessage() - ); - assertEquals(wrapperList.get(i).getIngestDocument(), captor.getValue().get(i).getIngestDocument()); - } - verify(clientAccessor, never()).inferenceSentences(anyString(), anyList(), any()); - } - public void test_batchExecuteWithNull_allFailedValidation() { final int docCount = 2; TestInferenceProcessor processor = new TestInferenceProcessor(createMockVectorResult(), BATCH_SIZE, null); @@ -107,27 +85,6 @@ public void test_batchExecuteWithNull_allFailedValidation() { verify(clientAccessor, never()).inferenceSentences(anyString(), anyList(), any()); } - public void test_batchExecute_partialFailedValidation() { - final int docCount = 2; - TestInferenceProcessor processor = new TestInferenceProcessor(createMockVectorResult(), BATCH_SIZE, null); - List wrapperList = createIngestDocumentWrappers(docCount); - wrapperList.get(0).getIngestDocument().setFieldValue("key1", Arrays.asList("", "value1")); - wrapperList.get(1).getIngestDocument().setFieldValue("key1", Arrays.asList("value3", "value4")); - Consumer resultHandler = mock(Consumer.class); - processor.batchExecute(wrapperList, resultHandler); - ArgumentCaptor> captor = ArgumentCaptor.forClass(List.class); - verify(resultHandler).accept(captor.capture()); - assertEquals(docCount, captor.getValue().size()); - assertNotNull(captor.getValue().get(0).getException()); - assertNull(captor.getValue().get(1).getException()); - for (int i = 0; i < docCount; ++i) { - assertEquals(wrapperList.get(i).getIngestDocument(), captor.getValue().get(i).getIngestDocument()); - } - ArgumentCaptor> inferenceTextCaptor = ArgumentCaptor.forClass(List.class); - verify(clientAccessor).inferenceSentences(anyString(), inferenceTextCaptor.capture(), any()); - assertEquals(2, inferenceTextCaptor.getValue().size()); - } - public void test_batchExecute_happyCase() { final int docCount = 2; List> inferenceResults = createMockVectorWithLength(6); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java index 97e85e46e..a0809de23 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java @@ -28,6 +28,7 @@ import java.util.function.Consumer; import java.util.function.Supplier; +import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.tuple.Pair; import org.junit.Before; import org.mockito.ArgumentCaptor; @@ -240,31 +241,6 @@ public void testExecute_withListTypeInput_successful() { verify(handler).accept(any(IngestDocument.class), isNull()); } - public void testExecute_SimpleTypeWithEmptyStringValue_throwIllegalArgumentException() { - Map sourceAndMetadata = new HashMap<>(); - sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); - sourceAndMetadata.put("key1", " "); - IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); - TextEmbeddingProcessor processor = createInstanceWithLevel1MapConfig(); - - BiConsumer handler = mock(BiConsumer.class); - processor.execute(ingestDocument, handler); - verify(handler).accept(isNull(), any(IllegalArgumentException.class)); - } - - public void testExecute_listHasEmptyStringValue_throwIllegalArgumentException() { - List list1 = ImmutableList.of("", "test2", "test3"); - Map sourceAndMetadata = new HashMap<>(); - sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); - sourceAndMetadata.put("key1", list1); - IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); - TextEmbeddingProcessor processor = createInstanceWithLevel1MapConfig(); - - BiConsumer handler = mock(BiConsumer.class); - processor.execute(ingestDocument, handler); - verify(handler).accept(isNull(), any(IllegalArgumentException.class)); - } - public void testExecute_listHasNonStringValue_throwIllegalArgumentException() { List list2 = ImmutableList.of(1, 2, 3); Map sourceAndMetadata = new HashMap<>(); @@ -549,20 +525,6 @@ public void testExecute_mapHasNonStringValue_throwIllegalArgumentException() { verify(handler).accept(isNull(), any(IllegalArgumentException.class)); } - public void testExecute_mapHasEmptyStringValue_throwIllegalArgumentException() { - Map map1 = ImmutableMap.of("test1", "test2"); - Map map2 = ImmutableMap.of("test3", " "); - Map sourceAndMetadata = new HashMap<>(); - sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); - sourceAndMetadata.put("key1", map1); - sourceAndMetadata.put("key2", map2); - IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); - TextEmbeddingProcessor processor = createInstanceWithLevel2MapConfig(); - BiConsumer handler = mock(BiConsumer.class); - processor.execute(ingestDocument, handler); - verify(handler).accept(isNull(), any(IllegalArgumentException.class)); - } - public void testExecute_mapDepthReachLimit_throwIllegalArgumentException() { Map ret = createMaxDepthLimitExceedMap(() -> 1); Map sourceAndMetadata = new HashMap<>(); @@ -785,6 +747,79 @@ public void testBuildVectorOutput_withNestedListHasNotForEmbeddingField_Level2_s assertNotNull(nestedObj.get(1).get("vectorField")); } + @SuppressWarnings("unchecked") + public void testBuildVectorOutput_withPlainString_EmptyString_skipped() { + Map config = createPlainStringConfiguration(); + IngestDocument ingestDocument = createPlainIngestDocument(); + Map sourceAndMetadata = ingestDocument.getSourceAndMetadata(); + sourceAndMetadata.put("oriKey1", StringUtils.EMPTY); + + TextEmbeddingProcessor processor = createInstanceWithNestedMapConfiguration(config); + Map knnMap = processor.buildMapWithTargetKeys(ingestDocument); + List> modelTensorList = createRandomOneDimensionalMockVector(6, 100, 0.0f, 1.0f); + processor.setVectorFieldsToDocument(ingestDocument, knnMap, modelTensorList); + + /** IngestDocument + * "oriKey1": "", + * "oriKey2": "oriValue2", + * "oriKey3": "oriValue3", + * "oriKey4": "oriValue4", + * "oriKey5": "oriValue5", + * "oriKey6": [ + * "oriValue6", + * "oriValue7" + * ] + * + */ + assertEquals(11, sourceAndMetadata.size()); + assertFalse(sourceAndMetadata.containsKey("oriKey1_knn")); + } + + @SuppressWarnings("unchecked") + public void testBuildVectorOutput_withNestedField_EmptyString_skipped() { + Map config = createNestedMapConfiguration(); + IngestDocument ingestDocument = createNestedMapIngestDocument(); + Map favorites = (Map) ingestDocument.getSourceAndMetadata().get("favorites"); + Map favorite = (Map) favorites.get("favorite"); + favorite.put("movie", StringUtils.EMPTY); + + TextEmbeddingProcessor processor = createInstanceWithNestedMapConfiguration(config); + Map knnMap = processor.buildMapWithTargetKeys(ingestDocument); + List> modelTensorList = createRandomOneDimensionalMockVector(1, 100, 0.0f, 1.0f); + processor.buildNLPResult(knnMap, modelTensorList, ingestDocument.getSourceAndMetadata()); + + /** + * "favorites": { + * "favorite": { + * "movie": "", + * "actor": "Charlie Chaplin", + * "games" : { + * "adventure": { + * "action": "overwatch", + * "rpg": "elden ring" + * } + * } + * } + * } + */ + Map favoritesMap = (Map) ingestDocument.getSourceAndMetadata().get("favorites"); + assertNotNull(favoritesMap); + Map favoriteMap = (Map) favoritesMap.get("favorite"); + assertNotNull(favoriteMap); + + Map favoriteGames = (Map) favoriteMap.get("games"); + assertNotNull(favoriteGames); + Map adventure = (Map) favoriteGames.get("adventure"); + List adventureKnnVector = (List) adventure.get("with_action_knn"); + assertNotNull(adventureKnnVector); + assertEquals(100, adventureKnnVector.size()); + for (float vector : adventureKnnVector) { + assertTrue(vector >= 0.0f && vector <= 1.0f); + } + + assertFalse(favoriteMap.containsKey("favorite_movie_knn")); + } + public void test_updateDocument_appendVectorFieldsToDocument_successful() { Map config = createPlainStringConfiguration(); IngestDocument ingestDocument = createPlainIngestDocument();