diff --git a/heat/core/io.py b/heat/core/io.py index b74829ce9..c48c2c3a6 100644 --- a/heat/core/io.py +++ b/heat/core/io.py @@ -7,6 +7,7 @@ import numpy as np import torch import warnings +import fnmatch from typing import Dict, Iterable, List, Optional, Tuple, Union @@ -27,7 +28,15 @@ __NETCDF_EXTENSIONS = frozenset([".nc", ".nc4", "netcdf"]) __NETCDF_DIM_TEMPLATE = "{}_dim_{}" -__all__ = ["load", "load_csv", "save_csv", "save", "supports_hdf5", "supports_netcdf"] +__all__ = [ + "load", + "load_csv", + "save_csv", + "save", + "supports_hdf5", + "supports_netcdf", + "load_npy_from_path", +] try: import netCDF4 as nc @@ -1131,3 +1140,56 @@ def save( DNDarray.save = lambda self, path, *args, **kwargs: save(self, path, *args, **kwargs) DNDarray.save.__doc__ = save.__doc__ + + +def load_npy_from_path( + path: str, + dtype: datatype = types.int32, + split: int = 0, + device: Optional[str] = None, + comm: Optional[Communication] = None, +) -> DNDarray: + """ + Loads multiple .npy files into one DNDarray which will be returned. The data will be concatenated along the split axis provided as input. + + Parameters + ---------- + path : str + Path to the directory in which .npy-files are located. + dtype : datatype, optional + Data type of the resulting array. + split : int + Along which axis the loaded arrays should be concatenated. + device : str, optional + The device id on which to place the data, defaults to globally set default device. + comm : Communication, optional + The communication to use for the data distribution, default is 'heat.MPI_WORLD' + """ + if not isinstance(path, str): + raise TypeError(f"path must be str, not {type(path)}") + elif split is not None and not isinstance(split, int): + raise TypeError(f"split must be None or int, not {type(split)}") + + process_number = MPI_WORLD.size + file_list = [] + for file in os.listdir(path): + if fnmatch.fnmatch(file, "*.npy"): + file_list.append(file) + n_files = len(file_list) + + if n_files == 0: + raise ValueError("No .npy Files were found") + if (n_files < process_number) and (process_number > 1): + raise RuntimeError("Number of processes can't exceed number of files") + + rank = MPI_WORLD.rank + n_for_procs = n_files // process_number + idx = rank * n_for_procs + if rank + 1 == process_number: + n_for_procs += n_files % process_number + array_list = [np.load(path + "/" + element) for element in file_list[idx : idx + n_for_procs]] + larray = np.concatenate(array_list, split) + larray = torch.from_numpy(larray) + + x = factories.array(larray, dtype=dtype, device=device, is_split=split, comm=comm) + return x diff --git a/heat/core/tests/test_io.py b/heat/core/tests/test_io.py index 7909343d6..5b8ece13b 100644 --- a/heat/core/tests/test_io.py +++ b/heat/core/tests/test_io.py @@ -2,6 +2,9 @@ import os import torch import tempfile +import random +import time +import fnmatch import heat as ht from .test_suites.basic_test import TestCase @@ -739,3 +742,68 @@ def test_save_netcdf_exception(self): # os.rmdir(os.getcwd() + '/tmp/') # except OSError: # pass + + def test_load_npy_int(self): + # testing for int arrays + if ht.MPI_WORLD.rank == 0: + crea_array = [] + for i in range(0, 20): + x = np.random.randint(1000, size=(random.randint(0, 30), 6, 11)) + np.save(os.path.join(os.getcwd(), "heat/datasets", "int_data") + str(i), x) + crea_array.append(x) + int_array = np.concatenate(crea_array) + ht.MPI_WORLD.Barrier() + + load_array = ht.load_npy_from_path( + os.path.join(os.getcwd(), "heat/datasets"), dtype=ht.int32, split=0 + ) + load_array_npy = load_array.numpy() + + self.assertIsInstance(load_array, ht.DNDarray) + self.assertEqual(load_array.dtype, ht.int32) + if ht.MPI_WORLD.rank == 0: + self.assertTrue((load_array_npy == int_array).all) + for file in os.listdir(os.path.join(os.getcwd(), "heat/datasets")): + if fnmatch.fnmatch(file, "*.npy"): + os.remove(os.path.join(os.getcwd(), "heat/datasets", file)) + + def test_load_npy_float(self): + # testing for float arrays and split dimension other than 0 + if ht.MPI_WORLD.rank == 0: + crea_array = [] + for i in range(0, 20): + x = np.random.rand(2, random.randint(1, 10), 11) + np.save(os.path.join(os.getcwd(), "heat/datasets", "float_data") + str(i), x) + crea_array.append(x) + float_array = np.concatenate(crea_array, 1) + ht.MPI_WORLD.Barrier() + + load_array = ht.load_npy_from_path( + os.path.join(os.getcwd(), "heat/datasets"), dtype=ht.float64, split=1 + ) + load_array_npy = load_array.numpy() + self.assertIsInstance(load_array, ht.DNDarray) + self.assertEqual(load_array.dtype, ht.float64) + if ht.MPI_WORLD.rank == 0: + self.assertTrue((load_array_npy == float_array).all) + for file in os.listdir(os.path.join(os.getcwd(), "heat/datasets")): + if fnmatch.fnmatch(file, "*.npy"): + os.remove(os.path.join(os.getcwd(), "heat/datasets", file)) + + def test_load_npy_exception(self): + with self.assertRaises(TypeError): + ht.load_npy_from_path(path=1, split=0) + with self.assertRaises(TypeError): + ht.load_npy_from_path("heat/datasets", split="ABC") + with self.assertRaises(ValueError): + ht.load_npy_from_path(path="heat", dtype=ht.int64, split=0) + if ht.MPI_WORLD.size > 1: + if ht.MPI_WORLD.rank == 0: + x = np.random.rand(2, random.randint(1, 10), 11) + np.save(os.path.join(os.getcwd(), "heat/datasets", "float_data"), x) + ht.MPI_WORLD.Barrier() + with self.assertRaises(RuntimeError): + ht.load_npy_from_path("heat/datasets", dtype=ht.int64, split=0) + ht.MPI_WORLD.Barrier() + if ht.MPI_WORLD.rank == 0: + os.remove(os.path.join(os.getcwd(), "heat/datasets", "float_data.npy"))