Skip to content

Commit

Permalink
[TTS] Scale sampler steps by number of devices
Browse files Browse the repository at this point in the history
Signed-off-by: Ryan <[email protected]>
  • Loading branch information
rlangman committed Nov 27, 2023
1 parent 79bc929 commit 3c1eacc
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 10 deletions.
7 changes: 5 additions & 2 deletions nemo/collections/tts/data/text_to_speech_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,12 +139,15 @@ def __init__(
self.data_samples += samples
self.sample_weights += weights

def get_sampler(self, batch_size: int) -> Optional[torch.utils.data.Sampler]:
def get_sampler(self, batch_size: int, world_size: int) -> Optional[torch.utils.data.Sampler]:
if not self.weighted_sampling_steps_per_epoch:
return None

sampler = get_weighted_sampler(
sample_weights=self.sample_weights, batch_size=batch_size, num_steps=self.weighted_sampling_steps_per_epoch
sample_weights=self.sample_weights,
batch_size=batch_size,
world_size=world_size,
num_steps=self.weighted_sampling_steps_per_epoch,
)
return sampler

Expand Down
9 changes: 6 additions & 3 deletions nemo/collections/tts/data/vocoder_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,12 +166,15 @@ def __init__(
self.data_samples += samples
self.sample_weights += weights

def get_sampler(self, batch_size: int) -> Optional[torch.utils.data.Sampler]:
def get_sampler(self, batch_size: int, world_size: int) -> Optional[torch.utils.data.Sampler]:
if not self.weighted_sampling_steps_per_epoch:
return None

sampler = get_weighted_sampler(
sample_weights=self.sample_weights, batch_size=batch_size, num_steps=self.weighted_sampling_steps_per_epoch
sample_weights=self.sample_weights,
batch_size=batch_size,
world_size=world_size,
num_steps=self.weighted_sampling_steps_per_epoch,
)
return sampler

Expand Down Expand Up @@ -410,7 +413,7 @@ def _build_sample(self, tup):

return example

def get_sampler(self, batch_size: int = 16):
def get_sampler(self, batch_size: int, world_size: int):
"""
Currently sampler is not supported for tarred dataset.
"""
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/tts/models/audio_codec.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ def get_dataset(self, cfg):

dataset = instantiate(cfg.dataset)

sampler = dataset.get_sampler(cfg.dataloader_params.batch_size)
sampler = dataset.get_sampler(cfg.dataloader_params.batch_size, world_size=self.trainer.world_size)
return dataset, sampler

def _setup_train_dataloader(self, cfg):
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/tts/models/fastpitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,7 +600,7 @@ def _setup_train_dataloader(self, cfg):
with phon_mode:
dataset = instantiate(cfg.dataset, text_tokenizer=self.vocab,)

sampler = dataset.get_sampler(cfg.dataloader_params.batch_size)
sampler = dataset.get_sampler(cfg.dataloader_params.batch_size, world_size=self.trainer.world_size)
return torch.utils.data.DataLoader(
dataset, collate_fn=dataset.collate_fn, sampler=sampler, **cfg.dataloader_params
)
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/tts/models/hifigan.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def istft(mags, phase):

def _setup_train_dataloader(self, cfg):
dataset = instantiate(cfg.dataset)
sampler = dataset.get_sampler(cfg.dataloader_params.batch_size)
sampler = dataset.get_sampler(cfg.dataloader_params.batch_size, world_size=self.trainer.world_size)
data_loader = torch.utils.data.DataLoader(
dataset, collate_fn=dataset.collate_fn, sampler=sampler, **cfg.dataloader_params
)
Expand Down
5 changes: 3 additions & 2 deletions nemo/collections/tts/parts/utils/tts_dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,21 +218,22 @@ def filter_dataset_by_duration(entries: List[Dict[str, Any]], min_duration: floa


def get_weighted_sampler(
sample_weights: List[float], batch_size: int, num_steps: int
sample_weights: List[float], batch_size: int, world_size: int, num_steps: int
) -> torch.utils.data.WeightedRandomSampler:
"""
Create pytorch sampler for doing weighted random sampling.
Args:
sample_weights: List of sampling weights for all elements in the dataset.
batch_size: Batch size to sample.
world_size: Number of devices being used.
num_steps: Number of steps to be considered an epoch.
Returns:
Pytorch sampler
"""
weights = torch.tensor(sample_weights, dtype=torch.float64)
num_samples = batch_size * num_steps
num_samples = batch_size * world_size * num_steps
sampler = torch.utils.data.WeightedRandomSampler(weights=weights, num_samples=num_samples)
return sampler

Expand Down

0 comments on commit 3c1eacc

Please sign in to comment.