diff --git a/docs/source/datasets.rst b/docs/source/datasets.rst index 036a2ac9a12..c7bdff274f0 100644 --- a/docs/source/datasets.rst +++ b/docs/source/datasets.rst @@ -50,6 +50,7 @@ You can also create your own datasets using the provided :ref:`base classes `_ Dataset. + + The dataset contains 10,200 images of aircraft, with 100 images for each of 102 + different aircraft model variants, most of which are airplanes. + Aircraft models are organized in a three-levels hierarchy. The three levels, from + finer to coarser, are: + + - ``variant``, e.g. Boeing 737-700. A variant collapses all the models that are visually + indistinguishable into one class. The dataset comprises 102 different variants. + - ``family``, e.g. Boeing 737. The dataset comprises 70 different families. + - ``manufacturer``, e.g. Boeing. The dataset comprises 41 different manufacturers. + + Args: + root (string): Root directory of the FGVC Aircraft dataset. + split (string, optional): The dataset split, supports ``train``, ``val``, + ``trainval`` and ``test``. + download (bool, optional): If True, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + annotation_level (str, optional): The annotation level, supports ``variant``, + ``family`` and ``manufacturer``. + transform (callable, optional): A function/transform that takes in an PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + """ + + _URL = "https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz" + + def __init__( + self, + root: str, + split: str = "trainval", + download: bool = False, + annotation_level: str = "variant", + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + ) -> None: + super().__init__(root, transform=transform, target_transform=target_transform) + self._split = verify_str_arg(split, "split", ("train", "val", "trainval", "test")) + self._annotation_level = verify_str_arg( + annotation_level, "annotation_level", ("variant", "family", "manufacturer") + ) + + self._data_path = os.path.join(self.root, "fgvc-aircraft-2013b") + if download: + self._download() + + if not self._check_exists(): + raise RuntimeError("Dataset not found. You can use download=True to download it") + + annotation_file = os.path.join( + self._data_path, + "data", + { + "variant": "variants.txt", + "family": "families.txt", + "manufacturer": "manufacturers.txt", + }[self._annotation_level], + ) + with open(annotation_file, "r") as f: + self.classes = [line.strip() for line in f] + + self.class_to_idx = dict(zip(self.classes, range(len(self.classes)))) + + image_data_folder = os.path.join(self._data_path, "data", "images") + labels_file = os.path.join(self._data_path, "data", f"images_{self._annotation_level}_{self._split}.txt") + + self._image_files = [] + self._labels = [] + + with open(labels_file, "r") as f: + for line in f: + image_name, label_name = line.strip().split(" ", 1) + self._image_files.append(os.path.join(image_data_folder, f"{image_name}.jpg")) + self._labels.append(self.class_to_idx[label_name]) + + def __len__(self) -> int: + return len(self._image_files) + + def __getitem__(self, idx) -> Tuple[Any, Any]: + image_file, label = self._image_files[idx], self._labels[idx] + image = PIL.Image.open(image_file).convert("RGB") + + if self.transform: + image = self.transform(image) + + if self.target_transform: + label = self.target_transform(label) + + return image, label + + def _download(self) -> None: + """ + Download the FGVC Aircraft dataset archive and extract it under root. + """ + if self._check_exists(): + return + download_and_extract_archive(self._URL, self.root) + + def _check_exists(self) -> bool: + return os.path.exists(self._data_path) and os.path.isdir(self._data_path)