Skip to content

Commit

Permalink
Added unit test
Browse files Browse the repository at this point in the history
Signed-off-by: Ante Jukić <[email protected]>
  • Loading branch information
anteju committed Apr 29, 2024
1 parent 8972716 commit 8463836
Showing 1 changed file with 100 additions and 0 deletions.
100 changes: 100 additions & 0 deletions tests/collections/common/test_lhotse_dataloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,51 @@ def nemo_manifest_path(cutset_path: Path):
return p


@pytest.fixture(scope="session")
def mc_cutset_path(tmp_path_factory) -> Path:
"""10 two-channel utterances of length 1s as a Lhotse CutSet."""
from lhotse import CutSet, MultiCut
from lhotse.testing.dummies import DummyManifest

num_examples = 10 # number of examples
num_channels = 2 # number of channels per example

# create a dummy manifest with single-channel examples
sc_cuts = DummyManifest(CutSet, begin_id=0, end_id=num_examples * num_channels, with_data=True)
mc_cuts = []

for n in range(num_examples):
# sources for individual channels
mc_sources = []
for channel in range(num_channels):
source = sc_cuts[n * num_channels + channel].recording.sources[0]
source.channels = [channel]
mc_sources.append(source)

# merge recordings
rec = Recording(
sources=mc_sources,
id=f'mc-dummy-recording-{n:02d}',
num_samples=sc_cuts[0].num_samples,
duration=sc_cuts[0].duration,
sampling_rate=sc_cuts[0].sampling_rate,
)

# multi-channel cut
cut = MultiCut(
recording=rec, id=f'mc-dummy-cut-{n:02d}', start=0, duration=1.0, channel=list(range(num_channels))
)
mc_cuts.append(cut)

mc_cuts = CutSet.from_cuts(mc_cuts)

tmp_path = tmp_path_factory.mktemp("data")
p = tmp_path / "mc_cuts.jsonl.gz"
pa = tmp_path / "mc_audio"
mc_cuts.save_audios(pa).to_file(p)
return p


@pytest.fixture(scope="session")
def nemo_tarred_manifest_path(nemo_manifest_path: Path) -> Tuple[str, str]:
"""10 utterances of length 1s as a NeMo tarred manifest."""
Expand Down Expand Up @@ -247,6 +292,61 @@ def test_dataloader_from_lhotse_cuts_cut_into_windows(cutset_path: Path):
# exactly 20 cuts were used because we cut 10x 1s cuts into 20x 0.5s cuts


def test_dataloader_from_lhotse_cuts_channel_selector(mc_cutset_path: Path):
# Dataloader without channel selector
config = OmegaConf.create(
{
"cuts_path": mc_cutset_path,
"sample_rate": 16000,
"shuffle": True,
"use_lhotse": True,
"num_workers": 0,
"batch_size": 4,
"seed": 0,
}
)

dl = get_lhotse_dataloader_from_config(
config=config, global_rank=0, world_size=1, dataset=UnsupervisedAudioDataset()
)
batches = [b for b in dl]
assert len(batches) == 3

# 1.0s = 16000 samples, two channels, note the constant duration and batch size
assert batches[0]["audio"].shape == (4, 2, 16000)
assert batches[1]["audio"].shape == (4, 2, 16000)
assert batches[2]["audio"].shape == (2, 2, 16000)
# exactly 10 cuts were used

# Apply channel selector
for channel_selector in [None, 0, 1]:

config_cs = OmegaConf.create(
{
"cuts_path": mc_cutset_path,
"channel_selector": channel_selector,
"sample_rate": 16000,
"shuffle": True,
"use_lhotse": True,
"num_workers": 0,
"batch_size": 4,
"seed": 0,
}
)

dl_cs = get_lhotse_dataloader_from_config(
config=config_cs, global_rank=0, world_size=1, dataset=UnsupervisedAudioDataset()
)

for n, b_cs in enumerate(dl_cs):
if channel_selector is None:
# no channel selector, needs to match the original dataset
assert torch.equal(b_cs["audio"], batches[n]["audio"])
else:
# channel selector, needs to match the selected channel
assert torch.equal(b_cs["audio"], batches[n]["audio"][:, channel_selector, :])


@requires_torchaudio
def test_dataloader_from_lhotse_shar_cuts(cutset_shar_path: Path):
config = OmegaConf.create(
Expand Down

0 comments on commit 8463836

Please sign in to comment.