Skip to content

Commit

Permalink
Using questions to track progress if qa_threshold is set
Browse files Browse the repository at this point in the history
  • Loading branch information
cedricvidal committed Aug 25, 2024
1 parent e95d6f4 commit f1b84f6
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 6 deletions.
5 changes: 5 additions & 0 deletions raft/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,17 @@ def load_checkpoint(self, num: int):

def get_checkpoints(self) -> List[Checkpoint]:
checkpoints = []
if not self.checkpoints_dir.exists():
return checkpoints
for dir_path in self.checkpoints_dir.iterdir():
if dir_path.is_dir() and dir_path.name.startswith("checkpoint-"):
num = int(dir_path.name.split("-")[1])
checkpoints.append(Checkpoint(dir_path, num))
return checkpoints

def has_checkpoints(self) -> bool:
return len(self.get_checkpoints()) > 0

def collect_checkpoints(self) -> Dataset:
ds_list = list([checkpoint.load() for checkpoint in self.get_checkpoints()])
ds = concatenate_datasets(ds_list)
Expand Down
35 changes: 29 additions & 6 deletions raft/raft.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,13 +547,29 @@ def process_chunk(i):
# we set the tqdm's initial position to avoid having cached data skew the stats
missing_chunks = answers_checkpointing.missing_checkpoints(num_chunks)

ds = answers_checkpointing.collect_checkpoints()
gen_questions_count = len(ds)
gen_questions_count = 0
if answers_checkpointing.has_checkpoints():
ds = answers_checkpointing.collect_checkpoints()
gen_questions_count = len(ds)

done_chunks = num_chunks - len(missing_chunks)
logger.info(f"Resuming generation from chunk {done_chunks}/{num_chunks} and {gen_questions_count} questions")
if done_chunks > 0 or gen_questions_count > 0:
logger.info(f"Resuming generation from chunk {done_chunks}/{num_chunks} and {gen_questions_count} questions")

# If we have a QA threshold, it makes more sense to keep track of the number of questions generated
# Otherwise, track chunks
track_questions = qa_threshold is not None

if qa_threshold:
logger.info(f"Will stop early as soon as the QA threshold is met: {qa_threshold}")

if track_questions:
tqdm_args = {"total": qa_threshold, "unit": "qa", "initial": gen_questions_count}
else:
tqdm_args = {"total": num_chunks, "unit": "chunk", "initial": done_chunks}

tps = 0
with tqdm(total=num_chunks, desc="Generating", unit="chunk", initial=done_chunks) as pbar:
with tqdm(desc="Generating", **tqdm_args) as pbar:
with ThreadPoolExecutor(max_workers=max_workers) as executor:
for i in missing_chunks:
futures.append(executor.submit(process_chunk, i))
Expand All @@ -564,13 +580,20 @@ def process_chunk(i):
break
answers_ds = future.result()
answers_ds_list.append(answers_ds)
increment = min(len(answers_ds), qa_threshold - gen_questions_count) if track_questions else 1
gen_questions_count += len(answers_ds)
done_chunks += 1
stats = chat_completer.get_stats_and_reset()
if stats:
tps = stats.total_tokens / stats.duration
usage_stats += stats
pbar.set_postfix({'qa': gen_questions_count, 'last tok/s': tps, 'avg tok/s': usage_stats.total_tokens / usage_stats.duration if usage_stats.duration > 0 else 0})
pbar.update(1)
postfix = {'last tok/s': tps, 'avg tok/s': usage_stats.total_tokens / usage_stats.duration if usage_stats.duration > 0 else 0}
if track_questions:
postfix['chunks'] = done_chunks
else:
postfix['qa'] = gen_questions_count
pbar.set_postfix(postfix)
pbar.update(increment)

ds = answers_checkpointing.collect_checkpoints()
ds = ds.select(range(qa_threshold)) if qa_threshold else ds
Expand Down

0 comments on commit f1b84f6

Please sign in to comment.