Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementing KNN classifier with Inference #583

Closed
sapan opened this issue Jan 31, 2023 · 3 comments
Closed

Implementing KNN classifier with Inference #583

sapan opened this issue Jan 31, 2023 · 3 comments
Labels
question A general question about the library

Comments

@sapan
Copy link

sapan commented Jan 31, 2023

Excellent library Kevin - Thanks a lot.

In my hyperparameter search, I am trying to perform KNN classification with query labels in my 'test' set and reference labels in my 'train' set. I thought of doing the following:

  1. Add custom accuracy metric by subclassing accuracycalculator
  2. use splits_to_eval to include {'test' : 'train'}
    I thought along with a custom function and splits_to_eval, I should get the KNN classifier. However, this is not going to help - reason being - the accuracy calculator class is supposed to compare class labels of query and reference embeddings -- and not supposed to predict class label for query. Is this correct?

The InferenceModel class has all the required ingredients for doing KNN I believe. It would really help if you can provide some idea on how to go about this. (I think having KNNClassifier in future version would be a really good feature)

@KevinMusgrave
Copy link
Owner

I'll come back to this tomorrow, but I just wanted to say that it might be worth trying scikit learn's KNeighborsClassifier if you haven't already. You just need to convert your tensors to numpy.

@sapan
Copy link
Author

sapan commented Jan 31, 2023

Thanks for the suggestion. Let me try that.
I am thinking to first call
base_tester.get_all_embeddings(dataset, trunk, embedder, collate_fn,..) to compute emberddings for both train and test set and then use that in sklearn.

@sapan
Copy link
Author

sapan commented Jan 31, 2023

This is working for me. I am adding the overall flow here - may help others.

  1. trainer.train() // the trainer class in pytorch-metric-learning
  2. best_model=load best Trunk model from 'example_saved_models' (the one having best regex)
  3. tester=GlobalEmbeddingSpaceTester()
  4. KNN classification here
    4.1 compute embeddings for train and test data -- using,
    embeddings, labels = tester.get_all_embeddings(dataset, best_trunk_model, embedded, collate_fn, result_as_numpy=True)
    labels = labels.reshape(labels.shape[0])
    4.2 from sklearn.neighbors import KNeighborClassifier
    from sklearn.metrics import accuracy_score
    knnmod=KNeighborClassifier(n_neighbors=20, weights='distance')
    knnmod.fit(train_embs, train_labels)
    predicted=knnmod.predict(test_embs)
    accuracy_score(train_labels, predicted)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question A general question about the library
Projects
None yet
Development

No branches or pull requests

2 participants