diff --git a/raft/client_utils.py b/raft/client_utils.py index 4726036ade..af4c0d74f8 100644 --- a/raft/client_utils.py +++ b/raft/client_utils.py @@ -3,6 +3,9 @@ from dotenv import load_dotenv from langchain_openai import OpenAIEmbeddings, AzureOpenAIEmbeddings from openai import AzureOpenAI, OpenAI +import logging + +logger = logging.getLogger("client_utils") load_dotenv() # take environment variables from .env. @@ -29,7 +32,7 @@ def build_langchain_embeddings(**kwargs: Any) -> OpenAIEmbeddings: def is_azure(): azure = "AZURE_OPENAI_ENDPOINT" in env or "AZURE_OPENAI_KEY" in env or "AZURE_OPENAI_AD_TOKEN" in env if azure: - print("Using Azure OpenAI environment variables") + logger.debug("Using Azure OpenAI environment variables") else: - print("Using OpenAI environment variables.") + logger.debug("Using OpenAI environment variables") return azure diff --git a/raft/logconf.py b/raft/logconf.py new file mode 100644 index 0000000000..15259c92de --- /dev/null +++ b/raft/logconf.py @@ -0,0 +1,35 @@ +import logging +import logging.config +import os +import yaml + +def log_setup(): + """ + Set up basic console logging. Root logger level can be set with ROOT_LOG_LEVEL environment variable. + """ + + # Load the config file + with open(os.getenv('LOGGING_CONFIG', 'logging.yaml'), 'rt') as f: + config = yaml.safe_load(f.read()) + + # Configure the logging module with the config file + logging.config.dictConfig(config) + + install_default_record_field(logging, 'progress', '') + + +def install_default_record_field(logging, field, value): + """ + Wraps the log record factory to add a default progress field value + Required to avoid a KeyError when the progress field is not set + Such as when logging from a different thread + """ + old_factory = logging.getLogRecordFactory() + + def record_factory(*args, **kwargs): + record = old_factory(*args, **kwargs) + if not hasattr(record, field): + record.progress = value + return record + + logging.setLogRecordFactory(record_factory) diff --git a/raft/logging.yaml b/raft/logging.yaml new file mode 100644 index 0000000000..a13ee5f8b2 --- /dev/null +++ b/raft/logging.yaml @@ -0,0 +1,34 @@ +version: 1 +disable_existing_loggers: False + +formatters: + simple: + format: '%(asctime)s %(levelname)5s [%(progress)4s] %(name)s %(message)s' + colored: + format: "%(asctime)s %(levelname)5s [%(progress)4s] %(name)s %(message)s" + class: coloredlogs.ColoredFormatter + +handlers: + console: + class: logging.StreamHandler + level: INFO + formatter: colored + stream: ext://sys.stdout + + file: + class: logging.FileHandler + level: DEBUG + formatter: simple + filename: raft.log + +root: + level: INFO + handlers: [console, file] + +loggers: + raft: + level: INFO + langchain_community.utils.math: + level: INFO + httpx: + level: WARN diff --git a/raft/raft.py b/raft/raft.py index 3ac931c45d..d055d53699 100644 --- a/raft/raft.py +++ b/raft/raft.py @@ -1,3 +1,7 @@ +import mdc +from mdc import MDC +from logconf import log_setup +import logging from typing import Literal, Any import argparse from openai import OpenAI @@ -9,6 +13,11 @@ from langchain_experimental.text_splitter import SemanticChunker from langchain_openai.embeddings import OpenAIEmbeddings from client_utils import build_openai_client, build_langchain_embeddings +from math import ceil + +log_setup() + +logger = logging.getLogger("raft") DocType = Literal["api", "pdf", "json", "txt"] @@ -45,7 +54,9 @@ def get_chunks( `chunk_size`, and returns the chunks. """ chunks = [] - + + logger.info(f"Retrieving chunks from {file_path} of type {doctype}") + if doctype == "api": with open(file_path) as f: api_docs_json = json.load(f) @@ -76,16 +87,17 @@ def get_chunks( else: raise TypeError("Document is not one of the accepted types: api, pdf, json, txt") - num_chunks = len(text) / chunk_size + num_chunks = ceil(len(text) / chunk_size) + logger.info(f"Splitting text into {num_chunks} chunks using the {model} model.") - embeddings = build_langchain_embeddings(openai_api_key=OPENAPI_API_KEY, model=model) + embeddings = build_langchain_embeddings(openai_api_key=openai_key, model=model) text_splitter = SemanticChunker(embeddings, number_of_chunks=num_chunks) chunks = text_splitter.create_documents([text]) chunks = [chunk.page_content for chunk in chunks] return chunks -def generate_instructions(api_call: Any, x=5, model: str = None) -> list[str]: +def generate_instructions(client: OpenAI, api_call: Any, x=5, model: str = None) -> list[str]: """ Generates `x` questions / use cases for `api_call`. Used when the input document is of type `api`. """ @@ -106,7 +118,7 @@ def generate_instructions(api_call: Any, x=5, model: str = None) -> list[str]: return queries -def generate_instructions_gen(chunk: Any, x: int = 5, model: str = None) -> list[str]: +def generate_instructions_gen(client: OpenAI, chunk: Any, x: int = 5, model: str = None) -> list[str]: """ Generates `x` questions / use cases for `chunk`. Used when the input document is of general types `pdf`, `json`, or `txt`. @@ -171,7 +183,7 @@ def encode_question_gen(question: str, chunk: Any) -> list[str]: prompts.append({"role": "user", "content": prompt}) return prompts -def generate_label(question: str, context: Any, doctype: DocType = "pdf", model: str = None) -> str | None: +def generate_label(client: OpenAI, question: str, context: Any, doctype: DocType = "pdf", model: str = None) -> str | None: """ Generates the label / answer to `question` using `context` and GPT-4. """ @@ -186,6 +198,7 @@ def generate_label(question: str, context: Any, doctype: DocType = "pdf", model: return response def add_chunk_to_dataset( + client: OpenAI, chunks: list[str], chunk: str, doctype: DocType = "api", @@ -199,7 +212,7 @@ def add_chunk_to_dataset( """ global ds i = chunks.index(chunk) - qs = generate_instructions(chunk, x, model) if doctype == "api" else generate_instructions_gen(chunk, x, model) + qs = generate_instructions(client, chunk, x, model) if doctype == "api" else generate_instructions_gen(client, chunk, x, model) for q in qs: datapt = { "id": None, @@ -237,7 +250,7 @@ def add_chunk_to_dataset( datapt["oracle_context"] = chunk # add answer to q - datapt["cot_answer"] = generate_label(q, chunk, doctype, model=model) + datapt["cot_answer"] = generate_label(client, q, chunk, doctype, model=model) # construct model instruction context = "" @@ -260,8 +273,9 @@ def add_chunk_to_dataset( else: ds = ds.add_item(datapt) +def main(): + global ds -if __name__ == "__main__": # run code args = get_args() @@ -278,11 +292,19 @@ def add_chunk_to_dataset( ds = None - for chunk in chunks: - add_chunk_to_dataset(chunks, chunk, args.doctype, args.questions, NUM_DISTRACT_DOCS, model=args.completion_model) + num_chunks = len(chunks) + for i, chunk in enumerate(chunks): + perc = ceil(i / num_chunks * 100) + with MDC(progress=f"{perc}%"): + logger.info(f"Adding chunk {i}/{num_chunks}") + add_chunk_to_dataset(client, chunks, chunk, args.doctype, args.questions, NUM_DISTRACT_DOCS, model=args.completion_model) # Save as .arrow format ds.save_to_disk(args.output) # Save as .jsonl format ds.to_json(args.output + ".jsonl") + +if __name__ == "__main__": + with MDC(progress="0%"): + main() diff --git a/raft/requirements.txt b/raft/requirements.txt index 02da64aa65..a70cfa4d3c 100644 Binary files a/raft/requirements.txt and b/raft/requirements.txt differ