From f1b84f606dd5c24c8c94b2f50ddcd0c3a1db5270 Mon Sep 17 00:00:00 2001 From: Cedric Vidal Date: Sun, 25 Aug 2024 15:05:39 +0000 Subject: [PATCH] Using questions to track progress if qa_threshold is set --- raft/checkpointing.py | 5 +++++ raft/raft.py | 35 +++++++++++++++++++++++++++++------ 2 files changed, 34 insertions(+), 6 deletions(-) diff --git a/raft/checkpointing.py b/raft/checkpointing.py index 661306b78..e4434f265 100644 --- a/raft/checkpointing.py +++ b/raft/checkpointing.py @@ -44,12 +44,17 @@ def load_checkpoint(self, num: int): def get_checkpoints(self) -> List[Checkpoint]: checkpoints = [] + if not self.checkpoints_dir.exists(): + return checkpoints for dir_path in self.checkpoints_dir.iterdir(): if dir_path.is_dir() and dir_path.name.startswith("checkpoint-"): num = int(dir_path.name.split("-")[1]) checkpoints.append(Checkpoint(dir_path, num)) return checkpoints + def has_checkpoints(self) -> bool: + return len(self.get_checkpoints()) > 0 + def collect_checkpoints(self) -> Dataset: ds_list = list([checkpoint.load() for checkpoint in self.get_checkpoints()]) ds = concatenate_datasets(ds_list) diff --git a/raft/raft.py b/raft/raft.py index 66e2615b4..faa29fbe3 100644 --- a/raft/raft.py +++ b/raft/raft.py @@ -547,13 +547,29 @@ def process_chunk(i): # we set the tqdm's initial position to avoid having cached data skew the stats missing_chunks = answers_checkpointing.missing_checkpoints(num_chunks) - ds = answers_checkpointing.collect_checkpoints() - gen_questions_count = len(ds) + gen_questions_count = 0 + if answers_checkpointing.has_checkpoints(): + ds = answers_checkpointing.collect_checkpoints() + gen_questions_count = len(ds) + done_chunks = num_chunks - len(missing_chunks) - logger.info(f"Resuming generation from chunk {done_chunks}/{num_chunks} and {gen_questions_count} questions") + if done_chunks > 0 or gen_questions_count > 0: + logger.info(f"Resuming generation from chunk {done_chunks}/{num_chunks} and {gen_questions_count} questions") + + # If we have a QA threshold, it makes more sense to keep track of the number of questions generated + # Otherwise, track chunks + track_questions = qa_threshold is not None + + if qa_threshold: + logger.info(f"Will stop early as soon as the QA threshold is met: {qa_threshold}") + + if track_questions: + tqdm_args = {"total": qa_threshold, "unit": "qa", "initial": gen_questions_count} + else: + tqdm_args = {"total": num_chunks, "unit": "chunk", "initial": done_chunks} tps = 0 - with tqdm(total=num_chunks, desc="Generating", unit="chunk", initial=done_chunks) as pbar: + with tqdm(desc="Generating", **tqdm_args) as pbar: with ThreadPoolExecutor(max_workers=max_workers) as executor: for i in missing_chunks: futures.append(executor.submit(process_chunk, i)) @@ -564,13 +580,20 @@ def process_chunk(i): break answers_ds = future.result() answers_ds_list.append(answers_ds) + increment = min(len(answers_ds), qa_threshold - gen_questions_count) if track_questions else 1 gen_questions_count += len(answers_ds) + done_chunks += 1 stats = chat_completer.get_stats_and_reset() if stats: tps = stats.total_tokens / stats.duration usage_stats += stats - pbar.set_postfix({'qa': gen_questions_count, 'last tok/s': tps, 'avg tok/s': usage_stats.total_tokens / usage_stats.duration if usage_stats.duration > 0 else 0}) - pbar.update(1) + postfix = {'last tok/s': tps, 'avg tok/s': usage_stats.total_tokens / usage_stats.duration if usage_stats.duration > 0 else 0} + if track_questions: + postfix['chunks'] = done_chunks + else: + postfix['qa'] = gen_questions_count + pbar.set_postfix(postfix) + pbar.update(increment) ds = answers_checkpointing.collect_checkpoints() ds = ds.select(range(qa_threshold)) if qa_threshold else ds