Skip to content

Commit

Permalink
loader should sample from files without replacement (#73)
Browse files Browse the repository at this point in the history
* loader should sample from files without replacement

* lint

* codecov

* lint

* lint

* fixing lint

* Adding docstrings

---------

Co-authored-by: Hugo Flores Garcia <[email protected]>
Co-authored-by: pseeth <[email protected]>
  • Loading branch information
3 people authored Feb 23, 2023
1 parent 3ec92c5 commit b6a22b3
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 2 deletions.
2 changes: 1 addition & 1 deletion audiotools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.6.1"
__version__ = "0.6.2"
from .core import AudioSignal
from .core import STFTParams
from .core import Meter
Expand Down
30 changes: 30 additions & 0 deletions audiotools/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ class AudioLoader:
List of extensions to find audio within each source by. Can
also be a file name (e.g. "vocals.wav"). by default
``['.wav', '.flac', '.mp3', '.mp4']``.
shuffle: bool
Whether to shuffle the files within the dataloader. Defaults to True.
shuffle_state: int
State to use to seed the shuffle of the files.
"""

def __init__(
Expand All @@ -44,10 +48,22 @@ def __init__(
transform: Callable = None,
relative_path: str = "",
ext: List[str] = util.AUDIO_EXTENSIONS,
shuffle: bool = True,
shuffle_state: int = 0,
):
self.audio_lists = util.read_sources(
sources, relative_path=relative_path, ext=ext
)

self.audio_indices = [
(src_idx, item_idx)
for src_idx, src in enumerate(self.audio_lists)
for item_idx in range(len(src))
]
if shuffle:
state = util.random_state(shuffle_state)
state.shuffle(self.audio_indices)

self.sources = sources
self.weights = weights
self.transform = transform
Expand All @@ -62,12 +78,18 @@ def __call__(
offset: float = None,
source_idx: int = None,
item_idx: int = None,
global_idx: int = None,
):
if source_idx is not None and item_idx is not None:
try:
audio_info = self.audio_lists[source_idx][item_idx]
except:
audio_info = {"path": "none"}
elif global_idx is not None:
source_idx, item_idx = self.audio_indices[
global_idx % len(self.audio_indices)
]
audio_info = self.audio_lists[source_idx][item_idx]
else:
audio_info, source_idx, item_idx = util.choose_from_list_of_lists(
state, self.audio_lists, p=self.weights
Expand Down Expand Up @@ -169,6 +191,11 @@ class AudioDataset:
offset, duration, and matched file name), by default False
shuffle_loaders : bool, optional
Whether to shuffle the loaders before sampling from them, by default False
matcher : Callable
How to match files from adjacent audio lists (e.g. for a multitrack audio loader),
by default uses the parent directory of each file.
without_replacement : bool
Whether to choose files with or without replacement, by default True.
Examples
Expand Down Expand Up @@ -341,6 +368,7 @@ def __init__(
aligned: bool = False,
shuffle_loaders: bool = False,
matcher: Callable = default_matcher,
without_replacement: bool = True,
):
# Internally we convert loaders to a dictionary
if isinstance(loaders, list):
Expand All @@ -359,6 +387,7 @@ def __init__(
self.offset = offset
self.aligned = aligned
self.shuffle_loaders = shuffle_loaders
self.without_replacement = without_replacement

if aligned:
loaders_list = list(loaders.values())
Expand All @@ -382,6 +411,7 @@ def __getitem__(self, idx):
"duration": self.duration,
"loudness_cutoff": self.loudness_cutoff,
"num_channels": self.num_channels,
"global_idx": idx if self.without_replacement else None,
}

# Draw item from first loader
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

setup(
name="audiotools",
version="0.6.1",
version="0.6.2",
classifiers=[
"Intended Audience :: Developers",
"Intended Audience :: Education",
Expand Down
59 changes: 59 additions & 0 deletions tests/data/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,65 @@ def test_aligned_audio_dataset():
assert np.all(col == col[0])


def test_loader_without_replacement():
with tempfile.TemporaryDirectory() as d:
dataset_dir = Path(d)
num_items = 100
audiotools.util.generate_chord_dataset(
max_voices=1,
num_items=num_items,
output_dir=dataset_dir,
duration=0.01,
)
loader = audiotools.data.datasets.AudioLoader([dataset_dir], shuffle=False)
dataset = audiotools.data.datasets.AudioDataset(loader, 44100)

for idx in range(num_items):
item = dataset[idx]
assert item["item_idx"] == idx


def test_loader_with_replacement():
with tempfile.TemporaryDirectory() as d:
dataset_dir = Path(d)
num_items = 100
audiotools.util.generate_chord_dataset(
max_voices=1,
num_items=num_items,
output_dir=dataset_dir,
duration=0.01,
)
loader = audiotools.data.datasets.AudioLoader([dataset_dir])
dataset = audiotools.data.datasets.AudioDataset(
loader, 44100, without_replacement=False
)

for idx in range(num_items):
item = dataset[idx]


def test_loader_out_of_range():
with tempfile.TemporaryDirectory() as d:
dataset_dir = Path(d)
num_items = 100
audiotools.util.generate_chord_dataset(
max_voices=1,
num_items=num_items,
output_dir=dataset_dir,
duration=0.01,
)
loader = audiotools.data.datasets.AudioLoader([dataset_dir])

item = loader(
sample_rate=44100,
duration=0.01,
state=audiotools.util.random_state(0),
source_idx=0,
item_idx=101,
)
assert item["path"] == "none"


def test_dataset_pipeline():
transform = tfm.Compose(
[
Expand Down

0 comments on commit b6a22b3

Please sign in to comment.