Skip to content

Commit

Permalink
Support empty string for fields in text embedding processor (#1041)
Browse files Browse the repository at this point in the history
* Allow empty string for field in field map

Signed-off-by: Yizhe Liu <[email protected]>

* Allow empty string when validation

Signed-off-by: Yizhe Liu <[email protected]>

* Add to change log

Signed-off-by: Yizhe Liu <[email protected]>

* Update CHANGELOG to: Support empty string for fields in text embedding processor

Signed-off-by: Yizhe Liu <[email protected]>

---------

Signed-off-by: Yizhe Liu <[email protected]>
  • Loading branch information
yizheliu-amazon authored Dec 27, 2024
1 parent 22ba5d3 commit ee24b1c
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 86 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Explainability in hybrid query ([#970](https://github.com/opensearch-project/neural-search/pull/970))
- Support new knn query parameter expand_nested ([#1013](https://github.com/opensearch-project/neural-search/pull/1013))
- Implement pruning for neural sparse ingestion pipeline and two phase search processor ([#988](https://github.com/opensearch-project/neural-search/pull/988))
- Support empty string for fields in text embedding processor ([#1041](https://github.com/opensearch-project/neural-search/pull/1041))
### Bug Fixes
- Address inconsistent scoring in hybrid query results ([#998](https://github.com/opensearch-project/neural-search/pull/998))
### Infrastructure
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ Map<String, Object> buildMapWithTargetKeys(IngestDocument ingestDocument) {
buildNestedMap(originalKey, targetKey, sourceAndMetadataMap, treeRes);
mapWithProcessorKeys.put(originalKey, treeRes.get(originalKey));
} else {
mapWithProcessorKeys.put(String.valueOf(targetKey), sourceAndMetadataMap.get(originalKey));
mapWithProcessorKeys.put(String.valueOf(targetKey), normalizeSourceValue(sourceAndMetadataMap.get(originalKey)));
}
}
return mapWithProcessorKeys;
Expand All @@ -333,19 +333,34 @@ void buildNestedMap(String parentKey, Object processorKey, Map<String, Object> s
} else if (sourceAndMetadataMap.get(parentKey) instanceof List) {
for (Map.Entry<String, Object> nestedFieldMapEntry : ((Map<String, Object>) processorKey).entrySet()) {
List<Map<String, Object>> list = (List<Map<String, Object>>) sourceAndMetadataMap.get(parentKey);
List<Object> listOfStrings = list.stream().map(x -> x.get(nestedFieldMapEntry.getKey())).collect(Collectors.toList());
List<Object> listOfStrings = list.stream().map(x -> {
Object nestedSourceValue = x.get(nestedFieldMapEntry.getKey());
return normalizeSourceValue(nestedSourceValue);
}).collect(Collectors.toList());
Map<String, Object> map = new LinkedHashMap<>();
map.put(nestedFieldMapEntry.getKey(), listOfStrings);
buildNestedMap(nestedFieldMapEntry.getKey(), nestedFieldMapEntry.getValue(), map, next);
}
}
treeRes.merge(parentKey, next, REMAPPING_FUNCTION);
} else {
Object parentValue = sourceAndMetadataMap.get(parentKey);
String key = String.valueOf(processorKey);
treeRes.put(key, sourceAndMetadataMap.get(parentKey));
treeRes.put(key, normalizeSourceValue(parentValue));
}
}

private boolean isBlankString(Object object) {
return object instanceof String && StringUtils.isBlank((String) object);
}

private Object normalizeSourceValue(Object value) {
if (isBlankString(value)) {
return null;
}
return value;
}

/**
* Process the nested key, such as "a.b.c" to "a", "b.c"
* @param nestedFieldMapEntry
Expand Down Expand Up @@ -376,7 +391,7 @@ private void validateEmbeddingFieldsValue(IngestDocument ingestDocument) {
indexName,
clusterService,
environment,
false
true
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,28 +66,6 @@ public void test_batchExecute_emptyInput() {
verify(clientAccessor, never()).inferenceSentences(anyString(), anyList(), any());
}

public void test_batchExecuteWithEmpty_allFailedValidation() {
final int docCount = 2;
TestInferenceProcessor processor = new TestInferenceProcessor(createMockVectorResult(), BATCH_SIZE, null);
List<IngestDocumentWrapper> wrapperList = createIngestDocumentWrappers(docCount);
wrapperList.get(0).getIngestDocument().setFieldValue("key1", Arrays.asList("", "value1"));
wrapperList.get(1).getIngestDocument().setFieldValue("key1", Arrays.asList("", "value1"));
Consumer resultHandler = mock(Consumer.class);
processor.batchExecute(wrapperList, resultHandler);
ArgumentCaptor<List<IngestDocumentWrapper>> captor = ArgumentCaptor.forClass(List.class);
verify(resultHandler).accept(captor.capture());
assertEquals(docCount, captor.getValue().size());
for (int i = 0; i < docCount; ++i) {
assertNotNull(captor.getValue().get(i).getException());
assertEquals(
"list type field [key1] has empty string, cannot process it",
captor.getValue().get(i).getException().getMessage()
);
assertEquals(wrapperList.get(i).getIngestDocument(), captor.getValue().get(i).getIngestDocument());
}
verify(clientAccessor, never()).inferenceSentences(anyString(), anyList(), any());
}

public void test_batchExecuteWithNull_allFailedValidation() {
final int docCount = 2;
TestInferenceProcessor processor = new TestInferenceProcessor(createMockVectorResult(), BATCH_SIZE, null);
Expand All @@ -107,27 +85,6 @@ public void test_batchExecuteWithNull_allFailedValidation() {
verify(clientAccessor, never()).inferenceSentences(anyString(), anyList(), any());
}

public void test_batchExecute_partialFailedValidation() {
final int docCount = 2;
TestInferenceProcessor processor = new TestInferenceProcessor(createMockVectorResult(), BATCH_SIZE, null);
List<IngestDocumentWrapper> wrapperList = createIngestDocumentWrappers(docCount);
wrapperList.get(0).getIngestDocument().setFieldValue("key1", Arrays.asList("", "value1"));
wrapperList.get(1).getIngestDocument().setFieldValue("key1", Arrays.asList("value3", "value4"));
Consumer resultHandler = mock(Consumer.class);
processor.batchExecute(wrapperList, resultHandler);
ArgumentCaptor<List<IngestDocumentWrapper>> captor = ArgumentCaptor.forClass(List.class);
verify(resultHandler).accept(captor.capture());
assertEquals(docCount, captor.getValue().size());
assertNotNull(captor.getValue().get(0).getException());
assertNull(captor.getValue().get(1).getException());
for (int i = 0; i < docCount; ++i) {
assertEquals(wrapperList.get(i).getIngestDocument(), captor.getValue().get(i).getIngestDocument());
}
ArgumentCaptor<List<String>> inferenceTextCaptor = ArgumentCaptor.forClass(List.class);
verify(clientAccessor).inferenceSentences(anyString(), inferenceTextCaptor.capture(), any());
assertEquals(2, inferenceTextCaptor.getValue().size());
}

public void test_batchExecute_happyCase() {
final int docCount = 2;
List<List<Float>> inferenceResults = createMockVectorWithLength(6);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import java.util.function.Consumer;
import java.util.function.Supplier;

import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.junit.Before;
import org.mockito.ArgumentCaptor;
Expand Down Expand Up @@ -240,31 +241,6 @@ public void testExecute_withListTypeInput_successful() {
verify(handler).accept(any(IngestDocument.class), isNull());
}

public void testExecute_SimpleTypeWithEmptyStringValue_throwIllegalArgumentException() {
Map<String, Object> sourceAndMetadata = new HashMap<>();
sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index");
sourceAndMetadata.put("key1", " ");
IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>());
TextEmbeddingProcessor processor = createInstanceWithLevel1MapConfig();

BiConsumer handler = mock(BiConsumer.class);
processor.execute(ingestDocument, handler);
verify(handler).accept(isNull(), any(IllegalArgumentException.class));
}

public void testExecute_listHasEmptyStringValue_throwIllegalArgumentException() {
List<String> list1 = ImmutableList.of("", "test2", "test3");
Map<String, Object> sourceAndMetadata = new HashMap<>();
sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index");
sourceAndMetadata.put("key1", list1);
IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>());
TextEmbeddingProcessor processor = createInstanceWithLevel1MapConfig();

BiConsumer handler = mock(BiConsumer.class);
processor.execute(ingestDocument, handler);
verify(handler).accept(isNull(), any(IllegalArgumentException.class));
}

public void testExecute_listHasNonStringValue_throwIllegalArgumentException() {
List<Integer> list2 = ImmutableList.of(1, 2, 3);
Map<String, Object> sourceAndMetadata = new HashMap<>();
Expand Down Expand Up @@ -549,20 +525,6 @@ public void testExecute_mapHasNonStringValue_throwIllegalArgumentException() {
verify(handler).accept(isNull(), any(IllegalArgumentException.class));
}

public void testExecute_mapHasEmptyStringValue_throwIllegalArgumentException() {
Map<String, String> map1 = ImmutableMap.of("test1", "test2");
Map<String, String> map2 = ImmutableMap.of("test3", " ");
Map<String, Object> sourceAndMetadata = new HashMap<>();
sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index");
sourceAndMetadata.put("key1", map1);
sourceAndMetadata.put("key2", map2);
IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>());
TextEmbeddingProcessor processor = createInstanceWithLevel2MapConfig();
BiConsumer handler = mock(BiConsumer.class);
processor.execute(ingestDocument, handler);
verify(handler).accept(isNull(), any(IllegalArgumentException.class));
}

public void testExecute_mapDepthReachLimit_throwIllegalArgumentException() {
Map<String, Object> ret = createMaxDepthLimitExceedMap(() -> 1);
Map<String, Object> sourceAndMetadata = new HashMap<>();
Expand Down Expand Up @@ -785,6 +747,79 @@ public void testBuildVectorOutput_withNestedListHasNotForEmbeddingField_Level2_s
assertNotNull(nestedObj.get(1).get("vectorField"));
}

@SuppressWarnings("unchecked")
public void testBuildVectorOutput_withPlainString_EmptyString_skipped() {
Map<String, Object> config = createPlainStringConfiguration();
IngestDocument ingestDocument = createPlainIngestDocument();
Map<String, Object> sourceAndMetadata = ingestDocument.getSourceAndMetadata();
sourceAndMetadata.put("oriKey1", StringUtils.EMPTY);

TextEmbeddingProcessor processor = createInstanceWithNestedMapConfiguration(config);
Map<String, Object> knnMap = processor.buildMapWithTargetKeys(ingestDocument);
List<List<Float>> modelTensorList = createRandomOneDimensionalMockVector(6, 100, 0.0f, 1.0f);
processor.setVectorFieldsToDocument(ingestDocument, knnMap, modelTensorList);

/** IngestDocument
* "oriKey1": "",
* "oriKey2": "oriValue2",
* "oriKey3": "oriValue3",
* "oriKey4": "oriValue4",
* "oriKey5": "oriValue5",
* "oriKey6": [
* "oriValue6",
* "oriValue7"
* ]
*
*/
assertEquals(11, sourceAndMetadata.size());
assertFalse(sourceAndMetadata.containsKey("oriKey1_knn"));
}

@SuppressWarnings("unchecked")
public void testBuildVectorOutput_withNestedField_EmptyString_skipped() {
Map<String, Object> config = createNestedMapConfiguration();
IngestDocument ingestDocument = createNestedMapIngestDocument();
Map<String, Object> favorites = (Map<String, Object>) ingestDocument.getSourceAndMetadata().get("favorites");
Map<String, Object> favorite = (Map<String, Object>) favorites.get("favorite");
favorite.put("movie", StringUtils.EMPTY);

TextEmbeddingProcessor processor = createInstanceWithNestedMapConfiguration(config);
Map<String, Object> knnMap = processor.buildMapWithTargetKeys(ingestDocument);
List<List<Float>> modelTensorList = createRandomOneDimensionalMockVector(1, 100, 0.0f, 1.0f);
processor.buildNLPResult(knnMap, modelTensorList, ingestDocument.getSourceAndMetadata());

/**
* "favorites": {
* "favorite": {
* "movie": "",
* "actor": "Charlie Chaplin",
* "games" : {
* "adventure": {
* "action": "overwatch",
* "rpg": "elden ring"
* }
* }
* }
* }
*/
Map<String, Object> favoritesMap = (Map<String, Object>) ingestDocument.getSourceAndMetadata().get("favorites");
assertNotNull(favoritesMap);
Map<String, Object> favoriteMap = (Map<String, Object>) favoritesMap.get("favorite");
assertNotNull(favoriteMap);

Map<String, Object> favoriteGames = (Map<String, Object>) favoriteMap.get("games");
assertNotNull(favoriteGames);
Map<String, Object> adventure = (Map<String, Object>) favoriteGames.get("adventure");
List<Float> adventureKnnVector = (List<Float>) adventure.get("with_action_knn");
assertNotNull(adventureKnnVector);
assertEquals(100, adventureKnnVector.size());
for (float vector : adventureKnnVector) {
assertTrue(vector >= 0.0f && vector <= 1.0f);
}

assertFalse(favoriteMap.containsKey("favorite_movie_knn"));
}

public void test_updateDocument_appendVectorFieldsToDocument_successful() {
Map<String, Object> config = createPlainStringConfiguration();
IngestDocument ingestDocument = createPlainIngestDocument();
Expand Down

0 comments on commit ee24b1c

Please sign in to comment.