Skip to content

Commit

Permalink
Colored logging configuration + displaying progress in logs (ShishirP…
Browse files Browse the repository at this point in the history
…atil#384)

**Logging config**
- Colored logging using `coloredlogs` package
- Logging configuration loaded from YAML config file, by default
`logging.yaml`
- Logging configuration YAML file location overridable with
`LOGGING_CONFIG` env var

**Displaying progress in logs**
- Added mdc dependency
- Progress attached to MDC and included in logging format message
- A default progress empty value is provided to avoid a KeyError when
the progress field is not set, such as when logging from a different
thread

Here is what it looks like when we're hitting quota limits and getting
some retries, it displays progress and colored logs:
![Screenshot 2024-04-23 at 8 33
07 PM](https://github.com/ShishirPatil/gorilla/assets/33618/591a8093-0b02-451a-96bb-9b639d0a8fc5)

Here is what it looks like when it runs smoothly:
![Screenshot 2024-04-25 at 8 53
02 PM](https://github.com/ShishirPatil/gorilla/assets/33618/d022c38d-a2c5-4b02-8d0f-06b8d8275c12)

**Tests**

- Tested non regression with OpenAI API
- Tested with Azure AI Resource Real Time `gpt-3.5-turbo` and
`text-embedding-ada-002` deployment

**Note**: this PR depends on ShishirPatil#381
  • Loading branch information
cedricvidal authored Apr 26, 2024
1 parent 6acc440 commit 0d632da
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 13 deletions.
7 changes: 5 additions & 2 deletions raft/client_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
from dotenv import load_dotenv
from langchain_openai import OpenAIEmbeddings, AzureOpenAIEmbeddings
from openai import AzureOpenAI, OpenAI
import logging

logger = logging.getLogger("client_utils")

load_dotenv() # take environment variables from .env.

Expand All @@ -29,7 +32,7 @@ def build_langchain_embeddings(**kwargs: Any) -> OpenAIEmbeddings:
def is_azure():
azure = "AZURE_OPENAI_ENDPOINT" in env or "AZURE_OPENAI_KEY" in env or "AZURE_OPENAI_AD_TOKEN" in env
if azure:
print("Using Azure OpenAI environment variables")
logger.debug("Using Azure OpenAI environment variables")
else:
print("Using OpenAI environment variables.")
logger.debug("Using OpenAI environment variables")
return azure
35 changes: 35 additions & 0 deletions raft/logconf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import logging
import logging.config
import os
import yaml

def log_setup():
"""
Set up basic console logging. Root logger level can be set with ROOT_LOG_LEVEL environment variable.
"""

# Load the config file
with open(os.getenv('LOGGING_CONFIG', 'logging.yaml'), 'rt') as f:
config = yaml.safe_load(f.read())

# Configure the logging module with the config file
logging.config.dictConfig(config)

install_default_record_field(logging, 'progress', '')


def install_default_record_field(logging, field, value):
"""
Wraps the log record factory to add a default progress field value
Required to avoid a KeyError when the progress field is not set
Such as when logging from a different thread
"""
old_factory = logging.getLogRecordFactory()

def record_factory(*args, **kwargs):
record = old_factory(*args, **kwargs)
if not hasattr(record, field):
record.progress = value
return record

logging.setLogRecordFactory(record_factory)
34 changes: 34 additions & 0 deletions raft/logging.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
version: 1
disable_existing_loggers: False

formatters:
simple:
format: '%(asctime)s %(levelname)5s [%(progress)4s] %(name)s %(message)s'
colored:
format: "%(asctime)s %(levelname)5s [%(progress)4s] %(name)s %(message)s"
class: coloredlogs.ColoredFormatter

handlers:
console:
class: logging.StreamHandler
level: INFO
formatter: colored
stream: ext://sys.stdout

file:
class: logging.FileHandler
level: DEBUG
formatter: simple
filename: raft.log

root:
level: INFO
handlers: [console, file]

loggers:
raft:
level: INFO
langchain_community.utils.math:
level: INFO
httpx:
level: WARN
44 changes: 33 additions & 11 deletions raft/raft.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import mdc
from mdc import MDC
from logconf import log_setup
import logging
from typing import Literal, Any
import argparse
from openai import OpenAI
Expand All @@ -9,6 +13,11 @@
from langchain_experimental.text_splitter import SemanticChunker
from langchain_openai.embeddings import OpenAIEmbeddings
from client_utils import build_openai_client, build_langchain_embeddings
from math import ceil

log_setup()

logger = logging.getLogger("raft")

DocType = Literal["api", "pdf", "json", "txt"]

Expand Down Expand Up @@ -45,7 +54,9 @@ def get_chunks(
`chunk_size`, and returns the chunks.
"""
chunks = []


logger.info(f"Retrieving chunks from {file_path} of type {doctype}")

if doctype == "api":
with open(file_path) as f:
api_docs_json = json.load(f)
Expand Down Expand Up @@ -76,16 +87,17 @@ def get_chunks(
else:
raise TypeError("Document is not one of the accepted types: api, pdf, json, txt")

num_chunks = len(text) / chunk_size
num_chunks = ceil(len(text) / chunk_size)
logger.info(f"Splitting text into {num_chunks} chunks using the {model} model.")

embeddings = build_langchain_embeddings(openai_api_key=OPENAPI_API_KEY, model=model)
embeddings = build_langchain_embeddings(openai_api_key=openai_key, model=model)
text_splitter = SemanticChunker(embeddings, number_of_chunks=num_chunks)
chunks = text_splitter.create_documents([text])
chunks = [chunk.page_content for chunk in chunks]

return chunks

def generate_instructions(api_call: Any, x=5, model: str = None) -> list[str]:
def generate_instructions(client: OpenAI, api_call: Any, x=5, model: str = None) -> list[str]:
"""
Generates `x` questions / use cases for `api_call`. Used when the input document is of type `api`.
"""
Expand All @@ -106,7 +118,7 @@ def generate_instructions(api_call: Any, x=5, model: str = None) -> list[str]:

return queries

def generate_instructions_gen(chunk: Any, x: int = 5, model: str = None) -> list[str]:
def generate_instructions_gen(client: OpenAI, chunk: Any, x: int = 5, model: str = None) -> list[str]:
"""
Generates `x` questions / use cases for `chunk`. Used when the input document is of general types
`pdf`, `json`, or `txt`.
Expand Down Expand Up @@ -171,7 +183,7 @@ def encode_question_gen(question: str, chunk: Any) -> list[str]:
prompts.append({"role": "user", "content": prompt})
return prompts

def generate_label(question: str, context: Any, doctype: DocType = "pdf", model: str = None) -> str | None:
def generate_label(client: OpenAI, question: str, context: Any, doctype: DocType = "pdf", model: str = None) -> str | None:
"""
Generates the label / answer to `question` using `context` and GPT-4.
"""
Expand All @@ -186,6 +198,7 @@ def generate_label(question: str, context: Any, doctype: DocType = "pdf", model:
return response

def add_chunk_to_dataset(
client: OpenAI,
chunks: list[str],
chunk: str,
doctype: DocType = "api",
Expand All @@ -199,7 +212,7 @@ def add_chunk_to_dataset(
"""
global ds
i = chunks.index(chunk)
qs = generate_instructions(chunk, x, model) if doctype == "api" else generate_instructions_gen(chunk, x, model)
qs = generate_instructions(client, chunk, x, model) if doctype == "api" else generate_instructions_gen(client, chunk, x, model)
for q in qs:
datapt = {
"id": None,
Expand Down Expand Up @@ -237,7 +250,7 @@ def add_chunk_to_dataset(
datapt["oracle_context"] = chunk

# add answer to q
datapt["cot_answer"] = generate_label(q, chunk, doctype, model=model)
datapt["cot_answer"] = generate_label(client, q, chunk, doctype, model=model)

# construct model instruction
context = ""
Expand All @@ -260,8 +273,9 @@ def add_chunk_to_dataset(
else:
ds = ds.add_item(datapt)

def main():
global ds

if __name__ == "__main__":
# run code
args = get_args()

Expand All @@ -278,11 +292,19 @@ def add_chunk_to_dataset(

ds = None

for chunk in chunks:
add_chunk_to_dataset(chunks, chunk, args.doctype, args.questions, NUM_DISTRACT_DOCS, model=args.completion_model)
num_chunks = len(chunks)
for i, chunk in enumerate(chunks):
perc = ceil(i / num_chunks * 100)
with MDC(progress=f"{perc}%"):
logger.info(f"Adding chunk {i}/{num_chunks}")
add_chunk_to_dataset(client, chunks, chunk, args.doctype, args.questions, NUM_DISTRACT_DOCS, model=args.completion_model)

# Save as .arrow format
ds.save_to_disk(args.output)

# Save as .jsonl format
ds.to_json(args.output + ".jsonl")

if __name__ == "__main__":
with MDC(progress="0%"):
main()
Binary file modified raft/requirements.txt
Binary file not shown.

0 comments on commit 0d632da

Please sign in to comment.