Skip to content

Commit

Permalink
Adding BM25 and GPT retrievers (ShishirPatil#61)
Browse files Browse the repository at this point in the history
Gorilla introduces Retriever Aware Training (RAT). Here we include the
code for retrievers. Note that while you can use {prompt + retrieved
API} with the 0-shot trained Gorilla LLMs for inference, for optimal
performance, we recommend you use {prompt + retrieved API} with RAT
trained Gorilla LLM for inference.

More precisely, BM25 refers to `BM25Okapi` and GPT refers to
`text-embedding-ada-002-v2`

Directly relevant to ShishirPatil#58 and might be useful for ShishirPatil#32 

+ Removing unused legacy code from `ast_eval_*.py`
+ Updating Apache 2.0 license for all of Gorilla code

---------

Co-authored-by: Tianjun <[email protected]>
  • Loading branch information
ShishirPatil and tianjunz authored Jul 18, 2023
1 parent 29f5ffb commit c849d11
Show file tree
Hide file tree
Showing 15 changed files with 823 additions and 56 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

**:rocket: Try Gorilla in 60s** [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1DEBPsccVLF_aUnmD0FwPeHFrtdC0QIUP?usp=sharing)

:computer: Use [Gorilla in your CLI](https://github.com/gorilla-llm/gorilla-cli)
:computer: Use [Gorilla in your CLI](https://github.com/gorilla-llm/gorilla-cli) with `pip install gorilla-cli`

**:newspaper_roll: Checkout our paper!** [![arXiv](https://img.shields.io/badge/arXiv-2305.15334-<COLOR>.svg?style=flat-square)](https://arxiv.org/abs/2305.15334)

Expand Down
13 changes: 13 additions & 0 deletions data/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,16 @@ Submit a json file containing the list of json objects:
**LLM assistance for LLM API dataset :wink:**:

Visit our website [Coming Soon!], where you can type in the API URL and we'll output an API submission for you! You can choose to submit or edit before and submitting. Easy and quick!

## Citation

If you use Gorilla or APIBench, please cite our paper:

```text
@article{patil2023gorilla,
title={Gorilla: Large Language Model Connected with Massive APIs},
author={Shishir G. Patil and Tianjun Zhang and Xin Wang and Joseph E. Gonzalez},
year={2023},
journal={arXiv preprint arXiv:2305.15334},
}
```
8 changes: 7 additions & 1 deletion eval/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,20 @@

## Get Started

### Getting GPT-3.5-turbo, GPT-4 and Claude Responses
### Getting GPT-3.5-turbo, GPT-4 and Claude Responses (0-Shot)

To get LLM responses for the API calls, use the following command:

```bash
python get_llm_responses.py --model gpt-3.5-turbo --api_key $API_KEY --output_file gpt-3.5-turbo_torchhub_0_shot.jsonl --question_data eval-data/questions/torchhub/questions_torchhub_0_shot.jsonl --api_name torchhub
```

### Getting Responses with Retrievers (`bm25` or `gpt`)

```bash
python get_llm_responses_retriever.py --retriever bm25 --model gpt-3.5-turbo --api_key $API_KEY --output_file gpt-3.5-turbo_torchhub_0_shot.jsonl --question_data eval-data/questions/torchhub/questions_torchhub_0_shot.jsonl --api_name torchhub --api_dataset ../data/api/torchhub_api.jsonl
```

### Evaluate the Response with AST tree matching

After the responses of the LLM is generated, we can start to evaluate the generated responses with respect to our dataset:
Expand Down
29 changes: 15 additions & 14 deletions eval/eval-scripts/ast_eval_hf.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
# Copyright 2023 https://github.com/ShishirPatil/gorilla
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import json
from codebleu.parser import DFG_python,DFG_java,DFG_ruby,DFG_go,DFG_php,DFG_javascript,DFG_csharp
from codebleu.parser import (remove_comments_and_docstrings,
tree_to_token_index,
index_to_code_token,
tree_to_variable_index)

from tree_sitter import Language, Parser

dfg_function={
'python':DFG_python,
'java':DFG_java,
'ruby':DFG_ruby,
'go':DFG_go,
'php':DFG_php,
'javascript':DFG_javascript,
'c_sharp':DFG_csharp,
}

# Get all the subtrees given a root_node
def get_all_sub_trees(root_node):
Expand Down
29 changes: 14 additions & 15 deletions eval/eval-scripts/ast_eval_tf.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,21 @@
# Copyright 2023 https://github.com/ShishirPatil/gorilla
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import json
from codebleu.parser import DFG_python,DFG_java,DFG_ruby,DFG_go,DFG_php,DFG_javascript,DFG_csharp
from codebleu.parser import (remove_comments_and_docstrings,
tree_to_token_index,
index_to_code_token,
tree_to_variable_index)
from tree_sitter import Language, Parser

dfg_function={
'python':DFG_python,
'java':DFG_java,
'ruby':DFG_ruby,
'go':DFG_go,
'php':DFG_php,
'javascript':DFG_javascript,
'c_sharp':DFG_csharp,
}

# Get all the subtrees given a root_node
def get_all_sub_trees(root_node):
node_stack = []
Expand Down
40 changes: 15 additions & 25 deletions eval/eval-scripts/ast_eval_th.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,22 @@
# Copyright 2023 https://github.com/ShishirPatil/gorilla
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import json
from codebleu.parser import (
DFG_python,
DFG_java,
DFG_ruby,
DFG_go,
DFG_php,
DFG_javascript,
DFG_csharp,
)
from codebleu.parser import (
remove_comments_and_docstrings,
tree_to_token_index,
index_to_code_token,
tree_to_variable_index,
)
from tree_sitter import Language, Parser
import concurrent.futures

dfg_function = {
"python": DFG_python,
"java": DFG_java,
"ruby": DFG_ruby,
"go": DFG_go,
"php": DFG_php,
"javascript": DFG_javascript,
"c_sharp": DFG_csharp,
}

# Get all the subtrees given a root_node
def get_all_sub_trees(root_node):
Expand Down Expand Up @@ -130,7 +120,7 @@ def process_response(response, api_database, qa_pairs, ast_database):
try:
output = response["text"]
except:
print("Error: cannot parse line ", response["index"])
print("Error: cannot parse line ", response)
return False, False

# Index the "api_call" domain
Expand Down
14 changes: 14 additions & 0 deletions eval/get_llm_responses.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# Copyright 2023 https://github.com/ShishirPatil/gorilla
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import sys
import json
Expand Down
165 changes: 165 additions & 0 deletions eval/get_llm_responses_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
# Copyright 2023 https://github.com/ShishirPatil/gorilla
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import sys
import json
import openai
import anthropic
import multiprocessing as mp
import os
import time
from retrievers import *
from retrievers.build_json_index import JSONLReader

def encode_question(question, api_name):
"""Encode multiple prompt instructions into a single string."""

prompts = []
if api_name == "torchhub":
domains = "1. $DOMAIN is inferred from the task description and should include one of {Classification, Semantic Segmentation, Object Detection, Audio Separation, Video Classification, Text-to-Speech}."
elif api_name == "huggingface":
domains = "1. $DOMAIN should include one of {Multimodal Feature Extraction, Multimodal Text-to-Image, Multimodal Image-to-Text, Multimodal Text-to-Video, \
Multimodal Visual Question Answering, Multimodal Document Question Answer, Multimodal Graph Machine Learning, Computer Vision Depth Estimation,\
Computer Vision Image Classification, Computer Vision Object Detection, Computer Vision Image Segmentation, Computer Vision Image-to-Image, \
Computer Vision Unconditional Image Generation, Computer Vision Video Classification, Computer Vision Zero-Shor Image Classification, \
Natural Language Processing Text Classification, Natural Language Processing Token Classification, Natural Language Processing Table Question Answering, \
Natural Language Processing Question Answering, Natural Language Processing Zero-Shot Classification, Natural Language Processing Translation, \
Natural Language Processing Summarization, Natural Language Processing Conversational, Natural Language Processing Text Generation, Natural Language Processing Fill-Mask,\
Natural Language Processing Text2Text Generation, Natural Language Processing Sentence Similarity, Audio Text-to-Speech, Audio Automatic Speech Recognition, \
Audio Audio-to-Audio, Audio Audio Classification, Audio Voice Activity Detection, Tabular Tabular Classification, Tabular Tabular Regression, \
Reinforcement Learning Reinforcement Learning, Reinforcement Learning Robotics }"
elif api_name == "tensorhub":
domains = "1. $DOMAIN is inferred from the task description and should include one of {text-sequence-alignment, text-embedding, text-language-model, text-preprocessing, text-classification, text-generation, text-question-answering, text-retrieval-question-answering, text-segmentation, text-to-mel, image-classification, image-feature-vector, image-object-detection, image-segmentation, image-generator, image-pose-detection, image-rnn-agent, image-augmentation, image-classifier, image-style-transfer, image-aesthetic-quality, image-depth-estimation, image-super-resolution, image-deblurring, image-extrapolation, image-text-recognition, image-dehazing, image-deraining, image-enhancemenmt, image-classification-logits, image-frame-interpolation, image-text-detection, image-denoising, image-others, video-classification, video-feature-extraction, video-generation, video-audio-text, video-text, audio-embedding, audio-event-classification, audio-command-detection, audio-paralinguists-classification, audio-speech-to-text, audio-speech-synthesis, audio-synthesis, audio-pitch-extraction}"
else:
print("Error: API name is not supported.")

prompt = question + "\nWrite a python program in 1 to 2 lines to call API in " + api_name + ".\n\nThe answer should follow the format: <<<domain>>> $DOMAIN, <<<api_call>>>: $API_CALL, <<<api_provider>>>: $API_PROVIDER, <<<explanation>>>: $EXPLANATION, <<<code>>>: $CODE}. Here are the requirements:\n" + domains + "\n2. The $API_CALL should have only 1 line of code that calls api.\n3. The $API_PROVIDER should be the programming framework used.\n4. $EXPLANATION should be a step-by-step explanation.\n5. The $CODE is the python code.\n6. Do not repeat the format in your answer."
prompts.append({"role": "system", "content": "You are a helpful API writer who can write APIs based on requirements."})
prompts.append({"role": "user", "content": prompt})
return prompts

def get_response(get_response_input, api_key):
question, question_id, api_name, model, retrieved_doc = get_response_input
question = encode_question(question, api_name)
question[-1]["content"] = question[-1]["content"] + "\nHere are some reference docs:"
for i, doc in enumerate(retrieved_doc):
question[-1]["content"] = question[-1]["content"] + "\nAPI " + str(i) + ": " + str(doc)

try:
if "gpt" in model:
openai.api_key = api_key
responses = openai.ChatCompletion.create(
model=model,
messages=question,
n=1,
temperature=0,
)
response = responses['choices'][0]['message']['content']
elif "claude" in model:
client = anthropic.Client(api_key)
responses = client.completion(
prompt=f"{anthropic.HUMAN_PROMPT} {question[0]['content']}{question[1]['content']}{anthropic.AI_PROMPT}",
stop_sequences=[anthropic.HUMAN_PROMPT],
model="claude-v1",
max_tokens_to_sample=2048,
)
response = responses["completion"].strip()
else:
print("Error: Model is not supported.")
except Exception as e:
print("Error:", e)
return None

print("=>",)
return {'text': response, "question_id": question_id, "answer_id": "None", "model_id": model, "metadata": {}}

def process_entry(entry, api_key):
question, question_id, api_name, model, retriever = entry
retrieved_doc = retriever.get_relevant_documents(question)
result = get_response((question, question_id, api_name, model, retrieved_doc), api_key)
return result

def write_result_to_file(result, output_file):
global file_write_lock
with file_write_lock:
with open(output_file, "a") as outfile:
json.dump(result, outfile)
outfile.write("\n")

def callback_with_lock(result, output_file):
global file_write_lock
write_result_to_file(result, output_file, file_write_lock)

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default=None, help="which model you want to use for eval, only support ['gpt*', 'claude*'] now")
parser.add_argument("--api_key", type=str, default=None, help="the api key provided for calling")
parser.add_argument("--output_file", type=str, default=None, help="the output file this script writes to")
parser.add_argument("--question_data", type=str, default=None, help="path to the questions data file")
parser.add_argument("--api_name", type=str, default=None, help="this will be the api dataset name you are testing, only support ['torchhub', 'tensorhun', 'huggingface'] now")
parser.add_argument("--retriever", type=str, default="bm25", help="which retriever to use")
parser.add_argument("--num_doc", type=int, default=1, help="top k docs to use")
parser.add_argument("--api_dataset", type=str, default=None, help="path to the api data")
args = parser.parse_args()

assert args.retriever in ["bm25", "gpt"]
if args.retriever == "gpt":
retriever = GPTRetriever(query_kwargs={"similarity_top_k": args.num_doc})
if os.path.exists(args.retriever + '_dataset_index.json'):
print('data index already saved')
os.environ["OPENAI_API_KEY"] = args.api_key
index = retriever.load_from_disk(args.retriever + '_dataset_index.json')
else:
print('data index being created')
os.environ["OPENAI_API_KEY"] = args.api_key
documents = JSONLReader().load_data(args.api_dataset)
index = retriever.from_documents(documents)
retriever.save_to_disk(index, args.retriever + '_dataset_index.json')
elif args.retriever == "bm25":
from rank_bm25 import BM25Okapi
corpus = []
with open(args.api_dataset, 'r') as f:
for line in f:
corpus.append(json.loads(line))
tokenized_corpus = [str(doc).split(" ") for doc in corpus]
bm25 = BM25Okapi(tokenized_corpus)
retriever = BM25Retriever(index=bm25, corpus=corpus, query_kwargs={"similarity_top_k": args.num_doc})
else:
assert False

start_time = time.time()
# Read the question file
questions = []
question_ids = []
with open(args.question_data, 'r') as f:
for idx, line in enumerate(f):
questions.append(json.loads(line)["text"])
question_ids.append(json.loads(line)["question_id"])

file_write_lock = mp.Lock()
with mp.Pool(1) as pool:
results = []
for idx, (question, question_id) in enumerate(zip(questions, question_ids)):
result = pool.apply_async(
process_entry,
args=((question, question_id, args.api_name, args.model, retriever), args.api_key),
callback=lambda result: write_result_to_file(result, args.output_file),
)
results.append(result)
pool.close()
pool.join()

end_time = time.time()
print("Total time used: ", end_time - start_time)
7 changes: 7 additions & 0 deletions eval/retrievers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from retrievers.gpt import GPTRetriever
from retrievers.bm25 import BM25Retriever

__all__ = [
"GPTRetriever",
"BM25Retriever",
]
Loading

0 comments on commit c849d11

Please sign in to comment.