Skip to content
This repository has been archived by the owner on Nov 23, 2024. It is now read-only.

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
style: apply automated linter fixes
Browse files Browse the repository at this point in the history
megalinter-bot committed May 8, 2024
1 parent 5782725 commit 5190242
Showing 2 changed files with 22 additions and 9 deletions.
4 changes: 3 additions & 1 deletion src/safeds_datasets/image/_mnist/_mnist.py
Original file line number Diff line number Diff line change
@@ -204,7 +204,9 @@ def _load_mnist_like(
[labels[label_index] for label_index in array("B", label_file.read())],
)
else:
test_labels = Column(file_name, [labels[label_index] for label_index in array("B", label_file.read())])
test_labels = Column(
file_name, [labels[label_index] for label_index in array("B", label_file.read())],
)
else:
with gzip.open(path / file_path, mode="rb") as image_file:
magic, size, rows, cols = struct.unpack(">IIII", image_file.read(16))
27 changes: 19 additions & 8 deletions tests/safeds_datasets/image/_mnist/test_mnist.py
Original file line number Diff line number Diff line change
@@ -4,8 +4,7 @@

import pytest
from safeds.data.labeled.containers import ImageDataset

from safeds_datasets.image import load_mnist, _mnist, load_fashion_mnist, load_kmnist
from safeds_datasets.image import _mnist, load_fashion_mnist, load_kmnist, load_mnist


class TestMNIST:
@@ -22,11 +21,15 @@ def test_should_download_and_return_mnist(self) -> None:
assert len(test) == 10_000
train_output = train.get_output()
test_output = test.get_output()
assert set(train_output.get_unique_values()) == set(test_output.get_unique_values()) == set(_mnist._mnist._mnist_labels.values())
assert (
set(train_output.get_unique_values())
== set(test_output.get_unique_values())
== set(_mnist._mnist._mnist_labels.values())
)

def test_should_raise_if_file_not_found(self) -> None:
with tempfile.TemporaryDirectory() as tmpdirname, pytest.raises(FileNotFoundError):
load_mnist(tmpdirname, download=False)
load_mnist(tmpdirname, download=False)


class TestFashionMNIST:
@@ -43,11 +46,15 @@ def test_should_download_and_return_mnist(self) -> None:
assert len(test) == 10_000
train_output = train.get_output()
test_output = test.get_output()
assert set(train_output.get_unique_values()) == set(test_output.get_unique_values()) == set(_mnist._mnist._fashion_mnist_labels.values())
assert (
set(train_output.get_unique_values())
== set(test_output.get_unique_values())
== set(_mnist._mnist._fashion_mnist_labels.values())
)

def test_should_raise_if_file_not_found(self) -> None:
with tempfile.TemporaryDirectory() as tmpdirname, pytest.raises(FileNotFoundError):
load_fashion_mnist(tmpdirname, download=False)
load_fashion_mnist(tmpdirname, download=False)


class TestKMNIST:
@@ -64,8 +71,12 @@ def test_should_download_and_return_mnist(self) -> None:
assert len(test) == 10_000
train_output = train.get_output()
test_output = test.get_output()
assert set(train_output.get_unique_values()) == set(test_output.get_unique_values()) == set(_mnist._mnist._kuzushiji_mnist_labels.values())
assert (
set(train_output.get_unique_values())
== set(test_output.get_unique_values())
== set(_mnist._mnist._kuzushiji_mnist_labels.values())
)

def test_should_raise_if_file_not_found(self) -> None:
with tempfile.TemporaryDirectory() as tmpdirname, pytest.raises(FileNotFoundError):
load_kmnist(tmpdirname, download=False)
load_kmnist(tmpdirname, download=False)

0 comments on commit 5190242

Please sign in to comment.