{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ur8xi4C7S06n"
      },
      "outputs": [],
      "source": [
        "# Copyright 2024 Google LLC\n",
        "#\n",
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "#     https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JAPoU8Sm5E6e"
      },
      "source": [
        "# Vertex AI RAG Engine with Vertex AI Search\n",
        "\n",
        "<table align=\"left\">\n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://colab.research.google.com/github/GoogleCloudPlatform/generative-ai/blob/main/gemini/rag-engine/rag_engine_vertex_ai_search.ipynb\">\n",
        "      <img width=\"32px\" src=\"https://www.gstatic.com/pantheon/images/bigquery/welcome_page/colab-logo.svg\" alt=\"Google Colaboratory logo\"><br> Open in Colab\n",
        "    </a>\n",
        "  </td>\n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2FGoogleCloudPlatform%2Fgenerative-ai%2Fmain%2Fgemini%2Frag-engine%2Frag_engine_vertex_ai_search.ipynb\">\n",
        "      <img width=\"32px\" src=\"https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN\" alt=\"Google Cloud Colab Enterprise logo\"><br> Open in Colab Enterprise\n",
        "    </a>\n",
        "  </td>\n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/GoogleCloudPlatform/generative-ai/main/gemini/rag-engine/rag_engine_vertex_ai_search.ipynb\">\n",
        "      <img src=\"https://www.gstatic.com/images/branding/gcpiconscolors/vertexai/v1/32px.svg\" alt=\"Vertex AI logo\"><br> Open in Vertex AI Workbench\n",
        "    </a>\n",
        "  </td>\n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/rag-engine/rag_engine_vertex_ai_search.ipynb\">\n",
        "      <img width=\"32px\" src=\"https://upload.wikimedia.org/wikipedia/commons/9/91/Octicons-mark-github.svg\" alt=\"GitHub logo\"><br> View on GitHub\n",
        "    </a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "84f0f73a0f76"
      },
      "source": [
        "| | |\n",
        "|-|-|\n",
        "| Author(s) | [Alex Dorozhkin](https://github.com/galexdor) |"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tvgnzT1CKxrO"
      },
      "source": [
        "## Overview\n",
        "\n",
        "This notebook illustrates how to use [Vertex AI RAG Engine](https://cloud.google.com/vertex-ai/generative-ai/docs/rag-overview) with [Vertex AI Search](https://cloud.google.com/enterprise-search) as a retrieval backend. Vertex AI Search's ability to handle large datasets, provide low-latency retrieval, and improve scalability makes it a powerful tool for enhancing RAG applications.  By integrating Vertex AI Search, you can ensure that your RAG applications can efficiently access and process the necessary information for generating high-quality and contextually relevant responses.\n",
        "\n",
        "For more information, refer to the [official documentation](https://cloud.google.com/generative-ai-app-builder/docs/enterprise-search-introduction).\n",
        "\n",
        "For more details on RAG corpus/file management and detailed support please visit [Vertex AI RAG Engine API](https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/rag-api)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "61RBz8LLbxCR"
      },
      "source": [
        "## Get started"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "No17Cw5hgx12"
      },
      "source": [
        "### Install Vertex AI SDK and other required packages\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tFy3H3aPgx12"
      },
      "outputs": [],
      "source": [
        "%pip install --upgrade --user --quiet google-cloud-aiplatform google-cloud-discoveryengine"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "R5Xep4W9lq-Z"
      },
      "source": [
        "### Restart runtime\n",
        "\n",
        "To use the newly installed packages in this Jupyter runtime, you must restart the runtime. You can do this by running the cell below, which restarts the current kernel.\n",
        "\n",
        "The restart might take a minute or longer. After it's restarted, continue to the next step."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XRvKdaPDTznN"
      },
      "outputs": [],
      "source": [
        "import IPython\n",
        "\n",
        "app = IPython.Application.instance()\n",
        "app.kernel.do_shutdown(True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SbmM4z7FOBpM"
      },
      "source": [
        "<div class=\"alert alert-block alert-warning\">\n",
        "<b>⚠️ The kernel is going to restart. Wait until it's finished before continuing to the next step. ⚠️</b>\n",
        "</div>\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DF4l8DTdWgPY"
      },
      "source": [
        "### Set Google Cloud project information and initialize Vertex AI SDK\n",
        "\n",
        "To get started using Vertex AI, you must have an existing Google Cloud project and [enable the Vertex AI API](https://console.cloud.google.com/flows/enableapi?apiid=aiplatform.googleapis.com).\n",
        "\n",
        "Learn more about [setting up a project and a development environment](https://cloud.google.com/vertex-ai/docs/start/cloud-environment)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 15,
      "metadata": {
        "id": "Nqwi-5ufWp_B"
      },
      "outputs": [],
      "source": [
        "# Use the environment variable if the user doesn't provide Project ID.\n",
        "import os\n",
        "\n",
        "import vertexai\n",
        "\n",
        "PROJECT_ID = \"[your-project-id]\"  # @param {type:\"string\", isTemplate: true}\n",
        "if PROJECT_ID == \"[your-project-id]\":\n",
        "    PROJECT_ID = str(os.environ.get(\"GOOGLE_CLOUD_PROJECT\"))\n",
        "\n",
        "LOCATION = os.environ.get(\"GOOGLE_CLOUD_REGION\", \"us-central1\")\n",
        "\n",
        "vertexai.init(project=PROJECT_ID, location=LOCATION)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RQ-QWIZk6Rqb"
      },
      "source": [
        "### Authenticate your notebook environment (Colab only)\n",
        "\n",
        "If you're running this notebook on Google Colab, run the cell below to authenticate your environment."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 16,
      "metadata": {
        "id": "FsX2KKvx7tm4"
      },
      "outputs": [],
      "source": [
        "import sys\n",
        "\n",
        "if \"google.colab\" in sys.modules:\n",
        "    from google.colab import auth\n",
        "\n",
        "    auth.authenticate_user(project_id=PROJECT_ID)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "N4QCqfqJ36LR"
      },
      "source": [
        "## (Optional) Setup Vertex AI Search Datastore and Engine"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VxksV2zm6L1V"
      },
      "source": [
        "In this section, we have some helper methods to help you setup your Vertex AI Search. These methods handle the creation of resources like Data Stores and Engines, which can take a few minutes.\n",
        "\n",
        "This section is not required if you already have a Vertex AI Search engine ready to use.\n",
        "\n",
        "To get started using Vertex AI Search, you must have an existing Google Cloud project and [enable the Discovery Engine API](https://console.cloud.google.com/flows/enableapi?apiid=discoveryengine.googleapis.com)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XmvYhLurAta4"
      },
      "source": [
        "### Initialize Vertex AI Search SDK"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "id": "UKX-J8OKAz3J"
      },
      "outputs": [],
      "source": [
        "from google.api_core.client_options import ClientOptions\n",
        "from google.cloud import discoveryengine\n",
        "\n",
        "VERTEX_AI_SEARCH_LOCATION = \"global\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kfkQMK1yBnN6"
      },
      "source": [
        "### Create and Populate a Datastore"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 12,
      "metadata": {
        "id": "MwUPv8h5BqEJ"
      },
      "outputs": [],
      "source": [
        "def create_data_store(\n",
        "    project_id: str, location: str, data_store_name: str, data_store_id: str\n",
        "):\n",
        "    # Create a client\n",
        "    client_options = (\n",
        "        ClientOptions(api_endpoint=f\"{location}-discoveryengine.googleapis.com\")\n",
        "        if location != \"global\"\n",
        "        else None\n",
        "    )\n",
        "    client = discoveryengine.DataStoreServiceClient(client_options=client_options)\n",
        "\n",
        "    # Initialize request argument(s)\n",
        "    data_store = discoveryengine.DataStore(\n",
        "        display_name=data_store_name,\n",
        "        industry_vertical=discoveryengine.IndustryVertical.GENERIC,\n",
        "        content_config=discoveryengine.DataStore.ContentConfig.CONTENT_REQUIRED,\n",
        "    )\n",
        "\n",
        "    operation = client.create_data_store(\n",
        "        request=discoveryengine.CreateDataStoreRequest(\n",
        "            parent=client.collection_path(project_id, location, \"default_collection\"),\n",
        "            data_store=data_store,\n",
        "            data_store_id=data_store_id,\n",
        "        )\n",
        "    )\n",
        "\n",
        "    # Make the request\n",
        "    response = operation.result(timeout=90)\n",
        "    return response.name"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ABXeS6jCBs8H"
      },
      "outputs": [],
      "source": [
        "# The datastore name can only contain lowercase letters, numbers, and hyphens\n",
        "DATASTORE_NAME = \"alphabet-contracts\"  # @param {type:\"string\", isTemplate: true}\n",
        "DATASTORE_ID = f\"{DATASTORE_NAME}-id\"\n",
        "\n",
        "create_data_store(PROJECT_ID, VERTEX_AI_SEARCH_LOCATION, DATASTORE_NAME, DATASTORE_ID)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 19,
      "metadata": {
        "id": "jRMWKUJXBwhu"
      },
      "outputs": [],
      "source": [
        "def import_documents(\n",
        "    project_id: str,\n",
        "    location: str,\n",
        "    data_store_id: str,\n",
        "    gcs_uri: str,\n",
        "):\n",
        "    # Create a client\n",
        "    client_options = (\n",
        "        ClientOptions(api_endpoint=f\"{location}-discoveryengine.googleapis.com\")\n",
        "        if location != \"global\"\n",
        "        else None\n",
        "    )\n",
        "    client = discoveryengine.DocumentServiceClient(client_options=client_options)\n",
        "\n",
        "    # The full resource name of the search engine branch.\n",
        "    # e.g. projects/{project}/locations/{location}/dataStores/{data_store_id}/branches/{branch}\n",
        "    parent = client.branch_path(\n",
        "        project=project_id,\n",
        "        location=location,\n",
        "        data_store=data_store_id,\n",
        "        branch=\"default_branch\",\n",
        "    )\n",
        "\n",
        "    source_documents = [f\"{gcs_uri}/*\"]\n",
        "\n",
        "    request = discoveryengine.ImportDocumentsRequest(\n",
        "        parent=parent,\n",
        "        gcs_source=discoveryengine.GcsSource(\n",
        "            input_uris=source_documents, data_schema=\"content\"\n",
        "        ),\n",
        "        # Options: `FULL`, `INCREMENTAL`\n",
        "        reconciliation_mode=discoveryengine.ImportDocumentsRequest.ReconciliationMode.INCREMENTAL,\n",
        "    )\n",
        "\n",
        "    # Make the request\n",
        "    operation = client.import_documents(request=request)\n",
        "\n",
        "    response = operation.result()\n",
        "\n",
        "    # Once the operation is complete,\n",
        "    # get information from operation metadata\n",
        "    metadata = discoveryengine.ImportDocumentsMetadata(operation.metadata)\n",
        "\n",
        "    # Handle the response\n",
        "    return operation.operation.name"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fJlFkhe7BznS"
      },
      "outputs": [],
      "source": [
        "GCS_BUCKET = \"gs://cloud-samples-data/gen-app-builder/search/alphabet-investor-pdfs\"  # @param {type:\"string\", isTemplate: true}\n",
        "\n",
        "import_documents(PROJECT_ID, VERTEX_AI_SEARCH_LOCATION, DATASTORE_ID, GCS_BUCKET)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gvrbLtDj6zCv"
      },
      "source": [
        "### Create a Search Engine"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 21,
      "metadata": {
        "id": "GHaRWVvZ6vOc"
      },
      "outputs": [],
      "source": [
        "def create_engine(\n",
        "    project_id: str, location: str, engine_name: str, engine_id: str, data_store_id: str\n",
        "):\n",
        "    # Create a client\n",
        "    client_options = (\n",
        "        ClientOptions(api_endpoint=f\"{location}-discoveryengine.googleapis.com\")\n",
        "        if location != \"global\"\n",
        "        else None\n",
        "    )\n",
        "    client = discoveryengine.EngineServiceClient(client_options=client_options)\n",
        "\n",
        "    # Initialize request argument(s)\n",
        "    engine = discoveryengine.Engine(\n",
        "        display_name=engine_name,\n",
        "        solution_type=discoveryengine.SolutionType.SOLUTION_TYPE_SEARCH,\n",
        "        industry_vertical=discoveryengine.IndustryVertical.GENERIC,\n",
        "        data_store_ids=[data_store_id],\n",
        "        search_engine_config=discoveryengine.Engine.SearchEngineConfig(\n",
        "            search_tier=discoveryengine.SearchTier.SEARCH_TIER_ENTERPRISE,\n",
        "        ),\n",
        "    )\n",
        "\n",
        "    request = discoveryengine.CreateEngineRequest(\n",
        "        parent=client.collection_path(project_id, location, \"default_collection\"),\n",
        "        engine=engine,\n",
        "        engine_id=engine.display_name,\n",
        "    )\n",
        "\n",
        "    # Make the request\n",
        "    operation = client.create_engine(request=request)\n",
        "    response = operation.result(timeout=90)\n",
        "    return response.name"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6Z21T3Pm8ngv"
      },
      "outputs": [],
      "source": [
        "ENGINE_NAME = DATASTORE_NAME\n",
        "ENGINE_ID = DATASTORE_ID\n",
        "create_engine(\n",
        "    PROJECT_ID, VERTEX_AI_SEARCH_LOCATION, ENGINE_NAME, ENGINE_ID, DATASTORE_ID\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EdvJRUWRNGHE"
      },
      "source": [
        "## Create a RAG corpus using Vertex AI Search Engine as the retrieval backend"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5303c05f7aa6"
      },
      "source": [
        "### Import libraries"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 23,
      "metadata": {
        "id": "6fc324893334"
      },
      "outputs": [],
      "source": [
        "from vertexai.preview import rag\n",
        "from vertexai.preview.generative_models import GenerativeModel, Tool"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "e43229f3ad4f"
      },
      "source": [
        "### Create RAG Config\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 24,
      "metadata": {
        "id": "cf93d5f0ce00"
      },
      "outputs": [],
      "source": [
        "# Name your corpus\n",
        "DISPLAY_NAME = \"\"  # @param {type:\"string\", \"placeholder\": \"your-corpus-name\"}\n",
        "\n",
        "# Vertex AI Search name\n",
        "ENGINE_NAME = \"\"  # @param {type:\"string\", \"placeholder\": \"your-engine-name\"}\n",
        "vertex_ai_search_config = rag.VertexAiSearchConfig(\n",
        "    serving_config=f\"{ENGINE_NAME}/servingConfigs/default_search\",\n",
        ")\n",
        "\n",
        "rag_corpus = rag.create_corpus(\n",
        "    display_name=DISPLAY_NAME,\n",
        "    vertex_ai_search_config=vertex_ai_search_config,\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2834a2721633"
      },
      "outputs": [],
      "source": [
        "# Check the corpus just created\n",
        "new_corpus = rag.get_corpus(name=rag_corpus.name)\n",
        "new_corpus"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "e1c1c7944dc8"
      },
      "source": [
        "## Using Gemini GenerateContent API with Rag Retrieval Tool\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 26,
      "metadata": {
        "id": "d3afdf13dbbb"
      },
      "outputs": [],
      "source": [
        "rag_resource = rag.RagResource(\n",
        "    rag_corpus=rag_corpus.name,\n",
        ")\n",
        "\n",
        "rag_retrieval_tool = Tool.from_retrieval(\n",
        "    retrieval=rag.Retrieval(\n",
        "        source=rag.VertexRagStore(\n",
        "            rag_resources=[rag_resource],  # Currently only 1 corpus is allowed.\n",
        "            similarity_top_k=10,\n",
        "        ),\n",
        "    )\n",
        ")\n",
        "\n",
        "rag_model = GenerativeModel(\"gemini-1.5-flash\", tools=[rag_retrieval_tool])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hLAzPNJOOKuI"
      },
      "source": [
        "Note: The Vertex AI Search engine will take some time to be ready to query.\n",
        "\n",
        "If you recently created an engine and you receive an error similar to:\n",
        "\n",
        "`404 Engine {ENGINE_NAME} is not found`\n",
        "\n",
        "Then wait a few minutes and try your query again."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "484f5242dcae"
      },
      "outputs": [],
      "source": [
        "GENERATE_CONTENT_PROMPT = (\n",
        "    \"Who is CFO of Google?\"  # @param {type:\"string\", isTemplate: true}\n",
        ")\n",
        "\n",
        "response = rag_model.generate_content(GENERATE_CONTENT_PROMPT)\n",
        "\n",
        "response"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "287a90fed14f"
      },
      "source": [
        "## Using other generation API with Rag Retrieval Tool\n",
        "\n",
        "The retrieved contexts can be passed to any SDK or model generation API to generate final results."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "428921dea97d"
      },
      "outputs": [],
      "source": [
        "RETRIEVAL_QUERY = \"Who is CFO of Google?\"  # @param {type:\"string\", isTemplate: true}\n",
        "\n",
        "rag_resource = rag.RagResource(rag_corpus=rag_corpus.name)\n",
        "\n",
        "response = rag.retrieval_query(\n",
        "    rag_resources=[rag_resource],  # Currently only 1 corpus is allowed.\n",
        "    text=RETRIEVAL_QUERY,\n",
        "    similarity_top_k=10,\n",
        ")\n",
        "\n",
        "# The retrieved context can be passed to any SDK or model generation API to generate final results.\n",
        "retrieved_context = \" \".join(\n",
        "    [context.text for context in response.contexts.contexts]\n",
        ").replace(\"\\n\", \"\")\n",
        "\n",
        "retrieved_context"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2a4e033321ad"
      },
      "source": [
        "## Cleaning up\n",
        "\n",
        "Clean up RAG resources created in this notebook."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "a105caffd9e7"
      },
      "outputs": [],
      "source": [
        "delete_rag_corpus = False  # @param {type:\"boolean\"}\n",
        "\n",
        "if delete_rag_corpus:\n",
        "    rag.delete_corpus(name=rag_corpus.name)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "name": "rag_engine_vertex_ai_search.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}