diff --git a/sae_training/activations_store.py b/sae_training/activations_store.py index 0d3722a3..19a0653d 100644 --- a/sae_training/activations_store.py +++ b/sae_training/activations_store.py @@ -272,7 +272,7 @@ def get_data_loader( # 3. put other 50 % in a dataloader dataloader = iter( DataLoader( - mixing_buffer[: mixing_buffer.shape[0] // 2 :], + mixing_buffer[mixing_buffer.shape[0] // 2 :], batch_size=batch_size, shuffle=True, )