From cf709a6b7a951fc333ef5a089b24179ca660469b Mon Sep 17 00:00:00 2001 From: impulsivus Date: Wed, 24 May 2023 21:12:42 +0300 Subject: [PATCH] feat: Get answers using preferred number of chunks --- README.md | 1 + example.env | 3 ++- privateGPT.py | 3 ++- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 9881dec3c..6d2121348 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,7 @@ PERSIST_DIRECTORY: is the folder you want your vectorstore in MODEL_PATH: Path to your GPT4All or LlamaCpp supported LLM MODEL_N_CTX: Maximum token limit for the LLM model EMBEDDINGS_MODEL_NAME: SentenceTransformers embeddings model name (see https://www.sbert.net/docs/pretrained_models.html) +TARGET_SOURCE_CHUNKS: The amount of chunks (sources) that will be used to answer a question ``` Note: because of the way `langchain` loads the `SentenceTransformers` embeddings, the first time you run the script it will require internet connection to download the embeddings model itself. diff --git a/example.env b/example.env index 829078457..bcf13ebbb 100644 --- a/example.env +++ b/example.env @@ -2,4 +2,5 @@ PERSIST_DIRECTORY=db MODEL_TYPE=GPT4All MODEL_PATH=models/ggml-gpt4all-j-v1.3-groovy.bin EMBEDDINGS_MODEL_NAME=all-MiniLM-L6-v2 -MODEL_N_CTX=1000 \ No newline at end of file +MODEL_N_CTX=1000 +TARGET_SOURCE_DOCUMENTS=4 \ No newline at end of file diff --git a/privateGPT.py b/privateGPT.py index 7adab52d0..8fa10a6e0 100755 --- a/privateGPT.py +++ b/privateGPT.py @@ -16,6 +16,7 @@ model_type = os.environ.get('MODEL_TYPE') model_path = os.environ.get('MODEL_PATH') model_n_ctx = os.environ.get('MODEL_N_CTX') +target_source_chunks = int(os.environ.get('TARGET_SOURCE_CHUNKS',4)) from constants import CHROMA_SETTINGS @@ -24,7 +25,7 @@ def main(): args = parse_arguments() embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name) db = Chroma(persist_directory=persist_directory, embedding_function=embeddings, client_settings=CHROMA_SETTINGS) - retriever = db.as_retriever() + retriever = db.as_retriever(search_kwargs={"k": target_source_chunks}) # activate/deactivate the streaming StdOut callback for LLMs callbacks = [] if args.mute_stream else [StreamingStdOutCallbackHandler()] # Prepare the LLM