Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Search Pipeline processors, Remote Inference and HttpConnector to enable Retrieval Augmented Generation (RAG) #1195

Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions plugin/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ opensearchplugin {
dependencies {
implementation project(':opensearch-ml-common')
implementation project(':opensearch-ml-algorithms')
implementation project(':opensearch-ml-search-processors')

implementation group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}"
implementation "org.opensearch.client:opensearch-rest-client:${opensearch_version}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,18 @@
import org.opensearch.monitor.os.OsService;
import org.opensearch.plugins.ActionPlugin;
import org.opensearch.plugins.Plugin;
import org.opensearch.plugins.SearchPipelinePlugin;
import org.opensearch.plugins.SearchPlugin;
import org.opensearch.repositories.RepositoriesService;
import org.opensearch.rest.RestController;
import org.opensearch.rest.RestHandler;
import org.opensearch.script.ScriptService;
import org.opensearch.search.pipeline.Processor;
import org.opensearch.search.pipeline.SearchRequestProcessor;
import org.opensearch.search.pipeline.SearchResponseProcessor;
import org.opensearch.searchpipelines.questionanswering.generative.GenerativeQAParamExtBuilder;
import org.opensearch.searchpipelines.questionanswering.generative.GenerativeQARequestProcessor;
import org.opensearch.searchpipelines.questionanswering.generative.GenerativeQAResponseProcessor;
import org.opensearch.threadpool.ExecutorBuilder;
import org.opensearch.threadpool.FixedExecutorBuilder;
import org.opensearch.threadpool.ThreadPool;
Expand All @@ -180,7 +188,7 @@

import lombok.SneakyThrows;

public class MachineLearningPlugin extends Plugin implements ActionPlugin {
public class MachineLearningPlugin extends Plugin implements ActionPlugin, SearchPlugin, SearchPipelinePlugin {
public static final String ML_THREAD_POOL_PREFIX = "thread_pool.ml_commons.";
public static final String GENERAL_THREAD_POOL = "opensearch_ml_general";
public static final String EXECUTE_THREAD_POOL = "opensearch_ml_execute";
Expand Down Expand Up @@ -610,4 +618,26 @@ public List<Setting<?>> getSettings() {
);
return settings;
}

@Override
public List<SearchPlugin.SearchExtSpec<?>> getSearchExts() {
return List
.of(
new SearchPlugin.SearchExtSpec<>(
GenerativeQAParamExtBuilder.NAME,
input -> new GenerativeQAParamExtBuilder(input),
parser -> GenerativeQAParamExtBuilder.parse(parser)
)
);
}

@Override
public Map<String, Processor.Factory<SearchRequestProcessor>> getRequestProcessors(Parameters parameters) {
return Map.of(GenerativeQARequestProcessor.TYPE, new GenerativeQARequestProcessor.Factory());
}

@Override
public Map<String, Processor.Factory<SearchResponseProcessor>> getResponseProcessors(Parameters parameters) {
return Map.of(GenerativeQAResponseProcessor.TYPE, new GenerativeQAResponseProcessor.Factory(this.client));
}
}
93 changes: 93 additions & 0 deletions search-processors/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# conversational-search-processors
OpenSearch search processors providing conversational search capabilities
=======
# Plugin for Conversations Using Search Processors in OpenSearch
This repo is a WIP plugin for handling conversations in OpenSearch ([Per this RFC](https://github.com/opensearch-project/ml-commons/issues/1150)).

Conversational Retrieval Augmented Generation (RAG) is implemented via Search processors that combine user questions and OpenSearch query results as input to an LLM, e.g. OpenAI, and return answers.

## Creating a search pipeline with the GenerativeQAResponseProcessor

```
PUT /_search/pipeline/<search pipeline name>
{
"response_processors": [
{
"generative_qa": {
"tag": <tag>,
"description": <description>,
"opensearch_model_id": "<model_id>",
austintlee marked this conversation as resolved.
Show resolved Hide resolved
"context_field": <field> (e.g. "text")
}
}
]
}
```

The 'opensearch_model_id' parameter here needs to refer to a model of type REMOTE that has an HttpConnector instance associated with it.

## Making a search request against an index using the above processor
```
GET /<index>/_search\?search_pipeline\=<search pipeline name>
{
"_source": ["title", "text"],
"query" : {
"neural": {
"text_vector": {
"query_text": <query string>,
"k": <integer> (e.g. 10),
"model_id": <model_id>
}
}
},
"ext": {
"generative_qa_parameters": {
"llm_model": <LLM model> (e.g. "gpt-3.5-turbo"),
"llm_question": <question string>
}
austintlee marked this conversation as resolved.
Show resolved Hide resolved
}
}
```

## Retrieval Augmented Generation response
```
{
"took": 3,
"timed_out": false,
"_shards": {
"total": 3,
"successful": 3,
"skipped": 0,
"failed": 0
},
"hits": {
"total": {
"value": 110,
"relation": "eq"
},
"max_score": 0.55129033,
"hits": [
{
"_index": "...",
"_id": "...",
"_score": 0.55129033,
"_source": {
"text": "...",
"title": "..."
}
},
{
...
}
...
{
...
},
"ext": {
"generative_qa": {
"answer": "..."
}
austintlee marked this conversation as resolved.
Show resolved Hide resolved
}
}
```
The RAG answer is returned as an "ext" to SearchResponse following the "hits" array.
austintlee marked this conversation as resolved.
Show resolved Hide resolved
56 changes: 56 additions & 0 deletions search-processors/build.gradle
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
plugins {
id 'java'
id 'jacoco'
id "io.freefair.lombok"
}

repositories {
mavenCentral()
}
austintlee marked this conversation as resolved.
Show resolved Hide resolved

dependencies {

compileOnly group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}"
implementation 'org.apache.commons:commons-lang3:3.12.0'
implementation project(':opensearch-ml-client')
implementation project(':opensearch-ml-common')
implementation group: 'org.opensearch', name: 'common-utils', version: "${common_utils_version}"
// https://mvnrepository.com/artifact/org.apache.httpcomponents.core5/httpcore5
implementation group: 'org.apache.httpcomponents.core5', name: 'httpcore5', version: '5.2.1'
implementation("com.google.guava:guava:32.0.1-jre")
implementation group: 'org.json', name: 'json', version: '20230227'
implementation group: 'org.apache.commons', name: 'commons-text', version: '1.10.0'
testImplementation "org.opensearch.test:framework:${opensearch_version}"
}

test {
include '**/*Tests.class'
systemProperty 'tests.security.manager', 'false'
}

jacocoTestReport {
dependsOn /*integTest,*/ test
reports {
xml.required = true
html.required = true
}
}

jacocoTestCoverageVerification {
violationRules {
rule {
limit {
counter = 'LINE'
minimum = 0.65 //TODO: increase coverage to 0.90
}
limit {
counter = 'BRANCH'
minimum = 0.55 //TODO: increase coverage to 0.85
}
}
}
dependsOn jacocoTestReport
}

check.dependsOn jacocoTestCoverageVerification
//jacocoTestCoverageVerification.dependsOn jacocoTestReport
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.searchpipelines.questionanswering.generative;
austintlee marked this conversation as resolved.
Show resolved Hide resolved
austintlee marked this conversation as resolved.
Show resolved Hide resolved

import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.Setter;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.search.SearchExtBuilder;

import java.io.IOException;
import java.util.Objects;

/**
* This is the extension builder for generative QA search pipelines.
*/
@NoArgsConstructor
austintlee marked this conversation as resolved.
Show resolved Hide resolved
public class GenerativeQAParamExtBuilder extends SearchExtBuilder {

public static final String NAME = "generative_qa_parameters";
austintlee marked this conversation as resolved.
Show resolved Hide resolved

@Setter
@Getter
private GenerativeQAParameters params;

public GenerativeQAParamExtBuilder(StreamInput input) throws IOException {
this.params = new GenerativeQAParameters(input);
}

@Override
public int hashCode() {
return Objects.hash(this.getClass(), this.params);
}

@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}

if (!(obj instanceof GenerativeQAParamExtBuilder)) {
return false;
}

return Objects.equals(this.getParams(), ((GenerativeQAParamExtBuilder) obj).getParams());
}

@Override
public String getWriteableName() {
return NAME;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
this.params.writeTo(out);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
return builder.value(params);
}

public static GenerativeQAParamExtBuilder parse(XContentParser parser) throws IOException {
GenerativeQAParamExtBuilder builder = new GenerativeQAParamExtBuilder();
GenerativeQAParameters params = GenerativeQAParameters.parse(parser);
builder.setParams(params);
return builder;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.searchpipelines.questionanswering.generative;

import org.opensearch.action.search.SearchRequest;
import org.opensearch.search.SearchExtBuilder;

import java.util.Optional;

/**
* Utility class for extracting generative QA search pipeline parameters from search requests.
*/
public class GenerativeQAParamUtil {
austintlee marked this conversation as resolved.
Show resolved Hide resolved

public static GenerativeQAParameters getGenerativeQAParameters(SearchRequest request) {
GenerativeQAParamExtBuilder builder = null;
if (request.source() != null && request.source().ext() != null && !request.source().ext().isEmpty()) {
Optional<SearchExtBuilder> b = request.source().ext().stream().filter(bldr -> GenerativeQAParamExtBuilder.NAME.equals(bldr.getWriteableName())).findFirst();
if (b.isPresent()) {
builder = (GenerativeQAParamExtBuilder) b.get();
}
}

GenerativeQAParameters params = null;
if (builder != null) {
params = builder.getParams();
}

return params;
}
}
Loading