Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Colored logging configuration + displaying progress in logs #384

Merged
merged 2 commits into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions raft/client_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
from dotenv import load_dotenv
from langchain_openai import OpenAIEmbeddings, AzureOpenAIEmbeddings
from openai import AzureOpenAI, OpenAI
import logging

logger = logging.getLogger("client_utils")

load_dotenv() # take environment variables from .env.

Expand All @@ -29,7 +32,7 @@ def build_langchain_embeddings(**kwargs: Any) -> OpenAIEmbeddings:
def is_azure():
azure = "AZURE_OPENAI_ENDPOINT" in env or "AZURE_OPENAI_KEY" in env or "AZURE_OPENAI_AD_TOKEN" in env
if azure:
print("Using Azure OpenAI environment variables")
logger.debug("Using Azure OpenAI environment variables")
else:
print("Using OpenAI environment variables.")
logger.debug("Using OpenAI environment variables")
return azure
35 changes: 35 additions & 0 deletions raft/logconf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import logging
import logging.config
import os
import yaml

def log_setup():
"""
Set up basic console logging. Root logger level can be set with ROOT_LOG_LEVEL environment variable.
"""

# Load the config file
with open(os.getenv('LOGGING_CONFIG', 'logging.yaml'), 'rt') as f:
config = yaml.safe_load(f.read())

# Configure the logging module with the config file
logging.config.dictConfig(config)

install_default_record_field(logging, 'progress', '')


def install_default_record_field(logging, field, value):
"""
Wraps the log record factory to add a default progress field value
Required to avoid a KeyError when the progress field is not set
Such as when logging from a different thread
"""
old_factory = logging.getLogRecordFactory()

def record_factory(*args, **kwargs):
record = old_factory(*args, **kwargs)
if not hasattr(record, field):
record.progress = value
return record

logging.setLogRecordFactory(record_factory)
34 changes: 34 additions & 0 deletions raft/logging.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
version: 1
disable_existing_loggers: False

formatters:
simple:
format: '%(asctime)s %(levelname)5s [%(progress)4s] %(name)s %(message)s'
colored:
format: "%(asctime)s %(levelname)5s [%(progress)4s] %(name)s %(message)s"
class: coloredlogs.ColoredFormatter

handlers:
console:
class: logging.StreamHandler
level: INFO
formatter: colored
stream: ext://sys.stdout

file:
class: logging.FileHandler
level: DEBUG
formatter: simple
filename: raft.log

root:
level: INFO
handlers: [console, file]

loggers:
raft:
level: INFO
langchain_community.utils.math:
level: INFO
httpx:
level: WARN
44 changes: 33 additions & 11 deletions raft/raft.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import mdc
from mdc import MDC
from logconf import log_setup
import logging
from typing import Literal, Any
import argparse
from openai import OpenAI
Expand All @@ -9,6 +13,11 @@
from langchain_experimental.text_splitter import SemanticChunker
from langchain_openai.embeddings import OpenAIEmbeddings
from client_utils import build_openai_client, build_langchain_embeddings
from math import ceil

log_setup()

logger = logging.getLogger("raft")

DocType = Literal["api", "pdf", "json", "txt"]

Expand Down Expand Up @@ -45,7 +54,9 @@ def get_chunks(
`chunk_size`, and returns the chunks.
"""
chunks = []


logger.info(f"Retrieving chunks from {file_path} of type {doctype}")

if doctype == "api":
with open(file_path) as f:
api_docs_json = json.load(f)
Expand Down Expand Up @@ -76,16 +87,17 @@ def get_chunks(
else:
raise TypeError("Document is not one of the accepted types: api, pdf, json, txt")

num_chunks = len(text) / chunk_size
num_chunks = ceil(len(text) / chunk_size)
logger.info(f"Splitting text into {num_chunks} chunks using the {model} model.")

embeddings = build_langchain_embeddings(openai_api_key=OPENAPI_API_KEY, model=model)
embeddings = build_langchain_embeddings(openai_api_key=openai_key, model=model)
text_splitter = SemanticChunker(embeddings, number_of_chunks=num_chunks)
chunks = text_splitter.create_documents([text])
chunks = [chunk.page_content for chunk in chunks]

return chunks

def generate_instructions(api_call: Any, x=5, model: str = None) -> list[str]:
def generate_instructions(client: OpenAI, api_call: Any, x=5, model: str = None) -> list[str]:
"""
Generates `x` questions / use cases for `api_call`. Used when the input document is of type `api`.
"""
Expand All @@ -106,7 +118,7 @@ def generate_instructions(api_call: Any, x=5, model: str = None) -> list[str]:

return queries

def generate_instructions_gen(chunk: Any, x: int = 5, model: str = None) -> list[str]:
def generate_instructions_gen(client: OpenAI, chunk: Any, x: int = 5, model: str = None) -> list[str]:
"""
Generates `x` questions / use cases for `chunk`. Used when the input document is of general types
`pdf`, `json`, or `txt`.
Expand Down Expand Up @@ -171,7 +183,7 @@ def encode_question_gen(question: str, chunk: Any) -> list[str]:
prompts.append({"role": "user", "content": prompt})
return prompts

def generate_label(question: str, context: Any, doctype: DocType = "pdf", model: str = None) -> str | None:
def generate_label(client: OpenAI, question: str, context: Any, doctype: DocType = "pdf", model: str = None) -> str | None:
"""
Generates the label / answer to `question` using `context` and GPT-4.
"""
Expand All @@ -186,6 +198,7 @@ def generate_label(question: str, context: Any, doctype: DocType = "pdf", model:
return response

def add_chunk_to_dataset(
client: OpenAI,
chunks: list[str],
chunk: str,
doctype: DocType = "api",
Expand All @@ -199,7 +212,7 @@ def add_chunk_to_dataset(
"""
global ds
i = chunks.index(chunk)
qs = generate_instructions(chunk, x, model) if doctype == "api" else generate_instructions_gen(chunk, x, model)
qs = generate_instructions(client, chunk, x, model) if doctype == "api" else generate_instructions_gen(client, chunk, x, model)
for q in qs:
datapt = {
"id": None,
Expand Down Expand Up @@ -237,7 +250,7 @@ def add_chunk_to_dataset(
datapt["oracle_context"] = chunk

# add answer to q
datapt["cot_answer"] = generate_label(q, chunk, doctype, model=model)
datapt["cot_answer"] = generate_label(client, q, chunk, doctype, model=model)

# construct model instruction
context = ""
Expand All @@ -260,8 +273,9 @@ def add_chunk_to_dataset(
else:
ds = ds.add_item(datapt)

def main():
global ds

if __name__ == "__main__":
# run code
args = get_args()

Expand All @@ -278,11 +292,19 @@ def add_chunk_to_dataset(

ds = None

for chunk in chunks:
add_chunk_to_dataset(chunks, chunk, args.doctype, args.questions, NUM_DISTRACT_DOCS, model=args.completion_model)
num_chunks = len(chunks)
for i, chunk in enumerate(chunks):
perc = ceil(i / num_chunks * 100)
with MDC(progress=f"{perc}%"):
logger.info(f"Adding chunk {i}/{num_chunks}")
add_chunk_to_dataset(client, chunks, chunk, args.doctype, args.questions, NUM_DISTRACT_DOCS, model=args.completion_model)

# Save as .arrow format
ds.save_to_disk(args.output)

# Save as .jsonl format
ds.to_json(args.output + ".jsonl")

if __name__ == "__main__":
with MDC(progress="0%"):
main()
Binary file modified raft/requirements.txt
Binary file not shown.