Skip to content

Commit

Permalink
Displaying progress in logs
Browse files Browse the repository at this point in the history
- Added mdc dependency
  • Loading branch information
cedricvidal committed Apr 25, 2024
1 parent 4fea62d commit a7ab3b3
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 18 deletions.
19 changes: 19 additions & 0 deletions raft/logconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,22 @@ def log_setup():

# Configure the logging module with the config file
logging.config.dictConfig(config)

install_default_record_field(logging, 'progress', '')


def install_default_record_field(logging, field, value):
"""
Wraps the log record factory to add a default progress field value
Required to avoid a KeyError when the progress field is not set
Such as when logging from a different thread
"""
old_factory = logging.getLogRecordFactory()

def record_factory(*args, **kwargs):
record = old_factory(*args, **kwargs)
if not hasattr(record, field):
record.progress = value
return record

logging.setLogRecordFactory(record_factory)
10 changes: 7 additions & 3 deletions raft/logging.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ disable_existing_loggers: False

formatters:
simple:
format: '%(asctime)s %(levelname)5s %(name)s %(message)s'
format: '%(asctime)s %(levelname)5s [%(progress)4s] %(name)s %(message)s'
colored:
format: '%(asctime)s %(levelname)5s %(name)s %(message)s'
format: "%(asctime)s %(levelname)5s [%(progress)4s] %(name)s %(message)s"
class: coloredlogs.ColoredFormatter

handlers:
Expand All @@ -22,9 +22,13 @@ handlers:
filename: raft.log

root:
level: DEBUG
level: INFO
handlers: [console, file]

loggers:
raft:
level: INFO
langchain_community.utils.math:
level: INFO
httpx:
level: WARN
41 changes: 26 additions & 15 deletions raft/raft.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import mdc
from mdc import MDC
from logconf import log_setup
import logging
from typing import Literal, Any
Expand All @@ -11,12 +13,11 @@
from langchain_experimental.text_splitter import SemanticChunker
from langchain_openai.embeddings import OpenAIEmbeddings
from client_utils import build_openai_client, build_langchain_embeddings
from math import ceil

log_setup()

logger_raft = logging.getLogger("raft")
logger_chunking = logging.getLogger("chunking")
logger_gen = logging.getLogger("gen")
logger = logging.getLogger("raft")

DocType = Literal["api", "pdf", "json", "txt"]

Expand Down Expand Up @@ -54,7 +55,7 @@ def get_chunks(
"""
chunks = []

logger_chunking.info(f"Retrieving chunks from {file_path} of type {doctype}")
logger.info(f"Retrieving chunks from {file_path} of type {doctype}")

if doctype == "api":
with open(file_path) as f:
Expand Down Expand Up @@ -86,17 +87,17 @@ def get_chunks(
else:
raise TypeError("Document is not one of the accepted types: api, pdf, json, txt")

num_chunks = len(text) / chunk_size
logger_chunking.info(f"Splitting text into {num_chunks} chunks using the {model} model.")
num_chunks = ceil(len(text) / chunk_size)
logger.info(f"Splitting text into {num_chunks} chunks using the {model} model.")

embeddings = build_langchain_embeddings(openai_api_key=OPENAPI_API_KEY, model=model)
embeddings = build_langchain_embeddings(openai_api_key=openai_key, model=model)
text_splitter = SemanticChunker(embeddings, number_of_chunks=num_chunks)
chunks = text_splitter.create_documents([text])
chunks = [chunk.page_content for chunk in chunks]

return chunks

def generate_instructions(api_call: Any, x=5, model: str = None) -> list[str]:
def generate_instructions(client: OpenAI, api_call: Any, x=5, model: str = None) -> list[str]:
"""
Generates `x` questions / use cases for `api_call`. Used when the input document is of type `api`.
"""
Expand All @@ -117,7 +118,7 @@ def generate_instructions(api_call: Any, x=5, model: str = None) -> list[str]:

return queries

def generate_instructions_gen(chunk: Any, x: int = 5, model: str = None) -> list[str]:
def generate_instructions_gen(client: OpenAI, chunk: Any, x: int = 5, model: str = None) -> list[str]:
"""
Generates `x` questions / use cases for `chunk`. Used when the input document is of general types
`pdf`, `json`, or `txt`.
Expand Down Expand Up @@ -182,7 +183,7 @@ def encode_question_gen(question: str, chunk: Any) -> list[str]:
prompts.append({"role": "user", "content": prompt})
return prompts

def generate_label(question: str, context: Any, doctype: DocType = "pdf", model: str = None) -> str | None:
def generate_label(client: OpenAI, question: str, context: Any, doctype: DocType = "pdf", model: str = None) -> str | None:
"""
Generates the label / answer to `question` using `context` and GPT-4.
"""
Expand All @@ -197,6 +198,7 @@ def generate_label(question: str, context: Any, doctype: DocType = "pdf", model:
return response

def add_chunk_to_dataset(
client: OpenAI,
chunks: list[str],
chunk: str,
doctype: DocType = "api",
Expand All @@ -210,7 +212,7 @@ def add_chunk_to_dataset(
"""
global ds
i = chunks.index(chunk)
qs = generate_instructions(chunk, x, model) if doctype == "api" else generate_instructions_gen(chunk, x, model)
qs = generate_instructions(client, chunk, x, model) if doctype == "api" else generate_instructions_gen(client, chunk, x, model)
for q in qs:
datapt = {
"id": None,
Expand Down Expand Up @@ -248,7 +250,7 @@ def add_chunk_to_dataset(
datapt["oracle_context"] = chunk

# add answer to q
datapt["cot_answer"] = generate_label(q, chunk, doctype, model=model)
datapt["cot_answer"] = generate_label(client, q, chunk, doctype, model=model)

# construct model instruction
context = ""
Expand All @@ -271,8 +273,9 @@ def add_chunk_to_dataset(
else:
ds = ds.add_item(datapt)

def main():
global ds

if __name__ == "__main__":
# run code
args = get_args()

Expand All @@ -289,11 +292,19 @@ def add_chunk_to_dataset(

ds = None

for chunk in chunks:
add_chunk_to_dataset(chunks, chunk, args.doctype, args.questions, NUM_DISTRACT_DOCS, model=args.completion_model)
num_chunks = len(chunks)
for i, chunk in enumerate(chunks):
perc = ceil(i / num_chunks)
with MDC(progress=f"{perc}%"):
logger.info(f"Adding chunk {i}/{num_chunks}")
add_chunk_to_dataset(client, chunks, chunk, args.doctype, args.questions, NUM_DISTRACT_DOCS, model=args.completion_model)

# Save as .arrow format
ds.save_to_disk(args.output)

# Save as .jsonl format
ds.to_json(args.output + ".jsonl")

if __name__ == "__main__":
with MDC(progress="0%"):
main()
Binary file modified raft/requirements.txt
Binary file not shown.

0 comments on commit a7ab3b3

Please sign in to comment.