diff --git a/nemo/collections/tts/data/text_to_speech_dataset.py b/nemo/collections/tts/data/text_to_speech_dataset.py index 1e551bcf2020..0addfc9724a0 100644 --- a/nemo/collections/tts/data/text_to_speech_dataset.py +++ b/nemo/collections/tts/data/text_to_speech_dataset.py @@ -139,12 +139,15 @@ def __init__( self.data_samples += samples self.sample_weights += weights - def get_sampler(self, batch_size: int) -> Optional[torch.utils.data.Sampler]: + def get_sampler(self, batch_size: int, world_size: int) -> Optional[torch.utils.data.Sampler]: if not self.weighted_sampling_steps_per_epoch: return None sampler = get_weighted_sampler( - sample_weights=self.sample_weights, batch_size=batch_size, num_steps=self.weighted_sampling_steps_per_epoch + sample_weights=self.sample_weights, + batch_size=batch_size, + world_size=world_size, + num_steps=self.weighted_sampling_steps_per_epoch, ) return sampler diff --git a/nemo/collections/tts/data/vocoder_dataset.py b/nemo/collections/tts/data/vocoder_dataset.py index 93295750dd43..76dfe1154ae9 100644 --- a/nemo/collections/tts/data/vocoder_dataset.py +++ b/nemo/collections/tts/data/vocoder_dataset.py @@ -166,12 +166,15 @@ def __init__( self.data_samples += samples self.sample_weights += weights - def get_sampler(self, batch_size: int) -> Optional[torch.utils.data.Sampler]: + def get_sampler(self, batch_size: int, world_size: int) -> Optional[torch.utils.data.Sampler]: if not self.weighted_sampling_steps_per_epoch: return None sampler = get_weighted_sampler( - sample_weights=self.sample_weights, batch_size=batch_size, num_steps=self.weighted_sampling_steps_per_epoch + sample_weights=self.sample_weights, + batch_size=batch_size, + world_size=world_size, + num_steps=self.weighted_sampling_steps_per_epoch, ) return sampler @@ -410,7 +413,7 @@ def _build_sample(self, tup): return example - def get_sampler(self, batch_size: int = 16): + def get_sampler(self, batch_size: int, world_size: int): """ Currently sampler is not supported for tarred dataset. """ diff --git a/nemo/collections/tts/models/audio_codec.py b/nemo/collections/tts/models/audio_codec.py index 1bdc383952c3..ccc93d690257 100644 --- a/nemo/collections/tts/models/audio_codec.py +++ b/nemo/collections/tts/models/audio_codec.py @@ -506,7 +506,7 @@ def get_dataset(self, cfg): dataset = instantiate(cfg.dataset) - sampler = dataset.get_sampler(cfg.dataloader_params.batch_size) + sampler = dataset.get_sampler(cfg.dataloader_params.batch_size, world_size=self.trainer.world_size) return dataset, sampler def _setup_train_dataloader(self, cfg): diff --git a/nemo/collections/tts/models/fastpitch.py b/nemo/collections/tts/models/fastpitch.py index bfb00e00c4ba..3235a096a04b 100644 --- a/nemo/collections/tts/models/fastpitch.py +++ b/nemo/collections/tts/models/fastpitch.py @@ -600,7 +600,7 @@ def _setup_train_dataloader(self, cfg): with phon_mode: dataset = instantiate(cfg.dataset, text_tokenizer=self.vocab,) - sampler = dataset.get_sampler(cfg.dataloader_params.batch_size) + sampler = dataset.get_sampler(cfg.dataloader_params.batch_size, world_size=self.trainer.world_size) return torch.utils.data.DataLoader( dataset, collate_fn=dataset.collate_fn, sampler=sampler, **cfg.dataloader_params ) diff --git a/nemo/collections/tts/models/hifigan.py b/nemo/collections/tts/models/hifigan.py index 03205f5685cd..7a9a6d30671f 100644 --- a/nemo/collections/tts/models/hifigan.py +++ b/nemo/collections/tts/models/hifigan.py @@ -337,7 +337,7 @@ def istft(mags, phase): def _setup_train_dataloader(self, cfg): dataset = instantiate(cfg.dataset) - sampler = dataset.get_sampler(cfg.dataloader_params.batch_size) + sampler = dataset.get_sampler(cfg.dataloader_params.batch_size, world_size=self.trainer.world_size) data_loader = torch.utils.data.DataLoader( dataset, collate_fn=dataset.collate_fn, sampler=sampler, **cfg.dataloader_params ) diff --git a/nemo/collections/tts/parts/utils/tts_dataset_utils.py b/nemo/collections/tts/parts/utils/tts_dataset_utils.py index 92b50edac143..5f1185c2c399 100644 --- a/nemo/collections/tts/parts/utils/tts_dataset_utils.py +++ b/nemo/collections/tts/parts/utils/tts_dataset_utils.py @@ -218,7 +218,7 @@ def filter_dataset_by_duration(entries: List[Dict[str, Any]], min_duration: floa def get_weighted_sampler( - sample_weights: List[float], batch_size: int, num_steps: int + sample_weights: List[float], batch_size: int, world_size: int, num_steps: int ) -> torch.utils.data.WeightedRandomSampler: """ Create pytorch sampler for doing weighted random sampling. @@ -226,13 +226,14 @@ def get_weighted_sampler( Args: sample_weights: List of sampling weights for all elements in the dataset. batch_size: Batch size to sample. + world_size: Number of devices being used. num_steps: Number of steps to be considered an epoch. Returns: Pytorch sampler """ weights = torch.tensor(sample_weights, dtype=torch.float64) - num_samples = batch_size * num_steps + num_samples = batch_size * world_size * num_steps sampler = torch.utils.data.WeightedRandomSampler(weights=weights, num_samples=num_samples) return sampler