Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mergeback 1.9.0 to develop #1604

Merged
merged 6 commits into from
Sep 13, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions 3rd-party.txt
Original file line number Diff line number Diff line change
@@ -7518,5 +7518,22 @@ Apache-2.0
See the License for the specific language governing permissions and
limitations under the License.
-------------------------------------------------------------
portalocker

BSD-3-Clause

Copyright 2022 Rick van Hattem

Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:

1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.

2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.

3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

-------------------------------------------------------------

* Other names and brands may be claimed as the property of others.
8 changes: 7 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -5,16 +5,22 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## \[unreleased\]
## \[Q3 2024 Release 1.9.0\]
### New features
- Add a new CLI command: datum format
(<https://github.com/openvinotoolkit/datumaro/pull/1570>)
- Support language dataset for DmTorchDataset
(<https://github.com/openvinotoolkit/datumaro/pull/1592>)

### Enhancements
- Change _Shape to Shape and add comments for subclasses of Shape
(<https://github.com/openvinotoolkit/datumaro/pull/1568>)
- Fix `kitti_raw` importer and exporter for dimensions (height, width, length) in meters
(<https://github.com/openvinotoolkit/datumaro/pull/1596>)

### Bug fixes
- Fix KITTI-3D importer and exporter
(<https://github.com/openvinotoolkit/datumaro/pull/1596>)

## Q3 2024 Release 1.8.0
### New features
16 changes: 16 additions & 0 deletions docs/source/docs/release_notes.rst
Original file line number Diff line number Diff line change
@@ -4,6 +4,22 @@ Release Notes
.. toctree::
:maxdepth: 1

v1.9.0 (2024 Q3)
----------------

New features
^^^^^^^^^^^^
- Add a new CLI command: datum format
- Support language dataset for DmTorchDataset

Enhancements
^^^^^^^^^^^^
- Change _Shape to Shape and add comments for subclasses of Shape

Bug fixes
^^^^^^^^^
- Fix KITTI-3D importer and exporter

v1.8.0 (2024 Q3)
----------------

3 changes: 3 additions & 0 deletions requirements-core.txt
Original file line number Diff line number Diff line change
@@ -64,3 +64,6 @@ json-stream

# TabularValidator
nltk

# torch converter for language
portalocker
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -85,7 +85,7 @@ def parse_requirements(filename=CORE_REQUIREMENTS_FILE):
extras_require={
"tf": ["tensorflow"],
"tfds": ["tensorflow-datasets<4.9.3"],
"torch": ["torch", "torchvision"],
"torch": ["torch", "torchvision", "torchtext==0.16.0"],
"default": DEFAULT_REQUIREMENTS,
},
ext_modules=ext_modules,
4 changes: 2 additions & 2 deletions src/datumaro/plugins/data_formats/kitti_raw/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (C) 2021-2023 Intel Corporation
# Copyright (C) 2021-2024 Intel Corporation
#
# SPDX-License-Identifier: MIT

@@ -182,7 +182,7 @@ def _parse_attr(cls, value):
@classmethod
def _parse_track(cls, track_id, track, categories):
common_attrs = {k: cls._parse_attr(v) for k, v in track["attributes"].items()}
scale = [track["scale"][k] for k in ["w", "h", "l"]]
scale = [track["scale"][k] for k in ["h", "w", "l"]]
label = categories[AnnotationType.label].find(track["label"])[0]

kf_occluded = False
8 changes: 4 additions & 4 deletions src/datumaro/plugins/data_formats/kitti_raw/exporter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (C) 2021 Intel Corporation
# Copyright (C) 2021-2024 Intel Corporation
#
# SPDX-License-Identifier: MIT

@@ -339,16 +339,16 @@ def _create_tracklets(self, subset):
if not track:
track = {
"objectType": label,
"h": ann.scale[1],
"w": ann.scale[0],
"h": ann.scale[0],
"w": ann.scale[1],
"l": ann.scale[2],
"first_frame": frame_id,
"poses": [],
"finished": 1, # keep last
}
tracks[track_id] = track
else:
if [track["w"], track["h"], track["l"]] != ann.scale:
if [track["h"], track["w"], track["l"]] != ann.scale:
# Tracks have fixed scale in the format
raise DatasetExportError(
"Item %s: mismatching track shapes, "
51 changes: 49 additions & 2 deletions src/datumaro/plugins/framework_converter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (C) 2023 Intel Corporation
# Copyright (C) 2023-2024 Intel Corporation
#
# SPDX-License-Identifier: MIT

@@ -17,6 +17,7 @@
"detection": AnnotationType.bbox,
"instance_segmentation": AnnotationType.polygon,
"semantic_segmentation": AnnotationType.mask,
"tabular": [AnnotationType.label, AnnotationType.caption],
}


@@ -88,7 +89,10 @@
if ann.type == TASK_ANN_TYPE[self.task]
]
label = mask_tools.merge_masks((mask, label_id) for mask, label_id in masks)

elif self.task == "tabular":
label = [
ann.as_dict() for ann in item.annotations if ann.type in TASK_ANN_TYPE[self.task]
]
return image, label


@@ -103,15 +107,58 @@
task: str,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
target: Optional[str] = None,
tokenizer: Optional[tuple[Callable, Callable]] = None,
vocab: Optional[tuple[Callable, Callable]] = None,
):
super().__init__(dataset=dataset, subset=subset, task=task)

self.transform = transform
self.target_transform = target_transform

if self.task == "tabular":
if not isinstance(target, dict):
raise ValueError(

Check warning on line 121 in src/datumaro/plugins/framework_converter.py

Codecov / codecov/patch

src/datumaro/plugins/framework_converter.py#L121

Added line #L121 was not covered by tests
"Target should be a dictionary with 'input' and 'output' keys."
)
self.input_target = target.get("input")
self.output_target = target.get("output")
if not self.input_target:
raise ValueError(

Check warning on line 127 in src/datumaro/plugins/framework_converter.py

Codecov / codecov/patch

src/datumaro/plugins/framework_converter.py#L127

Added line #L127 was not covered by tests
"Please provide target column for tabular task which is used for input"
)

if not (tokenizer and vocab):
raise ValueError("Both tokenizer and vocab must be provided for tabular task")

Check warning on line 132 in src/datumaro/plugins/framework_converter.py

Codecov / codecov/patch

src/datumaro/plugins/framework_converter.py#L132

Added line #L132 was not covered by tests
self.tokenizer = tokenizer
self.vocab = vocab

def __getitem__(self, idx):
image, label = self._gen_item(idx)

if self.task == "tabular":
text = image()[self.input_target]

if self.output_target:
src_tokenizer, tgt_tokenizer = self.tokenizer
src_vocab, tgt_vocab = self.vocab
src_tokens = src_tokenizer(text)
src_token_ids = src_vocab(src_tokens)

label_text = label[0]["caption"].split(f"{self.output_target}:")[-1]
tgt_tokens = tgt_tokenizer(label_text)
tgt_token_ids = tgt_vocab(tgt_tokens)

return torch.tensor(src_token_ids, dtype=torch.long), torch.tensor(
tgt_token_ids, dtype=torch.long
)
else:
tokens = self.tokenizer(text)
token_ids = self.vocab(tokens)
return torch.tensor(token_ids, dtype=torch.long), torch.tensor(
label[0]["label"], dtype=torch.long
)

if len(image.shape) == 2:
image = np.expand_dims(image, axis=-1)

2 changes: 1 addition & 1 deletion src/datumaro/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.9.0rc0"
__version__ = "1.9.0"
8 changes: 4 additions & 4 deletions tests/integration/cli/test_kitti_raw_format.py
Original file line number Diff line number Diff line change
@@ -33,13 +33,13 @@ def test_can_convert_to_kitti_raw(self):
annotations=[
Cuboid3d(
position=[1, 2, 3],
scale=[7.95, -3.62, -1.03],
scale=[-3.62, 7.95, -1.03],
label=1,
attributes={"occluded": False, "track_id": 1},
),
Cuboid3d(
position=[1, 1, 0],
scale=[8.34, 23.01, -0.76],
scale=[23.01, 8.34, -0.76],
label=0,
attributes={"occluded": False, "track_id": 2},
),
@@ -65,7 +65,7 @@ def test_can_convert_to_kitti_raw(self):
annotations=[
Cuboid3d(
position=[0, 1, 0],
scale=[8.34, 23.01, -0.76],
scale=[23.01, 8.34, -0.76],
rotation=[1, 1, 3],
label=0,
attributes={"occluded": True, "track_id": 2},
@@ -92,7 +92,7 @@ def test_can_convert_to_kitti_raw(self):
annotations=[
Cuboid3d(
position=[1, 2, 3],
scale=[-9.41, 13.54, 0.24],
scale=[13.54, -9.41, 0.24],
label=1,
attributes={"occluded": False, "track_id": 3},
)
244 changes: 236 additions & 8 deletions tests/unit/test_framework_converter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (C) 2023 Intel Corporation
# Copyright (C) 2023-2024 Intel Corporation
#
# SPDX-License-Identifier: MIT

@@ -13,14 +13,16 @@
from datumaro.components.annotation import (
AnnotationType,
Bbox,
Caption,
Label,
LabelCategories,
Mask,
Polygon,
Tabular,
)
from datumaro.components.dataset import Dataset
from datumaro.components.dataset_base import DatasetItem
from datumaro.components.media import Image
from datumaro.components.media import Image, Table, TableRow
from datumaro.plugins.framework_converter import (
TASK_ANN_TYPE,
DmTfDataset,
@@ -36,6 +38,8 @@

try:
import torch
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torchvision import datasets, transforms
except ImportError:
TORCH_AVAILABLE = False
@@ -142,6 +146,89 @@ def fxt_dataset():
)


@pytest.fixture
def fxt_tabular_label_dataset():
table = Table.from_list(
[
{
"label": 1,
"text": "I rented I AM CURIOUS-YELLOW from my video store because of all the controversy that surrounded it when it was first released in 1967. I also heard that at first it was seized by U.S. customs if it ever tried to enter this country, therefore being a fan of films considered "
"controversial"
" I really had to see this for myself.<br /><br />The plot is centered around a young Swedish drama student named Lena who wants to learn everything she can about life. In particular she wants to focus her attentions to making some sort of documentary on what the average Swede thought about certain political issues such as the Vietnam War and race issues in the United States. In between asking politicians and ordinary denizens of Stockholm about their opinions on politics, she has sex with her drama teacher, classmates, and married men.<br /><br />What kills me about I AM CURIOUS-YELLOW is that 40 years ago, this was considered pornographic. Really, the sex and nudity scenes are few and far between, even then it's not shot like some cheaply made porno. While my countrymen mind find it shocking, in reality sex and nudity are a major staple in Swedish cinema. Even Ingmar Bergman, arguably their answer to good old boy John Ford, had sex scenes in his films.<br /><br />I do commend the filmmakers for the fact that any sex shown in the film is shown for artistic purposes rather than just to shock people and make money to be shown in pornographic theaters in America. I AM CURIOUS-YELLOW is a good film for anyone wanting to study the meat and potatoes (no pun intended) of Swedish cinema. But really, this film doesn't have much of a plot.",
}
]
)
return Dataset.from_iterable(
[
DatasetItem(
id=0,
subset="train",
media=TableRow(table=table, index=0),
annotations=[Label(id=0, attributes={}, group=0, object_id=-1, label=0)],
)
],
categories={
AnnotationType.label: LabelCategories.from_iterable(
[("label:1", "label"), ("label:2", "label")]
)
},
media_type=TableRow,
)


@pytest.fixture
def fxt_tabular_caption_dataset():
table = Table.from_list(
[
{
"source": "Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.",
"target": "Two young, White males are outside near many bushes.",
}
]
)
return Dataset.from_iterable(
[
DatasetItem(
id=0,
subset="train",
media=TableRow(table=table, index=0),
annotations=[
Caption("target:Two young, White males are outside near many bushes.")
],
)
],
categories={},
media_type=TableRow,
)


@pytest.fixture
def fxt_dummy_tokenizer():
def dummy_tokenizer(text):
return text.split()

return dummy_tokenizer


@pytest.fixture
def data_iter():
return [(1, "This is a sample text"), (2, "Another sample text")]


@pytest.fixture
def fxt_dummy_vocab(fxt_dummy_tokenizer, data_iter):
vocab = build_vocab_from_iterator(
map(fxt_dummy_tokenizer, (text for _, text in data_iter)), specials=["<unk>"]
)
vocab.set_default_index(vocab["<unk>"])
return vocab


@pytest.fixture
def fxt_tabular_fixture(fxt_dummy_tokenizer, fxt_dummy_vocab):
return {"target": {"input": "text"}, "tokenizer": fxt_dummy_tokenizer, "vocab": fxt_dummy_vocab}


@pytest.mark.new
@mark_requirement(Requirements.DATUM_GENERAL_REQ)
class FrameworkConverterFactoryTest(TestCase):
@@ -173,38 +260,49 @@ def test_create_converter_tf_importerror(self):
@mark_requirement(Requirements.DATUM_GENERAL_REQ)
class MultiframeworkConverterTest:
@pytest.mark.parametrize(
"fxt_subset,fxt_task",
"fxt_dataset_type,fxt_subset,fxt_task",
[
(
"fxt_dataset",
"train",
"classification",
),
(
"fxt_dataset",
"val",
"multilabel_classification",
),
(
"fxt_dataset",
"train",
"detection",
),
(
"fxt_dataset",
"val",
"instance_segmentation",
),
(
"fxt_dataset",
"train",
"semantic_segmentation",
),
("fxt_tabular_label_dataset", "train", "tabular"),
],
)
def test_multi_framework_dataset(self, fxt_dataset: Dataset, fxt_subset: str, fxt_task: str):
def test_multi_framework_dataset(
self, fxt_dataset_type: str, fxt_subset: str, fxt_task: str, request
):
dataset = request.getfixturevalue(fxt_dataset_type)
dm_multi_framework_dataset = _MultiFrameworkDataset(
dataset=fxt_dataset, subset=fxt_subset, task=fxt_task
dataset=dataset, subset=fxt_subset, task=fxt_task
)

for idx in range(len(dm_multi_framework_dataset)):
image, label = dm_multi_framework_dataset._gen_item(idx)
assert isinstance(image, np.ndarray)
if fxt_task == "tabular":
image = image()
assert isinstance(image, (np.ndarray, dict))
if fxt_task == "classification":
assert isinstance(label, int)
elif fxt_task == "multilabel_classification":
@@ -213,6 +311,8 @@ def test_multi_framework_dataset(self, fxt_dataset: Dataset, fxt_subset: str, fx
assert isinstance(label, list)
if fxt_task == "semantic_segmentation":
assert isinstance(label, np.ndarray)
elif fxt_task == "tabular":
assert isinstance(label, list)

@pytest.mark.skipif(not TORCH_AVAILABLE, reason="PyTorch is not installed")
@pytest.mark.parametrize(
@@ -261,7 +361,6 @@ def test_can_convert_torch_framework(
fxt_subset: str,
fxt_task: str,
fxt_convert_kwargs: Dict[str, Any],
request: pytest.FixtureRequest,
):
multi_framework_dataset = FrameworkConverter(fxt_dataset, subset=fxt_subset, task=fxt_task)

@@ -294,7 +393,12 @@ def test_can_convert_torch_framework(
if ann.type == TASK_ANN_TYPE[fxt_task]
]
label = np.sum(masks, axis=0, dtype=np.uint8)

elif fxt_task == "tabular":
label = [
ann.as_dict()
for ann in exp_item.annotations
if ann.type in TASK_ANN_TYPE[fxt_task]
]
if fxt_convert_kwargs.get("transform", None):
actual = dm_torch_item[0].permute(1, 2, 0).mul(255.0).to(torch.uint8).numpy()
assert np.array_equal(image, actual)
@@ -374,6 +478,130 @@ def test_can_convert_torch_framework_detection(self):
assert torch_ann["bbox"] == [x1, y1, x2 - x1, y2 - y1]
assert torch_ann["iscrowd"] == dm_ann["attributes"]["is_crowd"]

@pytest.mark.skipif(not TORCH_AVAILABLE, reason="PyTorch is not installed")
def test_can_convert_torch_framework_tabular_label(self, fxt_tabular_label_dataset):
class IMDBDataset(Dataset):
def __init__(self, data_iter, vocab, transform=None):
self.data = list(data_iter)
self.vocab = vocab
self.transform = transform
self.tokenizer = get_tokenizer("basic_english")

def __len__(self):
return len(self.data)

def __getitem__(self, idx):
label, text = self.data[idx]
token_ids = [self.vocab[token] for token in self.tokenizer(text)]

if self.transform:
token_ids = self.transform(token_ids)

return torch.tensor(token_ids, dtype=torch.long), torch.tensor(
label, dtype=torch.long
)

# Prepare data and tokenizer
# First item of IMDB
first_item = (
1,
"I rented I AM CURIOUS-YELLOW from my video store because of all the controversy that surrounded it when it was first released in 1967. I also heard that at first it was seized by U.S. customs if it ever tried to enter this country, therefore being a fan of films considered \"controversial\" I really had to see this for myself.<br /><br />The plot is centered around a young Swedish drama student named Lena who wants to learn everything she can about life. In particular she wants to focus her attentions to making some sort of documentary on what the average Swede thought about certain political issues such as the Vietnam War and race issues in the United States. In between asking politicians and ordinary denizens of Stockholm about their opinions on politics, she has sex with her drama teacher, classmates, and married men.<br /><br />What kills me about I AM CURIOUS-YELLOW is that 40 years ago, this was considered pornographic. Really, the sex and nudity scenes are few and far between, even then it's not shot like some cheaply made porno. While my countrymen mind find it shocking, in reality sex and nudity are a major staple in Swedish cinema. Even Ingmar Bergman, arguably their answer to good old boy John Ford, had sex scenes in his films.<br /><br />I do commend the filmmakers for the fact that any sex shown in the film is shown for artistic purposes rather than just to shock people and make money to be shown in pornographic theaters in America. I AM CURIOUS-YELLOW is a good film for anyone wanting to study the meat and potatoes (no pun intended) of Swedish cinema. But really, this film doesn't have much of a plot.",
)
tokenizer = get_tokenizer("basic_english")

# Build vocabulary
vocab = build_vocab_from_iterator([tokenizer(first_item[1])], specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])

# Create torch dataset
torch_dataset = IMDBDataset(iter([first_item]), vocab)

# Convert to dm_torch_dataset
dm_dataset = fxt_tabular_label_dataset
multi_framework_dataset = FrameworkConverter(dm_dataset, subset="train", task="tabular")
dm_torch_dataset = multi_framework_dataset.to_framework(
framework="torch", target={"input": "text"}, tokenizer=tokenizer, vocab=vocab
)

# Verify equality of items in torch_dataset and dm_torch_dataset
label_indices = dm_dataset.categories().get(AnnotationType.label)._indices
torch_item = torch_dataset[0]
dm_item = dm_torch_dataset[0]
assert torch.equal(torch_item[0], dm_item[0]), "Token IDs do not match"

# Extract and compare labels
torch_item_label = str(torch_item[1].item())
dm_item_label = list(label_indices.keys())[list(label_indices.values()).index(0)].split(
":"
)[-1]
assert torch_item_label == dm_item_label, "Labels do not match"

@pytest.mark.skipif(not TORCH_AVAILABLE, reason="PyTorch is not installed")
def test_can_convert_torch_framework_tabular_caption(self, fxt_tabular_caption_dataset):
class Multi30kDataset(Dataset):
def __init__(self, dataset, src_tokenizer, tgt_tokenizer, src_vocab, tgt_vocab):
self.dataset = list(dataset)
self.src_tokenizer = src_tokenizer
self.tgt_tokenizer = tgt_tokenizer
self.src_vocab = src_vocab
self.tgt_vocab = tgt_vocab

def __len__(self):
return len(self.dataset)

def _data_process(self, text, tokenizer, vocab):
tokens = tokenizer(text)
token_ids = [vocab[token] for token in tokens]
return torch.tensor(token_ids, dtype=torch.long)

def __getitem__(self, idx):
src, tgt = self.dataset[idx]
src_tensor = self._data_process(src, self.src_tokenizer, self.src_vocab)
tgt_tensor = self._data_process(tgt, self.tgt_tokenizer, self.tgt_vocab)
return src_tensor, tgt_tensor

# Prepare data and tokenizer
# First item of Multi30k
first_item = (
"Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.",
"Two young, White males are outside near many bushes.",
)

dummy_tokenizer = str.split

def build_single_vocab(item, tokenizer, specials):
tokens = tokenizer(item)
vocab = build_vocab_from_iterator([tokens], specials=specials)
vocab.set_default_index(vocab["<unk>"])
return vocab

# Build vocabularies
specials = ["<unk>", "<pad>", "<bos>", "<eos>"]
src_vocab = build_single_vocab(first_item[0], dummy_tokenizer, specials)
tgt_vocab = build_single_vocab(first_item[1], dummy_tokenizer, specials)

# Create torch dataset
torch_dataset = Multi30kDataset(
iter([first_item]), dummy_tokenizer, dummy_tokenizer, src_vocab, tgt_vocab
)

# Convert to dm_torch_dataset
dm_dataset = fxt_tabular_caption_dataset
multi_framework_dataset = FrameworkConverter(dm_dataset, subset="train", task="tabular")
dm_torch_dataset = multi_framework_dataset.to_framework(
framework="torch",
target={"input": "source", "output": "target"},
tokenizer=(dummy_tokenizer, dummy_tokenizer),
vocab=(src_vocab, tgt_vocab),
)

# Verify equality of items in torch_dataset and dm_torch_dataset
torch_item = torch_dataset[0]
dm_item = dm_torch_dataset[0]

assert torch.equal(torch_item[0], dm_item[0]), "Token IDs for de do not match"
assert torch.equal(torch_item[1], dm_item[1]), "Token IDs for en do not match"

@pytest.mark.skipif(not TF_AVAILABLE, reason="Tensorflow is not installed")
@pytest.mark.parametrize(
"fxt_subset,fxt_task,fxt_convert_kwargs",
20 changes: 10 additions & 10 deletions tests/unit/test_kitti_raw_format.py
Original file line number Diff line number Diff line change
@@ -52,13 +52,13 @@ def test_can_load(self):
annotations=[
Cuboid3d(
position=[1, 2, 3],
scale=[7.95, -3.62, -1.03],
scale=[-3.62, 7.95, -1.03],
label=1,
attributes={"occluded": False, "track_id": 1},
),
Cuboid3d(
position=[1, 1, 0],
scale=[8.34, 23.01, -0.76],
scale=[23.01, 8.34, -0.76],
label=0,
attributes={"occluded": False, "track_id": 2},
),
@@ -71,7 +71,7 @@ def test_can_load(self):
annotations=[
Cuboid3d(
position=[0, 1, 0],
scale=[8.34, 23.01, -0.76],
scale=[23.01, 8.34, -0.76],
rotation=[1, 1, 3],
label=0,
attributes={"occluded": True, "track_id": 2},
@@ -85,7 +85,7 @@ def test_can_load(self):
annotations=[
Cuboid3d(
position=[1, 2, 3],
scale=[-9.41, 13.54, 0.24],
scale=[13.54, -9.41, 0.24],
label=1,
attributes={"occluded": False, "track_id": 3},
)
@@ -161,7 +161,7 @@ def test_can_save_and_load(self):
Cuboid3d(position=[1.4, 2.1, 1.4], label=1, attributes={"track_id": 2}),
Cuboid3d(
position=[11.4, -0.1, 4.2],
scale=[2, 1, 2],
scale=[1, 2, 2],
label=0,
attributes={"track_id": 3},
),
@@ -172,7 +172,7 @@ def test_can_save_and_load(self):
annotations=[
Cuboid3d(
position=[0.4, -1, 2.24],
scale=[2, 1, 2],
scale=[1, 2, 2],
label=0,
attributes={"track_id": 3},
),
@@ -185,7 +185,7 @@ def test_can_save_and_load(self):
annotations=[
Cuboid3d(
position=[0.4, -1, 3.24],
scale=[2, 1, 2],
scale=[1, 2, 2],
label=0,
attributes={"track_id": 3},
),
@@ -244,7 +244,7 @@ def test_can_save_and_load(self):
),
Cuboid3d(
position=[11.4, -0.1, 4.2],
scale=[2, 1, 2],
scale=[1, 2, 2],
label=0,
attributes={"occluded": False, "track_id": 3},
),
@@ -256,7 +256,7 @@ def test_can_save_and_load(self):
annotations=[
Cuboid3d(
position=[0.4, -1, 2.24],
scale=[2, 1, 2],
scale=[1, 2, 2],
label=0,
attributes={"occluded": False, "track_id": 3},
),
@@ -271,7 +271,7 @@ def test_can_save_and_load(self):
annotations=[
Cuboid3d(
position=[0.4, -1, 3.24],
scale=[2, 1, 2],
scale=[1, 2, 2],
label=0,
attributes={"occluded": False, "track_id": 3},
),