Skip to content

Commit

Permalink
Add metric_threshold argument to BootstrapFewShot
Browse files Browse the repository at this point in the history
  • Loading branch information
CShorten committed Feb 19, 2024
1 parent 7a6421a commit 68fe064
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions dspy/teleprompt/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@


class BootstrapFewShot(Teleprompter):
def __init__(self, metric=None, teacher_settings={}, max_bootstrapped_demos=4, max_labeled_demos=16, max_rounds=1, max_errors=5):
def __init__(self, metric=None, metric_threshold=None, teacher_settings={}, max_bootstrapped_demos=4, max_labeled_demos=16, max_rounds=1, max_errors=5):
self.metric = metric
self.metric_threshold = metric_threshold
self.teacher_settings = teacher_settings

self.max_bootstrapped_demos = max_bootstrapped_demos
Expand Down Expand Up @@ -147,8 +148,11 @@ def _bootstrap_one_example(self, example, round_idx=0):

for name, predictor in teacher.named_predictors():
predictor.demos = predictor_cache[name]

success = (self.metric is None) or self.metric(example, prediction, trace)

if self.metric and self.metric_threshold:
success = self.metric(example, prediction, trace) > self.metric_threshold
else:
success = True
# print(success, example, prediction)
except Exception as e:
success = False
Expand Down

0 comments on commit 68fe064

Please sign in to comment.