Skip to content

Commit

Permalink
remove fn_kwargs from Filter and Mapper datapipes (#5113)
Browse files Browse the repository at this point in the history
* remove fn_kwargs from Filter and Mapper datapipes

* fix leftovers
  • Loading branch information
pmeier authored Dec 19, 2021
1 parent 40be657 commit 1efb567
Show file tree
Hide file tree
Showing 10 changed files with 30 additions and 19 deletions.
5 changes: 3 additions & 2 deletions torchvision/prototype/datasets/_builtin/caltech.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import io
import pathlib
import re
Expand Down Expand Up @@ -132,7 +133,7 @@ def _make_datapipe(
buffer_size=INFINITE_BUFFER_SIZE,
keep_key=True,
)
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder))

def _generate_categories(self, root: pathlib.Path) -> List[str]:
dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name)
Expand Down Expand Up @@ -185,7 +186,7 @@ def _make_datapipe(
dp = Filter(dp, self._is_not_rogue_file)
dp = hint_sharding(dp)
dp = hint_shuffling(dp)
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder))

def _generate_categories(self, root: pathlib.Path) -> List[str]:
dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name)
Expand Down
6 changes: 3 additions & 3 deletions torchvision/prototype/datasets/_builtin/celeba.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import csv
import functools
import io
from typing import Any, Callable, Dict, List, Optional, Tuple, Iterator, Sequence

Expand Down Expand Up @@ -26,7 +27,6 @@
hint_shuffling,
)


csv.register_dialect("celeba", delimiter=" ", skipinitialspace=True)


Expand Down Expand Up @@ -155,7 +155,7 @@ def _make_datapipe(
splits_dp, images_dp, identities_dp, attributes_dp, bboxes_dp, landmarks_dp = resource_dps

splits_dp = CelebACSVParser(splits_dp, fieldnames=("image_id", "split_id"))
splits_dp = Filter(splits_dp, self._filter_split, fn_kwargs=dict(split=config.split))
splits_dp = Filter(splits_dp, functools.partial(self._filter_split, split=config.split))
splits_dp = hint_sharding(splits_dp)
splits_dp = hint_shuffling(splits_dp)

Expand All @@ -181,4 +181,4 @@ def _make_datapipe(
keep_key=True,
)
dp = IterKeyZipper(dp, anns_dp, key_fn=getitem(0), buffer_size=INFINITE_BUFFER_SIZE)
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder))
2 changes: 1 addition & 1 deletion torchvision/prototype/datasets/_builtin/cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def _make_datapipe(
dp = CifarFileReader(dp, labels_key=self._LABELS_KEY)
dp = hint_sharding(dp)
dp = hint_shuffling(dp)
return Mapper(dp, self._collate_and_decode, fn_kwargs=dict(decoder=decoder))
return Mapper(dp, functools.partial(self._collate_and_decode, decoder=decoder))

def _generate_categories(self, root: pathlib.Path) -> List[str]:
dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name)
Expand Down
16 changes: 11 additions & 5 deletions torchvision/prototype/datasets/_builtin/coco.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import io
import pathlib
import re
Expand Down Expand Up @@ -183,12 +184,16 @@ def _make_datapipe(
if config.annotations is None:
dp = hint_sharding(images_dp)
dp = hint_shuffling(dp)
return Mapper(dp, self._collate_and_decode_image, fn_kwargs=dict(decoder=decoder))
return Mapper(dp, functools.partial(self._collate_and_decode_image, decoder=decoder))

meta_dp = Filter(
meta_dp,
self._filter_meta_files,
fn_kwargs=dict(split=config.split, year=config.year, annotations=config.annotations),
functools.partial(
self._filter_meta_files,
split=config.split,
year=config.year,
annotations=config.annotations,
),
)
meta_dp = JsonParser(meta_dp)
meta_dp = Mapper(meta_dp, getitem(1))
Expand Down Expand Up @@ -226,7 +231,7 @@ def _make_datapipe(
buffer_size=INFINITE_BUFFER_SIZE,
)
return Mapper(
dp, self._collate_and_decode_sample, fn_kwargs=dict(annotations=config.annotations, decoder=decoder)
dp, functools.partial(self._collate_and_decode_sample, annotations=config.annotations, decoder=decoder)
)

def _generate_categories(self, root: pathlib.Path) -> Tuple[Tuple[str, str]]:
Expand All @@ -235,7 +240,8 @@ def _generate_categories(self, root: pathlib.Path) -> Tuple[Tuple[str, str]]:

dp = resources[1].load(pathlib.Path(root) / self.name)
dp = Filter(
dp, self._filter_meta_files, fn_kwargs=dict(split=config.split, year=config.year, annotations="instances")
dp,
functools.partial(self._filter_meta_files, split=config.split, year=config.year, annotations="instances"),
)
dp = JsonParser(dp)

Expand Down
3 changes: 2 additions & 1 deletion torchvision/prototype/datasets/_builtin/imagenet.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import io
import pathlib
import re
Expand Down Expand Up @@ -165,7 +166,7 @@ def _make_datapipe(
dp = hint_shuffling(dp)
dp = Mapper(dp, self._collate_test_data)

return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder))

# Although the WordNet IDs (wnids) are unique, the corresponding categories are not. For example, both n02012849
# and n03126707 are labeled 'crane' while the first means the bird and the latter means the construction equipment
Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/datasets/_builtin/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def _make_datapipe(
dp = Zipper(images_dp, labels_dp)
dp = hint_sharding(dp)
dp = hint_shuffling(dp)
return Mapper(dp, self._collate_and_decode, fn_kwargs=dict(config=config, decoder=decoder))
return Mapper(dp, functools.partial(self._collate_and_decode, config=config, decoder=decoder))


class MNIST(_MNISTBase):
Expand Down
3 changes: 2 additions & 1 deletion torchvision/prototype/datasets/_builtin/sbd.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import io
import pathlib
import re
Expand Down Expand Up @@ -152,7 +153,7 @@ def _make_datapipe(
ref_key_fn=path_accessor("stem"),
buffer_size=INFINITE_BUFFER_SIZE,
)
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(config=config, decoder=decoder))
return Mapper(dp, functools.partial(self._collate_and_decode_sample, config=config, decoder=decoder))

def _generate_categories(self, root: pathlib.Path) -> Tuple[str, ...]:
dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name)
Expand Down
3 changes: 2 additions & 1 deletion torchvision/prototype/datasets/_builtin/semeion.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import io
from typing import Any, Callable, Dict, List, Optional, Tuple

Expand Down Expand Up @@ -65,5 +66,5 @@ def _make_datapipe(
dp = CSVParser(dp, delimiter=" ")
dp = hint_sharding(dp)
dp = hint_shuffling(dp)
dp = Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
dp = Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder))
return dp
4 changes: 2 additions & 2 deletions torchvision/prototype/datasets/_builtin/voc.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def _make_datapipe(
buffer_size=INFINITE_BUFFER_SIZE,
)

split_dp = Filter(split_dp, self._is_in_folder, fn_kwargs=dict(name=self._SPLIT_FOLDER[config.task]))
split_dp = Filter(split_dp, functools.partial(self._is_in_folder, name=self._SPLIT_FOLDER[config.task]))
split_dp = Filter(split_dp, path_comparator("name", f"{config.split}.txt"))
split_dp = LineReader(split_dp, decode=True)
split_dp = hint_sharding(split_dp)
Expand All @@ -142,4 +142,4 @@ def _make_datapipe(
ref_key_fn=path_accessor("stem"),
buffer_size=INFINITE_BUFFER_SIZE,
)
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(config=config, decoder=decoder))
return Mapper(dp, functools.partial(self._collate_and_decode_sample, config=config, decoder=decoder))
5 changes: 3 additions & 2 deletions torchvision/prototype/datasets/_folder.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import io
import os
import os.path
Expand Down Expand Up @@ -50,12 +51,12 @@ def from_data_folder(
categories = sorted(entry.name for entry in os.scandir(root) if entry.is_dir())
masks: Union[List[str], str] = [f"*.{ext}" for ext in valid_extensions] if valid_extensions is not None else ""
dp = FileLister(str(root), recursive=recursive, masks=masks)
dp: IterDataPipe = Filter(dp, _is_not_top_level_file, fn_kwargs=dict(root=root))
dp: IterDataPipe = Filter(dp, functools.partial(_is_not_top_level_file, root=root))
dp = hint_sharding(dp)
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
dp = FileLoader(dp)
return (
Mapper(dp, _collate_and_decode_data, fn_kwargs=dict(root=root, categories=categories, decoder=decoder)),
Mapper(dp, functools.partial(_collate_and_decode_data, root=root, categories=categories, decoder=decoder)),
categories,
)

Expand Down

0 comments on commit 1efb567

Please sign in to comment.