{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "id": "5norOZI0mA6s" }, "outputs": [], "source": [ "# Copyright 2023 Google LLC\n", "#\n", "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "XNPE46X8mJj4" }, "source": [ "# Use Retrieval Augmented Generation (RAG) with Gemini API\n", "\n", "<table align=\"left\">\n", "\n", " <td style=\"text-align: center\">\n", " <a href=\"https://colab.research.google.com/github/GoogleCloudPlatform/generative-ai/blob/main/language/code/code_retrieval_augmented_generation.ipynb\">\n", " <img src=\"https://cloud.google.com/ml-engine/images/colab-logo-32px.png\" alt=\"Google Colaboratory logo\"><br> Open in Colab\n", " </a>\n", " </td>\n", "\n", " <td style=\"text-align: center\">\n", " <a href=\"https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2FGoogleCloudPlatform%2Fgenerative-ai%2Fmain%2Flanguage%2Fcode%2Fcode_retrieval_augmented_generation.ipynb\">\n", " <img width=\"32px\" src=\"https://cloud.google.com/ml-engine/images/colab-enterprise-logo-32px.png\" alt=\"Google Cloud Colab Enterprise logo\"><br> Open in Colab Enterprise\n", " </a>\n", " </td>\n", " <td style=\"text-align: center\">\n", " <a href=\"https://github.com/GoogleCloudPlatform/generative-ai/blob/main/language/code/code_retrieval_augmented_generation.ipynb\">\n", " <img src=\"https://cloud.google.com/ml-engine/images/github-logo-32px.png\" alt=\"GitHub logo\"><br> View on GitHub\n", " </a>\n", " </td>\n", " <td style=\"text-align: center\">\n", " <a href=\"https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/GoogleCloudPlatform/generative-ai/main/language/code/code_retrieval_augmented_generation.ipynb\">\n", " <img src=\"https://lh3.googleusercontent.com/UiNooY4LUgW_oTvpsNhPpQzsstV5W8F7rYgxgGBD85cWJoLmrOzhVs_ksK_vgx40SHs7jCqkTkCk=e14-rj-sc0xffffff-h130-w32\" alt=\"Vertex AI logo\"><br> Open in Workbench\n", " </a>\n", " </td>\n", "</table>\n", "\n", "<div style=\"clear: both;\"></div>\n", "\n", "<b>Share to:</b>\n", "\n", "<a href=\"https://www.linkedin.com/sharing/share-offsite/?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/language/code/code_retrieval_augmented_generation.ipynb\" target=\"_blank\">\n", " <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/8/81/LinkedIn_icon.svg\" alt=\"LinkedIn logo\">\n", "</a>\n", "\n", "<a href=\"https://bsky.app/intent/compose?text=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/language/code/code_retrieval_augmented_generation.ipynb\" target=\"_blank\">\n", " <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/7/7a/Bluesky_Logo.svg\" alt=\"Bluesky logo\">\n", "</a>\n", "\n", "<a href=\"https://twitter.com/intent/tweet?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/language/code/code_retrieval_augmented_generation.ipynb\" target=\"_blank\">\n", " <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/5/53/X_logo_2023_original.svg\" alt=\"X logo\">\n", "</a>\n", "\n", "<a href=\"https://reddit.com/submit?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/language/code/code_retrieval_augmented_generation.ipynb\" target=\"_blank\">\n", " <img width=\"20px\" src=\"https://redditinc.com/hubfs/Reddit%20Inc/Brand/Reddit_Logo.png\" alt=\"Reddit logo\">\n", "</a>\n", "\n", "<a href=\"https://www.facebook.com/sharer/sharer.php?u=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/language/code/code_retrieval_augmented_generation.ipynb\" target=\"_blank\">\n", " <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/5/51/Facebook_f_logo_%282019%29.svg\" alt=\"Facebook logo\">\n", "</a> " ] }, { "cell_type": "markdown", "metadata": { "id": "VrLtlKPFqSxB" }, "source": [ "| | |\n", "|-|-|\n", "|Author(s) | [Lavi Nigam](https://github.com/lavinigam-gcp), [Polong Lin](https://github.com/polong-lin) |" ] }, { "cell_type": "markdown", "metadata": { "id": "zNAEdYNFmQcP" }, "source": [ "### Objective\n", "\n", "This notebook demonstrates how you augment output from Gemini API by bringing in external knowledge. An example is provided using Code Retrieval Augmented Generation(RAG) pattern using [Google Cloud's Generative AI github repository](https://github.com/GoogleCloudPlatform/generative-ai) as external knowledge. The notebook uses [Gemini API in Vertex AI](https://ai.google.dev/gemini-api), [Embeddings for Text API](https://cloud.google.com/vertex-ai/docs/generative-ai/embeddings/get-text-embeddings), FAISS vector store and [LangChain 🦜️🔗](https://python.langchain.com/en/latest/).\n", "\n", "### Overview\n", "\n", "Here is overview of what we'll go over.\n", "\n", "Index Creation:\n", "\n", "1. Recursively list the files(.ipynb) in github repo\n", "2. Extract code and markdown from the files\n", "3. Chunk & generate embeddings for each code strings and add initialize the vector store\n", "\n", "Runtime:\n", "\n", "4. User enters a prompt or asks a question as a prompt\n", "5. Try zero-shot prompt\n", "6. Run prompt using RAG Chain & compare results.To generate response we use **gemini-1.5-pro**\n", "\n", "### Cost\n", "\n", "This tutorial uses billable components of Google Cloud:\n", "\n", "- Gemini API in Vertex AI offered by Google Cloud\n", "\n", "Learn about [Vertex AI pricing](https://cloud.google.com/vertex-ai/pricing) and use the [Pricing Calculator](https://cloud.google.com/products/calculator/) to generate a cost estimate based on your projected usage.\n", "\n", "**Note:** We are using local vector store(FAISS) for this example however recommend managed highly scalable vector store for production usage such as [Vertex AI Vector Search](https://cloud.google.com/vertex-ai/docs/vector-search/overview) or [AlloyDB for PostgreSQL](https://cloud.google.com/alloydb/docs/ai/work-with-embeddings) or [Cloud SQL for PostgreSQL](https://cloud.google.com/sql/docs/postgres/features) using pgvector extension." ] }, { "cell_type": "markdown", "metadata": { "id": "2cab0c8509c9" }, "source": [ "## Get started" ] }, { "cell_type": "markdown", "metadata": { "id": "b56b5a5d28c1" }, "source": [ "### Install Vertex AI SDK for Python and other required packages\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "QHaqV20Csqkt" }, "outputs": [], "source": [ "%pip install --upgrade --user -q google-cloud-aiplatform \\\n", " langchain \\\n", " langchain_google_vertexai \\\n", " langchain-community \\\n", " faiss-cpu \\\n", " nbformat" ] }, { "cell_type": "markdown", "metadata": { "id": "-VUWOgz6M1rZ" }, "source": [ "### Restart runtime (Colab only)\n", "\n", "To use the newly installed packages, you must restart the runtime on Google Colab." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "BIS8EYgkMy8T" }, "outputs": [], "source": [ "import sys\n", "\n", "if \"google.colab\" in sys.modules:\n", " import IPython\n", "\n", " app = IPython.Application.instance()\n", " app.kernel.do_shutdown(True)" ] }, { "cell_type": "markdown", "metadata": { "id": "0af13c10a26a" }, "source": [ "<div class=\"alert alert-block alert-warning\">\n", "<b>⚠️ The kernel is going to restart. Wait until it's finished before continuing to the next step. ⚠️</b>\n", "</div>\n" ] }, { "cell_type": "markdown", "metadata": { "id": "uZcP9WBENG0e" }, "source": [ "### Authenticate your notebook environment (Colab only)\n", "\n", "Authenticate your environment on Google Colab.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1S_HgQXQNcbz" }, "outputs": [], "source": [ "import sys\n", "\n", "if \"google.colab\" in sys.modules:\n", " from google.colab import auth\n", "\n", " auth.authenticate_user()" ] }, { "cell_type": "markdown", "metadata": { "id": "rVmxMr43Nhoo" }, "source": [ "### Import libraries" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "L-Tljm5asMBc" }, "outputs": [], "source": [ "import time\n", "\n", "from google.cloud import aiplatform\n", "from langchain.chains import RetrievalQA\n", "from langchain.prompts import PromptTemplate\n", "from langchain.schema.document import Document\n", "from langchain.text_splitter import Language, RecursiveCharacterTextSplitter\n", "from langchain.vectorstores import FAISS\n", "\n", "# LangChain\n", "from langchain_google_vertexai import VertexAI, VertexAIEmbeddings\n", "import nbformat\n", "import requests\n", "\n", "# Vertex AI\n", "import vertexai\n", "\n", "# Print the version of Vertex AI SDK for Python\n", "print(f\"Vertex AI SDK version: {aiplatform.__version__}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "4f872cd812d0" }, "source": [ "### Set Google Cloud project information and initialize Vertex AI SDK for Python\n", "\n", "To get started using Vertex AI, you must have an existing Google Cloud project and [enable the Vertex AI API](https://console.cloud.google.com/flows/enableapi?apiid=aiplatform.googleapis.com). Learn more about [setting up a project and a development environment](https://cloud.google.com/vertex-ai/docs/start/cloud-environment)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "eNGEcBKG0iK-" }, "outputs": [], "source": [ "# Initialize project\n", "# Define project information\n", "PROJECT_ID = \"YOUR_PROJECT_ID\" # @param {type:\"string\"}\n", "LOCATION = \"us-central1\" # @param {type:\"string\"}\n", "\n", "vertexai.init(project=PROJECT_ID, location=LOCATION)\n", "\n", "# Code Generation\n", "code_llm = VertexAI(\n", " model_name=\"gemini-1.5-pro\",\n", " max_output_tokens=2048,\n", " temperature=0.1,\n", " verbose=False,\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "o537exyZk9DI" }, "source": [ "Next we need to create a GitHub personal token to be able to list all files in a repository.\n", "\n", "- Follow [this link](https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/managing-your-personal-access-tokens) to create GitHub token with repo->public_repo scope and update `GITHUB_TOKEN` variable below." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Bt9IVDSqk7y4" }, "outputs": [], "source": [ "# provide GitHub personal access token\n", "GITHUB_TOKEN = \"YOUR_GITHUB_TOKEN\" # @param {type:\"string\"}\n", "GITHUB_REPO = \"GoogleCloudPlatform/generative-ai\" # @param {type:\"string\"}" ] }, { "cell_type": "markdown", "metadata": { "id": "dqq3GeEbOJbU" }, "source": [ "# Index Creation\n", "\n", "We use the Google Cloud Generative AI github repository as the data source. First list all Jupyter Notebook files in the repo and store it in a text file.\n", "\n", "You can skip this step(#1) if you have executed it once and generated the output text file.\n", "\n", "### 1. Recursively list the files(.ipynb) in the github repository" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "eTA1Jt0uOX8y" }, "outputs": [], "source": [ "# Crawls a GitHub repository and returns a list of all ipynb files in the repository\n", "\n", "\n", "def crawl_github_repo(url: str, is_sub_dir: bool, access_token: str = GITHUB_TOKEN):\n", " ignore_list = [\"__init__.py\"]\n", "\n", " if not is_sub_dir:\n", " api_url = f\"https://api.github.com/repos/{url}/contents\"\n", "\n", " else:\n", " api_url = url\n", "\n", " headers = {\n", " \"Accept\": \"application/vnd.github.v3+json\",\n", " \"Authorization\": f\"Bearer {access_token}\",\n", " }\n", "\n", " response = requests.get(api_url, headers=headers)\n", " response.raise_for_status() # Check for any request errors\n", "\n", " files = []\n", "\n", " contents = response.json()\n", "\n", " for item in contents:\n", " if (\n", " item[\"type\"] == \"file\"\n", " and item[\"name\"] not in ignore_list\n", " and (item[\"name\"].endswith(\".py\") or item[\"name\"].endswith(\".ipynb\"))\n", " ):\n", " files.append(item[\"html_url\"])\n", " elif item[\"type\"] == \"dir\" and not item[\"name\"].startswith(\".\"):\n", " sub_files = crawl_github_repo(item[\"url\"], True)\n", " time.sleep(0.1)\n", " files.extend(sub_files)\n", "\n", " return files" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "5vaKaxcGO_R6" }, "outputs": [], "source": [ "code_files_urls = crawl_github_repo(GITHUB_REPO, False, GITHUB_TOKEN)\n", "\n", "# Write list to a file so you do not have to download each time\n", "with open(\"code_files_urls.txt\", \"w\") as f:\n", " for item in code_files_urls:\n", " f.write(item + \"\\n\")\n", "\n", "len(code_files_urls)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "c5hoNYJ5byMJ" }, "outputs": [], "source": [ "code_files_urls[0:10]" ] }, { "cell_type": "markdown", "metadata": { "id": "mFNVieLnR8Ie" }, "source": [ "### 2. Extract code from the Jupyter notebooks.\n", "\n", "You could also include .py file, shell scripts etc." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZsM1M4hn4cBu" }, "outputs": [], "source": [ "# Extracts the python code from an ipynb file from github\n", "\n", "\n", "def extract_python_code_from_ipynb(github_url, cell_type=\"code\"):\n", " raw_url = github_url.replace(\"github.com\", \"raw.githubusercontent.com\").replace(\n", " \"/blob/\", \"/\"\n", " )\n", "\n", " response = requests.get(raw_url)\n", " response.raise_for_status() # Check for any request errors\n", "\n", " notebook_content = response.text\n", "\n", " notebook = nbformat.reads(notebook_content, as_version=nbformat.NO_CONVERT)\n", "\n", " python_code = None\n", "\n", " for cell in notebook.cells:\n", " if cell.cell_type == cell_type:\n", " if not python_code:\n", " python_code = cell.source\n", " else:\n", " python_code += \"\\n\" + cell.source\n", "\n", " return python_code\n", "\n", "\n", "def extract_python_code_from_py(github_url):\n", " raw_url = github_url.replace(\"github.com\", \"raw.githubusercontent.com\").replace(\n", " \"/blob/\", \"/\"\n", " )\n", "\n", " response = requests.get(raw_url)\n", " response.raise_for_status() # Check for any request errors\n", "\n", " python_code = response.text\n", "\n", " return python_code" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WCRp5Xtb48is" }, "outputs": [], "source": [ "with open(\"code_files_urls.txt\") as f:\n", " code_files_urls = f.read().splitlines()\n", "len(code_files_urls)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "4Y9SMO7H4xgF" }, "outputs": [], "source": [ "code_strings = []\n", "\n", "for i in range(0, len(code_files_urls)):\n", " if code_files_urls[i].endswith(\".ipynb\"):\n", " content = extract_python_code_from_ipynb(code_files_urls[i], \"code\")\n", " doc = Document(\n", " page_content=content, metadata={\"url\": code_files_urls[i], \"file_index\": i}\n", " )\n", " code_strings.append(doc)" ] }, { "cell_type": "markdown", "metadata": { "id": "T1AF3fhBSLOm" }, "source": [ "### 3. Chunk & generate embeddings for each code strings & initialize the vector store\n", "\n", "We need to split code into usable chunks that the LLM can use for code generation. Therefore it's crucial to use the right chunking approach and chunk size." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Rj1cCA2fqx64" }, "outputs": [], "source": [ "# Utility functions for Embeddings API with rate limiting\n", "\n", "\n", "def rate_limit(max_per_minute):\n", " period = 60 / max_per_minute\n", " print(\"Waiting\")\n", " while True:\n", " before = time.time()\n", " yield\n", " after = time.time()\n", " elapsed = after - before\n", " sleep_time = max(0, period - elapsed)\n", " if sleep_time > 0:\n", " print(\".\", end=\"\")\n", " time.sleep(sleep_time)\n", "\n", "\n", "class CustomVertexAIEmbeddings(VertexAIEmbeddings):\n", " requests_per_minute: int\n", " num_instances_per_batch: int\n", " model_name: str\n", "\n", " # Overriding embed_documents method\n", " def embed_documents(\n", " self, texts: list[str], batch_size: int | None = None\n", " ) -> list[list[float]]:\n", " limiter = rate_limit(self.requests_per_minute)\n", " results = []\n", " docs = list(texts)\n", "\n", " while docs:\n", " # Working in batches because the API accepts maximum 5\n", " # documents per request to get embeddings\n", " head, docs = (\n", " docs[: self.num_instances_per_batch],\n", " docs[self.num_instances_per_batch :],\n", " )\n", " chunk = self.client.get_embeddings(head)\n", " results.extend(chunk)\n", " next(limiter)\n", "\n", " return [r.values for r in results]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "oae37l-pvzZ6" }, "outputs": [], "source": [ "# Chunk code strings\n", "text_splitter = RecursiveCharacterTextSplitter.from_language(\n", " language=Language.PYTHON, chunk_size=2000, chunk_overlap=200\n", ")\n", "\n", "\n", "texts = text_splitter.split_documents(code_strings)\n", "print(len(texts))\n", "\n", "# Initialize Embedding API\n", "EMBEDDING_QPM = 100\n", "EMBEDDING_NUM_BATCH = 5\n", "embeddings = CustomVertexAIEmbeddings(\n", " requests_per_minute=EMBEDDING_QPM,\n", " num_instances_per_batch=EMBEDDING_NUM_BATCH,\n", " model_name=\"textembedding-gecko@latest\",\n", ")\n", "\n", "# Create Index from embedded code chunks\n", "db = FAISS.from_documents(texts, embeddings)\n", "\n", "# Init your retriever.\n", "retriever = db.as_retriever(\n", " search_type=\"similarity\", # Also test \"similarity\", \"mmr\"\n", " search_kwargs={\"k\": 5},\n", ")\n", "\n", "retriever" ] }, { "cell_type": "markdown", "metadata": { "id": "Q_gn89IyuHIT" }, "source": [ "# Runtime\n", "### 4. User enters a prompt or asks a question as a prompt" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1vrvTkO7uFNi" }, "outputs": [], "source": [ "user_question = \"Create a Python function that takes a prompt and predicts using langchain.llms interface with Vertex AI text-bison model\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "azbvOUFRvEp5" }, "outputs": [], "source": [ "# Define prompt templates\n", "\n", "# Zero Shot prompt template\n", "prompt_zero_shot = \"\"\"\n", " You are a proficient python developer. Respond with the syntactically correct & concise code for to the question below.\n", "\n", " Question:\n", " {question}\n", "\n", " Output Code :\n", " \"\"\"\n", "\n", "prompt_prompt_zero_shot = PromptTemplate(\n", " input_variables=[\"question\"],\n", " template=prompt_zero_shot,\n", ")\n", "\n", "\n", "# RAG template\n", "prompt_RAG = \"\"\"\n", " You are a proficient python developer. Respond with the syntactically correct code for to the question below. Make sure you follow these rules:\n", " 1. Use context to understand the APIs and how to use it & apply.\n", " 2. Do not add license information to the output code.\n", " 3. Do not include Colab code in the output.\n", " 4. Ensure all the requirements in the question are met.\n", "\n", " Question:\n", " {question}\n", "\n", " Context:\n", " {context}\n", "\n", " Helpful Response :\n", " \"\"\"\n", "\n", "prompt_RAG_template = PromptTemplate(\n", " template=prompt_RAG, input_variables=[\"context\", \"question\"]\n", ")\n", "\n", "qa_chain = RetrievalQA.from_llm(\n", " llm=code_llm,\n", " prompt=prompt_RAG_template,\n", " retriever=retriever,\n", " return_source_documents=True,\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "3NBaObAQSlIv" }, "source": [ "### 5. Try zero-shot prompt" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1svTVwtBS0zP" }, "outputs": [], "source": [ "response = code_llm.invoke(input=user_question, max_output_tokens=2048, temperature=0.1)\n", "print(response)" ] }, { "cell_type": "markdown", "metadata": { "id": "JPm8qdxzwPM0" }, "source": [ "### 6. Run prompt using RAG Chain & compare results\n", "To generate response we use code-bison however can also use code-gecko and codechat-bison" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZMz3nPMyVoj_" }, "outputs": [], "source": [ "results = qa_chain.invoke(input={\"query\": user_question})\n", "print(results[\"result\"])" ] }, { "cell_type": "markdown", "metadata": { "id": "HF3lVWK1wjxe" }, "source": [ "### Let's try another prompt" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jel0ON68XiU7" }, "outputs": [], "source": [ "user_question = \"Create python function that takes text input and returns embeddings using LangChain with Vertex AI textembedding-gecko model\"\n", "\n", "\n", "response = code_llm.invoke(input=user_question, max_output_tokens=2048, temperature=0.1)\n", "print(response)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "G9bIkqE8sO6P" }, "outputs": [], "source": [ "results = qa_chain.invoke(input={\"query\": user_question})\n", "print(results[\"result\"])" ] } ], "metadata": { "colab": { "name": "code_retrieval_augmented_generation.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }