Skip to content

Commit

Permalink
Use dspy.Embedding for KNN (#1822)
Browse files Browse the repository at this point in the history
* Use dspy.Embedding for KNN

* Remove type from KNNFewShot

---------

Co-authored-by: Cyrus Nouroozi <[email protected]>
  • Loading branch information
CyrusNuevoDia and CyrusNuevoDia authored Nov 19, 2024
1 parent 62526fc commit 76a9b27
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 12 deletions.
34 changes: 24 additions & 10 deletions dspy/predict/knn.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,39 @@
from typing import List, Optional
from typing import List

import numpy as np

import dsp


class KNN:
def __init__(self, k: int, trainset: List[dsp.Example], vectorizer: Optional[dsp.BaseSentenceVectorizer] = None):
def __init__(self, k: int, trainset: List[dsp.Example], vectorizer=None):
"""
A k-nearest neighbors retriever that finds similar examples from a training set.
Args:
k: Number of nearest neighbors to retrieve
trainset: List of training examples to search through
vectorizer: Optional dspy.Embedding for computing embeddings. If None, uses sentence-transformers.
Example:
>>> trainset = [dsp.Example(input="hello", output="world"), ...]
>>> knn = KNN(k=3, trainset=trainset)
>>> similar_examples = knn(input="hello")
"""
import dspy

self.k = k
self.trainset = trainset
self.vectorizer = vectorizer or dsp.SentenceTransformersVectorizer()
self.embedding = vectorizer or dspy.Embedding(dsp.SentenceTransformersVectorizer())
trainset_casted_to_vectorize = [
" | ".join([f"{key}: {value}" for key, value in example.items() if key in example._input_keys])
for example in self.trainset
]
self.trainset_vectors = self.vectorizer(trainset_casted_to_vectorize).astype(np.float32)
self.trainset_vectors = self.embedding(trainset_casted_to_vectorize).astype(np.float32)

def __call__(self, **kwargs) -> List[dsp.Example]:
with dsp.settings.context(vectorizer=self.vectorizer):
input_example_vector = self.vectorizer([" | ".join([f"{key}: {val}" for key, val in kwargs.items()])])
scores = np.dot(self.trainset_vectors, input_example_vector.T).squeeze()
nearest_samples_idxs = scores.argsort()[-self.k :][::-1]
train_sampled = [self.trainset[cur_idx] for cur_idx in nearest_samples_idxs]
return train_sampled
input_example_vector = self.embedding([" | ".join([f"{key}: {val}" for key, val in kwargs.items()])])
scores = np.dot(self.trainset_vectors, input_example_vector.T).squeeze()
nearest_samples_idxs = scores.argsort()[-self.k :][::-1]
train_sampled = [self.trainset[cur_idx] for cur_idx in nearest_samples_idxs]
return train_sampled
4 changes: 2 additions & 2 deletions dspy/teleprompt/knn_fewshot.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import types
from typing import List, Optional
from typing import List

import dsp
from dspy.predict.knn import KNN
Expand All @@ -9,7 +9,7 @@


class KNNFewShot(Teleprompter):
def __init__(self, k: int, trainset: List[dsp.Example], vectorizer: Optional[dsp.BaseSentenceVectorizer] = None, **few_shot_bootstrap_args):
def __init__(self, k: int, trainset: List[dsp.Example], vectorizer=None, **few_shot_bootstrap_args):
self.KNN = KNN(k, trainset, vectorizer=vectorizer)
self.few_shot_bootstrap_args = few_shot_bootstrap_args

Expand Down

0 comments on commit 76a9b27

Please sign in to comment.