Skip to content

Commit

Permalink
[Backport 2.x] Fix bug where document embedding fails to be generated…
Browse files Browse the repository at this point in the history
… due to document has dot in field name (#1076)

* Fix bug where document embedding fails to be generated due to document has dot in field name (#1062)

* Fix bug where document embedding fails to be generated due to document has dot in field name

Signed-off-by: Yizhe Liu <[email protected]>

* Address comments

Signed-off-by: Yizhe Liu <[email protected]>

---------

Signed-off-by: Yizhe Liu <[email protected]>

* Clean up unused validateFieldName() and use existing methods for TextEmbeddingProcessorIT (#1074)

Signed-off-by: Yizhe Liu <[email protected]>

---------

Signed-off-by: Yizhe Liu <[email protected]>
  • Loading branch information
yizheliu-amazon authored Jan 10, 2025
1 parent 38e1f30 commit 19ea370
Show file tree
Hide file tree
Showing 7 changed files with 629 additions and 96 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Address inconsistent scoring in hybrid query results ([#998](https://github.com/opensearch-project/neural-search/pull/998))
- Fix bug where ingested document has list of nested objects ([#1040](https://github.com/opensearch-project/neural-search/pull/1040))
- Fixed document source and score field mismatch in sorted hybrid queries ([#1043](https://github.com/opensearch-project/neural-search/pull/1043))
- Fix bug where embedding is missing when ingested document has "." in field name, and mismatches fieldMap config ([#1062](https://github.com/opensearch-project/neural-search/pull/1062))
### Infrastructure
- Update batch related tests to use batch_size in processor & refactor BWC version check ([#852](https://github.com/opensearch-project/neural-search/pull/852))
- Fix CI for JDK upgrade towards 21 ([#835](https://github.com/opensearch-project/neural-search/pull/835))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ public IngestDocument execute(IngestDocument ingestDocument) throws Exception {
@Override
public void execute(IngestDocument ingestDocument, BiConsumer<IngestDocument, Exception> handler) {
try {
preprocessIngestDocument(ingestDocument);
validateEmbeddingFieldsValue(ingestDocument);
Map<String, Object> processMap = buildMapWithTargetKeys(ingestDocument);
List<String> inferenceList = createInferenceList(processMap);
Expand All @@ -150,6 +151,15 @@ public void execute(IngestDocument ingestDocument, BiConsumer<IngestDocument, Ex
}
}

@VisibleForTesting
void preprocessIngestDocument(IngestDocument ingestDocument) {
if (ingestDocument == null || ingestDocument.getSourceAndMetadata() == null) return;
Map<String, Object> sourceAndMetadataMap = ingestDocument.getSourceAndMetadata();
Map<String, Object> unflattened = ProcessorDocumentUtils.unflattenJson(sourceAndMetadataMap);
unflattened.forEach(ingestDocument::setFieldValue);
sourceAndMetadataMap.keySet().removeIf(key -> key.contains("."));
}

/**
* This is the function which does actual inference work for batchExecute interface.
* @param inferenceList a list of String for inference.
Expand Down Expand Up @@ -244,12 +254,14 @@ private List<DataForInference> getDataForInference(List<IngestDocumentWrapper> i
for (IngestDocumentWrapper ingestDocumentWrapper : ingestDocumentWrappers) {
Map<String, Object> processMap = null;
List<String> inferenceList = null;
IngestDocument ingestDocument = ingestDocumentWrapper.getIngestDocument();
try {
validateEmbeddingFieldsValue(ingestDocumentWrapper.getIngestDocument());
processMap = buildMapWithTargetKeys(ingestDocumentWrapper.getIngestDocument());
preprocessIngestDocument(ingestDocument);
validateEmbeddingFieldsValue(ingestDocument);
processMap = buildMapWithTargetKeys(ingestDocument);
inferenceList = createInferenceList(processMap);
} catch (Exception e) {
ingestDocumentWrapper.update(ingestDocumentWrapper.getIngestDocument(), e);
ingestDocumentWrapper.update(ingestDocument, e);
} finally {
dataForInferences.add(new DataForInference(ingestDocumentWrapper, processMap, inferenceList));
}
Expand Down Expand Up @@ -333,13 +345,14 @@ void buildNestedMap(String parentKey, Object processorKey, Map<String, Object> s
} else if (sourceAndMetadataMap.get(parentKey) instanceof List) {
for (Map.Entry<String, Object> nestedFieldMapEntry : ((Map<String, Object>) processorKey).entrySet()) {
List<Map<String, Object>> list = (List<Map<String, Object>>) sourceAndMetadataMap.get(parentKey);
Pair<String, Object> processedNestedKey = processNestedKey(nestedFieldMapEntry);
List<Object> listOfStrings = list.stream().map(x -> {
Object nestedSourceValue = x.get(nestedFieldMapEntry.getKey());
Object nestedSourceValue = x.get(processedNestedKey.getKey());
return normalizeSourceValue(nestedSourceValue);
}).collect(Collectors.toList());
Map<String, Object> map = new LinkedHashMap<>();
map.put(nestedFieldMapEntry.getKey(), listOfStrings);
buildNestedMap(nestedFieldMapEntry.getKey(), nestedFieldMapEntry.getValue(), map, next);
map.put(processedNestedKey.getKey(), listOfStrings);
buildNestedMap(processedNestedKey.getKey(), processedNestedKey.getValue(), map, next);
}
}
treeRes.merge(parentKey, next, REMAPPING_FUNCTION);
Expand Down Expand Up @@ -387,7 +400,7 @@ private void validateEmbeddingFieldsValue(IngestDocument ingestDocument) {
ProcessorDocumentUtils.validateMapTypeValue(
FIELD_MAP_FIELD,
sourceAndMetadataMap,
fieldMap,
ProcessorDocumentUtils.unflattenJson(fieldMap),
indexName,
clusterService,
environment,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,14 @@
import org.opensearch.env.Environment;
import org.opensearch.index.mapper.MapperService;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Stack;

/**
* This class is used to accommodate the common code pieces of parsing, validating and processing the document for multiple
Expand Down Expand Up @@ -178,4 +181,142 @@ private static void validateDepth(
);
}
}

/**
* Unflatten a JSON object represented as a {@code Map<String, Object>}, possibly with dot in field name,
* into a nested {@code Map<String, Object>}
* "Object" can be either a {@code Map<String, Object>} or a {@code List<Object>} or simply a String.
* For example, input is {"a.b": "c"}, output is {"a":{"b": "c"}}.
* Another example:
* input is {"a": [{"b.c": "d"}, {"b.c": "e"}]},
* output is {"a": [{"b": {"c": "d"}}, {"b": {"c": "e"}}]}
* @param originalJsonMap the original JSON object represented as a {@code Map<String, Object>}
* @return the nested JSON object represented as a nested {@code Map<String, Object>}
* @throws IllegalArgumentException if the originalJsonMap is null or has invalid dot usage in field name
*/
public static Map<String, Object> unflattenJson(Map<String, Object> originalJsonMap) {
if (originalJsonMap == null) {
throw new IllegalArgumentException("originalJsonMap cannot be null");
}
Map<String, Object> result = new HashMap<>();
Stack<ProcessJsonObjectItem> stack = new Stack<>();

// Push initial items to stack
for (Map.Entry<String, Object> entry : originalJsonMap.entrySet()) {
stack.push(new ProcessJsonObjectItem(entry.getKey(), entry.getValue(), result));
}

// Process items until stack is empty
while (!stack.isEmpty()) {
ProcessJsonObjectItem item = stack.pop();
String key = item.key;
Object value = item.value;
Map<String, Object> currentMap = item.targetMap;

// Handle nested value
if (value instanceof Map) {
Map<String, Object> nestedMap = new HashMap<>();
for (Map.Entry<String, Object> entry : ((Map<String, Object>) value).entrySet()) {
stack.push(new ProcessJsonObjectItem(entry.getKey(), entry.getValue(), nestedMap));
}
value = nestedMap;
} else if (value instanceof List) {
value = handleList((List<Object>) value);
}

// If key contains dot, split and create nested structure
unflattenSingleItem(key, value, currentMap);
}

return result;
}

private static List<Object> handleList(List<Object> list) {
List<Object> result = new ArrayList<>();
Stack<ProcessJsonListItem> stack = new Stack<>();

// Push initial items to stack
for (int i = list.size() - 1; i >= 0; i--) {
stack.push(new ProcessJsonListItem(list.get(i), result));
}

// Process items until stack is empty
while (!stack.isEmpty()) {
ProcessJsonListItem item = stack.pop();
Object value = item.value;
List<Object> targetList = item.targetList;

if (value instanceof Map) {
Map<String, Object> nestedMap = new HashMap<>();
Map<String, Object> sourceMap = (Map<String, Object>) value;
for (Map.Entry<String, Object> entry : sourceMap.entrySet()) {
stack.push(new ProcessJsonListItem(new ProcessJsonObjectItem(entry.getKey(), entry.getValue(), nestedMap), targetList));
}
targetList.add(nestedMap);
} else if (value instanceof List) {
List<Object> nestedList = new ArrayList<>();
for (Object listItem : (List<Object>) value) {
stack.push(new ProcessJsonListItem(listItem, nestedList));
}
targetList.add(nestedList);
} else if (value instanceof ProcessJsonObjectItem) {
ProcessJsonObjectItem processJsonObjectItem = (ProcessJsonObjectItem) value;
Map<String, Object> tempMap = new HashMap<>();
unflattenSingleItem(processJsonObjectItem.key, processJsonObjectItem.value, tempMap);
targetList.set(targetList.size() - 1, tempMap);
} else {
targetList.add(value);
}
}

return result;
}

private static void unflattenSingleItem(String key, Object value, Map<String, Object> result) {
if (StringUtils.isBlank(key)) {
throw new IllegalArgumentException("Field name cannot be null or empty");
}
if (key.contains(".")) {
// Use split with -1 limit to preserve trailing empty strings
String[] parts = key.split("\\.", -1);
Map<String, Object> current = result;

for (int i = 0; i < parts.length; i++) {
if (StringUtils.isBlank(parts[i])) {
throw new IllegalArgumentException(String.format(Locale.ROOT, "Field name '%s' contains invalid dot usage", key));
}
if (i == parts.length - 1) {
current.put(parts[i], value);
continue;
}
current = (Map<String, Object>) current.computeIfAbsent(parts[i], k -> new HashMap<>());
}
} else {
result.put(key, value);
}
}

// Helper classes to maintain state during iteration
private static class ProcessJsonObjectItem {
String key;
Object value;
Map<String, Object> targetMap;

ProcessJsonObjectItem(String key, Object value, Map<String, Object> targetMap) {
this.key = key;
this.value = value;
this.targetMap = targetMap;
}
}

private static class ProcessJsonListItem {
Object value;
List<Object> targetList;

ProcessJsonListItem(Object value, List<Object> targetList) {
this.value = value;
this.targetList = targetList;
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ public class TextEmbeddingProcessorIT extends BaseNeuralSearchIT {
private final String INGEST_DOC2 = Files.readString(Path.of(classLoader.getResource("processor/ingest_doc2.json").toURI()));
private final String INGEST_DOC3 = Files.readString(Path.of(classLoader.getResource("processor/ingest_doc3.json").toURI()));
private final String INGEST_DOC4 = Files.readString(Path.of(classLoader.getResource("processor/ingest_doc4.json").toURI()));
private final String INGEST_DOC5 = Files.readString(Path.of(classLoader.getResource("processor/ingest_doc5.json").toURI()));
private final String BULK_ITEM_TEMPLATE = Files.readString(
Path.of(classLoader.getResource("processor/bulk_item_template.json").toURI())
);
Expand Down Expand Up @@ -168,6 +169,23 @@ private void assertDoc(Map<String, Object> sourceMap, String textFieldValue, Opt
}
}

private void assertDocWithLevel2AsList(Map<String, Object> sourceMap) {
assertNotNull(sourceMap);
assertTrue(sourceMap.containsKey(LEVEL_1_FIELD));
assertTrue(sourceMap.get(LEVEL_1_FIELD) instanceof List);
List<Map<String, Object>> nestedPassages = (List<Map<String, Object>>) sourceMap.get(LEVEL_1_FIELD);
nestedPassages.forEach(nestedPassage -> {
assertTrue(nestedPassage.containsKey(LEVEL_2_FIELD));
Map<String, Object> level2 = (Map<String, Object>) nestedPassage.get(LEVEL_2_FIELD);
Map<String, Object> level3 = (Map<String, Object>) level2.get(LEVEL_3_FIELD_CONTAINER);
List<Double> embeddings = (List<Double>) level3.get(LEVEL_3_FIELD_EMBEDDING);
assertEquals(768, embeddings.size());
for (Double embedding : embeddings) {
assertTrue(embedding >= 0.0 && embedding <= 1.0);
}
});
}

public void testTextEmbeddingProcessor_withBatchSizeInProcessor() throws Exception {
String modelId = null;
try {
Expand Down Expand Up @@ -232,6 +250,44 @@ public void testTextEmbeddingProcessor_withFailureAndSkip() throws Exception {
}
}

@SuppressWarnings("unchecked")
public void testNestedFieldMapping_whenDocumentInListIngested_thenSuccessful() throws Exception {
String modelId = null;
try {
modelId = uploadTextEmbeddingModel();
loadModel(modelId);
createPipelineProcessor(modelId, PIPELINE_NAME, ProcessorType.TEXT_EMBEDDING_WITH_NESTED_FIELDS_MAPPING);
createIndexWithPipeline(INDEX_NAME, "IndexMappings.json", PIPELINE_NAME);
ingestDocument(INDEX_NAME, INGEST_DOC5, "5");

assertDocWithLevel2AsList((Map<String, Object>) getDocById(INDEX_NAME, "5").get("_source"));

NeuralQueryBuilder neuralQueryBuilderQuery = NeuralQueryBuilder.builder()
.fieldName(LEVEL_1_FIELD + "." + LEVEL_2_FIELD + "." + LEVEL_3_FIELD_CONTAINER + "." + LEVEL_3_FIELD_EMBEDDING)
.queryText(QUERY_TEXT)
.modelId(modelId)
.k(10)
.build();

QueryBuilder queryNestedLowerLevel = QueryBuilders.nestedQuery(
LEVEL_1_FIELD + "." + LEVEL_2_FIELD,
neuralQueryBuilderQuery,
ScoreMode.Total
);
QueryBuilder queryNestedHighLevel = QueryBuilders.nestedQuery(LEVEL_1_FIELD, queryNestedLowerLevel, ScoreMode.Total);

Map<String, Object> searchResponseAsMap = search(INDEX_NAME, queryNestedHighLevel, 2);
assertNotNull(searchResponseAsMap);

assertEquals(1, getHitCount(searchResponseAsMap));

Map<String, Object> innerHitDetails = getFirstInnerHit(searchResponseAsMap);
assertEquals("5", innerHitDetails.get("_id"));
} finally {
wipeOfTestResources(INDEX_NAME, PIPELINE_NAME, modelId, null);
}
}

private String uploadTextEmbeddingModel() throws Exception {
String requestBody = Files.readString(Path.of(classLoader.getResource("processor/UploadModelRequestBody.json").toURI()));
return registerModelGroupAndUploadModel(requestBody);
Expand Down
Loading

0 comments on commit 19ea370

Please sign in to comment.