Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

train_on_dataset much slower when using ActiveLearningDataset compared to torch Dataset #264

Closed
arthur-thuy opened this issue Jun 22, 2023 · 6 comments
Labels
bug Something isn't working

Comments

@arthur-thuy
Copy link
Contributor

arthur-thuy commented Jun 22, 2023

Describe the bug
When the entire pool is labelled (i.e. training on the entire training set), the train_on_dataset function is much slower when using an ActiveLearningDataset as compared to using a regular torch Dataset. In the MNIST experiment below, it is 17x slower (!!!).

I suspect this discrepancy is larger when the labelled pool is larger, because there is no difference when only using 20 labelled samples.

To Reproduce
In this gist, a LeNet-5 model with MC Dropout is trained on the entire MNIST data for 1 epoch. Note that this script does not perform any active learning as no acquisitions are done and the pool set is empty. The script was intended to compare training times across AL packages.

The script has an option use-ald, which uses the ActiveLearningDataset in the train_on_dataset function instead of the regular torch Dataset. Please refer to lines 83-94 in the gist for the relevant code.

Results are as follows:

  • python baal_mnist.py --use-ald => "Elapsed training time: 0:1:23"
  • python baal_mnist.py => "Elapsed training time: 0:0:5"

Here is the full output:

> python baal_mnist.py --use-ald
Use GPU: NVIDIA RTX A5000 for training
labelling 60000 observations
[1538876-MainThread] [baal.modelwrapper:train_on_dataset:83] 2023-06-22T09:35:08.962235Z [info     ] Starting training              dataset=60000 epoch=1
[1538876-MainThread] [baal.modelwrapper:train_on_dataset:94] 2023-06-22T09:36:32.093871Z [info     ] Training complete              train_loss=0.21927499771118164
Elapsed training time: 0:1:23
[1538876-MainThread] [baal.modelwrapper:test_on_dataset:123] 2023-06-22T09:36:32.101867Z [info     ] Starting evaluating            dataset=10000
[1538876-MainThread] [baal.modelwrapper:test_on_dataset:133] 2023-06-22T09:36:35.126586Z [info     ] Evaluation complete            test_loss=0.04848730191588402
{'dataset_size': 60000,
 'test_accuracy': 0.9842716455459595,
 'test_loss': 0.04848730191588402,
 'train_accuracy': 0.9318974018096924,
 'train_loss': 0.21927499771118164}
Elapsed total time: 0:1:27
> python baal_mnist.py
Use GPU: NVIDIA RTX A5000 for training
[1538621-MainThread] [baal.modelwrapper:train_on_dataset:83] 2023-06-22T09:34:51.757774Z [info     ] Starting training              dataset=60000 epoch=1
[1538621-MainThread] [baal.modelwrapper:train_on_dataset:94] 2023-06-22T09:34:56.868404Z [info     ] Training complete              train_loss=0.21591344475746155
Elapsed training time: 0:0:5
[1538621-MainThread] [baal.modelwrapper:test_on_dataset:123] 2023-06-22T09:34:56.874050Z [info     ] Starting evaluating            dataset=10000
[1538621-MainThread] [baal.modelwrapper:test_on_dataset:133] 2023-06-22T09:34:59.894939Z [info     ] Evaluation complete            test_loss=0.04452119022607803
{'dataset_size': 60000,
 'test_accuracy': 0.985236644744873,
 'test_loss': 0.04452119022607803,
 'train_accuracy': 0.9333688616752625,
 'train_loss': 0.21591344475746155}
Elapsed total time: 0:0:9

Expected behavior
I would expect the training time with ActiveLearningDataset to be a few percent slower, but not 17x slower.

Version (please complete the following information):

  • OS: Ubuntu 20.04
  • Python: 3.9.16
  • Baal version: 1.7.0

Additional context
I want to use active learning in my experiments, so just using the torch Dataset is not an appropriate solution.

Any ideas why this is the case and whether this could be fixed?
Thank you!

@arthur-thuy arthur-thuy added the bug Something isn't working label Jun 22, 2023
@arthur-thuy arthur-thuy changed the title train_on_dataset much slower when using ActiveLearningDataset compared to torch.Dataset train_on_dataset much slower when using ActiveLearningDataset compared to torch Dataset Jun 22, 2023
@Dref360
Copy link
Member

Dref360 commented Jun 23, 2023

Hello!

I was able to reproduce with this example

# test that active learning is fast
from torchvision.datasets import CIFAR10
from baal.active.dataset import ActiveLearningDataset
dataset = CIFAR10(root='/tmp', train=True, download=True)
al_dataset = ActiveLearningDataset(dataset)
al_dataset.label_randomly(len(dataset))
%timeit [x for x in al_dataset]
%timeit [x for x in dataset]

I have a possible fix where we cache the result of ActiveLearningDataset.get_indices_for_active_step.

I'll try to merge this quickly and make a release, but I'm away for the long weekend. Coming back on Monday

@Dref360
Copy link
Member

Dref360 commented Jun 23, 2023

I opened #265, not super happy with the solution, but that's the best I can do for now. Now it is "only" 2x slower, will revisit next week, but feel free to use the branch fix/al_dataset_speed for your experiments.

@arthur-thuy
Copy link
Contributor Author

Thank you for the fix! I would be happy with the "only 2x slower" training time.

@Dref360
Copy link
Member

Dref360 commented Jul 4, 2023

Hello!
Are you comfortable installing Baal from source on the branch fix/al_dataset_speed?

I want to be sure that we are not blocking you. If so, I'll immediately merge and deploy a minor release asap.

@arthur-thuy
Copy link
Contributor Author

I’m currently on holiday and need to work on a paper revision when I return to the office (not related to Baal). As such, I’ll not be working with Baal the next 4 weeks so the fix is not urgent for me.

If the minor release is not done by then, I’ll install it from source. Thank you for your message.

@Dref360
Copy link
Member

Dref360 commented Jul 17, 2023

Fixed in #265

@Dref360 Dref360 closed this as completed Jul 17, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants